[十二省联考2019]希望

(O(nL))(DP)很普及组吧。点减边容斥,设(f_{u,i}-1)为在(u)子树内选出一个连通块,使得它包含(u)且最深点距离(u)(i)的方案数,(g_{u,i})表示选出一个连通块,使得它包含(u)且不包含(u)子树内除(u)以外的点,距离(u)最远的点距离为(i)的方案数。

显然(f)可以用长链剖分优化,也很普及组,就不说了。

(g)其实也是可以用长链剖分优化的。将当前点的(g)转移到重儿子时像(f)一样直接继承,转移到轻儿子时,我们发现最后计算答案只用到了所有的(g_{u,L}),所以只有(g_{u,L-dep[u]}sim g_{u,L})这部分的值是有用的,暴力转移这一部分就行了,复杂度和长链剖分一样分析。

(g)时有一个问题,就是需要求一个点的所有子树除掉某个子树外的(f)的积。暴力的想法是维护可持久化线段树,复杂度(O(nlogn))很不优秀。不过可以发现可持久化是假的,只要可回退化就行了。具体的,在求(f)的dfs中,每次合并一个子树时,我们记录下来合并这个子树时对(f)相关信息进行的修改。在求(g)的dfs中,我们反序遍历每个点的儿子,这样就可以栈序撤销,维护每一个前缀的(f)的积,后缀再开另一个数组记录一下就行了。

不过这样还是需要线段树。我们发现我们要支持的是全局加,后缀乘。我们可以对每个点记录标记(a,b),表示真实值为(a*)存储值+(b)。全局加时修改(b)即可,后缀乘时修改(a,b),并暴力修改没有被后缀乘影响到的元素,乘一个逆元即可。如果(a)(0),那么(a)不存在逆元,就只能将后缀乘变为后缀赋值为(-frac{b}{a})。所以还要记录标记(p,t),表示所有(>=p)的位置被赋值为(t),每次修改一个元素时下放后缀赋值标记。这样就能做到(O(n))了。

要真正做到(O(n)),还需要线性求逆元。我们发现要求的逆元都是(f_{u,dep[u]})的形式。把这些值排成一个数组记为(a),并求出(a)的前缀积(b)。我们算出(b_n^{-1}),递推出所有(b^{-1}),再利用(b)(b^{-1})计算(a^{-1})即可。

实现时有几个trick:

1.修改本质都是对内存位置的修改,所以需要支持撤销时只要保存内存中哪个位置原来的值是多少就行了。

2.stack空间占用很大,可以用list代替。

3.给(g)分配空间时可以让(g_u)的起始位置指向(p-(L-dep[u])),其中(p)是当前内存池中第一个可用位置,这样访问(g_{u,L-dep[u]})时就会访问位置(p)了。

出题人真良心,代码只要写3.3K。

