题解-P4723 【模板】常系数齐次线性递推 [*hard]

正着推感觉很不可做,考虑倒着推。

例如我们要求 (a_i = a_{i - 1} + 2 a_{i - 2}, a_0 = 2, a_1 = 1) 的第 (4) 项的时候,我们倒着推 :

[a_4 = a_3 + 2a_2 = 3a_2 + 2a_1 = 5a_1 + 6a_0 = 5 imes 1 + 6 imes 2 = 17 ]

我们发现,我们倒着推的过程相当于每次把最高的那一项 (a_x) 变成 (sumlimits_{i = 1}^{k} f_i a_{x - i})

很像一个取模的过程,如果一个数可以表示 (a) 数组中的几个数相加 : (sumlimits_{i = 0} p_i a_i),那么我们把它表示成多项式 (sumlimits_{i = 0} p_i x^i)

一次操作相当于是把 (x^t) 变成 (sumlimits_{i = 1}^{k} f_i x^{t - i})

构造多项式 (lambda) 满足 (forall t, x^t equiv sumlimits_{i = 1}^{k} f_i x^{t - i} pmod lambda)

[x^t - sumlimits_{i = 1}^{k} f_i x^{t - i} equiv 0 pmod lambda ]

[x^{t - k} ( x^k - sumlimits_{i = 1}^{k} f_i x^{k - i} ) equiv 0 pmod lambda ]

(lambda = x^k - sumlimits_{i = 1}^{k} f_i x^{k - i}) 即可满足。

因次 (x^n) 可以当作 (x^n mod lambda) 来计算。

代码:

#include<bits/stdc++.h>
#define L(i, j, k) for(int i = j, i##E = k; i <= i##E; i++)
#define R(i, j, k) for(int i = j, i##E = k; i >= i##E; i--)
#define ll long long
#define ull unsigned long long
#define db double
#define pii pair<int, int>
#define mkp make_pair
using namespace std;
inline int read() {
	int x = 0, f = 1; char ch = getchar();
	while(!isdigit(ch)) {
		if(ch == '-') f = -1;
		ch = getchar();	
	}
	while(isdigit(ch)) x = x * 10 + (ch ^ 48), ch = getchar();
	return x * f;
}
const int N = (1 << 18), mod = 998244353, G = 3, iG = (mod + 1) / G;
int qpow(int x, int y = mod - 2) {
	int res = 1;
	for(; y; x = (ll) x * x % mod, y >>= 1) if(y & 1) res = (ll) res * x % mod;
	return res;
}
int Lim, lim, pp[N], PowG[N], iPowG[N];
void revlim() { L(i, 0, lim - 1) pp[i] = ((pp[i >> 1] >> 1) | ((i & 1) * (lim >> 1))); }
void up(int x) { lim = 1; for(; lim <= x; lim <<= 1); }
void cle(int *f) { L(i, 0, lim - 1) f[i] = 0; }
void init(int x) {
	int Pw;
	up(x), Lim = lim;
	Pw = qpow(G, (mod - 1) / Lim), PowG[0] = 1;
	L(i, 1, lim - 1) PowG[i] = (ll) PowG[i - 1] * Pw % mod;
	Pw = qpow(iG, (mod - 1) / Lim), iPowG[0] = 1;
	L(i, 1, lim - 1) iPowG[i] = (ll) iPowG[i - 1] * Pw % mod;
}
inline void fmod(int &x) {
	x += x >> 31 & mod;
}
inline void ad(int &x, int y) {
	x += y, x -= mod; x += x >> 31 & mod;
}
inline int Sum(int x, int y) {
	x += y, x -= mod; x += x >> 31 & mod;
	return x;
}
void NTT(int *f, int flag) {
	L(i, 0, lim - 1) if(pp[i] < i) swap(f[pp[i]], f[i]);
	for(int i = 2; i <= lim; i <<= 1) 
		for(int j = 0, l = (i >> 1), ch = Lim / i; j < lim; j += i) 
			for(int k = j, now = 0; k < j + l; k ++) {
				int pa = f[k], pb = (ll) f[k + l] * (flag == 1 ? PowG[now] : iPowG[now]) % mod;
				f[k] = Sum(pa, pb), f[k + l] = Sum(pa, mod - pb), now += ch;
			}
	if(flag == -1) {
		int nylim = qpow(lim);
		L(i, 0, lim - 1) f[i] = (ll) f[i] * nylim % mod;
	}
}
int sav[N];
void inv(int *f, int *g, int len) { 
	if(len == 1) return g[0] = qpow(f[0]), void();
	inv(f, g, (len + 1) >> 1), up(len << 1), cle(sav), copy(f, f + len, sav), revlim(), NTT(sav, 1), NTT(g, 1);
	L(i, 0, lim - 1) g[i] = (ll) g[i] * (2ll + mod - (ll) g[i] * sav[i] % mod) % mod;
	NTT(g, -1), fill(g + len, g + lim, 0);
}
void Mul(int *f, int *g, int *ans, int n, int m) {
	static int A[N], B[N];
	up(n + m), revlim(), cle(A), cle(B), copy(f, f + n, A), copy(g, g + m, B);
	NTT(A, 1), NTT(B, 1);
	L(i, 0, lim - 1) A[i] = (ll) A[i] * B[i] % mod;
	NTT(A, -1), copy(A, A + n + m - 1, ans);
}
void div(int *f, int *g, int *ansa, int *ansb, int n, int m) {
	static int A[N];
	reverse(f, f + n), reverse(g, g + m), up((n - m + 1) << 1), cle(A);
	inv(g, A, n - m + 1), Mul(f, A, A, n - m + 1, n - m + 1);
	reverse(A, A + n - m + 1), copy(A, A + n - m + 1, ansa);
	reverse(f, f + n), reverse(g, g + m), Mul(A, g, A, n - m + 1, m);
	L(i, 0, m - 2) ansb[i] = (f[i] - A[i] + mod) % mod; 
}
int n, m, f[N], a[N], res[N], g[N], sv[N], vs[N], ans;
int main() {
	n = read(), m = read(), init(m << 1);
	f[m] = 1;
	R(i, m - 1, 0) f[i] = ( mod - read() % mod ) % mod;
	L(i, 0, m - 1) a[i] = read() % mod, fmod(a[i] += mod);
	res[0] = 1, g[1] = 1;
	for(; n; Mul(g, g, g, m, m), div(g, f, sv, vs, m << 1, m + 1), copy(vs, vs + m, g), n >>= 1) 
		if(n & 1) Mul(res, g, res, m, m), div(res, f, sv, vs, m << 1, m + 1), copy(vs, vs + m, res);
	L(i, 0, m - 1) ad(ans, (ll) res[i] * a[i] % mod);
	cout << ans << endl;
	return 0;
} 
原文地址:https://www.cnblogs.com/zkyJuruo/p/14320301.html