【THUPC 2017】小L的计算题

Problem

Description

现有一个长度为 (n) 的非负整数数组 ({a_i})。小 L 定义了一种神奇变换:

[f_k=sum_{i=1}^{n}{a_i}^kpmod{ 998244353 } ]

小 L 计划用变换生成的序列 (f) 做一些有趣的事情,但是他并不擅长算乘法,所以来找你帮忙,希望你能帮他尽快计算出 (f_1sim f_n)

总共有 (T) 组数据。

Range

(nle 2 imes10^5, Tle20, sum nle 4 imes10^5, a_ile 10^9)

Algorithm

生成函数,多项式

Mentality

写出生成函数 (F) 的表达式:

[F=sum_{k} f_kx^k\ =sum_{k} sum_{i=1}^n a_i^kx^k\ =sum_{i=1}^n sum_{k} (a_ix)^k\ =sum_{i=1}^n frac{1}{1-a_ix}\ =n-xsum_{i=1}^n frac{-a_i}{1-a_ix} ]

然后发现 ((ln(1-a_ix))'=frac{-a_i}{1-a_ix}) ,直接代入:

[F=n-xsum_{i=1}^n (ln(1-a_ix))'\ =n-x(sum_{i=1}^n ln(1-a_ix))'\ =n-x(ln(prod_{i=1}^n (1-a_ix)))' ]

用分治计算 (prod) ,然后求个 (ln) 就完事了。

Code

#include <cmath>
#include <cstdio>
#include <iostream>
#include <vector>
using namespace std;
#define LL long long
#define go(G, x, i, v) 
  for (int i = G.hd[x], v = G.to[i]; i; v = G.to[i = G.nx[i]])
#define inline __inline__ __attribute__((always_inline))
inline LL read() {
  LL x = 0, w = 1;
  char ch = getchar();
  while (!isdigit(ch)) {
    if (ch == '-') w = -1;
    ch = getchar();
  }
  while (isdigit(ch)) {
    x = (x << 3) + (x << 1) + ch - '0';
    ch = getchar();
  }
  return x * w;
}

const int Max_n = 4e6 + 5, mod = 998244353;
int T;
bool fl;
int n, cnt, a[Max_n];
vector<int> f[1000000], ans;

namespace Input {
void main() {
  n = read();
  for (int i = 1; i <= n; i++) a[i] = read();
}
}  // namespace Input

namespace Poly {
int len, bit, rev[Max_n];
int ksm(int a, int b = mod - 2) {
  int res = 1;
  for (; b; b >>= 1, a = (LL)a * a % mod)
    if (b & 1) res = (LL)res * a % mod;
  return res;
}
void init(int n) {
  len = 1 << (bit = log2(n) + 1);
  for (int i = 0; i < len; i++)
    rev[i] = rev[i >> 1] >> 1 | ((i & 1) << bit - 1);
}
void dft(vector<int> &f, bool t) {
  for (int i = 0; i < len; i++)
    if (rev[i] > i) swap(f[i], f[rev[i]]);
  for (int l = 1; l < len; l <<= 1) {
    int Wn = ksm(3, (mod - 1) / (l << 1));
    if (t) Wn = ksm(Wn);
    for (int i = 0; i < len; i += l << 1) {
      int Wnk = 1;
      for (int j = i; j < i + l; j++, Wnk = (LL)Wnk * Wn % mod) {
        int x = f[j], y = (LL)f[j + l] * Wnk % mod;
        f[j] = (x + y) % mod, f[j + l] = (x - y + mod) % mod;
      }
    }
  }
  if (t)
    for (int i = 0, Inv = ksm(len); i < len; i++) f[i] = (LL)f[i] * Inv % mod;
}
void Resize(vector<int> &f, int len) {
  f.resize(len);
  for (int i = 0; i < len; i++) f[i] = 0;
}
void Mul(vector<int> f, vector<int> &g, vector<int> &res, int N) {
  init(N);
  static vector<int> G;
  Resize(res, len), Resize(G, len);
  for (int i = 0; i < min((int)f.size(), len); i++) res[i] = f[i];
  for (int i = 0; i < min((int)g.size(), len); i++) G[i] = g[i];
  dft(res, 0), dft(G, 0);
  for (int i = 0; i < len; i++) res[i] = (LL)res[i] * G[i] % mod;
  dft(res, 1);
}
void Inv(vector<int> &f, vector<int> &res, int N) {
  init(N * 6);
  Resize(res, len);
  static vector<int> F;
  Resize(F, len);
  res[0] = ksm(f[0]);
  for (int deg = 2; deg < (N << 1); deg <<= 1) {
    init(deg * 3);
    for (int i = 0; i < min(deg, (int)f.size()); i++) F[i] = f[i];
    for (int i = min(deg, (int)f.size()); i < len; i++) F[i] = 0;
    dft(F, 0), dft(res, 0);
    for (int i = 0; i < len; i++)
      res[i] = (2ll * res[i] % mod + mod - (LL)res[i] * res[i] % mod * F[i] % mod) % mod;
    dft(res, 1);
    for (int i = deg; i < len; i++) res[i] = 0;
  }
}
void Ln(vector<int> &f, vector<int> &res, int N) {
  static vector<int> inv;
  res = f;
  for (int i = 0; i < N; i++) res[i] = (LL)res[i + 1] * (i + 1) % mod;
  res[N] = 0, Inv(f, inv, N);
  Mul(res, inv, res, N + N);
}
}  // namespace Poly
using namespace Poly;

namespace Solve {
void Solve(int o, int l, int r) {
  if (l == r) {
    f[o].resize(2);
    f[o][0] = 1, f[o][1] = (-a[l] % mod + mod) % mod;
    return;
  }
  int mid = l + r >> 1;
  Solve(o << 1, l, mid), Solve(o << 1 | 1, mid + 1, r);
  Mul(f[o << 1], f[o << 1 | 1], f[o], r - l + 2);
}
void main() {
  Solve(1, 1, n);
  for (int i = n + 1; i < len; i++) f[1][i] = 0;
  fl = 1;
  Ln(f[1], ans, n + 1);
  int Ans = 0;
  for (int i = 0; i < n; i++) Ans ^= (-ans[i] + mod) % mod;
  cout << Ans << endl;
}
}  // namespace Solve

int main() {
#ifndef ONLINE_JUDGE
  freopen("2409.in", "r", stdin);
  freopen("2409.out", "w", stdout);
#endif
  T = read();
  while (T--) {
    Input::main();
    Solve::main();
  }
}
原文地址:https://www.cnblogs.com/luoshuitianyi/p/12891662.html