[JSOI2019]神经网络(树形DP+容斥+生成函数)

首先可以把题目转化一下:把树拆成若干条链,每条链的颜色为其所在的树的颜色,然后排放所有的链成环,求使得相邻位置颜色不同的排列方案数。

然后本题分为两个部分:将一棵树分为1~n条不相交的链的方案数;将这些链安排顺序使得不存在两条相邻的链来自同一棵树。

第一部分显然可以O(n2)树形DP,f[i][j][0/1/2]表示i及其子树j条链,i向儿子连出0/1/2条边的方案数,然后直接背包DP即可。看似O(n3)的树形背包DP其实是O(n2)的。证明复杂度:其实DP时只循环到sz[u]/sz[v]即可,然后可以把每个转移视为儿子v内子树的每个节点和节点u内v外节点组成的点对,于是全部DP完就是枚举了所有的点对,复杂度显然O(n2)。

第二部分,考虑n个点的树划分成i条链的方案是f[i],如果不考虑环只考虑链其对应的指数生成函数为Σf[i]i!(Σ(-1)i-jC(i-1,i-j)xj/j!),其中i∈[1,n],j∈[1,i]。拓展到环上,钦定一棵树作为开头,如果该颜色有i条链,则被算了i次,然后其指数生成函数为:Σf[i](i-1)!(Σ(-1)i-jC(i-1,i-j)xj-1/(j-1)!),其中i∈[1,n],j∈[1,i]。减去首尾同色后,生成函数是这样的:Σf[i](i-1)!(Σ(-1)i-jC(i-1,i-j)xj-2/(j-2)!),其中i∈[2,n],j∈[2,i]。然后暴力卷积即可。

#include<bits/stdc++.h>
using namespace std;
const int N=5005,mod=998244353;
int n,m,sum,ans,fac[N],inv[N],sz[N],f[N][N][3],g[N],tmp[N][3],dp[310][N],b[N];
vector<int>G[N];
int qpow(int a,int b)
{
    int ret=1;
    while(b)
    {
        if(b&1)ret=1ll*ret*a%mod;
        a=1ll*a*a%mod,b>>=1;
    }
    return ret;
}
void dfs(int u,int fa)
{
    sz[u]=1,f[u][1][0]=1;
    for(int i=0;i<G[u].size();i++)
    if(G[u][i]!=fa)
    {
        int v=G[u][i];
        dfs(v,u);
        for(int j=0;j<=sz[u]+sz[v];j++)tmp[j][0]=tmp[j][1]=tmp[j][2]=0;
        for(int j=1;j<=sz[u];j++)
        for(int k=1;k<=sz[v];k++)
        {
            tmp[j+k][0]=(tmp[j+k][0]+1ll*f[u][j][0]*(f[v][k][0]+2ll*f[v][k][1]+2ll*f[v][k][2]))%mod;
            tmp[j+k-1][1]=(tmp[j+k-1][1]+1ll*f[u][j][0]*(f[v][k][0]+f[v][k][1]))%mod;
            tmp[j+k][1]=(tmp[j+k][1]+1ll*f[u][j][1]*(f[v][k][0]+2ll*f[v][k][1]+2ll*f[v][k][2]))%mod;
            tmp[j+k-1][2]=(tmp[j+k-1][2]+1ll*f[u][j][1]*(f[v][k][0]+f[v][k][1]))%mod;
            tmp[j+k][2]=(tmp[j+k][2]+1ll*f[u][j][2]*(f[v][k][0]+2ll*f[v][k][1]+2ll*f[v][k][2]))%mod;
        }
        sz[u]+=sz[v];
        for(int j=1;j<=sz[u];j++)f[u][j][0]=tmp[j][0],f[u][j][1]=tmp[j][1],f[u][j][2]=tmp[j][2];
    }
}
int C(int a,int b){return a<b?0:1ll*fac[a]*inv[b]%mod*inv[a-b]%mod;}
int S(int a,int b){return (!a&&!b)?1:1ll*fac[a]*C(a-1,a-b)%mod;}
int main()
{
    fac[0]=1;for(int i=1;i<=5000;i++)fac[i]=1ll*fac[i-1]*i%mod;
    for(int i=0;i<=5000;i++)inv[i]=qpow(fac[i],mod-2);
    scanf("%d",&m);
    dp[0][0]=1;
    for(int p=1;p<=m;p++)
    {
        scanf("%d",&n);
        for(int i=1;i<=n;i++)G[i].clear();
        for(int i=1,x,y;i<n;i++)scanf("%d%d",&x,&y),G[x].push_back(y),G[y].push_back(x);
        for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
        f[i][j][0]=f[i][j][1]=f[i][j][2]=0;
        dfs(1,0);
        memset(g,0,sizeof g);
        for(int i=1;i<=n;i++)g[i]=(f[1][i][0]+2ll*f[1][i][1]+2ll*f[1][i][2])%mod;
        if(p!=m)
        {
            memset(b,0,sizeof b);
            for(int j=1;j<=n;j++)
            if(g[j])for(int k=0,t=1;k<=j;k++,t=mod-t)
            b[j-k]=(b[j-k]+1ll*t*S(j,j-k)%mod*g[j])%mod;
            for(int i=0;i<=sum;i++)
            if(dp[p-1][i])for(int j=0;j<=n;j++)
            dp[p][i+j]=(dp[p][i+j]+1ll*C(i+j,j)*b[j]%mod*dp[p-1][i])%mod;
        }
        else{
            memset(b,0,sizeof b);
            for(int j=1;j<=n;j++)
            if(g[j])for(int k=0,t=1;k<j;k++,t=mod-t)
            b[j-1-k]=(b[j-1-k]+1ll*t*S(j-1,j-k-1)%mod*g[j])%mod;
            for(int i=0;i<=sum;i++)
            if(dp[p-1][i])for(int j=0;j<=n;j++)
            ans=(ans+1ll*C(i-2+j,j)*b[j]%mod*dp[p-1][i])%mod;
        }
        sum+=n;
    }
    printf("%d",ans);
}
View Code
原文地址:https://www.cnblogs.com/hfctf0210/p/10864673.html