洛谷[1007]梦美与线段树(线段树+概率期望)

传送门

我我我我我我我终于AAAAAAAA了!!!

我们来考虑概率期望原来的定义,是总收益除以总情况数

然后因为每一个节点的权值都是左右儿子权值之和,而走的概率也与权值有关

那么我们可以看做是从根节点出发走$sum[1]$次($sum[1]$为根节点的权值),那么走到每一个节点的次数就是这一个节点的权值

那么总收益就是每一个节点被走到的次数乘上节点的权值,就是每一个节点权值的平方和。而总共走了$sum[1]$次

那么答案就是整棵子树的权值的平方和乘上$sum[1]$的逆元

问题来了,因为这是区间加和,我们得打懒标记,所以要$O(1)$在打懒标记时统计出子树里的权值平方和

只有对于需要被打懒标记的节点(也就是区间被完全包含的节点)我们需要统计答案,其他情况下直接pushup即可维护

我们考虑一下,设一个节点代表的区间长度为$l$,区间加和的值为$t$,原节点的值为$sum$,那么这个节点的权值将变为$sum+l*t$

那么平方和的增加量为$(sum+l*t)^2-sum^2=2*sum*l*t-l^2*t^2$

首先,$l^2$是很好维护的,只要记录一下子树里所有区间长度的平方和即可,这个是不变的,建树的时候就可以弄出来了

于是考虑怎么维护$sum*l$

对于每一个节点来说,区间加和之后,它的$sum$变成$sum+l*t$,那么$sum*l$变为$(sum+l*t)*l$

那么增加量就是$l^2*t$,而且$l^2$我们已经维护好了

所以每一次打完懒标记,我们都可以$O(1)$维护答案了!

然后最坑的一点是……辣鸡出题人最后一个点要用高精或__int128……然后导致无数个90分……调死掉……

ps:__int128编译器上过不了的,建议用洛谷IDE调,或者先写好调到90分再改成__int128

  1 //minamoto
  2 #include<bits/stdc++.h>
  3 #define ll long long
  4 #define add(x,y) ((x)+(y))
  5 #define dec(x,y) ((x)-(y)<0?(x)-(y)+P:(x)-(y))
  6 #define mul(x,y) ((I)*(x)*(y))
  7 #define ls (p<<1)
  8 #define rs (p<<1|1)
  9 using namespace std;
 10 #define getc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
 11 char buf[1<<21],*p1=buf,*p2=buf;
 12 inline ll read(){
 13     #define num ch-'0'
 14     char ch;bool flag=0;ll res;
 15     while(!isdigit(ch=getc()))
 16     (ch=='-')&&(flag=true);
 17     for(res=num;isdigit(ch=getc());res=res*10+num);
 18     (flag)&&(res=-res);
 19     #undef num
 20     return res;
 21 }
 22 const int N=500005,P=998244353;const __int128 I=1;
 23 inline ll ksm(ll a,ll b){
 24     ll res=1;
 25     while(b){
 26         if(b&1) res=mul(res,a)%P;
 27         a=mul(a,a)%P,b>>=1;
 28     }
 29     return res;
 30 }
 31 __int128 sum[N<<2],len[N<<2],len_2[N<<2],ans[N<<2],tag[N<<2],xl[N<<2],a[N],n,m;
 32 //sz子树节点个数,sum节点,ans当前子树答案 ,s子树总和 
 33 inline void upd(int p,int l,int r){
 34     sum[p]=add(sum[ls],sum[rs]),ans[p]=add(add(ans[ls],ans[rs]),mul(sum[p],sum[p]));
 35     xl[p]=add(add(xl[ls],xl[rs]),mul((r-l+1),sum[p]));
 36 }
 37 void pushdown(int p,int l,int r){
 38     if(tag[p]){
 39         int mid=(l+r)>>1;
 40         tag[ls]=add(tag[ls],tag[p]),tag[rs]=add(tag[rs],tag[p]);
 41         
 42         ans[ls]=add(ans[ls],mul(mul(xl[ls],2),tag[p]));
 43         ans[ls]=add(ans[ls],mul(mul(tag[p],tag[p]),len_2[ls]));
 44         
 45         ans[rs]=add(ans[rs],mul(mul(xl[rs],2),tag[p]));
 46         ans[rs]=add(ans[rs],mul(mul(tag[p],tag[p]),len_2[rs]));
 47         
 48         xl[ls]=add(xl[ls],mul(len_2[ls],tag[p]));
 49         xl[rs]=add(xl[rs],mul(len_2[rs],tag[p]));
 50         
 51         sum[ls]=add(sum[ls],mul(tag[p],(mid-l+1)));
 52         sum[rs]=add(sum[rs],mul(tag[p],(r-mid)));
 53         
 54         tag[p]=0;
 55     }
 56 }
 57 void build(int p,int l,int r){
 58     if(l==r){
 59         sum[p]=a[l],len[p]=1,len_2[p]=1,ans[p]=mul(sum[p],sum[p]),tag[p]=0,xl[p]=sum[p];
 60         return;
 61     }
 62     int mid=(l+r)>>1;
 63     build(ls,l,mid),build(rs,mid+1,r);
 64     len[p]=add(add(len[ls],len[rs]),(r-l+1));
 65     len_2[p]=add(add(len_2[ls],len_2[rs]),mul((r-l+1),(r-l+1)));
 66     upd(p,l,r);
 67 }
 68 void update(int p,int l,int r,int ql,int qr,ll val){
 69     if(ql<=l&&qr>=r){
 70         tag[p]=add(tag[p],val);
 71         
 72         ans[p]=add(ans[p],mul(mul(xl[p],2),val));
 73         ans[p]=add(ans[p],mul(len_2[p],mul(val,val)));
 74         
 75         xl[p]=add(xl[p],mul(len_2[p],val));
 76         
 77         sum[p]=add(sum[p],mul((r-l+1),val));
 78         
 79         return;
 80     }
 81     pushdown(p,l,r);
 82     int mid=(l+r)>>1;
 83     if(ql<=mid) update(ls,l,mid,ql,qr,val);
 84     if(qr>mid) update(rs,mid+1,r,ql,qr,val);
 85     upd(p,l,r);
 86 }
 87 ll calc(){
 88     __int128 x=ans[1],y=sum[1];
 89     while(y%P==0) y/=P,x/=P;
 90     x%=P,y%=P;
 91     return x*ksm(y,P-2)%P;
 92 }
 93 int main(){
 94 //    freopen("testdata.in","r",stdin);
 95     n=read(),m=read();
 96     for(int i=1;i<=n;++i) a[i]=read();
 97     build(1,1,n);
 98     while(m--){
 99         int op=read();
100         if(op==2) printf("%lld
",calc());
101         else{
102             int l=read(),r=read();ll v=read();
103             update(1,1,n,l,r,v);
104         }
105     }
106     return 0;
107 }
原文地址:https://www.cnblogs.com/bztMinamoto/p/9751262.html