poj 1741 Tree(点分治)

Tree
Time Limit: 1000MS   Memory Limit: 30000K
Total Submissions: 15548   Accepted: 5054

Description

Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.

Input

The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.

Output

For each test case output the answer on a single line.

Sample Input

5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0

Sample Output

8

Source

【思路】

       设i到当前root的距离为d[i],i属于belong[i]->belong[i]为当前root的儿子且i在belong[i]为根的树中。设Sum{E}为满足条件E的点对数。

       情况分为两种:

    1)      经过根节点

    2)      不经过根节点,在根节点的一颗子树中。

   其中2)可以递归求解。

   对于1)我们要求的是Sum{d[i]+d[j]<=k && belong[i]!=belong[j]},即为

   Sum{d[i]+d[j]<=k}  -  Sum{ d[i]+d[j]<=k && belong[i]==belong[j]}

       前后两项都可以转化为求一个序列a中满足d[a[i]]+d[a[j]]<=k的点对数。

       先将a按照d值排序,基于单调性,我们可以给出一个O(n)的统计方法。于是问题得到解决。

       总的时间复杂度为O(nlog2n)

【代码】

  1 #include<cstdio>
  2 #include<vector>
  3 #include<cstring>
  4 #include<algorithm>
  5 using namespace std;
  6 
  7 const int N =  10000+10;
  8 const int INF = 1e9;
  9 
 10 struct Edge{
 11     int u,v,w;
 12     Edge(int u=0,int v=0,int w=0):u(u),v(v),w(w){};
 13 };
 14 int n,K,l1,l2,tl,ans;
 15 int siz[N],d[N],list[N],f[N],can[N];
 16 vector<Edge> es;
 17 vector<int> g[N];
 18 void adde(int u,int v,int w) {
 19     es.push_back(Edge(u,v,w));
 20     int m=es.size(); g[u].push_back(m-1);
 21 }
 22 
 23 void init() {
 24     ans=0; es.clear();
 25     memset(can,1,sizeof(can));
 26     for(int i=0;i<=n;i++) g[i].clear();
 27 }
 28 void dfs1(int u,int fa) {
 29     siz[u]=1; 
 30     list[++tl]=u;
 31     for(int i=0;i<g[u].size();i++) {
 32         int v=es[g[u][i]].v;
 33         if(v!=fa && can[v]) { 
 34             dfs1(v,u);
 35             f[v]=u; siz[u]+=siz[v];
 36         }
 37     }
 38 }
 39 int getroot(int u,int fa) {                    //寻找u子树重心 
 40     int pos,mn=INF;
 41     tl=0;
 42     dfs1(u,fa); 
 43     for(int i=1;i<=tl;i++) {
 44         int y=list[i],d=0;
 45         for(int j=0;j<g[y].size();j++) {
 46             int v=es[g[y][j]].v;
 47             if(v!=f[y] && can[v]) d=max(d,siz[v]);
 48         }
 49         if(y!=u) d=max(d,siz[u]-siz[y]);      //上方 
 50         if(d<mn) mn=d , pos=y;                //使大子结点数最小
 51     }
 52     return pos;
 53 }
 54 void dfs2(int u,int fa,int dis) {
 55     list[++l1]=u; d[u]=dis;
 56     for(int i=0;i<g[u].size();i++) {
 57         int v=es[g[u][i]].v;
 58         if(v!=fa && can[v]) dfs2(v,u,dis+es[g[u][i]].w);
 59     }
 60 }
 61 int getans(int* a,int l,int r) {
 62     int res=0,j=r;
 63     for(int i=l;i<=r;i++) {
 64         while(d[a[i]]+d[a[j]]>K && j>i) j--;
 65         res+=j-i; if(i==j) break;
 66     }
 67     return res;
 68 }
 69 bool cmp(const int& x,const int& y) { return d[x]<d[y]; }
 70 void solve(int u,int fa) {
 71     int root=getroot(u,fa);
 72     l1=l2=0;
 73     for(int i=0;i<g[root].size();i++) {        //统计 d[i]+d[j]<=K && belong[i]==belong[j] 
 74         int v=es[g[root][i]].v;
 75         if(can[v]) {
 76             l2=l1;
 77             dfs2(v,root,es[g[root][i]].w);    //insert[以v为根的子树]
 78             sort(list+l2+1,list+l1+1,cmp);
 79             ans-=getans(list,l2+1,l1);
 80         }
 81     }
 82     list[++l1]=root; d[root]=can[root]=0;
 83     sort(list+1,list+l1+1,cmp);
 84     ans+=getans(list,1,l1);                    //统计d[i]+d[j]<=K 
 85     for(int i=0;i<g[root].size();i++) {        //递归<-分治 
 86         int v=es[g[root][i]].v;
 87         if(v!=fa && can[v]) solve(v,root);
 88     }
 89 }
 90 
 91 int main() {
 92     while(scanf("%d%d",&n,&K)==2 && (n&&K)) {
 93         int u,v,w;
 94         init();
 95         for(int i=0;i<n-1;i++) {
 96             scanf("%d%d%d",&u,&v,&w);
 97             adde(u,v,w) , adde(v,u,w);
 98         }
 99         solve(1,-1);
100         printf("%d
",ans);
101     }
102     return 0;
103 }
原文地址:https://www.cnblogs.com/lidaxin/p/5186971.html