#include<bits/stdc++.h>
#define foo int a[N],ia[N],b[N],p[N],t[N];int trs(int u,int w){return 1ll*J(w,b[u])*ia[u]%M;}
using namespace std;const int N=1e6+9;const int M=998244353;int gi(){int x;cin>>x;return x;}int P(int a,int b){return a+b>=M?a+b-M:a+b;}int J(int a,int b){return a-b<0?a-b+M:a-b;}void I(int&a,int b){a=a+b>=M?a+b-M:a+b;}void K(int&a,int b){a=a-b<0?a-b+M:a-b;}int qpow(int a,int b){int ret=1;while(b){if(b&1)ret=1ll*ret*a%M;a=1ll*a*a%M,b>>=1;}return ret;}vector<int>E[N];int n,L,k,U[N],O[N],V[N],ans=0,s[N],m=0,ff[N*10],*f[N],*pf=ff,gg[N*10],*g[N],*pg=gg,h[N];namespace F{foo int val(int u,int i){i=min(i,U[u]);return P(1ll*a[u]*(i<p[u]?f[u][i]:t[u])%M,b[u]);}}namespace G{foo int val(int u,int i){return P(1ll*a[u]*(i<p[u]?g[u][i]:t[u])%M,b[u]);}}bool cmp(int a,int b){return U[a]>U[b];}void D1(int u,int fa){U[u]=-1,V[u]=1;for(auto v:E[u])if(v!=fa)D1(v,u),U[v]>U[u]?O[u]=v,U[u]=U[v]:0,V[u]=1ll*V[u]*V[v]%M;sort(E[u].begin(),E[u].end(),cmp);++U[u],I(V[u],1);}void New(int u){f[u]=pf;pf+=U[u]+2;pg+=U[u]+2;g[u]=pg-max(0,L-U[u]);pg+=U[u]+2;}struct op{int x,*p;};list<op>st[N];void S(int u,int &x){st[u].push_back((op){x,&x});}void undo(int u){op t;while(!st[u].empty())t=st[u].back(),*t.p=t.x,st[u].pop_back();}void D2(int u,int fa){using namespace F;int z=0;p[u]=U[u]+1;a[u]=ia[u]=b[u]=1;if(O[u])f[z=O[u]]=f[u]+1,D2(z,u),a[u]=a[z],ia[u]=ia[z],b[u]=b[z],p[u]=p[z]+1,t[u]=t[z],f[u][0]=trs(u,1);for(auto v:E[u])if(v!=fa&&v!=O[u]){z=v,New(v),D2(v,u);for(int i=0;i<=U[v]+1;i++){if(p[u]==i)S(v,p[u]),S(v,f[u][i]),f[u][p[u]++]=t[u];S(v,f[u][i]),f[u][i]=trs(u,1ll*val(u,i)*(i?val(v,i-1):1)%M);}if(U[u]>U[v]+1){int w=val(v,U[v]);if(w){S(v,a[u]),S(v,ia[u]),S(v,b[u]),a[u]=1ll*a[u]*w%M,b[u]=1ll*b[u]*w%M,ia[u]=1ll*ia[u]*V[v]%M;for(int i=0;i<=U[v]+1;i++)S(v,f[u][i]),f[u][i]=trs(u,1ll*val(u,i)*V[v]%M);}else S(v,p[u]),S(v,t[u]),p[u]=U[v]+1,t[u]=trs(u,0);}}if(z)S(z,b[u]);I(b[u],1);}void D3(int u,int fa){using namespace G;reverse(E[u].begin(),E[u].end());int sx=0,prd=1;h[0]=1;I(ans,qpow(1ll*J(F::val(u,L),1)*val(u,L)%M,k));if(fa)K(ans,qpow(1ll*J(F::val(u,L-1),1)*J(val(u,L),1)%M,k));for(auto v:E[u])if(v!=fa&&v!=O[u]){undo(v);p[v]=L+1;a[v]=ia[v]=1;for(int i=max(0,L-U[v]);i<=L;i++)g[v][i]=P(1ll*(i?val(u,i-1):0)*(i>1?1ll*F::val(u,i-1)*(i-2<=sx?h[i-2]:prd)%M:1)%M,1);for(int i=0;i<=U[v];i++)h[i]=1ll*(i>sx?prd:h[i])*F::val(v,i)%M;sx=U[v];prd=1ll*prd*F::val(v,U[v])%M;}int z;if(O[u]){g[z=O[u]]=g[u]-1;a[z]=a[u];b[z]=b[u];ia[z]=ia[u];t[z]=t[u];p[z]=p[u]+1,L-U[z]<=0?g[z][0]=trs(z,0):0;for(auto v:E[u])if(v!=fa&&v!=O[u]){p[z]=max(p[z],L-U[z]);for(int i=max(0,L-U[z]);i<=min(L,U[v]+2);i++){if(p[z]==i)g[z][p[z]++]=t[z];g[z][i]=trs(z,1ll*val(z,i)*(i>1?F::val(v,i-2):1)%M);}if(L>U[v]+2){int w=F::val(v,U[v]);if(w){a[z]=1ll*a[z]*w%M,b[z]=1ll*b[z]*w%M,ia[z]=1ll*ia[z]*V[v]%M;for(int i=max(0,L-U[z]);i<=min(L,U[v]+2);i++)g[z][i]=trs(z,1ll*val(z,i)*V[v]%M);}else p[z]=U[v]+2,t[z]=trs(z,0);}}I(b[z],1);D3(z,u);}for(auto v:E[u])if(v!=fa&&v!=O[u])D3(v,u);}int main(){n=gi(),L=gi(),k=gi();if(!L)return printf("%d
",n),0;for(int i=1,u,v;i<n;i++)u=gi(),v=gi(),E[u].push_back(v),E[v].push_back(u);D1(1,0);s[0]=1;for(int i=1;i<=n;i++)if(V[i])++m,s[m]=1ll*s[m-1]*V[i]%M;s[m]=qpow(s[m],M-2);for(int i=n,t;i;i--)if(V[i])--m,t=V[i],V[i]=1ll*s[m+1]*s[m]%M,s[m]=1ll*s[m+1]*t%M;New(1),D2(1,0);G::a[1]=G::ia[1]=G::b[1]=1;G::p[1]=L+1;D3(1,0);printf("%d
",ans);return 0;}

正常代码:

//HNOIday1t1出题人nmsl
#include<bits/stdc++.h>
using namespace std;
const int N=1e6+10;
const int mod=998244353;

int gi() {
	int x=0,o=1;char ch=getchar();
	while((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
	if(ch=='-') o=-1,ch=getchar();
	while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
	return x*o;
}

int add(int a,int b) {
	return a+b>=mod?a+b-mod:a+b;
}

int sub(int a,int b) {
	return a-b<0?a-b+mod:a-b;
}

void inc(int &a,int b) {
	a=a+b>=mod?a+b-mod:a+b;
}

void dec(int &a,int b) {
	a=a-b<0?a-b+mod:a-b;
}

int qpow(int a,int b) {
	int ret=1;
	while(b) {
		if(b&1) ret=1ll*ret*a%mod;
		a=1ll*a*a%mod,b>>=1;
	}
	return ret;
}

vector<int> E[N];
int n,L,k,dep[N],son[N],inv[N],ans=0,s[N],m=0,ff[N*10],*f[N],*pf=ff,gg[N*10],*g[N],*pg=gg,h[N];
vector<int> vis[N];

#define foo int a[N],ia[N],b[N],p[N],t[N];		
	int trs(int u,int w) {						
		return 1ll*sub(w,b[u])*ia[u]%mod;		
	}											

namespace F {

	foo

	int val(int u,int i) {
		i=min(i,dep[u]);
		return add(1ll*a[u]*(i<p[u]?f[u][i]:t[u])%mod,b[u]);
	}
	
}

namespace G {

	foo

	int val(int u,int i) {
		return add(1ll*a[u]*(i<p[u]?g[u][i]:t[u])%mod,b[u]);
	}
	
}

bool cmp(int a,int b) {
	return dep[a]>dep[b];
}

void dfs1(int u,int fa) {
	dep[u]=-1,inv[u]=1;
	for(auto v:E[u])
		if(v!=fa) dfs1(v,u),dep[v]>dep[u]?son[u]=v,dep[u]=dep[v]:0,inv[u]=1ll*inv[u]*inv[v]%mod;
	if(son[u]) {
		int mx=0;
		for(auto v:E[u]) if(v!=fa&&v!=son[u]) mx=max(mx,dep[v]),vis[dep[v]].push_back(v);
		E[u]=vector<int>(1,son[u]);
		for(int i=mx;~i;i--) for(auto v:vis[i]) E[u].push_back(v);
		for(int i=mx;~i;i--) vis[i].clear();
	}
	++dep[u],inc(inv[u],1);
}

void New(int u) {
	f[u]=pf;pf+=dep[u]+2;pg+=dep[u]+2;g[u]=pg-max(0,L-dep[u]);pg+=dep[u]+2;
}

struct op { int x,*p; };
list<op> st[N];

void S(int u,int &x) {
	st[u].push_back((op){x,&x});
}

void undo(int u) {
	op t;
	while(!st[u].empty()) t=st[u].back(),*t.p=t.x,st[u].pop_back();
}

void dfs2(int u,int fa) {
	using namespace F;
	int z=0;p[u]=dep[u]+1;a[u]=ia[u]=b[u]=1;
	if(son[u]) f[z=son[u]]=f[u]+1,dfs2(z,u),a[u]=a[z],ia[u]=ia[z],b[u]=b[z],p[u]=p[z]+1,t[u]=t[z],f[u][0]=trs(u,1);
	for(auto v:E[u])
		if(v!=fa&&v!=son[u]) {
			z=v,New(v),dfs2(v,u);
			for(int i=0;i<=dep[v]+1;i++) {
				if(p[u]==i) S(v,p[u]),S(v,f[u][i]),f[u][p[u]++]=t[u];
				S(v,f[u][i]),f[u][i]=trs(u,1ll*val(u,i)*(i?val(v,i-1):1)%mod);
			}
			if(dep[u]>dep[v]+1) {
				int w=val(v,dep[v]);
				if(w) {
					S(v,a[u]),S(v,ia[u]),S(v,b[u]),a[u]=1ll*a[u]*w%mod,b[u]=1ll*b[u]*w%mod,ia[u]=1ll*ia[u]*inv[v]%mod;
					for(int i=0;i<=dep[v]+1;i++) S(v,f[u][i]),f[u][i]=trs(u,1ll*val(u,i)*inv[v]%mod);
				}
				else S(v,p[u]),S(v,t[u]),p[u]=dep[v]+1,t[u]=trs(u,0);
			}
		}
	if(z) S(z,b[u]);inc(b[u],1);
}

void dfs3(int u,int fa) {
	using namespace G;
	reverse(E[u].begin(),E[u].end());
	int sx=0,prd=1;h[0]=1;
	inc(ans,qpow(1ll*sub(F::val(u,L),1)*val(u,L)%mod,k));
	if(fa) dec(ans,qpow(1ll*sub(F::val(u,L-1),1)*sub(val(u,L),1)%mod,k));
	for(auto v:E[u])
		if(v!=fa&&v!=son[u]) {
			undo(v);p[v]=L+1;a[v]=ia[v]=1;
			for(int i=max(0,L-dep[v]);i<=L;i++) g[v][i]=add(1ll*(i?val(u,i-1):0)*(i>1?1ll*F::val(u,i-1)*(i-2<=sx?h[i-2]:prd)%mod:1)%mod,1);
			for(int i=0;i<=dep[v];i++) h[i]=1ll*(i>sx?prd:h[i])*F::val(v,i)%mod;
			sx=dep[v];prd=1ll*prd*F::val(v,dep[v])%mod;
		}
	int z;
	if(son[u]) {
		g[z=son[u]]=g[u]-1;a[z]=a[u];b[z]=b[u];ia[z]=ia[u];t[z]=t[u];p[z]=p[u]+1,L-dep[z]<=0?g[z][0]=trs(z,0):0;
		for(auto v:E[u])
			if(v!=fa&&v!=son[u]) {
				p[z]=max(p[z],L-dep[z]);
				for(int i=max(0,L-dep[z]);i<=min(L,dep[v]+2);i++) {
					if(p[z]==i) g[z][p[z]++]=t[z];
					g[z][i]=trs(z,1ll*val(z,i)*(i>1?F::val(v,i-2):1)%mod);
				}
				if(L>dep[v]+2) {
					int w=F::val(v,dep[v]);
					if(w) {
						a[z]=1ll*a[z]*w%mod,b[z]=1ll*b[z]*w%mod,ia[z]=1ll*ia[z]*inv[v]%mod;
						for(int i=max(0,L-dep[z]);i<=min(L,dep[v]+2);i++) g[z][i]=trs(z,1ll*val(z,i)*inv[v]%mod);
					}
					else p[z]=dep[v]+2,t[z]=trs(z,0);
				}
			}
		inc(b[z],1);dfs3(z,u);
	}
	for(auto v:E[u]) if(v!=fa&&v!=son[u]) dfs3(v,u);
}

int main() {
#ifndef ONLINE_JUDGE
	freopen("a.in","r",stdin);
	freopen("a.out","w",stdout);
#endif
	n=gi(),L=gi(),k=gi();
	if(L==0) return printf("%d
",n),0;
	for(int i=1,u,v;i<n;i++) u=gi(),v=gi(),E[u].push_back(v),E[v].push_back(u);
	dfs1(1,0);
	s[0]=1;for(int i=1;i<=n;i++) if(inv[i]) ++m,s[m]=1ll*s[m-1]*inv[i]%mod;
	s[m]=qpow(s[m],mod-2);
	for(int i=n,t;i;i--) if(inv[i]) --m,t=inv[i],inv[i]=1ll*s[m+1]*s[m]%mod,s[m]=1ll*s[m+1]*t%mod;
	New(1),dfs2(1,0);G::a[1]=G::ia[1]=G::b[1]=1;G::p[1]=L+1;dfs3(1,0);
	printf("%d
",ans);
	return 0;
}
原文地址:https://www.cnblogs.com/gczdajuruo/p/10712624.html