[学习笔记] 拆系数 FFT

0. 楔子

我们知道,(mathtt{FFT}) 是用复数运算的,当我们需要取模且数据范围较大的时候就没有办法了。

比如 这道题

如果能减小数据范围,我们就可以先算再模,于是,拆系数 (mathtt{FFT}) 就闪亮登场了!

1. 正文

注意(*) 表示卷积。

首先计算一下本题的数据范围:(10^9 imes 10^9 imes 10^5=10^{23})(两数相乘,加 (n) 遍)。

(base) 为一个 (sqrt p) 级别的数,可以将 (F(x),G(x)) 分别分解成 (A(x) imes base+B(x),C(x) imes base+D(x))(注意要除尽),这样拆出来的函数是 (<sqrt p) 的。

[F*G=(A imes base+B)*(C imes base+D) ]

[=A*C imes base^2+(A*D+B*C) imes base+B*D ]

这样数据范围就是 (10^{14}) 左右。但是这样要做 (7)(mathtt{FFT})(插值 (4) 次,转换 (3) 次)。

我们知道,(mathtt{FFT}) 的虚部赋值为 (0),我们可以利用这个空间。

考虑计算 (A,B,C,D) 的点值表示。令

[f(k)=A(k)+i imes B(k) ]

[g(k)=A(k)-i imes B(k) ]

[h(k)=C(k)+i imes D(k) ]

我们发现 (f(k),g(n-k)) 是共轭的((n)(n) 补全的 (2) 的幂)。

[f(k)=A(k)+i imes B(k) ]

[=sum_{j=0}^{n-1}a_j(omega_n^k)^j+i imes sum_{j=0}^{n-1}b_j(omega_n^k)^j ]

[=sum_{j=0}^{n-1}(a_j+i imes b_j)(omega_n^k)^j ]

[=sum_{j=0}^{n-1}(a_j+i imes b_j)(cos(frac{2pi kj}{n})+isin (frac{2pi kj}{n})) ]

[=sum_{j=0}^{n-1}(a_j imes cos(frac{2pi kj}{n})-b_j imes sin (frac{2pi kj}{n}))+i(b_j imes cos(frac{2pi kj}{n})+a_j imes sin (frac{2pi kj}{n})) ]

[g(n-k)=A(n-k)-i imes B(n-k) ]

[=sum_{j=0}^{n-1}a_j(omega_n^{-k})^j-i imes sum_{j=0}^{n-1}b_j(omega_n^{-k})^j ]

[=sum_{j=0}^{n-1}(a_j-i imes b_j)(omega_n^{-k})^j ]

[=sum_{j=0}^{n-1}(a_j-i imes b_j)(cos(frac{2pi kj}{n})-isin (frac{2pi kj}{n})) ]

[=sum_{j=0}^{n-1}(a_j imes cos(frac{2pi kj}{n})-b_j imes sin (frac{2pi kj}{n}))-i(b_j imes cos(frac{2pi kj}{n})+a_j imes sin (frac{2pi kj}{n})) ]

所以计算 (f,g,h) 的点值表达式只用 (2)(mathtt{FFT})

(p=f*h,q=g*h)。所以

[p=A*C-B*D+i(A*D+B*C) ]

[q=A*C+B*D+i(A*D-B*C) ]

左边的每一项就是右边卷完后每一项系数经过一番运算。

所以计算出 (p,q) 需要 (2)(mathtt{FFT}),将 (p,q) 对应项相加即可解出 (A*C,A*D),从而都解出来。

总共需要 (4)(mathtt{FFT})

2. 代码

用到了预处理单位根,这样精度会高一些。

#include <cstdio>

#define rep(i,_l,_r) for(register signed i=(_l),_end=(_r);i<=_end;++i)
#define fep(i,_l,_r) for(register signed i=(_l),_end=(_r);i>=_end;--i)
#define erep(i,u) for(signed i=head[u],v=to[i];i;i=nxt[i],v=to[i])
#define efep(i,u) for(signed i=Head[u],v=to[i];i;i=nxt[i],v=to[i])
#define print(x,y) write(x),putchar(y)

