【树形dp】Distance in Tree

[CF161.D] Distance in Tree
time limit per test
3 seconds
memory limit per test
512 megabytes

tree is a connected graph that doesn't contain any cycles.

The distance between two vertices of a tree is the length (in edges) of the shortest path between these vertices.

You are given a tree with n vertices and a positive number k. Find the number of distinct pairs of the vertices which have a distance of exactly k between them. Note that pairs (vu) and (u,v) are considered to be the same pair.

Input

The first line contains two integers n and k (1 ≤ n ≤ 50000, 1 ≤ k ≤ 500) — the number of vertices and the required distance between the vertices.

Next n - 1 lines describe the edges as "ai bi" (without the quotes) (1 ≤ ai, bi ≤ nai ≠ bi), where ai and bi are the vertices connected by the i-th edge. All given edges are different.

Output

Print a single integer — the number of distinct pairs of the tree's vertices which have a distance of exactly k between them.

Please do not use the %lld specifier to read or write 64-bit integers in С++. It is preferred to use the cin, cout streams or the %I64d specifier.

Examples
inpu
5 2
1 2
2 3
3 4
2 5
output
4
input
5 3
1 2
2 3
3 4
4 5
output
2
Note

In the first sample the pairs of vertexes at distance 2 from each other are (1, 3), (1, 5), (3, 5) and (2, 4).

题目大意:树上有N个点,问多少对不同点对(u,v)最短路为K?

试题分析:设dp[N][K]代表从i走j步能到达多少点。

     初始化:dp[i][0]=1;//它不走可以到它自己

     转移一步:dp[i][j]=sum(dp[i->son][j-1]);

     统计答案分两步,一步是从i走K步能到达的点:dp[i][K]

     一步是以i为最近公共祖先的点对:dp[i->son][t-1]*(dp[i][K-t]-dp[i->son][K-t-1]);

     因为u,v   v,u算一对,所以ans最后加上tmp/2;

#include<iostream>
#include<cstring>
#include<cstdio>
#include<vector>
#include<queue>
#include<stack>
#include<algorithm>
using namespace std;

inline int read(){
	int x=0,f=1;char c=getchar();
	for(;!isdigit(c);c=getchar()) if(c=='-') f=-1;
	for(;isdigit(c);c=getchar()) x=x*10+c-'0';
	return x*f;
}
const int MAXN=100001;
const int INF=999999;
int N,K;
long long dp[50001][501];
vector<int> vec[50001];
long long ans;

void dfs(int x,int fa){
	dp[x][0]=1;
	for(int i=0;i<vec[x].size();i++){
		if(vec[x][i]==fa) continue;
		dfs(vec[x][i],x);
	}
	for(int i=0;i<vec[x].size();i++){
		if(vec[x][i]==fa) continue;
		for(int j=1;j<=K;j++) dp[x][j]+=dp[vec[x][i]][j-1];
	}
	ans+=dp[x][K]; long long tmp=0;
	for(int i=0;i<vec[x].size();i++){
		if(vec[x][i]!=fa)
		    for(int j=1;j<K;j++) tmp+=(dp[vec[x][i]][j-1]*(dp[x][K-j]-dp[vec[x][i]][K-j-1]));
	}
	ans+=(tmp/2);
	return ;
}

int main(){
    N=read(),K=read();
    for(int i=1;i<N;i++){
    	int u=read(),v=read();
		vec[u].push_back(v);
		vec[v].push_back(u); 
	}
	dfs(1,-1);
	printf("%d
",ans);
}
原文地址:https://www.cnblogs.com/wxjor/p/7266934.html