[十二省联考2019]希望

题解

题解?不讲了不讲了,一张图说明一切(要素完备)

这里只是用来记录博主过了此题

代码不长,也就132行

#include <bits/stdc++.h>

#define rep(i, a, b) for (int i = a, i##end = b; i <= i##end; ++i)
#define per(i, a, b) for (int i = a, i##end = b; i >= i##end; --i)
#define rep0(i, a) for (int i = 0, i##end = a; i < i##end; ++i)
#define per0(i, a) for (int i = a-1; ~i; --i)
#define max(a, b) ((a) > (b) ? (a) : (b))
#define min(a, b) ((a) < (b) ? (a) : (b))
#define chkmin(a, b) a = min(a, b)
#define chkmax(a, b) a = max(a, b)
#define x first
#define y second
#define E puts("");

typedef long long ll;

const int maxn = 1000000 + 5;
const int P = 998244353;

inline int read() {
	int w = 0, f = 1; char c;
	while (!isdigit(c = getchar())) c == '-' && (f = -1);
	while (isdigit(c)) w = w*10+(c^48), c = getchar();
	return w * f;
}

using std::vector;
using std::pair;

int inc(int a, int b) { return (a += b) >= P ? a-P : a; }
int mul(int a, int b) { return 1ll*a*b%P; }
int qpow(int a, int b) {
	int t = 1;
	for (; b; b >>= 1, a = mul(a, a))
		if (b & 1) t = mul(a, t);
	return t;
}

int n, L, k, ans = 0;

vector<int> e[maxn];
vector<pair<int, int> > c[maxn];
int len[maxn], son[maxn];
int w[maxn], inv[maxn];
void dfs(int u, int fa) {
	son[u] = 0, w[u] = 1; int v;
	rep0(i, e[u].size())
		if ((v = e[u][i]) != fa)
			dfs(v, u), w[u] = mul(w[u], w[v]), len[v] > len[son[u]] && (son[u] = v);
	len[u] = len[son[u]]+1; w[u] = inc(w[u], 1);
	// sort
	rep0(i, e[u].size())
		if ((v = e[u][i]) != fa && v != son[u]) c[len[v]].push_back(std::make_pair(u, v));
}

void get_inv() {
	int prei[maxn], sufi[maxn]; // delete static
	prei[0] = sufi[n+1] = 1;
	rep(i, 1, n) prei[i] = w[i] ? mul(prei[i-1], w[i]) : prei[i-1];
	per(i, n, 1) sufi[i] = w[i] ? mul(sufi[i+1], w[i]) : sufi[i+1];
	int S = qpow(prei[n], P-2);
	rep(i, 1, n) if (w[i]) inv[i] = mul(mul(prei[i-1], sufi[i+1]), S);
}

vector<pair<int*, int> > S[maxn]; // back_stack
void ins(int i, int &x) { S[i].push_back(std::make_pair(&x, x)); }
void undo(int i) { per0(j, S[i].size()) *S[i][j].x = S[i][j].y; S[i].clear(); }

int f[maxn], *pf = f, *F[maxn], Af[maxn], If[maxn], Lf[maxn], Mf[maxn], Zf[maxn];
int g[maxn], *pg = g, *G[maxn], Ag[maxn], Ig[maxn], Lg[maxn], Mg[maxn], Zg[maxn], s[maxn], t;
int Gf(int u, int i) { return inc(mul(i < Lf[u] ? F[u][i] : Zf[u], Mf[u]), Af[u]); }
void Sf(int u, int i, int v) { F[u][i] = mul(inc(v, P-Af[u]), If[u]); }
void get_f(int u) {
	if (F[u] = pf++, !son[u]) { Mf[u] = If[u] = Lf[u] = Af[u] = 1; Zf[u] = 0; goto END; }
	get_f(son[u]), Af[u] = Af[son[u]], Mf[u] = Mf[son[u]], If[u] = If[son[u]], Zf[u] = Zf[son[u]], Lf[u] = Lf[son[u]]+1; Sf(u, 0, 1);
	rep0(x, e[u].size()) {
		int v = e[u][x]; get_f(v);
		while (Lf[u] < Lf[v]+1) ins(v, F[u][Lf[u]]), Sf(u, Lf[u], Gf(u, Lf[u])), ins(v, Lf[u]), Lf[u]++;
		rep0(i, Lf[v]) ins(v, F[u][i+1]), Sf(u, i+1, mul(Gf(u, i+1), Gf(v, i)));
		if (w[v]) {
			rep0(i, Lf[v]+1) ins(v, F[u][i]), Sf(u, i, mul(Gf(u, i), inv[v]));
			ins(v, Af[u]), Af[u] = mul(Af[u], w[v]), ins(v, Mf[u]), Mf[u] = mul(Mf[u], w[v]), ins(v, If[u]), If[u] = mul(If[u], inv[v]);
		} else ins(v, Lf[u]), Lf[u] = Lf[v]+1, ins(v, Zf[u]), Zf[u] = mul(P-Af[u], If[u]);
	}
	ins(son[u], Af[u]); END:Af[u] = inc(Af[u], 1);
}

