题解 洛谷 P5298 【[PKUWC2018]Minimax】

首先发现每个叶子节点的权值都有可能成为最终根节点的权值,观察题目中给出的式子:

[ sum_{i=1}^m i V_i D_i^2 ]

发现只要算出每个权值被根节点取到的概率 (D_i),然后就能计算答案。

(f_{x,i}) 为节点 (x) 取到第 (i) 小权值的概率,根据是从左儿子还是从右儿子取到的权值来进行分类讨论:

[f_{x,i} = f_{ls,i}(p_x sum_{j=1}^{i-1}f_{rs,j}+(1-p_x) sum_{j=i+1}^m f_{rs,j}) + f_{rs,i}(p_x sum_{j=1}^{i-1}f_{ls,j}+(1-p_x) sum_{j=i+1}^m f_{ls,j}) ]

直接进行转移的话,时间和空间都无法接受,由本题为树形 (DP) 和转移方程的特点得,可以通过整体 (DP) 来优化。

用动态开点线段树来维护第二维,通过线段树合并实现从左右儿子转移。当一个节点没有儿子时就直接赋初值,只有一个儿子就直接转移。线段树合并时记录前缀和和后缀和,来计算左右区间互相的贡献,线段树上区间乘,打标记即可。

(code:)

#include<bits/stdc++.h>
#define maxn 600010
#define maxm 10000010
#define p 998244353
#define inv 796898467
#define mid ((l+r)>>1)
using namespace std;
typedef long long ll;
template<typename T> inline void read(T &x)
{
    x=0;char c=getchar();bool flag=false;
    while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    if(flag)x=-x;
}
int n,tot,cnt;
ll ans;
int ch[maxn][2],rt[maxn],ls[maxm],rs[maxm];
ll val[maxn],s[maxn],d[maxn],sum[maxm],tag[maxm];
void pushtag(int x,ll v)
{
    sum[x]=sum[x]*v%p,tag[x]=tag[x]*v%p;
}
void pushdown(int x)
{
    if(tag[x]==1) return;
    pushtag(ls[x],tag[x]),pushtag(rs[x],tag[x]),tag[x]=1;
}
void insert(int l,int r,int pos,int &cur)
{
    if(!cur) cur=++tot,tag[cur]=1;
    sum[cur]++;
    if(l==r) return;
    if(pos<=mid) insert(l,mid,pos,ls[cur]);
    else insert(mid+1,r,pos,rs[cur]);
}
int merge(int x,int y,ll lx,ll rx,ll ly,ll ry,ll v)
{
    if(!x&&!y) return 0;
    pushdown(x),pushdown(y);
    if(x&&!y)
    {
        pushtag(x,(v*ly%p+(1-v+p)*ry%p)%p);
        return x;
    }
    if(!x&&y)
    {
        pushtag(y,(v*lx%p+(1-v+p)*rx%p)%p);
        return y;
    }
    ll px=sum[ls[x]],sx=sum[rs[x]],py=sum[ls[y]],sy=sum[rs[y]];
    ls[x]=merge(ls[x],ls[y],lx,(rx+sx)%p,ly,(ry+sy)%p,v);
    rs[x]=merge(rs[x],rs[y],(lx+px)%p,rx,(ly+py)%p,ry,v);
    sum[x]=(sum[ls[x]]+sum[rs[x]])%p;
    return x;
}
void solve(int x)
{
    if(!ch[x][0])
    {
        insert(1,cnt,val[x],rt[x]);
        return;
    }
    if(!ch[x][1])
    {
        solve(ch[x][0]),rt[x]=rt[ch[x][0]];
        return;
    }
    solve(ch[x][0]),solve(ch[x][1]);
    rt[x]=merge(rt[ch[x][0]],rt[ch[x][1]],0,0,0,0,val[x]);
}
void dfs(int l,int r,int cur)
{
    if(l==r)
    {
        d[l]=sum[cur];
        return;
    }
    pushdown(cur),dfs(l,mid,ls[cur]),dfs(mid+1,r,rs[cur]);
}
int main()
{
    read(n);
    for(int i=1;i<=n;++i)
    {
        int fa;
        read(fa);
        if(ch[fa][0]) ch[fa][1]=i;
        else ch[fa][0]=i;
    }
    for(int i=1;i<=n;++i)
    {
        read(val[i]);
        if(ch[i][0]) val[i]=val[i]*inv%p;
        else s[++cnt]=val[i];
    }
    sort(s+1,s+cnt+1);
    for(int i=1;i<=n;++i)
        if(!ch[i][0])
            val[i]=lower_bound(s+1,s+cnt+1,val[i])-s;
    solve(1),dfs(1,cnt,rt[1]);
    for(int i=1;i<=cnt;++i)
        ans=(ans+i*s[i]%p*d[i]%p*d[i]%p)%p;
    printf("%lld",ans);
    return 0;
}
原文地址:https://www.cnblogs.com/lhm-/p/13364659.html