hdu 2966 In case of failure (KDTree)

http://acm.hdu.edu.cn/showproblem.php?pid=2966

  一道KD树的题。题意是,给出n个不重合的点,求出这n个点的最邻近点的距离的平方。

  什么是KE树就不介绍了,网上有许多KD的资料,做这题前先阅读材料。我的方法参考的是http://blog.csdn.net/zhjchengfeng5/article/details/7855241这个博客的代码,划分的过程直接调用STL中的nth_element,从而减少代码量。

  我的做法是直接用点集数组构建线性存储的一棵KD树,然后用类似于线段树操作对点集进行划分和查找。因为题目的特殊性,于是我们可以将查找的时候,重合的(也就是距离为0的)点间的距离赋值为inf,这样子就可以直接利用查找最近点的方法找到目标点了。

代码如下:

 1 #include <cstdio>
 2 #include <algorithm>
 3 #include <vector>
 4 #include <cstring>
 5 #include <iostream>
 6 
 7 using namespace std;
 8 
 9 typedef long long LL;
10 
11 const int N = 111111;
12 struct Point {
13     LL x[3];
14 } p[N], ori[N];
15 int split[20], cur, dim;
16 
17 bool cmp(Point a, Point b) {
18     return a.x[cur] < b.x[cur];
19 }
20 
21 #define lson l, m - 1, depth + 1
22 #define rson m + 1, r, depth + 1
23 
24 void build(int l, int r, int depth) {
25     if (l >= r) return ;
26     int m = l + r >> 1;
27     cur = depth % dim;
28     nth_element(p + l, p + m, p + r + 1, cmp);
29     build(lson);
30     build(rson);
31 }
32 
33 template <class T> T sqr(T x) { return x * x;}
34 const LL inf = 0x7777777777777777ll;
35 
36 LL dist(Point x, Point y) {
37     LL ret = 0;
38     for (int i = 0; i < dim; i++) {
39         ret += sqr(x.x[i] - y.x[i]);
40     }
41     return ret ? ret : inf;
42 }
43 
44 LL find(Point x, int l, int r, int depth) {
45     int cur = depth % dim;
46     if (l >= r) {
47         if (l == r) return dist(x, p[l]);
48         return inf;
49     }
50     int m = l + r >> 1;
51     LL ret = dist(x, p[m]), tmp;
52     if (x.x[cur] < p[m].x[cur]) {
53         tmp = find(x, lson);
54         if (tmp > sqr(x.x[cur] - p[m].x[cur])) {
55             tmp = min(tmp, find(x, rson));
56         }
57     } else {
58         tmp = find(x, rson);
59         if (tmp > sqr(x.x[cur] - p[m].x[cur])) {
60             tmp = min(tmp, find(x, lson));
61         }
62     }
63     return min(ret, tmp);
64 }
65 
66 int main() {
67 //    freopen("in", "r", stdin);
68     int n, T;
69     scanf("%d", &T);
70     while (T-- && scanf("%d", &n)) {
71         dim = 2;
72         for (int i = 0; i < n; i++) {
73             for (int j = 0; j < 2; j++) {
74                 scanf("%I64d", &ori[i].x[j]);
75             }
76             p[i] = ori[i];
77         }
78         build(0, n - 1, 0);
79         for (int i = 0; i < n; i++) {
80             printf("%I64d\n", find(ori[i], 0, n - 1, 0));
81         }
82     }
83     return 0;
84 }
View Code

——written by Lyon

原文地址:https://www.cnblogs.com/LyonLys/p/hdu_2966_Lyon.html