Luogu P5450 [THUPC2018]淘米神的树

题意

写的很明白了,不需要解释。

( exttt{Data Range:}1leq nleq 234567)

题解

国 际 计 数 水 平

首先考虑一开始只有一个黑点的情况怎么做。

我们钦定黑点为根,设 (f_i) 表示以 (i) 为子树,并且 (i) 这个节点为黑色的答案,那么我们得到:

[f_u=(sz_u-1)!prodlimits_{fa_v=u}frac{f_v}{sz_v!} ]

接下来我们阐述这个式子到底是在干什么。

首先我们枚举可能的排列,由于现在只有 (u) 是黑色的,第一轮肯定是先把 (u) 变成红色。也就是说,(u) 只能放在第一个,而剩下的可以任意。(注意这里是可能而不是合法)

接下来考虑排除掉所有不合法的排列,枚举每一棵子树,考虑某棵子树所有节点之间的偏序关系。

对于一个树 (T) 对应的可能的序列 (P) 来说,我们按顺序从 (P) 中选出一些元素构造 (Q),使得 (Q) 中的每一个元素都是 (T) 中某个子树 (S) 中存在的节点的编号,并且 (S) 中每个节点的编号都要在 (Q) 中出现。

这里,(T) 的根为 (u)(S) 的根为 (v)(fa_v=u)

然后我们发现,如果序列 (Q) 是不合法的,那么序列 (P) 也是不合法的。

对所有子树都这么做就得到同时满足所有约束的序列个数,也就是上面的式子。

接下来我们考虑直接计算根节点的答案。

[f_{rt}=frac{(sz_{rt}-1)!}{prodlimits_{fa_v=rt}sz_v!}prodlimits_{fa_v=rt}f_v ]

然后我们对于所有根节点的儿子 (v) 代入进来,再将所有孙子代入进来……相当于我们枚举所有节点计算,于是有

[f_{rt}=prodlimits_{u=1}^{n}frac{(sz_{u}-1)!}{prodlimits_{fa_v=u}sz_v!} ]

简单化开

[f_{rt}=frac{prodlimits_{u=1}^{n}(sz_{u}-1)!}{prodlimits_{u=1}^{n}prodlimits_{fa_v=u}sz_v!} ]

注意分母,也就是说我们对于每个节点 (u) 都枚举 (u) 的所有儿子,相当于枚举所有非根节点(因为每个非根节点总会是某个节点的儿子),于是有

[f_{rt}=frac{prodlimits_{u=1}^{n}(sz_{u}-1)!}{prodlimits_{u=1,u eq rt}^{n}sz_u!} ]

现在来拆分子。单独把 (u=rt) 的拿出来,剩下的合并并约分,得到

[f_{rt}=(sz_{rt}-1)!prodlimits_{u=1,u eq rt}^{n}frac{1}{sz_u} ]

然后分子分母同乘一个 (sz_{rt}) 得到答案为

[f_{rt}=sz_{rt}!prodlimits_{u=1}^{n}frac{1}{sz_u} ]

接下来考虑两个黑点咋做,建立一个虚拟节点 (s),向 (a,b) 连无向边,初始的时候只有 (s) 是黑点。

第一次只能变 (s),然后就变成 (s) 是红点,(a,b) 是黑点。容易证明这样对答案没有任何影响。

于是树就变成了一颗基环树,考虑怎么求答案。

我们枚举环上染红的最后一个点的位置,但是这是不可行的。

那么我们就用类似于 NOIp2018 Day 2 T1 的方法,考虑枚举一条在环上的边并删除他来计算答案。

但是这样做可能会用重复的答案,所以我们要考虑如何来排除这些答案。

注意到对于一种染色方案来说,考虑这个方案在环上最后一个被染成红色的点。我们发现这个点可以有两种方案被染黑,也就是说每个方案都会算 (2) 此,于是最终答案要乘上 (frac{1}{2})

对于不是环上的点和虚拟节点 (s) 我们都能用 ( exttt{dfs}) 求出子树大小,设他们的和为 (z)

我们现在考虑这些在环上的点,对于某个在环上的点 (u) 来说,设 (a_u) 表示它的确定子树大小(也就是所有不在环上的儿子的大小之和加一),(b_u) 表示某种断环方式下的 (sz_u),那么通过人类智慧我们发现

将环上的 (k+1) 个点编号,(s) 点为 (0),其余的依次为 (1sim k),假设断的边是 (r o r+1),那么

