[BZOJ5461][LOJ#2537[PKUWC2018]Minimax(概率DP+线段树合并)

还是没有弄清楚线段树合并的时间复杂度是怎么保证的,就当是$O(mlog n)$吧。

这题有一个显然的DP,dp[i][j]表示节点i的值为j的概率,转移时维护前缀后缀和,将4项加起来就好了。

这个感觉已经很难做到比$O(n^2)$更优的复杂度了,但我们要看到题目里有什么条件没用上:每个节点最多有2个儿子。

这个提醒我们可以用启发式合并,据说splay可以做,但我们可以考虑一下线段树合并做法。

仍然采用上面的转移方程,这里线段树上的一个节点T[x]表示x表示的区间[L,R]最终成为当前子树的根的值的概率,那么答案显然可以通过最终线段树上每个叶子节点统计。现在难点就在于如何统计。

考虑某个左子树的动态开店权值线段树上某个节点x表示的区间[L,R]最后成为根rt的值的概率,它是由p[rt]*(右子树的值小于R的概率)+(1-p[rt])*(右子树的值小于L的概率),这个其实是可以在merge的过程中递归下去累加的。

具体做法是:merge(x,y,sx,sy)表示合并线段树节点x和y(显然这里x和y分属左右子树,表示的是同一个权值区间),那么我们每次将这一层的信息累加进去,接着递归下去即可。当我们发现只存在一棵子树了(设为x),那么这棵子树的概率要乘上sy。

 1 #include<cstdio>
 2 #include<cstring>
 3 #include<algorithm>
 4 #define rep(i,l,r) for (int i=l; i<=r; i++)
 5 typedef long long ll;
 6 using namespace std;
 7 
 8 const int N=300010,M=6000010,g=796898467,mod=998244353;
 9 int n,nd,p,tot,w[N],rt[N],fa[N],son[N][2],b[N],s[M],tag[M],ls[M],rs[M];
10 
11 void put(int x,int k){ s[x]=1ll*s[x]*k%mod; tag[x]=1ll*tag[x]*k%mod; }
12 void push(int x){ if (tag[x]!=1) put(ls[x],tag[x]),put(rs[x],tag[x]),tag[x]=1; }
13 
14 void init(int &x,int L,int R,int pos){
15     x=++nd; s[x]=tag[x]=1;
16     if (L==R) return;
17     int mid=(L+R)>>1;
18     if (pos<=mid) init(ls[x],L,mid,pos);
19         else init(rs[x],mid+1,R,pos);
20 }
21 
22 int merge(int x,int y,int sx,int sy){
23     if (!y){ put(x,sy); return x; }
24     if (!x){ put(y,sx); return y; }
25     push(x); push(y);
26     int x0=s[ls[x]],y0=s[ls[y]],x1=s[rs[x]],y1=s[rs[y]];
27     ls[x]=merge(ls[x],ls[y],(sx+1ll*(1-p)*x1)%mod,(sy+1ll*(1-p)*y1)%mod);
28     rs[x]=merge(rs[x],rs[y],(sx+1ll*p*x0)%mod,(sy+1ll*p*y0)%mod);
29     s[x]=(s[ls[x]]+s[rs[x]])%mod; return x;
30 }
31 
32 int solve(int x){
33     if (!son[x][0]) { init(rt[x],1,tot,lower_bound(b+1,b+tot+1,w[x])-b); return rt[x]; }
34     int l=solve(son[x][0]); if (!son[x][1]) return l;
35     int r=solve(son[x][1]); p=1ll*g*w[x]%mod; return merge(l,r,0,0);
36 }
37 
38 int dfs(int x,int L,int R){
39     if (L==R) return 1ll*L*b[L]%mod*s[x]%mod*s[x]%mod;
40     int mid=(L+R)>>1; push(x);
41     return (dfs(ls[x],L,mid)+dfs(rs[x],mid+1,R))%mod;
42 }
43 
44 int main(){
45     freopen("a.in","r",stdin);
46     freopen("a.out","w",stdout);
47     scanf("%d",&n);
48     rep(i,1,n) scanf("%d",&fa[i]),son[fa[i]][son[fa[i]][0] ? 1 : 0]=i;
49     rep(i,1,n){
50         scanf("%d",&w[i]); if (!son[i][0]) b[++tot]=w[i];
51     }
52     sort(b+1,b+tot+1); printf("%d
",(dfs(solve(1),1,tot)+mod)%mod);
53     return 0;
54 }
原文地址:https://www.cnblogs.com/HocRiser/p/9059599.html