洛谷P4721-分治FFT(NTT分治)

题目背景

也可用多项式求逆解决。

题目描述

给定序列 $g_{1dots n - 1}$求序列 $f_{0dots n - 1}$.

其中 $f_i=sum_{j=1}^if_{i-j}g_j$,边界为$f_0=1$

答案对 998244353 取模。

输入格式

第一行一个整数 n 。

第二行 n-1个整数$ g_{1dots n - 1}$.

输出格式

一行 n个整数,表示$ f_{0dots n - 1}$ 对 998244353 取模后的值。

输入输出样例

输入
4
3 1 2
输出 
1 3 10 35
输入 
10
2 456 32 13524543 998244352 0 1231 634544 51
输出
1 2 460 1864 13738095 55389979 617768468 234028967 673827961 708520894

说明/提示

$2leq nleq 10^5$,$0leq g_i<998244353$.

emmmm,说是说分治FFT。。。。但实际上FFT是复数啊,你来给我取个模看看QAQ。所以说这题实际上是分治NTT。

我们先来解析一下这个式子,$f_i=sum_{j=1}^{i}f_{i-j}g_j$,这不就是$f_i=sum_{j+k eq i}^{i-1}f_j*g_k$然后$j$从1开始吗,这和NTT的板子题不是一样一样的嘛。只不过有点差别的是我们的这个$f$是需要递推过来的,不是一上场就给你的。

如果直接计算的话肯定时间爆炸。我们考虑采用分治方法来解决,这里的说明方法来自GGljc1301,我觉得没有比这讲得更清楚的了。

举个例子g[1..3]=1, 1, 0,求f[0..3](别问我干嘛要拿这个算斐波那契数列)

刚开始,是这样的(用中括号代表要算的区间,竖线代表中间的位置)

f =[1 0|0 0]

先算左边

f =[1|0]0 0

左半边的长度为1,不往下递归。计算左区间对右区间的贡献。就是把1, 0和g的前两项0, 1做卷积,得到*, 1(星号代表我们不在意这个位置),再把得到的后半段加到这个区间的右半边。操作后:

f =[1|1]0 0

右半边的长度为1,不往下递归。这一步就好了,回到上一步。

f =[1 1|0 0]

计算左半边对右半边的贡献。把1, 1, 0, 0和g的前四项0, 1, 1, 0做卷积,得到*, *, 2, 1,再把得到的后半段加到这个区间的右半边。操作后:

f =[1 1|2 1]

现在开始计算右半段。

f = 1 1[2|1]

左半边的长度是1,不往下递归。计算左区间对右区间的贡献。就是把2, 0(注意这里是0)和g的前(不是后)两项0, 1做卷积,得到*, 2,再把得到的后半段加到这个区间的右半边。操作后:

f = 1 1[2|3]

右半边的长度为1,不往下递归。然后这个f数组就算好了。

然后以下是AC代码:

/*
    有多项式g,f(0)=1,求f
    f(i)=f(i-1)*g(i)
*/
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <iostream>
#include <cmath>
using namespace std;

#define debug printf("@#$#@$#$#@%
")
typedef long long ll;
const int mac=2e5+10;
const int mod=998244353;

ll g[mac],f[mac],G[5][30],a[mac],b[mac];
int r[mac],n;

ll qpow(ll a,ll b)
{
    ll ans=1;
    a%=mod;
    while (b){
        if (b&1) ans=ans*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return ans;
}

void pre_NTT(int logn)
{
    int len=1<<logn;
    for (int i=0; i<len; i++)
        r[i]=(r[i>>1]>>1)|((i&1)<<(logn-1));
}

void NTT(int len,ll *a,int type)
{
    for (int i=0; i<len; i++)
        if (i<r[i]) swap(a[i],a[r[i]]);
    int nb=1;
    for (int mid=1; mid<len; mid<<=1){
        ll wn=G[type+1][nb++];
        for (int j=0; j<len; j+=(mid<<1)){
            ll w=1;
            for (int k=0; k<mid; k++,w=(w*wn)%mod){
                ll x=a[j+k],y=w*a[j+k+mid]%mod;
                a[j+k]=(x+y)%mod;
                a[j+k+mid]=(x-y+mod)%mod;
            }
        }
    }
}

void solve(int l,int r,int logn)
{
    if (logn<=0) return;
    int mid=(l+r)>>1;
    solve(l,mid,logn-1);
    pre_NTT(logn);
    for (int i=(r-l)/2; i<=r-l; i++) a[i]=0;//右区间置0
    for (int i=l,j=0; i<mid; i++,j++) a[j]=f[i];//拷贝左区间
    for (int i=0; i<r-l; i++) b[i]=g[i];
    NTT(1<<logn,a,1); NTT(1<<logn,b,1);
    for (int i=0; i<r-l; i++)
        a[i]=a[i]*b[i]%mod;
    NTT(1<<logn,a,-1);
    ll inv=qpow(r-l,mod-2);
    for (int i=0; i<r-l; i++)
        a[i]=a[i]*inv%mod;
    for (int i=(r-l)/2; i<r-l; i++)
        f[l+i]=(f[l+i]+a[i])%mod;
    solve(mid,r,logn-1);
}

int main(int argc, char const *argv[])
{
    for (int i=0; i<=25; i++){
        int len=1<<i;
        G[2][i]=qpow(3,(mod-1)/len);//正变换
        G[0][i]=qpow(G[2][i],mod-2);//逆变换
    }
    scanf ("%d",&n);
    for (int i=1; i<n; i++)
        scanf ("%lld",&g[i]);
    f[0]=1;
    int len=1,logn=0;
    while (len<n) len<<=1,logn++;
    solve(0,1<<logn,logn);
    for (int i=0; i<n; i++)
        printf("%lld ",f[i]);
    printf("
");
    return 0;
}
路漫漫兮
原文地址:https://www.cnblogs.com/lonely-wind-/p/13332962.html