【模板】多项式开根

懒惰的我直接指对函数搞了

#include<cstdio>
#include<algorithm>
#include<iostream>
const int maxn = 1 << 19;
const int mod = 998244353,g=3;
typedef long long ll;
inline int pw(int a,int b,int ans=1){
    for(;b;b>>=1,a=ll(a)*a%mod)
        if(b&1)ans=ll(ans)*a%mod;
    return ans;
}
inline int inv(int x){return pw(x,mod-2);}
int rev[maxn],wn[maxn],lim;
inline void init(int len){
    wn[0]=lim=1;
    while(lim<len)lim<<=1;
    for(int i=1;i<lim;++i)rev[i]=rev[i>>1]>>1|(i%2*lim/2);
}
inline int reduce(int x){return x+(x>>31&mod);}
inline void fst(int*a,int type){
    for(int i=1;i<lim;++i)if(rev[i]>i)std::swap(a[i],a[rev[i]]);
    for(int mid=1;mid<lim;mid<<=1){
        const int W=pw(3,mod/mid/2);
        for(int k=1;k<mid;++k)wn[k]=ll(wn[k-1])*W%mod;
        for(int j=0;j<lim;j+=mid+mid){
            for(int*A=a+j,*B=A+mid,*w=wn;w!=wn+mid;++A,++B,++w){
                const int x=*A,y=ll(*B)**w%mod;
                *A=reduce(x+y-mod),*B=reduce(x-y);
            }
        }
    }
    if(!type){
        for(int i=0,lm=inv(lim);i<lim;++i)a[i]=ll(lm)*a[i]%mod;
        std::reverse(a+1,a+lim);
    }
}
inline void cpy(int*a,const int*b,int len){
    for(int i=0;i<len;++i)a[i]=b[i];
    for(int i=len;i<lim;++i)a[i]=0;
}
inline void inv(const int*a,int*b,int len){
    if(len==1)return void(*b=inv(*a));
    inv(a,b,len+1>>1),init(len*3/2+1);
    static int c[maxn],d[maxn];
    cpy(c,a,len),cpy(d,b,len+1>>1);
    fst(c,1),fst(d,1);
    for(int i=0;i<lim;++i)c[i]=ll(c[i])*d[i]%mod*d[i]%mod;
    fst(c,0);
    for(int i=len+1>>1;i<len;++i)b[i]=reduce(-c[i]);
}
int inv_[maxn];
inline void initinv(int n){inv_[1]=1;for(int i=2;i<=n;++i)inv_[i]=inv_[mod%i]*ll(mod-mod/i)%mod;}
inline void DER(int*a,int len){for(int i=0;i+1<len;++i)a[i]=a[i+1]*ll(i+1)%mod;a[len-1]=0;}
inline void INT(int*a,int len){for(int i=len;i;--i)a[i]=ll(a[i-1])*inv_[i]%mod;a[0]=0;}
inline void Ln(const int*a,int*b,int len){
    static int c[maxn];
    std::copy(a,a+len,c),DER(c,len),inv(a,b,len);
    init(len<<1);
    for(int i=len;i<lim;++i)b[i]=c[i]=0;
    fst(b,1),fst(c,1);
    for(int i=0;i<lim;++i)b[i]=ll(b[i])*c[i]%mod;
    fst(b,0),INT(b,len);
}
inline void exp(const int*a,int*b,int len){
    if(len==1)return void(*b=1);
    exp(a,b,len+1>>1);
    static int c[maxn],d[maxn];
    Ln(b,c,len),init(len);
    for(int i=0;i<len;++i)c[i]=reduce(a[i]-c[i]);
    for(int i=len;i<lim;++i)c[i]=0;++c[0];
    cpy(d,b,len+1>>1);
    fst(c,1),fst(d,1);
    for(int i=0;i<lim;++i)c[i]=ll(c[i])*d[i]%mod;
    fst(c,0);
    for(int i=len+1>>1;i<len;++i)b[i]=c[i];
}
inline void sqrt(const int*a,int*b,int n){
    static int c[maxn];
    Ln(a,c,n);
    for(int i=0;i<n;++i)c[i]=ll(c[i])*inv_[2]%mod;
    exp(c,b,n);
}
int a[maxn],b[maxn],n;
int main(){
    std::ios::sync_with_stdio(false),std::cin.tie(0);
    std::cin >> n;
    initinv(n);
    for(int i=0;i<n;++i)std::cin >> a[i];
    sqrt(a,b,n);
    for(int i=0;i<n;++i)std::cout << b[i] << ' ';
}
原文地址:https://www.cnblogs.com/skip1978/p/10334202.html