树点分治入门题poj1741

Tree
Time Limit: 1000MS   Memory Limit: 30000K
Total Submissions: 24253   Accepted: 8060

Description

Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.

Input

The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.

Output

For each test case output the answer on a single line.

Sample Input

5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0

Sample Output

8

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=11111;
const int M=55555;
const int INF=1e9;
struct node{
    int v,next,w;
}e[M];
int head[N],tot;
int n,k,vis[N],ans,root,num;
void init(){
    memset(vis,0,sizeof(vis));
    memset(head,-1,sizeof(head));
    tot=ans=0;
}
void add(int u,int v,int w){
   e[tot].v=v;e[tot].w=w;e[tot].next=head[u];head[u]=tot++;
}
int mx[N],size[N],mi,dis[N];
void dfssize(int u,int fa){
    size[u]=1;
    mx[u]=0;
    for(int i=head[u];~i;i=e[i].next){
        int v=e[i].v;
        if(v!=fa&&!vis[v]) {
            dfssize(v,u);
            size[u]+=size[v];
            if(size[v]>mx[u]) mx[u]=size[v];
        }
    }
}
void dfsroot(int r,int u,int fa){
   if(size[r]-size[u]>mx[u]) mx[u]=size[r]-size[u];
   if(mx[u]<mi) mi=mx[u],root=u;
   for(int i=head[u];~i;i=e[i].next){
    int v=e[i].v;
    if(v!=fa&&!vis[v]) dfsroot(r,v,u);
   }
}
void dfsdis(int u,int d,int fa){
    dis[num++]=d;
    for(int i=head[u];~i;i=e[i].next){
        int v=e[i].v;
        if(v!=fa&&!vis[v]) dfsdis(v,d+e[i].w,u);
    }
}
int calc(int u,int d){
   int ret=0;
   num=0;
   dfsdis(u,d,0);
   sort(dis,dis+num);
   int i=0,j=num-1;
   while(i<j){
    while(dis[i]+dis[j]>k&&i<j) --j;
    ret+=j-i;
    ++i;
   }
   return ret;
}
void dfs(int u){//由于每次都取重心,所以最糟糕的情况是logN层,然后每层小于等于N个数遍历,所以复杂度N*logN
    mi=n;
    dfssize(u,0);
    dfsroot(u,u,0);
    ans+=calc(root,0);
    vis[root]=1;
    for(int i=head[root];~i;i=e[i].next){
        int v=e[i].v;
        if(!vis[v]){
         ans-=calc(v,e[i].w);
        dfs(v);
        }
    }
}
int main(){
    while(scanf("%d%d",&n,&k)!=EOF){
        if(!n&&!k) break;
        init();
        int u,v,w;
        for(int i=0;i<n-1;++i) {
            scanf("%d%d%d",&u,&v,&w);
            add(u,v,w);
            add(v,u,w);
        }
        dfs(1);
        printf("%d
",ans);
    }
}
原文地址:https://www.cnblogs.com/mfys/p/7563439.html