jzoj 5251. 【GDOI2018模拟8.11】决战


我:!@#¥%……&*()

原来这就是一个套路(* ̄︶ ̄)
我们从暴力入手,(f[i+1][T][j+|T|] += f[i][S][j])
其中(S,T)表示该行哲学家的状态,而(|T|)表示(T)状态放的哲学家个数。
我们可以将第三维去掉,改成多项式的形式:(g[i][S]=∑_{j=0}^{3*n+1}f[i][S][j]*x^j),因为这样便于快速求出第(n)项(矩乘)。
这样我们可以带入(3*n+1)(x)然后与最终结果累加的和组成(3*n+1)个点值,最后用(NTT)得到系数表达式,
而此时的(b[m])显然就是最终结果了。
如何带入(x)快速得到(∑g[n][S])?用快速幂,对于(S)状态到(T)状态,可以转移的话,矩阵填上的数即为(x^|T|),好巧妙啊。。。

#include <cstdio>
#include <cstring>
#include <algorithm>
#define g 3
#define mo 998244353
#define ll long long
#define mem(x, a) memset(x, a, sizeof x)
#define mpy(x, y) memcpy(x, y, sizeof y)
#define fo(x, a, b) for (int x = (a); x <= (b); x++)
#define fd(x, a, b) for (int x = (a); x >= (b); x--)
#define go(x) for (int p = tail[x], v; p; p = e[p].fr)
using namespace std;
struct matrix{int a[9][9], n, m;}aw, zy, c;
int n, m, a[4][4], dl[4], to[9][9], b[15010], r[15010];

inline int read() {
	int x = 0, f = 0; char c = getchar();
	while (c < '0' || c > '9') f = (c == '-') ? 1 : f, c = getchar();
	while (c >= '0' && c <= '9') x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
	return f ? -x : x;
}

bool judge(int x) {return ! (((x & 1) || (x & 4)) && (x & 2) && (a[2][3] || a[2][1]));}

bool check(int x, int y) {
	if (! judge(x)) return 0;
	if (! judge(y)) return 0;
	if (x & 1) {if ((a[3][2] && (y & 1)) || (a[3][3] && (y & 2))) return 0;}
	if (x & 2) {if ((a[3][1] && (y & 1)) || (a[3][2] && (y & 2)) || (a[3][3] && (y & 4))) return 0;}
	if (x & 4) {if ((a[3][1] && (y & 2)) || (a[3][2] && (y & 4))) return 0;}
	if (y & 1) {if ((a[1][2] && (x & 1)) || (a[1][3] && (x & 2))) return 0;}
	if (y & 2) {if ((a[1][1] && (x & 1)) || (a[1][2] && (x & 2)) || (a[1][3] && (x & 4))) return 0;}
	if (y & 4) {if ((a[1][1] && (x & 2)) || (a[1][2] && (x & 4))) return 0;}
	return 1;
}

ll ksm(ll x, int y) {
	ll s = 1;
	while (y) {
		if (y & 1) s = s * x % mo;
		x = x * x % mo, y >>= 1;
	}
	return s;
}

matrix ksc(matrix a, matrix b) {
	mem(c.a, 0); c.n = a.n, c.m = b.m;
	fo(k, 1, a.m) fo(i, 1, a.n) fo(j, 1, b.m)
		c.a[i][j] = (c.a[i][j] + (ll)a.a[i][k] * b.a[k][j]) % mo;
	return c;
}

ll solve(int x) {
	mem(aw.a, 0);
	fo(i, 1, 8) fo(j, 1, 8)
		if (to[i - 1][j - 1] == -1) zy.a[i][j] = 0;
		else zy.a[i][j] = ksm(x, to[i - 1][j - 1]);
	aw.a[1][1] = 1; aw.n = 1, aw.m = 8, zy.n = zy.m = 8;
	int cs = n;
	while (cs) {
		if (cs & 1) aw = ksc(aw, zy);
		zy = ksc(zy, zy), cs >>= 1;
	}
	ll ans = 0;
	fo(i, 1, 8) (ans += aw.a[1][i]) %= mo;
	return ans;
}

void NTT(int *x, int n, int type) {
	fo(i, 0, n - 1) if (i < r[i]) swap(x[i], x[r[i]]);
	for (int i = 1; i < n; i <<= 1) {
		ll wn = ksm(g, (type * (mo - 1) / (i << 1) + mo - 1) % (mo - 1));
		for (int j = 0; j < n; j += (i << 1)) {
			ll w = 1;
			for (int k = 0; k < i; k++, w = w * wn % mo) {
				int a = x[j + k], b = w * x[j + k + i] % mo;
				x[j + k] = (a + b) % mo, x[j + k + i] = (a - b + mo) % mo;
			}
		}
	}
}

int main()
{
	freopen("final.in", "r", stdin);
	freopen("final.out", "w", stdout);
	n = read(), m = read();
	fo(i, 1, 3) fo(j, 1, 3) a[i][j] = read();
	fo(i, 0, 7) fo(j, 0, 7) {
		if (! check(i, j)) to[i][j] = -1;
		else to[i][j] = ((j & 1) == 1) + ((j & 2) == 2) + ((j & 4) == 4);
	}
	int len = 1, times = 0;
	while (len <= 3 * n + 1) len <<= 1, times++;
	ll wn = ksm(g, (mo - 1) / len), w = 1;
	fo(i, 0, len - 1) b[i] = solve(w), w = w * wn % mo;
	fo(i, 0, len - 1) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (times - 1));
	NTT(b, len, -1);
	ll inv = ksm(len, mo - 2);
	printf("%lld
", b[m] * inv % mo);
	return 0;
}
转载需注明出处。
原文地址:https://www.cnblogs.com/jz929/p/13431097.html