《AtCoder Regular Contest 130 D》

首先一开始有个误区,对于排列u的时候,把他的父节点和子节点放进来一起考虑了。

但是其实父节点到父节点的时候考虑,就是用u变成两种排列中另一种排列来考虑。

然后我们考虑一下合并子树:这里我们假设当前已经合并了的序列A大小为n,对于要合并的子树序列B为m

钦定根u在序列A的第i个位置,根v在序列B的第j个位置。

显然u和v只有u < v和u > v的情况,这里我们只需要算一种就可以了。最后方案 * 2即可。

我们计算u > v。

现在我们钦定ai即A的根u在排列后的序列的第i + j个位置。

枚举ai的原位置来转移,即a 在原序列第i个位置。

那么对于序列的排列情况,我们以ai为分解划分为:

前半部分即为A{a1...ai - 1} II B{b1.b2...bj}.

考虑求这个序列的组合方案:因为A,B序列的内部情况我们已经处理出来了,所以我们可以当成无差别的元素来计算方案即C(i - 1 + j,i - 1),即a序列插入位置中,然后乘上两序列的方案数dp.

对于后半部分A{ai + 1,ai + 2...an} II B{bj + 1,b2....bm},方案同理前面C(n - i +m - j,n - i)

因为我们计算的是u > v,那么对于u = ai,v = b1 ~ bj的情况都满足这个限制条件.

所以对于每个i方案为C(i - 1 + j,i - 1) * C(n - i +m - j,n - i) * dpA[i] * (dpB[1] + dpB[2] + ... dpB[j])

显然对于dpB可以前缀和预处理一下,然后我们就可以枚举i,j来解决这个问题。

// Author: levil
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef long double ld;
typedef pair<int,int> pii;
const int N = 3e3 + 5;
const int M = 1e4 + 5;
const LL Mod = 998244353;
#define INF 1e9
#define dbg(ax) cout << "now this num is " << ax << endl;
inline long long ADD(long long x,long long y) {
    if(x + y < 0) return ((x + y) % Mod + Mod) % Mod;
    return (x + y) % Mod;
}
inline long long MUL(long long x,long long y) {
    if(x * y < 0) return ((x * y) % Mod + Mod) % Mod;
    return x * y % Mod;
}
inline long long DEC(long long x,long long y) {
    if(x - y < 0) return (x - y + Mod) % Mod;
    return (x - y) % Mod;
}

LL fac[N],dp[N][N],inv[N],pre[N],f[N];//dp[i][j] - i在序列第j个位置
vector<int> G[N];
int n,sz[N];
LL quick_mi(LL a,LL b) {
    LL re = 1;
    while(b) {
        if(b & 1) re = re * a % Mod;
        a = a * a % Mod;
        b >>= 1;
    }
    return re;
}
void init() {
    fac[0] = 1;
    for(int i = 1;i < N;++i) fac[i] = fac[i - 1] * i % Mod;
    inv[N - 1] = quick_mi(fac[N - 1],Mod - 2) % Mod;
    for(int i = N - 2;i >= 0;--i) inv[i] = inv[i + 1] * (i + 1) % Mod;
}
LL C(int n,int m) {
    return fac[n] * inv[m] % Mod * inv[n - m] % Mod;
}
void dfs(int u,int fa) {
    sz[u] = 1;
    dp[u][1] = 1;
    for(auto v : G[u]) {
        if(v == fa) continue;
        dfs(v,u);
        memset(pre,0,sizeof(pre));
        memset(f,0,sizeof(f));
        for(int i = 1;i <= sz[v];++i) pre[i] = ADD(pre[i - 1],dp[v][sz[v] - i + 1]);
        for(int i = 1;i <= sz[u];++i) {
            for(int j = 1;j <= sz[v];++j) {
                LL tmp1 = C(i - 1 + j,i - 1),tmp2 = C(sz[u] - i + sz[v] - j,sz[u] - i);
                f[i + j] = ADD(f[i + j],MUL(MUL(MUL(C(i - 1 + j,i - 1),C(sz[u] - i + sz[v] - j,sz[u] - i)),dp[u][i]),pre[j]));
            }
        }
        sz[u] += sz[v];
        for(int i = 1;i <= sz[u];++i) dp[u][i] = f[i];
    }
}
void solve() {
    init();
    scanf("%d",&n);
    for(int i = 1;i < n;++i) {
        int u,v;scanf("%d %d",&u,&v);
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfs(1,0);
    LL ans = 0;
    for(int i = 1;i <= sz[1];++i) ans = ADD(ans,dp[1][i]);
    printf("%lld\n",MUL(2,ans));
}   
int main() {
    //int _;
    //for(scanf("%d",&_);_;_--) {
        solve();
    //}
    //system("pause");
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/zwjzwj/p/15621290.html