[PKUWC 2018]随机游走

Description

题库链接

给定一棵 (n) 个结点的树,你从点 (x) 出发,每次等概率随机选择一条与所在点相邻的边走过去。

(Q) 次询问,每次询问给定一个集合 (S) ,求如果从 (x) 出发一直随机游走,直到点集 (S) 中所有点都至少经过一次的话,期望游走几步。

特别地,点 (x)(即起点)视为一开始就被经过了一次。

答案对 (998244353) 取模。

Solution

不妨设 (f_{i,S}) 表示在点 (i) 时,要遍历集合 (S) 的期望步数。那么对于一个询问 (S) ,答案就是 (f_{x,S})

从两个方面来考虑如何求 (f)

  1. 如果 (u otin S) ,由套路,显然满足 [f_{u,S}=frac{sum_{ ext{v is the neighbor of u}}f_{v,S}}{degree_u}+1]
  2. 如果 (uin S)
    1. ({u}=S) ,显然 (f_{u,S}=0)
    2. ({u} eq S) ,容易得到 (f_{u,S}=f_{u,S-{u}})

这样我们对于同一个状态 (S) 可以得到若干个方程,那么在这一个状态内高斯消元即可。

由于是树上消元,所以可以用[Codeforces 802L]Send the Fool Further! (hard)的方法化成 (f_u=k_uf_{fa_u}+b_u) 的形式 (O(n)) 求解。

总复杂度是 (O(nlog(n)2^n+Q)) ,其中 (log(n)) 是求逆元的复杂度。

Code

#include <bits/stdc++.h>
using namespace std;
const int N = 20, SIZE = (1<<18)+5, yzh = 998244353;

int n, q, x, u, v, bin[N], dg[N], S;
struct tt {int to, next; }edge[N<<1];
int path[N], top, k[N], b[N], f[N][SIZE];

int quick_pow(int a, int b) {
    int ans = 1;
    while (b) {
        if (b&1) ans = 1ll*ans*a%yzh;
        b >>= 1, a = 1ll*a*a%yzh;
    }
    return ans;
}
void dfs(int u, int fa) {
    k[u] = b[u] = 0;
    for (int i = path[u], v; i; i = edge[i].next)
        if ((v = edge[i].to) != fa) dfs(v, u);
    if (!(bin[u-1]&S)) {
        if (dg[u] == 1 && x != u) k[u] = b[u] = 1;
        else {
            k[u] = dg[u], b[u] = dg[u];
            for (int i = path[u], v; i; i = edge[i].next)
                if ((v = edge[i].to) != fa) {
                    (k[u] -= k[v]) %= yzh; (b[u] += b[v]) %= yzh;
                }
            k[u] = quick_pow(k[u], yzh-2);
            b[u] = 1ll*b[u]*k[u]%yzh;
        }
    }else {
        if (S^bin[u-1]) {
            k[u] = 0; b[u] = f[u][S^bin[u-1]];
        }else k[u] = b[u] = 0;
    }
}
void cal(int u, int fa) {
    f[u][S] = (1ll*k[u]*f[fa][S]%yzh+b[u])%yzh;
    for (int i = path[u], v; i; i = edge[i].next)
        if ((v = edge[i].to) != fa) cal(v, u);
}
void add(int u, int v) {edge[++top] = (tt){v, path[u]}, path[u] = top; ++dg[v]; }
void work() {
    scanf("%d%d%d", &n, &q, &x);
    for (int i = 1; i < n; i++) {
        scanf("%d%d", &u, &v);
        add(u, v), add(v, u);
    }
    bin[0] = 1; for (int i = 1; i < N; i++) bin[i] = (bin[i-1]<<1);
    for (int i = 1; i < bin[n]; i++) S = i, dfs(x, 0), cal(x, 0);
    while (q--) {
        S = 0; scanf("%d", &u);
        for (int i = 1; i <= u; i++) scanf("%d", &v), S |= bin[v-1];
        printf("%d
", (f[x][S]+yzh)%yzh);
    }
}
int main() {work(); return 0; }
原文地址:https://www.cnblogs.com/NaVi-Awson/p/9277596.html