树分治模板

 1 #include <cstdio>
 2 #include <algorithm>
 3 #include <vector>
 4 #include <cmath>
 5 #pragma comment(linker,"/STACK:102400000,102400000")
 6 using namespace std;
 7 
 8 #define MP make_pair
 9 #define PB push_back
10 typedef long long LL;
11 typedef pair<int,int> PII;
12 const double eps=1e-8;
13 const double pi=acos(-1.0);
14 const int K=1e5+7;
15 const int mod=1e9+7;
16 
17 vector<PII >mp[K];
18 int cnt,ans,vis[K],tdis[K],dis[K];
19 struct CenterTree
20 {
21     int n,ret,mx,son[K];
22     void dfs(int x,int f)
23     {
24         son[x]=1;
25         int tmp=0;
26         for(int i=0;i<mp[x].size();i++)
27         if(f!=mp[x][i].first && !vis[mp[x][i].first])
28         {
29             dfs(mp[x][i].first,x);
30             son[x]+=son[mp[x][i].first];
31             tmp=max(tmp,son[mp[x][i].first]);
32         }
33         tmp=max(tmp,n-son[x]);
34         if(tmp<mx)
35             mx=tmp,ret=x;
36     }
37     int getCenter(int x,int num)
38     {
39         n=num,mx=0x3f3f3f3f;
40         dfs(x,0);
41         return ret;
42     }
43 }center;
44 void getArray(int x,int f)
45 {
46     tdis[++cnt]=dis[x];
47     for(int i=0;i<mp[x].size();i++)
48     if(mp[x][i].first!=f && !vis[mp[x][i].first])
49     {
50         dis[mp[x][i].first]=dis[x]+mp[x][i].second;
51         getArray(mp[x][i].first,x);
52     }
53 }
54 int calc(int x,int d,int mx)
55 {
56     dis[x]=d,cnt=0;
57     getArray(x,0);
58     sort(tdis+1,tdis+1+cnt);
59     int ret=0,l=1,r=cnt;
60     while(l<r)
61         if(tdis[l]+tdis[r]<=mx) ret+=r-l,l++;
62         else r--;
63     return ret;
64 }
65 void solve(int x,int mx)
66 {
67     ans+=calc(x,0,mx);
68     vis[x]=1;
69     for(int i=0;i<mp[x].size();i++)
70     if(!vis[mp[x][i].first])
71     {
72         ans-=calc(mp[x][i].first,mp[x][i].second,mx);
73         solve(center.getCenter(mp[x][i].first,center.son[mp[x][i].first]),mx);
74     }
75 }
76 int main(void)
77 {
78     int n,mx;
79     while(~scanf("%d%d",&n,&mx)&&n)
80     {
81         for(int i=1;i<=n;i++)
82             mp[i].clear(),vis[i]=0;
83         for(int i=1,x,y,z;i<n;i++)
84             scanf("%d%d%d",&x,&y,&z),mp[x].PB(MP(y,z)),mp[y].PB(MP(x,z));
85         ans=0;
86         solve(center.getCenter(1,n),mx);
87         printf("%d
",ans);
88     }
89     return 0;
90 }
原文地址:https://www.cnblogs.com/weeping/p/6875298.html