多项式exp

调了很久,一直蜜汁错误,然而结果是b数组没有及时清零……

前置技能:多项式求逆。

简单讲一下牛顿迭代(推导详见picks博客,前置技能是泰勒公式):
求多项式F(x),使得G(F(x))≡0 (mod x^n)。方法倍增。
设已知多项式F_t满足G(F_t(x))≡0 (mod x(2t)),有递推式:
F_(t+1)(x)≡F_t(x)-G(F_t(x))/G'(F_t(x)) (mod x(2t+1))

然后是多项式求ln,方法:
因为(lnF(x))'=F'(x)/F(x),所以对F'(x)/F(x)求积分即可。(我起初纠结0次项无法积分出来怎么办,然而事实上0次项的值就是0)

现在才是多项式exp:
给出A(x),求e^A(x),%P。方法:
即求F(x),满足F(x)-e^A(x)≡0 (mod x^n)。
即 ln(F(x))-A(x)≡0
设计G(x),有G(F(x))=ln(F(x))-A(x)
套牛顿迭代式得:F_(t+1)(x)≡F_t(x)·(1-ln(F_t(x))+A(x))
这样就好了。

代码:

#include<bits/stdc++.h>
#define inf 0x3f3f3f3f
#define ll long long
#define P 998244353
using namespace std;
const int maxn=400010;
const int g=3;
inline int read(){
    int f=1,x=0;char ch=getchar();
    while (ch<'0'||ch>'9'){if (ch=='-')f=-1;ch=getchar();}
    while (ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return f*x;
}
 
inline int qpow(int x,int p){int ans=1;
    for(;p;p>>=1,x=1ll*x*x%P)if(p&1)ans=1ll*ans*x%P;
    return ans;
}
int n,m;
int a[maxn],b[maxn],c[maxn],d[maxn],tmp[maxn];
 
inline void NTT(int *a,int f,int k){  
    for (int i=0,j=0;i<k;i++){
        if (i>j)swap(a[i],a[j]);
        for (int l=k>>1;(j^=l)<l;l>>=1);
    }
    for (int i=1;i<k;i<<=1){
		int w=qpow(g,(f*(P-1)/(i<<1)+P-1)%(P-1));
        for (int j=0;j<k;j+=(i<<1)){
			int e=1;
            for (int k=0;k<i;++k,e=1ll*e*w%P){
				int x,y;
                x=a[j+k];y=1ll*a[j+k+i]*e%P;
                a[j+k]=(x+y)%P;a[j+k+i]=(x-y+P)%P;        
            }
        }
    }
    if(f==-1){
		int _inv=qpow(k,P-2);
		for (int i=0;i<k;++i)a[i]=1ll*a[i]*_inv%P;
	}
}

inline void GetInv(int *a,int *b,int n){
  	if (n==1) return void(b[0]=qpow(a[0],P-2));
  	GetInv(a,b,n>>1);
  	for (int i=0;i<n;++i)tmp[i]=a[i],tmp[n+i]=0;
  	int k;for (k=1;k<=n;k<<=1);
  	for (int i=n;i<k;++i)b[i]=0;
  	NTT(tmp,1,k);NTT(b,1,k);
  	for (int i=0;i<k;++i)
    	tmp[i]=b[i]*(2+P-1ll*tmp[i]*b[i]%P)%P;
  	NTT(tmp,-1,k);
  	for (int i=0;i<n;++i)b[i]=tmp[i],b[n+i]=0;
}

inline void Getln(int *a,int *b,int *c,int n){
	for (int i=0;i<n;++i)b[i]=1ll*(i+1)*a[i+1]%P;
	GetInv(a,c,n);
	for (int i=0;i<n;++i)tmp[i]=c[i],tmp[n+i]=0;
	int k;for (k=1;k<=n;k<<=1);
	NTT(tmp,1,k);NTT(b,1,k);
	for (int i=0;i<k;++i)
		tmp[i]=1ll*b[i]*tmp[i]%P;
	NTT(tmp,-1,k);
	for (int i=1;i<=n;++i)b[i]=1ll*tmp[i-1]*qpow(i,P-2)%P,b[i+n]=0;b[0]=0;
}

inline void Getexp(int *a,int *b,int *c,int *d,int n){
	if (n==1)return void(b[0]=1);
	Getexp(a,b,c,d,n>>1);
	Getln(b,c,d,n);
	for (int i=0;i<n;++i)tmp[i]=(a[i]-c[i]+P)%P,tmp[n+i]=0;++tmp[0];
	int k;for (k=1;k<=n;k<<=1);
	NTT(tmp,1,k);NTT(b,1,k);
	for (int i=0;i<k;++i)
		tmp[i]=(ll)tmp[i]*b[i]%P;
	NTT(tmp,-1,k);
	for (int i=0;i<n;++i)b[i]=tmp[i],b[i+n]=0;
}

int main(){
	n=read();
	for (int i=0;i<n;++i)a[i]=read();
	for (m=1;m<n;m<<=1);
	Getexp(a,b,c,d,m);
	for (int i=0;i<=m;++i)printf("%d ",b[i]);
}
原文地址:https://www.cnblogs.com/szboyi/p/7692645.html