Tree

题目链接

https://ac.nowcoder.com/acm/contest/6226/C

tag

换根DP

solution

(d[u])表示(u)的子树对(u)的贡献,根据乘法原理可以得到, (d[u] = d[u] * (1 + d[v]))(v)(u)的儿子,第一次从(root)节点(dfs)可以求出每个节点(d[i]),同时我们可以得到(ans[root] = d[root]),然后我们第二次(dfs)换根统计其他节点的答案,(g[u])表示除去(u)(u)的子树,即(u)的父亲节点对答案的贡献,对于节点(u, v), (v)(u)的儿子,不难得到(ans[v] = d[v] * (1 + g[v])),当((1 + d[v]) % mod != 0) 时,(g[v] = ans[u] /(1 + d[v])) , 在取模意义下((1+d[v]))为0时,我们将(u)的出边分为三类,(u)的父亲节点, (v),剩余(u)的儿子节点(son),我们可以暴力更新(g[v]),首先我们已经得到了(u)的父亲节点对答案的贡献即(g[u]),然后我们需要计算(u)除去(v)以外的儿子节点对答案的贡献,即(d[son]),根据乘法原理, (g[v] *= (1 + g[u]),g[v] *= (1 + d[son])),然后根据(ans[v] = d[v] * (1 + g[v])),换根便可以得到所有点的答案

code

//created by pyoxiao on 2020/07/07
#include<bits/stdc++.h>
#define LL long long
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define CL(a, b) memset(a, b, sizeof(a))
using namespace std;
const int mod = 1e9 + 7;
LL fpow(LL a, LL b, LL p = mod){LL ans = 1; a %= p; while(b) {if(b & 1) ans = ans * a % p; b >>= 1; a = a * a % p;} return ans;}
LL gcd(LL a, LL b){return b == 0 ? a : gcd(b, a % b);}
LL inv(LL x) {return fpow(x, mod - 2); }
const int N = 1e6 + 7;
vector<int> ver[N];
int n;
LL d[N], f[N], g[N], ffa[N];
void dfs(int u, int fa) {
    d[u] = 1; ffa[u] = fa; 
    for(auto to : ver[u]) {
        if(to == fa) continue;
        dfs(to, u);
        d[u] = d[u] * (1 + d[to]) % mod;
    }
}
void dfs2(int u, int fa) {
    if(u != 1) {
        if((d[u] + 1) % mod) {
            g[u] = f[fa] * inv(d[u] + 1) % mod;
        } else {
            LL res = 1 + g[fa];
            for(auto to : ver[fa]) {
                if(to == u || to == ffa[fa]) continue;
                res *= (1 + d[to]);
                res %= mod;
            }
            g[u] = res;
        }
        f[u] = d[u] * (1 + g[u]) % mod;
    }
    for(auto to : ver[u]) {
        if(to == fa) continue;
        dfs2(to, u);
    }
}
void solve() {
    scanf("%d", &n);
    for(int i = 2; i <= n; i ++) {
        int u, v; 
        scanf("%d %d", &u, &v);
        ver[u].pb(v);
        ver[v].pb(u);
    }
    dfs(1, 0);
    f[1] = d[1];
    dfs2(1, 0);
    for(int i = 1; i <= n; i ++) printf("%lld
", (f[i] % mod + mod) % mod);
}
int main() {
    int T = 1;
    // scanf("%d", &T);
    while(T --) 
        solve();
    return 0;
}
原文地址:https://www.cnblogs.com/pyoxiao/p/13267738.html