bzoj 3053: The Closest M Points【KD-tree】

多维KDtree板子
左右儿子的估价用mn~mx当区间,假设区间里的数都存在;k维轮着做割点

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<queue>
#include<cstring>
using namespace std;
const int N=50005;
int n,k,m,rt,w,ans[15];
priority_queue<pair<int,int> >q;
struct qwe
{
	int a[5];
	int& operator [] (int x)
	{
		return a[x];
	}
	bool operator < (const qwe &b) const
	{
		return a[w]<b.a[w];
	}
}a[N],b;
struct KD
{
	int ls,rs;
	qwe d,mn,mx;
}t[N<<2];
int read()
{
	int r=0,f=1;
	char p=getchar();
	while(p>'9'||p<'0')
	{
		if(p=='-')
			f=-1;
		p=getchar();
	}
	while(p>='0'&&p<='9')
	{
		r=r*10+p-48;
		p=getchar();
	}
	return r*f;
}
void minn(int &x,int y)
{
	x>y?x=y:0;
}
void maxx(int &x,int y)
{
	x<y?x=y:0;
}
void ud(int ro)
{
	if(t[ro].ls)
		for(int i=0;i<k;i++)
			minn(t[ro].mn[i],t[t[ro].ls].mn[i]),maxx(t[ro].mx[i],t[t[ro].ls].mx[i]);
	if(t[ro].rs)
		for(int i=0;i<k;i++)
			minn(t[ro].mn[i],t[t[ro].rs].mn[i]),maxx(t[ro].mx[i],t[t[ro].rs].mx[i]);
}
int build(int l,int r,int f)
{
	if(l>r)
		return 0;
	w=f;
	int mid=(l+r)>>1;
	nth_element(a+l,a+mid,a+r+1);
	t[mid].mn=t[mid].mx=t[mid].d=a[mid];
	t[mid].ls=build(l,mid-1,(f+1)%k);
	t[mid].rs=build(mid+1,r,(f+1)%k);
	ud(mid);
	return mid;
}
int dis(qwe a,qwe b)
{
	int r=0;
	for(int i=0;i<k;i++)
		r+=(a[i]-b[i])*(a[i]-b[i]);
	return r;
}
int wk(int ro)
{
	int r=0;
	for(int i=0;i<k;i++)
	{
		if(b[i]<t[ro].mn[i])
			r+=(t[ro].mn[i]-b[i])*(t[ro].mn[i]-b[i]);
		if(b[i]>t[ro].mx[i])
			r+=(t[ro].mx[i]-b[i])*(t[ro].mx[i]-b[i]);
	}
	return r;
}
void ques(int ro,int f)
{
	if(!ro)
		return;
	int dm=dis(t[ro].d,b),dl=t[ro].ls?wk(t[ro].ls):1e9,dr=t[ro].rs?wk(t[ro].rs):1e9;//cerr<<"OK"<<dm<<endl;
	if(q.top().first>dm)
		q.pop(),q.push(make_pair(dm,ro));
	if(dl<dr)
	{
		if(dl<q.top().first)
			ques(t[ro].ls,(f+1)%k);
		if(dr<q.top().first)
			ques(t[ro].rs,(f+1)%k);
	}
	else
	{
		if(dr<q.top().first)
			ques(t[ro].rs,(f+1)%k);
		if(dl<q.top().first)
			ques(t[ro].ls,(f+1)%k);
	}
}
int main()
{
	while(~scanf("%d%d",&n,&k))
	{
		memset(t,0,sizeof(t));
		for(int i=1;i<=n;i++)
			for(int j=0;j<k;j++)
				a[i][j]=read();
		rt=build(1,n,0);
		m=read();
		while(m--)
		{
			for(int i=0;i<k;i++)
				b[i]=read();
			int s=read();
			for(int i=1;i<=s;i++)
				q.push(make_pair(1e9,0));
			ques(rt,0);
			for(int i=1;i<=s;i++)
				ans[i]=q.top().second,q.pop();
			printf("the closest %d points are:
",s);
			for(int i=s;i>=1;i--)
			{
				for(int j=0;j<k;j++)
					printf("%d ",t[ans[i]].d[j]);
				puts("");
			}
		}
	}
	return 0;
}
原文地址:https://www.cnblogs.com/lokiii/p/10098965.html