Codechef BINOMSUM

题意:(复制sunset的)有(T)天,每天有(K)个小时,第(i)天有(D+i−1)道菜,第一个小时你选择(L)道菜吃,接下来每个小时你可以选择吃一道菜或者选择(A)个活动中的一个参加,不能连续两个小时吃菜,问每天的方案数之和。(K),(A)预先给定,(Q)次询问,每次给(D),(L),(T)

题解:显然(ans=sum_{i=D}^{D+T-1}inom{i}{L}F(i)),其中(F(i))是一个不超过(k-1)次的多项式。

把组合数暴力拆开,变为(sum_{i=D}^{D+T-1}frac{i!}{L!(i-L)!}F(i))。因为有阶乘,所以考虑把(F(i))写成上升幂多项式的形式来消掉阶乘。具体地,设(F(x)=sum_{i=0}^{k-1}a_i(x+1)dots(x+i)=sum_{i=0}^{k-1}a_ifrac{(x+i)!}{x!}),则(ans=frac{1}{L!}sum_{i=D}^{D+T-1}sum_{j=0}^{k-1}a_jfrac{(i+j)!}{(i-L)!})。考虑在(frac{(i+j)!}{(i-L)!})的分母处补上((j+L)!)变为组合数,则(ans=frac{1}{L!} sum_{j=0}^{k-1}a_j(j+L)!sum_{i=D}^{D+T-1}inom{i+j}{j+L})。后面是组合数上指标求和,可以(O(1))计算。

剩下的问题是怎样求(a)。上升幂多项式可以考虑用连续点值来求。具体地,假设我们求出了(F(-1),F(-2),dots,F(-k)),显然有式子(F(-u)=sum_{i=0}^{u-1}frac{(u-1)!}{(u-1-i)!}(-1)^ia_i)。设(x_i=(-1)^ia_i,y_i=frac{1}{i!},z_i=F(-(u+1))),则(Z=X*Y,X=frac{Z}{Y})。多项式求逆即可。(其实可以不用求逆,可以发现(Y=e^x,Y^{-1}=e^{-x})。)

剩下的问题是怎样求点值。设(b_i)为考虑了前(i)个小时的方案数,对于要求的点值(x),有递推式(b_i=Ab_{i-1}+Axb_{i-2}),可以用矩阵快速幂在(O(log k))的时间内求出单个点值。

#include<bits/stdc++.h>
using namespace std;
typedef double db;
typedef long long ll;
typedef unsigned long long ull;
const int N = 1e6 + 10;
const int M = 1e7 + 1e5 + 10;
const db pi = acos(-1);

int k, a, mod, q, l, r[N], fac[M], inv[M], ifac[M], x[N], y[N], z[N];

int gi() {
  int x = 0, o = 1;
  char ch = getchar();
  while((ch < '0' || ch > '9') && ch != '-') {
    ch = getchar();
  }
  if(ch == '-') {
    o = -1, ch = getchar();
  }
  while(ch >= '0' && ch <= '9') {
    x = x * 10 + ch - '0', ch = getchar();
  }
  return x * o;
}

struct com {
  db x, y;
  com(db x = 0, db y = 0): x(x), y(y) {}
  com operator+(const com &A) const {
    return com(x + A.x, y + A.y);
  }
  com operator-(const com &A) const {
    return com(x - A.x, y - A.y);
  }
  com operator*(const com &A) const {
    return com(x * A.x - y * A.y, x * A.y + y * A.x);
  }
  com conj() {
    return com(x, -y);
  }
} w[N];

void init(int n) {
  l = 0;
  for(int i = 1; i < n; i <<= 1) {
    ++l;
  }
  for(int i = 0; i < n; i++) {
    r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1)), w[i] = com(cos(pi * i / n), sin(pi * i / n));
  }
}

void FFT(com *a, int n) {
  for(int i = 0; i < n; i++) if(i < r[i]) {
      swap(a[i], a[r[i]]);
    }
  for(int i = 1; i < n; i <<= 1)
    for(int p = i << 1, j = 0; j < n; j += p)
      for(int k = 0; k < i; k++) {
        com x = a[j + k], y = w[n / i * k] * a[j + k + i];
        a[j + k] = x + y, a[j + k + i] = x - y;
      }
}

