KDTree

学习链接:http://www.cnblogs.com/eyeszjwang/articles/2429382.html

下面实现的kdtree支持以下操作:
(1) 插入一个节点
(2) 插入n个节点
(3) 查找距离某个给定点距离最小的K个节点。

插入的时候可能会导致树严重不平衡,这个时候会重建某个子树。

  1 #include <algorithm>
  2 #include <queue>
  3 
  4 /***
  5   _NodeType: 节点类型
  6   _CompareFuncType: 比较函数类型,它应该接受两个_NodeType并返回0或者1
  7   _UpdateFuncType:  更新函数类型,它有三个参数,分别表示Father,leftSon,rightSon,参数的类型的 _NodeType*
  8   _DIMENSION:  维数
  9   _DISTANCE_TYPE: 两个_NodeType之间距离的类型
 10   _REBUILD_ALPHA: 这个参数用来维持树大致平衡,它应该大于50小于等于100
 11 ***/
 12 template<class _NodeType,
 13          class _CompareFuncType,
 14          class _UpdateFuncType,
 15          int _DIMENSION,
 16          class _DISTANCE_TYPE,
 17          int _REBUILD_ALPHA=80>
 18 class KDTree
 19 {
 20 private:
 21     struct TreeNode
 22     {
 23         TreeNode(const _NodeType& _Node):m_Left(nullptr),m_Right(nullptr),m_Size(1),
 24             m_Node(_Node) {}
 25         TreeNode() {}
 26 
 27         _NodeType m_Node;
 28         TreeNode* m_Left;
 29         TreeNode* m_Right;
 30         unsigned int m_Size;
 31     };
 32 
 33     TreeNode* NewNode(const _NodeType& _Node)
 34     {
 35         return new TreeNode(_Node);
 36     }
 37 
 38 public:
 39     KDTree(_CompareFuncType** _CompareFuncs,_UpdateFuncType* _UpdateFunc):m_Root(nullptr)
 40     {
 41         for(unsigned int Idx=0;Idx<_DIMENSION;++Idx)
 42         {
 43             m_CompareFuncs[Idx]=_CompareFuncs[Idx];
 44         }
 45         m_UpdateFunc=_UpdateFunc;
 46     }
 47 
 48     void insert(const _NodeType& _Value)
 49     {
 50         TreeNode* BadTreeNode=nullptr;
 51         TreeNode* BadTreeNodeParent=nullptr;
 52         unsigned int BadTreeNodeDimension=0;
 53         m_Root=Insert(m_Root,NewNode(_Value),0,
 54             BadTreeNode,BadTreeNodeParent,BadTreeNodeDimension);
 55 
 56         if(BadTreeNode!=nullptr)
 57         {
 58             if(BadTreeNodeParent==nullptr)
 59             {
 60                 m_Root=RebuildTree(BadTreeNode,BadTreeNodeDimension);
 61             }
 62             else
 63             {
 64                 BadTreeNodeParent=RebuildTree(BadTreeNode,BadTreeNodeDimension);
 65             }
 66         }
 67     }
 68 
 69 
 70     
 71     template<class _NodeTypeBegin>
 72     void insert(_NodeTypeBegin _ValueArray,const unsigned int _ValueArraySize)
 73     {
 74         if(_ValueArraySize<=0) return;
 75         if(m_Root==nullptr)
 76         {
 77             m_Root=BuildGroup(_ValueArray,0,_ValueArraySize-1,0);
 78         }
 79         else
 80         {
 81             if(m_Root->m_Size<=(unsigned int)(_ValueArraySize*_REBUILD_ALPHA/100))
 82             {
 83                 _NodeType* NewValueArray=new _NodeType[_ValueArraySize+m_Root->m_Size];
 84                 unsigned int NewValueArraySize=StoreSubtreeNodeIntoArray(m_Root,NewValueArray);
 85                 ClearSubTrees(m_Root);
 86                 for(unsigned int Idx=0;Idx<_ValueArraySize;++Idx)
 87                 {
 88                     NewValueArray[NewValueArraySize++]=_ValueArray[Idx];
 89                 }
 90                 m_Root=BuildGroup(NewValueArray,0,NewValueArraySize-1,0);
 91                 delete[] NewValueArray;
 92             }
 93             else
 94             {
 95                 for(unsigned int Idx=0;Idx<_ValueArraySize;++Idx)
 96                 {
 97                     insert(_ValueArray[Idx]);
 98                 }
 99             }
100         }
101     }
102 
103 
104     /***
105       查找距离_Value "最近" 的_SearchNumber个元素 存储在_StoreAnswerArray
106 
107       _ComputeMinDistanceFunc接受两个_NodeType(first,second) 用来计算second范围内所有点到first的 "最近"距离
108              _DISTANCE_TYPE是它的返回值类型
109       _ComputeDistanceFunc接受两个_NodeType(first,second) 计算first和second的距离
110              _DISTANCE_TYPE是它的返回值类型
111       _CompareDistanceFunc它接受两个_DISTANCE_TYPE(first,second) 并返回0或者1
112              1表示first小于second
113       函数返回查找到的元素个数(有可能小于_SearchNumber)
114     ***/
115     template<class _ComputeDistanceFuncType,
116              class _CompareDistanceFuncType>
117     unsigned int searchKNear(const _NodeType& _Value,_NodeType* _StoreAnswerArray,
118         const unsigned int _SearchNumber,_CompareDistanceFuncType* _CompareDistanceFunc,
119         _ComputeDistanceFuncType* _ComputeMinDistanceFunc,
120         _ComputeDistanceFuncType* _ComputeDistanceFunc)
121     {
122         if(_SearchNumber==0) return 0;
123         unsigned int AnswerArrayElementNumber=0;
124         SearchKNear(m_Root,_Value,_StoreAnswerArray,AnswerArrayElementNumber,_SearchNumber,
125             _CompareDistanceFunc,_ComputeMinDistanceFunc,_ComputeDistanceFunc);
126         return AnswerArrayElementNumber;
127     }
128 
129     unsigned int size() const
130     {
131         if(m_Root) return m_Root->m_Size;
132         return 0;
133     }
134 
135 private:
136     template<class _ComputeDistanceFuncType,
137              class _CompareDistanceFuncType>
138     void SearchKNear(
139            TreeNode* _Root,const _NodeType& _Value,_NodeType* _StoreAnswerArray,
140            unsigned int &_CurAnswerArrayElementNumber,
141            const unsigned int _SearchNumber,_CompareDistanceFuncType* _CompareDistanceFunc,
142            _ComputeDistanceFuncType* _ComputeMinDistanceFunc,
143            _ComputeDistanceFuncType* _ComputeDistanceFunc)
144     {
145         if(_Root==nullptr) return;
146         if(_CurAnswerArrayElementNumber==0)
147         {
148             _StoreAnswerArray[0]=_Root->m_Node;
149             ++_CurAnswerArrayElementNumber;
150         }
151         else
152         {
153             _DISTANCE_TYPE CurNodeDis=_ComputeDistanceFunc(_Value,_Root->m_Node);
154             for(unsigned int Idx=_CurAnswerArrayElementNumber-1;;--Idx)
155             {
156                 _DISTANCE_TYPE PreNodeDis=_ComputeDistanceFunc(_Value,_StoreAnswerArray[Idx]);
157                 if(_CompareDistanceFunc(CurNodeDis,PreNodeDis))
158                 {
159                     if(Idx+1<_SearchNumber)
160                     {
161                         _StoreAnswerArray[Idx+1]=_StoreAnswerArray[Idx];
162                     }
163                 }
164                 else
165                 {
166                     if(Idx+1<_SearchNumber) _StoreAnswerArray[Idx+1]=_Root->m_Node;
167                     if(_CurAnswerArrayElementNumber<_SearchNumber)
168                     {
169                         ++_CurAnswerArrayElementNumber;
170                     }
171                     break;
172                 }
173                 if(0==Idx)
174                 {
175                     _StoreAnswerArray[0]=_Root->m_Node;
176                     if(_CurAnswerArrayElementNumber<_SearchNumber)
177                     {
178                         ++_CurAnswerArrayElementNumber;
179                     }
180                     break;
181                 }
182             }
183         }
184         if(_Root->m_Left&&_Root->m_Right)
185         {
186             _DISTANCE_TYPE LSonMinDis=_ComputeMinDistanceFunc(_Value,_Root->m_Left->m_Node);
187             _DISTANCE_TYPE RSonMinDis=_ComputeMinDistanceFunc(_Value,_Root->m_Right->m_Node);
188             _DISTANCE_TYPE CurMaxDis=_ComputeDistanceFunc(
189                 _Value,_StoreAnswerArray[_CurAnswerArrayElementNumber-1]);
190 
191             if(_CompareDistanceFunc(LSonMinDis,RSonMinDis))
192             {
193                 if(_CurAnswerArrayElementNumber<_SearchNumber||_CompareDistanceFunc(LSonMinDis,CurMaxDis))
194                 {
195                     SearchKNear(_Root->m_Left,_Value,_StoreAnswerArray,_CurAnswerArrayElementNumber,
196                         _SearchNumber,_CompareDistanceFunc,_ComputeMinDistanceFunc,
197                         _ComputeDistanceFunc);
198                 }
199                 CurMaxDis=_ComputeDistanceFunc(
200                     _Value,_StoreAnswerArray[_CurAnswerArrayElementNumber-1]);
201                 if(_CurAnswerArrayElementNumber<_SearchNumber||_CompareDistanceFunc(RSonMinDis,CurMaxDis))
202                 {
203                     SearchKNear(_Root->m_Right,_Value,_StoreAnswerArray,_CurAnswerArrayElementNumber,
204                         _SearchNumber,_CompareDistanceFunc,_ComputeMinDistanceFunc,
205                         _ComputeDistanceFunc);
206                 }
207             }
208             else
209             {
210                 if(_CurAnswerArrayElementNumber<_SearchNumber||_CompareDistanceFunc(RSonMinDis,CurMaxDis))
211                 {
212                     SearchKNear(_Root->m_Right,_Value,_StoreAnswerArray,_CurAnswerArrayElementNumber,
213                         _SearchNumber,_CompareDistanceFunc,_ComputeMinDistanceFunc,
214                         _ComputeDistanceFunc);
215                 }
216                 CurMaxDis=_ComputeDistanceFunc(
217                     _Value,_StoreAnswerArray[_CurAnswerArrayElementNumber-1]);
218                 if(_CurAnswerArrayElementNumber<_SearchNumber||_CompareDistanceFunc(LSonMinDis,CurMaxDis))
219                 {
220                     SearchKNear(_Root->m_Left,_Value,_StoreAnswerArray,_CurAnswerArrayElementNumber,
221                         _SearchNumber,_CompareDistanceFunc,_ComputeMinDistanceFunc,
222                         _ComputeDistanceFunc);
223                 }
224             }
225         }
226         else if(_Root->m_Left)
227         {
228             _DISTANCE_TYPE LSonMinDis=_ComputeMinDistanceFunc(_Value,_Root->m_Left->m_Node);
229             _DISTANCE_TYPE CurMaxDis=_ComputeDistanceFunc(
230                 _Value,_StoreAnswerArray[_CurAnswerArrayElementNumber-1]);
231             if(_CurAnswerArrayElementNumber<_SearchNumber||_CompareDistanceFunc(LSonMinDis,CurMaxDis))
232             {
233                 SearchKNear(_Root->m_Left,_Value,_StoreAnswerArray,_CurAnswerArrayElementNumber,
234                     _SearchNumber,_CompareDistanceFunc,_ComputeMinDistanceFunc,
235                     _ComputeDistanceFunc);
236             }
237         }
238         else if(_Root->m_Right)
239         {
240             _DISTANCE_TYPE RSonMinDis=_ComputeMinDistanceFunc(_Value,_Root->m_Right->m_Node);
241             _DISTANCE_TYPE CurMaxDis=_ComputeDistanceFunc(
242                 _Value,_StoreAnswerArray[_CurAnswerArrayElementNumber-1]);
243             if(_CurAnswerArrayElementNumber<_SearchNumber||_CompareDistanceFunc(RSonMinDis,CurMaxDis))
244             {
245                 SearchKNear(_Root->m_Right,_Value,_StoreAnswerArray,_CurAnswerArrayElementNumber,
246                     _SearchNumber,_CompareDistanceFunc,_ComputeMinDistanceFunc,
247                     _ComputeDistanceFunc);
248             }
249         }
250     }
251     
252     unsigned int StoreSubtreeNodeIntoArray(TreeNode* _Root,_NodeType* _NodeArray)
253     {
254         if(_Root==nullptr) return 0;
255         unsigned int NodeArraySize=0;
256         std::queue<TreeNode*> TmpQue;
257         TmpQue.push(_Root);
258         _NodeArray[NodeArraySize++]=_Root->m_Node;
259         while(!TmpQue.empty())
260         {
261             TreeNode* Tmp=TmpQue.front(); TmpQue.pop();
262             if(Tmp->m_Left)
263             {
264                 TmpQue.push(Tmp->m_Left);
265                 _NodeArray[NodeArraySize++]=Tmp->m_Left->m_Node;
266             }
267             if(Tmp->m_Right)
268             {
269                 TmpQue.push(Tmp->m_Right);
270                 _NodeArray[NodeArraySize++]=Tmp->m_Right->m_Node;
271             }
272         }
273         return NodeArraySize;
274     }
275 
276     void ClearSubTrees(TreeNode* _Root)
277     {
278         if(_Root)
279         {
280             if(_Root->m_Left) ClearSubTrees(_Root->m_Left);
281             if(_Root->m_Right) ClearSubTrees(_Root->m_Right);
282             delete _Root;
283         }
284     }
285 
286     TreeNode* RebuildTree(TreeNode* _Root,const unsigned int _Dimension)
287     {
288         _NodeType* TmpPool=new _NodeType[_Root->m_Size];
289         unsigned int TmpPoolSize=StoreSubtreeNodeIntoArray(_Root,TmpPool);
290         ClearSubTrees(_Root);
291         if(TmpPoolSize==0) return nullptr;
292         TreeNode* Tmp=BuildGroup(TmpPool,0,TmpPoolSize-1,_Dimension);
293         delete[] TmpPool;
294         return Tmp;
295     }
296 
297     template<class _NodeTypeBegin>
298     TreeNode* BuildGroup(_NodeTypeBegin _TmpPool,const unsigned int _Left,const unsigned int _Right,const unsigned int _CurLayerDimention)
299        {
300         if(_Left>_Right) return nullptr;
301 
302         const unsigned int MidPos=(_Left+_Right)>>1;
303         std::nth_element(_TmpPool+_Left,_TmpPool+MidPos,_TmpPool+_Right+1,
304             m_CompareFuncs[_CurLayerDimention]);
305 
306         TreeNode* CurNode=NewNode(_TmpPool[MidPos]);
307         if(_Left+1<=MidPos)
308         {
309             CurNode->m_Left=BuildGroup(_TmpPool,_Left,MidPos-1,(_CurLayerDimention+1)%_DIMENSION);
310         }
311         CurNode->m_Right=BuildGroup(_TmpPool,MidPos+1,_Right,(_CurLayerDimention+1)%_DIMENSION);
312 
313         PushUp(CurNode);
314         return CurNode;
315     }
316 
317 
318     void PushUp(TreeNode* _Root)
319     {
320         if(_Root)
321         {
322             _Root->m_Size=1+SonSize(_Root->m_Left)+SonSize(_Root->m_Right);
323             _NodeType *LsonNode=_Root->m_Left?&(_Root->m_Left->m_Node):nullptr;
324             _NodeType *RsonNode=_Root->m_Right?&_Root->m_Right->m_Node:nullptr;
325             m_UpdateFunc(&_Root->m_Node,LsonNode,RsonNode);
326         }
327     }
328 
329     TreeNode* Insert(
330         TreeNode* _Root,
331         TreeNode* _InsertNode,
332         const unsigned int _CurLayerDimention,
333         TreeNode* &_BadTreeNode,
334         TreeNode* &_BadTreeNodeParent,
335         unsigned int& _BadTreeNodeDimension)
336     {
337         if(nullptr==_Root)
338         {
339             PushUp(_InsertNode);
340             return _InsertNode;
341         }
342         if(m_CompareFuncs[_CurLayerDimention](_InsertNode->m_Node,_Root->m_Node))
343         {
344             _Root->m_Left=Insert(_Root->m_Left,_InsertNode,(_CurLayerDimention+1)%_DIMENSION,
345                 _BadTreeNode,_BadTreeNodeParent,_BadTreeNodeDimension);
346         }
347         else
348         {
349             _Root->m_Right=Insert(_Root->m_Right,_InsertNode,(_CurLayerDimention+1)%_DIMENSION,
350                 _BadTreeNode,_BadTreeNodeParent,_BadTreeNodeDimension);
351         }
352 
353         PushUp(_Root);
354 
355         if(_BadTreeNode==nullptr)
356         {
357             if(IsSeriousBadTree(_Root))
358             {
359                 _BadTreeNode=_Root;
360                 _BadTreeNodeDimension=_CurLayerDimention;
361             }
362         }
363         else if(_BadTreeNode==_Root->m_Left||_BadTreeNode==_Root->m_Right)
364         {
365             _BadTreeNodeParent=_Root;
366         }
367         return _Root;
368     }
369 
370     unsigned int SonSize(TreeNode* _Node)
371     {
372         if(_Node==nullptr) return 0;
373         return _Node->m_Size;
374     }
375 
376     bool IsSeriousBadTree(TreeNode* _Root)
377     {
378         if(SonSize(_Root)==0) return false;
379         return std::max(SonSize(_Root->m_Left),SonSize(_Root->m_Right))
380             >=(unsigned int)(SonSize(_Root)*_REBUILD_ALPHA/100)+5;
381     }
382 
383     TreeNode* m_Root;
384     _CompareFuncType* m_CompareFuncs[_DIMENSION];
385     _UpdateFuncType* m_UpdateFunc;
386 };

