POJ 1741 Tree 点分治

题目传送门

题意:在一颗树上,求多少对点的距离 <= k.

先吐槽这个题目, f**k, 题目中说了l 和 n的范围, 我还以为k的范围也小于1001, 结果k的范围是没有确定的,直接写了一个树状数组疯狂re。。。。。。

题解:很裸的点分治。

1.找重心。

2.算多少个点对经过重心且满足题意。

3.递归处理所有子树。

代码:

  1 #include<cstdio>
  2 #include<algorithm>
  3 #include<vector>
  4 #include<queue>
  5 #include<iostream>
  6 #include<cstring>
  7 using namespace std;
  8 #define Fopen freopen("_in.txt","r",stdin); freopen("_out.txt","w",stdout);
  9 #define LL long long
 10 #define ULL unsigned LL
 11 #define fi first
 12 #define se second
 13 #define pb push_back
 14 #define lson l,m,rt<<1
 15 #define rson m+1,r,rt<<1|1
 16 #define lch(x) tr[x].son[0]
 17 #define rch(x) tr[x].son[1]
 18 #define max3(a,b,c) max(a,max(b,c))
 19 #define min3(a,b,c) min(a,min(b,c))
 20 typedef pair<int,int> pll;
 21 const int inf = 0x3f3f3f3f;
 22 const LL INF = 0x3f3f3f3f3f3f3f3f;
 23 const LL mod =  (int)1e9+7;
 24 const int N = 1e5 + 100;
 25 int sz[N], vis[N];
 26 int head[N], to[N<<1], ct[N<<1], nt[N<<1], tot;
 27 int n, m; int ans;
 28 void add(int u, int v, int w){
 29     to[tot] = v; ct[tot] = w;
 30     nt[tot] = head[u]; head[u] = tot++;
 31 }
 32 int rtsz, rt;
 33 void get_rt(int o, int u, int num){
 34     sz[u] = 1;
 35     int v, mxnum = 0;
 36     for(int i = head[u]; ~i; i = nt[i]){
 37         v = to[i];
 38         if(vis[v] || o == v) continue;
 39         get_rt(u, v, num);
 40         sz[u] += sz[v];
 41         mxnum = max(mxnum, sz[v]);
 42     }
 43     if(o)  mxnum = max(mxnum, num - sz[u]);
 44     if(mxnum < rtsz){
 45         rtsz = mxnum;
 46         rt = u;
 47     }
 48     return ;
 49 }
 50 int bit[N];
 51 void Add(int x, int v){
 52     while(x <= n){
 53         bit[x] += v;
 54         x += x & (-x);
 55     }
 56     return ;
 57 }
 58 int Query(int x){
 59     int ret = 0;
 60     while(x > 0){
 61         ret += bit[x];
 62         x -= x & (-x);
 63     }
 64     return ret;
 65 }
 66 int d[N], dcnt;
 67 void dfs(int o, int u, int w){
 68     d[++dcnt] = w;
 69     sz[u] = 1;
 70     for(int i = head[u]; ~i; i = nt[i]){
 71         int v = to[i];
 72         if(vis[v] || o == v) continue;
 73         dfs(u, v, w+ct[i]);
 74         sz[u] += sz[v];
 75     }
 76     return ;
 77 }
 78 int cal(){
 79     sort(d+1, d+1+dcnt);
 80     int l = 1, r = dcnt, ret = 0;
 81     while(l < r){
 82         if(d[l] + d[r] <= m) ret += r - l, l++;
 83         else r--;
 84     }
 85     return ret;
 86 }
 87 
 88 void solve(int u, int num){
 89     if(num <= 1) return ;
 90     rtsz = inf;
 91     get_rt(0, u, num);
 92     vis[rt] = 1;
 93     int v;
 94     dcnt = 0;
 95     dfs(0,rt,0);
 96     ans += cal();
 97     for(int i = head[rt]; ~i; i = nt[i]){
 98         v = to[i];
 99         if(vis[v]) continue;
100         dcnt = 0;
101         dfs(0, v, ct[i]);
102         ans -= cal();
103     }
104     for(int i = head[rt]; ~i; i = nt[i]){
105         v = to[i];
106         if(vis[v]) continue;
107         solve(v, sz[v]);
108     }
109     return ;
110 }
111 int main(){
112     int u, v, w;
113     while(~scanf("%d%d", &n, &m) && n+m){
114         tot = 0; ans = 0;
115         for(int i = 1; i <= n; i++){
116             vis[i] = 0;
117             head[i] = -1;
118         }
119         for(int i = 1; i < n; i++){
120             scanf("%d%d%d", &u, &v, &w);
121             add(u, v, w); add(v, u, w);
122         }
123         solve(1, n);
124         printf("%d
", ans);
125     }
126     return 0;
127 }
View Code
原文地址:https://www.cnblogs.com/MingSD/p/9871556.html