题解 CF960G 【Bandit Blues】

简化题面

给你三个正整数 (n)(a)(b),定义 (A) 为一个排列中是前缀最大值的数的个数,定义 (B) 为一个排列中是后缀最大值的数的个数,求长度为 (n) 的排列中满足 (A=a)(B=b) 的排列个数。(n le 10^5),答案对 (998244353) 取模。

解题思路

考场第一眼看了之后直接考虑爆搜,没看出来是个dp,考完之后被告知是(dp)之后就开始推了推,考虑设(dp[i][j])(i)个数的排列有(j)个前缀最大值的方案个数,然后手玩样例可以得出递推式

(dp[i][j]=dp[i-1][j-1]+(i-1) imes dp[i-1][j])

然后你会发现这是第一类斯特林数,然后考虑如何统计答案,我们考虑对于一个长度为(n)的排列,枚举(n)所在的位置,我们可以得到统计答案的柿子

[sum_{i=1}^{n-1}egin{bmatrix}i\a-1end{bmatrix} imes egin{bmatrix}n-1-i\b-1end{bmatrix} imes C_{n-1}^{i} ]

然后我们考虑这个柿子的意义,(c)为枚举前半部分的数是那些,然后我们将我们所枚举出来的数排成轮换,然后再将剩下的除了(n)以外的数在排成轮换

你会发现这个过程极其的繁琐,我们为什么不直接先排成轮换然后再选呢,先把(n-1)个数排成(a+b-2)个轮换然后从(a+b-2)个轮换里面选出(a-1)个到前面,剩下的到后面,然后就有了极其简化的柿子

[egin{bmatrix}n-1\a+b-2end{bmatrix} imes C_{a+b-2}^{a-1} ]

然后你会发现数据是(10^5)没法(n^2)预处理,这里就要靠多项式分治来优化求第一类斯特林数了,直接预处理变成(O(nlogn))了,然后你就可以过了

#include<bits/stdc++.h>
    
#define LL long long
    
#define _ 0
    
using namespace std;
    
/*Grievous Lady*/
    
template <typename _n_> void read(_n_ & _x_){
    _x_ = 0;int _m_ = 1;char buf_ = getchar();
    while(buf_ < '0' || buf_ > '9'){if(buf_ == '-')_m_ =- 1;buf_ = getchar();}
    do{_x_ = _x_ * 10 + buf_ - '0';buf_ = getchar();}while(buf_ >= '0' && buf_ <= '9');_x_ *= _m_;
}
    
#define mod 998244353

const int kato = 1e6 + 10;

LL fac[kato] , Ans[kato];

#define int long long

inline LL quick_pow(LL a , LL b){
    LL res = 1;
    for(; b ; b >>= 1 , a = a * a % mod){
        if(b & 1){
            res = res * a % mod;
        }
    }
    return res % mod;
}

int n , a , b;

inline void NTT(int *y , int len , int opt){
    int *rev = new int[len];
    rev[0] = 0;
    for(int i = 1;i < len;i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * (len >> 1));
    for(int i = 0;i < len;i ++) if (i < rev[i]) swap(y[i] , y[rev[i]]);
    for(int i = 1;i < len;i <<= 1){
        int G1 = quick_pow(3 , (mod - 1) / (i << 1));
        for(int j = 0;j < len;j += (i << 1)){
            for(int k = 0 , g = 1;k < i;k ++ , g = 1LL * g * G1 % mod){
                int res = 1LL * y[i + k + j] * g % mod;
                y[i + k + j] = ((y[j + k] - res) + mod) % mod;
                y[j + k] = (y[j + k] + res) % mod;
            }
        }
    }
    if(opt == -1){
        reverse(y + 1 , y + len);
        for(int i = 0 , inv = quick_pow(len , mod - 2);i < len;i ++) y[i] = 1LL * y[i] * inv % mod;
    }
    delete []rev;
}

inline void mul(int *x , int *y , int *c , int len){
    int len1 = 1;
    while(len1 <= 2 * len) len1 <<= 1;
    int *g = new int[len1] , *f = new int[len1];
    for(int i = 0;i < len;i ++) g[i] = x[i] , f[i] = y[i];
    NTT(g , len1 , 1) , NTT(f , len1 , 1);
    for(int i = 0;i < len1;i ++) c[i] = 1LL * (g[i] * f[i]) % mod;
    NTT(c , len1 , -1);
    delete []g , g = NULL;
    delete []f , f = NULL;
}

inline LL c(LL a , LL b){
    if(a < b) return 0;
    return fac[a] * quick_pow(fac[b] * fac[a - b] % mod , mod - 2) % mod;
}

#define init() 
{ 
    for(int i = 1;i <= 100000;i ++){ 
        fac[i] = fac[i - 1] * i % mod; 
    } 
}

#define del() 
{ 
    memset(L , 0 , sizeof(int)* len * 2); 
    memset(R , 0 , sizeof(int)* len * 2); 
}

inline void CDQ(int l , int r , int *ans){
    if(l == r) {
        ans[0] = l , ans[1] = 1;
        return;
    }
    // cout << "QAQ" << '
';
    int len = r - l + 1 , mid = (l + r) >> 1;
    int *L = new int[len * 2] , *R = new int[len * 2];
    // cout << "QWQ" << '
';
    del();
    CDQ(l , mid , L) , CDQ(mid + 1 , r , R);
    int len1 = 1;
    while(len1 <= len) len1 <<= 1;
    NTT(L , len1 , 1) , NTT(R , len1 , 1);
    for(int i = 0;i < len1;i ++) ans[i] = L[i] * R[i] % mod;
    NTT(ans , len1 , -1);
	delete []L , L = NULL;
    delete []R , R = NULL;
} 

inline void gets(int n){
    if(n == 0)return (void)(Ans[0] = 1);
    CDQ(0 , n - 1 , Ans);
}

inline int Ame_(){
    read(n) , read(a) , read(b);
    fac[0] = fac[1] = 1;
    init();
    // cout << "QAQ" << '
';
    // for(int i = 0;i < n;i ++) abi[i] = gali[i] = i;
    // Ans = mul(abi , gali , n);
    // for(int i = 0;i < n;i ++){
    //     cout << Ans[i] << ' ';
    // }
    gets(n - 1);
    // cout << Ans[1] << '
';
    // cout << "QAQ" << '
';
    printf("%lld
" , (1LL * Ans[a + b - 2] * c(a + b - 2 , a - 1) % mod)); 
    return ~~(0^_^0);
}
    
int Ame__ = Ame_();
    
signed main(){;}
原文地址:https://www.cnblogs.com/Ame-sora/p/13573821.html