线段树

线段树

九条可怜是一个喜欢数据结构的女孩子,在常见的数据结构中,可怜最喜欢的就是线段树。

线段树的核心是懒标记,下面是一个带懒标记的线段树的伪代码,其中 `tag` 数组为懒标记:

![](https://s2.ax1x.com/2019/04/02/AyHyRJ.md.png)

其中函数 $ exttt{Lson}( ext{Node})$ 表示 $ ext{Node}$ 的左儿子,$ exttt{Rson}( ext{Node})$ 表示 $ ext{Node}$ 的右儿子。

现在可怜手上有一棵 $[1, n]$ 上的线段树,编号为 $1$。这棵线段树上的所有节点的 `tag` 均为 $0$。接下来可怜进行了 $m$ 次操作,操作有两种:
- $1 l r$,假设可怜当前手上有 $t$ 棵线段树,可怜会把每棵线段树复制两份(`tag` 数组也一起复制),原先编号为 $i$ 的线段树复制得到的两棵编号为 $2i − 1$ 与 $2i$,在复制结束后,可怜手上一共有 $2t$ 棵线段树。接着,可怜会对所有编号为奇数的线段树进行一次 $ exttt{Modify}( ext{root}, 1, n, l, r)$。
- $2$,可怜定义一棵线段树的权值为它上面有多少个节点 `tag` 为 $1$。可怜想要知道她手上所有线段树的权值和是多少。


Sol

这题很妙。

考虑用概率来表示出现次数。

注意到如果一个点的祖先有标记,那么他也有可能被影响。

可以记a为i这一个节点的1的出现概率,b为i和i的祖先中有一个出现出现1的概率。

这两个互相转移一下。

大致思路是把点分成四类:经过的,覆盖的,pushdown的和覆盖的点的儿子。

最后一类需要区间加。

#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>
#define maxn 100005
#define ll long long
#define mod 998244353
using namespace std;
int n,m;
ll ny=499122177;
struct node{
    ll a,b,sum,bc,ba;
}tr[maxn*8];
void wh(int k){
    tr[k].sum=(tr[k*2].sum+tr[k*2+1].sum+tr[k].a)%mod;
}
void ch1(int k){
    tr[k].a=tr[k].a*ny%mod;tr[k].b=tr[k].b*ny%mod;wh(k);
}
void ch2(int k){
    tr[k].a=ny*(tr[k].a+1)%mod;tr[k].b=ny*(tr[k].b+1)%mod;wh(k);
}
void ch3(int k){
    tr[k].a=(tr[k].b+tr[k].a)*ny%mod;tr[k].a%=mod;wh(k);
}
void ch4(int k){
    tr[k].b=ny*(tr[k].b+1)%mod;
    tr[k].bc=tr[k].bc*ny%mod;tr[k].ba=(tr[k].ba+1)*ny%mod; 
}
void down(int k){
    int ls=k*2,rs=k*2+1; 
    if(tr[k].bc!=1){
        ll &t=tr[k].bc;
        tr[ls].b=tr[ls].b*t%mod;tr[ls].bc=tr[ls].bc*t%mod;tr[ls].ba=tr[ls].ba*t%mod;
        tr[rs].b=tr[rs].b*t%mod;tr[rs].bc=tr[rs].bc*t%mod;tr[rs].ba=tr[rs].ba*t%mod;
        t=1;
    }
    if(tr[k].ba!=0){
        ll &t=tr[k].ba;
        tr[ls].b=(tr[ls].b+t)%mod;tr[ls].ba=(tr[ls].ba+t)%mod;
        tr[rs].b=(tr[rs].b+t)%mod;tr[rs].ba=(tr[rs].ba+t)%mod;
        t=0;
    }
}
void add(int k,int l,int r,int li,int ri){
    if(l>=li&r<=ri){
        ch2(k);
        //if(l<r)ch4(k*2),ch4(k*2+1);
        tr[k].ba=(tr[k].ba+1)*ny%mod;
        tr[k].bc=tr[k].bc*ny%mod;
        return;
    }
    down(k);
    int mid=l+r>>1;
    if(li<=mid)add(k*2,l,mid,li,ri);
    else ch3(k*2);
    if(ri>mid)add(k*2+1,mid+1,r,li,ri);
    else ch3(k*2+1);
    ch1(k);
}
int main(){
    cin>>n>>m;
    ll num=1;
    for(int i=1;i<maxn;i++)tr[i].bc=1;
    for(int i=1,op,l,r;i<=m;i++){
        scanf("%d",&op);
        if(op==1){
            scanf("%d%d",&l,&r);
            add(1,1,n,l,r);
            num=num*2%mod;
        }
        else {
            ll ans=tr[1].sum*num%mod;
            ans=(ans+mod)%mod;
            printf("%lld
",ans);
        }
    }
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/liankewei/p/10685825.html