int Gg(int u, int i) { return inc(mul(i < Lg[u] ? G[u][L-i] : Zg[u], Mg[u]), Ag[u]); } // L-i<len[u] => i>L-len[u]
void Sg(int u, int i, int v) { G[u][L-i] = mul(inc(v, P-Ag[u]), Ig[u]); } // just like above
void initG() { G[1] = pg, pg += len[1], Ag[1] = Mg[1] = Ig[1] = 1, Zg[1] = 0, Lg[1] = L+1; }
void get_g(int u) {
	std::reverse(e[u].begin(), e[u].end()); s[t = 0] = 1;
	if (L < len[u]) Sg(u, 0, 1);
	ans = inc(ans, inc(qpow(mul(inc(Gf(u, L), P-1), Gg(u, L)), k), P-qpow(mul(inc(Gf(u, L-1), P-1), inc(Gg(u, L), P-1)), k)));
	if (son[u]) undo(son[u]); else return;
	rep0(x, e[u].size()) {
		int v = e[u][x]; undo(v);
		G[v] = pg, pg += len[v]; Ag[v] = 0; Mg[v] = Ig[v] = 1, Lg[v] = L+1, Zg[v] = 0;
		rep(i, L-len[v]+1, L) Sg(v, i, mul(Gg(u, i-1), i>1?mul(Gf(u, i-1), s[min(i-2, t)]):1));
		while (t < len[v]-1) s[t+1] = s[t], t++;
		rep0(i, len[v]) s[i] = mul(s[i], Gf(v, i)); Ag[v] = 1;
	}
	G[son[u]] = G[u]+1; Ag[son[u]] = Ag[u], Mg[son[u]] = Mg[u], Ig[son[u]] = Ig[u], Lg[son[u]] = min(L+1, Lg[u]+1), Zg[son[u]] = Zg[u];
	rep0(x, e[u].size()) {
		int v = e[u][x], now = son[u];
		while (Lg[now] <= L && Lg[now] < Lf[v]+2) Sg(now, Lg[now], Gg(now, Lg[now])), Lg[now]++;
		rep0(i, Lf[v]) if (i+2>L-len[now] && i+2<=L) Sg(now, i+2, mul(Gg(now, i+2), Gf(v, i)));
		if (w[v]) {
			rep0(i, Lf[v]+2) if (i>L-len[now] && i<=L) Sg(now, i, mul(Gg(now, i), inv[v]));
			Ag[now] = mul(Ag[now], w[v]), Mg[now] = mul(Mg[now], w[v]), Ig[now] = mul(Ig[now], inv[v]);
		} else Lg[now] = max(L-len[now]+1, min(Lf[v]+2, L+1)), Zg[now] = mul(P-Ag[now], Ig[now]);
		get_g(v);
	}
	Ag[son[u]] = inc(Ag[son[u]], 1), get_g(son[u]);
}

int main() {
	n = read(), L = read(), k = read();
	rep(i, 1, n-1) {
		int u = read(), v = read();
		e[u].push_back(v), e[v].push_back(u);
	}
	if (!L) return printf("%d", n), 0;
	dfs(1, 1);
	rep(i, 1, n) e[i].clear();
	per(i, n, 1)
		rep0(j, c[i].size())
			e[c[i][j].x].push_back(c[i][j].y);
	get_inv(); get_f(1);
	initG(); get_g(1);
	return printf("%d", ans), 0;
}

(细节贼多,真心写吐了)

原文地址:https://www.cnblogs.com/ac-evil/p/13156431.html