树的点分治

题意:统计树上有多少点对距离不超过k。
树分治模板题
 1 #include <iostream>
 2 #include <cstdio>
 3 #include <cstring>
 4 #include <algorithm>
 5 using namespace std;
 6 const int maxn = 1e5 + 10;
 7 const int inf = 0x3f3f3f3f;
 8 int n, k;
 9 int rt, snode, ans;
10 int size[maxn], d[maxn], vis[maxn];
11 int dep[maxn];
12 int maxson[maxn];
13 struct Edge{
14     int v, w, nxt;
15     Edge(int v = 0, int w = 0, int nxt = 0) : v(v), w(w), nxt(nxt) {}
16 }e[maxn<<1];
17 int head[maxn], cnt;
18 void init(){
19     cnt = 0;
20     memset(head, -1, sizeof head);
21 }
22 void add(int u, int v, int w){
23     e[cnt] = Edge(v, w, head[u]);
24     head[u] = cnt++;
25 }
26 
27 void getrt(int u, int f){
28     size[u] = 1;
29     maxson[u] = 0;
30     for(int i = head[u]; ~i; i = e[i].nxt){
31         int v = e[i].v;
32         if(v == f || vis[v]) continue;
33         getrt(v, u);
34         size[u] += size[v];
35         maxson[u] = max(maxson[u], size[v]);
36     }
37     maxson[u] = max(maxson[u], snode - size[u]);
38     if(maxson[u] < maxson[rt])  rt = u;
39 }
40 void getdep(int u, int f){
41     dep[++dep[0]] = d[u];
42     for(int i = head[u]; ~i; i = e[i].nxt){
43         int v = e[i].v;
44         if(v == f || vis[v]) continue;
45         d[v] = d[u] + e[i].w;
46         getdep(v, u);
47     }
48 }
49 int cal(int u, int w){
50     d[u] = w;
51     dep[0] = 0;
52     getdep(u, 0);
53     sort(dep + 1, dep + 1 + dep[0]);
54     int sum = 0;
55     int l =  1, r = dep[0];
56     while(l < r){
57         if(dep[l] + dep[r] <= k) {
58             sum += r - l;
59             l++;
60         }else r--;
61     }
62     return sum;
63 }
64 
65 void solve(int u){
66     vis[u] =  1;
67     ans += cal(u, 0);
68     for(int i = head[u]; ~i; i = e[i].nxt){
69         int v = e[i].v;
70         if(vis[v]) continue;
71         ans -= cal(v, e[i].w);
72         rt = 0;
73         snode = size[v];
74         getrt(v, u);
75         solve(rt);
76     }
77 }
78 
79 int main(){
80     while(scanf("%d %d", &n, &k) && (n || k)){
81         init();
82         memset(vis, 0, sizeof vis);
83         int u, v, w;
84         for(int i = 1; i < n; i++){
85             scanf("%d %d %d", &u, &v, &w);
86             add(u, v, w);
87             add(v, u, w);
88         }
89         rt = ans = 0;
90         snode = n;
91         maxson[0] = inf;
92         getrt(1, 0);
93         solve(rt);
94         printf("%d
", ans);
95     }
96 
97 }
View Code

Distance in Tree

CodeForces - 161D

题意:统计树上有多少点对距离为k。

点分治

 1 #include <bits/stdc++.h>
 2 using namespace std;
 3 const int maxn = 50010;
 4 const int maxd = 510;
 5 const int inf = 0x3f3f3f3f;
 6 int n, k;
 7 int ct[maxd], temp[maxd];
 8 int rt, snode, ans;
 9 int size[maxn];
10 int maxson[maxn], vis[maxn];
11 
12 struct Edge{
13     int v, nxt;
14     Edge(int v = 0, int nxt = 0) : v(v), nxt(nxt) {}
15 }e[maxn<<1];
16 int head[maxn], cnt;
17 void init(){
18     cnt = 0;
19     memset(head, -1, sizeof head);
20 }
21 void add(int u, int v){
22     e[cnt] = Edge(v, head[u]);
23     head[u] = cnt++;
24 }
25 
26 void getrt(int u, int f){
27     size[u] = 1;
28     maxson[u] = 0;
29     for(int i = head[u]; ~i; i = e[i].nxt){
30         int v = e[i].v;
31         if(vis[v] || v == f) continue;
32         getrt(v, u);
33         size[u] += size[v];
34         maxson[u] = max(maxson[u], size[v]);
35     }
36     maxson[u] = max(maxson[u], snode - size[u]);
37     if(maxson[u] < maxson[rt]) rt = u;
38 }
39 
40 void dfs(int u, int f, int d){
41     if(d > k) return ;
42     ans += ct[k - d];
43     ++temp[d];
44     for(int i = head[u]; ~i; i = e[i].nxt){
45         int v = e[i].v;
46         if(vis[v] || v == f) continue;
47         dfs(v, u, d + 1);
48     }
49 }
50 
51 void cal(int u){
52     memset(ct, 0, sizeof ct);
53     ct[0] = 1;
54     for(int i = head[u]; ~i; i = e[i].nxt){
55         int v = e[i].v;
56         if(vis[v]) continue;
57         memset(temp, 0, sizeof temp);
58         dfs(v, u, 1);
59         for(int i = 1; i <= k ; i++) ct[i] += temp[i];
60     }
61 }
62 
63 void divide(int u){
64     getrt(u, u);
65     u = rt;
66     vis[u] = 1;
67     cal(u);
68     for(int i = head[u]; ~i; i = e[i].nxt){
69         int v = e[i].v;
70         if(vis[v]) continue;
71         rt = 0;
72         snode = size[v];
73         divide(v);
74     }
75 }
76 
77 int main(){
78     while(scanf("%d %d", &n, &k) != EOF){
79         int u, v;
80         init();
81         memset(vis, 0, sizeof vis);
82         for(int i = 1; i  < n; i++){
83             scanf("%d %d", &u, &v);
84             add(u, v);
85             add(v, u);
86         }
87         rt = ans  = 0;
88         snode = n;
89         maxson[0] = inf;
90         divide(1);
91         printf("%d
", ans);
92 
93     }
94 }
View Code

树DP

题解:here

 1 #include <bits/stdc++.h>
 2 using namespace std;
 3 const int maxn = 50010;
 4 const int maxd = 510;
 5 int dp[maxn][maxd];
 6 struct Edge{
 7     int v, nxt;
 8     Edge(int v = 0, int nxt = 0) : v(v), nxt(nxt){}
 9 }e[maxn<<1];
10 int head[maxn], cnt;
11 void init(){
12     cnt = 0;
13     memset(head, -1, sizeof head);
14 }
15 void add(int u, int v){
16     e[cnt] = Edge(v, head[u]);
17     head[u] = cnt++;
18 }
19 
20 int n, k;
21 
22 void dp1(int u, int f){
23     dp[u][0] = 1;
24     for(int i = 1; i <= k; i++) dp[u][i] = 0;
25     for(int i = head[u]; ~i; i = e[i].nxt){
26         int v = e[i].v;
27         if(v == f) continue;
28         dp1(v, u);
29         for(int i = 1; i <= k; i++) dp[u][i] += dp[v][i - 1];
30     }
31 }
32 void dp2(int u, int f){
33     for(int i = head[u]; ~i; i = e[i].nxt){
34         int v = e[i].v;
35         if(v == f) continue;
36         for(int i = k; i >= 1; i--){
37             dp[v][i] += dp[u][i - 1];
38             if(i > 1) dp[v][i] -= dp[v][i - 2];
39         }
40         dp2(v, u);
41     }
42 }
43 
44 int main(){
45     while(scanf("%d %d", &n, &k) != EOF){
46         init();
47         int u, v;
48         for(int i = 1; i < n; i++){
49             scanf("%d %d", &u, &v);
50             add(u, v); add(v, u);
51         }
52         dp1(1, 0);
53         dp2(1, 0);
54         long long ans = 0;
55         for(int i = 1; i <= n; i++){
56             ans += dp[i][k];
57         }
58         printf("%lld
", ans / 2);
59     }
60 }
View Code

聪聪可可

HYSBZ - 2152
题意:统计树上有多少点对距离为3的倍数。
 
两种写法
 
 1 #include <iostream>
 2 #include <cstdio>
 3 #include <cstring>
 4 #include <algorithm>
 5 using namespace std;
 6 const int maxn = 1e5 + 10;
 7 const int inf = 0x3f3f3f3f;
 8 int n, k;
 9 int rt, snode, ans;
10 int size[maxn], d[maxn], vis[maxn];
11 int dep[maxn];
12 int maxson[maxn];
13 struct Edge{
14     int v, w, nxt;
15     Edge(int v = 0, int w = 0, int nxt = 0) : v(v), w(w), nxt(nxt) {}
16 }e[maxn<<1];
17 int head[maxn], cnt;
18 void init(){
19     cnt = 0;
20     memset(head, -1, sizeof head);
21 }
22 void add(int u, int v, int w){
23     e[cnt] = Edge(v, w, head[u]);
24     head[u] = cnt++;
25 }
26 
27 void getrt(int u, int f){
28     size[u] = 1;
29     maxson[u] = 0;
30     for(int i = head[u]; ~i; i = e[i].nxt){
31         int v = e[i].v;
32         if(v == f || vis[v]) continue;
33         getrt(v, u);
34         size[u] += size[v];
35         maxson[u] = max(maxson[u], size[v]);
36     }
37     maxson[u] = max(maxson[u], snode - size[u]);
38     if(maxson[u] < maxson[rt])  rt = u;
39 }
40 void getdep(int u, int f){
41     dep[d[u]]++;
42     for(int i = head[u]; ~i; i = e[i].nxt){
43         int v = e[i].v;
44         if(v == f || vis[v]) continue;
45         d[v] = (d[u] + e[i].w) % 3;
46         getdep(v, u);
47     }
48 }
49 int cal(int u, int w){
50     dep[0] = dep[1] = dep[2] = 0;
51     d[u] = w;
52     getdep(u, 0);
53     return dep[0] * dep[0] + dep[1] * dep[2] * 2;
54 }
55 
56 void solve(int u){
57     vis[u] =  1;
58     ans += cal(u, 0);
59     for(int i = head[u]; ~i; i = e[i].nxt){
60         int v = e[i].v;
61         if(vis[v]) continue;
62         ans -= cal(v, e[i].w);
63         rt = 0;
64         snode = size[v];
65         getrt(v, u);
66         solve(rt);
67     }
68 }
69 
70 int main(){
71     while(scanf("%d", &n) != EOF){
72         init();
73         memset(vis, 0, sizeof vis);
74         int u, v, w;
75         for(int i = 1; i < n; i++){
76             scanf("%d %d %d", &u, &v, &w);
77             w %= 3;
78             add(u, v, w);
79             add(v, u, w);
80         }
81         rt = ans = 0;
82         snode = n;
83         maxson[0] = inf;
84         getrt(1, 0);
85         solve(rt);
86         int temp = n *n;
87         int g = __gcd(temp, ans);
88         printf("%d/%d
", ans/g, temp/ g);
89     }
90 
91 }
+---
  1 #include <bits/stdc++.h>
  2 using namespace std;
  3 const int maxn = 20010;
  4 const int maxd = 4;
  5 const int inf = 0x3f3f3f3f;
  6 int n, k;
  7 int ct[maxd], temp[maxd];  // 开太大会TLE...反复memset耗时严重...
  8 int rt, snode, ans;
  9 int size[maxn];
 10 int maxson[maxn], vis[maxn];
 11 
 12 struct Edge{
 13     int v, w, nxt;
 14     Edge(int v = 0, int w = 0, int nxt = 0) : v(v), w(w), nxt(nxt) {}
 15 }e[maxn<<1];
 16 int head[maxn], cnt;
 17 void init(){
 18     cnt = 0;
 19     memset(head, -1, sizeof head);
 20 }
 21 void add(int u, int v, int w){
 22     e[cnt] = Edge(v, w, head[u]);
 23     head[u] = cnt++;
 24 }
 25 
 26 void getrt(int u, int f){
 27     size[u] = 1;
 28     maxson[u] = 0;
 29     for(int i = head[u]; ~i; i = e[i].nxt){
 30         int v = e[i].v;
 31         if(vis[v] || v == f) continue;
 32         getrt(v, u);
 33         size[u] += size[v];
 34         maxson[u] = max(maxson[u], size[v]);
 35     }
 36     maxson[u] = max(maxson[u], snode - size[u]);
 37     if(maxson[u] < maxson[rt]) rt = u;
 38 }
 39 
 40 void dfs(int u, int f, int d){
 41     if(d == 0) ans += ct[0];   //特殊处理
 42     else ans += ct[k - d];
 43     ++temp[d];
 44     for(int i = head[u]; ~i; i = e[i].nxt){
 45         int v = e[i].v;
 46         if(vis[v] || v == f) continue;
 47         dfs(v, u, (d + e[i].w) % 3);
 48     }
 49 }
 50 
 51 void cal(int u){
 52     memset(ct, 0, sizeof ct);  
 53     ct[0] = 1;
 54     for(int i = head[u]; ~i; i = e[i].nxt){
 55         int v = e[i].v;
 56         if(vis[v]) continue;
 57         memset(temp, 0, sizeof temp);
 58         dfs(v, u, e[i].w);
 59         for(int i = 0; i < k ; i++) ct[i] = ct[i] + temp[i];
 60     }
 61 }
 62 
 63 void divide(int u){
 64     getrt(u, u);
 65     u = rt;
 66     vis[u] = 1;
 67     cal(u);
 68     for(int i = head[u]; ~i; i = e[i].nxt){
 69         int v = e[i].v;
 70         if(vis[v]) continue;
 71         rt = 0;
 72         snode = size[v];
 73         divide(v);
 74     }
 75 }
 76 
 77 int main(){
 78     //freopen("in.txt", "r", stdin);
 79     //freopen("out1.txt", "w", stdout);
 80     while(scanf("%d", &n) != EOF){
 81         int u, v, w;
 82         k = 3;
 83         init();
 84         memset(vis, 0, sizeof vis);
 85         for(int i = 1; i  < n; i++){
 86             scanf("%d %d %d", &u, &v, &w);
 87             w %= 3;
 88             add(u, v, w);
 89             add(v, u, w);
 90         }
 91         rt = ans  = 0;
 92         snode = n;
 93         maxson[0] = inf;
 94         divide(1);
 95         ans = ans * 2 + n;
 96         int temp = n * n;
 97         int g = __gcd(ans, temp);
 98         printf("%d/%d
", ans/g, temp/g);
 99 
100     }
101 }
++++
 

 http://blog.csdn.net/ALPS233/article/details/51398629

 
原文地址:https://www.cnblogs.com/yijiull/p/8335195.html