HDU 5378 (2015多校第七场1010) 概率DP

题意是给一棵树,有n个节点,求能组成k个leader的方案数?每个节点有一个val值(1~n且每个节点的val值不相同)。leader的定义,如果一个子树中最大的val值是根节点对应的val值,那么我们称这个节点是leader。

我们用x[i],y[i]分别代表这个节点能够成为leader和不能够成为leader的概率。

cnt[i] 代表以i节点为根的子树的节点数。

那么x[i] = 1/cnt[i],y[i] = 1-(1/cnt[i])。因为这里面出现了分数,所有我们用逆元处理一下。

我们设dp[i][j]表示编号为1,2...i的节点中有j个leader的概率。

那么转移方程就是 dp[i][j] = dp[i-1][j-1] * x[i] + dp[i-1][j] * y[i]。

#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <cstring>
#include <vector>
#define ll long long
#define FOR(i,x,y)  for(int i = x;i < y;i ++)
#define IFOR(i,x,y) for(int i = x;i > y;i --)
#define MOD 1000000007
#define N 1100

using namespace std;

ll dp[N][N],fac[N],cnt[N],x[N],y[N];
int n,k;
vector <int> G[N];

void init(){
    fac[0] = 1;
    FOR(i,1,N){
        fac[i] = fac[i-1]*i;
        fac[i] %= MOD;
    }
}

void gcd(ll a,ll b,ll& d,ll& x,ll& y){
    if(!b)  {d = a; x = 1; y = 0;}
    else{gcd(b,a%b,d,y,x);y -= x*(a/b);}
}

ll inv(ll a){
    ll d,x,y;
    gcd(a,MOD,d,x,y);
    return d == 1 ? (x+MOD)%MOD : -1;
}

void dfs(int u,int fa){
    cnt[u] = 1;
    FOR(i,0,G[u].size()){
        int v = G[u][i];
        if(v == fa) continue;
        dfs(v,u);
        cnt[u] += cnt[v];
    }
}

void calc(){
    FOR(i,1,n+1){
        if(cnt[i] == 0){
            x[i] = 1;
            y[i] = 0;
            continue;
        }
        x[i] = inv(cnt[i]);
        y[i] = (cnt[i]-1) * x[i];
        y[i] %= MOD;
    }
}

ll solve(){
    calc();
    dp[1][0] = y[1];
    dp[1][1] = x[1];
    FOR(i,2,n+1){
        dp[i][0] = dp[i-1][0] * y[i];
        dp[i][0] %= MOD;
    }
    FOR(i,1,n+1){
        FOR(j,i+1,n+1){
            dp[i][j] = 0;
        }
    }
    FOR(i,2,n+1){
        int lim = min(i+1,k+1);
        FOR(j,1,lim){
            ll t1 = (dp[i-1][j-1] * x[i])%MOD;
            ll t2 = (dp[i-1][j] * y[i])%MOD;
            dp[i][j] = (t1+t2)%MOD;
            dp[i][j] %= MOD;
        }
    }
    ll ans = dp[n][k] * fac[n];
    ans %= MOD;
    return ans;
}

int main()
{
    //freopen("test.in","r",stdin);
    init();
    int T,tCase = 0;
    scanf("%d",&T);
    while(T--){
        printf("Case #%d: ",++tCase);
        FOR(i,0,N)  G[i].clear();
        scanf("%d%d",&n,&k);
        int u,v;
        FOR(i,0,n-1){
            scanf("%d%d",&u,&v);
            G[u].push_back(v);
            G[v].push_back(u);
        }
        dfs(1,-1);
        printf("%I64d
",solve());
    }
    return 0;
}


版权声明:本文为博主原创文章,未经博主允许不得转载。

原文地址:https://www.cnblogs.com/hqwhqwhq/p/4811894.html