【模板】很强的KD-Tree模板

之所以说超强,是因为这个模板又短(我见过最短的kdtree)跑得又快(我用来写了某道题在vj上跑了第一)也易于修改(之前拿某个大板子来改,不仅不好改而且改了跑得贼慢)。无需初始化任何变量,直接build + query!

其实还有优化空间,把取模操作都换成手动的会更快。

HDU4347 - The Closeset M Points

 1 #include<queue>
 2 #include<cstdio>
 3 #include<cstring>
 4 #include<algorithm>
 5 using namespace std;
 6 const int N=55555,K=5;
 7 const int inf=0x3f3f3f3f;
 8 
 9 #define sqr(x) (x)*(x)
10 int k,n,idx;   //k为维数,n为点数
11 struct point
12 {
13     int x[K];
14     bool operator < (const point &u) const
15     {
16      return x[idx]<u.x[idx];
17     }
18 }po[N];
19 
20 typedef pair<double,point>tp;
21 priority_queue<tp>nq;
22 
23 struct kdTree
24 {
25     point pt[N<<2];
26     int son[N<<2];
27 
28     void build(int l,int r,int rt=1,int dep=0)
29     {
30         if(l>r) return;
31         son[rt]=r-l;
32         son[rt*2]=son[rt*2+1]=-1;
33         idx=dep%k;
34         int mid=(l+r)/2;
35         nth_element(po+l,po+mid,po+r+1);
36         pt[rt]=po[mid];
37         build(l,mid-1,rt*2,dep+1);
38         build(mid+1,r,rt*2+1,dep+1);
39     }
40     void query(point p,int m,int rt=1,int dep=0)
41     {
42         if(son[rt]==-1) return;
43         tp nd(0,pt[rt]);
44         for(int i=0;i<k;i++) nd.first+=sqr(nd.second.x[i]-p.x[i]);
45         int dim=dep%k,x=rt*2,y=rt*2+1,fg=0;
46         if(p.x[dim]>=pt[rt].x[dim]) swap(x,y);
47         if(~son[x]) query(p,m,x,dep+1);
48         if(nq.size()<m) nq.push(nd),fg=1;
49         else
50         {
51             if(nd.first<nq.top().first) nq.pop(),nq.push(nd);
52             if(sqr(p.x[dim]-pt[rt].x[dim])<nq.top().first) fg=1;
53         }
54         if(~son[y]&&fg) query(p,m,y,dep+1);
55     }
56 }kd;
57 void print(point &p)
58 {
59     for(int j=0;j<k;j++) printf("%d%c",p.x[j],j==k-1?'
':' ');
60 }
61 int main()
62 {
63     while(scanf("%d%d",&n,&k)!=EOF)
64     {
65         for(int i=0;i<n;i++) for(int j=0;j<k;j++) scanf("%d",&po[i].x[j]);
66         kd.build(0,n-1);
67         int t,m;
68         for(scanf("%d",&t);t--;)
69         {
70              point ask;
71              for(int j=0;j<k;j++) scanf("%d",&ask.x[j]);
72              scanf("%d",&m); kd.query(ask,m);
73              printf("the closest %d points are:
", m);
74              point pt[20];
75              for(int j=0;!nq.empty();j++) pt[j]=nq.top().second,nq.pop();
76              for(int j=m-1;j>=0;j--) print(pt[j]);
77         }
78     }
79     return 0;
80 }
View Code

HDU4347 - The Closeset M Points

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 const int N=200555,K=2;
 4 const int inf=0x3f3f3f3f;
 5 
 6 typedef long long ll;
 7 #define sqr(x) ((x)*(x))
 8 int k=2,idx;   //k为维数,n为点数
 9 struct Point
10 {
11     int pr,id;
12     ll x[K];
13     bool operator < (const Point &u) const
14     {
15         return x[idx]<u.x[idx];
16     }
17 }po[N];
18 
19 typedef pair<ll,Point>tp;
20 priority_queue<tp>nq;
21 
22 struct KDTree
23 {
24     Point pt[N<<2];
25     int son[N<<2],mn[N<<2];
26 
27     void init(){
28         memset(son,0,sizeof son);
29         memset(mn,0,sizeof mn);
30     }
31     void build(int l,int r,int rt=1,int dep=0)
32     {
33         if(l>r) return;
34         son[rt]=r-l;
35         son[rt*2]=son[rt*2+1]=-1;
36         idx=dep%k;
37         int mid=(l+r)/2;
38         nth_element(po+l,po+mid,po+r+1);
39         pt[rt]=po[mid];
40         mn[rt]=po[mid].pr;
41         build(l,mid-1,rt*2,dep+1);
42         build(mid+1,r,rt*2+1,dep+1);
43         mn[rt]=min(mn[rt],min(mn[rt*2],mn[rt*2+1]));
44     }
45     void query(Point p,int m,int rt=1,int dep=0)
46     {
47         if(son[rt]==-1) return;
48 //        printf("//
");
49         tp nd(0,pt[rt]);
50         for(int i=0;i<k;i++) nd.first+=sqr(nd.second.x[i]-p.x[i]);
51         int dim=dep%k,x=rt*2,y=rt*2+1,fg=0;
52         if(p.x[dim]>=pt[rt].x[dim]) swap(x,y);
53 //        printf("//1
");
54         if(nd.second.pr>p.pr) nd.first=1e18;
55         if(~son[x] && mn[x]<=p.pr) query(p,m,x,dep+1);
56         if(nq.size()<m) nq.push(nd),fg=1;
57         else
58         {
59             if(nd.first<nq.top().first) nq.pop(),nq.push(nd);
60             else if(nd.first==nq.top().first&&nd.second.id<nq.top().second.id) nq.pop(),nq.push(nd);
61             if(sqr(p.x[dim]-pt[rt].x[dim])<=nq.top().first) fg=1;
62         }
63 //        printf("//2
");
64         if(~son[y]&&fg && mn[y]<=p.pr) query(p,m,y,dep+1);
65     }
66 }kdt;
67 
68 int T,n,m,root;
69 
70 int main(){
71 //    freopen("in.txt","r",stdin);
72     cin >> T;
73     while(T--){
74         scanf("%d%d",&n,&m);
75         for(int i = 1;i <= n;++i){
76             scanf("%I64d%I64d%d",&po[i].x[0],&po[i].x[1],&po[i].pr);
77             po[i].id = i;
78         }
79         kdt.build(1,n);
80         while(m--){
81             Point p;
82             scanf("%I64d%I64d%d",&p.x[0],&p.x[1],&p.pr);
83             kdt.query(p,1);
84             if(nq.empty()){
85                 printf("-1
");
86                 continue;
87             }
88             printf("%I64d %I64d %d
",nq.top().second.x[0],nq.top().second.x[1],nq.top().second.pr);
89             nq.pop();
90         }
91     }
92 
93     return 0;
94 }
View Code
原文地址:https://www.cnblogs.com/doub7e/p/7643705.html