牛客国庆集训派对Day3 B Tree

Tree

思路:

树形dp

注意0不存在逆元,任何一个数乘以0就变成0了,就没有价值浪,所以要暴力转移

代码:

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb push_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pli, int>
#define mem(a, b) memset(a, b, sizeof(a))
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define fopen freopen("in.txt", "r", stdin);freopen("out.txt", "w", stout);
//head

const int N = 1e6 + 10;
const int MOD = 1e9 + 7;
struct edge {
    int to;
    int next;
}edge[N*2];
LL cnt[N], _cnt[N];
int  head[N], tot = 0;
void add_edge(int u, int v) {
    edge[tot].to = v;
    edge[tot].next = head[u];
    head[u] = tot++;
}
LL q_pow(LL n, LL k) {
    LL ans = 1;
    while(k) {
        if(k&1) ans = (ans * n) % MOD;
        n = (n * n) % MOD;
        k >>= 1;
    }
    return ans;
}
void dfs(int o, int u) {
    cnt[u] = 1;
    for (int i = head[u]; ~i; i = edge[i].next) {
        int v = edge[i].to;
        if(v != o) {
            dfs(u, v);
            cnt[u] = (cnt[u] * (cnt[v]+1)) % MOD;
        }
    }
}
void DFS(int o, int u) {
    for (int i = head[u]; ~i; i = edge[i].next) {
        int v = edge[i].to;
        if(v != o) {
            LL t = 1;
            if(cnt[v]+1 != MOD) t = (cnt[u] * q_pow(cnt[v]+1, MOD-2)) % MOD;
            else {
                for (int i = head[u]; ~i; i = edge[i].next) {
                    int vv = edge[i].to;
                    if(vv != o && vv != v) {
                        t = (t * (cnt[vv]+1)) % MOD;
                    }
                }
            }
            _cnt[v] = (_cnt[u]*t + 1) % MOD;
            DFS(u, v);
        }
    }
}
int main() {
    int n, u, v;
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) head[i] = -1;
    for (int i = 1; i < n; i++) {
        scanf("%d %d", &u, &v);
        add_edge(u, v);
        add_edge(v, u);
    }
    dfs(1, 1);
    _cnt[1] = 1;
    DFS(1, 1);
    for (int i = 1; i <= n; i++) {
        printf("%lld
", (cnt[i] * _cnt[i]) % MOD);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/widsom/p/9742345.html