[ZJOI2019]线段树

https://www.luogu.com.cn/blog/ShadowassIIXVIIIIV/solution-p5280
题目问的相当于一颗线段树的每个操作序列的子集在树上的标记数之和
可以转为求期望 这样的好处是每次操作只会影响(O(log n))个节点存在标记的概率
可以把线段树上的点分为三类:
1.不为终止节点且被经过的(这些点会失去标记)
2.得到1类点下传标记的没被经过的节点
3.终止节点

记录(dp_i)表示(i)有标记概率 (g_i)表示(i)祖先(含(i))有标记概率
但是(g)受到影响的节点很多 但可以用lazy tag处理
因为实现的问题需要开八倍空间。。。

#include<bits/stdc++.h>
using namespace std;
#define fp(i,l,r) for(register int (i)=(l);(i)<=(r);++(i))
#define fd(i,l,r) for(register int (i)=(l);(i)>=(r);--(i))
#define fe(i,u) for(register int (i)=front[(u)];(i);(i)=e[(i)].next)
#define mem(a) memset((a),0,sizeof (a))
#define O(x) cerr<<#x<<':'<<x<<endl 
#define int long long
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
	return x*f;
}
void wr(int x){
	if(x<0)putchar('-'),x=-x;
	if(x>=10)wr(x/10);
	putchar('0'+x%10);
}
const int MAXN=8e5+20,mod=998244353,inv2=(mod+1)/2;
inline void tmod(int &x){x%=mod;}
inline int qpow(int a,int b){
	int res=1;a%=mod;
	for(;b;b>>=1,tmod(a*=a))
	if(b&1)tmod(res*=a);
	return res;
}
inline int ginv(int x){return qpow(x,mod-2);}
#define lson o<<1
#define rson o<<1|1
int tag[MAXN],dp[MAXN],g[MAXN],n,m;
int pw[MAXN],a[MAXN],b[MAXN],c,sum[MAXN];
inline void madd(int o,int v){
	tag[o]+=v;tmod(g[o]=(g[o]+a[v])*b[v]);
}
inline void pushdown(int o){
	if(!tag[o])return;
	int v=tag[o];tag[o]=0;
	madd(lson,v);madd(rson,v);
}
inline void pushup(int o){tmod(sum[o]=sum[lson]+sum[rson]+dp[o]);}
inline void calc(int o){tmod(dp[o]=(g[o]+dp[o])*inv2);pushup(o);}
void mdy(int o,int l,int r,int ql,int qr){
	if(l==ql&&r==qr){
		dp[o]=(dp[o]+1)*inv2%mod;madd(o,1);pushup(o);return;
	}
	int mid=l+r>>1;pushdown(o);
	tmod(dp[o]*=inv2);tmod(g[o]*=inv2);
	if(qr<=mid){
		calc(rson);mdy(lson,l,mid,ql,qr);
	}
	else if(ql>mid){
		calc(lson);mdy(rson,mid+1,r,ql,qr);
	}
	else{
		mdy(lson,l,mid,ql,mid);mdy(rson,mid+1,r,mid+1,qr);
	}
	pushup(o);
}
int tot;
main(){
	pw[0]=b[0]=1;fp(i,1,MAXN-1)tmod(pw[i]=pw[i-1]*2),tmod(b[i]=b[i-1]*inv2);
	fp(i,0,MAXN-1)tmod(a[i]=pw[i]+mod-1);
	n=read();m=read();
	fp(i,1,m){
		int op=read();
		if(op==1){
			int l=read(),r=read();
			mdy(1,1,n,l,r);++tot;
		}
		else wr(sum[1]*pw[tot]%mod),putchar('
');
	}
	return 0;
}
原文地址:https://www.cnblogs.com/WinterSpell/p/13275535.html