POJ 1741 Tree

POJ 1741 Tree

POJ传送门

Description

Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.

Input

The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.

Output

For each test case output the answer on a single line.

Sample Input

5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0

Sample Output

8

题目大意:

给定一棵有n个节点的带边权无根树。求长度不超过k的路径有多少条。


题解:

点分治入门题。

关于点分治,可参以下:

详解点分治

代码:

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn=1e4+10;
int n,k,ans;
int tot,head[maxn],nxt[maxn<<1],to[maxn<<1],val[maxn<<1];
bool v[maxn];
int size[maxn],mp[maxn],dist[maxn];
int sum,root,s;
void add(int x,int y,int z)
{
	to[++tot]=y;
	val[tot]=z;
	nxt[tot]=head[x];
	head[x]=tot;
}
void getroot(int x,int f)
{
	size[x]=1,mp[x]=0;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=to[i];
		if(y==f||v[y])
			continue;
		getroot(y,x);
		size[x]+=size[y];
		mp[x]=max(mp[x],size[y]);
	}
	mp[x]=max(mp[x],sum-size[x]);
	if(mp[x]<mp[root])
		root=x;
}
void getdis(int x,int f,int d)
{
	dist[++s]=d;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=to[i];
		if(v[y]||y==f)
			continue;
		getdis(y,x,d+val[i]);
	}
}
int calc(int x,int len)
{
	s=0;
	memset(dist,0,sizeof(dist));
	getdis(x,0,len);
	sort(dist+1,dist+s+1);
	int l=1,r=s,cnt=0;
	while(l<=r)
	{
		if(dist[r]+dist[l]<=k)
			cnt+=(r-l),l++;
		else
			r--;
	}
	return cnt;
}
void dfz(int x)
{
	ans+=calc(x,0);
	v[x]=1;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=to[i];
		if(v[y])
			continue;
		ans-=calc(y,val[i]);
		sum=size[y],root=0;
		getroot(y,0);
		dfz(root);
	}
}
int main()
{
	while(scanf("%d%d",&n,&k)&&(n&&k))
	{
		memset(head,0,sizeof(head));
		memset(v,0,sizeof(v));
		tot=0;ans=0;
		for(int i=1;i<n;i++)
		{
			int x,y,z;
			scanf("%d%d%d",&x,&y,&z);
			add(x,y,z);
			add(y,x,z);
		}
		mp[0]=sum=n;
		getroot(1,0);
		dfz(root);
		printf("%d
",ans);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/fusiwei/p/13822136.html