P4721 【模板】分治 FFT

P4721 【模板】分治 FFT

复习了一下,稍微写一写。

边界很奇怪,从 (1) 开始的。。。

这种自己按某一维顺序更新自己的可以考虑分治FFT。

分治FFT用的是cdq分治的思想,以这题为例。

在分治 ([l,r]) 的时候,先分治左边 ([l,mid]),然后统计 ([l,mid])([mid+1,r]) 的贡献,再分治 ([mid+1,r])

为啥先分治左边?其实想想就很显然。你要知道右边的值,必然要依靠左边的值;但是左边的值不需要右边的值。先分治左边,确定了左边的值就可以更新右边了。

个人感觉,分治FFT这东西是有实现难度的,不像倍增FFT难度全在推式子。初学,dalao别D

详细讲一下怎么实现。

设当前分治区间 ([l,r],mid=lfloordfrac{l+r}{2} floor)

(A_i=f_{i+l}(i+lin [l,mid]),B_i=g_{i+1}(iin [0,r-l)))

把这两个东西卷起来,看看是什么。

((A*B)(k)=sum_{j=0}^{k}A_jB_{k-j}=f_{j+l}g_{k+1-j})

不难发现,([l,mid]) 内的 (f) 对于 ([mid+1,r]) 内某个位置 (i) 的贡献为 (sum_{j=l}^{mid}f_jg_{i-j})

所以我们把 (iin [mid+1,r])(f) 加上 ((A*B)(i-l-1)) 即可。

我其实感觉实现难度就在上面这部分。

首先要构造 (A,B) 两个多项式。

大概说一下我是怎么搞的吧。

(A) 设成要统计贡献的部分,这个一般就是左区间或者右区间。

(B) 的范围得看 (A) 卷上哪些部分可以到达目标区间。这题 (A) 的下标是 ([l,mid]) ,目标区间是 ([mid+1,r]),由于 (l) 是对 (r) 有贡献的((f_l*g_{r-l})),那么上界至少是 (r-l)。用这个思路可以发现上界设 (r-l) 够了。同理下界设成 (1)

现在我们可以肯定的是,(A*B) 的某一项系数就是对于 ([mid+1,r]) 某一项的贡献,然后手动算一下具体是哪一项的贡献。这里一定要小心,少一个 (+1) 或者少一个 (-1) ,差之毫厘失之千里。

复杂度是 (T(n)=2T(dfrac{n}{2})+O(nlog n)=O(nlog^2 n))

#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mkp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define sz(v) (int)v.size()
typedef long long LL;
typedef double db;
template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
#define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
#define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
	while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
	return f?x:-x;
}
const int N=100005;
const int M=N<<2;
#define mod 998244353
inline void fmod(int&x){x-=mod,x+=x>>31&mod;}
inline int qpow(int n,int k){int res=1;for(;k;k>>=1,n=1ll*n*n%mod)if(k&1)res=1ll*n*res%mod;return res;}
int n,g[N],f[N];
namespace poly{
int lg,lim,rev[M];
void init(const int&n){
	for(lim=1,lg=0;lim<=n;lim<<=1,++lg);
	for(int i=0;i<lim;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1));
}
void NTT(int*a,int op){
	for(int i=0;i<lim;++i)if(i>rev[i])swap(a[i],a[rev[i]]);
	int g=op?3:qpow(3,mod-2);
	for(int i=1;i<lim;i<<=1){
		int wn=qpow(g,(mod-1)/(i<<1));
		for(int j=0;j<lim;j+=i<<1){
			int w0=1;
			for(int k=0;k<i;++k,w0=1ll*w0*wn%mod){
				const int X=a[j+k],Y=1ll*w0*a[i+j+k]%mod;
				fmod(a[j+k]=X+Y),fmod(a[i+j+k]=X-Y+mod);
			}
		}
	}
	if(op)return;int ilim=qpow(lim,mod-2);
	for(int i=0;i<lim;++i)a[i]=1ll*a[i]*ilim%mod;
}
}
#define clr(a,n) memset(a,0,sizeof(int)*(n))
void CDQ_NTT(int l,int r){
	if(l==r)return;
	int mid=(l+r)>>1;
	CDQ_NTT(l,mid);
	static int A[M],B[M];
	poly::init(r-l-1+mid-l),clr(A,poly::lim),clr(B,poly::lim);
	for(int i=l;i<=mid;++i)A[i-l]=f[i];
	for(int i=1;i<=r-l;++i)B[i-1]=g[i];
	poly::NTT(A,1),poly::NTT(B,1);
	for(int i=0;i<poly::lim;++i)A[i]=1ll*A[i]*B[i]%mod;
	poly::NTT(A,0);
	for(int i=mid+1;i<=r;++i)fmod(f[i]+=A[i-l-1]);
	CDQ_NTT(mid+1,r);
}
signed main(){
	n=read(),f[0]=1;rep(i,1,n-1)g[i]=read();
	CDQ_NTT(0,n-1);
	rep(i,0,n-1)printf("%d ",f[i]);
	return 0;
}
原文地址:https://www.cnblogs.com/zzctommy/p/14209636.html