学习链接:http://www.cnblogs.com/eyeszjwang/articles/2429382.html
下面实现的kdtree支持以下操作:
(1) 插入一个节点
(2) 插入n个节点
(3) 查找距离某个给定点距离最小的K个节点。
插入的时候可能会导致树严重不平衡,这个时候会重建某个子树。
1 #include <algorithm> 2 #include <queue> 3 4 /*** 5 _NodeType: 节点类型 6 _CompareFuncType: 比较函数类型,它应该接受两个_NodeType并返回0或者1 7 _UpdateFuncType: 更新函数类型,它有三个参数,分别表示Father,leftSon,rightSon,参数的类型的 _NodeType* 8 _DIMENSION: 维数 9 _DISTANCE_TYPE: 两个_NodeType之间距离的类型 10 _REBUILD_ALPHA: 这个参数用来维持树大致平衡,它应该大于50小于等于100 11 ***/ 12 template<class _NodeType, 13 class _CompareFuncType, 14 class _UpdateFuncType, 15 int _DIMENSION, 16 class _DISTANCE_TYPE, 17 int _REBUILD_ALPHA=80> 18 class KDTree 19 { 20 private: 21 struct TreeNode 22 { 23 TreeNode(const _NodeType& _Node):m_Left(nullptr),m_Right(nullptr),m_Size(1), 24 m_Node(_Node) {} 25 TreeNode() {} 26 27 _NodeType m_Node; 28 TreeNode* m_Left; 29 TreeNode* m_Right; 30 unsigned int m_Size; 31 }; 32 33 TreeNode* NewNode(const _NodeType& _Node) 34 { 35 return new TreeNode(_Node); 36 } 37 38 public: 39 KDTree(_CompareFuncType** _CompareFuncs,_UpdateFuncType* _UpdateFunc):m_Root(nullptr) 40 { 41 for(unsigned int Idx=0;Idx<_DIMENSION;++Idx) 42 { 43 m_CompareFuncs[Idx]=_CompareFuncs[Idx]; 44 } 45 m_UpdateFunc=_UpdateFunc; 46 } 47 48 void insert(const _NodeType& _Value) 49 { 50 TreeNode* BadTreeNode=nullptr; 51 TreeNode* BadTreeNodeParent=nullptr; 52 unsigned int BadTreeNodeDimension=0; 53 m_Root=Insert(m_Root,NewNode(_Value),0, 54 BadTreeNode,BadTreeNodeParent,BadTreeNodeDimension); 55 56 if(BadTreeNode!=nullptr) 57 { 58 if(BadTreeNodeParent==nullptr) 59 { 60 m_Root=RebuildTree(BadTreeNode,BadTreeNodeDimension); 61 } 62 else 63 { 64 BadTreeNodeParent=RebuildTree(BadTreeNode,BadTreeNodeDimension); 65 } 66 } 67 } 68 69 70 71 template<class _NodeTypeBegin> 72 void insert(_NodeTypeBegin _ValueArray,const unsigned int _ValueArraySize) 73 { 74 if(_ValueArraySize<=0) return; 75 if(m_Root==nullptr) 76 { 77 m_Root=BuildGroup(_ValueArray,0,_ValueArraySize-1,0); 78 } 79 else 80 { 81 if(m_Root->m_Size<=(unsigned int)(_ValueArraySize*_REBUILD_ALPHA/100)) 82 { 83 _NodeType* NewValueArray=new _NodeType[_ValueArraySize+m_Root->m_Size]; 84 unsigned int NewValueArraySize=StoreSubtreeNodeIntoArray(m_Root,NewValueArray); 85 ClearSubTrees(m_Root); 86 for(unsigned int Idx=0;Idx<_ValueArraySize;++Idx) 87 { 88 NewValueArray[NewValueArraySize++]=_ValueArray[Idx]; 89 } 90 m_Root=BuildGroup(NewValueArray,0,NewValueArraySize-1,0); 91 delete[] NewValueArray; 92 } 93 else 94 { 95 for(unsigned int Idx=0;Idx<_ValueArraySize;++Idx) 96 { 97 insert(_ValueArray[Idx]); 98 } 99 } 100 } 101 } 102 103 104 /*** 105 查找距离_Value "最近" 的_SearchNumber个元素 存储在_StoreAnswerArray 106 107 _ComputeMinDistanceFunc接受两个_NodeType(first,second) 用来计算second范围内所有点到first的 "最近"距离 108 _DISTANCE_TYPE是它的返回值类型 109 _ComputeDistanceFunc接受两个_NodeType(first,second) 计算first和second的距离 110 _DISTANCE_TYPE是它的返回值类型 111 _CompareDistanceFunc它接受两个_DISTANCE_TYPE(first,second) 并返回0或者1 112 1表示first小于second 113 函数返回查找到的元素个数(有可能小于_SearchNumber) 114 ***/ 115 template<class _ComputeDistanceFuncType, 116 class _CompareDistanceFuncType> 117 unsigned int searchKNear(const _NodeType& _Value,_NodeType* _StoreAnswerArray, 118 const unsigned int _SearchNumber,_CompareDistanceFuncType* _CompareDistanceFunc, 119 _ComputeDistanceFuncType* _ComputeMinDistanceFunc, 120 _ComputeDistanceFuncType* _ComputeDistanceFunc) 121 { 122 if(_SearchNumber==0) return 0; 123 unsigned int AnswerArrayElementNumber=0; 124 SearchKNear(m_Root,_Value,_StoreAnswerArray,AnswerArrayElementNumber,_SearchNumber, 125 _CompareDistanceFunc,_ComputeMinDistanceFunc,_ComputeDistanceFunc); 126 return AnswerArrayElementNumber; 127 } 128 129 unsigned int size() const 130 { 131 if(m_Root) return m_Root->m_Size; 132 return 0; 133 } 134 135 private: 136 template<class _ComputeDistanceFuncType, 137 class _CompareDistanceFuncType> 138 void SearchKNear( 139 TreeNode* _Root,const _NodeType& _Value,_NodeType* _StoreAnswerArray, 140 unsigned int &_CurAnswerArrayElementNumber, 141 const unsigned int _SearchNumber,_CompareDistanceFuncType* _CompareDistanceFunc, 142 _ComputeDistanceFuncType* _ComputeMinDistanceFunc, 143 _ComputeDistanceFuncType* _ComputeDistanceFunc) 144 { 145 if(_Root==nullptr) return; 146 if(_CurAnswerArrayElementNumber==0) 147 { 148 _StoreAnswerArray[0]=_Root->m_Node; 149 ++_CurAnswerArrayElementNumber; 150 } 151 else 152 { 153 _DISTANCE_TYPE CurNodeDis=_ComputeDistanceFunc(_Value,_Root->m_Node); 154 for(unsigned int Idx=_CurAnswerArrayElementNumber-1;;--Idx) 155 { 156 _DISTANCE_TYPE PreNodeDis=_ComputeDistanceFunc(_Value,_StoreAnswerArray[Idx]); 157 if(_CompareDistanceFunc(CurNodeDis,PreNodeDis)) 158 { 159 if(Idx+1<_SearchNumber) 160 { 161 _StoreAnswerArray[Idx+1]=_StoreAnswerArray[Idx]; 162 } 163 } 164 else 165 { 166 if(Idx+1<_SearchNumber) _StoreAnswerArray[Idx+1]=_Root->m_Node; 167 if(_CurAnswerArrayElementNumber<_SearchNumber) 168 { 169 ++_CurAnswerArrayElementNumber; 170 } 171 break; 172 } 173 if(0==Idx) 174 { 175 _StoreAnswerArray[0]=_Root->m_Node; 176 if(_CurAnswerArrayElementNumber<_SearchNumber) 177 { 178 ++_CurAnswerArrayElementNumber; 179 } 180 break; 181 } 182 } 183 } 184 if(_Root->m_Left&&_Root->m_Right) 185 { 186 _DISTANCE_TYPE LSonMinDis=_ComputeMinDistanceFunc(_Value,_Root->m_Left->m_Node); 187 _DISTANCE_TYPE RSonMinDis=_ComputeMinDistanceFunc(_Value,_Root->m_Right->m_Node); 188 _DISTANCE_TYPE CurMaxDis=_ComputeDistanceFunc( 189 _Value,_StoreAnswerArray[_CurAnswerArrayElementNumber-1]); 190 191 if(_CompareDistanceFunc(LSonMinDis,RSonMinDis)) 192 { 193 if(_CurAnswerArrayElementNumber<_SearchNumber||_CompareDistanceFunc(LSonMinDis,CurMaxDis)) 194 { 195 SearchKNear(_Root->m_Left,_Value,_StoreAnswerArray,_CurAnswerArrayElementNumber, 196 _SearchNumber,_CompareDistanceFunc,_ComputeMinDistanceFunc, 197 _ComputeDistanceFunc); 198 } 199 CurMaxDis=_ComputeDistanceFunc( 200 _Value,_StoreAnswerArray[_CurAnswerArrayElementNumber-1]); 201 if(_CurAnswerArrayElementNumber<_SearchNumber||_CompareDistanceFunc(RSonMinDis,CurMaxDis)) 202 { 203 SearchKNear(_Root->m_Right,_Value,_StoreAnswerArray,_CurAnswerArrayElementNumber, 204 _SearchNumber,_CompareDistanceFunc,_ComputeMinDistanceFunc, 205 _ComputeDistanceFunc); 206 } 207 } 208 else 209 { 210 if(_CurAnswerArrayElementNumber<_SearchNumber||_CompareDistanceFunc(RSonMinDis,CurMaxDis)) 211 { 212 SearchKNear(_Root->m_Right,_Value,_StoreAnswerArray,_CurAnswerArrayElementNumber, 213 _SearchNumber,_CompareDistanceFunc,_ComputeMinDistanceFunc, 214 _ComputeDistanceFunc); 215 } 216 CurMaxDis=_ComputeDistanceFunc( 217 _Value,_StoreAnswerArray[_CurAnswerArrayElementNumber-1]); 218 if(_CurAnswerArrayElementNumber<_SearchNumber||_CompareDistanceFunc(LSonMinDis,CurMaxDis)) 219 { 220 SearchKNear(_Root->m_Left,_Value,_StoreAnswerArray,_CurAnswerArrayElementNumber, 221 _SearchNumber,_CompareDistanceFunc,_ComputeMinDistanceFunc, 222 _ComputeDistanceFunc); 223 } 224 } 225 } 226 else if(_Root->m_Left) 227 { 228 _DISTANCE_TYPE LSonMinDis=_ComputeMinDistanceFunc(_Value,_Root->m_Left->m_Node); 229 _DISTANCE_TYPE CurMaxDis=_ComputeDistanceFunc( 230 _Value,_StoreAnswerArray[_CurAnswerArrayElementNumber-1]); 231 if(_CurAnswerArrayElementNumber<_SearchNumber||_CompareDistanceFunc(LSonMinDis,CurMaxDis)) 232 { 233 SearchKNear(_Root->m_Left,_Value,_StoreAnswerArray,_CurAnswerArrayElementNumber, 234 _SearchNumber,_CompareDistanceFunc,_ComputeMinDistanceFunc, 235 _ComputeDistanceFunc); 236 } 237 } 238 else if(_Root->m_Right) 239 { 240 _DISTANCE_TYPE RSonMinDis=_ComputeMinDistanceFunc(_Value,_Root->m_Right->m_Node); 241 _DISTANCE_TYPE CurMaxDis=_ComputeDistanceFunc( 242 _Value,_StoreAnswerArray[_CurAnswerArrayElementNumber-1]); 243 if(_CurAnswerArrayElementNumber<_SearchNumber||_CompareDistanceFunc(RSonMinDis,CurMaxDis)) 244 { 245 SearchKNear(_Root->m_Right,_Value,_StoreAnswerArray,_CurAnswerArrayElementNumber, 246 _SearchNumber,_CompareDistanceFunc,_ComputeMinDistanceFunc, 247 _ComputeDistanceFunc); 248 } 249 } 250 } 251 252 unsigned int StoreSubtreeNodeIntoArray(TreeNode* _Root,_NodeType* _NodeArray) 253 { 254 if(_Root==nullptr) return 0; 255 unsigned int NodeArraySize=0; 256 std::queue<TreeNode*> TmpQue; 257 TmpQue.push(_Root); 258 _NodeArray[NodeArraySize++]=_Root->m_Node; 259 while(!TmpQue.empty()) 260 { 261 TreeNode* Tmp=TmpQue.front(); TmpQue.pop(); 262 if(Tmp->m_Left) 263 { 264 TmpQue.push(Tmp->m_Left); 265 _NodeArray[NodeArraySize++]=Tmp->m_Left->m_Node; 266 } 267 if(Tmp->m_Right) 268 { 269 TmpQue.push(Tmp->m_Right); 270 _NodeArray[NodeArraySize++]=Tmp->m_Right->m_Node; 271 } 272 } 273 return NodeArraySize; 274 } 275 276 void ClearSubTrees(TreeNode* _Root) 277 { 278 if(_Root) 279 { 280 if(_Root->m_Left) ClearSubTrees(_Root->m_Left); 281 if(_Root->m_Right) ClearSubTrees(_Root->m_Right); 282 delete _Root; 283 } 284 } 285 286 TreeNode* RebuildTree(TreeNode* _Root,const unsigned int _Dimension) 287 { 288 _NodeType* TmpPool=new _NodeType[_Root->m_Size]; 289 unsigned int TmpPoolSize=StoreSubtreeNodeIntoArray(_Root,TmpPool); 290 ClearSubTrees(_Root); 291 if(TmpPoolSize==0) return nullptr; 292 TreeNode* Tmp=BuildGroup(TmpPool,0,TmpPoolSize-1,_Dimension); 293 delete[] TmpPool; 294 return Tmp; 295 } 296 297 template<class _NodeTypeBegin> 298 TreeNode* BuildGroup(_NodeTypeBegin _TmpPool,const unsigned int _Left,const unsigned int _Right,const unsigned int _CurLayerDimention) 299 { 300 if(_Left>_Right) return nullptr; 301 302 const unsigned int MidPos=(_Left+_Right)>>1; 303 std::nth_element(_TmpPool+_Left,_TmpPool+MidPos,_TmpPool+_Right+1, 304 m_CompareFuncs[_CurLayerDimention]); 305 306 TreeNode* CurNode=NewNode(_TmpPool[MidPos]); 307 if(_Left+1<=MidPos) 308 { 309 CurNode->m_Left=BuildGroup(_TmpPool,_Left,MidPos-1,(_CurLayerDimention+1)%_DIMENSION); 310 } 311 CurNode->m_Right=BuildGroup(_TmpPool,MidPos+1,_Right,(_CurLayerDimention+1)%_DIMENSION); 312 313 PushUp(CurNode); 314 return CurNode; 315 } 316 317 318 void PushUp(TreeNode* _Root) 319 { 320 if(_Root) 321 { 322 _Root->m_Size=1+SonSize(_Root->m_Left)+SonSize(_Root->m_Right); 323 _NodeType *LsonNode=_Root->m_Left?&(_Root->m_Left->m_Node):nullptr; 324 _NodeType *RsonNode=_Root->m_Right?&_Root->m_Right->m_Node:nullptr; 325 m_UpdateFunc(&_Root->m_Node,LsonNode,RsonNode); 326 } 327 } 328 329 TreeNode* Insert( 330 TreeNode* _Root, 331 TreeNode* _InsertNode, 332 const unsigned int _CurLayerDimention, 333 TreeNode* &_BadTreeNode, 334 TreeNode* &_BadTreeNodeParent, 335 unsigned int& _BadTreeNodeDimension) 336 { 337 if(nullptr==_Root) 338 { 339 PushUp(_InsertNode); 340 return _InsertNode; 341 } 342 if(m_CompareFuncs[_CurLayerDimention](_InsertNode->m_Node,_Root->m_Node)) 343 { 344 _Root->m_Left=Insert(_Root->m_Left,_InsertNode,(_CurLayerDimention+1)%_DIMENSION, 345 _BadTreeNode,_BadTreeNodeParent,_BadTreeNodeDimension); 346 } 347 else 348 { 349 _Root->m_Right=Insert(_Root->m_Right,_InsertNode,(_CurLayerDimention+1)%_DIMENSION, 350 _BadTreeNode,_BadTreeNodeParent,_BadTreeNodeDimension); 351 } 352 353 PushUp(_Root); 354 355 if(_BadTreeNode==nullptr) 356 { 357 if(IsSeriousBadTree(_Root)) 358 { 359 _BadTreeNode=_Root; 360 _BadTreeNodeDimension=_CurLayerDimention; 361 } 362 } 363 else if(_BadTreeNode==_Root->m_Left||_BadTreeNode==_Root->m_Right) 364 { 365 _BadTreeNodeParent=_Root; 366 } 367 return _Root; 368 } 369 370 unsigned int SonSize(TreeNode* _Node) 371 { 372 if(_Node==nullptr) return 0; 373 return _Node->m_Size; 374 } 375 376 bool IsSeriousBadTree(TreeNode* _Root) 377 { 378 if(SonSize(_Root)==0) return false; 379 return std::max(SonSize(_Root->m_Left),SonSize(_Root->m_Right)) 380 >=(unsigned int)(SonSize(_Root)*_REBUILD_ALPHA/100)+5; 381 } 382 383 TreeNode* m_Root; 384 _CompareFuncType* m_CompareFuncs[_DIMENSION]; 385 _UpdateFuncType* m_UpdateFunc; 386 };
下面是一个简单的测试代码,两个点之间的距离定义为曼哈顿距离。
1 struct node 2 { 3 int x,y; 4 int MinX,MaxX,MinY,MaxY; 5 6 node(int _x=0,int _y=0):x(_x),y(_y) {} 7 8 }; 9 10 typedef int Func(const node&,const node&); 11 typedef void Func1(node*,node*,node*); 12 13 int cmp0(const node &a,const node &b) 14 { 15 return a.x<b.x; 16 } 17 18 int cmp1(const node &a,const node &b) 19 { 20 return a.y<b.y; 21 } 22 23 void pushUp(node *Fa,node *lson,node *rson) 24 { 25 Fa->MinX=Fa->MaxX=Fa->x; 26 Fa->MinY=Fa->MaxY=Fa->y; 27 for(int i=0;i<2;++i) 28 { 29 node* p=i==0?lson:rson; 30 if(!p) continue; 31 32 Fa->MinX=min(Fa->MinX,p->MinX); 33 Fa->MaxX=max(Fa->MaxX,p->MaxX); 34 Fa->MinY=min(Fa->MinY,p->MinY); 35 Fa->MaxY=max(Fa->MaxY,p->MaxY); 36 } 37 } 38 39 int caldis(const node &a,const node &b) 40 { 41 return abs(a.x-b.x)+abs(a.y-b.y); 42 } 43 44 int calMinDis(const node &a,const node &b) 45 { 46 int xx=0; 47 int yy=0; 48 if(a.x<b.MinX) xx=b.MinX-a.x; 49 else if(a.x>b.MaxX) xx=a.x-b.MaxX; 50 51 if(a.y<b.MinY) yy=b.MinY-a.y; 52 else if(a.y>b.MaxY) yy=a.y-b.MaxY; 53 54 return xx+yy; 55 } 56 57 int cmp(int x,int y) 58 { 59 return x<y; 60 } 61 62 63 64 int main() 65 { 66 Func* funs[2]={cmp0,cmp1}; 67 KDTree<node,Func,Func1,2,int> *T= 68 new KDTree<node,Func,Func1,2,int>(funs,pushUp); 69 70 node a[4]={node(2,3),node(-1,4),node(5,6),node(10,1)}; 71 T->insert(a,4); 72 T->insert(node(5,9)); 73 unsigned int Num=T->searchKNear(node(3,7),a,3,cmp,calMinDis,caldis); 74 for(unsigned int Idx=0;Idx<Num;++Idx) 75 { 76 printf("%d %d ",a[Idx].x,a[Idx].y); 77 } 78 /** 79 5 6 80 5 9 81 2 3 82 **/ 83 }