[做题记录-计数] [九省联考2018]秘密袭击coat

秘密袭击解题报告

远古题解丢一下。

题意

求出树上随机选择的连通块中第(k)大的权值之和。

题解

高级算法综合练习题
吐了
看到第(k)大先(Min-Max)出来, 可以得到一个计算式子。

  • 扩展(Min-Max)容斥
    • [kthmax(S)=sum_{Tsubseteq S}(-1)^{|T|-k}dbinom{|T|-1}{k-1}min(T) ]

这题值域不大,提示我们可以直接考虑计算以一个某个值为最小值的方案数目乘枚举的最小值来求出答案。

(dp[x][i][j])表示以(x)为根的子树选择(i)个点形成连通块且包含(x)这个点,最小值为(j)的方案数目。

考虑写出转移(dp[x][i][j]=sum_{a + b = i,min(c, d) = j} dp[x][a][c] imes dp[y][b][d])

值得注意的是这个转移在细节上有点问题, 这没有包括在儿子中不选的方案数目, 考虑用某种方法补充进去。比如把在(i = 0, j = inf)的时候,我们考虑把后者设为(1),那就可以把当前点自身作为一个连通块的情况计入了。

然后你发现, (a + b = i)这个式子就很(NTT)
然后模数是64123
所以考虑写成生成函数的形式, 把(i)这一维度写到生成函数里面去,(F(x, j))表示(sum_{i= 0}^{inf}dp[x][i][j]x^i)
然后就有转移(F(x,j)=F(x,a) imes (F(y,b) + [j==inf]))
当然这里也要注意刚刚的问题。
边界就是(F(x, va[x])=x, F(x, inf)=1)。(注意一个(x)是节点, 一个(x)是多项式)。
然后根据套路这个东西可以线段树优化转移。但是你发现上面是个卷积貌似不太好搞, 可以考虑维护点值。这样多项式的卷积只需要直接对应项相乘即可。具体实现的时候取(1 -> n +1)(n+ 1)个点值带入, 最后插值回来即可。
答案就是(sum_{x}sum_{j} F(x, j) imes j)
接下来是一些实现问题。

  • 线段树合并维护(dp)

    • 发现是可以看成合并两个序列, 仔细分析一下可以在线段树上维护乘法标记,(f)和,(f imes j)的和实现。
  • 拉格朗日插值法。

    • 可以用背包求出分子的结果。具体的, 设(f[i][j])表示前(i)个数, 当前次数为(j)的这一位上的权值和, 转移就是(f[i][j] = f[i - 1][j - 1] + f[i - 1][j] * (- x[i]))
    • 然后枚举每一个数计算的时候只要反着背包一下,除去贡献就好了。
#include <bits/stdc++.h>
using namespace std;

#define LL long long
//#define int long long

