P4721 【模板】分治 FFT

题意

给定序列 (g_1,cdots ,g_n) ,求序列 (f_0,cdots ,f_{n-1})。其中,

[f_i=sum_{j=1}^{i}{f_{i-j}·g_j} ]

边界为:(f_0=1) ,答案对 (998244353) 取模。

(2leq n leq 10^5,0leq g_i<998244353)

分析

(CDQ) 分治的思想与 (FFT) 结合。

(CDQ) 分治,将序列 ([l,r]) 以中点 (mid) 为分界线,将整个序列分为 ([l,mid])([mid+1,r]) 两部分。每次求区间 ([l,r]) 的时候,先求出左边区间的答案,然后利用 (FFT) 求出左边区间对右边区间的贡献,最后再计算右边区间的值。当区间内只有一个元素时,直接返回。

复杂度:(O(nlog^2n))

代码

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
const int mod=998244353;
const int pr=3;//mod的原根
const int N=2e5+5;
int g[N],f[N],A[N],B[N],rev[N];
void read(int &x)
{
    x=0;
    int f=1;
    char ch=getchar();
    while(!isdigit(ch))
    {
        if(ch=='-') f=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        x=(x<<3)+(x<<1)+ch-'0';
        ch=getchar();
    }
    x*=f;
}
ll power(ll x,ll y)
{
    ll res=1;
    x%=mod;
    while(y)
    {
        if(y&1) res=res*x%mod;
        x=x*x%mod;
        y>>=1;
    }
    return res;
}
void NTT(int *pn,int len,int f)
{
    for(int i=0;i<len;i++)
        if(i<rev[i]) swap(pn[i],pn[rev[i]]);
    for(int i=1;i<len;i<<=1)
    {
        int wn=power(pr,(mod-1)/(2*i));
        if(f==-1) wn=power(wn,mod-2);
        for(int j=0,d=(i<<1);j<len;j+=d)
        {
            int w=1;
            for(int k=0;k<i;k++)
            {
                int u=pn[j+k],v=1LL*w*pn[j+k+i]%mod;
                pn[j+k]=1LL*(u+v)%mod;
                pn[j+k+i]=1LL*((u-v)%mod+mod)%mod;
                w=1LL*w*wn%mod;
            }
        }
    }
    if(f==-1)
    {
        int inv_len=power(len,mod-2);
        for(int i=0;i<len;i++)
            pn[i]=1LL*pn[i]*inv_len%mod;
    }
}
void sol(int l,int r)//区间:[l,r)
{
    if(l+1>=r) return;
    int mid=(l+r)>>1;
    sol(l,mid);//先计算左边
    int len=r-l;//要计算贡献的左边的区间的长度
    for(int i=0;i<len;i++)
        rev[i]=(rev[i>>1]>>1)|((i&1)?(len>>1):0);
    for(int i=0;i<len;i++) B[i]=g[i];
    for(int i=l;i<mid;i++) A[i-l]=f[i];
    for(int i=mid;i<r;i++) A[i-l]=0;
    NTT(A,len,1);
    NTT(B,len,1);
    for(int i=0;i<len;i++) A[i]=1LL*A[i]*B[i]%mod;
    NTT(A,len,-1);
    for(int i=mid;i<r;i++) f[i]=(f[i]+A[i-l])%mod;
    //先把左边对右边的贡献累加,再计算右边
    sol(mid,r);
}
int main()
{
    int n;
    read(n);
    for(int i=1;i<n;i++) read(g[i]);
    f[0]=1;
    int len=1,cnt=0;
    while(len<n) len<<=1;
    sol(0,len);
    for(int i=0;i<n;i++) printf("%d%c",f[i],i==n-1?'
':' ');
    return 0;
}
原文地址:https://www.cnblogs.com/1024-xzx/p/13847842.html