[学习笔记]K-D树

简述

K-D树的本质是一棵二叉查找树,但每一层划分的标准变为某一维度,以垂直于某一坐标轴的超平面将当前区域划分为两个区域

但和二叉查找树不同的是K-D树每个节点储存了一个样本,简单理解为每个节点都代表插入的一个点

构建

考虑当前区域按第(dim)维划分,为了让树尽量平衡,将这个区域内所有点按第(dim)维排序后,从第(mid)个点处划分最优,(C++)可以用(nth\_element)快速求出这个点,顺便把排在这个点之前的点全都放到它左边

当前节点存下第(mid)个点的信息,然后递归的构建左子树和右子树

代码(二维)如下:

KD_Tree::Node *KD_Tree::build(int l, int r, int dim) {
	if (l > r) return NULL;
	Node *rt = new Node();
	int mid = (l + r + 1) >> 1;
	std::nth_element(pt + l, pt + mid, pt + r + 1, cmp[dim]);
	rt->cur = pt[mid];
	rt->ls = build(l, mid - 1, dim ^ 1);
	rt->rs = build(mid + 1, r, dim ^ 1);
	push_up(rt);
	return rt;
}

其中(l)(r)的点是区域内的点,(dim)代表当前维度

查询

以查询距点(p)最近的点为例

依然是递归查找

首先查看当前节点所代表的点是否更优,然后估计左右子树距这个点可能的最近距离

如果估计值更优,就继续往子树查找,否则就不用找下去了

当左右子树都更优时,显然先找较优的子树不会更差,若找完较优的子树后另一棵子树还有可能更优,再到查找另一棵子树中查找

代码(二维)如下:

LL queryMin(KD_Tree *rt, Point &tar) {
	LL res = dist(rt->data, tar);
	if (!res) res = INF;
	LL dl = (rt->son[0] ? getMin(rt->son[0], tar) : INF), dr = (rt->son[1] ? getMin(rt->son[1], tar) : INF);
	if (dl > dr) {
		if (dr < res) res = std::min(res, queryMin(rt->son[1], tar));
		if (dl < res) res = std::min(res, queryMin(rt->son[0], tar));
	} else {
		if (dl < res) res = std::min(res, queryMin(rt->son[0], tar));
		if (dr < res) res = std::min(res, queryMin(rt->son[1], tar));
	}
	return res;
}

节点记录的东西

通常需要记录左右儿子、每一维坐标的最大及最小值、所代表的点的信息,划分的维度可以记录,也可以递归过程中处理出来

其它信息根据题目要求

总结

看起来非常高端的K-D树其实是很朴素的搜索加上很神奇的剪枝

两句话概括就是:

  1. 构建——循环用每一维构建二叉查找树,记录信息
  2. 查询——如果子树中可能有更优解就进入查找,否则退出,优先查找可能的解更优的那颗子树

是不是异常简单?(然而我学了半个星期才打出模板)

代码

最近点对的找不到了,放个这题[CQOI2016]K远点对吧,也挺裸的

#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
#include <queue>
#define sqr(x) ((x) * (x))
#define MAXN 100005

typedef long long LL;
struct Point {
	LL cor[2];
	Point(LL d0 = 0, LL d1 = 1) { cor[0] = d0, cor[1] = d1; }
} pt[MAXN];
struct KD_Tree {
	struct Node {
		Node *ls, *rs;
		Point cur;
		LL maxc[2], minc[2];
	} * root;
	void push_up(Node *);
	Node* build(int, int, int);
	void query(Node *, int, const Point &);
} kd;
int N, K;
std::priority_queue<LL, std::vector<LL>, std::greater<LL> > que;

bool cmp0(const Point &, const Point &);
bool cmp1(const Point &, const Point &);
bool (* cmp[])(const Point &, const Point &) = {cmp0, cmp1};
inline LL dist(const Point &p1, const Point &p2) { return sqr(p1.cor[0] - p2.cor[0]) + sqr(p1.cor[1] - p2.cor[1]); }
inline LL max_dist(const Point &p, KD_Tree::Node *rt) {
	return std::max(sqr(p.cor[0] - rt->maxc[0]), sqr(p.cor[0] - rt->minc[0]))
		 + std::max(sqr(p.cor[1] - rt->maxc[1]), sqr(p.cor[1] - rt->minc[1]));
}
int main() {
	std::ios::sync_with_stdio(false);
	std::cin >> N >> K;
	K <<= 1;
	for (int i = 1; i <= K; ++i) que.push(0);
	for (int i = 1; i <= N; ++i)
		std::cin >> pt[i].cor[0] >> pt[i].cor[1];
	kd.root = kd.build(1, N, 0);
	for (int i = 1; i <= N; ++i)
		kd.query(kd.root, 0, pt[i]);
	std::cout << que.top() << std::endl;
		
	return 0;
}
inline void max(LL &a, LL b) { a = std::max(a, b); }
inline void min(LL &a, LL b) { a = std::min(a, b); }
bool cmp0(const Point &p1, const Point &p2) { return p1.cor[0] < p2.cor[0]; }
bool cmp1(const Point &p1, const Point &p2) { return p1.cor[1] < p2.cor[1]; }
KD_Tree::Node *KD_Tree::build(int l, int r, int dim) {
	if (l > r) return NULL;
	Node *rt = new Node();
	int mid = (l + r + 1) >> 1;
	std::nth_element(pt + l, pt + mid, pt + r + 1, cmp[dim]);
	rt->cur = pt[mid];
	rt->ls = build(l, mid - 1, dim ^ 1);
	rt->rs = build(mid + 1, r, dim ^ 1);
	push_up(rt);
	return rt;
}
void KD_Tree::query(Node *rt, int dim, const Point &p) {
	if (!rt) return;
	if (dist(p, rt->cur) > que.top()) { que.pop(); que.push(dist(p, rt->cur)); }
	LL dl = -0x3f3f3f3f3f3f3f3f, dr = -0x3f3f3f3f3f3f3f3f;
	if (rt->ls) dl = max_dist(p, rt->ls);
	if (rt->rs) dr = max_dist(p, rt->rs);
	if (dl > dr) {
		if (dl > que.top()) query(rt->ls, dim ^ 1, p);
		if (dr > que.top()) query(rt->rs, dim ^ 1, p);
	} else {
		if (dr > que.top()) query(rt->rs, dim ^ 1, p);
		if (dl > que.top()) query(rt->ls, dim ^ 1, p);
	}
}
void KD_Tree::push_up(Node *rt) {
	rt->maxc[0] = rt->minc[0] = rt->cur.cor[0];
	rt->maxc[1] = rt->minc[1] = rt->cur.cor[1];
	if (rt->ls) {
		max(rt->maxc[0], rt->ls->maxc[0]); min(rt->minc[0], rt->ls->minc[0]);
		max(rt->maxc[1], rt->ls->maxc[1]); min(rt->minc[1], rt->ls->minc[1]);
	}
	if (rt->rs) {
		max(rt->maxc[0], rt->rs->maxc[0]); min(rt->minc[0], rt->rs->minc[0]);
		max(rt->maxc[1], rt->rs->maxc[1]); min(rt->minc[1], rt->rs->minc[1]);
	}
}
//Rhein_E
原文地址:https://www.cnblogs.com/Rhein-E/p/10448355.html