CF712E

线段树维护概率好题

设状态$f(i)$表示从$i$走到终点的概率

不妨令起点为$1$,终点为$n$

那么有转移:$f(i)=(1-p(i))f(i-1)+p(i)*f(i+1)$

移项,得:

$p(i)[f(i+1)-f(i-1)]=f(i)-f(i-1)$

令$g(i)=f(i)-f(i-1)$

于是原式可以写作

$p(i)[g(i+1)+g(i)]=g(i)$

据此可以写出:

$g(i+1)=frac{1-p(i)}{p(i)}g(i)$

设$t(i)=frac{1-p(i)}{p(i)}$

则有:

$g(i+1)=t(i)g(i)$

设了一大堆,这要干啥?

我们分析一下$g$的性质:

考虑到$f(n)=1$,$f(0)=0$

于是可以推出$g(1)+g(2)+...+g(n)=f(n)-f(0)=1$

同时不要忘了上面的递推公式:

代入,得:

$g(1)+t(1)g(1)+t(2)t(1)g(1)+...+t(n-1)...t(1)g(1)=1$

提出g(1),可得:$g(1)=frac{1}{sum_{i=1}^{n-1} prod_{j=1}^{i-1} t(j)}$

同时不要忘了,$g(1)=f(1)-f(0)=f(1)$,因此我们只需求出$g(1)$即可

因此我们用一棵线段树分别维护区间$t(i)$乘积和$t(1)+t(1)t(2)....$这个表达式的值

区间乘积好维护,后面那个表达式怎么合并两个区间的答案呢?

将后面区间的答案整体乘上前面区间的乘积再加上前面区间的答案即可

贴代码:

#include <cstdio>
#include <cmath>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#include <queue>
#include <stack>
#define rt1 rt<<1
#define rt2 (rt<<1)|1
using namespace std;
struct Seg_tree
{
    double mul_sum;
    double S_sum;
}tree[400005];
double p[100005];
int n,q;
void pushup(int rt)
{
    tree[rt].mul_sum=tree[rt1].mul_sum*tree[rt2].mul_sum;
    tree[rt].S_sum=tree[rt1].S_sum+tree[rt1].mul_sum*tree[rt2].S_sum;
}
void buildtree(int rt,int l,int r)
{
    if(l==r)
    {
        tree[rt].mul_sum=p[l];
        tree[rt].S_sum=p[l];
        return;
    }
    int mid=(l+r)>>1;
    buildtree(rt1,l,mid),buildtree(rt2,mid+1,r);
    pushup(rt);
}
void update(int rt,int l,int r,int posi)
{
    if(l==r)
    {
        tree[rt].mul_sum=p[l];
        tree[rt].S_sum=p[l];
        return;
    }
    int mid=(l+r)>>1;
    if(posi<=mid)update(rt1,l,mid,posi);
    else update(rt2,mid+1,r,posi);
    pushup(rt);
}
double query_m(int rt,int l,int r,int lq,int rq)
{
    if(lq<=l&&rq>=r)return tree[rt].mul_sum;
    int mid=(l+r)>>1;
    double ret=1.0;
    if(lq<=mid)ret=ret*query_m(rt1,l,mid,lq,rq);
    if(rq>mid)ret=ret*query_m(rt2,mid+1,r,lq,rq);
    return ret;
}
double query(int rt,int l,int r,int lq,int rq)
{
    if(lq<=l&&rq>=r)return tree[rt].S_sum;
    int mid=(l+r)>>1;
    double ret=0;
    if(lq<=mid)ret=ret+1.0*query(rt1,l,mid,lq,rq);
    if(rq>mid)ret=ret+1.0*query_m(rt1,l,mid,lq,mid)*query(rt2,mid+1,r,lq,rq);
    return ret;
}
int main()
{
//    freopen("hs.in","r",stdin);
//    freopen("hs.out","w",stdout);
    scanf("%d%d",&n,&q);
    for(int i=1;i<=n;i++)
    {
        double x,y;
        scanf("%lf%lf",&x,&y);
        p[i]=(y-x)/x;
    }
    buildtree(1,1,n);
    while(q--)
    {
        int typ;
        scanf("%d",&typ);
        if(typ==1)
        {
            int posi;
            double x,y;
            scanf("%d%lf%lf",&posi,&x,&y);
            p[posi]=(y-x)/x;
            update(1,1,n,posi);
        }else
        {
            int l,r;
            scanf("%d%d",&l,&r);
            double temp=query(1,1,n,l,r)+1.0;
            printf("%.10lf
",1.0/temp);
        }
    }
    return 0;
}
原文地址:https://www.cnblogs.com/zhangleo/p/10975752.html