[EOJ629] 两开花

Description

给定一棵以 (1) 为根 (n) 个节点的树。

定义 (f(k)) :从树上等概率随机选出 (k) 个节点,这 (k) 个点的虚树大小的期望。

一个点 (x) 在这些被选出的 (k) 个点的虚树上,当且仅当它满足下列条件至少一个:

  • (x) 被选出。
  • 存在两个被选出的节点 (a,b),使得 (operatorname{lca}(a,b)=x)

给定 (m),求 (f(1),f(2),cdots,f(m))。 对 (998244353) 取模。(nleq 4cdot 10^5)

Sol

又是套着期望皮的计数题。

对于每个点 (i) 求出有多少种方案对答案有贡献即可:

  • (i) 被选出,总方案数为 (C(n-1,k-1))
  • (i) 至少两个儿子的子树中存在被选出的点。

第二种不太好算,考虑用总方案数减去不合法的方案数。

总方案数就是 (C(n-1,k))

如果点 (i) 的子树中没有被选中的,方案数为 (C(n-sze[i],k))

只有一个儿子的子树中有被选中的,可以枚举儿子 (j),方案数就是 (sumlimits_{j} C(n-sze[i]+sze[j],k))

注意到这样的话,(i) 子树中没有被选中的方案数被多算了 儿子个数次,所以还需要加上 (son[i] imes C(n-sze[i],k))

所以

[f(k)=sumlimits_{i=1}^n C_{n-1}^{k-1}+C_{n-1}^k+(son[i]-1) imes C_{n-sze[i]}^k-sum_j C_{n-sze[i]+sze[j]}^k ]

[f(k)=sumlimits_{i=1}^n C_{n}^{k}+(son[i]-1) imes C_{n-sze[i]}^k-sum_j C_{n-sze[i]+sze[j]}^k ]

如何对于每个 (k) 快速求呢?

观察到式子中的每一项组合数的上标都是 (k),所以我们可以开个桶 (buc[i]),在形如 (buc[n-sze[i]]) 的地方加上 (son[i]+1),在 (buc[n-sze[i]+sze[j]])(-1)

好处就是,再推一步式子:

[f(k)=sum_{i=0}^n buc[i]cdot C_i^k ]

这就是个卷积的形式,(mathbf{NTT})优化就吼了。

Code

#pragma GCC optimize(2)
#include<bits/stdc++.h>
using std::min;
using std::max;
using std::swap;
using std::vector;
typedef double db;
typedef long long ll;
#define pb(A) push_back(A)
#define pii std::pair<int,int>
#define all(A) A.begin(),A.end()
#define mp(A,B) std::make_pair(A,B)
const int N=2e6+5;
const int mod=998244353;

int son[N],sze[N],buc[N];
int n,m,cnt,head[N],fac[N];
int a[N],b[N],lim,rev[N],ifac[N];

struct Edge{
    int to,nxt;
}edge[N<<1];

void add(int x,int y){
    edge[++cnt].to=y;
    edge[cnt].nxt=head[x];
    head[x]=cnt;
}

int ksm(int a,int b=mod-2,int ans=1){
    while(b){
        if(b&1) ans=1ll*ans*a%mod;
        a=1ll*a*a%mod;b>>=1;
    } return ans;
}

void ntt(int *f,int g){
    for(int i=1;i<lim;i++) if(i<rev[i]) swap(f[i],f[rev[i]]);
    for(int mid=1;mid<lim;mid<<=1){
        int tmp=ksm(g,(mod-1)/(mid<<1));
        for(int R=mid<<1,j=0;j<lim;j+=R){
            for(int w=1,k=0;k<mid;k++,w=1ll*w*tmp%mod){
                int x=f[j+k],y=1ll*w*f[j+k+mid]%mod;
                f[j+k]=(x+y)%mod,f[j+k+mid]=(mod+x-y)%mod;
            }
        }
    } if(g>3)
        for(int in=ksm(lim),i=0;i<lim;i++) f[i]=1ll*f[i]*in%mod;
}

int getint(){
    int X=0,w=0;char ch=getchar();
    while(!isdigit(ch))w|=ch=='-',ch=getchar();
    while( isdigit(ch))X=X*10+ch-48,ch=getchar();
    if(w) return -X;return X;
}

void init(int n){
    fac[0]=ifac[0]=1;
    for(int i=1;i<=n;i++) fac[i]=1ll*fac[i-1]*i%mod;
    ifac[n]=ksm(fac[n]);
    for(int i=n-1;i;i--) ifac[i]=1ll*ifac[i+1]*(i+1)%mod;
}

void dfs(int now,int fa=0){
    sze[now]=1; int tot=0; buc[n]++;
    for(int i=head[now];i;i=edge[i].nxt){
        int to=edge[i].to;
        if(sze[to]) continue;
        tot++; dfs(to,now);
        sze[now]+=sze[to];
    }
    for(int i=head[now];i;i=edge[i].nxt){
        int to=edge[i].to;
        if(to==fa) continue;
        (buc[n-sze[now]+sze[to]]+=mod-1)%=mod;
    } (buc[n-sze[now]]+=tot-1+mod)%=mod;
}

int C(int n,int m){
    if(n<m) return 0;
    return 1ll*ifac[n]*fac[m]%mod*fac[n-m]%mod;
}

signed main(){
    n=getint(),m=getint(),init(N-5);
    for(int i=1;i<n;i++){
        int x=getint(),y=getint();
        add(x,y),add(y,x);
    } dfs(1);
    lim=1;while(lim<=n+n) lim<<=1;
    for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|(i&1?lim>>1:0);
    for(int i=0;i<=n;i++)
        a[n-i]=1ll*buc[i]*fac[i]%mod,
        b[i]=ifac[i];
    ntt(a,3),ntt(b,3);
    for(int i=0;i<lim;i++) a[i]=1ll*a[i]*b[i]%mod;
    ntt(a,(mod+1)/3);
    for(int i=1;i<=m;i++) 
        printf("%lld
",1ll*a[n-i]*ifac[i]%mod*C(n,i)%mod);
    return 0;
}

原文地址:https://www.cnblogs.com/YoungNeal/p/10363101.html