FFT&NTT学习笔记

前置芝士:单位根

复数

定义

众所周知实数分布在一维的实数轴上,单位是1。类比实数轴,我们有虚数轴,单位是(i)(i)是什么呢?简单地说就是(sqrt{-1})。类比于平面直角坐标系的x,y轴,我们有复平面,竖轴是虚数轴,横轴是实数轴,两个轴互相垂直。类比于平面直角坐标系上的每个坐标,复数轴上的每个点就是一个复数。平面直角坐标系的点((x,y)),对应的复数就是(a+bi)。其中称(a)为实部,(b)为虚部。
此处输入图片的描述

复数的模

即复数到原点的距离,(|a+bi|=sqrt{a^2+b^2})

复数的辐角

图中的(θ)。复数有无限多个辐角,一般取([-pi,pi])之间那个。

复数的运算

复数的加减法满足平行四边形法则;复数的乘法为:模相乘,辐角相加。对应到代数也很简单:

[(a+bi)+(c+di)=(a+c)+(b+d)i\ (a+bi)-(c+di)=(a-c)+(b-d)i\ (a+bi) imes(c+di)=(ac-bd)+(ad+bc)i ]

注意(i^2=-1),可以推出上面的乘法公式。

欧拉公式

(e)的定义式

有人这么想过:如果我往银行里存了笔钱,我吃一期利息,再取出来,把利息加本金作为新的本金,再存进去,再吃一期利息,再取出来,再存进去......我的钱会不会无限多了呢?然而经过实验得出,随着存取的次数增多,现在拥有的钱和最初的本金的比值趋近于一个数,大概是二点七几。于是就有了(e)的定义式:

[e=lim_{n ightarrow infty}(1+frac{1}{n})^n ]

欧拉公式

可以发现(e)有这样的性质:

[e^2=lim_{n ightarrow infty}(1+frac{1}{n})^{2n}= lim_{n ightarrow infty}((1+frac{1}{n})^2)^n=\ lim_{n ightarrow infty}(1+frac{2}{n}+frac{1}{n^2})^n= lim_{n ightarrow infty}(1+frac{2}{n})^n ]

类似地可以发现:

[e^3=lim_{n ightarrow infty}(1+frac{3}{n})^n ]

归纳证明得到:

[e^x=lim_{n ightarrow infty}(1+frac{x}{n})^n ]

对应到复数可以得到:

[e^{ix}=lim_{n ightarrow infty}(1+frac{x}{n}i)^n ]

也就是无限多个模为(sqrt{1^2+(frac{x}{n})^2}=1),并且辐角为(arctan{frac{x}{n}}=frac{x}{n})(注意这里的(n)为正无穷)的复数乘起来。根据复数乘法的定义,可以知道(e^{ix})是个复数,模长为(1),辐角为(x)。那么把这个复数表示出来就是:(cos{x}+i·sin{x})

于是就有了欧拉公式:(e^{ix}=cos{x}+i·sin{x})

单位根

定义

在复平面上做单位圆,以原点为起点,圆的(n)等分点为终点,做(n)个向量。其中辐角为正且最小的一个向量所对应的复数叫做(n)次单位根,记为(w_n)

根据复数乘法,圆上剩下的(n-1)个向量所对应的复数就是:(w_n^2,w_n^3...w_n^n)

易知(w_n^k)的辐角为(frac{2pi}{n} imes k),模为(1),那么根据欧拉公式:

[w_n^k=cos{(k imes frac{2pi}{n})}+i·sin{(k imes frac{2pi}{n})} ]

性质

1.(w_n^k=cos{(k imes frac{2pi}{n})}+i·sin{(k imes frac{2pi}{n})})

2.(w_{2n}^{2k}=cos{(2k imes frac{2pi}{2n})}+i·sin{(2k imes frac{2pi}{2n})}=w_n^k)

3.(w_n^{k+frac{n}{2}}=w_n^k imes w_n^{frac{n}{2}}=w_n^k imes (cos{pi}+i·sin{pi})=-w_n^k)

4.(w_n^0=w_n^n=1)

快速傅里叶变换FFT

多项式

定义

形如(A(x)=sum_{i=0}^{n}a_ix^i)(A(x))称为多项式。

