bzoj2599: [IOI2011]Race(点分治)

写了四五道点分治的题目了,算是比较理解点分治是什么东西了吧= =

点分治主要用来解决点对之间的问题的,比如距离为不大于K的点有多少对。

这道题要求距离等于K的点对中连接两点的最小边数。

那么其实道理是一样的。先找重心,然后先从重心开始求距离dis和边数num,更新ans,再从重心的儿子开始求得dis和num,减去这部分答案

因为这部分的答案中,从重心开始的两条链有重叠部分,所以要剪掉

基本算是模板题,但是减去儿子的答案的那部分还有双指针那里调了好久,所以还不算特别熟练。。

PS跑了27秒慢到飞起,不过代码短一点,看起来比较清晰

 1 #include<stdio.h>
 2 #include<string.h>
 3 #include<algorithm>
 4 #define INF 0x3f3f3f3f
 5 using namespace std;
 6 const int maxn = 200010;
 7 struct node{
 8     int to,next,cost;
 9 }e[maxn*2];
10 struct data{
11     int l,e;
12 }dis[maxn];
13 int n,m,K,tot,head[maxn],size[maxn],sz,total,vis[maxn],root,p,ans[maxn*2];
14 
15 void insert(int u, int v, int w){
16     e[++tot].to=v; e[tot].next=head[u]; head[u]=tot; e[tot].cost=w;
17 }
18 
19 void getroot(int u, int f){
20     int mx=0; size[u]=1;
21     for (int i=head[u],v; i; i=e[i].next){
22         if (vis[v=e[i].to] || v==f) continue;
23         getroot(v,u);
24         size[u]+=size[v];
25         mx=max(mx,size[v]);
26     }
27     mx=max(mx,total-size[u]);
28     if (mx<sz) sz=mx,root=u;
29 }
30 
31 void getdis(int u, int f, int len, int num){
32     dis[++p].l=len; dis[p].e=num;
33     for (int i=head[u],v; i; i=e[i].next){
34         if (vis[v=e[i].to] || v==f) continue;
35         getdis(v,u,len+e[i].cost,num+1);
36     }
37 }
38 
39 bool cmp(data a, data b){
40     if (a.l==b.l) return a.e<b.e; return a.l<b.l;  /////
41 }
42 
43 void count(int u, int len, int num, int f){
44     p=0; getdis(u,0,len,num); //这里要将len和num传下去。。
45     int l=1,r=p;
46     sort(dis+1,dis+1+p,cmp);
47     while (l<=r){  //注意这里要用<=,WA了几发
48         while (l<r && dis[l].l+dis[r].l>K) r--;
49         for (int k=r; dis[l].l+dis[k].l==K; k--) ans[dis[l].e+dis[k].e]+=f;
50         l++;
51     }
52 }
53 
54 void work(int u){
55     total=size[u]?size[u]:n;
56     sz=INF;
57     getroot(u,0); u=root;
58     vis[u]=1; count(u,0,0,1);
59     for (int i=head[u],v; i; i=e[i].next){
60         if (vis[v=e[i].to]) continue;
61         count(v,e[i].cost,1,-1);
62         work(v);
63     }
64 }
65 
66 int main(){
67     scanf("%d%d", &n, &K);
68     for (int i=1,u,v,w; i<n; i++) scanf("%d%d%d", &u, &v, &w),u++,v++,insert(u,v,w),insert(v,u,w);
69     work(1);
70     for (int i=1; i<n; i++) if (ans[i]){printf("%d
", i); return 0;} puts("-1");
71     return 0;
72 } 
原文地址:https://www.cnblogs.com/mzl0707/p/6183847.html