6543. 【GDOI2020模拟4.4】Easy Data Structure(动态dp)

题目描述


题解

第一次写动态dp,就是把转移变成矩阵然后用数据结构维护

把式子变成树的形式,等于从下往上每次合并儿子

树剖,叶子直接维护概率,非叶子(操作符)维护重儿子是多少时的01变化

儿子只有两个,转移矩阵取决于自己以及轻儿子

这样的好处是每次修改只用修改向上的链顶父亲的矩阵,修改量是O(log)的,不断向上跳即可

不用也无法维护当前子树的概率,直接询问链底到当前点即可得出

注意01的概率要放在[0,0]和[0,1]上

code

#include <bits/stdc++.h>
#define fo(a,b,c) for (a=b; a<=c; a++)
#define fd(a,b,c) for (a=b; a>=c; a--)
#define mod 998244353
#define Mod 998244351
#define ll long long
#define file
using namespace std;

struct type{
	int tp;
	ll x,y,z;
} A[200001];
struct mat{
	ll a[2][2];
	void clear() {memset(a,0,sizeof(a));}
} tr[800001],one,c,S[200001];
int a[200001][2],ls[200001],bg[200001],size[200001],d[200001];
int fa[200001],nx[200001],nx2[200001],top[200001],end[200001],n,Q,i,j,k,l,len,root,tot,h,t;
ll x,y,z;

ll qpower(ll a,int b) {ll ans=1; while (b) {if (b&1) ans=ans*a%mod;a=a*a%mod;b>>=1;} return ans;}
void New(int x,int y) {++len;a[len][0]=y;a[len][1]=ls[x];ls[x]=len;}
mat mul(mat a,mat b)
{
	mat c;
	int i,j,k;
	
	fo(i,0,1)
	{
		fo(j,0,1)
		{
			c.a[i][j]=0;
			fo(k,0,1)
			c.a[i][j]=(c.a[i][j]+a.a[i][k]*b.a[k][j])%mod;
		}
	}
	return c;
}

void work()
{
	int d[200001],i,j,k,l,t=0;
	
	fo(i,1,n)
	{
		switch (A[i].tp)
		{
			case 1:{d[++t]=i;break;}
			case 2:{d[++t]=i;break;}
			case 3:{d[++t]=-1;break;}
			case 4:{
				if (d[t-1]==-1) {d[t-1]=d[t];--t;break;}
				
				j=t;while (d[j-1]!=-1) j-=2;
				New(d[j+1],d[j]); for (k=j; k<t; k+=2) {New(d[k+1],d[k+2]); if (k+3<=t) New(d[k+3],d[k+1]);};
				d[j-1]=d[t-1],t=j-1;
				break;
			}
		}
	}
	if (t>1) {New(d[2],d[1]); for (k=1; k<t; k+=2) {New(d[k+1],d[k+2]); if (k+3<=t) New(d[k+3],d[k+1]);};root=d[t-1];} else root=d[1];
}

void dfs()
{
	int i,mx=0;
	
	h=0;t=1;
	d[1]=root;
	while (h<t)
	{
		for (i=ls[d[++h]]; i; i=a[i][1])
		fa[a[i][0]]=d[h],d[++t]=a[i][0];
	}
	while (t)
	{
		size[d[t]]=1;mx=0;
		for (i=ls[d[t]]; i; i=a[i][1])
		{
			size[d[t]]+=size[a[i][0]];
			if (size[a[i][0]]>mx)
			mx=size[a[i][0]],nx[d[t]]=a[i][0];
		}
		--t;
	}
}
void dfs2()
{
	int i,j,k;
	
	h=0;t=1;
	d[1]=root;
	while (h<t)
	{
		i=j=d[++h];
		while (nx[j]) j=nx[j];
		do{
			for (k=ls[i]; k; k=a[k][1])
			if (a[k][0]!=nx[i])
			nx2[i]=a[k][0],d[++t]=a[k][0];
			
			bg[i]=++tot;
			top[i]=d[h];
			end[i]=j;
			i=nx[i];
		}while (i);
	}
}