[b_u=egin{cases}sumlimits_{i=1}^{u}a_u,& 1leq uleq r\sumlimits_{i=u}^{k}a_u,&r+1leq uleq kend{cases} ]

我们对 (a) 做前缀和变成 (c),那么可以发现

[prod b=prod_{i eq j}vert c_i-c_jvert ]

其中 (c_0=0)

然后我们可以用快速插值的套路去优化这个东西,复杂度 (O(nlog^2n))

代码

#include<cstdio>
#include<cstring>
#include<cctype>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<vector>
#define clr(f,n) memset(f,0,sizeof(int)*(n))
#define cpy(f,g,n) memcpy(f,g,sizeof(int)*(n))
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
using namespace std;
typedef int ll;
typedef long long int li;
typedef unsigned long long int ull;
const ll MAXN=524291,MOD=998244353,G=3,INVG=332748118;
struct Edge{
    ll to,prev;
};
Edge ed[MAXN];
ll n,x,y,z,fct,from,to,tot,tp,cur,res;
ll rev[MAXN],omgs[MAXN],invo[MAXN];
ll last[MAXN],sz[MAXN],fa[MAXN],st[MAXN],p[MAXN],c[MAXN];
ll pr[MAXN],px[MAXN];
inline ll read()
{
    register ll num=0,neg=1;
    register char ch=getchar();
    while(!isdigit(ch)&&ch!='-')
    {
        ch=getchar();
    }
    if(ch=='-')
    {
        neg=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        num=(num<<3)+(num<<1)+(ch-'0');
        ch=getchar();
    }
    return num*neg;
}
inline void addEdge(ll from,ll to)
{
    ed[++tot].prev=last[from];
    ed[tot].to=to;
    last[from]=tot;
}
inline void dfs(ll node,ll f)
{
    sz[node]=1,fa[node]=f;
    for(register int i=last[node];i;i=ed[i].prev)
    {
        if(ed[i].to!=f)
        {
            dfs(ed[i].to,node),sz[node]+=sz[ed[i].to];
        }
    }
}
inline ll qpow(ll base,ll exponent)
{
    ll res=1;
    while(exponent)
    {
        if(exponent&1)
        {
            res=(li)res*base%MOD;
        }
        base=(li)base*base%MOD,exponent>>=1;
    }
    return res;
}
inline void setupOmg(ll cnt)
{
    ll limit=log2(cnt)-1,omg;
    omg=qpow(G,(MOD-1)>>(limit+1)),omgs[cnt>>1]=1;
    for(register int i=(cnt>>1|1);i!=cnt;i++)
    {
        omgs[i]=(li)omgs[i-1]*omg%MOD;
    }
    for(register int i=(cnt>>1)-1;i;i--)
    {
        omgs[i]=omgs[i<<1]; 
    }
}
inline void NTT(ll *cp,ll cnt,ll inv)
{
    static ull tcp[MAXN];
    register ll cur=0,x,shift=log2(cnt)-__builtin_ctz(cnt);
    if(inv==-1)
    {
        reverse(cp+1,cp+cnt);
    }
    for(register int i=0;i<cnt;i++)
    {
        tcp[rev[i]>>shift]=cp[i];
    }
    for(register int i=2;i<=cnt;i<<=1)
    {
        cur=i>>1;
        for(register int j=0;j<cnt;j+=i)
        {
            for(register int k=0;k<cur;k++)
            {
                x=tcp[j|k|cur]*omgs[k|cur]%MOD;
                tcp[j|k|cur]=tcp[j|k]+MOD-x,tcp[j|k]+=x;
            }
        }
    }
    for(register int i=0;i<cnt;i++)
    {
        cp[i]=tcp[i]%MOD;
    }
    if(inv==1)
    {
        return;
    }
    x=MOD-(MOD-1)/cnt;
    for(register int i=0;i<cnt;i++)
    {
        cp[i]=(li)cp[i]*x%MOD;
    }
}
inline void conv(ll fd,ll *f,ll *g,ll *res)
{
    static ll tmpf[MAXN],tmpg[MAXN];
    ll cnt=1,limit=-1;
    while(cnt<(fd<<1))
    {
        cnt<<=1,limit++;
    }
    for(register int i=0;i<cnt;i++)
    {
        tmpf[i]=i<fd?f[i]:0,tmpg[i]=i<fd?g[i]:0;
        rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
    }
    NTT(tmpf,cnt,1),NTT(tmpg,cnt,1);
    for(register int i=0;i<cnt;i++)
    {
        tmpf[i]=(li)tmpf[i]*tmpg[i]%MOD;
    }
    NTT(tmpf,cnt,-1),cpy(res,tmpf,fd);
}
inline void inv(ll fd,ll *f,ll *res)
{
    static ll tmp[MAXN];
    if(fd==1)
    {
        res[0]=qpow(f[0],MOD-2);
        return;
    }
    inv((fd+1)>>1,f,res);
    ll cnt=1,limit=-1;
    while(cnt<(fd<<1))
    {
        cnt<<=1,limit++;
    }
    for(register int i=0;i<cnt;i++)
    {
        tmp[i]=i<fd?f[i]:0;
        rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
    }
    NTT(tmp,cnt,1),NTT(res,cnt,1);
    for(register int i=0;i<cnt;i++)
    {
        res[i]=(li)(2-(li)tmp[i]*res[i]%MOD+MOD)%MOD*res[i]%MOD;
    }
    NTT(res,cnt,-1),clr(res+fd,cnt-fd-1);
}
inline void mod(ll fd,ll gd,ll *f,ll *g,ll *r)
{
    static ll tmpf[MAXN],tmpg[MAXN],tinv[MAXN],q[MAXN];
    if(fd<gd)
    {
        for(register int i=0;i<gd-1;i++)
        {
            r[i]=f[i];
        }
        return;
    }
    for(register int i=0;i<fd;i++)
    {
        tmpf[i]=f[fd-1-i];
    }
    for(register int i=0;i<gd;i++)
    {
        tmpg[i]=g[gd-1-i];
    }
    inv(fd-gd+2,tmpg,tinv);
    ll cnt=1,limit=-1;
    while(cnt<(fd<<1))
    {
        cnt<<=1,limit++;
    }
    for(register int i=0;i<cnt;i++)
    {
        rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
    }
    NTT(tmpf,cnt,1),NTT(tinv,cnt,1);
    for(register int i=0;i<cnt;i++)
    {
        q[i]=1ll*tmpf[i]*tinv[i]%MOD;
    }
    NTT(q,cnt,-1),reverse(q,q+fd-gd+1);
    for(register int i=0;i<cnt;i++)
    {
        tmpf[i]=tinv[i]=tmpg[i]=0;
        q[i]=i<fd-gd+1?q[i]:0,g[i]=i<gd?g[i]:0;
    }
    cnt>>=1,limit--;
    for(register int i=0;i<cnt;i++)
    {
        rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
    }
    NTT(q,cnt,1),NTT(g,cnt,1);
    for(register int i=0;i<cnt;i++)
    {
        tmpf[i]=1ll*q[i]*g[i]%MOD;
    }
    NTT(g,cnt,-1),NTT(tmpf,cnt,-1);
    for(register int i=0;i<gd-1;i++)
    {
        r[i]=(f[i]-tmpf[i]+MOD)%MOD;
    }
    for(register int i=0;i<cnt;i++)
    {
        q[i]=tmpf[i]=0;
    }
}
vector<ll> tmpf2[MAXN<<2];
void dnc(ll *pts,ll l,ll r,ll node)
{
    static ll tmp[MAXN],tmp2[MAXN];
    if(l==r)
    {
        tmpf2[node].push_back((MOD-pts[l])%MOD),tmpf2[node].push_back(1);
        return;
    }
    ll mid=(l+r)>>1,ls=node<<1,rs=ls|1;
    dnc(pts,l,mid,ls),dnc(pts,mid+1,r,rs);
    ll d=tmpf2[ls].size(),d2=tmpf2[rs].size();
    copy(tmpf2[ls].begin(),tmpf2[ls].end(),tmp);
    copy(tmpf2[rs].begin(),tmpf2[rs].end(),tmp2);
    ll cnt=1,limit=-1;
    while(cnt<(d+d2))
    {
        cnt<<=1,limit++;
    }
    for(register int i=0;i<cnt;i++)
    {
        rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
    }
    NTT(tmp,cnt,1),NTT(tmp2,cnt,1);
    for(register int i=0;i<cnt;i++)
    {
        tmp[i]=(li)tmp[i]*tmp2[i]%MOD;
    }
    NTT(tmp,cnt,-1),tmpf2[node].resize(d+d2-1);
    copy(tmp,tmp+d+d2-1,tmpf2[node].begin()),clr(tmp,cnt),clr(tmp2,cnt);
}
ll tmpf3[19][MAXN];
void dnc2(ll fd,ll *f,ll depth,ll l,ll r,ll node,ll *res)
{
    static ll tmp[MAXN],pw[17];
    if(r-l<=1024)
    {
        for(register int i=l;i<=r;i++)
        {
            ll x=c[i],cur=f[fd-1],v1,v2,v3,v4;
            pw[0]=1;
            for(register int j=1;j<=16;j++)
            {
                pw[j]=1ll*pw[j-1]*x%MOD;
            }
            for(register int j=fd-2;j-15>=0;j-=16)
            {
                v1=(1ll*cur*pw[16]+1ll*f[j]*pw[15]+
                    1ll*f[j-1]*pw[14]+1ll*f[j-2]*pw[13])%MOD;
                v2=(1ll*f[j-3]*pw[12]+1ll*f[j-4]*pw[11]+
                    1ll*f[j-5]*pw[10]+1ll*f[j-6]*pw[9])%MOD;
                v3=(1ll*f[j-7]*pw[8]+1ll*f[j-8]*pw[7]+
                    1ll*f[j-9]*pw[6]+1ll*f[j-10]*pw[5])%MOD;
                v4=(1ll*f[j-11]*pw[4]+1ll*f[j-12]*pw[3]+
                    1ll*f[j-13]*pw[2]+1ll*f[j-14]*pw[1])%MOD;
                cur=(0ll+v1+v2+v3+v4+f[j-15])%MOD;
            }
            for(register int j=((fd-1)&15)-1;~j;j--)
            {
                cur=(1ll*cur*x+f[j])%MOD;
            }
            res[i]=cur;
        }
        return;
    }
    ll sz=tmpf2[node].size()-1;
    for(register int i=0;i<sz+1;i++)
    {
        tmp[i]=tmpf2[node][i];
    }
    clr(tmpf3[depth],sz+10),mod(fd,sz+1,f,tmp,tmpf3[depth]);
    ll mid=(l+r)>>1;
    dnc2(sz,tmpf3[depth],depth+1,l,mid,node<<1,res);
    dnc2(sz,tmpf3[depth],depth+1,mid+1,r,(node<<1)|1,res);
    for(register int i=0;i<sz;i++)
    {
        tmpf3[depth][i]=0;
    }
}
inline void eval(ll fd,ll pcnt,ll *f,ll *pts,ll *res)
{
    dnc(pts,0,pcnt-1,1),dnc2(fd,f,0,0,pcnt-1,1,res);
}
int main()
{
    setupOmg(524288),n=read(),x=read(),y=read();
    for(register int i=0;i<n-1;i++)
    {
        from=read(),to=read();
        addEdge(from,to),addEdge(to,from);
    }
    dfs(x,0);
    while(y!=x)
    {
        st[++tp]=y,p[y]=1,y=fa[y];
    }
    st[++tp]=y,p[y]=1,y=fa[y],z=n+1,fct=1;
    for(register int i=1;i<=tp;i++)
    {
        c[i]=1;
        for(register int j=last[st[i]];j;j=ed[j].prev)
        {
            if(!p[ed[j].to])
            {
                c[i]+=sz[ed[j].to];
            }
        }
    }
    for(register int i=1;i<=n;i++)
    {
        fct=(li)fct*i%MOD;
        if(!p[i])
        {
            z=(li)z*sz[i]%MOD;
        }
    }
    fct=(li)fct*(n+1)%MOD;
    for(register int i=1;i<=tp;i++)
    {
        c[i]=(c[i-1]+c[i])%MOD;
    }
    dnc(c,1,tp+1,1);
    for(register int i=1;i<=tp+1;i++)
    {
        pr[i-1]=(li)i*tmpf2[1][i]%MOD;
    }
    memset(tmpf2,0,sizeof(tmpf2)),eval(n+1,tp+1,pr,c,px);
    for(register int i=0;i<=tp;i++)
    {
        cur=(li)fct*qpow((li)px[i]*z%MOD,MOD-2)%MOD;
        res=(res+(((tp-i)&1)?MOD-cur:cur))%MOD;
    }
    printf("%d
",(li)res*499122177%MOD);
}
原文地址:https://www.cnblogs.com/Karry5307/p/13041340.html