hdu 4347

#include<bits/stdc++.h>
#define sq(x) (x)*(x)
#define N (55555)
 
using namespace std;
 
int idx,k,n,m,q;
 
struct Node
{
    int x[5];
    bool operator < (const Node &u) const
    {
        return x[idx] < u.x[idx];
    }
} P[N];

typedef pair<double,Node> PDN;
priority_queue<PDN> que;    
 
struct KD_Tree
{
    int sz[N<<2]; Node p[N<<2];
 
    void nthelement(int s,int k,int e,int idx)
    {
        if (s >= e)
        {
            return;
        }
        //int mid = random.Next(s, e);
        // int mid = (s + e) / 2;
        // Node tmp = P[mid];
        // P[mid] = P[s];
        // P[s] = tmp;
        Node md = P[s];
        int l = 0;
        for (int i = s + 1; i < e; i++)
        {
            if (P[i].x[idx] <= md.x[idx])
            {
                ++l;
                Node tmp = P[s + l];
                P[s + l] = P[i];
                P[i] = tmp;
            }
        }

        P[s] = P[l + s];
        P[l + s] = md;
        if (l + 1 == k)
            return;
        if (l + 1 < k)
        {
            //ss.Push(new F2(s + l + 1, k - l - 1, e));
            nthelement(s + l + 1, k - l - 1, e, idx);
        }
        else
        {
            //ss.Push(new F2(s, l + 1 - k, s + l + 1));
            nthelement(s, k, s + l, idx);
        }
    }

    void build(int i,int l,int r,int dep)
    {
        if (l>r) return;
        int mid=(l+r)>>1;
        idx=dep%k;sz[i]=r-l;
        sz[i<<1]=sz[i<<1|1]=-1;
        //nth_element(P+l,P+mid,P+r+1);
        nthelement(l,mid+1-l,r+1,idx);
        p[i]=P[mid];
        build(i<<1,l,mid-1,dep+1);
        build(i<<1|1,mid+1,r,dep+1);
    }
 
    void query(int i,int m,int dep,Node a)
    {
        if (sz[i]==-1) return;
        PDN tmp=PDN(0,p[i]);
        for(int j=0;j<k;j++)
            tmp.first+=sq(tmp.second.x[j]-a.x[j]);
        int lc=i<<1,rc=i<<1|1,dim=dep%k,flag=0;
        if (a.x[dim]>=p[i].x[dim]) swap(lc,rc);
        if (~sz[lc]) query(lc,m,dep+1,a);
        if (que.size()<m) que.push(tmp),flag=1;
        else
        {
            if (tmp.first<que.top().first) que.pop(),que.push(tmp);
            if (sq(a.x[dim]-p[i].x[dim])<que.top().first) flag=1;
        }
        if (~sz[rc]&&flag) query(rc,m,dep+1,a);
    }
} KDT;
 
