一篇系列——KD-Tree 一篇就够了

Python代码部分


import heapq

maxn = int(1e5)
K = 13 # dimension

class Pt:
    def __init__(self, x = None) -> None:
        # x 假如用 List 会出现天坑.... href: https://www.cnblogs.com/jclian91/p/10325849.html
        if x is None: self.x = [0 for _ in range(K)]
        else : self.x = x
        self.val = 0

class Node:
    def __init__(self, d = 0, id = 0) -> None:
        """ 用于放到heapq中的结点 """
        self.d = d 
        self.id = id 
    def __lt__(self, other):
        return self.d > other.d 
    def __str__(self) -> str:
        return "(d: %.2f, id: %d)" % (self.d, self.id)

class KDTree_Node:
    def __init__(self) -> None:
        self.SZ = 0
        self.lc = 0
        self.rc = 0
        self.maxn = [0] * K
        self.minn = [0] * K 
        self.place = Pt()

class KDTree:
    def __init__(self, n = 0, alpha= 0.75) -> None:
        """ initial KDself.Tree """
        self.n = n # p array length 
        self.p = [Pt() for _ in range(maxn)]
        self.Tr = [KDTree_Node() for _ in range(maxn)]
        # 优先级队列,大根堆
        self.hp = []
        self.alpha = alpha # 替罪羊树重构因子
        self.cuK = 0 # current K
        self.top, self.tot = 0, 0 # self.top-> self.store array, self.tot -> self.total KD-self.Tree size
        self.store = [0] * maxn # self.store useless node
        self.root = 1 # temp save the KDTree root 
    
    def New(self) -> int:
        if self.top != 0:
            self.top -= 1
            return self.store[self.top + 1]
        self.tot += 1
        return self.tot

    def update(self, x) -> None:
        """ use to update the father node info """
        lc = self.Tr[x].lc; rc = self.Tr[x].rc 
        for i in range(K):
            self.Tr[x].maxn[i] = self.Tr[x].minn[i] = self.Tr[x].place.x[i]
            if (lc != 0):
                self.Tr[x].maxn[i] = max(self.Tr[x].maxn[i], self.Tr[lc].maxn[i])
                self.Tr[x].minn[i] = min(self.Tr[x].minn[i], self.Tr[lc].minn[i])
            if (rc != 0):
                self.Tr[x].maxn[i] = max(self.Tr[x].maxn[i], self.Tr[rc].maxn[i])
                self.Tr[x].minn[i] = min(self.Tr[x].minn[i], self.Tr[rc].minn[i])
        self.Tr[x].SZ = self.Tr[lc].SZ + self.Tr[rc].SZ + 1

    def build(self, l: int, r: int, dep= 0) -> int:
        if (l > r): return 0
        self.cuK = dep % K
        mid = (l + r) >> 1
        x = self.New()
        
        temp_p = self.p[l:r + 1]
        temp_p.sort(key= lambda arr: arr.x[self.cuK])
        self.p[l:r + 1] = temp_p 
        
        self.Tr[x].place = self.p[mid]
        self.Tr[x].lc = self.build(l, mid - 1, dep + 1)
        self.Tr[x].rc = self.build(mid + 1, r, dep + 1) 
        self.update(x)
        return x

    def rebuild(self, x: int, base= 0):
        """ recursion rebuild, and store these self.Tree node in store """
        lc = self.Tr[x].lc
        rc = self.Tr[x].rc
        if lc != 0: self.rebuild(lc, base)
        self.p[base + self.Tr[lc].SZ + 1] = self.Tr[x].place
        self.top += 1; self.store[self.top] = x;
        if rc != 0: self.rebuild(rc, base + self.Tr[lc].SZ + 1)

    def check(self, x: int, dep: int) -> int:
        """ if meet the limit, rebuild the sub-Tree """
        if (self.alpha * self.Tr[x].SZ < self.Tr[self.Tr[x].lc].SZ or self.alpha * self.Tr[x].SZ < self.Tr[self.Tr[x].rc].SZ):
            self.rebuild(x)
            x = self.build(1, self.Tr[x].SZ, dep)
        return x

    def insert(self, inP: Pt, loc: int, dep= 0) -> int:
        if loc == 0:
            loc = self.New()
            self.Tr[loc].place = inP
            self.Tr[loc].lc = self.Tr[loc].rc = 0
            self.update(loc)
            return loc 
        if (inP.x[dep % K] <= self.Tr[loc].place.x[dep % K]):
            self.Tr[loc].lc = self.insert(inP, self.Tr[loc].lc, dep + 1)
        else :
            self.Tr[loc].rc = self.insert(inP, self.Tr[loc].rc, dep +1)
        self.update(loc)
        loc = self.check(loc, dep)
        return loc 

    def getdis(self, temp: Pt, x: int):
        """ 用于计算upper bound的估值 """
        res = 0
        for i in range(K):
            res += (max(0, temp.x[i] - self.Tr[x].maxn[i]) + max(0, self.Tr[x].minn[i] - temp.x[i])) ** 2
        return res

    def dist(self, a, b):
        res = 0
        for i in range(K):
            res += (a.x[i] - b.x[i]) ** 2
        return res 

    def _query(self, ask: Pt, x: int, k: int) -> None:
        """ inter query """
        d = self.dist(ask, self.Tr[x].place) # 当前值
        heapq.heappush(self.hp, Node(d, x))
        if len(self.hp) > k:
            heapq.heappop(self.hp)
        lim = self.hp[0].d # upper bound 
        lc = self.Tr[x].lc; rc = self.Tr[x].rc;
        inf = 1e9 # 最值
        dl = inf; dr = inf; # 左右孩子的估值 dl, dr 
        if lc != 0: dl = self.getdis(ask, lc);
        if rc != 0: dr = self.getdis(ask, rc);
        
        # print(lim, dl, dr, d, len(self.hp), "lc: ", lc, "rc: ", rc)
        
        if dl > dr:
            dl, dr = dr, dl
            lc, rc = rc, lc 
        if (dl < lim or len(self.hp) < k): self._query(ask, lc, k);
        if (dr < lim or len(self.hp) < k): self._query(ask, rc, k);
        
    def query(self, ask: Pt, x: int, k= 1):
        """ outer query """
        assert k > 0 and k <= self.n
        self.hp = []
        self._query(ask, x, k)
        return [[e.d, e.id] for e in self.hp]

