[poj1741][tree] (树/点分治)

Description

Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.

Input

The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.

Output

For each test case output the answer on a single line.

Sample Input

5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0

Sample Output

8

Source

Solution

1.点分治+排序

先找出重心,求解答案。对于每个重心,计算出所有过该点的最短路径长度小于或等于k的点对,记此答案为ans1

由于这些点对中会出现如下情况:

即,设任意分治出的子树重心的儿子为p,可能出现两个p的儿子共用了p到重心的路径,不符合最短路径要求

为了减去这种情况,我们可以递归算出所有关于p的重复答案,计为ans2

ans1-sum(ans2)即为最后答案

16456848

  ksq2013 1741 Accepted 760K 172MS C++ 1547B 2017-01-07 13:50:58
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#define N 10010
#define inf ~0U>>1
using namespace std;
int fst[N],ecnt,ans;
struct edge{
  int v,w,nxt;
}e[N<<1];
inline void link(int x,int y,int w){
  e[++ecnt].v=y;
  e[ecnt].w=w;
  e[ecnt].nxt=fst[x];
  fst[x]=ecnt;
}
bool vis[N];
int n,m,root,f[N],size[N],d[N],deep[N],sum,top;
void getroot(int x,int fa){
  f[x]=0;
  size[x]=1;
  for(int j=fst[x];j;j=e[j].nxt)
    if(e[j].v^fa&&!vis[e[j].v])
      getroot(e[j].v,x),
      size[x]+=size[e[j].v],
      f[x]=max(f[x],size[e[j].v]);
  f[x]=max(f[x],sum-size[x]);
  if(f[x]<=f[root])root=x;
}
void getdeep(int x,int fa){
  deep[++top]=d[x];
  for(int j=fst[x];j;j=e[j].nxt)
    if(e[j].v^fa&&!vis[e[j].v])
      d[e[j].v]=d[x]+e[j].w,
      getdeep(e[j].v,x);
}
int cal(int x,int v){
  d[x]=v;top=0;
  getdeep(x,0);
  sort(deep+1,deep+1+top);
  int t=0;
  for(int l=1,r=top;l<r;)
    if(deep[l]+deep[r]<=m)
      t+=r-l,l++;
    else r--;
  return t;
}
void solve(int x){
  vis[x]=1;
  ans+=cal(x,0);
  for(int j=fst[x];j;j=e[j].nxt)
    if(!vis[e[j].v])
      ans-=cal(e[j].v,e[j].w),
      root=0,sum=size[e[j].v],
      getroot(e[j].v,root),
      solve(root);
}
int main(){
  while(scanf("%d%d",&n,&m)&&n){
    ans=ecnt=0;memset(fst,0,sizeof(fst));
    memset(vis,0,sizeof(vis));
    for(int i=1;i<n;i++){
      int x,y,w;
      scanf("%d%d%d",&x,&y,&w);
      link(x,y,w);link(y,x,w);
    }
    root=0;f[0]=inf;sum=n;
    getroot(1,0);
    solve(root);
    printf("%d
",ans);
  }
  return 0;
}
原文地址:https://www.cnblogs.com/keshuqi/p/6259253.html