int main()
{
    while(~scanf("%d%d",&n,&k))
    {
        for(int i=0;i<n;i++)
            for(int j=0;j<k;j++)
                scanf("%d",&P[i].x[j]);
        
        KDT.build(1,0,n-1,0);
        scanf("%d",&q);
        while(q--)
        {
            Node now;
            for(int i=0;i<k;i++)
                scanf("%d",&now.x[i]);
            scanf("%d",&m); int t=0;
            KDT.query(1,m,0,now); Node pp[21];
            for(;!que.empty();que.pop())
                pp[++t]=que.top().second;
            printf("the closest %d points are:
",t);
            for(int i=m;i>0;i--)
            {
                printf("%d",pp[i].x[0]);
                for(int j=1;j<k;j++)
                    printf(" %d",pp[i].x[j]);
                puts("");
            }
        }
    }
 
    return 0;
 
}

C# TLE

using System;
using System.Collections.Generic;
using System.Data.SqlTypes;
using System.Linq;

namespace kdtree
{
    public class F
    {
        public F(int i,int cnt,int k,Node nd)
        {
            this.i = i;
            this.cnt = cnt;
            this.k = k;
            this.nd = nd;
        }
        public int i, cnt, k;
        public Node nd;
    }

    public class F2
    {
        public F2(int s, int k, int e)
        {
            this.s = s;
            this.k = k;
            this.e = e;
        }
        public int s, k, e;
    }
    public class Node:IComparable<Node>
    {
        public List<int> a;
        public int u;
        public int d;

        public int CompareTo(Node obj)
        {
            var val = d.CompareTo(obj.d);
            if (val != 0)
                return val;
            return u.CompareTo(obj.u);
        }
    }

    public class treekd
    {
        public int dim;
        public void init(int n,int dim,List<Node> points)
        {
            int N = n << 2;
            sz = new int[N];
            p = new Node[N];
            this.dim = dim;
            this.points = points;
        }
        int[] sz;
        List<Node> points;
        Node[] p;
        public SortedSet<Node> que;
        public void initquery()
        {
            que = new SortedSet<Node>();
        }
        
        public void createnode()
        {

        }
        public void update()
        {

        }
        public int sq(int a) { return a * a; }
        public void swap(ref int a,ref int b)
        {
            a = a + b;
            b = a - b;
            a = a - b;
        }
        public void query(int i,int cnt,int k,Node nd)
        {
            Stack<F> s = new Stack<F>();
            s.Push(new F(i,cnt,k,nd));
            while (s.Count > 0)
            {
                F f = s.Pop();
                i = f.i;cnt = f.cnt;k = f.k;nd = f.nd;
                if (sz[i] == -1) continue;
                Node tmp = new Node() { u = i, a = p[i].a };
                for (int ii = 0; ii < dim; ii++)
                {
                    tmp.d += sq(tmp.a[ii] - nd.a[ii]);
                }
                int lc = i << 1, rc = i << 1 | 1;
                int ndim = k % dim;
                bool flag = false;
                if (nd.a[ndim] >= p[i].a[ndim]) swap(ref lc, ref rc);
                if (sz[lc] != -1)
                {
                    s.Push(new F(lc,cnt,k+1,nd));
                    //query(lc, cnt, k + 1, nd);
                }
                if (que.Count < cnt)
                {
                    que.Add(tmp);
                    flag = true;
                }
                else
                {
                    if (tmp.d < que.Last().d)
                    {
                        que.Remove(que.Last());
                        que.Add(tmp);
                    }
                    if (sq(nd.a[ndim] - p[i].a[ndim]) < que.Last().d) flag = true;
                }
                if (sz[rc] != -1 && flag == true)
                {
                    s.Push(new F(rc, cnt, k + 1, nd));
                    //query(rc, cnt, k + 1, nd);
                }
            }
        }
        static Random random = new Random();
        public void nth_element(ref List<Node> list,int s,int k,int e,int idx)
        {
            Stack<F2> ss = new Stack<F2>();
            ss.Push(new F2(s, k, e));
            while (ss.Count > 0)
            {
                F2 f = ss.Pop();
                s = f.s; k = f.k; e = f.e;
                if (s >= e)
                {
                    continue;
                }
                //int mid = random.Next(s, e);
                //int mid = (s + e) / 2;
                //Node tmp = list[mid];
                //list[mid] = list[s];
                //list[s] = tmp;
                Node md = list[s];
                int l = 0;
                for (int i = s + 1; i < e; i++)
                {
                    if (list[i].a[idx] <= md.a[idx])
                    {
                        ++l;
                        Node tmp = list[s + l];
                        list[s + l] = list[i];
                        list[i] = tmp;
                    }
                }

                list[s] = list[l + s];
                list[l + s] = md;
                if (l + 1 == k)
                    return;
                if (l + 1 < k)
                {
                    ss.Push(new F2(s + l + 1, k - l - 1, e));
                    //nth_element(ref list, s + l + 1, k - l - 1, e);
                }
                else
                {
                    ss.Push(new F2(s, k, s + l));
                    //nth_element(ref list, s, l + 1 - k, s + l + 1);
                }
            }
        }
        public void build(int i, int l,int r,int k)
        {
            if (l > r) return;
            int mid = (l + r) >> 1;
            int idx = k % dim;
            sz[i] = r - l;
            sz[i << 1] = sz[i << 1 | 1] = -1;
            nth_element(ref points, l, mid+1-l, r+1, idx);
            //var list=points.Skip(l).Take(r-l+1).OrderBy(q => q.a[idx]).ToList();
            //for (int ii = l; ii <= r; ii++)
            //    points[ii] = list[ii-l];
            p[i] = points[mid];
            build(i << 1, l, mid - 1, k + 1);
            build(i << 1 | 1, mid + 1, r, k + 1);
        }
    }
    class Program
    {
        static public List<int> GetLine()
        {
            var input = Console.ReadLine();
            if (input == null)
                return null;
            return input.Split(' ').Select(p=>int.Parse(p)).ToList();
        }
        static public int GetInt()
        {
            return Console.ReadLine().Split(' ').Select(p => int.Parse(p)).ToList()[0];
        }
        static public void nth_element(ref List<int> list, int s, int k, int e)
        {
            if (s >= e|| k==0)
            {
                return;
            }
            //int mid = (s + e) / 2;
            //int tmp = list[mid];
            //list[mid] = list[s];
            //list[s] = tmp;
            int md = list[s];
            int l = 0;
            for (int i = s+1; i < e; i++)
            {
                if (list[i] <= md)
                {
                    ++l;
                    var tmp = list[s+l];
                    list[s+l] = list[i];
                    list[i] = tmp;
                }
            }

            list[s]=list[l+s];
            list[l+s] = md;
            if (l + 1 == k)
                return;
            if (l + 1 < k)
            {
                nth_element(ref list, s+l+1, k - (l + 1), e);
            }
            else
            {
                nth_element(ref list, s, k, s + l);
            }
        }
        static void Main(string[] args)
        {
            //List<int> list1 = new List<int>() { 8, 3, 4, 5, 1, 6, 8, 2, 3, 0 };
            //List<int> list1 = new List<int>();
            //for (int i = 0; i < 100; i++)
            //{
            //    list1.Add(new Random().Next(0, 100));
            //}
            ////List<int> list1
            //for (int i = 1; i < list1.Count; i++)
            //{
            //    var list11 = list1.ToList();
            //    nth_element(ref list11, 0, i + 1, list11.Count);
            //    Console.WriteLine(string.Join(",", list1));
            //    var list12 = list1.OrderBy(p => p).ToList();
            //    if (list11[i] != list12[i])
            //    {
            //        Console.WriteLine(list11[i]+","+list12[i]);
            //    }
            //}
            //List<int> list2 = new List<int>() { 10, 4, 6, 3, 7, 3, 7, 11, 8, 6, 10, 2 };
            //List<int> list3 = new List<int>() { 2, 7, 6, 3, 8, 3, 14, 3, 8, 6, 10, 2 };
            List<int> input;
            while ((input = GetLine()) != null&&input.Count>0)
            {
                treekd kdt = new treekd();
                var n = input[0];
                var k = input[1];
                List<Node> nodes = new List<Node>();
                for (int i = 0; i < n; i++)
                {
                    input = GetLine();
                    nodes.Add(new Node());
                    nodes[i].a = input;
                }
                kdt.init(n, k, nodes);
                kdt.build(1, 0, n - 1, 0);
                var t = GetInt();
                while (t-- > 0)
                {
                    kdt.initquery();
                    Node now = new Node();
                    now.a = GetLine();
                    int cnt = GetInt();
                    kdt.query(1, cnt, 0, now);
                    Console.WriteLine("the closest " + cnt + " points are:");
                    foreach (var item in kdt.que.ToList().Take(cnt))
                    {
                        Console.WriteLine(string.Join(" ", item.a));
                    }
                }
            }
        }
    }
}
原文地址:https://www.cnblogs.com/HaibaraAi/p/nearest_k_points.html