花园

题意

给一个长度为(n)的环,要求在这个环上填上(0)(1),使得这个环满足对于任意长度为(m)的区间,其中(0)的个数不超过(k)。请求出所有合法的填数的方案数

将环上的结点标号为(1)(n),两种方案不同当且仅当至少存在一个节点,两种方案在此处所填的数不同

(nleq 10^{15},kleq mleq 5,mod = 10^9+7)


解法

见数据范围识矩乘优化DP

这里引用一下miracle大佬对矩乘优化DP题目的一些特点:

  • 一定存在一个线性递推式
  • 总有一个保持不变的转移矩阵
  • 由于矩乘的复杂度是(O(n^3))的,所以转移矩阵的边长不能太大
  • 矩阵只需要保留可以继续转移的项

首先,我们观察到(mleq 5),显然是可以进行状压的

我们能够得到以下的转移方程:(f[i][j]=f[i-1][k] imes a[j][k]),其中(a[j][k])意味着状态(j)可以转移到状态(k)

(a)数组是很好处理的,这里不再赘述

这个转移方程与floyd很类似,可以用矩乘优化,转移矩阵即是我们处理出的(a)矩阵

至于环的情况如何处理?

有一个经典套路:枚举第一个状态(s)(复杂度(O(2^m)))后转移(n)次,将(f[n+m][s])作为答案加入,这样就能保证首尾的状态均为(s),也就连接成了一个环

我们构建一个((2^m imes 2^m))的初始矩阵,初始状态分别位于(0 o 2^m)之间,值为(1)

可以发现这就是一个单位矩阵

我们直接把转移矩阵自乘(n)次,统计对角线上元素的和作为答案即可


代码

#include <cstdio>
#include <cstring>

using namespace std;

template<typename _T> void read(_T& x) {
	int c = getchar(); x = 0;	
	while (c < '0' || c > '9')   c = getchar();
	while (c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
}

const int mod = 1e9 + 7;

long long n, m, k, sz;

void add(int &x, int y) { (x += y) > mod ? x -= mod : x; }
int mul(int x, int y) { return 1LL * x * y % mod; }

struct matrix {
	
	int a[50][50];	
	
	matrix() { memset(a, 0, sizeof a); }
	matrix operator = (const matrix& rhs) {
		for (int i = 0; i < sz; ++i)
			for (int j = 0; j < sz; ++j)  a[i][j] = rhs.a[i][j];	
		return *this;
	}
	friend matrix operator * (const matrix &lhs, const matrix &rhs) {
		matrix res;
		for (int i = 0; i < sz; ++i)
			for (int j = 0; j < sz; ++j)	
				for (int k = 0; k < sz; ++k)
					add(res.a[i][j], mul(lhs.a[i][k], rhs.a[k][j]));
		return res;
	}	
	matrix operator ^ (long long k) const {
		matrix res, t = *this;
		for (int i = 0; i < sz; ++i)  res.a[i][i] = 1;
		for (; k; t = t * t, k >>= 1)
			if (k & 1)  res = res * t;	
		return res;
	}
} mt;

int calc(int x) {
	int res = 0;
	while (x)  ++res, x -= x & -x;
	return res;
}

int main() {
	
	read(n), read(m), read(k);
	
	sz = 1 << m;
	
	for (int i = 0; i < sz; ++i) {
		if (calc(i) > k)  continue;
		mt.a[(i >> 1)][i] = 1;
		mt.a[(i >> 1) | (1 << m - 1)][i] = 1;
	}
	
	mt = mt ^ n;
	
	long long ans = 0;
	for (int i = 0; i < sz; ++i) 
		ans = (ans + mt.a[i][i]) % mod;
	
	printf("%lld
", ans);
	
	return 0;
}
原文地址:https://www.cnblogs.com/VeniVidiVici/p/11600652.html