AtCoder Regular Contest 086 E

  好强的题。

  方案不好算,改成算概率,注意因为是模意义下的概率所以直接乘法逆元就好不要傻傻地开double。

  设$f[i][d][0]$为第i个节点离d层的球球走到第i个点时第i个点没有球的概率, $f[i][d][1]$为有1个球的概率, $f[i][d][2]$为有2个球及以上的概率。

  我们可以把$f[i]$看成一个队列, 然后从儿子转移的时候, 就是把儿子的队列一个一个合并起来,最后在队列头加上一个$f[i][0]$, 并且把队列里的所有$f[i][0$~$d][2]$加上$f[i][0$~$d][0]$,并且$f[i][0$~$d][2]$变成0就好了。

  合并的时候转移为:

  $f[i][d][0]=f[i][d][0]*f[j][d][0]$

  $f[i][d][1]=f[i][d][1]*f[j][d][0]+f[i][d][0]*f[j][d][1]$

  $f[i][d][2]=f[i][d][0]*f[j][d][2]+f[i][d][1]*f[j][d][2]+f[i][d][1]*f[j][d][1]+f[i][d][2]*f[j][d][2]+f[i][d][2]*f[j][d][1]+f[i][d][2]*f[j][d][0]$

  复杂度为O(N),因为每层元素只加1,交集最多为N。

#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<vector>
#define ll long long
#define MOD(x) ((x)>=mod?(x)-mod:(x))
using namespace std;
const int maxn=500010, mod=1e9+7;
struct tjm{int too, pre;}e[maxn<<1];
struct poi{int f[3];};
int n, x, ans, tot, tott;
int last[maxn], root[maxn];
vector<poi>q[maxn];
inline void read(int &k)
{
    int f=1; k=0; char c=getchar();
    while(c<'0' || c>'9') c=='-' && (f=-1), c=getchar();
    while(c<='9' && c>='0') k=k*10+c-'0', c=getchar();
    k*=f;
}
inline void add(int x, int y){e[++tot]=(tjm){y, last[x]}; last[x]=tot;}
inline int merge(int x, int y)
{
    if(q[x].size()<q[y].size()) swap(x, y);
    int nx=q[x].size()-1, ny=q[y].size()-1;
    for(int i=0;i<=ny;i++) 
    {
        int sum0=0, sum1=0, sum2=0;
        sum0=1ll*q[x][nx-i].f[0]*q[y][ny-i].f[0]%mod;
        sum1=(1ll*q[x][nx-i].f[1]*q[y][ny-i].f[0]+1ll*q[x][nx-i].f[0]*q[y][ny-i].f[1])%mod;
        for(int j=0;j<3;j++)
            for(int k=2;j+k>=2;k--)
                sum2=(1ll*sum2+1ll*q[x][nx-i].f[j]*q[y][ny-i].f[k])%mod;
        q[x][nx-i].f[0]=sum0; q[x][nx-i].f[1]=sum1; q[x][nx-i].f[2]=sum2;
    }
    q[y].clear(); return x;
}
void dfs(int x, int fa)
{
    if(!last[x]) root[x]=++tott; int dep=0;
    for(int i=last[x], too;i;i=e[i].pre)
    if((too=e[i].too)!=fa)
    {
        dfs(too, x);
        if(!root[x]) root[x]=root[too];
        else dep=max(dep, (int)min(q[root[x]].size(), q[root[too]].size())), root[x]=merge(root[x], root[too]);
    }
    int nx=q[root[x]].size()-1; 
    for(int i=0;i<dep;i++) 
        q[root[x]][nx-i].f[0]=MOD(q[root[x]][nx-i].f[0]+q[root[x]][nx-i].f[2]), q[root[x]][nx-i].f[2]=0;
    poi tmp; tmp.f[1]=tmp.f[0]=(mod+1)>>1; tmp.f[2]=0; q[root[x]].push_back(tmp);
}
inline int power(int a, int b)
{
    int ans=1;
    for(;b;b>>=1, a=1ll*a*a%mod)
    if(b&1) ans=1ll*ans*a%mod;
    return ans;
}
int main()
{
    read(n);
    for(int i=1;i<=n;i++) read(x), add(x, i);
    dfs(0, -1); 
    for(int i=0;i<q[root[0]].size();i++) ans=MOD(ans+q[root[0]][i].f[1]);
    printf("%lld
", 1ll*ans*power(2, n+1)%mod);
}
View Code
原文地址:https://www.cnblogs.com/Sakits/p/8027335.html