[九省联考2018]秘密袭击coat

CXLV.[九省联考2018]秘密袭击coat

首先先讲一种暴力但能过的方法。

很容易就会往每个值各被计算几次的方向去想。于是我们枚举每个节点,计算有多少种可能下该节点是目标节点。

为了避免相同的值的影响,我们在值相同的点间也决出一种顺序,即,若两个值相同的点在作比较,依照上文定下的那种顺序决定。

于是我们考虑从该枚举的点 \(x\) 出发,遍历整棵子树同时DP。设 \(f_{i,j}\) 表示 \(i\) 子树中有 \(j\) 个点的危险程度 \(\geq d_x\)。于是就直接背包转移就行了。

看上去复杂度是 \(O(n^3)\),但是加上下述两个优化就可以过了:

  1. 第二维最大只枚举到 \(m\)(这里的 \(m\) 即题面中的 \(k\),因为 \(k\) 这个字母我们接下来还要用)

  2. 第二维最大只枚举到子树大小 \(sz\)

然后就过了,跑的还比正解都要快。

代码:

#include<bits/stdc++.h>
using namespace std;
const int mod=64123;
int n,m,W,d[2010],f[2010][2010],p[2010],q[2010],res,sz[2010];
vector<int>v[2010];
void dfs(int x,int fa,int lim){
	for(int i=0;i<=sz[x];i++)f[x][i]=0;
	f[x][sz[x]=(q[x]>=lim)]=1;
	for(auto y:v[x]){
		if(y==fa)continue;
		dfs(y,x,lim);
//		printf("%d:",x);for(int i=0;i<=m;i++)printf("%d ",f[x][i]);puts("");
//		printf("%d:",y);for(int i=0;i<=m;i++)printf("%d ",f[y][i]);puts("");
		for(int i=sz[x];i>=0;i--)for(int j=min(m-i,sz[y]);j>=0;j--)(f[x][i+j]+=1ll*f[x][i]*f[y][j]%mod)%=mod;
		sz[x]=min(sz[x]+sz[y],m);
//		printf("%d:",x);for(int i=0;i<=m;i++)printf("%d ",f[x][i]);puts("\n");
	}
}
int main(){
	scanf("%d%d%d",&n,&m,&W);
	for(int i=1;i<=n;i++)scanf("%d",&d[i]),p[i]=i;
	sort(p+1,p+n+1,[](int u,int v){return d[u]<d[v];});
	for(int i=1;i<=n;i++)q[p[i]]=i;
	for(int i=1,x,y;i<n;i++)scanf("%d%d",&x,&y),v[x].push_back(y),v[y].push_back(x);
	for(int i=1;i<=n;i++){
//		printf("%d\n",q[i]);
		if(q[i]>n-m+1)continue;
		dfs(i,0,q[i]);
		(res+=1ll*d[i]*f[i][m]%mod)%=mod;
	}
	printf("%d\n",res);
	return 0;
}

然后是正解。

我们要求 \(\sum\limits_{\mathbb S\subseteq\mathbb T}\text{Kth of }\mathbb S\)

考虑枚举该 \(\text{Kth}\) 的值为 \(i\),则要求 \(\sum\limits_{i=1}i\sum\limits_{\mathbb S\subseteq\mathbb T}[\text{Kth of }\mathbb S=i]\)

考虑让每个 \(i\) 拆分成在所有的 \(j\leq i\) 的位置上各被计算一次,则要求 \(\sum\limits_{i=1}\sum\limits_{\mathbb S\subseteq\mathbb T}[\text{Kth of }\mathbb S\geq i]\)。(这同时也是期望DP的经典套路)

