[科技] 求数列的前k次方和

[科技] 求数列的前k次方和

到现在才会的一个科技,写一篇博客来记录一下。

简单来说,就是对于(0 leq t leq k)(sum_{i = 1} ^ n a_i^t)(n, k leq 10^5)

我们考虑答案序列的生成函数:

[F(x) = sum_{t = 0} ^ {infty} x^t sum_{i = 1} ^ n a_i ^ t = sum_{i = 0} ^ {n}sum_{t = 0} ^ {infty} (a_ix) ^ t = sum_{i = 1} ^ n frac{1}{1 - a_ix} ]

发现这样还是无法下手,但是考虑到$$(ln(1 - a_ix))' = ln'(1 - a_ix) imes a_i = frac{-a_i}{1 - a_ix} = -sum_{t = 0} ^ {infty}(a_ix) ^ t a_i$$

我们记$$G(x) = sum_{i = 1} ^ n (ln(1 - a_ix))' = sum_{i = 1} ^ n frac{-a_i}{1 - a_ix} = -sum_{i = 1} ^ {n} sum_{t = 0} ^ {infty}(a_ix) ^ t a_i = -sum_{t = 0} ^ {infty}x ^ tsum_{i = 1} ^ n a_i ^ {t + 1}$$

那么(F(x) = -xG(x) + n),于是我们求出了(G(x)),就可以快速求出(F(x))了。
考虑继续化简(G(x))

[G(x) = (sum_{i = 1} ^ n ln(1 - a_ix))' = ln'(prod_{i = 1} ^ n (1 - a_ix))' ]

里面的那个分治(FFT)一下就行了,然后求(Ln),求导就可以得到(G)了。

不过不知道是不是自己的写法有问题,发现这样只能求到前(n)项的幂和,必须在(a)后面补零才行……

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 50;
const int Md = 998244353;
typedef long long ll;
typedef vector<int> Vec;

inline int Add(const int &x, const int &y) { return (x + y >= Md) ? (x + y - Md) : (x + y); }
inline int Sub(const int &x, const int &y) { return (x - y < 0) ? (x - y + Md) : (x - y); }
inline int Mul(const int &x, const int &y) { return (ll)x * y % Md; }
int Powe(int x, int y) {
  int ans = 1;
  while(y) {
    if(y & 1) ans = Mul(ans, x);
    x = Mul(x, x);
    y >>= 1;
  }
  return ans;
}

int n, k;
int a[N];

namespace Poly {
  int rev[N << 2 | 1], inv[N << 2 | 1];

  void Init() {
    inv[0] = inv[1] = 1;
    for(int i = 2; i < N; i++) {
      inv[i] = Mul(Md - Md / i, inv[Md % i]); 
    }
  }
  
  void DFT(Vec &A, int len) {
    for(int i = 0; i < len; i++) if(i < rev[i]) swap(A[i], A[rev[i]]);
    for(int i = 1; i < len; i <<= 1) {
      int wn = Powe(3, (Md - 1) / (i << 1));
      for(int j = 0; j < len; j += i << 1) {
	    int nw = 1, x, y;
	    for(int k = 0; k < i; k++, nw = Mul(nw, wn)) {
	      x = A[j + k], y = Mul(nw, A[i + j + k]);
	      A[j + k] = Add(x, y); A[i + j + k] = Sub(x, y);
	    }
      }
    }
  }

  void IDFT(Vec &A, int len) {
    reverse(A.begin() + 1, A.end());
    int IV = Powe(len, Md - 2);
    DFT(A, len);
    for(int i = 0; i < len; i++) A[i] = Mul(A[i], IV);
  }

  Vec MUL(Vec A, Vec B) {
    int n = A.size(), m = B.size(), len;
    for(len = 1; len < n + m - 1; len <<= 1);
    for(int i = 0; i < len; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? len >> 1 : 0);
    A.resize(len); B.resize(len);
    DFT(A, len); DFT(B, len);
    for(int i = 0; i < len; i++) A[i] = Mul(A[i], B[i]);
    IDFT(A, len);
    A.resize(n + m - 1);
    return A;
  }

  Vec GetInv(Vec A, int len) {
    Vec B(1, Powe(A[0], Md - 2)), C;
    for(int i = 2; (i >> 1) < len; i <<= 1) {
      for(int j = 0; j < (i << 1); j++) rev[j] = (rev[j >> 1] >> 1) | ((j & 1) ? i : 0);
      C = A; C.resize(i);
      C.resize(i << 1); DFT(C, i << 1);
      B.resize(i << 1); DFT(B, i << 1);
      for(int j = 0; j < (i << 1); j++) B[j] = Mul(B[j], Sub(2, Mul(B[j], C[j])));
      IDFT(B, i << 1);
      B.resize(i);
    }
    B.resize(len);
    return B;
  }

  Vec Dir(Vec A) {
    Vec B; int len = A.size(); B.resize(len); 
    for(int i = 1; i < len; i++) B[i - 1] = Mul(i, A[i]);
    B[len - 1] = 0;
    return B;
  }

  Vec Inter(Vec A) {
    Vec B; int len = A.size(); B.resize(len);
    for(int i = 1; i < len; i++) B[i] = Mul(A[i - 1], inv[i]);
    B[0] = 0;
    return B;
  }

  Vec Ln(Vec A, int len) {
    A = Inter(MUL(Dir(A), GetInv(A, len)));
    A.resize(len);
    return A;
  }
}

Vec Solve(int l, int r) {
  if(l == r) {
    if(l <= n) {
      Vec A(2); A[0] = 1; A[1] = Md - a[l];
      return A;
    }
    else {
      Vec A(2); A[0] = 1; A[1] = 0;
      return A;
    }
  }
  int mid = (l + r) >> 1;
  Vec tmp1 = Solve(l, mid), tmp2 = Solve(mid + 1, r);
  tmp1 = Poly::MUL(tmp1, tmp2);
  return tmp1;
}

int main() {
  Poly::Init();
  scanf("%d%d", &n, &k);
  for(int i = 1; i <= n; i++) scanf("%d", &a[i]);
  Vec A = Solve(1, k);
  A = Poly::Ln(A, A.size());
  A = Poly::Dir(A);
  for(int i = 0; i < A.size(); i++) A[i] = (Md - A[i]) % Md;
  A.resize(A.size() + 1);
  for(int i = A.size(); i; i--) A[i] = A[i - 1]; A[0] = 0;
  A[0] = Add(A[0], n);
  return 0;
}

原文地址:https://www.cnblogs.com/Apocrypha/p/10619032.html