学习笔记—点分治

点分治,是一种针对可带权树上简单路径统计问题的算法。

就 POJ 1741 来说:

问题:给一棵边带权树,问两点之间的距离小于等于$k$的点对有多少个。

解决:

当前有一个节点$u$,那么树上的路径可分为两种:(1) 经过节点$u$的   (2) 不经过节点$u$的

第 (2) 种路径,一定在$u$的某个子节点构成的子树中。在各个子树中找一点递归下去即可。

 1 void solve(int u) {
 2     ans+=calc(u,0); vis[u]=1;
 3     for(int i=fro[u];i;i=nxt[i]) {
 4         int v=to[i];
 5         if(vis[v]) continue;
 6         ans-=calc(v,w[i]); 
 7         //合并路径时,u的同一个子树下的两点合并出的路径是不存在的。在此减去。
 8         root=0,size=sz[v];
 9         findrt(v,0); solve(root);
10     }
11 }

找什么样的点递归下去使得效率最高?

递归层数要最少,所以应选一棵树中最大子树最小的点,即树的重心

 1 //size表示整棵树的大小
 2 void findrt(int u,int fa) {
 3     sz[u]=1; f[u]=0;
 4     for(int i=fro[u];i;i=nxt[i]) {
 5         int v=to[i];
 6         if(vis[v]||v==fa) continue;
 7         findrt(v,u);
 8         sz[u]+=sz[v]; f[u]=max(f[u],sz[v]);
 9     }
10     f[u]=max(f[u],size-sz[u]); 
11     if(f[u]<f[root]) root=u;
12 }

本题完整代码:

 1 #include<cstdio>
 2 #include<cstring>
 3 #include<algorithm>
 4 using namespace std;
 5 const int N=1e4+5;
 6 int n,k,ans,root,size,tot,d[N],dep[N],sz[N],f[N];
 7 int cnt,fro[N],to[N<<1],w[N<<1],nxt[N<<1];
 8 bool vis[N];
 9 
10 inline int read() {
11     int x=0; char c=getchar();
12     while(c<'0'||c>'9') c=getchar();
13     while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-'0',c=getchar();
14     return x;
15 }
16 void add(int x,int y,int z) {
17     to[++cnt]=y,w[cnt]=z,nxt[cnt]=fro[x]; fro[x]=cnt;
18 }
19 
20 void findrt(int u,int fa) {
21     sz[u]=1; f[u]=0;
22     for(int i=fro[u];i;i=nxt[i]) {
23         int v=to[i];
24         if(vis[v]||v==fa) continue;
25         findrt(v,u);
26         sz[u]+=sz[v],f[u]=max(f[u],sz[v]);
27     }
28     f[u]=max(f[u],size-sz[u]);
29     if(f[u]<f[root]) root=u;
30 }
31 void getdeep(int u,int fa) {
32     d[++tot]=dep[u];
33     for(int i=fro[u];i;i=nxt[i]) {
34         int v=to[i];
35         if(vis[v]||v==fa) continue;
36         dep[v]=dep[u]+w[i]; 
37         getdeep(v,u);
38     }
39 }
40 int clac(int u) {
41     tot=0; getdeep(u,0);
42     sort(d+1,d+tot+1);
43     int sum=0,l=1,r=tot;
44     while(l<r) {
45         if(d[l]+d[r]<=k) sum+=r-l,l++;
46         else r--; 
47     }
48     return sum;
49 }
50 void solve(int u) {
51     vis[u]=1;  
52     dep[u]=0; ans+=clac(u);
53     for(int i=fro[u];i;i=nxt[i]) {
54         int v=to[i];
55         if(vis[v]) continue;
56         dep[v]=w[i]; ans-=clac(v);
57         root=0; size=sz[v];
58         findrt(v,0); solve(root);
59     }
60 }
61 
62 int main() {
63     while(scanf("%d%d",&n,&k)&&(n||k)) {
64         cnt=0,ans=0;
65         memset(fro,0,sizeof(fro));
66         memset(vis,0,sizeof(vis));
67         for(int i=1;i<n;i++) {
68             int x=read(),y=read(),z=read();
69             add(x,y,z); add(y,x,z);
70         }
71         root=0; size=f[0]=n;
72         findrt(1,0); solve(root);
73         printf("%d
",ans);
74     }
75     return 0;
76 }

其它例题

P2634 聪聪可可

把路径长度对$3$取模后答案为$0,1,2$的路径条数分别保存为$t[0],t[1],t[2]$。

答案:$2 imes t[1] imes t[2]+t[0] imes t[0]$。

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 const int N=2e4+5;
 4 int n,ans,root,size,t[3],dep[N],sz[N],f[N];
 5 int cnt,to[N<<1],w[N<<1],nxt[N<<1],fro[N];
 6 bool vis[N];
 7 
 8 inline int read() {
 9     int x=0; char c=getchar();
10     while(c<'0'||c>'9') c=getchar();
11     while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-'0',c=getchar();
12     return x;
13 }
14 void add(int x,int y,int z) {
15     to[++cnt]=y,w[cnt]=z,nxt[cnt]=fro[x]; fro[x]=cnt;
16 }
17 int gcd(int a,int b) {return b?gcd(b,a%b):a;}
18 
19 void findrt(int u,int fa) {
20     sz[u]=1; f[u]=0;
21     for(int i=fro[u];i;i=nxt[i]) {
22         int v=to[i];
23         if(vis[v]||v==fa) continue;
24         findrt(v,u);
25         sz[u]+=sz[v]; f[u]=max(f[u],sz[v]);
26     }
27     f[u]=max(f[u],size-sz[u]);
28     if(f[u]<f[root]) root=u;
29 }
30 void query(int u,int fa) {
31     t[dep[u]]++;
32     for(int i=fro[u];i;i=nxt[i]) {
33         int v=to[i];
34         if(vis[v]||v==fa) continue;
35         dep[v]=(dep[u]+w[i])%3;
36         query(v,u);
37     }
38 }
39 int calc(int u,int d0) {
40     dep[u]=d0; 
41     t[0]=t[1]=t[2]=0;
42     query(u,0);
43     return t[0]*t[0]+2*t[1]*t[2];
44 }
45 void solve(int u) {
46     ans+=calc(u,0); vis[u]=1;
47     for(int i=fro[u];i;i=nxt[i]) {
48         int v=to[i];
49         if(vis[v]) continue;
50         ans-=calc(v,w[i]);
51         root=0,size=sz[v];
52         findrt(v,0); solve(root);
53     }
54 }
55 
56 int main() {
57     n=read();
58     for(int i=1;i<n;i++) {
59         int x=read(),y=read(),z=read()%3;
60         add(x,y,z),add(y,x,z);
61     }
62     size=f[0]=n;
63     findrt(1,0); solve(root);
64     int t=gcd(ans,n*n);
65     printf("%d/%d
",ans/t,n*n/t);
66 }

  如有错误、疑问请联系作者(见公告)!感谢。

原文地址:https://www.cnblogs.com/qq8260573/p/10803815.html