P4491 [HAOI2018]染色

题目描述

洛谷

为了报答小 C 的苹果, 小 G 打算送给热爱美术的小 C 一块画布, 这块画布可 以抽象为一个长度为 (N) 的序列, 每个位置都可以被染成 (M) 种颜色中的某一种.

然而小 C 只关心序列的 (N) 个位置中出现次数恰好为 (S) 的颜色种数, 如果恰好出现了 (S) 次的颜色有 (K) 种, 则小 C 会产生 (W_k) 的愉悦度.

小 C 希望知道对于所有可能的染色方案, 他能获得的愉悦度的和对 (1004535809) 取模的结果是多少.

数据范围:(nleq 10^7,mleq 10^5,Sleq 150,0leq w_i < 1004535809)

solution

容斥+二项式反演+(NTT)

(g(k)) 表示恰好出现了 (S) 次的颜色恰好有 (k) 种的方案数。

那么答案显然为:(displaystylesum_{i=0}^{m} g(k) imes w_k)

考虑怎么求出 (g(k)) 来,因为恰好不太好求,所以可以用二项式反演转化一下。

(f(k)) 表示恰好出现了 (S) 次的颜色至少有 (k) 种的方案数。

显然: (displaystyle f(k) = {mchoose k} imes {nchoose ks} imes {ks!over (s!)^{k}} imes {(m-k)^{n-ks}})

解释一下上面的柿子是怎么来的。

首先我们先从 (m) 个颜色里面选出 (k) 个,让他恰好出现 (m) 次,方案数显然为 (displaystyle {mchoose k}), 之后我们还要在 (n) 个位置里面选出 (ks) 个位置,使他们的颜色为这 (k) 中颜色的一种,方案数为 (displaystyle {nchoose ks}), 又因为每个颜色之间的顺序不同,所以还要在乘上一个 (displaystyle {ks!over ({s!})^i}) ,剩下的 (m-ks) 的位置中可以涂 (m-k) 种颜色的任意一种,方案数为 (displaystyle {(m-k)^{n-ks}})

显然有:(displaystyle {f(k) = sum_{i=k}^{min(n,m/s)}} {ichoose k} imes g(i))

根据二项式反演可得(以下默认 (upper = {min(n,m/s)})):

(displaystyle g(k) = sum_{i=k}^{upper} (-1)^{i-k} imes {ichoose k} imes f(i))

把组合数拆开尝试构造一下卷积:

(displaystyle {g(k) = sum_{i=k}^{upper} (-1)^{i-k} imes {i!over {k! imes (i-k!)}} imes f(i)})

(k!) 移到左边去可得:

(displaystyle g(k) imes k! = sum_{i=k}^{upper} {(-1)^{i-k}over (i-k)!} imes i!f(i))

我们构造多项式,(A(x),B(x),C(x)) 其中:

(displaystyle A(x) = sum_{i=0}^{infin} g(i)i! imes x^i)

(displaystyle B(x) = sum_{i=0}^{infin} i!f(i) x^i)

(displaystyle C(x) = sum_{i=0}^{infin} {{(-1)}^{i}over i!} x^i)

我们会发现这其实是一个差卷积的形式,我们可以通过构造一下使他变为正常的加法卷积的形式。

具体来说:首先把 (B(x)) 的每一项系数反转得到 (B'(x)) ,设 (A'(x) = B'(x) * C(x)), 那么 (A'(x))(i) 项的系数其实就是 (A(x))(n-i) 项的系数。

简单证明一下,原来的时候是 (B[i] imes C[i-k] = A[k]), 经过反转后变为:(B[n-i] imes C[i-k] = A'[k’]) ,因为我们后面的是加法卷积的形式,所以 (k' = n-i+i-k = n-k) 。因此 (A'(x)) 的第 (i) 项的系数其实就是 (A(x))(n-i) 项的系数。

我们用 (NTT) 求出来这个 (A'(x)) 之后,把系数反转一下就可以得到 (A(x))

预处理出 (f(i))(g(i)) 就可以直接求了。

复杂度:(O(nlogn))

注意:卷积数组要开大点,开 (1e7) 差不多就可以过了。

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
#define int long long
const int p = 1004535809;
const int N = 1e7+10;
int n,m,s,len,ans,w[N],rev[N],jz[N],inv[N],a[N],b[N];
inline int read()
{
    int s = 0,w = 1; char ch = getchar();
    while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
    return s * w;
}
int ksm(int a,int b)
{
    int res = 1;
    for(; b; b >>= 1)
    {
        if(b & 1) res = res * a % p;
        a = a * a % p;
    }
    return res;
}
int C(int n,int m)
{
    return jz[n] * inv[m] % p * inv[n-m] % p;
}
void NTT(int *a,int lim,int opt)
{
    for(int i = 0; i < lim; i++)
    {
        if(i < rev[i]) swap(a[i],a[rev[i]]);
    }
    for(int h = 1; h < lim; h <<= 1)
    {
        int wn = ksm(3,(p-1)/(h<<1));
        if(opt == -1) wn = ksm(wn,p-2);
        for(int j = 0; j < lim; j += (h<<1))
        {
            int w = 1;
            for(int k = 0; k < h; k++)
            {
                int u = a[j + k];
                int v = w * a[j + h + k] % p;
                a[j + k] = (u + v) % p;
                a[j + h + k] = (u - v + p) % p;
                w = w * wn % p;
            }
        }
    }
    if(opt == -1)
    {
        int inv = ksm(lim,p-2);
        for(int i = 0; i < lim; i++) a[i] = a[i] * inv % p;
    }
}
signed main()
{
    n = read(); m = read(); s = read(); len = min(m,n/s) + 1;
    jz[0] = inv[0] = 1; 
    for(int i = 0; i <= m; i++) w[i] = read();
    for(int i = 1; i <= N-5; i++) jz[i] = jz[i-1] * i % p;
    inv[N-5] = ksm(jz[N-5],p-2);
    for(int i = N-6; i >= 1; i--) inv[i] = inv[i+1] * (i+1) % p;
    for(int i = 0; i < len; i++) 
    {
        int tmp = ksm(ksm(jz[s],i),p-2);
        a[i] = C(m,i) * C(n,i*s) % p * jz[i*s] % p * tmp % p * ksm(m-i,n-i*s) % p;
        a[i] = a[i] * jz[i];
    }
    for(int i = 0, tmp = 1; i < len; i++, tmp *= -1) b[i] = (tmp * inv[i] + p) % p;
    reverse(a,a+len);
    int lim = 1, tim = 0;
    while(lim < (len<<1)) tim++, lim <<= 1;
    for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
    NTT(a,lim,1); NTT(b,lim,1);
    for(int i = 0; i < lim; i++) a[i] = a[i] * b[i] % p;
    NTT(a,lim,-1);
    reverse(a,a+len);
    for(int i = 0; i < len; i++) a[i] = a[i] * inv[i] % p;
    for(int i = 0; i < len; i++) ans = (ans + w[i] * a[i] % p) % p;
    printf("%lld
",ans);
    return 0; 
}
原文地址:https://www.cnblogs.com/genshy/p/14537622.html