分治FFT

分治FFT

考虑(F(i)),显然因为(F)是个卷积的形式(虽然我们不知道其中的某一部分),因此有:

[F(x)=sum_{i+j=x} F(i)G(j) ]

因此考虑我们计算出了前边一段的(F)值,可以通过乘上(G)中的一部分让这个(F)整体右移,如(F(1)-F(3))卷上(G(1)-G(3))就成为了(F(4)-F(6))中得的一部分。

因此考虑分治。

考虑我们已经计算出了一段([l,mid])中的真实(F)值,我们给右边的部分加上这些的贡献。

那么很显然就是(F[l,mid])卷上一个(G[0,r-l])就得到了(F[mid,r])的一部分。

那么我们每次分治计算左区间后,(NTT)计算出左边对右边的贡献,然后累加上去即可。

对比代码理解更好哦QAQ。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#define N 500005
#define pb push_back
#define g 3
#define gi 332748118
#define mod 998244353 
#define int long long
using namespace std;
int read()
{
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
	return x*f;
}
int n,rev[N];
vector<int>F,G,S,T;
int ksm(int a,int b)
{
	int res=1;
	while(b)
	{
		if(b&1)res*=a,res%=mod;
		a*=a;a%=mod;b>>=1;
	}
	return res%mod;
}
void NTT(vector<int>&a,int limit,int type)
{
	for(int i=0;i<limit;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
	for(int mid=1;mid<limit;mid<<=1)
	{
		int Wn=ksm(type==1?g:gi,(mod-1)/(mid<<1));
		for(int j=0;j<limit;j+=(mid<<1))
		{
			int w=1;
			for(int k=0;k<mid;k++,w=(w*Wn%mod)%mod)
			{
				int x=a[j+k]%mod,y=w*a[j+k+mid]%mod;
				a[j+k]=(x+y)%mod;
				a[j+k+mid]=(x-y+mod)%mod;
			}
		}
	}
	if(type==-1)
	{
		int INV=ksm(limit,mod-2);
		for(int i=0;i<limit;i++)a[i]=a[i]*INV%mod;
	}
}
int get_limit(int x)
{
	int limit=1;while(limit<=x)limit<<=1;
	for(int i=0;i<limit;i++)rev[i]=((rev[i>>1]>>1)|((i&1)?limit>>1:0));
	return limit;
}
vector<int> operator*(vector<int>&a,vector<int>&b)
{
	int len=a.size()+b.size()-1;
	int limit=get_limit(len);
	a.resize(limit);b.resize(limit);
	NTT(a,limit,1);NTT(b,limit,1);
	for(int i=0;i<limit;i++)a[i]=a[i]*b[i]%mod;
	NTT(a,limit,-1);a.resize(len);
	return a;
}
void solve(int l,int r)
{
	if(l==r)return;
	int mid=(l+r)>>1;
	solve(l,mid);
	S.clear();T.clear();
	for(int i=l;i<=mid;i++)S.pb(F[i]),T.pb(G[i-l]);
	for(int i=mid+1;i<=r;i++)S.pb(0),T.pb(G[i-l]);
	S=S*T;
	for(int i=mid+1;i<=r;i++)F[i]=(F[i]+S[i-l])%mod;
	solve(mid+1,r);
}
signed main()
{
	n=read();G.pb(0);F.pb(1);
	for(int i=1;i<n;i++)G.pb(read());
	for(int i=1;i<n;i++)F.pb(0);
	solve(0,n-1);
	for(int i=0;i<n;i++)printf("%d ",F[i]);
	return 0;
}

原文地址:https://www.cnblogs.com/szmssf/p/14476716.html