CF 161D Distance in Tree 树形DP

一棵树,边长都是1,问这棵树有多少点对的距离刚好为k


令tree(i)表示以i为根的子树


dp[i][j][1]:在tree(i)中,经过节点i,长度为j,其中一个端点为i的路径的个数
dp[i][j][0]:在tree(i)中,经过节点i,长度为j,端点不在i的路径的个数


则目标:∑(dp[i][k][0]+dp[i][k][1])
初始化:dp[i][0][1]=1,其余为0


siz[i]:tree(i)中,i与离i最远的点的距离
递推:
dp[i][j][0]+=dp[i][j-l][1]*dp[soni][l-1][1]
dp[i][j][1]=∑dp[soni][j-1][1]


注意:在更新每一个dp[i]时,先更新dp[i][j][0],再更新dp[i][j][1]

  1 #include<cstdio>
  2 #include<cstring>
  3 #include<iostream>
  4 
  5 using namespace std;
  6 
  7 const int maxn=50000+5;
  8 const int maxk=505;
  9 #define LL long long
 10 
 11 inline int max(int a,int b)
 12 {
 13     return a>b?a:b;
 14 }
 15 
 16 inline int min(int a,int b)
 17 {
 18     return a<b?a:b;
 19 }
 20 
 21 LL dp[maxn][maxk][2];
 22 int siz[maxn];
 23 struct Edge
 24 {
 25     int to,next;
 26 };
 27 Edge edge[maxn<<1];
 28 int head[maxn];
 29 int tot;
 30 int k;
 31 
 32 void init()
 33 {
 34     memset(head,-1,sizeof head);
 35     tot=0;
 36     memset(dp,0,sizeof dp);
 37 }
 38 
 39 void addedge(int u,int v)
 40 {
 41     edge[tot].to=v;
 42     edge[tot].next=head[u];
 43     head[u]=tot++;
 44 }
 45 
 46 void solve(int ,int );
 47 void dfs(int ,int );
 48 
 49 int main()
 50 {
 51     init();
 52     int n;
 53     scanf("%d %d",&n,&k);
 54     for(int i=1;i<n;i++)
 55     {
 56         int u,v;
 57         scanf("%d %d",&u,&v);
 58         addedge(u,v);
 59         addedge(v,u);
 60     }
 61     solve(n,k);
 62     return 0;
 63 }
 64 
 65 void solve(int n,int k)
 66 {
 67     dfs(1,0);
 68     LL ans=0;
 69     for(int i=1;i<=n;i++)
 70     {
 71         ans+=(dp[i][k][0]+dp[i][k][1]);
 72     }
 73     cout<<ans<<endl;
 74     return ;
 75 }
 76 
 77 void dfs(int u,int pre)
 78 {
 79     dp[u][0][1]=1;
 80     siz[u]=0;
 81     for(int i=head[u];~i;i=edge[i].next)
 82     {
 83         int v=edge[i].to;
 84         if(v==pre)
 85             continue;
 86         dfs(v,u);
 87         for(int l=1;l<=siz[v]+1;l++)
 88         {
 89             for(int j=l+1;j<=siz[u]+l;j++)
 90             {
 91                 if(j>k)
 92                     continue;
 93                 dp[u][j][0]+=dp[u][j-l][1]*dp[v][l-1][1];
 94             }
 95         }
 96         for(int j=1;j<=siz[v]+1;j++)
 97             dp[u][j][1]+=dp[v][j-1][1];
 98         siz[u]=max(siz[u],siz[v]+1);
 99         siz[u]=min(siz[u],k);
100     }
101 }
View Code
原文地址:https://www.cnblogs.com/-maybe/p/4764393.html