bzoj 3053 HDU 4347 : The Closest M Points kd树

bzoj 3053 HDU 4347 : The Closest M Points  kd树

题目大意:求k维空间内某点的前k近的点。

就是一般的kd树,根据实测发现,kd树的两种建树方式,即按照方差较大的维度分开(建树常数大)或者每一位轮换分割(询问常数大),后者更快也更好些,以后就果断写第二种了。

#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;
#define MAXN 510000
#define MAXT MAXN
#define MAXM 6
#define sqr(x) ((qword)(x)*(x))
#define INF 0x3f3f3f3f
typedef long long qword;
int n,m;
struct point
{
        int a[MAXM];
        qword dis;
        void pm()
        {
                printf("(%d",a[0]);
                for (int i=1;i<m;i++)
                        printf(",%d",a[i]);
                printf(")");
        }
        void pm2()
        {
                for (int i=0;i<m-1;i++)
                        printf("%d ",a[i]);
                printf("%d
",a[m-1]);
        }
}pl[MAXN];
bool cmp_0(const point &p1,const point &p2){
        return p1.a[0]<p2.a[0];
}
bool cmp_1(const point &p1,const point &p2){
        return p1.a[1]<p2.a[1];
}
bool cmp_2(const point &p1,const point &p2){
        return p1.a[2]<p2.a[2];
}
bool cmp_3(const point &p1,const point &p2){
        return p1.a[3]<p2.a[3];
}
bool cmp_4(const point &p1,const point &p2){
        return p1.a[4]<p2.a[4];
}
bool cmp_d(const point &p1,const point &p2){
        return p1.dis<p2.dis;
}
struct kdt_node
{
        int a[MAXM];
        int dd,mxp[MAXM],mnp[MAXM];
        int ch[2];
        int ptr;
}kdt[MAXT];
qword get_dis(point &pt,kdt_node &pn)
{
        qword ret=0;
        for (int j=0;j<m;j++)
                if (pt.a[j]<pn.mnp[j] || pt.a[j]>pn.mxp[j])
                        ret+=min(sqr(pt.a[j]-pn.mxp[j]),sqr(pt.a[j]-pn.mnp[j]));
        return ret;
}
qword get_dis(point &p1,point &p2)
{
        qword ret=0;
        for (int i=0;i<m;i++)
                ret+=sqr(p1.a[i]-p2.a[i]);
        return ret;
}
int topt=0;
void Build_kdt(int &now,int l,int r,int d)
{ 
        if (l>r)return;
           now=++topt;
        int i,j;
        for (j=0;j<m;j++)kdt[now].mxp[j]=-INF,kdt[now].mnp[j]=INF;
        for (i=l;i<=r;i++)
                for (j=0;j<m;j++)
                {
                        kdt[now].mxp[j]=max(kdt[now].mxp[j],pl[i].a[j]);
                        kdt[now].mnp[j]=min(kdt[now].mnp[j],pl[i].a[j]);
                }
/*        double ave[MAXM];
        double sqv[MAXM];
        memset(ave,0,sizeof(ave));
        memset(sqv,0,sizeof(sqv));
        for (i=l;i<=r;i++)
                for (j=0;j<m;j++)
                        ave[j]+=pl[i].a[j];
        for (j=0;j<m;j++)
                ave[j]/=(r-l+1);
        for (i=l;i<=r;i++)
                for (j=0;j<m;j++)
                        sqv[j]+=sqr(pl[i].a[j]-ave[j]);
        kdt[now].dd=0;
        for (j=0;j<m;j++)
                if (sqv[j]>sqv[kdt[now].dd])
                        kdt[now].dd=j;*/
        kdt[now].dd=d;
        switch (kdt[now].dd)
        {
                case 0:nth_element(&pl[l],&pl[(l+r)>>1],&pl[r+1],cmp_0);break;
                case 1:nth_element(&pl[l],&pl[(l+r)>>1],&pl[r+1],cmp_1);break;
                case 2:nth_element(&pl[l],&pl[(l+r)>>1],&pl[r+1],cmp_2);break;
                case 3:nth_element(&pl[l],&pl[(l+r)>>1],&pl[r+1],cmp_3);break;
                case 4:nth_element(&pl[l],&pl[(l+r)>>1],&pl[r+1],cmp_4);break;
        }
        kdt[now].ptr=(l+r)>>1;
        Build_kdt(kdt[now].ch[0],l,((r+l)>>1)-1,(d+1)%m);
        Build_kdt(kdt[now].ch[1],((r+l)>>1)+1,r,(d+1)%m);
}
point h[MAXN];
int atot;
int toph=0;
void search_point(int now,point &pt)
{
        if (!now)return ;
        qword cdis=get_dis(pt,pl[kdt[now].ptr]);
        if (toph<atot || cdis<h[0].dis)
        {
                if (toph==atot)pop_heap(&h[0],&h[toph--],cmp_d);
                h[toph]=pl[kdt[now].ptr];
                h[toph].dis=cdis;
                push_heap(&h[0],&h[++toph],cmp_d);
        }
        int t;
        if (get_dis(pt,kdt[kdt[now].ch[0]]) < get_dis(pt,kdt[kdt[now].ch[1]]))
                t=0;
        else
                t=1;
        search_point(kdt[now].ch[t],pt);
        if (toph<atot || get_dis(pt,kdt[kdt[now].ch[t^1]]) < h[0].dis)
        {
                search_point(kdt[now].ch[t^1],pt);
        }
}
int main()
{
    //    freopen("input.txt","r",stdin);
        //freopen("output.txt","w",stdout);
        int i,j,k,x,y,z;
        int root=0;
        while (~scanf("%d%d",&n,&m))
        {
                for (i=0;i<n;i++)
                        for (j=0;j<m;j++)
                                scanf("%d",&pl[i].a[j]);
                Build_kdt(root,0,n-1,0);
                int q;
                scanf("%d",&q);
                point pt;
                for (i=0;i<q;i++)
                {
                        for (j=0;j<m;j++)
                                scanf("%d",&pt.a[j]);
                        scanf("%d",&atot);
                        search_point(root,pt);
                        printf("the closest %d points are:
",atot);
                        while (toph)
                        {
                                pop_heap(&h[0],&h[toph--],cmp_d);
                        }
                        for (j=0;j<atot;j++)
                                h[j].pm2();
                }
        }
}
by mhy12345(http://www.cnblogs.com/mhy12345/) 未经允许请勿转载

本博客已停用,新博客地址:http://mhy12345.xyz

原文地址:https://www.cnblogs.com/mhy12345/p/4151328.html