void change(int t,int l,int r,int x,mat s)
{
	int mid=(l+r)/2;
	
	if (l==r) {tr[t]=s;return;}
	
	if (x<=mid) change(t*2,l,mid,x,s);
	else change(t*2+1,mid+1,r,x,s);
	
	tr[t]=mul(tr[t*2+1],tr[t*2]);
}
mat find(int t,int l,int r,int x,int y)
{
	int mid=(l+r)/2;
	mat ans=one,s;
	
	if (x<=l && r<=y) {return tr[t];}
	
	if (mid<y) s=find(t*2+1,mid+1,r,x,y),ans=mul(ans,s);
	if (x<=mid) s=find(t*2,l,mid,x,y),ans=mul(ans,s);
	
	return ans;
}

mat js(int t,mat B)
{
	ll a=B.a[0][0],b=B.a[0][1];
	mat c;
	
	c.a[0][0]=(A[t].x+A[t].y*a+A[t].z*a)%mod;
	c.a[0][1]=(A[t].y*b+A[t].z*b)%mod;
	c.a[1][0]=(A[t].x*a+A[t].z*b)%mod;
	c.a[1][1]=(A[t].x*b+A[t].y+A[t].z*a)%mod;
	
	return c;
}

mat get(int t)
{
	return find(1,1,tot,bg[t],bg[end[t]]);
}

void Dfs()
{
	mat c;
	
	h=0;t=1;
	d[1]=root;
	while (h<t)
	{
		for (i=ls[d[++h]]; i; i=a[i][1])
		d[++t]=a[i][0];
	}
	while (t)
	{
		if (nx[d[t]])
		{
			c=js(d[t],S[nx2[d[t]]]);
			S[d[t]]=mul(S[nx[d[t]]],c);
			change(1,1,tot,bg[d[t]],c);
		}
		else
		{
			S[d[t]].a[1][0]=S[d[t]].a[1][1]=0;S[d[t]].a[0][0]=A[d[t]].x;S[d[t]].a[0][1]=A[d[t]].y;
			change(1,1,tot,bg[d[t]],S[d[t]]);
		}
		--t;
	}
}

void Change(int t)
{
	while (t)
	{
		t=top[t];
		if (!fa[t]) break;
		change(1,1,tot,bg[fa[t]],js(fa[t],get(t)));
		t=fa[t];
	}
}

int main()
{
	one.a[0][0]=one.a[1][1]=1;
	
	freopen("structure.in","r",stdin);
	#ifdef file
	freopen("structure.out","w",stdout);
	#endif
	
	scanf("%d%d",&n,&Q);
	fo(i,1,n)
	{
		scanf("%d",&A[i].tp);
		if (A[i].tp==1) scanf("%lld%lld",&A[i].x,&A[i].y),l=qpower((A[i].x+A[i].y)%mod,Mod);
		if (A[i].tp==2) scanf("%lld%lld%lld",&A[i].x,&A[i].y,&A[i].z),l=qpower((A[i].x+A[i].y+A[i].z)%mod,Mod);
		if (A[i].tp<=2) A[i].x=A[i].x*l%mod,A[i].y=A[i].y*l%mod,A[i].z=A[i].z*l%mod;
	}
	
	work();
	dfs();
	tot=0;top[root]=root;
	dfs2();
	Dfs();
	
	for (;Q;--Q)
	{
		scanf("%d%lld%lld",&t,&x,&y);
		if (A[t].tp==2) scanf("%lld",&z); else z=0;
		l=qpower(x+y+z,Mod);
		x=x*l%mod,y=y*l%mod,z=z*l%mod;
		
		switch (A[t].tp)
		{
			case 1:{
				c.a[1][0]=c.a[1][1]=0;c.a[0][0]=x;c.a[0][1]=y;
				change(1,1,tot,bg[t],c);
				break;
			}
			case 2:{
				A[t].x=x;A[t].y=y;A[t].z=z;
				change(1,1,tot,bg[t],js(t,get(nx2[t])));
				break;
			}
		}
		Change(t);
		
		c=get(root);
		printf("%lld
",c.a[0][1]);
	}
	
	fclose(stdin);
	fclose(stdout);
	
	return 0;
}
原文地址:https://www.cnblogs.com/gmh77/p/12684380.html