为什么这里的题解都是写的顺推呢?这里提供一篇倒推的题解,为那些和我一样打倒退写wa的小伙伴提供一点能够借鉴的代码。
对于这道弱化版的题目,我们考虑倒推,定义(f_{i,j,k,w})我们在被打了(i-1)次后,还剩下(j)个三血奴隶主,(k)个两血奴隶主,(w)个一血奴隶主在接下来的(i)~(k)次被打中受到伤害的期望。
那么转移就会比较容易:
定义(tot = j+k+w+1)
当(j+k+w < 7)时有:
[f_{i,j,k,w} = (f_{i+1,j,k,w}+1) imes frac{1}{tot}+ f_{i+1,j,k+1,w} imes frac {j}{tot} + f_{i+1,j+1,k-1,w+1} imes frac{k}{tot} +f_{i+1,j,k,w-1} imes frac{w}{tot}
]
否则有:
[f_{i,j,k,w} = f_{i+1,j-1,k+1,w} imes frac {j}{tot} + f_{i+1,j,k-1,w+1} imes frac{k}{tot} +f_{i+1,j,k,w-1} imes frac{w}{tot}
]
代码比较简单,注意一下输入顺序就好了。
现在我们来看这一道题,我们仍然考虑倒退,发现(n leq 10^{18}),我们需要优化,我们发现对于每个(dp)状态的转移的参数与(i)并无关系,因此我们考虑矩阵加速。
我们枚举出(j,k,w)的所有合法情况(由于我们转移的特殊性,我们需要再把1作为常数项加入矩阵参与转移),将其作为矩阵的大小,由于(K leq 8),因此状态的数量并不会太多,只有(166)个,我们枚举完合法情况将其编号过后枚举可以转移它的状态并计算出参数填入矩阵中,就可以进行矩阵优化了~。
但是这个复杂度我们仍然无法通过本题,我们考虑继续优化。我们发现这个矩阵是固定的,因此我们可以预处理出矩阵的(2^i)的结果,我们在处理询问的时候只需要用一个单行去乘以我们预处理出来的一些矩阵就可以了。
但是它还是死了,根据一些奇技淫巧,我们可以用__in128储存矩阵乘法的中间结果,最后再去取模,就能有效地减少取模次数,优化效果显著。
下面就是十分冗长的代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
const LL mod = 998244353;
struct matrix {
LL a[170][170];
int n , m;
matrix operator * (const matrix &b) const{
matrix c;
__int128 C[170][170];
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= b.m; ++j) {
C[i][j] = 0;
}
}
for (int i = 1; i <= n; ++i) {
for (int k = 1; k <= m; ++k) {
for (int j = 1; j <= b.m; ++j) {
C[i][j] = C[i][j] + a[i][k] * b.a[k][j];
}
}
}
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= b.m; ++j) {
c.a[i][j] = C[i][j] % mod;
}
}
c.n = n;
c.m = b.m;
return c;
}
}I;
int T , m , K , tot , id[10][10][10];
matrix A , B[64] , beg , ans;
LL get_inv(LL a) {
int b = mod - 2;
LL res = 1;
while(b) {
if(b & 1) res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
void get_matrix_3 () {
for (int j = 0; j <= K; ++j) {
for (int k = 0; k <= K; ++k) {
for (int w = 0; w <= K; ++w) {
if(j + k + w > K) continue;
id[j][k][w] = ++tot;
}
}
}
++tot;
A.n = A.m = tot;
A.a[tot][tot] = 1;
for (int j = 0; j <= K; ++j) {
for (int k = 0; k <= K; ++k) {
for (int w = 0; w <= K; ++w) {
if(j + k + w > K) continue;
int now = id[j][k][w];
LL tmp = get_inv((LL)(j + k + w + 1));
A.a[now][now] = tmp;A.a[now][tot] = tmp;
if(j) {
if(j + k + w < K) A.a[now][id[j][k+1][w]] = (LL)j * tmp % mod;
else A.a[now][id[j-1][k+1][w]] = (LL)j * tmp % mod;
}
if(k) {
if(j + k + w < K) A.a[now][id[j+1][k-1][w+1]] = (LL)k * tmp % mod;
else A.a[now][id[j][k-1][w+1]] = (LL)k * tmp % mod;
}
if(w) A.a[now][id[j][k][w-1]] = (LL)w * tmp % mod;
}
}
}
}
int id2[10][10];
void get_matrix_2 () {
for (int j = 0; j <= K; ++j) {
for (int k = 0; k <= K; ++k) {
if(j + k > K) continue;
id2[j][k] = ++tot;
}
}
++tot;
I.n = I.m = A.n = A.m = tot;
for (int j = 0; j <= K; ++j) {
for (int k = 0; k <= K; ++k) {
if(j + k > K) continue;
int now = id2[j][k];
LL tmp = get_inv((LL)(j + k + 1));
A.a[now][now] = tmp;A.a[now][tot] = tmp;
if(j) {
if(j + k < K) A.a[now][id2[j][k+1]] = (LL)j * tmp % mod;
else A.a[now][id2[j-1][k+1]] = (LL)j * tmp % mod;
}
if(k) A.a[now][id2[j][k-1]] = (LL)k * tmp % mod;
}
}
A.a[tot][tot] = 1;
}
void make_pow() {
B[0] = A;
for (int i = 1; i <= 59; ++i) B[i] = B[i - 1] * B[i - 1];
}
LL f[15][10];
void work_1(int n) {
for (int i = n; i >= 1; --i) {
for (int j = 0; j <= K; ++j) {
LL tmp = get_inv((LL)(j + 1));
f[i][j] = (f[i+1][j] + 1) * tmp % mod;
if(j) f[i][j] += f[i+1][j-1] * (LL)j % mod * tmp % mod;
}
}
printf("%lld
" , f[1][1]);
memset(f , 0 , sizeof f);
}
int main() {
// freopen("10.in" , "r" , stdin);
scanf("%d %d %d" , &T , &m , &K);
if(m == 3) {
get_matrix_3();
make_pow();
ans.n = tot;ans.m = 1;
while(T -- > 0) {
LL x;
scanf("%lld" , &x);
for (int i = 1; i < tot; ++i) ans.a[i][1] = 0;
ans.a[tot][1] = 1;
for (LL i = 60; i >= 1; --i) {
if(x >= (1ll << (i - 1ll))) {
x -= (1ll << (i - 1ll));
ans = B[i - 1] * ans;
}
}
printf("%lld
" , ans.a[id[1][0][0]][1]);
}
}
else if(m == 2) {
get_matrix_2();
make_pow();
beg.n = tot;beg.m = 1;
beg.a[tot][1] = 1;
while(T -- > 0) {
LL x;
scanf("%lld" , &x);
ans = beg;
for (LL i = 60; i >= 1; --i) {
if(x >= (1ll << (i - 1ll))) {
x -= (1ll << (i - 1ll));
ans = B[i - 1] * ans;
}
}
printf("%lld
" , ans.a[id2[1][0]][1]);
}
}
else {
while(T -- > 0) {
int n;
scanf("%d" , &n);
work_1(n);
}
}
return 0;
}