Npc50F [生成函数+NTT]

Npc50FNpc50F

题目描述见标题链接 .


color{red}{正解部分}

对比赛 ii 构造生成函数: 2bi1+Cai1x1+Cai2x2+...+Caiaixai2^{b_i}-1+C_{a_i}^1x^1+C_{a_i}^2x^2+...+C_{a_i}^{a_i}x^{a_i} ,
然后将每场比赛的生成函数使用 NTTNTT 启发式合并, 得到多项式 ff ,
最后的 AnsjAns_j 即为 多项式ff(i=1jai)(sum_{i=1}^ja_i) 次项系数 .


color{red}{实现部分}

  • 注意多项式乘数组要清空 !
  • 不能暴力合并多项式 !
#include<bits/stdc++.h>
#define reg register
#define pb push_back

int read(){
        char c;
        int s = 0, flag = 1;
        while((c=getchar()) && !isdigit(c))
                if(c == '-'){ flag = -1, c = getchar(); break ; }
        while(isdigit(c)) s = s*10 + c-'0', c = getchar();
        return s * flag;
}

const int maxn = 4e6 + 5;
const int mod = 998244353;

int N;
int bit_cnt;
int NTT_len;
int C[maxn];
int fac[maxn];
int rev[maxn];
int Fmp_1[maxn];
int Fmp_2[maxn];

std::vector <int> P[maxn];

struct Node{ int a, b; } A[maxn];

int Ksm(int a, int b){
        if(!a) return 1;
        int s = 1; a %= mod;
        while(b){ if(b&1)s=1ll*s*a%mod; a=1ll*a*a%mod; b>>=1; }
        return s;
}

void NTT(int *f, int opt){
        for(reg int i = 0; i < NTT_len; i ++) if(i < rev[i]) std::swap(f[i], f[rev[i]]);
        for(reg int p = 2; p <= NTT_len; p <<= 1){
                int half = p >> 1;
                int wn = Ksm(3, (mod-1)/p);
                if(opt == -1) wn = Ksm(wn, mod-2);
                for(reg int i = 0; i < NTT_len; i += p){
                        int buf = 1;
                        for(reg int k = i; k < i+half; k ++){
                                int Tmp_1 = 1ll*buf*f[k+half] % mod;
                                f[k+half] = (f[k] - Tmp_1 + mod) % mod;
                                f[k] = (f[k] + Tmp_1) % mod;
                                buf = 1ll*buf*wn % mod;
                        }
                }
        }
}

void Merge_poly(){
        int tot = N; 

while(tot > 1){
        int cnt = 0;
        for(reg int t = 1; t <= tot/2; t ++){
                for(reg int i = 0; i < P[t*2-1].size(); i ++) Fmp_1[i] = P[t*2-1][i];
                for(reg int i = 0; i < P[t<<1].size(); i ++) Fmp_2[i] = P[t<<1][i];
                NTT_len = 1, bit_cnt = 0;
                while(NTT_len <= (P[t*2-1].size()+P[t<<1].size())) NTT_len <<= 1, bit_cnt ++; 
                for(reg int i = P[t*2-1].size(); i <= NTT_len; i ++) Fmp_1[i] = 0;
                for(reg int i = P[t<<1].size(); i <= NTT_len; i ++) Fmp_2[i] = 0;
                for(reg int i = 0; i < NTT_len; i ++) rev[i] = (rev[i>>1]>>1)|((i&1)<<bit_cnt-1);
                NTT(Fmp_1, 1), NTT(Fmp_2, 1);
                for(reg int i = 0; i < NTT_len; i ++) Fmp_1[i] = 1ll*Fmp_1[i]*Fmp_2[i] % mod;
                NTT(Fmp_1, -1);
                int Tmp_len = P[t<<1].size()-1+P[t*2-1].size()-1, Inv = Ksm(NTT_len, mod-2);
                for(reg int i = 0; i < NTT_len; i ++) Fmp_1[i] = 1ll*Fmp_1[i]*Inv%mod;
                P[++ cnt].clear();
                for(reg int i = 0; i <= Tmp_len; i ++) P[cnt].pb(Fmp_1[i]);
        }
        if(tot & 1){
                P[++ cnt].clear();
                for(reg int i = 0; i < P[tot].size(); i ++) P[cnt].pb(P[tot][i]);
        }
        tot = cnt;
}                  

}

void Init(){ fac[0] = 1; for(reg int i = 1; i < maxn; i ++) fac[i] = 1ll*fac[i-1]*i % mod; }

int ZhC(int n, int m){ return 1ll*fac[n]*Ksm(fac[n-m], mod-2)%mod*Ksm(fac[m], mod-2)%mod; }

bool cmp(Node x, Node y){ return x.a < y.a; }

int main(){
        Init();
        N = read();
        for(reg int i = 1; i <= N; i ++) A[i].a = read();
        for(reg int i = 1; i <= N; i ++) A[i].b = read();
        std::sort(A+1, A+N+1, cmp);
        for(reg int i = 1; i <= N; i ++){
                P[i].pb(Ksm(2, A[i].b) - 1);
                for(reg int j = 1; j <= A[i].a; j ++) P[i].pb(ZhC(A[i].a, j));
        }
        Merge_poly();
        int sum_a = 0;
        for(reg int i = 0; i <= N; i ++) sum_a += A[i].a;
        for(reg int i = 0; i <= sum_a; i ++) printf("%d ", P[1][i]);
        return 0;
}
原文地址:https://www.cnblogs.com/zbr162/p/11822511.html