KDTree  C++实现

参考:https://blog.csdn.net/qing101hua/article/details/53228668

#include <iostream>
#include <vector>
#include <stack>
#include <cmath>

#define KDtreeSize 1000

#define UL unsigned long

using namespace std;


struct coordinate
{
    double x = 0;
    double y = 0;
    UL index = 0;
};

struct TreeNode
{
    struct coordinate dom_elt;
    UL split = 0;
    struct TreeNode* left = nullptr;
    struct TreeNode* right = nullptr;
};

bool cmp1(const coordinate& a, const coordinate& b){
    return a.x < b.x;
}

bool cmp2(const coordinate& a, const coordinate& b){
    return a.y < b.y;
}

bool equal(const coordinate& a, const coordinate& b){
    return (a.x == b.x && a.y == b.y);
}

void ChooseSplit(coordinate exm_set[], UL size, UL &split, coordinate &SplitChoice){
    /*compute the variance on every dimension. Set split as the dismension that have the biggest
     variance. Then choose the instance which is the median on this split dimension.*/
    /*compute variance on the x,y dimension. DX=EX^2-(EX)^2*/
    double tmp1, tmp2;
    tmp1 = tmp2 = 0;
    for (int i = 0; i < size; ++i)
    {
        tmp1 += 1.0 / (double)size * exm_set[i].x * exm_set[i].x;
        tmp2 += 1.0 / (double)size * exm_set[i].x;
    }
    double v1 = tmp1 - tmp2 * tmp2;  //compute variance on the x dimension

    tmp1 = tmp2 = 0;
    for (int i = 0; i < size; ++i)
    {
        tmp1 += 1.0 / (double)size * exm_set[i].y * exm_set[i].y;
        tmp2 += 1.0 / (double)size * exm_set[i].y;
    }
    double v2 = tmp1 - tmp2 * tmp2;  //compute variance on the y dimension

    split = v1 > v2 ? 0:1; //set the split dimension

    if (split == 0)
    {
        sort(exm_set,exm_set + size, cmp1);
    }
    else{
        sort(exm_set,exm_set + size, cmp2);
    }

    //set the split point value
    SplitChoice.x = exm_set[size / 2].x;
    SplitChoice.y = exm_set[size / 2].y;

}


TreeNode* build_kdtree(coordinate exm_set[], UL size, TreeNode* T){
    //call function ChooseSplit to choose the split dimension and split point
    if (size == 0){
        return nullptr;
    }
    else{
        UL split;
        coordinate dom_elt;
        ChooseSplit(exm_set, size, split, dom_elt);
        coordinate exm_set_right [KDtreeSize];
        coordinate exm_set_left [KDtreeSize];
        UL size_left ,size_right;
        size_left = size_right = 0;

        if (split == 0)
        {
            for (UL i = 0; i < size; ++i)
            {

                if (!equal(exm_set[i],dom_elt) && exm_set[i].x <= dom_elt.x)
                {
                    exm_set_left[size_left].x = exm_set[i].x;
                    exm_set_left[size_left].y = exm_set[i].y;
                    size_left++;
                }
                else if (!equal(exm_set[i],dom_elt) && exm_set[i].x > dom_elt.x)
                {
                    exm_set_right[size_right].x = exm_set[i].x;
                    exm_set_right[size_right].y = exm_set[i].y;
                    size_right++;
                }
            }
        }
        else{
            for (UL i = 0; i < size; ++i)
            {

                if (!equal(exm_set[i],dom_elt) && exm_set[i].y <= dom_elt.y)
                {
                    exm_set_left[size_left].x = exm_set[i].x;
                    exm_set_left[size_left].y = exm_set[i].y;
                    size_left++;
                }
                else if (!equal(exm_set[i],dom_elt) && exm_set[i].y > dom_elt.y)
                {
                    exm_set_right[size_right].x = exm_set[i].x;
                    exm_set_right[size_right].y = exm_set[i].y;
                    size_right++;
                }
            }
        }
        T = new TreeNode;
        T->dom_elt.x = dom_elt.x;
        T->dom_elt.y = dom_elt.y;
        T->split = split;
        T->left = build_kdtree(exm_set_left, size_left, T->left);
        T->right = build_kdtree(exm_set_right, size_right, T->right);
        return T;

    }
}


double Distance(coordinate a, coordinate b){
    double tmp = (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y);
    return sqrt(tmp);
}


