启智树提高组Day4T3 2的幂拆分

问把 (n) 拆分成2的幂的和的形式的方案数。注意拆分是无序的,即 (2^0,2^2,2^1,2^0)(2^0,2^0,2^1,2^2) 是一种方案。(n le 10^{18})

前置知识

一些定义与约定

定义:

[S_k(n)=sum_{i=0}^{n-1}i^k ]

[]

部分下取整符号可能有遗漏。

题解

首先通过打表发现前几项是:(从0开始)

[1,1,2,2,4,4,6,6,10,10,14,14,20,20,26,26,... ]

发现去重差分还是原数组,并且发现:

[f_i = f_{i - 1} ,i mod 2=1 ]

[f_i = f_{i-1}+f_{i/2},imod2=0 ]

不断重复使用这两个式子拆开发现:

[f_{n} = sum_{j=0}^{left lfloor n/2 ight floor}f_j ]

(f_j) 发现:

[f_n=sum_{j=0}^{left lfloor n/2 ight floor }sum_{i=0}^{left lfloor j/2 ight floor }f_i ]

[=sum_{i=0}^{left lfloor n/4 ight floor } f_isum_{j=2i}^{left lfloor n/2 ight floor }1 ]

[=sum_{i=0}^{n/4}f_i(S_{0}(left lfloor n/2 ight floor +1)-S_0(2i)) ]

[=S_0(left lfloor n/2 ight floor +1)sum_{i=0}^{n/4}f_i-sum_{i=0}^{n/4}f_iS_0(2i) ]

[=S_0(left lfloor n/2 ight floor +1)sum_{i=0}^{n/4}f_i-2sum_{i=0}^{n/4}f_ii ]

(公式不太好打,就不打那么多了)

发现前面的那部分可以直接递归子问题,我们难以解决的是后边的那个类似 (sum_{i=0}^{n}f_ii^k) 的式子。我们试图求解这个式子:

[F(n,k)=sum_{i=0}^nf_ii^k ]

[=sum_{i=0}^ni^ksum_{j=0}^{i/2}f_j ]

[=sum_{j=0}^{n/2}f_j sum_{i=2j}^ni^k ]

[=sum_{j=0}^{n/2}f_j(S_k(n+1)-S_k(2j)) ]

[=sum_{j=0}^{n/2}f_jS_k(n+1)-sum_{j=0}^{n/2}f_jsum_{d=1}^{k+1}frac{{k + 1 choose d}}{k+1}B_{k+1-d}(2j)^d ]

[=sum_{j=0}^{n/2}f_jS_k(n+1)-sum_{d=1}^{k+1}frac{{k + 1 choose d}}{k+1}B_{k+1-d}2^d sum_{j=0}^{n/2}f_jj^d ]

[=F(n/2,0)S_k(n+1)-sum_{d=1}^{k+1}frac{{k + 1 choose d}}{k+1}B_{k+1-d}2^d F(n/2,d) ]

发现左边是子问题,右边每次需要的 (k) 会加一。因此 (k) 最多是 (logn)。状态数 (O(log^2n)),转移复杂度 (O(logn)),总复杂度:(O(log^3n))。由于 (k) 比较小,不需要多项式求逆加速求 (B_i),且这并不是复杂度瓶颈。

注意最好手写个哈希表,不然用 STL map 可能会被卡常。

(Code:)

#include <cstdlib>
#include <cstdio>
#include <iostream>
#include <cmath>
#include <string>
#include <cstring>
#include <ctime>
#include <algorithm>
#define N 76
typedef long long ll;
template <typename T> inline void read(T &x) {
	x = 0; char c = getchar(); bool flag = false;
	while (!isdigit(c)) { if (c == '-')	flag = true; c = getchar(); }
	while (isdigit(c)) { x = (x << 1) + (x << 3) + (c ^ 48); c = getchar(); }
	if (flag)	x = -x;
}
using namespace std;
const int P = 1e9 + 7;
inline void MAX(int &a, int b) {
	if (b > a)	a = b;
}
ll c[N][N], B[N];
inline ll quickpow(ll x, ll k) {
	ll res = 1;
	while (k) {
		if (k & 1)	res = res * x % P;
		x = x * x % P;
		k >>= 1;
	}
	return res;
}
ll inv[N];
inline ll get_c(int n, int m) {
	return c[n + 1][m + 1];
}
inline void init_B() {
	const int up = 66;
	for (register int i = 1; i <= up; ++i)	inv[i] = quickpow(i, P - 2);
	c[1][1] = 1;
	for (register int i = 2; i <= up; ++i) {
		for (register int j = 1; j <= i; ++j) {
			c[i][j] = (c[i - 1][j - 1] + c[i - 1][j]) % P;
		}
	}
	B[0] = 1;
	for (register int i = 1; i <= up; ++i) {
		for (register int j = 0; j <= i - 1; ++j) {
			B[i] = (B[i] + get_c(i + 1, j) * B[j]) % P;
		}
		B[i] = (P - inv[i + 1] * B[i]) % P; 
	}
}
const int PP = 13331;
struct Hashtable {//手写了哈希表,也可以用STL map,但是复杂度会多一个 log
	struct edge{
		int nxt;
		ll to;
		int val;
	}e[N << 1];
	int head[PP + 1000], ecnt;
	inline void init() {
		memset(head, 0, sizeof(head));
		ecnt = 0;
	}
	inline void addedge(ll v, int val) {
		int mirr = v % PP + 1;
		e[++ecnt] = (edge){head[mirr], v, val};
		head[mirr] = ecnt;
	}
	inline int find(ll v) {
		int mirr = v % PP + 1;
		for (register int i = head[mirr]; i; i = e[i].nxt) {
			ll to = e[i].to; if (to == v)	return e[i].val;
		}
		return -1;
	}
}Hash;
int mtot;
ll memo[N << 1][N];
ll calc(ll n, int k) {
	int mirr = Hash.find(n);
	if (mirr == -1)	Hash.addedge(n, ++mtot), mirr = mtot;
	if (memo[mirr][k])	return memo[mirr][k];
	if (n <= 1) {
		if (n == 0) {
			return memo[mirr][k] = (k == 0 ? 1 : 0);
		}
		return memo[mirr][k] = (k == 0 ? 2 : 1);
	}
	ll res = 0, tmp = (n + 1) % P;
	for (register int d = 1; d <= k + 1; ++d) {
		res = (res + get_c(k + 1, d) * B[k - d + 1] % P * inv[k + 1] % P * tmp) % P;
		tmp = (n + 1) % P * tmp % P;
	}
	res = res * calc(n >> 1, 0) % P;
	for (register int d = 1; d <= k + 1; ++d) {
		res = (res - (1ll << d) % P * get_c(k + 1, d) % P * inv[k + 1] % P * B[k - d + 1] % P * calc(n >> 1, d) % P) % P;
	}
	return memo[mirr][k] = res;
}
ll n;
inline void work() {
	Hash.init();
	mtot = 0;
	memset(memo, 0, sizeof(memo));
	read(n);
	printf("%lld
", (calc(n >> 1, 0) % P + P) % P);
}

int main() {
	init_B();
	int _; read(_);
	while (_--) {
		work();
	}
	return 0;
}
原文地址:https://www.cnblogs.com/JiaZP/p/13550434.html