P4705 玩游戏

思路

超级麻烦。。。
写了一堆最后常数太大T飞了。。。
真的难受
发现solve函数可以不用把下一层复制上来,直接传指针就可以,下次再说写不写叭

思路

[ans_k=sum_{i=1}^nsum_{j=1}^m (a_i+b_j)^k ]

二项式定理拆一下式子

[egin{align}ans_k=&sum_{i=1}^nsum_{j=1}^m (a_i+b_j)^k\=&sum_{i=1}^nsum_{j=1}^msum_{t=0}^kleft(egin{matrix}k\tend{matrix} ight)a_i^{k-t}b_j^t\=&sum_{t=0}^kleft(egin{matrix}k\tend{matrix} ight)sum_{i=1}^na_i^{k-t}sum_{j=1}^mb_j^t\=&k!sum_{r=0}^k(sum_{i=1}^nfrac{a_i^r}{r!})(sum_{j=1}^mfrac{b_j^r}{r!})end{align} ]

所以只要求出(sum_i a_i^k)即可

(a_i),设其生成函数(A(x)=1+a_ix+a_i^2x^2+a_i^3x^3+dots)

[A_k(x)=sum_{i=0}^{infty}a_k^ix^i=frac{1}{1-a_kx} ]

最后答案的生成函数(G(x))就是(sum_{i=0}^n A_i(x))