下面是一个简单的测试代码,两个点之间的距离定义为曼哈顿距离。

 1 struct node
 2 {
 3     int x,y;
 4     int MinX,MaxX,MinY,MaxY;
 5 
 6     node(int _x=0,int _y=0):x(_x),y(_y) {}
 7 
 8 };
 9 
10 typedef int Func(const node&,const node&);
11 typedef void Func1(node*,node*,node*);
12 
13 int cmp0(const node &a,const node &b)
14 {
15     return a.x<b.x;
16 }
17 
18 int cmp1(const node &a,const node &b)
19 {
20     return a.y<b.y;
21 }
22 
23 void pushUp(node *Fa,node *lson,node *rson)
24 {
25     Fa->MinX=Fa->MaxX=Fa->x;
26     Fa->MinY=Fa->MaxY=Fa->y;
27     for(int i=0;i<2;++i)
28     {
29         node* p=i==0?lson:rson;
30         if(!p) continue;
31 
32         Fa->MinX=min(Fa->MinX,p->MinX);
33         Fa->MaxX=max(Fa->MaxX,p->MaxX);
34         Fa->MinY=min(Fa->MinY,p->MinY);
35         Fa->MaxY=max(Fa->MaxY,p->MaxY);
36     }
37 }
38 
39 int caldis(const node &a,const node &b)
40 {
41     return abs(a.x-b.x)+abs(a.y-b.y);
42 }
43 
44 int calMinDis(const node &a,const node &b)
45 {
46     int xx=0;
47     int yy=0;
48     if(a.x<b.MinX) xx=b.MinX-a.x;
49     else if(a.x>b.MaxX) xx=a.x-b.MaxX;
50 
51     if(a.y<b.MinY) yy=b.MinY-a.y;
52     else if(a.y>b.MaxY) yy=a.y-b.MaxY;
53 
54     return xx+yy;
55 }
56 
57 int cmp(int x,int y)
58 {
59     return x<y;
60 }
61 
62 
63 
64 int main()
65 {
66     Func* funs[2]={cmp0,cmp1};
67     KDTree<node,Func,Func1,2,int> *T=
68         new KDTree<node,Func,Func1,2,int>(funs,pushUp);
69 
70     node a[4]={node(2,3),node(-1,4),node(5,6),node(10,1)};
71     T->insert(a,4);
72     T->insert(node(5,9));
73     unsigned int Num=T->searchKNear(node(3,7),a,3,cmp,calMinDis,caldis);
74     for(unsigned int Idx=0;Idx<Num;++Idx)
75     {
76         printf("%d %d
",a[Idx].x,a[Idx].y);
77     }
78     /**
79     5 6
80     5 9
81     2 3
82     **/
83 }
原文地址:https://www.cnblogs.com/jianglangcaijin/p/5998004.html