系数表示法

(n+1)个系数唯一确定一个n次多项式,所以可以用系数来表示这个多项式:({{}a_0,a_1,a_2,...,a_n{}})

点值表示法

(n)次多项式代(n+1)个不同的(x),可以得到(n+1)个不同的值({{}y_0,y_1,y_2,...,y_n{}}),如果这(n+1)个点((x_0,y_0),(x_1,y_1),...,(x_n,y_n))线性无关,则这个多项式可以被这些点唯一确定。所以可以用点值来表示这个多项式。

快速傅里叶变换FFT

作用

已知一个多项式的所有系数,FFT可以在(O(nlogn))的复杂度内得到一组点值,而朴素算法需要(O(n^2))。对应的有快速傅里叶逆变换IDFFT,已知一个多项式的点指表示,可以在(O(nlogn))的复杂度内求出多项式的系数。其中(n)是带入的(x)的数量。

一个(n)次多项式和一个(m)次多项式相乘会得到一个(n+m)次多项式,所以如果代入大于(n+m+1)(x),把算出的每个(A(x))(B(x))乘起来,得到(n+m+1)个点值((x_0,A(x_0) imes B(x_0))),((x_1,A(x_1) imes B(x_1))),...,((x_{n+m},A(x_{n+m}) imes B(x_{n+m})))可以唯一确定多项式(A imes B),再用IDFFT可以求出(A imes B)的系数。所以FFT可以用来计算多项式乘法,复杂度为(O(nlogn)),朴素算法需要(O(n^2))

公式推导

设多项式(A(x))系数为({{}a_0,a_1,a_2,...,a_{n-1}{}})。这里认为(n)可以表示为(2^k)的形式。实际的(n)若不足可以在后面补(0)

[A(x)=a_0+a_1x+a_2x^2+...+a_{n-1}x^{n-1}\ =(a_0+a_2x^2+...+a_{n-2}x^{n-2})+(a_1x+a_3x^3+...+a_{n-1}x^{n-1}) ]

(A_1(x)=(a_0+a_2x+a_4x^2+...+a_{n-2}x^{frac{n}{2}-1}))

(A_2(x)=(a_1+a_3x+a_5x^2+...+a_{n-1}x^{frac{n}{2}-1}))

(A(x)=A_1(x^2)+xA_2(x^2))

设单位根为(w_n),代入(w_n^k(k<frac{n}{2}))得:

[A(w_{n}^{k})=A_1(w_{n}^{2k})+w_{n}^{k}A_2(w_{n}^{2k})=A_1(w_{frac{n}{2}}^{k})+w_{n}^{k}A_2(w_{frac{n}{2}}^{k}) ]

再把(w_n^{k+frac{n}{2}})代入得

[A(w_n^{k+frac{n}{2}})=A_1(w_n^{2k+n})+w_n^{k+frac{n}{2}}A_2(w_n^{2k+n})\ =A_1(w_n^{2k} imes w_n^n)-w_n^kA_2(w_n^{2k} imes w_n^n)\ =A_1(w_n^{2k})-w_n^kA_2(w_n^{2k})=A_1(w_{frac{n}{2}}^{k})-w_{n}^{k}A_2(w_{frac{n}{2}}^{k}) ]

上面有点乱,我来整理一下:

[A(w_{n}^{k})=A_1(w_{frac{n}{2}}^{k})+w_{n}^{k}A_2(w_{frac{n}{2}}^{k})\ A(w_n^{k+frac{n}{2}})=A_1(w_{frac{n}{2}}^{k})-w_{n}^{k}A_2(w_{frac{n}{2}}^{k}) ]

可以发现只有中间的符号不同。又因为当(k)取遍([0,frac{n}{2}))所有值时,(k+frac{n}{2})取遍([frac{n}{2},n))所有值。所以我们只需要计算当(kin [0,frac{n}{2}))时,长度为(frac{n}{2})的多项式(A_1(w_{frac{n}{2}}^{k}))(A_2(w_{frac{n}{2}}^{k}))的值,就可以得到当(kin [0,n))时,长度为(n)的多项式(A(w_{n}^{k}))的值。递归计算即可,复杂度为(O(nlogn))

快速傅里叶逆变换IDFFT