然后一个常见套路就是把(frac{1}{x})(ln 'x)代替

所以有

[G(x)=sum_{i=1}^n frac{1}{1-a_ix}=sum_{i=1}^n ln'(1-a_ix) ]

但是这样依然无法快速计算

我们可以再设一个(F(x))

[F(x)=sum_{i=1}^n (ln(1-a_ix))'=sum_{i=1}^nfrac{-a_i}{1-a_ix}\ G(x)=sum_{i=1}^n frac{1}{1-a_ix} ]

所以(G(x)=-xF(x)+n)

然后对于(F(x))

[F(x)=sum_{i=1}^n (ln(1-a_ix))'=(ln(prod_{i=1}(1-a_ix)))' ]

分治加NTT就可以在(O(nlog^2n))的时间内解决

常数过大T掉的代码

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cctype>
#include <cstdlib>
#include <assert.h>
#define int long long
using namespace std;
const int MAXN = 600000;
const int MAXL = 100100;
const int G = 3;
const int invG = 332748118;
const int MOD = 998244353;
 
const int InputBufferSize = 67108864;//输入缓冲区大小
const int OutputBufferSize = 67108864;//输出缓冲区大小 
 
namespace input
{
    char buffer[InputBufferSize],*s,*eof;
    inline void init()
    {
        assert(stdin!=NULL);
        s=buffer;
        eof=s+fread(buffer,1,InputBufferSize,stdin);
    }
    inline bool read(int &x)
    {
        x=0;
        int flag=1;
        while(!isdigit(*s)&&*s!='-')s++;
        if(eof<=s)return false;
        if(*s=='-')flag=-1,s++;
        while(isdigit(*s))x=x*10+*s++-'0';
        x*=flag;
        return true;
    }
    inline bool read(char* str)
    {
        *str=0;
        while(isspace(*s))s++;
        if(eof<s)return false;
        while(!isspace(*s))*str=0,*str=*s,str++,s++;
        *str=0;
        return true;
    }
}
 
namespace output
{
    char buffer[OutputBufferSize];
    char *s=buffer;
    inline void flush()
    {
        assert(stdout!=NULL);
        fwrite(buffer,1,s-buffer,stdout);
        s=buffer;
        fflush(stdout);
    }
    inline void print(const char ch)
    {
        if(s-buffer>OutputBufferSize-2)flush();
        *s++=ch;
    }
    inline void print(char* str)
    {
        while(*str!=0)print(char(*str++));
    }
    inline void print(int x)
    {
        char buf[25]= {0},*p=buf;
        if(x<0)print('-'),x=-x;
        if(x==0)print('0');
        while(x)*(++p)=x%10,x/=10;
        while(p!=buf)print(char(*(p--)+'0'));
    }
}
 
using namespace input;
using namespace output;

int pow(int a,int b){
    int ans=1;
    while(b){
        if(b&1)
            ans=(1LL*ans*a)%MOD;
        a=(1LL*a*a)%MOD;
        b>>=1;
    }
    return ans;
}
void FFT(int *a,int n,int opt,int lim){
    for(int i=0;i<n;++i){
        int t=0;
        for(int j=0;j<lim;++j)
            if((i>>j)&1)
                t|=(1LL<<(lim-j-1));
        if(i<t)
            swap(a[i],a[t]);
    }
    for(int i=2;i<=n;i<<=1){
        int len=i/2;
        int tmp=pow((opt)?G:invG,(MOD-1)/i);
        for(int j=0;j<n;j+=i){
            int arr=1;
            for(int k=j;k<j+len;++k){
                int t=(1LL*a[k+len]*arr)%MOD;
                a[k+len]=(a[k]-t+MOD)%MOD;
                a[k]=(a[k]+t)%MOD;
                arr=(1LL*arr*tmp)%MOD;;
            }
        }
    }
    if(!opt){
        int invN=pow(n,MOD-2);
        for(int i=0;i<n;++i)
            a[i]=1LL*a[i]*invN%MOD;        
    }
}
void mul(int *a,int *bx,int &at,int bt){
    static int b[MAXN];
    int lim=0,num=at+bt,logt;
    while((1<<lim)<=(num+2))
        lim++;
    logt=lim;
    lim=(1<<lim);
    for(int i=0;i<lim;++i)
        b[i]=bx[i];
    FFT(a,lim,1,logt);
    FFT(b,lim,1,logt);
    for(int i=0;i<lim;++i)
        a[i]=(1LL*a[i]*b[i])%MOD;
    FFT(a,lim,0,logt);
    for(int i=num+1;i<lim;++i)
        a[i]=0;
    at=num;
}
void inv(int *a,int *b,int &bt,int dep,int &midlen,int &logt){
    if(dep==1){
        b[0]=pow(a[0],MOD-2);
        bt=0;
        return;
    }
    inv(a,b,bt,(dep+1)>>1,midlen,logt);
    static int tmp1[MAXN];
    for(int i=0;i<dep;++i)
        tmp1[i]=a[i];
    while((dep<<1)>midlen)
        midlen<<=1,logt++;
    for(int i=dep;i<midlen;++i)
        tmp1[i]=0;
    FFT(tmp1,midlen,1,logt);
    FFT(b,midlen,1,logt);
    for(int i=0;i<midlen;++i)
        b[i]=1LL*b[i]*(2-1LL*tmp1[i]*b[i]%MOD+MOD)%MOD;
    FFT(b,midlen,0,logt);
    for(int i=dep;i<midlen;++i)
        b[i]=0;
    bt=dep-1;
}
void jf(int *a,int &at){
    for(int i=at;i>=0;--i)
        a[i+1]=(1LL*a[i]*pow(i+1,MOD-2))%MOD;
    a[0]=0;
    at++;
}
void qd(int *a,int &at){
    for(int i=0;i<at;++i)
        a[i]=(1LL*a[i+1]*(i+1))%MOD;
    a[at]=0;
    at--;
}
void ln(int *a,int *b,int at,int &bt,int n){
    int midlen=1,logt=0;
    inv(a,b,bt,at+1,midlen,logt);
    qd(a,at);
    mul(b,a,bt,at);
    jf(b,bt);
    for(int i=n;i<=bt;++i)
        b[i]=0;
    bt=n-1;
}
int val[MAXL];
int P[20][MAXN],Pt[20];
void solve(int l,int r,int dep){
    if(l==r){
        for(int i=2;i<=Pt[dep];++i)
            P[dep][i]=0;
        P[dep][0]=1;
        P[dep][1]=MOD-val[l];
        Pt[dep]=1;
        return;
    }
    int mid=(l+r)>>1;
    solve(l,mid,dep+1);
    for(int i=0;i<=Pt[dep+1];++i)
        P[dep][i]=P[dep+1][i];
    for(int i=Pt[dep+1]+1;i<=Pt[dep];++i)
        P[dep][i]=0;
    Pt[dep]=Pt[dep+1];
    solve(mid+1,r,dep+1);
    mul(P[dep],P[dep+1],Pt[dep],Pt[dep+1]);
}
int jc[MAXL],jc_inv[MAXL];
int n,m,t;
void initx(void){
    jc[0]=1;
    int up=max(max(n,m),t)+1;
    for(int i=1;i<up;++i)
        jc[i]=(1LL*jc[i-1]*i)%MOD;
    jc_inv[up-1]=pow(jc[up-1],MOD-2);
    for(int i=up-2;i>=0;--i){
        jc_inv[i]=(1LL*jc_inv[i+1]*(i+1))%MOD;
    }
}
void getf(int *b,int &bt,int n){
    solve(1,n,0);
    int midlen=1,midlog=0;
    Pt[0]=max(n+1,1LL*t+1);
    inv(P[0],b,bt,Pt[0]+1,midlen,midlog);   
    qd(P[0],Pt[0]);
    mul(b,P[0],bt,Pt[0]);
    for(int i=bt;i>=0;--i)
        b[i+1]=MOD-b[i];
    b[0]=n;
    for(int i=0;i<=bt;++i){
        b[i]=(1LL*b[i]*jc_inv[i])%MOD;
    }
}
int ap[MAXN],bp[MAXN];
int ax[MAXL],bx[MAXL];
signed main(){
    freopen("test.in","r",stdin);
    freopen("test.out","w",stdout);
    // scanf("%d %d",&n,&m);
    init();
    read(n);
    read(m);
    for(int i=1;i<=n;++i)
        read(ax[i]);
        // scanf("%d",&ax[i]);
    for(int i=1;i<=m;++i)
        read(bx[i]);
        // scanf("%d",&bx[i]);
    // scanf("%d",&t);
    read(t);
    initx();
    for(int i=1;i<=n;++i)
        val[i]=ax[i];
    int apt=0,bpt=0;
    getf(ap,apt,n);
    for(int i=1;i<=m;++i)
        val[i]=bx[i];
    getf(bp,bpt,m);
    mul(ap,bp,apt,bpt);
    int n_inv=pow(n,MOD-2),m_inv=pow(m,MOD-2);
    for(int i=1;i<=t;++i){
        print(1LL*jc[i]*ap[i]%MOD*n_inv%MOD*m_inv%MOD);
        print('
');
    }
    flush();
    return 0;    
}
原文地址:https://www.cnblogs.com/dreagonm/p/10729748.html