BZOJ 2152 聪聪可可(树形DP)

给出一颗n个点带边权的树(n<=20000),求随机选择两个点,使得它们之间的路径边权是3的倍数的概率是多少。

首先总的对数是n*n,那么只需要统计路径边权是3的倍数的点对数量就行了。

考虑将无根树化为有根树,令dp[x][i]表示以x点为路径起点,x的某个子孙为路径终点的边权值模3为i的点对数量。

那么显然有dp[x][i]+=dp[son[x]][(i-w)%3].

考虑点对之间的路径,要么是它们的LCA是点对中的一个点,要么不在点对中,因此统计一下以每个点x为LCA时的路径边权值%3为i的点对数量。

而这两个统计都可以在一次树形DP中完成。因此总复杂度为O(n).

# include <cstdio>
# include <cstring>
# include <cstdlib>
# include <iostream>
# include <vector>
# include <queue>
# include <stack>
# include <map>
# include <bitset>
# include <set>
# include <cmath>
# include <algorithm>
using namespace std;
# define lowbit(x) ((x)&(-x))
# define pi acos(-1.0)
# define eps 1e-8
# define MOD 30031
# define INF 1000000000
# define mem(a,b) memset(a,b,sizeof(a))
# define FOR(i,a,n) for(int i=a; i<=n; ++i)
# define FO(i,a,n) for(int i=a; i<n; ++i)
# define bug puts("H");
# define lch p<<1,l,mid
# define rch p<<1|1,mid+1,r
# define mp make_pair
# define pb push_back
typedef pair<int,int> PII;
typedef vector<int> VI;
# pragma comment(linker, "/STACK:1024000000,1024000000")
typedef long long LL;
int Scan() {
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
const int N=20005;
//Code begin...

struct Edge{int p, next, w;}edge[N<<1];
int head[N], cnt=1, dp[N][3], son[N];

void add_edge(int u, int v, int w){edge[cnt].p=v; edge[cnt].w=w; edge[cnt].next=head[u]; head[u]=cnt++;}
void dfs(int x, int fa){
    for (int i=head[x]; i; i=edge[i].next) {
        int v=edge[i].p;
        if (v==fa) continue;
        dfs(v,x);
        FO(j,0,3) dp[x][j]+=dp[v][((j-edge[i].w)%3+3)%3];
    }
    for (int i=head[x]; i; i=edge[i].next) {
        int v=edge[i].p;
        if (v==fa) continue;
        int y0=((-edge[i].w)%3+3)%3, y1=((1-edge[i].w)%3+3)%3, y2=((2-edge[i].w)%3+3)%3;
        son[x]+=dp[v][y0]*(dp[x][0]-dp[v][y0])+dp[v][y1]*(dp[x][2]-dp[v][y2])+dp[v][y2]*(dp[x][1]-dp[v][y1]);
    }
    dp[x][0]+=1;
}
int main ()
{
    int n, ans=0, sum=0, u, v, w;
    scanf("%d",&n);
    FO(i,1,n) scanf("%d%d%d",&u,&v,&w), add_edge(u,v,w%3), add_edge(v,u,w%3);
    dfs(1,0);
    FOR(i,1,n) ans+=dp[i][0];
    ans=ans*2-n; sum=n*n;
    FOR(i,1,n) ans+=son[i];
    int gcd=__gcd(ans,sum);
    ans/=gcd; sum/=gcd;
    printf("%d/%d
",ans,sum);
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/lishiyao/p/6882519.html