2017 CCPC 杭州 HDU 6270 Marriage (NTT,容斥)

题目:传送门

题意

有 n 个家庭,每个家庭有 ai 个男孩和 bi 个女孩,n 个家庭总的男孩等于总的女孩。对于来自 i 家庭的男孩他只能和不来自 i 家庭的女孩结婚,也就是来自同个家庭的男孩女孩不能结婚。问有多少种方案,使得这 这些男孩女孩都能成功结婚。

思路

参考博客:

对于一个有 x 个男孩和 y 个女孩的家庭来说,有且仅有 k 对来自这个家庭的男孩女孩结婚(近亲结婚)的方案数是:

C(x, k) * C(y,k) * k!

那么如果在第一个家庭选 k1 对近亲结婚,第二个家庭 k2 对......第 n 个家庭 kn 对,剩下的自由组合,最后这种方案至少有 k1+k2...+kn 对近亲结婚。

那我们对每个家庭构造一个多项式:

c0 + c1*x + c2*x^2 + .... + cm*x^m  (m = min(x, y))

把这 n 个多项式乘起来,得到的多项式的 x^k 的系数 ck 代表的就是至少有 k 对近亲结婚的方案数。

因为代表的是至少,所以最后还需要容斥一下。

n个多项式相乘,复杂度跟多项式的长度有很大关系,n个多项式的长度就是所有男孩的总数;所以复杂度其实是 o(nlognlogn)的

#include <bits/stdc++.h>
#define LL long long
#define ULL unsigned long long
#define UI unsigned int
#define mem(i, j) memset(i, j, sizeof(i))
#define rep(i, j, k) for(int i = j; i <= k; i++)
#define dep(i, j, k) for(int i = k; i >= j; i--)
#define pb push_back
#define make make_pair
#define INF 0x3f3f3f3f
#define inf LLONG_MAX
#define PI acos(-1)
#define fir first
#define sec second
#define lb(x) ((x) & (-(x)))
#define dbg(x) cout<<#x<<" = "<<x<<endl;
using namespace std;

const int N = 1e6 + 5;
const LL mod = 998244353;
const LL g = 3;

int n, all, cnt;
LL fac[N];
LL x1[N], x2[N];
vector < LL > a[N];

LL ksm(LL a, LL b) {
    LL res = 1LL;
    while(b) {
        if(b & 1) res = res * a % mod;
        a = a * a % mod; b >>= 1;
    }
    return res;
}

LL C(int n, int m) { return m > n ? 0 : fac[n] * ksm(fac[m] * fac[n - m] % mod, mod - 2) % mod; }

struct cmp{
    bool operator()(int A, int B) {
        return a[A].size() > a[B].size();
    }
};
priority_queue <LL, vector<LL>, cmp> Q;

void change(LL y[], int len){
    for (int i = 1, j = len / 2; i < len - 1; i++){
        if (i < j) swap(y[i], y[j]);
        int k = len / 2;
        while (j >= k){
            j -= k;
            k /= 2;
        }
        if (j < k) j += k;
    }
}

void ntt(LL y[], int len, int on){
    change(y, len);
    for (int h = 2; h <= len; h <<= 1){
        LL wn = ksm(g, (mod - 1) / h);
        if (on == -1) wn = ksm(wn, mod - 2);
        for (int j = 0; j < len; j += h){
            LL w = 1ll;
            for (int k = j; k < j + h / 2; k++){
                LL u = y[k];
                LL t = w * y[k + h / 2] % mod;
                y[k] = (u + t) % mod;
                y[k + h / 2] = (u - t + mod) % mod;
                w = w * wn % mod;
            }
        }
    }

    if (on == -1){
        LL t = ksm(len, mod - 2);
        rep(i, 0, len - 1) y[i] = y[i] * t % mod;
    }
}

void mul(vector <LL> &a, vector <LL> &b, vector <LL> &c){
    int len = 1;
    int sz1 = a.size(), sz2 = b.size();

    while (len <= sz1 + sz2 - 1) len <<= 1;

    rep(i, 0, sz1 - 1) x1[i] = a[i];
    rep(i, sz1, len)   x1[i] = 0;

    rep(i, 0, sz2 - 1) x2[i] = b[i];
    rep(i, sz2, len)   x2[i] = 0;

    ntt(x1, len, 1);
    ntt(x2, len, 1);

    rep(i, 0, len - 1) x1[i] = x1[i] * x2[i];

    ntt(x1, len, -1);

    vector <LL>().swap(c);
    rep(i, 0, sz1 + sz2 - 2) c.push_back(x1[i]);
}

void solve() {

    scanf("%d", &n);
    rep(i, 0, n) vector < LL >().swap(a[i]);
    while(!Q.empty()) Q.pop();

    all = 0;

    rep(i, 1, n) {
        int x, y;
        scanf("%d %d", &x, &y);
        a[i].resize(min(x,y)+1);
        rep(j, 0, min(x,y)) a[i][j] = C(x, j) * C(y, j) % mod * fac[j] % mod;
        Q.push(i);  all += x;
    }

    cnt = n;

    rep(i, 1, n - 1) {
        int pos1 = Q.top(); Q.pop();
        int pos2 = Q.top(); Q.pop();

        mul(a[pos1], a[pos2], a[++cnt]);

        vector < LL >().swap(a[pos1]);
        vector < LL >().swap(a[pos2]);

        Q.push(cnt);
    }

    LL ans = 0LL, flag = 1LL;

    rep(i, 0, (int)(a[cnt].size()) - 1) {
        ans = ans + flag * fac[all - i] * a[cnt][i] % mod;
        ans = (ans + mod) % mod;
        flag *= -1;
    }
    printf("%lld
", ans);
}


int main() {

    fac[0] = 1LL; rep(i, 1, N - 5) fac[i] = 1LL * i * fac[i - 1] % mod;

    int _; scanf("%d", &_);
    while(_--) solve();

//    solve();

    return 0;
}
原文地址:https://www.cnblogs.com/Willems/p/13839974.html