template <class T> inline T read(const T sample) {
    T x=0; int f=1; char s;
    while((s=getchar())>'9'||s<'0') if(s=='-') f=-1;
    while(s>='0'&&s<='9') x=(x<<1)+(x<<3)+(s^48),s=getchar();
    return x*f;
}
template <class T> inline void write(const T x) {
    if(x<0) return (void) (putchar('-'),write(-x));
    if(x>9) write(x/10);
    putchar(x%10^48);
}
template <class T> inline T Max(const T x,const T y) {if(x>y) return x; return y;}
template <class T> inline T Min(const T x,const T y) {if(x<y) return x; return y;}
template <class T> inline T fab(const T x) {return x>0?x:-x;}
template <class T> inline T gcd(const T x,const T y) {return y?gcd(y,x%y):x;}
template <class T> inline T lcm(const T x,const T y) {return x/gcd(x,y)*y;}
template <class T> inline T Swap(T &x,T &y) {x^=y^=x^=y;}

#include <cmath>
#include <iostream>
using namespace std;
typedef long long ll;

const double Pi=acos(-1.0);
const int num1=(1<<30),num2=(1<<15),maxn=262150;

int n,m,mod,rev[maxn],lim,bit;
ll a1b1,a1b2,a2b1,a2b2;
struct cp {
	double x,y;
	
	cp operator + (const cp t) const {return (cp){x+t.x,y+t.y};}
	cp operator - (const cp t) const {return (cp){x-t.x,y-t.y};}
	cp operator * (const cp t) const {return (cp){x*t.x-y*t.y,y*t.x+x*t.y};}
} f[maxn],g[maxn],h[maxn],w[maxn][2],tmp;

void FFT(cp *f,const int op=1) {
	rep(i,0,lim-1) if(i<rev[i]) swap(f[i],f[rev[i]]);
	for(int mid=1;mid<lim;mid<<=1) {
		for(int i=0,p=(mid<<1);i<lim;i+=p) {
			for(int j=0;j<mid;++j) {
				tmp=w[lim/mid/2*j][op==1]*f[i+j+mid];
				f[i+j+mid]=f[i+j]-tmp,f[i+j]=f[i+j]+tmp;
			}
		}
	}
}

void init() {
	lim=1;
	while(lim<=n+m) lim<<=1,++bit;
	rep(i,0,lim-1) {
		rev[i]=(rev[i>>1]>>1)|((i&1)<<bit-1);
		w[i][1]=(cp){cos(Pi*2*i/lim),sin(Pi*2*i/lim)};
		w[i][0]=(cp){w[i][1].x,-w[i][1].y};
	}
}

void MTT() {
	int val;
	rep(i,0,n) val=read(9),f[i]=(cp){val>>15,val&32767};
	rep(i,0,m) val=read(9),h[i]=(cp){val>>15,val&32767};
	init();
	FFT(f),FFT(h);
	g[0]=(cp){f[0].x,-f[0].y};
	rep(i,1,lim-1) g[i]=(cp){f[lim-i].x,-f[lim-i].y};
	// 先除以 lim,这样就只用除一次
	rep(i,0,lim-1) h[i].x/=lim,h[i].y/=lim,f[i]=f[i]*h[i],g[i]=g[i]*h[i];
	FFT(f,-1),FFT(g,-1);
	rep(i,0,n+m) {
		a1b1=(ll)((f[i].x+g[i].x)/2+0.5)%mod;
		a1b2=(ll)((f[i].y+g[i].y)/2+0.5)%mod;
		a2b1=((ll)(f[i].y+0.5)-a1b2)%mod;
		a2b2=((ll)(g[i].x+0.5)-a1b1)%mod;
		print((a1b1*num1%mod+(a1b2+a2b1)*num2%mod+a2b2)%mod,' ');
	}
	puts("");
}

int main() {
	n=read(9),m=read(9),mod=read(9);
	MTT();
	return 0;
} 
原文地址:https://www.cnblogs.com/AWhiteWall/p/14407422.html