作用

前面提到过,把点值表示转系数表示

公式推导

(y_0,y_1,...y_{n-1})为多项式(a_0+a_1x+...+a_{n-1}x^{n-1})的点值表示。

(c_0),(c_1),...,(c_{n-1})满足(c_k=sum_{i=0}^{n-1}y_i(w_n^{-k})^i),即多项式(B(x)=y_0+y_1x+...+y_{n-1}x^{n-1})(w_n^0),(w_n^{-1}),...,(w_n^{-n+1})处的点值表示。

[c_k=sum_{i=0}^{n-1}y_i(w_n^{-k})^i=sum_{i=0}^{n-1}(sum_{j=0}^{n-1}a_j(w_n^i)^j)(w_n^{-k})^i\ =sum_{i=0}^{n-1}sum_{j=0}^{n-1}a_j(w_n^i)^j(w_n^{-k})^i=sum_{i=0}^{n-1}sum_{j=0}^{n-1}a_j(w_n^j)^i(w_n^{-k})^i=\ sum_{i=0}^{n-1}sum_{j=0}^{n-1}a_j(w_n^{j-k})^i=sum_{j=0}^{n-1}a_j sum_{i=0}^{n-1}(w_n^{j-k})^i ]

注意一下后面这个(sum),设(S(x)=sum_{i=0}^{n-1}x^i),代入(w_n^k)

[S(w_n^k)=1+w_n^k +(w_n^k)^2+...+(w_n^k)^{n-1} ]

1.当(k eq0)时,(w_n^kS(w_n^k)=w_n^k +(w_n^k)^2+...+(w_n^k)^n),相减得

[(w_n^k-1)S(w_n^k)=(w_n^k)^n-1\ herefore S(w_n^k)=frac{(w_n^k)^n-1}{w_n^k-1}=frac{(w_n^n)^k-1}{w_n^k-1}=frac{1-1}{w_n^k-1}=0 ]

2.当(k=0)时,(S(w_n^k)=n)

然后我们回到之前的式子(c_k=sum_{j=0}^{n-1}a_j sum_{i=0}^{n-1}(w_n^{j-k})^i),根据上面的结论可以知道,只有当(j=k)时,(sum_{i=0}^{n-1}(w_n^{j-k})^i=n),否则等于(0)

( herefore c_j=na_j,a_j=frac{c_j}{n})

所以对于一个多项式(a_0+a_1x+...+a_{n-1}x^{n-1}),如果我们知道它的点值表示(y_0),(y_1),...,(y_{n-1}),就可以用FFT求出多项式(B(x)=y_0+y_1x+...+y_{n-1}x^{n-1})(w_n^0),(w_n^{-1}),...,(w_n^{-n+1})处的点值表示(c_0),(c_1),...,(c_{n-1}),从而求出(a_0),(a_1),...,(a_{n-1})

时间复杂度也是(O(nlogn))

蝴蝶优化

你真的用递归去写?可以发现递归的方式每次都需要把(A(x))的系数复制一遍,排个顺序得到(A_1(x))(A_2(x))的系数。这是非常慢的。蝴蝶优化可以解决这个问题。

举个(n=8)的例子,我在这里列出每次递归的系数的顺序:

[n=8:a_0,a_1,a_2,a_3,a_4,a_5,a_6,a_7\ n=4:a_0,a_2,a_4,a_6,a_1,a_3,a_5,a_7\ n=2:a_0,a_4,a_2,a_6,a_1,a_5,a_3,a_7\ ]

然后把最上面一排和最下面一排的下标的二进制写出来:

[000,001,010,011,100,101,110,111\ 000,100,010,110,001,101,011,111 ]

可以发现最后的顺序的二进制,就是对应位置的最初顺序的二进制的翻转。所以我们已开始就给它翻好,就不用递归了。给出代码:

