[学习笔记] 多项式 ln & exp

前置知识

  1. [egin{aligned} (ln x)' &= frac{1}{x} \ (exp x)' &= x \ end{aligned} ]

  2. 复合函数的求导(链式法则)

    [(gcirc f)' (x) = g(f(x))'f'(x) ]

  3. 多项式求逆, 分治FFT.

多项式 ln

(ln f(x)) 求导再积分.

[egin{aligned} frac{mathrm{d} ln f(x)}{mathrm{d} x} &equiv frac{f'(x)}{f(x)} pmod{x ^ n} \ ln f(x) &equiv int mathrm{d} ln f(x) equiv int frac{f'(x)}{f(x)} mathrm{d} x pmod{x^n} end{aligned} ]

多项式求导, 积分都是 (O(n)) 的, 多项式乘法为 (O(nlog n)), 所以总复杂度为 (O(nlog n)).

多项式 exp

普通方法

和求 ln 一样, 也是求导再积分.

[egin{aligned} (exp f(x))' &equiv f'(x) exp f(x) pmod{x^n} \ exp f(x) &equiv int f'(x) exp f(x) mathrm{d}x pmod{x^n} end{aligned} ]

可以用分治FFT解决, 时间复杂度为 (O(n log^2 n)).

牛顿迭代

[[学习笔记] 牛顿迭代]https://www.cnblogs.com/BruceW/p/14079514.html

时间复杂度为 (O(n log n)) ,但由于实现过程中需要求 (ln),所以实际上快不了多少(至少在洛谷的模板上跑得差不多)。

代码

(ln)

#include <cstdio>
#include <iostream>

using namespace std;

typedef long long ll;

const int _ = (1 << 18) + 7;
const int mod = 998244353;
const int rt = 3;

int n, f[_];

int Pw(int a, int p) {
  int res = 1;
  while (p) {
    if (p & 1) res = (ll)res * a % mod;
    a = (ll)a * a % mod;
    p >>= 1;
  }
  return res;
}

namespace POLY {
  int tot, num[_], inv[_], pwrt[2][_], tmp[6][_];
  
  void Init() {
    tot = 1; while (tot <= n + n) tot <<= 1;
    inv[1] = 1;
    for (int i = 2; i <= tot; ++i) inv[i] = (ll)inv[mod % i] * (mod - mod / i) % mod;
    pwrt[0][tot] = Pw(rt, (mod - 1) / tot);
    pwrt[1][tot] = Pw(pwrt[0][tot], mod - 2);
    for (int len = (tot >> 1); len; len >>= 1) {
      pwrt[0][len] = (ll)pwrt[0][len << 1] * pwrt[0][len << 1] % mod;
      pwrt[1][len] = (ll)pwrt[1][len << 1] * pwrt[1][len << 1] % mod;
    }
  }

  void NTT(int *f, int t, bool ty) {
    for (int i = 1; i < t; ++i) {
      num[i] = (num[i >> 1] >> 1) | ((i & 1) ? t >> 1 : 0);
      if (i < num[i]) swap(f[i], f[num[i]]);
    }
    for (int len = 2; len <= t; len <<= 1) {
      int gap = len >> 1, w1 = pwrt[ty][len], w, tmp;
      for (int i = 0; i < t; i += len) {
        w = 1;
        for (int j = i; j < i + gap; ++j) {
          tmp = (ll)w * f[j + gap] % mod;
          f[j + gap] = (f[j] - tmp + mod) % mod;
          f[j] = (f[j] + tmp) % mod;
          w = (ll)w * w1 % mod;
        }
      }
    }
    if (ty) for (int i = 0; i < t; ++i) f[i] = (ll)f[i] * inv[t] % mod;
  }

  void Mul(int *f, int *g, int *h) {
    for (int i = 0; i < tot; ++i) tmp[2][i] = f[i], tmp[3][i] = g[i];
    NTT(tmp[2], tot, 0), NTT(tmp[3], tot, 0);
    for (int i = 0; i < tot; ++i) h[i] = (ll)tmp[2][i] * tmp[3][i] % mod;
    NTT(h, tot, 1);
  }

  void Inv(int *f, int *h) {
    for (int i = 0; i < tot; ++i) h[i] = tmp[1][i] = 0;
    h[0] = Pw(f[0], mod - 2), tmp[1][0] = f[0], tmp[1][1] = f[1];
    for (int len = 2, t = 4; len < tot; len <<= 1, t = (len << 1)) {
      NTT(h, t, 0), NTT(tmp[1], t, 0);
      for (int i = 0; i < t; ++i) h[i] = (ll)h[i] * (2 - (ll)h[i] * tmp[1][i] % mod + mod) % mod;
      NTT(h, t, 1), NTT(tmp[1], t, 1);
      for (int i = len; i < t; ++i) tmp[1][i] = f[i], h[i] = 0;
    }
  }

  void Deriv(int *f, int *h) { for (int i = 0; i < tot - 1; ++i) h[i] = (ll)f[i + 1] * (i + 1) % mod; }

  void Integ(int *f, int *h) { for (int i = tot - 1; i > 0; --i) h[i] = (ll)f[i - 1] * Pw(i, mod - 2) % mod; h[0] = 0; }

  void Ln(int *f, int *h) {
    for (int i = 0; i < tot; ++i) tmp[4][i] = f[i];
    Inv(f, tmp[4]);
    Deriv(f, f);
    Mul(f, tmp[4], f);
    Integ(f, h);
  }
}

int gi() {
  int x = 0; char c = getchar();
  while (!isdigit(c)) c = getchar();
  while (isdigit(c)) x = (x << 3) + (x << 1) + c - '0', c = getchar();
  return x;
}

int main() {
  n = gi();
  for (int i = 0; i < n; ++i) f[i] = gi();
  POLY::Init();
  POLY::Ln(f, f);
  for (int i = 0; i < n; ++i) printf("%d ", f[i]); putchar('
');
  return 0;
}

(exp)(普通方法)

#include <cstdio>
#include <cstring>
#include <iostream>

using namespace std;

typedef long long ll;

const int _ = (1 << 18) + 7;
const int mod = 998244353, rt = 3;

int n, g[_], f[_];

int Pw(int a, int p) {
  int res = 1;
  while (p) {
    if (p & 1) res = (ll)res * a % mod;
    a = (ll)a * a % mod;
    p >>= 1;
  }
  return res;
}

namespace POLY {
  int tot, num[_], pwrt[2][_], inv[_], tmp[5][_];

  void Init() {
    tot = 1; while (tot <= n + n) tot <<= 1;
    inv[1] = 1;
    for (int i = 2; i <= tot; ++i) inv[i] = (ll)inv[mod % i] * (mod - mod / i) % mod;
    pwrt[0][tot] = Pw(rt, (mod - 1) / tot);
    pwrt[1][tot] = Pw(pwrt[0][tot], mod - 2);
    for (int len = (tot >> 1); len; len >>= 1) {
      pwrt[0][len] = (ll)pwrt[0][len << 1] * pwrt[0][len << 1] % mod;
      pwrt[1][len] = (ll)pwrt[1][len << 1] * pwrt[1][len << 1] % mod;
    }
  }

  void NTT(int *f, int t, bool ty) {
    for (int i = 1; i < t; ++i) {
      num[i] = (num[i >> 1] >> 1) | ((i & 1) ? t >> 1 : 0);
      if (i < num[i]) swap(f[i], f[num[i]]);
    }
    for (int len = 2; len <= t; len <<= 1) {
      int gap = len >> 1, w1 = pwrt[ty][len];
      for (int i = 0, w = 1, tmp; i < t; i += len, w = 1)
        for (int j = i; j < i + gap; ++j) {
          tmp = (ll)w * f[j + gap] % mod;
          f[j + gap] = (f[j] - tmp + mod) % mod;
          f[j] = (f[j] + tmp) % mod;
          w = (ll)w * w1 % mod;
        }
    }
    if (ty) for (int i = 0; i < t; ++i) f[i] = (ll)f[i] * inv[t] % mod;
  }

  void Mul(int *f, int *g, int *h, int t) {
    memcpy(tmp[1], f, t << 2);
    memcpy(tmp[2], g, t << 2);
    NTT(tmp[1], t, 0), NTT(tmp[2], t, 0);
    for (int i = 0; i < (t << 1); ++i) h[i] = (ll)tmp[1][i] * tmp[2][i] % mod;
    NTT(h, t, 1);
  }

  void dcNTT(int *f, int *g, int t, int l, int r) {
    if (t == 1) { f[0] = l ? (ll)f[0] * inv[l] % mod : f[0]; return; }
    dcNTT(f, g, t >> 1, l, (l + r) >> 1);
    memset(tmp[0] + (t >> 1), 0, t << 1);
    memcpy(tmp[0], f, t << 1);
    Mul(tmp[0], g, tmp[0], t);
    for (int i = (t >> 1); i < t; ++i) f[i] = (f[i] + tmp[0][i - 1]) % mod;
    dcNTT(f + (t >> 1), g, t >> 1, (l + r) >> 1, r);
  }

  void Exp(int *f, int *g) { dcNTT(f, g, tot >> 1, 1, tot >> 1); }

  void Deriv(int *f, int *h) { for (int i = 0; i < tot - 1; ++i) h[i] = (ll)f[i + 1] * (i + 1) % mod; }
}

int main() {
  scanf("%d", &n);
  for (int i = 0; i < n; ++i) scanf("%d", &g[i]);
  POLY::Init();
  POLY::Deriv(g, g);
  f[0] = 1;
  POLY::Exp(f, g);
  for (int i = 0; i < n; ++i) printf("%d ", f[i]); putchar('
');
  return 0;
}

(exp) (牛顿迭代)

#include <cstdio>
#include <cstring>
#include <iostream>

using namespace std;

typedef long long ll;

const int _ = (1 << 18) + 7;
const int mod = 998244353, rt = 3;

int n, f[_], g[_];

int Pw(int a, int p) {
  int res = 1;
  while (p) {
    if (p & 1) res = (ll)res * a % mod;
    a = (ll)a * a % mod;
    p >>= 1;
  }
  return res;
}

namespace POLY {
  int tot, num[_], pwrt[2][_], inv[_];

  void Init() {
    tot = 1; while (tot <= n + n) tot <<= 1;
    inv[1] = 1; for (int i = 2; i <= tot; ++i) inv[i] = (ll)inv[mod % i] * (mod - mod / i) % mod;
    pwrt[0][tot] = Pw(rt, (mod - 1) / tot);
    pwrt[1][tot] = Pw(pwrt[0][tot], mod - 2);
    for (int len = (tot >> 1); len; len >>= 1) {
      pwrt[0][len] = (ll)pwrt[0][len << 1] * pwrt[0][len << 1] % mod;
      pwrt[1][len] = (ll)pwrt[1][len << 1] * pwrt[1][len << 1] % mod;
    }
  }

  void Clear(int *f, int L) { memset(f, 0, L << 3); }

  void NTT(int *f, int L, bool ty) {
    for (int i = 1; i < L; ++i) {
      num[i] = (num[i >> 1] >> 1) | ((i & 1) ? L >> 1 : 0);
      if (i < num[i]) swap(f[i], f[num[i]]);
    }
    for (int len = 2; len <= L; len <<= 1) {
      int gap = len >> 1, w1 = pwrt[ty][len];
      for (int i = 0, w = 1, tmp; i < L; i += len, w = 1)
        for (int j = i; j < i + gap; ++j) {
          tmp = (ll)w * f[j + gap] % mod;
          f[j + gap] = (f[j] - tmp + mod) % mod;
          f[j] = (f[j] + tmp) % mod;
          w = (ll)w * w1 % mod;
        }
    }
    if (ty) for (int i = 0; i < L; ++i) f[i] = (ll)f[i] * inv[L] % mod;
  }

  void Cpy(int *h, int *f, int L) { memcpy(h, f, L << 2); }

  void Inv(int *h, int *f, int L) {
    int a[_], b[_];
    Clear(h, L), Clear(a, L), Clear(b, L);
    h[0] = Pw(f[0], mod - 2);
    for (int len = 2, t = 4; len <= L; len <<= 1, t <<= 1) {
      Cpy(a, f, len), Cpy(b, h, len), NTT(b, t, 0), NTT(a, t, 0);
      for (int i = 0; i < t; ++i) b[i] = (ll)b[i] * (2 - (ll)a[i] * b[i] % mod + mod) % mod;
      NTT(b, t, 1);
      for (int i = (len >> 1); i < len; ++i) h[i] = b[i];
    }
  }

  void Deriv(int *h, int *f, int L) { for (int i = 0; i < L - 1; ++i) h[i] = (ll)f[i + 1] * (i + 1) % mod; }
  void Integ(int *h, int *f, int L) { for (int i = L - 1; i; --i) h[i] = (ll)f[i - 1] * inv[i] % mod; h[0] = 0; }

  void Ln(int *h, int *f, int L) {
    int a[_], b[_];
    Clear(h, L), Clear(a, L), Clear(b, L);
    Deriv(a, f, L), Inv(b, f, L);
    NTT(a, L << 1, 0), NTT(b, L << 1, 0);
    for (int i = 0; i < (L << 1); ++i) h[i] = (ll)a[i] * b[i] % mod;
    NTT(h, L << 1, 1);
    Integ(h, h, L);
  }

  void Exp(int *h, int *f, int L) {
    int a[_], b[_], c[_];
    Clear(h, L), Clear(a, L), Clear(b, L), Clear(c, L);
    h[0] = 1, a[0] = f[0], a[1] = f[1];
    for (int len = 2, t = 4; len <= L; len <<= 1, t <<= 1) {
      Cpy(c, h, len), Ln(b, h, len), Cpy(a, f, len);
      NTT(c, len, 0), NTT(b, len, 0), NTT(a, len, 0);
      for (int i = 0; i < len; ++i) c[i] = (ll)c[i] * (1ll - b[i] + a[i] + mod) % mod;
      NTT(c, len, 1);
      for (int i = (len >> 1); i < len; ++i) h[i] = c[i];
    }
  }
}

int main() {
  scanf("%d", &n);
  for (int i = 0; i < n; ++i) scanf("%d", &g[i]);
  POLY::Init();
  POLY::Exp(f, g, POLY::tot >> 1);
  for (int i = 0; i < n; ++i) printf("%d ", f[i]); putchar('
');
  return 0;
}
原文地址:https://www.cnblogs.com/BruceW/p/14079550.html