k近邻法(二)

上一篇文章讲了k近邻法,以及使用kd树构造数据结构,使得提高最近邻点搜索效率,但是这在数据点N 远大于 2^n 时可以有效的降低算法复杂度,n为数据点的维度,否则,由于需要向上回溯比较距离,使得实际效率总是很低(接近线性扫描)。比如SIFT特征矢量128维,SURF特征矢量64维,维度都比较大,N 远大于 2^n 可能无法满足。此外,由于每个最近邻点都需要回溯到根节点才算结束,那么,在获取k个近邻点时,必然存在大量不必要的回溯点,这些都需要另寻其他查询方法。

一个简单的改进思路就是将“查询路径”上的结点进行排序,如按各自分割超平面(也称bin)与查询点的距离排序,也就是说,回溯检查总是从优先级最高(Best Bin)的树结点开始

所以这篇文章讨论这种改进的方法 Best Bin First(BBF)。

主要思想是,使用一个优先列表,按节点与目标点的距离排序,从优先列表中取出第一项(当前距离最近的)节点,按照某个规则决定是访问其左子节点或者右子节点,同时将另一个子节点(如果存在)存储到优先列表中,循环以上操作,直到优先列表为空。

步骤:

  1. 将根结点store in 优先列表Priority中。声明一个最近节点对象nearest,先令其指向根节点,一个当前节点对象current
  2. 取出第一项,根据访问规则,访问其左子节点或者右子节点,并令current指向它,然后将另一个子节点store in Priority中,递归向下,直到遇到叶节点,此时,比较current和nearest哪个距离目标节点更近,更新nearest,
  3. 如果Priority中还有项,则继续步骤2,否则返回nearest,此为最近邻点

由于比较简单,这里不再详述,直接给出代码。

     private List<Tuple<TreeNode, double>> _priorities = new List<Tuple<TreeNode, double>>();
        /// <summary>
        /// 按priority升序排序插入
        /// </summary>
        /// <param name="node"></param>
        /// <param name="priority"></param>
        private void InsertByPriority(TreeNode node, double priority)
        {
            if(_priorities.Count == 0)
            {
                _priorities.Add(new Tuple<TreeNode, double>(node, priority));
            }
            else
            {
                for(int i = 0; i < _priorities.Count; i++)
                {
                    if(_priorities[i].Item2 >= priority)
                    {
                        _priorities.Insert(i, new Tuple<TreeNode, double>(node, priority));
                        break;
                    }
                }
            }
        }
        private double GetPriority(TreeNode node, Point p, int axis) => Math.Abs(node.point.vector[axis] - p.vector[axis]);

        public Point BBFSearchNearestNode(Point p)
        {
            var rootPriority = GetPriority(root, p, root.axis);
            InsertByPriority(root, rootPriority);
            var nearest = root;

            TreeNode topNode = null;        // 优先级最高的节点
            TreeNode curNode = null;        
            while(_priorities.Count > 0)
            {
                topNode = _priorities[0].Item1;
                _priorities.RemoveAt(0);

                while(topNode != null)
                {
                    if(topNode.left != null || topNode.right != null)
                    {
                        var axis = topNode.axis;
                        if(p.vector[axis] <= topNode.point.vector[axis])
                        {
                            // wanna to go down left child node 
                            if(topNode.right != null)                                       // 将右子节点添加到优先列表
                            {
                                InsertByPriority(topNode.right, GetPriority(topNode.right, p, axis));
                            }
                            topNode = topNode.left;
                        }
                        else
                        {
                            // wanna to go down right child node
                            if(topNode.left != null)
                            {
                                InsertByPriority(topNode.left, GetPriority(topNode.left, p, axis));
                            }
                            topNode = topNode.right;
                        }
                    }
                    else
                    {
                        curNode = topNode;
                        topNode = null;
                    }

                    if(curNode != null && p.Distance(curNode.point) < p.Distance(nearest.point))        // find a nearer node
                    {
                        nearest = curNode;
                    }
                }
            }
            return nearest.point;
        }

 上面的代码仅仅是返回了最近的那一个点,如果要返回k个近邻点,则只需对上面代码稍作修改,将 上面每次的current保存到一个按距离排序的列表中,这样前k个点就是所求的k近邻点,代码如下:

        /// <summary>
        /// 最大检测次数
        /// </summary>
        public int max_nn_chks = 0x1000;
        /// <summary>
        /// 搜索k近邻点
        /// </summary>
        /// <param name="p"></param>
        /// <param name="k"></param>
        /// <returns></returns>
        public List<TreeNode> BBFSearchKNearest(Point p, int k)
        {
            var list = new List<BBFData>();    //
            var pq = new MinPQ();
            pq.insert(new PQNode(root, 0));
            int t = 0;
            while(pq.nodes.Count > 0 && t < max_nn_chks)
            {
                var expl = pq.pop_min_default().data;
                expl = Explore2Leaf(expl, p, pq);

                var bbf = new BBFData(expl, expl.point.Distance(p));
                insert(list, k, bbf);

                t++;
            }
            return list.Select(l => l.data).ToList();
        }
        /// <summary>
        /// 向下访问叶节点,并将slide添加到优先列表中
        /// </summary>
        /// <param name="node"></param>
        /// <param name="p"></param>
        /// <param name="pq"></param>
        /// <returns></returns>
        private TreeNode Explore2Leaf(TreeNode node, Point p, MinPQ pq)
        {
            TreeNode unexpl;
            var expl = node;
            TreeNode prev;
            while(expl != null && (expl.left != null || expl.right != null))
            {
                prev = expl;
                var axis = expl.axis;
                var val = expl.point.vector[axis];

                if(p.vector[axis] <= val)
                {
                    unexpl = expl.right;
                    expl = expl.left;
                }
                else
                {
                    unexpl = expl.left;
                    expl = expl.right;
                }
                if(unexpl != null)
                {
                    pq.insert(new PQNode(unexpl, Math.Abs(val - p.vector[axis])));
                }
                if(expl == null)
                {
                    return prev;
                }
            }
            return expl;
        }
        /// <summary>
        /// 将节点按距离插入列表中
        /// </summary>
        /// <param name="list"></param>
        /// <param name="k"></param>
        /// <param name="bbf"></param>
        private void insert(List<BBFData> list, int k, BBFData bbf)
        {
            if(list.Count == 0)
            {
                list.Add(bbf);
                return;
            }

            int ret = 0;
            int oldCount = list.Count;
            var last = list[list.Count - 1];
            var df = bbf.d;
            var dn = last.d;
            if(df >= dn)        // bbf will be appended to list
            {
                if(oldCount == k)     // already has k nearest neighbors, nothing should be done
                {
                    return;
                }
                list.Add(bbf);      // append directively
                return;
            }

            // bbf will be inserted into list internally

            if(oldCount < k)
            {
                // suppose bbf be inserted at idx1, all elements after idx1 should be moved 1 backward respectively
                // first we move the last element
                list.Add(last);     
            }
            // from backer to former, move related elements
            int i = oldCount - 2;
            while(i > -1)
            {
                if (list[i].d <= df)
                    break;

                list[i + 1] = list[i];      // move backward
                i--;
            }
            i++;
            list[i] = bbf;
        }

