[学习笔记] 分治FFT

算法用途

(n-1) 次多项式 (g(x)), 求 (f(x)), 满足

[[x^i]f(x) equiv sum_{j = 0}^{i - 1} ([x^j]f(x)) imes ([x^{i-j}]g(x)) ]

算法过程

总思路: 分治.

对于 ([x^{l sim r}]f(x)),

  1. 先递归算出 ([x^{l sim mid}]f(x));
  2. 一遍 FFT, 统计 ([x^{l sim mid}]f(x))([x^{mid + 1 sim r}]f(x)) 的贡献;
  3. 递归算出 ([x^{mid+1 sim r}]f(x)).

时间复杂度 (O(nlog^2 n)).

代码

【模板】分治 FFT

#include <cstdio>
#include <cstring>
#include <iostream>

using namespace std;

typedef long long ll;

const int _ = (1 << 18) + 7;
const int mod = 998244353, rt = 3;

int n, f[_], g[_];

int Pw(int a, int p) {
  int res = 1;
  while (p) {
    if (p & 1) res = (ll)res * a % mod;
    a = (ll)a * a % mod;
    p >>= 1;
  }
  return res;
}

namespace POLY {
  int tot, num[_], pwrt[2][_], inv[_], tmp[5][_];

  void Init() {
    tot = 1; while (tot <= n + n) tot <<= 1;
    inv[1] = 1;
    for (int i = 2; i <= tot; ++i) inv[i] = (ll)inv[mod % i] * (mod - mod / i) % mod;
    pwrt[0][tot] = Pw(rt, (mod - 1) / tot);
    pwrt[1][tot] = Pw(pwrt[0][tot], mod - 2);
    for (int len = (tot >> 1); len; len >>= 1) {
      pwrt[0][len] = (ll)pwrt[0][len << 1] * pwrt[0][len << 1] % mod;
      pwrt[1][len] = (ll)pwrt[1][len << 1] * pwrt[1][len << 1] % mod;
    }
  }

  void NTT(int *f, int t, bool ty) {
    for (int i = 1; i < t; ++i) {
      num[i] = (num[i >> 1] >> 1) | ((i & 1) ? t >> 1 : 0);
      if (i < num[i]) swap(f[i], f[num[i]]);
    }
    for (int len = 2; len <= t; len <<= 1) {
      int gap = len >> 1, w1 = pwrt[ty][len];
      for (int i = 0, w = 1, tmp; i < t; i += len, w = 1)
        for (int j = i; j < i + gap; ++j) {
          tmp = (ll)w * f[j + gap] % mod;
          f[j + gap] = (f[j] - tmp + mod) % mod;
          f[j] = (f[j] + tmp) % mod;
          w = (ll)w * w1 % mod;
        }
    }
    if (ty) for (int i = 0; i < t; ++i) f[i] = (ll)f[i] * inv[t] % mod;
  }

  void Mul(int *f, int *g, int *h, int t) {
    memcpy(tmp[0], f, t << 2);  // memcpy/memset 是按照字节数赋值, 所以int类型需要 * 4.
    memcpy(tmp[1], g, t << 2);
    NTT(tmp[0], t, 0), NTT(tmp[1], t, 0);
    for (int i = 0; i < t; ++i) h[i] = (ll)tmp[0][i] * tmp[1][i] % mod;
    NTT(h, t, 1);
  }

  void dcNTT(int *f, int *g, int t, int l, int r) {
    if (t == 1) return;
    dcNTT(f, g, t >> 1, l, (l + r) >> 1);
    memcpy(tmp[2], f, t << 1);
    Mul(tmp[2], g, tmp[2], t);  // 由于 FFT 本质上是循环卷积, 所以可以不用把多项式拓展到 t << 1
    for (int i = (t >> 1); i < t; ++i) f[i] = (f[i] + tmp[2][i]) % mod;
    memset(tmp[2], 0, t << 3);
    dcNTT(f + (t >> 1), g, t >> 1, (l + r) >> 1, r);
  }

  void dcMul(int *g, int *f) {
    cerr << (tot >> 1) << endl;
    dcNTT(f, g, tot >> 1, 0, tot >> 1);
  }
}

int main() {
  scanf("%d", &n);
  for (int i = 1; i < n; ++i) scanf("%d", &g[i]);
  POLY::Init();
  f[0] = 1;
  POLY::dcMul(g, f);
  for (int i = 0; i < n; ++i) printf("%d ", f[i]); putchar('
');
  return 0;
}
原文地址:https://www.cnblogs.com/BruceW/p/13886916.html