namespace IO {
	const int N = 2e6;
	char buf[N], *p1 = buf, *p2 = buf;
	inline char gc() {
		if(p1 == p2) p2 = (p1 = buf) + fread(buf, 1, N, stdin);
		return p1 == p2 ? EOF : *(p1 ++);
	}
	template <typename T>
	inline void read(T &x) {
		x = 0; bool f = 0; char a = gc(); 
		for(; ! isdigit(a); a = gc()) if(a == '-') f = 1; 
		for(; isdigit(a); a = gc()) x = x * 10 + (a ^ 48); 
		if(f) x = -x; 
	}
	inline int read() { int x; read(x); return x; }
	inline LL readll() { LL x; read(x); return x; }
	inline char readchar() {
		char a = gc();
		while(a == ' ' || a == '
') a = gc(); return a;
	}
	inline int reads(char *s) {
		char *O = s, a = gc();
		while(! isalpha(a) && ! isdigit(a)) a = gc();
		while(isalpha(a) || isdigit(a)) *O = a, O ++, a = gc();
		return O - s;
	}
	inline void outs(int len, char *s, char ch = '
') {
		char *O = s;
		while(len --) putchar(*O), O ++; putchar(ch);
	}
	char Of[105], *O1 = Of, *O2 = Of;
	template <typename T>
	inline void print(T n, char ch = '
') {
		if(n < 0) putchar('-'), n = -n;
		if(n == 0) putchar('0');
		while(n) *(O1 ++) = (n % 10) ^ 48, n /= 10;
		while(O1 != O2) putchar(*(-- O1));
		putchar(ch);
	}
}

using IO :: read;
using IO :: print;
using IO :: reads;
using IO :: outs;
using IO :: readchar;

const int N = 2e3 + 10;
const int P = 64123;

int power(int x, int k) {
	int res = 1;
	while(k) {
		if(k & 1) res = (LL) res * x % P;
		x = (LL) x * x % P; k >>= 1;
	} return res;
}

int fac[N], ifac[N];
void init(int n = 2000) {
	fac[0] = 1;
	for(int i = 1; i <= n; i ++) fac[i] = (LL) fac[i - 1] * i % P;
	ifac[n] = power(fac[n], P - 2);
	for(int i = n - 1; i >= 0; i --) ifac[i] = (LL) ifac[i + 1] * (i + 1) % P;
}

inline int C(int x, int y) {
	return (LL) fac[x] * ifac[y] % P * ifac[x - y] % P;
}

#define ls(x) t[x].ls
#define rs(x) t[x].rs

struct SGT {
	int l, r, ls, rs;
	int mul, sum, val;
}t[N << 5];
int Rt[N], tot;
int build(int l, int r) {
	int x = ++ tot;
	t[x] = {l, r, 0, 0, 1, 0, 0};
	return x;
}

void update(int x) {
	t[x].sum = (t[ls(x)].sum + t[rs(x)].sum) % P;
	t[x].val = (t[ls(x)].val + t[rs(x)].val) % P;
}

void pushdown(int x) {
	int mid = (t[x].l + t[x].r) / 2;
	if(t[x].mul != 1) {
		if(ls(x)) {
			t[ls(x)].val = (LL) t[ls(x)].val * t[x].mul % P;
			t[ls(x)].sum = (LL) t[ls(x)].sum * t[x].mul % P;
			t[ls(x)].mul = (LL) t[ls(x)].mul * t[x].mul % P;
		}
		if(rs(x)) {
			t[rs(x)].val = (LL) t[rs(x)].val * t[x].mul % P;
			t[rs(x)].sum = (LL) t[rs(x)].sum * t[x].mul % P;
			t[rs(x)].mul = (LL) t[rs(x)].mul * t[x].mul % P;
		}
		t[x].mul = 1;
	}
}

int merge(int x, int y, int sx, int sy) {
	if(! x && ! y) return 0;
	if(! x) {
		t[y].mul = (LL) t[y].mul * sx % P;
		t[y].sum = (LL) t[y].sum * sx % P;
		t[y].val = (LL) t[y].val * sx % P;
		return y;
	}
	if(! y) {
		t[x].mul = (LL) t[x].mul * sy % P;
		t[x].sum = (LL) t[x].sum * sy % P;
		t[x].val = (LL) t[x].val * sy % P;
		return x;
	}
	pushdown(x); pushdown(y);
	if(t[x].l == t[x].r) {
		t[x].sum = ((LL) t[x].sum * sy % P + (LL) sx * t[y].sum % P + (LL) t[x].sum * t[y].sum % P) % P;
		t[x].val = (LL) t[x].sum * t[x].l % P;
		return x;
	}
	int sx1 = (sx + t[rs(x)].sum) % P;
	int sy1 = (sy + t[rs(y)].sum) % P;
	t[x].ls = merge(t[x].ls, t[y].ls, sx1, sy1);
	t[x].rs = merge(t[x].rs, t[y].rs, sx, sy);
	update(x); return x;
}

void modify(int x, int pos, int v) {
	if(t[x].l == t[x].r) {
		t[x].sum += v; t[x].sum %= P; 
		t[x].val = (LL) t[x].sum * t[x].l % P; 
		return ;
	}
	pushdown(x);
	int mid = (t[x].l + t[x].r) >> 1;
	if(pos <= mid) {
		if(! ls(x)) ls(x) = build(t[x].l, mid);
		modify(ls(x), pos, v);
	}
	else {
		if(! rs(x)) rs(x) = build(mid + 1, t[x].r);
		modify(rs(x), pos, v);
	}
	update(x);
}

int va[N];
struct edge {
	int to, next;
}e[N << 1];
int cnt, head[N];
void add(int x, int y) {
	e[++ cnt] = {y, head[x]}; head[x] = cnt;
}

int n, m, w;
int y[N];

void dfs(int x, int fx, int k) {
	Rt[x] = build(1, w + 1);
	modify(Rt[x], va[x], k); modify(Rt[x], w + 1, 1);
	for(int i = head[x]; i; i = e[i].next) {
		int y = e[i].to;
		if(y == fx) continue;
		dfs(y, x, k);
		modify(Rt[y], w + 1, 1);
		Rt[x] = merge(Rt[x], Rt[y], 0, 0);
	}
	y[k] = (y[k] + t[Rt[x]].val) % P;
}

int f[N], ans[N];

signed main() {
	#ifdef IN
	freopen("a.in", "r", stdin); 
	//freopen("a.out", "w", stdout);
	#endif
	init();
	n = read(); m = read(); w = read();
	for(int i = 1; i <= n; i ++) va[i] = read();
	for(int i = 1; i < n; i ++) {
		int x = read(), y = read();
		add(x, y); add(y, x);
	}
	for(int i = 1; i <= n + 1; i ++) {
		tot = 0;
		dfs(1, 0, i);
	}
	f[0] = 1;
	for(int i = 1; i <= n + 1; i ++) {
		for(int j = n + 1; j; j --)
			f[j] = ((LL) (P - i) * f[j] % P + f[j - 1]) % P;
		f[0] = (LL) f[0] * (P - i) % P;
	}
	for(int i = 1; i <= n + 1; i ++) {
		int inv = power(P - i, P - 2);
		f[0] = (LL) f[0] * inv % P;
		for(int j = 1; j <= n + 1; j ++)
			f[j] = (LL) (f[j] - f[j - 1] + P) * inv % P;
		int tmp = y[i];
		for(int j = 1; j <= n + 1; j ++) {
			if(i == j) continue;
			tmp = (LL) tmp * power(i - j + P, P - 2) % P;
		}
		for(int j = 0; j <= n; j ++) ans[j] = (ans[j] + (LL) f[j] * tmp % P) % P;
		for(int j = n + 1; j; j --)
			f[j] = ((LL) (P - i) * f[j] + f[j - 1]) % P;
		f[0] = (LL) f[0] * (P - i) % P;
	}
	int Ans = 0;
	for(int i = m; i <= n; i ++)
		Ans = (Ans + (LL) ans[i] * ((LL) ((i - m) & 1 ? (P - 1) : 1) * C(i - 1, m - 1)) % P) % P;
	print(Ans);
	return 0;
}
原文地址:https://www.cnblogs.com/clover4/p/15340054.html