考虑令 \(cnt_i\) 表示 \(\geq i\) 的数的总数。则要求 \(\sum\limits_{i=1}\sum\limits_{\mathbb S\subseteq\mathbb T}[cnt_i\geq m]\)。(注意这里的 \(m\) 即为 \(k\)

考虑对于每个连通块,在树上最高处计算它的贡献。设 \(f_{i,j,k}\) 表示以 \(i\) 为根的子树内,当前统计的是 \(cnt_j\),且 \(cnt_j=k\) 的方案数。转移是裸的背包卷积。

考虑如何求答案。因为我们是在最高处计算贡献,所以就要求 \(\sum\limits_{i=1}^n\sum\limits_{j=1}^W\sum\limits_{k=m}^nf_{i,j,k}\)

因为我们是卷积,所以考虑FFT转移。又因为它一直在卷,所以我们干脆考虑压根不把它复原成系数式,就纯粹用点集表示。

更准确地说,因为 \(f_{i,j}\) 的生成函数是 \(\sum\limits_{k=1}^nf_{i,j,k}x^k\),一个 \(n\) 次多项式,所以我们直接枚举 \(x\in[1,n+1]\),然后分别求出这时的生成函数的值,最后拉格朗日插值一下就插回系数式了。

则,在合并父亲 \(x\) 和儿子 \(y\)\(f\) 数组时,因为是点集式,所以直接对应位置相乘就行了。

但是就算有了这么伟大的思路,我们算算复杂度,还是 \(O(n^3)\) 的。有没有更棒的复杂度?

我们发现,最终要对所有 \(x\)\(f\) 数组的和,倒不如在正常处理的过程中就顺便维护了。于是我们设 \(g_{i,j}=\sum\limits_{k\in\text{subtree}_i}f_{k,j}\),则最终要求的就是 \(\sum\limits_{i=m}^ng_{1,i}\)。当然,依据分配律,我们还是可以直接一股脑求出 \(\sum\limits_{i=1}^ng_{1,i}\),待插出系数式后再取 \(\geq m\) 的项。

我们思考当合并 \(f_{x,i}\)\(f_{y,i}\) 时会发生什么:

\(f_{x,i}\rightarrow f_{x,i}\times(f_{y,i}+1)\)(采用 \(f_{y,i}\) 或不用)

\(g_{x,i}\rightarrow g_{x,i}+g_{y,i}\)

在全体结束后,再有 \(g_{x,i}\rightarrow g_{x,i}+f_{x,i}\)

同时,为了便于合并,我们采用线段树合并来维护DP数组。(这种操作被称作整体DP

我们考虑初始化,发现假如我们最外层枚举的点值是 \(X\),则所有 \(\forall i\leq d_x\)\(f_{x,i}\) 在结束时都要乘上一个 \(X\)

明显这个初始状态大体是区间的,非常适合线段树合并。

但是,就算这样,暴力合并的复杂度仍然 \(O(n^3)\),必须考虑在区间上把它做掉。

于是,一个天才般的想法诞生了:

观察到每次的变换都是 \((f,g)\rightarrow(af+b,cf+d+g)\) 的形式。

而这个变换,可以用四元组 \((a,b,c,d)\) 唯一刻画。

同时,展开式子,就会发现这个变换具有结合律(虽然很遗憾的是,大部分情形下它不具有交换律)。

假如我们初始令 \(b=f,d=g\) 的话,就会发现,做一次上述操作,它就自动帮你更新了 \(f\)\(g\)

于是,我们把它看作区间的 tag,然后线段树合并就非常简单了。

同时,要注意的是,其单位元是 \((1,0,0,0)\)

我们来总结一下操作:当合并的时候,我们希望 \(f_x\rightarrow f_x\times(f_y+1)\),而 \(f_y+1\) 可以通过在遍历完 \(y\) 的子树后打上全体 \(+1\)tag 解决,当然这里不需要额外增加其它的 tag,我们发现 \((1,1,0,0)\) 刚好胜任了这个操作。于是现在 \(f_x\rightarrow f_{x}\times f_y\)\((f_y,0,0,0)\)tag 即可(需要注意的是,\(f_y\) 是线段树上 \(y\) 处的 \(b\))。\(g_{x}\rightarrow g_x+g_y\)\((1,0,0,g_y)\) 即可,而 \(g_y\) 则是 \(d\)。两个乘起来,就是使用 \((b,0,0,d)\)

最后合并 \(f\)\(g\) 的时候,则要使用 \((1,0,1,0)\),意义通过展开即可得到就是将 \(f\) 加到 \(g\)。而乘上 \(X\) 的操作,使用 \((X,0,0,0)\) 即可。

需要注意的是,这里并不能标记永久化,主要是因为从四元组中抽出 \(b\)\(d\) 的操作并非线性变换,不能打到 tag 上去,在线段树合并的时候要先一路下传到某一方已经没有叶子了再合并。

同时,使用 unsigned int 可以刚好把 64123 卡进一次乘法内不爆。

代码(一些比较疑惑的地方已经加了注释):

#include<bits/stdc++.h>
using namespace std;
#define int unsigned int
const int mod=64123;
int n,m,W,d[2010],X,cnt,bin[5010000],tp,rt[2010],a[2010];
struct dat{//(f,g)->(af+b,cf+d+g)
	int a,b,c,d;
	dat(){a=1,b=c=d=0;}
	dat(int A,int B,int C,int D){a=A,b=B,c=C,d=D;}
	friend dat operator*(const dat&u,const dat&v){return dat((u.a*v.a)%mod,(u.b*v.a+v.b)%mod,(u.a*v.c+u.c)%mod,(u.b*v.c+u.d+v.d)%mod);}
	void operator*=(const dat&v){(*this)=(*this)*v;}
	void print()const{printf("(%u %u %u %u)\n",a,b,c,d);}
};
#define mid ((l+r)>>1)
int newnode(){return tp?bin[tp--]:++cnt;}
struct SegTree{
	int lson,rson;
	dat tag;
}seg[5010000];
void erase(int &x){if(x)seg[x].tag=dat(),erase(seg[x].lson),erase(seg[x].rson),bin[++tp]=x,x=0;}//erase all the subtree of x.
void pushdown(int x){
	if(!seg[x].lson)seg[x].lson=newnode();
	if(!seg[x].rson)seg[x].rson=newnode();
	seg[seg[x].lson].tag*=seg[x].tag,seg[seg[x].rson].tag*=seg[x].tag,seg[x].tag=dat();
}
void modify(int &x,int l,int r,int L,int R,dat val){
	if(l>R||r<L)return;
	if(!x)x=newnode();
	if(L<=l&&r<=R){seg[x].tag*=val;return;}
	pushdown(x),modify(seg[x].lson,l,mid,L,R,val),modify(seg[x].rson,mid+1,r,L,R,val);
}
void merge(int &x,int &y){
	if(!seg[x].lson&&!seg[x].rson)swap(x,y);
	if(!seg[y].lson&&!seg[y].rson){seg[x].tag*=dat(seg[y].tag.b,0,0,seg[y].tag.d);return;}
	pushdown(x),pushdown(y),merge(seg[x].lson,seg[y].lson),merge(seg[x].rson,seg[y].rson);
}
int query(int x,int l,int r){
	if(l==r)return seg[x].tag.d;
	pushdown(x);
	return (query(seg[x].lson,l,mid)+query(seg[x].rson,mid+1,r))%mod;
}
void iterate(int x,int l,int r){
	if(!x)return;
	printf("%u:[%u,%u]\n",x,l,r);seg[x].tag.print();
	iterate(seg[x].lson,l,mid),iterate(seg[x].rson,mid+1,r);
}
vector<int>v[2010];
void dfs(int x,int fa){
	modify(rt[x],1,W,1,W,dat(0,1,0,0));//set all into (0,1,0,0),which means only f=1.
	for(auto y:v[x])if(y!=fa)dfs(y,x),merge(rt[x],rt[y]),erase(rt[y]);
	modify(rt[x],1,W,1,d[x],dat(X,0,0,0));//those <=d[x] are multiplied by an X
	modify(rt[x],1,W,1,W,dat(1,1,1,0));
	//product of (1,0,1,0) and (1,1,0,0), first means add f to g(to calculate the real g), second means add 1 to f (stands for x itself not chosen at x's father)
}
int all[2010],tmp[2010],res;
int ksm(int x,int y=mod-2){int z=1;for(;y;y>>=1,x=x*x%mod)if(y&1)z=z*x%mod;return z;}
void Lagrange(){
	all[0]=1;
	for(int i=1;i<=n+1;i++)for(int j=i-1;j<=i;j--)(all[j+1]+=all[j])%=mod,(all[j]*=mod-i)%=mod;//note that j is unsigned!!!
//	for(int i=0;i<=n+1;i++)printf("%u ",all[i]);puts("");
	for(int i=1;i<=n+1;i++){
		int inv=ksm(mod-i),sum=0;
		for(int j=0;j<=n;j++)tmp[j]=all[j];
		for(int j=0;j<=n;j++)(tmp[j]*=inv)%=mod,(tmp[j+1]+=mod-tmp[j])%=mod;
//		if(i>=1410){for(int j=0;j<=n;j++)printf("%u ",tmp[j]);puts("");}
		for(int j=m;j<=n;j++)sum+=tmp[j];sum%=mod;
		for(int j=1;j<=n+1;j++)if(j!=i)(sum*=ksm((i-j+mod)%mod))%=mod;
		res+=sum*a[i]%mod;
	}
	res%=mod;
}
signed main(){
	scanf("%u%u%u",&n,&m,&W);
	for(int i=1;i<=n;i++)scanf("%u",&d[i]);
	for(int i=1,x,y;i<n;i++)scanf("%u%u",&x,&y),v[x].push_back(y),v[y].push_back(x);
	for(X=1;X<=n+1;X++)dfs(1,0),a[X]=query(rt[1],1,W),erase(rt[1]);
//	for(int i=1;i<=n+1;i++)printf("%u ",a[i]);puts("");
	Lagrange();printf("%u\n",res);
	return 0;
}

原文地址:https://www.cnblogs.com/Troverld/p/14601691.html