【BZOJ 3771】Triple

Problem

Description

给出 (n) 个物品,第 (i) 个物品体积为 (a_i)

对于每个体积 (V) ,求选出 (3) 个物品,体积之和为 (V) 的方案总数。

选择顺序不同算同一种方案。

Range

(n) 保证不会读入到 (TLE)(a_ile 4 imes 10^4)

Algorithm

多项式,生成函数。

Mentality

设生成函数 (A(x)) 为只选择一个物品的生成函数。其中 ([x^m]A(x)) 的系数代表了体积 (m) 有多少种选法。

同理设 (B(x)) 为选择两个相同物品的生成函数,设 (C(x)) 为选择三个相同物品的生成函数。

则对于最后的答案而言:

若选择的 (3) 个物品互不相同,则方案数为:

[frac{A^3(x)-3B(x)A(x)+2C(x)}{6} ]

因为根据容斥,(A^3(x)) 等于所有选择三个物品的方案数,(B(x)A(x)) 则是所有形如 ((a, a, b)) 的方案数,由于这种方案在 (A^3(x)) 会出现三次,所以要乘 (3) ,然后对于所有 ((a,a,a)) ,也即生成函数 (C(x))(B(x)A(x)) 中出现了 (3) 次,但实际上在 (A^3(x)) 只会被计算一次,所以还要加回 (2) 个来。

若选择 (2) 个物品,那么方案为:

[frac{A^2(x)-B(x)}{2} ]

这个很好理解。

选择一个物品的方案自然就是 (A(x)) 了。

(FFT) 即可。

Code

#include <cmath>
#include <complex>
#include <cstdio>
#include <iostream>
using namespace std;
#define LL long long
#define cp complex<double>
#define inline __inline__ __attribute__((always_inline))
inline LL read() {
  LL x = 0, w = 1;
  char ch = getchar();
  while (!isdigit(ch)) {
    if (ch == '-') w = -1;
    ch = getchar();
  }
  while (isdigit(ch)) {
    x = (x << 3) + (x << 1) + ch - '0';
    ch = getchar();
  }
  return x * w;
}

const int Max_n = 4e5 + 5, Ml = 1.2e5;
const double pi = acos(-1);
cp ans[Max_n], A[Max_n], B[Max_n], C[Max_n];

namespace Input {
void main() {
  int n = read();
  for (int i = 1, x; i <= n; i++)
    x = read(), A[x] += 1, B[x * 2] += 1, C[x * 3] += 1;
}
}  // namespace Input

namespace Solve {
int bit, len, rev[Max_n];
void init() {
  int bit = log2(Ml + 1) + 1;
  len = 1 << bit;
  for (int i = 0; i < len; i++)
    rev[i] = rev[i >> 1] >> 1 | ((i & 1) << (bit - 1));
}
void dft(cp *f, int t) {
  for (int i = 0; i < len; i++)
    if (i < rev[i]) swap(f[i], f[rev[i]]);
  for (int l = 1; l < len; l <<= 1) {
    cp Wn(cos(t * pi / (double)l), sin(t * pi / (double)l));
    for (int i = 0; i < len; i += (l << 1)) {
      cp Wnk(1, 0);
      for (int k = i; k < i + l; k++, Wnk *= Wn) {
        cp x = f[k], y = f[k + l] * Wnk;
        f[k] = x + y, f[k + l] = x - y;
      }
    }
  }
}
void main() {
  init();
  dft(A, 1), dft(B, 1), dft(C, 1);
  for (int i = 0; i < len; i++) {
    ans[i] = (A[i] * A[i] * A[i] - A[i] * B[i] * 3.0 + 2.0 * C[i]) / 6.0;
    ans[i] += (A[i] * A[i] - B[i]) / 2.0 + A[i];
  }
  dft(ans, -1);
  for (int i = 0; i <= Ml; i++) ans[i] /= (double)len;
  for (int i = 0; i <= Ml; i++) {
    LL Ans = (LL)(ans[i].real() + 0.5);
    if (Ans) printf("%d %lld
", i, Ans);
  }
}
}  // namespace Solve

int main() {
  Input::main();
  Solve::main();
}
原文地址:https://www.cnblogs.com/luoshuitianyi/p/12056962.html