void searchNearest(TreeNode * Kd, coordinate target, coordinate &nearestpoint, double & distance){

    //1. 如果Kd是空的,则设dist为无穷大返回

    //2. 向下搜索直到叶子结点

    stack<TreeNode*> search_path;
    TreeNode* pSearch = Kd;
    coordinate nearest;
    double dist;

    while(pSearch != nullptr)
    {
        //pSearch加入到search_path中;
        search_path.push(pSearch);

        if (pSearch->split == 0)
        {
            if(target.x <= pSearch->dom_elt.x) /* 如果小于就进入左子树 */
            {
                pSearch = pSearch->left;
            }
            else
            {
                pSearch = pSearch->right;
            }
        }
        else{
            if(target.y <= pSearch->dom_elt.y) /* 如果小于就进入左子树 */
            {
                pSearch = pSearch->left;
            }
            else
            {
                pSearch = pSearch->right;
            }
        }
    }
    //取出search_path最后一个赋给nearest
    nearest.x = search_path.top()->dom_elt.x;
    nearest.y = search_path.top()->dom_elt.y;
    search_path.pop();


    dist = Distance(nearest, target);
    //3. 回溯搜索路径

    TreeNode* pBack;

    while(search_path.empty())
    {
        //取出search_path最后一个结点赋给pBack
        pBack = search_path.top();
        search_path.pop();

        if(pBack->left == nullptr && pBack->right == nullptr) /* 如果pBack为叶子结点 */

        {

            if( Distance(nearest, target) > Distance(pBack->dom_elt, target) )
            {
                nearest = pBack->dom_elt;
                dist = Distance(pBack->dom_elt, target);
            }

        }

        else

        {

            UL s = pBack->split;
            if (s == 0)
            {
                if( fabs(pBack->dom_elt.x - target.x) < dist) /* 如果以target为中心的圆(球或超球),半径为dist的圆与分割超平面相交, 那么就要跳到另一边的子空间去搜索 */
                {
                    if( Distance(nearest, target) > Distance(pBack->dom_elt, target) )
                    {
                        nearest = pBack->dom_elt;
                        dist = Distance(pBack->dom_elt, target);
                    }
                    if(target.x <= pBack->dom_elt.x) /* 如果target位于pBack的左子空间,那么就要跳到右子空间去搜索 */
                        pSearch = pBack->right;
                    else
                        pSearch = pBack->left; /* 如果target位于pBack的右子空间,那么就要跳到左子空间去搜索 */
                    if(pSearch != nullptr)
                        //pSearch加入到search_path中
                        search_path.push(pSearch);
                }
            }
            else {
                if( fabs(pBack->dom_elt.y - target.y) < dist) /* 如果以target为中心的圆(球或超球),半径为dist的圆与分割超平面相交, 那么就要跳到另一边的子空间去搜索 */
                {
                    if( Distance(nearest, target) > Distance(pBack->dom_elt, target) )
                    {
                        nearest = pBack->dom_elt;
                        dist = Distance(pBack->dom_elt, target);
                    }
                    if(target.y <= pBack->dom_elt.y) /* 如果target位于pBack的左子空间,那么就要跳到右子空间去搜索 */
                        pSearch = pBack->right;
                    else
                        pSearch = pBack->left; /* 如果target位于pBack的右子空间,那么就要跳到左子空间去搜索 */
                    if(pSearch != nullptr)
                        // pSearch加入到search_path中
                        search_path.push(pSearch);
                }
            }

        }
    }

    nearestpoint.x = nearest.x;
    nearestpoint.y = nearest.y;
    distance = dist;

}



void test_kdtree(){
    coordinate exm_set[6];
    exm_set[0].x = 2;    exm_set[0].y = 3;
    exm_set[1].x = 5;    exm_set[1].y = 4;
    exm_set[2].x = 9;    exm_set[2].y = 6;
    exm_set[3].x = 4;    exm_set[3].y = 7;
    exm_set[4].x = 8;    exm_set[4].y = 1;
    exm_set[5].x = 7;    exm_set[5].y = 2;

    struct TreeNode * root = nullptr;
    root = build_kdtree(exm_set, 6, root);

    coordinate nearestpoint;
    double distance;
    coordinate target;
    target.x = 2.1;
    target.y = 3.2;

    searchNearest(root, target, nearestpoint, distance);
    cout<<"The nearest distance is "<<distance<<",and the nearest point is "<<nearestpoint.x<<","<<nearestpoint.y<<endl;
}


int main(){
    test_kdtree();
    return 0;
}
原文地址:https://www.cnblogs.com/theodoric008/p/9475189.html