LibreOJ #108. 多项式乘法

二次联通门 : LibreOJ #108. 多项式乘法

/*
    LibreOJ #108. 多项式乘法

    FFT板子题
    不行啊。。。跑的还是慢

    应该找个机会学一学由乃dalao的fft
    或者是毛爷爷的fft,跑的真是快啊。。。
*/
#include <cstdio>
#include <iostream>
#include <cmath>

const int BUF = 12312312;
char Buf[BUF], *buf = Buf;

inline void read (int &now)
{
    for (now = 0; !isdigit (*buf); ++ buf);
    for (; isdigit (*buf); now = now * 10 + *buf - '0', ++ buf);
}
using std :: swap;
#define Max 3000000
typedef double flo;
struct Vec 
{
    flo r, i; Vec () {}
    Vec (flo x, flo y) : r (x), i (y) {}
    Vec operator * (const Vec &b) const
    { return Vec (r * b.r - i * b.i, r * b.i + i * b.r); }
    Vec operator * (const flo &k) const 
    { return Vec (r * k, i * k); }
    Vec operator + (const Vec &b) const
    { return Vec (r + b.r, i + b.i); }
    Vec operator - (const Vec &b) const
    { return Vec (r - b.r, i - b.i); }
    Vec& operator /= (const flo &k) 
    { return r /= k, i /= k, *this; }
};

Vec a[Max], b[Max];
int N, M, Maxn, rader[Max];
const flo PI = acos (-1);

void FFT (Vec *a, int N, int f = 1)
{
    register int i, j, k;
    for (i = 1; i < N; ++ i)
        if (rader[i] > i) swap (a[i], a[rader[i]]);
    for (k = 1; k < N; k <<= 1)
    {
        Vec wn (cos (PI / k), f * sin (PI / k));
        for (j = 0; j < N; j += k << 1)
        {
            Vec w (1, 0), t;
            for (i = j; i < j + k; ++ i, w = w * wn)
            {
                t = w * a[i + k];
                a[i + k] = a[i] - t;
                a[i] = a[i] + t;
            }
        }
    }
    if (f == -1)
        for (i = 0; i < N; ++ i) a[i] /= N;
}

int Main ()
{
    fread (buf, 1, BUF, stdin);
    read (N), read (M); register int i; int x;
    ++ N, ++ M, Maxn = 1 << int (ceil (log2 (N + M)));
    for (i = 0; i < N; ++ i) read (x), a[i].r = x;
    for (i = 0; i < M; ++ i) read (x), b[i].r = x;
    
    for (i = 1; i < Maxn; ++ i)
       rader[i] = rader[i >> 1] >> 1 | (i & 1) * (Maxn >> 1);
    FFT (a, Maxn), FFT (b, Maxn);
    for (i = 0; i < Maxn; ++ i)
        a[i] = a[i] * b[i];
    N = N + M - 2;
    for (FFT (a, Maxn, -1), i = 0; i <= N; ++ i)
        printf ("%d ", int (round (a[i].r)));    
    return 0;
}
int ZlycerQan = Main ();
int main (int argc, char *argv[]) {;}
原文地址:https://www.cnblogs.com/ZlycerQan/p/7435542.html