其中用到的辅助类如下:

    public class BBFData
    {
        public TreeNode data;
        /// <summary>
        /// 节点与目标点的距离
        /// </summary>
        public double d;

        public BBFData(TreeNode data, double d)
        {
            this.data = data;
            this.d = d;
        }
    }

    public class PQNode
    {
        public TreeNode data;
        /// <summary>
        /// 目标点与当前节点的超平面的距离
        /// </summary>
        public double d;

        public PQNode(TreeNode data, double d)
        {
            this.data = data;
            this.d = d;
        }
    }

    public class MinPQ
    {
        public List<PQNode> nodes;
        // 将节点插入优先列表中
        public void insert(PQNode node)    
        {
            nodes.Add(node);

            int i = nodes.Count - 1;
            int p = parent(i);
            PQNode tmp;
            while(i > 0 && p >= 0 && nodes[i].d < nodes[p].d)
            {
                tmp = nodes[p];
                nodes[p] = nodes[i];
                nodes[i] = tmp;
                i = p;
                p = parent(i);
            }
        }

        public PQNode get_min_default() => nodes.Count > 0 ? nodes[0] : null;
        public PQNode pop_min_default()
        {
            if (nodes.Count == 0) return null;

            var ret = nodes[0];
            nodes[0] = nodes[nodes.Count - 1];
            nodes.RemoveAt(nodes.Count - 1);
            restore_minpq_order(0, nodes.Count);

            return ret;
        }

        private void restore_minpq_order(int i, int n)
        {
            int l = left(i);
            int r = right(i);
            int min = i;

            if (l < n && nodes[l].d < nodes[i].d)
                min = l;
            if (r < n && nodes[r].d < nodes[min].d)
                min = r;
            if(min != i)
            {
                var tmp = nodes[min];
                nodes[min] = nodes[i];
                nodes[i] = tmp;
            }
        }

        public static int parent(int i) => (i - 1) / 2;
        public static int right(int i) => 2 * (i + 1);
        public static int left(int i) => 2 * i + 1;
    }

注意,上面代码中,优先列表是使用最小堆实现。

后记:

以上代码搜索k近邻,仅仅是一定程度上得到最大可能的近似k近邻,因为有max_nn_chks的检测次数限制。

假设没有这个限制,则实际上应该对全部训练数据集中的数据点做检测的,而不用k-d树结构存储数据集时也是要检测全部数据点的,不过,两者区别还是有的,也许使用了k-d树后,由于是从优先列表中选择数据点进行检测,导致insert结果列表的操作平均时间复杂度低(当然了,这些我此时并没有去仔细想),而且使用了k-d树后,在数据集数量很大时,需要max_nn_chks限制,此时近似k近邻还是比不使用k-d树得到的近似k近邻更加准确吧(概率意义上)^^!

原文地址:https://www.cnblogs.com/sjjsxl/p/6863586.html