数论杂记——快速求解组合数 C(n,m) 取模

模板:

#include <bits/stdc++.h>
using namespace std; 
typedef long long ll;
const ll mod = 998244353;
const int Max = 1e6 + 10;
ll fact[Max],ifact[Max];
ll n,m;
ll pow_mod(ll n,ll k)
{
    ll res=1;
    n=n%mod;
    while (k>0)
    {
        if (k&1)
            res=res*n%mod;
        n=n*n%mod;
        k>>=1;
    }
    return res;
}
void init()    
{
    fact[0]=ifact[0]=1;
    for (int i=1;i<Max;i++)
    {
        fact[i]=(fact[i-1]*i)%mod;
        ifact[i]=pow_mod(fact[i],mod-2);
    }
}
ll C(ll n, ll m)
{
    if (n<m||m<0) return 0;
    return (fact[n]*ifact[m]%mod)*ifact[n-m]%mod;
}
int main()
{
    init();
    cout<<C(?,?)<<endl;
    return 0;
}

例题:CodeForces1312D

#include <bits/stdc++.h>
using namespace std; 
typedef long long ll;
const ll mod = 998244353;
const int Max = 1e6 + 10;
ll fact[Max],ifact[Max];
ll n,m;
ll pow_mod(ll n,ll k)
{
    ll res=1;
    n=n%mod;
    while (k>0)
    {
        if (k&1)
            res=res*n%mod;
        n=n*n%mod;
        k>>=1;
    }
    return res;
}
void init()    
{
    fact[0]=ifact[0]=1;
    for (int i=1;i<Max;i++)
    {
        fact[i]=(fact[i-1]*i)%mod;
        ifact[i]=pow_mod(fact[i],mod-2);
    }
}
ll C(ll n, ll m)
{
    if (n<m||m<0) return 0;
    return (fact[n]*ifact[m]%mod)*ifact[n-m]%mod;
}
int main()
{
    init();
    cin>>n>>m;
    ll ans=0;
    for (int i=n-1;i<=m;i++)
    {
        ans+=C(i-1,n-2)%mod*(n-2)%mod*pow_mod(2,n-3)% mod;
    }
    cout<<ans%mod;
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/Y-Knightqin/p/12709360.html