hdu4705 Y 树形DP

给出一颗数,求没有一条路径穿过的节点三元集合个数。

这样的三元集合呈现Y字形,求出反面情况,三点为子节点和两个祖先节点,或一个祖先节点与它子树中非父子关系的节点。可由树形DP求得。

 1 #pragma comment(linker, "/STACK:16777216")
 2 #include<stdio.h>
 3 #include<string.h>
 4 typedef long long ll;
 5 const int maxm=1e5+10;
 6 int d[maxm],son[maxm],fa[maxm];
 7 ll dp[maxm],Dp[maxm],sum,Sum;
 8 int head[maxm],point[maxm<<1],nxt[maxm<<1],size;
 9 int n;
10 
11 inline void add(int a,int b){
12     point[size]=b;
13     nxt[size]=head[a];
14     head[a]=size++;
15     point[size]=a;
16     nxt[size]=head[b];
17     head[b]=size++;
18 }
19 
20 int dfs1(int r){
21     for(int i=head[r];~i;i=nxt[i]){
22         int j=point[i];
23         if(!d[j]){
24             d[j]=d[r]+1;
25             fa[j]=r;
26             son[r]+=dfs1(j);
27         }
28     }
29     return son[r]+1;
30 }
31 
32 ll dfs2(int r){
33     for(int i=head[r];~i;i=nxt[i]){
34         int j=point[i];
35         if(d[j]==d[r]+1){
36             dp[r]+=son[j]+dfs2(j);
37         }
38     }
39     sum+=dp[r];
40     return dp[r];
41 }
42 
43 void dfs3(int r){
44     Dp[r]=-son[r]-1+Dp[fa[r]]+dp[fa[r]]-dp[r]-son[r]+n-d[fa[r]];
45     Sum+=Dp[r];
46     for(int i=head[r];~i;i=nxt[i]){
47         int j=point[i];
48         if(d[j]==d[r]+1){
49             dfs3(j);
50         }
51     }
52 }
53 
54 inline int read(){
55     int x=0;
56     char c=getchar();
57     while(c>'9'||c<'0')c=getchar();
58     while(c>='0'&&c<='9'){
59         x=x*10+c-'0';
60         c=getchar();
61     }
62     return x;
63 }
64 
65 int main(){
66     while(scanf("%d",&n)!=EOF){
67         memset(son,0,sizeof(son));
68         memset(head,-1,sizeof(head));
69         size=0;
70         memset(d,0,sizeof(d));
71         memset(dp,0,sizeof(dp));
72         memset(Dp,0,sizeof(Dp));
73         sum=Sum=0;
74         int i;
75         for(i=1;i<=n-1;i++){
76             int a,b;
77             scanf("%d%d",&a,&b);
78     //        int a=read();
79     //        int b=read();
80             add(a,b);
81         }
82         d[1]=1;
83         fa[1]=0;
84         dfs1(1);
85         son[0]=son[1]+1;
86         dfs2(1);
87         for(i=head[1];~i;i=nxt[i]){
88             dfs3(point[i]);
89         }
90         ll ans=((ll)n*(n-1)*(n-2)/2/3)-sum-Sum/2;
91         printf("%lld
",ans);
92     }
93     return 0;}
View Code



原文地址:https://www.cnblogs.com/cenariusxz/p/6598558.html