[ZJOI2019]线段树

题目大意

一开始有一棵线段树,然后有一个操作序列,问执行这个操作序列的所有子集时线段树上有标记的节点个数和。

题解

其实我们把它除以(2^m)后发现就是有标记节点的期望个数。

然后套路的根据期望的线性性,我们要统计所有点有标记的概率和。

然后我们来讨论一些情况:

1、当前节点和修改区间没有交且当前节点的父亲节点也没有交,那么这个点的标记就不会动。

2、当前节点被修改区间包含且父亲节点也被包含,那根本碰不到这个节点,也不会动。

3、当前节点被修改区间包含且父亲节点没有被包含,那么这个节点一定会有标记。

4、当前节点和修改区间有交但不包含,那么这个点一定没有标记。

5、这个点和修改区间没有交但是父亲有,那么这个点有没有标记取决于这个点的祖先节点(包括自己)有没有标记。

我们观察到第5种情况需要考虑到祖先节点是否有标记,所以我们设一个(f​)表示这个点有标记的概率,(g)表示这个点的祖先节点有(包括自己)标记的概率。

对于第一种情况不做讨论。

对于第二种情况,(f​)是没有变化的,(g=g*0.5+0.5​)

对于第三种情况,(f=f*0.5+0.5 g=g*0.5+0.5​);

对于第四种情况(f=f*0.5 g=g*0.5​)

对于第五种情况(f=f*0.5+g*0.5 g=g*0.5+g*0.5)

这些都可以按照线段树的操作去维护。

代码

#include<iostream>
#include<cstdio>
#define ls tr[cnt].l
#define rs tr[cnt].r
#define N 100009
using namespace std;
typedef long long ll;
const int mod=998244353;
ll now,inv2;
int tot,n,m,rot;
inline int rd(){
	int x=0;char c=getchar();bool f=0;
	while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
	while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
	return f?-x:x;
}
inline ll power(ll x,ll y){
	ll ans=1;
	while(y){if(y&1)ans=ans*x%mod;x=x*x%mod;y>>=1;}
	return ans;
}
inline void MOD(ll &x){x=x>=mod?x-mod:x;}
struct node{
	int l,r;
	ll multag,addtag,f,g,sum;
}tr[N<<1];
inline void pushup(int cnt){(tr[cnt].sum=tr[ls].sum+tr[rs].sum+tr[cnt].f)%=mod;}
inline void gx(int cnt){
	tr[cnt].f=(tr[cnt].g+tr[cnt].f)*inv2%mod;
	pushup(cnt);
}
inline void gan(int cnt,ll x,ll y){
	MOD(tr[cnt].g=tr[cnt].g*x%mod+y);
	tr[cnt].multag=tr[cnt].multag*x%mod;
	MOD(tr[cnt].addtag=tr[cnt].addtag*x%mod+y);
}
inline void pushdown(int cnt){
	if(tr[cnt].multag==1&&!tr[cnt].addtag)return;
	gan(ls,tr[cnt].multag,tr[cnt].addtag);
	gan(rs,tr[cnt].multag,tr[cnt].addtag);
	tr[cnt].multag=1;tr[cnt].addtag=0;
}
void upd(int cnt,int l,int r,int L,int R){
	if(l>=L&&r<=R){
		MOD(tr[cnt].f=tr[cnt].f*inv2%mod+inv2);
	    gan(cnt,inv2,inv2);pushup(cnt);return;
	}
	tr[cnt].f=tr[cnt].f*inv2%mod;
    tr[cnt].g=tr[cnt].g*inv2%mod;
	int mid=(l+r)>>1;
	pushdown(cnt); 
	if(mid>=L)upd(ls,l,mid,L,R);
	if(mid<R)upd(rs,mid+1,r,L,R);
	if(mid<L)gx(ls);if(mid>=R)gx(rs);
	pushup(cnt);
}
void build(int &cnt,int l,int r){
	cnt=++tot;
	tr[cnt].multag=1;
	if(l==r)return;
	int mid=(l+r)>>1;
	build(ls,l,mid);build(rs,mid+1,r);
}
int main(){
	n=rd();m=rd();inv2=power(2,mod-2);
	build(rot,1,n);
	int l,r,opt;now=1;
	for(int i=1;i<=m;++i){
		opt=rd();	
		if(opt==1){
			l=rd();r=rd();
			upd(rot,1,n,l,r);
			MOD(now=now+now); 
		}
		else printf("%lld
",tr[rot].sum*now%mod);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/ZH-comld/p/10677284.html