CF749E Inversions After Shuffle 解题报告 (期望 树状数组)

E. Inversions After Shuffle

题意

有一个长 (n) 的排列, 随机选取一段区间进行随机全排列, 求排列后整个序列的逆序对期望个数.

((n le 10^5)).


思路

首先, 考虑一整个排列进行排序后的逆序对期望个数,

一共有 (frac{n(n-1)}{2}) 对点, 每对点行程逆序对的概率为 (frac{1}{2}) , 所以逆序对期望个数为 (frac{n(n-1)}{4}).


那么, 我们就可以算出每个区间对答案的贡献,

长度为 (len) 的区间有 (n-len+1) 个, 它们排序后的逆序对期望个数为 (frac{len(len-1)}{4}), 所以总贡献为

[sum_{len=1}^{n} frac{len(len-1)}{4} imes(n-len+1) ]


然后, 我们再考虑在选定区间外的逆序对个数,

((i,j)) 来表示一对逆序对, ([l,r]) 表示当前选中的区间.


先考虑 (i ot in [l,r]) 的情况, 设 (i=k) 时的逆序对个数为 (pir[k]), 那么贡献为

[sum_{i=1}^{l-1} pir[i] + sum_{i=r+1}^{n} pir[i] ]

考虑 (pir[i]) 怎么算,

我们可以从小到大枚举排列中的每个数 (i), 在它的位置, 设为 (pl[i]) 上 +1, 然后在树状数组上查询 ([pl[i]+1,n]) 的区间和,

因为我们是从小到大枚举的每个数, 所以这时查询到的都是比 (i) 小的数, 也就是 (i) 为左端点的逆序对, 即 (pir[i]).


再考虑 (i in [l,r], j >r) 的情况, 直接算貌似不太好算, 考虑容斥,

先不考虑 (j>r) 的限制, 则贡献为 (pir[i] imes i imes (n-i+1)),

我们再考虑每个满足 (a[k]<a[i])(k>i)(k), 它会导致 (i) 多算了 (1 imes i imes (n-k+1)) 的贡献, 因为当 (j ge k) 时, (k) 本不应该有贡献的,

所以, 对于每个 (a[i]), 我们在枚举到它的时候, 先在 ([1,i-1]) 打上 ((n-i+1)) 的标记, 表示到时候要减去的权值,

然后再查询 (i) 的标记, 设为 (val[i]), 它对答案的总贡献就是

[pir[i] imes i imes (n-i+1) - val[i] imes i ]


这样, 我们就得到了选取每个区间时, 所得到的逆序对的总期望个数, 最后 (ans/=frac{n(n+1)}{2}) 即可. (总共有 (frac{n(n+1)}{2}) 个区间.)


代码

#include<bits/stdc++.h>
#define ll long long
#define db double
using namespace std;
const int _=1e5+7;
int n,a[_],pl[_]; // c0 用来算 pir, c1 用来算多余的贡献
ll c[2][_],pir[_],sum[_];
// 先算出后缀和, 再算出后缀和的后缀和.
db ans;
void add(int x,ll w,bool id){
  for(int i=x;i<=n;i+=i&(-i))
    c[id][i]+=w;
}
void modify(int l,int r,ll w,bool id){
  add(l,w,id);
  add(r+1,-w,id);
}
ll Sum(int x,bool id){
  ll res=0;
  for(int i=x;i;i-=i&(-i))
    res+=c[id][i];
  return res;
}
ll query(int l,int r,bool id){
  return Sum(r,id)-Sum(l-1,id);
}
void init(){
  cin>>n;
  for(int i=1;i<=n;i++){
    scanf("%d",&a[i]);
    pl[a[i]]=i;
  }
  for(int i=1;i<=n;i++){
    pir[pl[i]]=query(pl[i],n,0);
    add(pl[i],1ll,0);
    modify(1,pl[i]-1,n-pl[i]+1,1ll);
    ll val=Sum(pl[i],1);
    ans+=(db)pl[i]*pir[pl[i]]*(n-pl[i]+1)-(db)val*pl[i];
  }
}
void run(){
  for(int i=n;i>=1;i--)
    sum[i]=sum[i+1]+pir[i];
  for(int i=n;i>=1;i--)
    sum[i]=sum[i+1]+sum[i];
  for(db i=1;i<=n;i+=1)
    ans+=i*(i-1)/4*(n-i+1);
  ll res=0;
  for(int i=1;i<=n;i++){
    res+=pir[i-1];
    ans+=sum[i+1]+(n-i+1)*res;
  }
  
  ans/=(db)n*(n+1)/2;
}
int main(){
  init();
  run();
  printf("%.9lf
",ans);
  return 0;
}
原文地址:https://www.cnblogs.com/BruceW/p/12189372.html