P3803 【模板】多项式乘法(FFT)

题目背景

这是一道FFT模板题

注意:虽然本题开到3s,但是建议程序在1s内可以跑完,本题需要一定程度的常数优化。

题目描述

给定一个n次多项式F(x),和一个m次多项式G(x)。

请求出F(x)和G(x)的卷积。

输入输出格式

输入格式:

第一行2个正整数n,m。

接下来一行n+1个数字,从低到高表示F(x)的系数。

接下来一行m+1个数字,从低到高表示G(x))的系数。

输出格式:

一行n+m+1个数字,从低到高表示F(x)∗G(x)的系数。

输入输出样例

输入样例#1: 复制
1 2
1 2
1 2 1
输出样例#1: 复制
1 4 5 2

说明

保证输入中的系数大于等于 0 且小于等于9。

对于100%的数据: n, m leq {10}^6n,m106 , 共计20个数据点,2s。

数据有一定梯度。

空间限制:256MB

//problem: P3803 【模板】多项式乘法(FFT)

#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
using namespace std;

const int N=1e7+5;
const double Pi=acos(-1);

int n,m;
int rev[N];
int bit,len=1;
struct Complex
{
    double x,y;
    Complex(double xx=0,double yy=0){x=xx,y=yy;}
    Complex operator + (const Complex &a)
    {
        return Complex(this->x + a.x,this->y + a.y); 
    }
    Complex operator - (const Complex &a)
    {
        return Complex(this->x - a.x,this->y - a.y);
    }
    Complex operator * (const Complex &a)
    {
        return Complex(this->x * a.x - this->y * a.y, this->x * a.y + this->y * a.x);
    }
}a[N],b[N];
//Complex operator * (Complex a,Complex b){ return Complex(a.x*b.x-a.y*b.y , a.x*b.y+a.y*b.x);}//不懂的看复数的运算那部分 

inline int read()
{
    char c=getchar();int num=0;
    for(;!isdigit(c);c=getchar());
    for(;isdigit(c);c=getchar())
        num=num*10+c-'0';
    return num;
}

void fft(Complex *A,int type)
{
    for(int i=0;i<len;++i)
        if(i<rev[i])
            swap(A[i],A[rev[i]]);
    for(int step=1;step<len;step<<=1)
    {
        Complex wn(cos(Pi/step),type*sin(Pi/step));
        for(int j=0;j<len;j+=step<<1)
        {
            Complex wnk(1,0);
            for(int k=j;k<step+j;++k)
            {
                Complex x=A[k];
                Complex y=wnk*A[k+step];
                A[k]=x+y;
                A[k+step]=x-y;
                wnk=wnk*wn;
            }
        }
    }
}

int main()
{
    n=read(),m=read();
    for(int i=0;i<=n;++i)
        a[i].x=read();
    for(int i=0;i<=m;++i)
        b[i].x=read();
    while(len<=n+m)
        len<<=1,++bit;
    for(int i=0;i<len;++i)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    fft(a,1);
    fft(b,1);
    for(int i=0;i<=len;++i)
        a[i]=a[i]*b[i];
    fft(a,-1);
    for(int i=0;i<=n+m;++i)
        printf("%d ",(int)(a[i].x/len+0.5));
    return 0;
}
原文地址:https://www.cnblogs.com/lovewhy/p/8977500.html