#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#define rg register
#define il inline
#define cn const
#define gc getchar()
#define fp(i,a,b) for(rg int i=a;i<=b;++i)
using std::swap;
inline int read(){
	rg int x(0),f(1); rg char c(gc);
	while(c<'0'||'9'<c){ if(c=='-') f=-1; c=gc; }
	while('0'<=c&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=gc;
	return x*f;
}
#define maxn 10000010
const double pi=acos(-1.0);
int n,m,limit=1,l,r[maxn];
double Cos[maxn],Sin[maxn];
struct complex{ double x,y; }a[maxn],b[maxn];//复数
complex operator+(cn complex &x,cn complex &y){ return (complex){x.x+y.x,x.y+y.y}; }
complex operator-(cn complex &x,cn complex &y){ return (complex){x.x-y.x,x.y-y.y}; }
complex operator*(cn complex &x,cn complex &y){ return (complex){x.x*y.x-x.y*y.y,x.x*y.y+x.y*y.x}; }
inline void FastFourierTransform(complex *a,cn int &type){
	fp(i,0,limit-1) if(i<r[i]) swap(a[i],a[r[i]]);//蝴蝶优化
	for(rg int mid=1;mid<limit;mid<<=1){
		rg int len=mid<<1; complex Wn=(complex){Cos[len],type*Sin[len]};//根据上面的推导,这个type很灵性
		for(rg int j=0;j<limit;j+=len){
			complex Pow=(complex){1,0};
			for(rg int k=0;k<mid;++k,Pow=Pow*Wn){
				complex x=a[j+k],y=Pow*a[j+mid+k];
				a[j+k]=x+y,a[j+mid+k]=x-y;
			}
		}
	}
}
int main(){
	n=read(),m=read(); fp(i,0,n) a[i].x=read(); fp(i,0,m) b[i].x=read();
	while(limit<=n+m) limit<<=1,++l; fp(i,0,limit-1) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));//求出翻转的下标
	fp(i,0,limit) Cos[i]=cos(pi*2/i),Sin[i]=sin(pi*2/i);
	FastFourierTransform(a,1),FastFourierTransform(b,1);
	fp(i,0,limit) a[i]=a[i]*b[i]; FastFourierTransform(a,-1);
	fp(i,0,n+m) printf("%d ",(int)(a[i].x/limit+0.5)); return 0;
}

快速数论变换NTT

可以发现单位根有的性质原根都有......所以可以用原根代替单位根。

#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#define rg register
#define il inline
#define LL long long
#define cn const
#define gc getchar()
#define fp(i,a,b) for(rg int i=a;i<=b;++i)
using std::swap;
il int read(){
	rg int x(0),f(1); rg char c(gc);
	while(c<'0'||'9'<c){ if(c=='-') f=-1; c=gc; }
	while('0'<=c&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=gc;
	return x*f;
}
#define maxn 3000010
const int G=3,invG=332748118,P=998244353;//G是原根,invG是原根的逆元
int n,m,limit=1,l,r[maxn];
LL a[maxn],b[maxn];
inline LL FastPow(LL a,int b){
	LL ans=1;
	for(;b;b>>=1,a=a*a%P) if(b&1) ans=ans*a%P;
	return ans;
}
inline void NumberTheoreticTransform(LL *a,cn int &type){
	fp(i,0,limit-1) if(i<r[i]) swap(a[i],a[r[i]]);
	for(rg int mid=1;mid<limit;mid<<=1){
		rg int len=mid<<1; rg LL Gn=FastPow(type?G:invG,(P-1)/len);
		for(rg int j=0;j<limit;j+=len){
			LL Pow=1;
			for(rg int k=0;k<mid;++k,Pow=Pow*Gn%P){
				LL x=a[j+k],y=Pow*a[j+mid+k]%P;
				a[j+k]=(x+y)%P,a[j+mid+k]=(x-y+P)%P;
			}
		}
	}
}
int main(){
	n=read(),m=read(); fp(i,0,n) a[i]=(read()+P)%P; fp(i,0,m) b[i]=(read()+P)%P;
	while(limit<=n+m) limit<<=1,++l; fp(i,0,limit-1) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
	NumberTheoreticTransform(a,1),NumberTheoreticTransform(b,1);
	fp(i,0,limit) a[i]=a[i]*b[i]%P; NumberTheoreticTransform(a,0);
	rg LL invN=FastPow(limit,P-2); fp(i,0,n+m) printf("%d ",a[i]*invN%P);
	return 0;
}

模板题:【模板】多项式乘法(FFT)

原文地址:https://www.cnblogs.com/akura/p/12236347.html