多项式乘法到FFT、NTT、FWT

多项式乘法

形如 (a_{0}+a_{1}x+a_{2}x^{2}+a_{3}x^{3}.....)的式子称为多项式,接下来设 (A(x)=a_{0}+a_{1}x+a_{2}x^{2}+a_{3}x^{3}.....)(B(x)=b_{0}+b_{1}x+b_{2}x^{2}+b_{3}x^{3}.....),求 (A(x) imes B(x)),用正常的暴力算法去实现,可以看到时间复杂度为 (O(n^{2})),而FFT就可以把时间复杂度降到 (O(nlog(n)))。从而去解决一些问题。


问题:

比如说给你三个数组 (a、b、c),问在这三个数组中有多少组 (i、j、k) 满足 (a[i]+c[k]=2 imes b[j]),在 (O(n^{2})) 的复杂度行不通的情况下,就去考虑多项式乘法,以 (a[i]) 的数量作为多项式(A(x))(x^{a[i]})的系数,(c[k])同理。那么 (A(x) imes B(x))中每项的系数就是 (a[i]+c[k]) 的值的数量。


解决对应问题�

  • (C_{k}=sum_{i+j=k} A_{i} imes B_{j})
    (FFT、FNT)(FNT)没有复数运算,精度问题较少。

  • (C_{k}=sum_{ i or j=k} A_{i} imes B_{j})

  • (C_{k}=sum_{ i and j=k} A_{i} imes B_{j})

  • (C_{k}=sum_{ i xor j=k} A_{i} imes B_{j})
    (FWT),用来解决多项式的位运算卷积。


代码

FFT:

#include<cmath>
#include<cstdio>
#define R register
#define I inline
using namespace std;
const int N=4.2e6;
const double PI=acos(-1);
int n,r[N];
struct C{//手写complex,比STL快一点点
	double r,i;
	I C(){r=i=0;}
	I C(R double x,R double y){r=x;i=y;}
	I C operator+(R C&x){return C(r+x.r,i+x.i);}
	I C operator-(R C&x){return C(r-x.r,i-x.i);}
	I C operator*(R C&x){return C(r*x.r-i*x.i,r*x.i+i*x.r);}
	I void operator+=(R C&x){r+=x.r;i+=x.i;}
	I void operator*=(R C&x){R double t=r;r=r*x.r-i*x.i;i=t*x.i+i*x.r;}
}f[N],g[N];
I int in(){
	R char c=getchar();
	while(c<'-')c=getchar();
	return c&15;
}
I void FFT(R C*a,R int op){
	R C W,w,t,*a0,*a1;
	R int i,j,k;
	for(i=0;i<n;++i)//根据映射关系交换元素
		if(i<r[i])t=a[i],a[i]=a[r[i]],a[r[i]]=t;
	for(i=1;i<n;i<<=1)//控制层数
		for(W=C(cos(PI/i),sin(PI/i)*op),j=0;j<n;j+=i<<1)//控制一层中的子问题个数
			for(w=C(1,0),a1=i+(a0=a+j),k=0;k<i;++k,++a0,++a1,w*=W)
				t=*a1*w,*a1=*a0-t,*a0+=t;//蝴蝶操作
}
int main(){
    R int m,i,l=0;
    scanf("%d%d",&n,&m);
    for(i=0;i<=n;++i)f[i].r=in();
    for(i=0;i<=m;++i)g[i].r=in();
    for(m+=n,n=1;n<=m;n<<=1,++l);
    for(i=0;i<n;++i)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));//递推求r
    FFT(f,1);FFT(g,1);
    for(i=0;i<n;++i)f[i]*=g[i];
    FFT(f,-1);
    for(i=0;i<=m;++i)printf("%.0f ",fabs(f[i].r)/n);
    puts("");
    return 0;
}

FNT:

#define LL long long int
#define ls (x << 1)
#define rs (x << 1 | 1)
#define MID int mid=(l+r)>>1
using namespace std;
 
const int N = 300010;
const int Mod = 998244353;
int n,m,L,R[N],g[N],a[N],b[N];
 
int gi()
{
  int x=0,res=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')res*=-1;ch=getchar();}
  while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
  return x*res;
}
 
inline int QPow(int d,int z)
{
  int ans=1;
  for(;z;z>>=1,d=1ll*d*d%Mod)
    if(z&1)ans=1ll*ans*d%Mod;
  return ans;
}
 
