[被踩计划] 题解 [NOI2020]命运

[被踩计划] 题解 [NOI2020]命运

为什么叫被踩记录呢?因为感觉自己之前真的是太菜了,打算把之前联赛等考过的题目做一做,看看自已以前有多菜,所以取名叫被踩记录。

题目链接

题意简述

给定一棵 (n) 个点的有根树,同时给定 (m) 条链 ((u,v)) (保证 (u)(v) 的祖先),询问有多少种给每条边赋权 ([0,1]) 的方案使得每条链都至少有一条边边权为 (1) ,答案对 (998244353) 取模。

(1le n,mle 5 imes 10^5) ,时限 2s ,空间限制 1GB 。

题目分析

考场上傻乎乎的我想了个容斥,然后就被踩了。

据说这道题目是套路题,自己做的题目果然还是太少了。

不难发现,一个点 (u) 往上的所有链中我们只关心另一个端点最深的那条链,因为如果这条链被覆盖了,那么其它的链都会被覆盖。

(dp(u,i)) 表示考虑完了以 (u) 为根的整个子树中的所有边的取值情况,除了跨越 (u) 节点的链以外其它链都被覆盖,并且其中跨越了 (u) 节点的没有被覆盖的链中最深的链另一个端点的深度为 (i) 的方案数。特殊的,其中 (dp(u,0)) 表示没有跨越 (u) 节点的未被覆盖的链的方案数。 (dp(1,0)) 就是答案。

转移就是考虑子树合并,考虑将子树 (v) 并入子树 (u) ,那么转移就是枚举 (u-v) 的权值,设转移后得到的数组为 (f()) ,取值为 (0) 时:

[f(i)gets f(i)+sum_{j=0}^idp(u,i) imes dp(v,j)+sum_{j=0}^{i-1}dp(u,j) imes dp(v,i) ]

取值为 (1) 时:

[f(i)gets f(i)+sum_{j=0}^{operatorname{deep}_v-1}dp(u,i) imes dp(v,j) ]

总的方程:

[f(i)=sum_{j=0}^idp(u,i) imes dp(v,j)+sum_{j=0}^{i-1}dp(u,j) imes dp(v,i)+sum_{j=0}^{operatorname{deep}_v-1}dp(u,i) imes dp(v,j) ]

设前缀和数组为 (g(u,i)) ,则:

[f(i)=dp(u,i)(g(v,i)+g(v,operatorname{deep}_v-1))+dp(v,i)g(u,i-1) ]

这样如果直接 dp ,那么时间复杂度就是 (mathcal O(n^2))

由于 (dp(u,i)) 的第二维只有若干个位置是有值的(值非 (0) ),这些位置可以认为是它自己向上的链加上它子树内向上的链,所以可以使用线段树来维护所有非 (0) 的位置的值,而 dp 合并的时候就采用线段树合并来实现转移。

具体如何实现?观察转移方程,发现有两个前缀和数组,可以在线段树合并的过程中记录前缀和以供转移。如果出现了某一个节点代表的位置只有 (dp(u,lsim r)) 有值,那么转移方程可以认为是 (f(lsim r)=dp(u,lsim r)(g(v,i)+g(v,operatorname{deep}_v-1))) ,只需要实现区间乘即可;如果出现了某一个节点代表的位置只有 (dp(v,lsim r)) 有值,那么转移方程可以认为是 (f(lsim r)=dp(v,lsim r)g(u,l-1)) ,也只需要实现区间乘即可。如果合并到了叶子节点,就直接按 dp 转移方程合并即可。

还有一点需要注意的是,从儿子节点合并上来的线段树可能有若干个位置是非法的(当 (dp(u,i)) 满足 (ige operatorname{deep}_u) 时,这个状态时非法的),要把非法状态去掉,就需要再进行一遍区间赋 (0) (区间乘以 (0) )。

总的时间复杂度是 (mathcal O(nlog_2n))

