Codeforces 1156D 0-1-Tree(树形dp)

传送:http://codeforces.com/contest/1156/problem/D

题意:有一棵$n$($nleq200000$)个结点的树,$n-1$条边,每条边有一个值$(0,1)$,对于从$x$到$y$的唯一路径不能从0边到1边,问有多少点对符合要求。

分析:

  考虑这样一个dp方程,$dp[i][0/1]$,$dp[i][0]$代表从结点$i$出发的权值为0的边,可以到达点的个数,同理:$dp[i][1]$代表从结点$i$出发的权值为1的边,可以到达点的个数。

  那么就是说,我做两边树的遍历:

第一遍可以先处理出儿子继承父亲的答案;

第二遍处理出父亲“继承”儿子的答案(同时需要去除掉本身儿子继承父亲的答案)。

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 typedef long long ll;
 4 const int maxn=2e5+10;
 5 struct node{
 6     int to,w,nxt;
 7 }e[maxn*2];
 8 int head[maxn],tot;
 9 ll ans;
10 ll dp[maxn][2];
11 void add(int x,int y,int w){
12     e[tot]={y,w,head[x]};
13     head[x]=tot++;
14 }
15 void dfs(int x,int fa){
16     for (int i=head[x];i!=-1;i=e[i].nxt){
17         if (e[i].to==fa) continue;
18         dfs(e[i].to,x);
19         if (e[i].w==0) dp[x][0]+=dp[e[i].to][0];
20         else dp[x][1]+=dp[e[i].to][0]+dp[e[i].to][1];
21     }
22 }
23 void dfs2(int x,int fa){
24     for (int i=head[x];i!=-1;i=e[i].nxt){
25         if (e[i].to==fa) continue;
26         if (e[i].w==0) dp[e[i].to][0]+=(dp[x][0]-dp[e[i].to][0]);
27         else dp[e[i].to][1]+=(dp[x][0]-dp[e[i].to][0])+(dp[x][1]-dp[e[i].to][1]);
28         dfs2(e[i].to,x);
29     }
30 } 
31 int main(){
32     int n,x,y,z; scanf("%d",&n);
33     tot=0;
34     for (int i=1;i<=n;i++) head[i]=-1;
35     for (int i=0;i<n-1;i++){
36         scanf("%d%d%d",&x,&y,&z);
37         add(x,y,z);
38         add(y,x,z);
39         dp[x][z]++;
40     }
41     ans=0;
42     dfs(1,0);
43     dfs2(1,0);
44     for (int i=1;i<=n;i++){
45         cout << dp[i][0] << "  " << dp[i][1] << endl;
46         ans+=(dp[i][0]+dp[i][1]);
47     }
48     printf("%lld
",ans);
49     return 0; 
50 }

 来自学妹的并查集做法:

维护两个块,一个全为1,一个全为0。全为1或者全为0的块内答案为num*(num-1);

同时如果一个点可以连接起全0块或全1块,那么答案为(num1-1)*(num2-1)。

 1 #include<bits/stdc++.h>
 2 #define ll long long
 3 using namespace std;
 4 const int maxn=2e5+10;
 5 int pre[2][maxn],num[2][maxn];
 6 int n,a,b,c;
 7 int find(int mark,int x)
 8 {
 9     return x==pre[mark][x]?x:pre[mark][x]=find(mark,pre[mark][x]); 
10 }
11 void merge(int mark,int x,int y)
12 {
13     int fx=find(mark,x),fy=find(mark,y);
14     if (fx!=fy)
15     {
16         pre[mark][fx]=fy;
17         num[mark][fy]+=num[mark][fx];
18     }
19 }
20 int main()
21 {
22     scanf("%d",&n);
23     for (int i=0;i<=n;i++) pre[0][i]=pre[1][i]=i,num[0][i]=num[1][i]=1;
24     for (int i=1;i<n;i++)
25     {
26         scanf("%d%d%d",&a,&b,&c);
27         merge(c,a,b);
28     }
29     ll ans=0;
30     for (int i=1;i<=n;i++)
31     {
32         if (pre[0][i]==i) ans+=1ll*num[0][i]*(num[0][i]-1);
33         if (pre[1][i]==i) ans+=1ll*num[1][i]*(num[1][i]-1);
34         int xx=find(0,i),yy=find(1,i);
35         ans+=1ll*(num[0][xx]-1)*(num[1][yy]-1);
36     }
37     printf("%lld
",ans);
38     return 0;
39 } 
原文地址:https://www.cnblogs.com/changer-qyz/p/10827258.html