void mul(int *a, int *b, int *c, int n) {
  static com s1[N], s2[N], s3[N], s4[N], s5[N], s6[N];
  init(n);
  for(int i = 0; i < n; i++) {
    s1[i] = com(a[i] & 32767, a[i] >> 15);
    s2[i] = com(b[i] & 32767, b[i] >> 15);
  }
  FFT(s1, n), FFT(s2, n);
  for(int i = 0; i < n; i++) {
    int j = (n - 1) & (n - i);
    com da = (s1[i] + s1[j].conj()) * com(0.5, 0);
    com db = (s1[i] - s1[j].conj()) * com(0, -0.5);
    com dc = (s2[i] + s2[j].conj()) * com(0.5, 0);
    com dd = (s2[i] - s2[j].conj()) * com(0, -0.5);
    s3[i] = da * dc, s4[i] = da * dd, s5[i] = db * dc, s6[i] = db * dd;
  }
  for(int i = 0; i < n; i++) {
    s1[i] = s3[i] + s4[i] * com(0, 1);
    s2[i] = s5[i] + s6[i] * com(0, 1);
  }
  FFT(s1, n), FFT(s2, n);
  reverse(s1 + 1, s1 + n), reverse(s2 + 1, s2 + n);
  for(int i = 0; i < n; i++) {
    int da = (ll)(s1[i].x / n + 0.5) % mod;
    int db = (ll)(s1[i].y / n + 0.5) % mod;
    int dc = (ll)(s2[i].x / n + 0.5) % mod;
    int dd = (ll)(s2[i].y / n + 0.5) % mod;
    c[i] = (da + ((ll)(db + dc) << 15) + ((ll)dd << 30)) % mod;
  }
}

struct mat {
  int v[2][2];
  mat() {
    memset(v, 0, sizeof(v));
  }
  mat operator*(const mat &A) const {
    mat ret;
    for(int i = 0; i < 2; i++)
      for(int j = 0; j < 2; j++) {
        ull tmp = 0;
        for(int k = 0; k < 2; k++) {
          tmp += 1ll * v[i][k] * A.v[k][j];
        }
        ret.v[i][j] = tmp % mod;
      }
    return ret;
  }
} S, T;

mat qpow(mat a, int b) {
  mat ret;
  for(int i = 0; i < 2; i++) {
    ret.v[i][i] = 1;
  }
  while(b) {
    if(b & 1) {
      ret = ret * a;
    }
    a = a * a, b >>= 1;
  }
  return ret;
}

void init() {
  const int n = 1e7 + 1e5 + 1;
  fac[0] = fac[1] = ifac[0] = ifac[1] = inv[1] = 1;
  for(int i = 2; i <= n; i++) {
    fac[i] = 1ll * fac[i - 1] * i % mod;
    inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
    ifac[i] = 1ll * ifac[i - 1] * inv[i] % mod;
  }
}

int C(int n, int m) {
  if(m < 0 || n < m) {
    return 0;
  }
  return 1ll * fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}

int main() {
#ifndef ONLINE_JUDGE
  freopen("a.in", "r", stdin);
  freopen("a.out", "w", stdout);
#endif
  cin >> k >> a >> mod >> q;
  init();
  S.v[0][0] = 1, S.v[0][1] = a, T.v[1][0] = 1, T.v[1][1] = a;
  for(int i = 0; i < k; i++) {
    T.v[0][1] = 1ll * a * (mod - i - 1) % mod;
    z[i] = 1ll * (S * qpow(T, k - 1)).v[0][0] * ifac[i] % mod;
    y[i] = 1ll * ((i & 1) ? mod - 1 : 1) * ifac[i] % mod;
  }
  int N = 1;
  while(N <= 2 * k - 2) {
    N <<= 1;
  }
  mul(y, z, x, N);
  for(int i = 0; i < k; i++) {
    x[i] = 1ll * x[i] * ((i & 1) ? mod - 1 : 1) % mod;
  }
  while(q--) {
    int l = gi(), d = gi(), t = gi(), ans = 0;
    for(int i = 0; i < k; i++) {
      ans = (ans + 1ll * x[i] * fac[i + l] % mod * (C(d + t + i, i + l + 1) - C(d + i, i + l + 1) + mod)) % mod;
    }
    cout << 1ll * ans*ifac[l] % mod << '
';
  }
  return 0;
}

原文地址:https://www.cnblogs.com/gczdajuruo/p/10921236.html