# KNN-KDTree
class KNN_KDTree:
    def __init__(self, X_train, y_train, n_neighbors=3, p=2): # 通过n_neighbors修改k值
        """
        parameter: n_neighbors 临近点个数
        parameter: p 距离度量
        """
        self.n = n_neighbors
        self.p = p
        self.X_train = X_train
        self.y_train = y_train
        self.KDTree = KDTree()
        self.KDTree.n = len(X_train)
        
        Len = len(self.X_train)
        for i in range(1, len(self.X_train) + 1, 1):
            for j in range(K):
                self.KDTree.p[i].x[j] = self.X_train[i - 1][j]
            self.KDTree.p[i].val = self.y_train[i - 1]
        self.KDTree.root = self.KDTree.build(1, Len)
        
    def predict(self, X):
        res = self.KDTree.query(Pt(X), self.KDTree.root, self.n)
        knn = [self.KDTree.Tr[k[-1]].place.val for k in res]
        count_pairs = Counter(knn)
        # print(count_pairs.items())
        max_count = sorted(count_pairs.items(), key= lambda x: x[1])[-1][0]
        return max_count

    def score(self, X_test, y_test):
        right_count = 0
        for X, y in zip(X_test, y_test):
            label = self.predict(X)
            if label == y:
                right_count += 1
        return right_count / len(X_test)

C++代码部分


#include <bits/stdc++.h>
using namespace std;

using ll = long long;
using db = double;
const int maxn = 3e5 + 50;
const int K = 2; // 维度
const db alpha = 0.75; // 替罪羊树重构因子
int cuK; // current K
int top, tot; // top -> store array, tot -> total KDTree size 
int store[maxn]; // store useless node, to save memory 


struct Pt{
    int x[K]; // denote the position
    int val; // denote the value of point
    /* compare by K-Dimension */
    inline bool operator< (const Pt &other){
        return x[cuK] < other.x[cuK];
    }
}p[maxn];