参考代码

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ch() getchar()
#define pc(x) putchar(x)
using namespace std;
template<typename T>void read(T&x){
	static char c;static int f;
	for(c=ch(),f=1;c<'0'||c>'9';c=ch())if(c=='-')f=-f;
	for(x=0;c>='0'&&c<='9';c=ch())x=x*10+(c&15);x*=f;
}
template<typename T>void write(T x){
	static char q[65];int cnt=0;
	if(x<0)pc('-'),x=-x;
	q[++cnt]=x%10,x/=10;
	while(x)
		q[++cnt]=x%10,x/=10;
	while(cnt)pc(q[cnt--]+'0');
}
const int mod=998244353,maxn=500005;
int mo(const int x){
	return x>=mod?x-mod:x;
}
struct Edge{
	int v,nt;
	Edge(int v=0,int nt=0):
		v(v),nt(nt){}
}e[maxn*2];
int hd[maxn],num;
void qwq(int u,int v){
	e[++num]=Edge(v,hd[u]),hd[u]=num;
}
int dp[maxn],mx[maxn];
void dfs(int u,int fa){
	dp[u]=dp[fa]+1;
	for(int i=hd[u];i;i=e[i].nt){
		int v=e[i].v;
		if(v==fa)continue;
		dfs(v,u);
	}
}
struct Node{
	int l,r,sum,mul;
	Node(int l=0,int r=0,int sum=0,int mul=1):
		l(l),r(r),sum(sum),mul(mul){}
}P[maxn*25];
int tot;
int build(int l,int r,int p){
	int re=++tot,mid=(l+r)>>1;
	P[re]=Node(0,0,1);
	if(l==r)return re;
	if(p<=mid)P[re].l=build(l,mid,p);
	else P[re].r=build(mid+1,r,p);
	return re;
}
void push(int x,int mul){
	if(!x)return;
	P[x].sum=1ll*P[x].sum*mul%mod;
	P[x].mul=1ll*P[x].mul*mul%mod;
}
void pushdown(int x){
	if(P[x].mul==1)return;
	push(P[x].l,P[x].mul);
	push(P[x].r,P[x].mul);
	P[x].mul=1;
}
void pushup(int x){
	P[x].sum=mo(P[P[x].l].sum+P[P[x].r].sum);
}
int Merge(int x,int y,int smu,int smv){
	if(!x)return push(y,smu),y;
	if(!y)return push(x,smv),x;
	pushdown(x);pushdown(y);
	if(!P[x].l&&!P[x].r)
		P[x].sum=mo(1ll*P[x].sum*mo(smv+P[y].sum)%mod+1ll*P[y].sum*smu%mod);
	else{
		P[x].r=Merge(P[x].r,P[y].r,mo(smu+P[P[x].l].sum),mo(smv+P[P[y].l].sum));
		P[x].l=Merge(P[x].l,P[y].l,smu,smv);pushup(x);
	}
	return x;
}
void cover(int x,int l,int r,int L,int R){
	if(!x||(L<=l&&r<=R))return push(x,0);
	pushdown(x);int mid=(l+r)>>1;
	if(L<=mid)cover(P[x].l,l,mid,L,R);
	if(R>mid)cover(P[x].r,mid+1,r,L,R);
	return pushup(x);
}
int n,rt[maxn];
void print(int x,int l,int r){
	if(l==r)return write(P[x].sum),pc(" 
"[r==n]),void();
	pushdown(x);int mid=(l+r)>>1;
	print(P[x].l,l,mid);print(P[x].r,mid+1,r);
}
void solve(int u,int fa){
	rt[u]=build(0,n,mx[u]);
	for(int i=hd[u];i;i=e[i].nt){
		int v=e[i].v;if(v==fa)continue;solve(v,u);
		rt[u]=Merge(rt[u],rt[v],0,P[rt[v]].sum);
	}
	cover(rt[u],0,n,dp[u],n);
}
int query(int x,int l,int r){
	if(l==r)return P[x].sum;
	pushdown(x);int mid=(l+r)>>1;
	return query(P[x].l,l,mid);
}
int main(){
	read(n);
	for(int i=2;i<=n;++i){
		int u,v;
		read(u),read(v);
		qwq(u,v),qwq(v,u);
	}
	dfs(1,0);
	int m;read(m);
	for(int i=1;i<=m;++i){
		int u,v;
		read(u),read(v);
		mx[v]=max(mx[v],dp[u]);
	}
	solve(1,0);
	write(query(rt[1],0,n)),pc('
');
	return 0;
}

原文地址:https://www.cnblogs.com/lsq147/p/14302756.html