hdu 6900 Residual Polynomial (NTT)

题目链接

http://acm.hdu.edu.cn/showproblem.php?pid=6900

题意

定义(f_1(x)=sum_{i=0}^{n}a_ix^i),给定序列(a_i,b_i,c_i),以及递推式(f_i(x)=b_i(f_{i-1}(x))'+c_if_{i-1}(x))

思路

一篇讲个很好的博客:https://www.cnblogs.com/JustinRochester/p/13705300.html

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxx = 1e6+10;
const int mod = 998244353,G=3,G1=332748118;
LL a[maxx],b[maxx],c[maxx],r[maxx];
LL f[maxx],g[maxx];
LL *d[maxx<<2];
LL e[maxx],ans[maxx];
LL p[maxx],invp[maxx];
int limit;
LL quick(LL a,LL b)
{
    LL res=1;
    while(b)
    {
        if(b&1)res=res*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return res;
}
void init()
{
    p[0]=1;
    for(int i=1;i<maxx;i++)
    {
        p[i]=p[i-1]*i%mod;
        invp[i]=quick(p[i],mod-2);
    }
}
void NTT(LL *A,int type)
{
    for(int i=0;i<limit;i++)
        if(i<r[i])swap(A[i],A[r[i]]);
    for(int mid=1;mid<limit;mid<<=1)
    {
        LL wn = quick(type==1?G:G1,(mod-1)/(mid<<1));
        for(int j=0;j<limit;j+=(mid<<1))
        {
            LL w=1;
            for(int k=0;k<mid;k++,w=(w*wn)%mod)
            {
                int x=A[j+k],y=w*A[j+k+mid]%mod;
                A[j+k]=(x+y)%mod;
                A[j+k+mid]=(x-y+mod)%mod;
            }
        }
    }
}
void mul(LL *a,LL *b,LL *h,int n,int m)
{
    int L=0;
    limit=1;
    while(limit<=n+m)limit<<=1,L++;
    for(int i=0;i<limit;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(L-1));
    for(int i=0;i<limit;i++)f[i]=g[i]=0;
    for(int i=0;i<=n;i++)f[i]=a[i];
    for(int i=0;i<=m;i++)g[i]=b[i];
    NTT(f,1),NTT(g,1);
    for(int i=0;i<limit;i++)h[i]=(f[i]*g[i])%mod;
    NTT(h,-1);
    LL inv=quick(limit,mod-2);
    for(int i=0;i<=n+m;i++)h[i]=h[i]*inv%mod;
}
void solve(int l,int r,int rt)
{
    if(l==r)
    {
        d[rt]=new LL[2];
        d[rt][0]=b[l];
        d[rt][1]=c[l];
        return;
    }
    int mid=(l+r)/2;
    solve(l,mid,rt*2);
    solve(mid+1,r,rt*2+1);
    d[rt]=new LL[2*(r-l+1)];
    mul(d[rt*2],d[rt*2+1],d[rt],mid-l+1,r-mid);
}
void del(int l,int r,int rt)
{
    delete d[rt];
    if(l==r)return;
    int mid=(l+r)/2;
    del(l,mid,rt*2);
    del(mid+1,r,rt*2+1);
}
int main()
{
    init();
    int T;
    scanf("%d",&T);
    while(T--)
    {
        int n;
        scanf("%d",&n);
        for(int i=0;i<=n;i++)scanf("%lld",&a[i]),e[i]=a[i]*p[i]%mod;
        for(int i=2;i<=n;i++)scanf("%lld",&b[i]);
        for(int i=2;i<=n;i++)scanf("%lld",&c[i]);
        solve(2,n,1);
        mul(d[1],e,ans,n,n);
        printf("%lld",ans[n-1]);
        for(int i=1;i<=n;i++)printf(" %lld",ans[i+n-1]*invp[i]%mod);
        printf("
");
        del(2,n,1);
    }
}
原文地址:https://www.cnblogs.com/HooYing/p/13809306.html