struct KDTree{
    int SZ; // denote the size of child root 
    int lc, rc; // denote left, right child 
    int maxn[K], minn[K]; // maxn(i) denote the maximum of K-th dimension
    Pt place; // denote the split point
}Tr[maxn];

inline int New() { if (top) return store[top--]; return ++tot; }

#define chmax(x, y) x = max(x, y)
#define chmin(x, y) x = min(x, y)
inline void update(int x){
    /* use to update the father node info */
    int lc = Tr[x].lc, rc = Tr[x].rc;
    for (int i = 0; i < K; ++ i){
        Tr[x].maxn[i] = Tr[x].minn[i] = Tr[x].place.x[i];
        if (lc) chmax(Tr[x].maxn[i], Tr[lc].maxn[i]), chmin(Tr[x].minn[i], Tr[lc].minn[i]);
        if (rc) chmax(Tr[x].maxn[i], Tr[rc].maxn[i]), chmin(Tr[x].minn[i], Tr[rc].minn[i]);
    }
    Tr[x].SZ = Tr[lc].SZ + Tr[rc].SZ + 1; // update size 
}

int build(int l, int r, int dep= 0){
    if (l > r) return 0;
    cuK = dep % K; // determine the current dimension
    int mid = (1ll * l +  r) >> 1, x = New();
    nth_element(p + l, p + mid, p + r + 1), Tr[x].place = p[mid];
    Tr[x].lc = build(l, mid - 1, dep + 1), Tr[x].rc = build(mid + 1, r, dep + 1);
    update(x);
    return x;
}

void rebuild(int x, int base= 0){
    /* recursion rebuild, and store these tree node in store */
    int lc = Tr[x].lc, rc = Tr[x].rc;
    if (lc) rebuild(lc, base);
    p[base + Tr[lc].SZ + 1] = Tr[x].place, store[++top] = x;
    if (rc) rebuild(rc, base + Tr[lc].SZ + 1);
}

void check(int &x, int dep){
    /* if meet the limit, rebuild the sub-tree */
    if (alpha * Tr[x].SZ < Tr[Tr[x].lc].SZ || alpha * Tr[x].SZ < Tr[Tr[x].rc].SZ)
        rebuild(x), x = build(1, Tr[x].SZ, dep);
}

void insert(Pt inP, int &loc, int dep= 0){
    if (!loc) { loc = New(); Tr[loc].place = inP, Tr[loc].lc = Tr[loc].rc = 0; update(loc); return; }
    if (inP.x[dep % K] <= Tr[loc].place.x[dep % K]) insert(inP, Tr[loc].lc, dep + 1);
    else insert(inP, Tr[loc].rc, dep + 1);
    update(loc), check(loc, dep);
}


int getdis(Pt temp, int x){
    int res = 0;
    for (int i = 0; i < K; ++ i){
        res += max(0, temp.x[i] - Tr[x].maxn[i]) + max(0, Tr[x].minn[i] - temp.x[i]); 
    }
    return res;
}

int dist(Pt a, Pt b){
    int res = 0;
    for (int i = 0; i < K; ++ i) res += abs(a.x[i] - b.x[i]);
    return res; 
}

int ans;
const int inf = 0x3f3f3f3f;
inline void query(Pt ask, int x){
    chmin(ans, dist(ask, Tr[x].place));
    int lc = Tr[x].lc, rc = Tr[x].rc;
    int dl = inf, dr = inf;
    if (lc) dl = getdis(ask, lc);
    if (rc) dr = getdis(ask, rc);
    if (dl > dr) swap(dl, dr), swap(lc, rc);
    if (dl < ans) query(ask, lc);
    if (dr < ans) query(ask, rc);
}

void solve(){
    int n, m; std::cin >> n >> m;
    for (int i = 1; i <= n; ++ i) std::cin >> p[i].x[0] >> p[i].x[1];
    int root = build(1, n);
    for (int i = 0; i < m; ++ i){
        int type; std::cin >> type;
        Pt ask; std::cin >> ask.x[0] >> ask.x[1];
        if (type == 1) insert(ask, root);
        else { ans = inf; query(ask, root); std::cout << ans << "
"; }
    }
}

int main(){
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    solve();
    return 0;
}

原文地址:https://www.cnblogs.com/Last--Whisper/p/14550592.html