inline void NTT(int *A,int f)
{
  for(int i=0;i<n;++i)if(i<R[i])swap(A[i],A[R[i]]);
  for(int i=1;i<n;i<<=1){
    int gn=QPow(3,(Mod-1)/(i<<1)),x,y;
    for(int j=0;j<n;j+=(i<<1)){
      int g=1;
      for(int k=0;k<i;++k,g=1ll*g*gn%Mod){
	x=A[j+k];y=1ll*g*A[i+j+k]%Mod;
	A[j+k]=(x+y)%Mod;A[i+j+k]=(x-y+Mod)%Mod;
      }
    }
  }
  if(f==1)return;reverse(A+1,A+n);
  int y=QPow(n,Mod-2);
  for(int i=0;i<n;++i)A[i]=1ll*A[i]*y%Mod;
}
 
int main()
{
  n=gi();m=gi();
  for(int i=0;i<=n;++i)a[i]=gi();
  for(int i=0;i<=m;++i)b[i]=gi();
  m+=n;for(n=1;n<=m;n<<=1)++L;
  for(int i=0;i<n;++i)R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
  NTT(a,1);NTT(b,1);
  for(int i=0;i<n;++i)a[i]=1ll*a[i]*b[i]%Mod;
  NTT(a,-1);
  for(int i=0;i<=m;++i)printf("%d ",a[i]);
  printf("
");
  return 0;
}

FWT:

const int p=998244353;
const int N=(1<<17)+10;
int n,a[N],b[N],ta[N],tb[N];

void FWT_or(int a[],int type)
{
    int i,j,k;
    for(i=1;i<=n;i++)
    for(j=0;j<(1<<n);j+=1<<i)
    for(k=0;k<(1<<i-1);k++)
    (a[j|(1<<i-1)|k]+=(a[j|k]*type+p)%p)%=p;
}
void FWT_and(int a[],int type)
{
    int i,j,k;
    for(i=1;i<=n;i++)
    for(j=0;j<(1<<n);j+=1<<i)
    for(k=0;k<(1<<i-1);k++)
    (a[j|k]+=(a[j|(1<<i-1)|k]*type+p)%p)%=p;
}
void FWT_xor(int a[],long long type)
{
    int i,j,k,x,y;
    for(i=1;i<=n;i++)
    for(j=0;j<(1<<n);j+=1<<i)
    for(k=0;k<(1<<i-1);k++)
    x=(a[j|k]+a[j|(1<<i-1)|k])*type%p,
    y=(a[j|k]-a[j|(1<<i-1)|k]+p)*type%p,
    a[j|k]=x,a[j|(1<<i-1)|k]=y;
}
int main()
{
    mem(ta,0);
    mem(tb,0);
    scanf("%d",&n);
    for(int i=0;i<(1<<n);i++){
        scanf("%d",&a[i]);
        ta[i]=a[i];
    }
    for(int i=0;i<(1<<n);i++){
        scanf("%d",&b[i]);
        tb[i]=b[i];
    }
    //or
    FWT_or(ta,1);
    FWT_or(tb,1);
    for(int i=0;i<N;i++)ta[i]=1ll*ta[i]*tb[i]%p;
    FWT_or(ta,-1);
    for(int i=0;i<(1<<n);i++)printf("%d%c",ta[i]," 
"[i==(1<<n)-1]);
    //and
    for(int i=0;i<(1<<n);i++){
        ta[i]=a[i];
    }
    for(int i=0;i<(1<<n);i++){
        tb[i]=b[i];
    }
    FWT_and(ta,1);
    FWT_and(tb,1);
    for(int i=0;i<N;i++)ta[i]=1ll*ta[i]*tb[i]%p;
    FWT_and(ta,-1);
    for(int i=0;i<(1<<n);i++)printf("%d%c",ta[i]," 
"[i==(1<<n)-1]);
    //xor
    for(int i=0;i<(1<<n);i++){
        ta[i]=a[i];
    }
    for(int i=0;i<(1<<n);i++){
        tb[i]=b[i];
    }
    FWT_xor(ta,1);
    FWT_xor(tb,1);
    for(int i=0;i<N;i++)ta[i]=1ll*ta[i]*tb[i]%MOD;
    FWT_xor(ta,1ll*((p+1)>>1));
    for(int i=0;i<(1<<n);i++)printf("%d%c",ta[i]," 
"[i==(1<<n)-1]);
}
越自律,越自由
原文地址:https://www.cnblogs.com/ha-chuochuo/p/14339419.html