P4723 【模板】常系数齐次线性递推 题解

祭奠我逝去的一下午加一晚上=-=

Description

Luogu传送门

顺便 sto \(Zhang\_RQ\)学长 orz

Solution

这名字听起来就很高大上的样子(事实上确实如此

好吧其实我并没有打算推式子,因为 \(BJpers2\) 巨佬在他的题解中已经把式子推的很明白了,只是他的代码着实有些毒瘤,因此我这里只是想放下我自己的代码罢了。

简单说两句,推出特征多项式 \(p(x)\) 之后,就是要求 \(x^n \ \ mod \ \ p(x)\),然而我们这个 \(n\)\(10^9\),没办法直接放到多项式里算,所以采用快速幂的思想。

快速幂过程

设最开始的多项式 \(t(x) = 1\),倍增往上跳,最大情况下 \(t\) 会是一个 \(k - 1\) 次的多项式乘一下就会变成 \(2 \times k - 2\) 次,然后去对 \(k\) 次的多项式 \(p(x)\) 取模。

次数体现在代码里的话,就是 \(Mod\) 函数中传的实参是 \(n << 1\)

坑点

  1. 数组最好都开成局部变量,不然就各种错乱(我一开始用的全局变量数组就一直都是 0).
  2. 边界!边界!边界!好吧,说实话这玩意就算知道有坑点也没啥用,就算让我再写一遍可能也得调半天。

废话不多说,上代码吧,希望对您有帮助。

Code(大常数警告)

#include <bits/stdc++.h>
#define ll long long

using namespace std;

namespace IO{
    inline ll read(){
        ll x = 0, f = 1;
        char ch = getchar();
        while(!isdigit(ch)) {if(ch == '-') f = -1; ch = getchar();}
        while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
        return x * f;
    }

    template <typename T> inline void write(T x){
        if(x < 0) putchar('-'), x = -x;
        if(x > 9) write(x / 10);
        putchar(x % 10 + '0');
    }

    inline void print(ll a[], ll n){
        for(int i = 0; i <= n; ++i) printf("%lld ", a[i]);
        puts("");
    }
}
using namespace IO;

const ll N = 5e5 + 10;
const ll mod = 998244353;
const ll G = 3, Gi = 332748118;
ll n, m, k;
ll a[N], b[N], c[N], d[N], e[N], p[N], res[N], t[N];
ll f[N], g[N], ig[N], q[N], r[N];

namespace NTT{
    ll lim, len;

    inline ll qpow(ll a, ll b){
        ll res = 1;
        while(b){
            if(b & 1) res = res * a % mod;
            a = a * a % mod, b >>= 1;
        }
        return res;
    }

    inline void get_rev(ll n){
        lim = 1, len = 0;
        while(lim <= n) lim <<= 1, ++len;
        for(int i = 0; i <= lim; ++i) p[i] = (p[i >> 1] >> 1) | ((i & 1) << (len - 1));
    }

    inline void ntt(ll A[], ll lim, ll type){
        for(int i = 0; i <= lim; ++i)
            if(i < p[i]) swap(A[i], A[p[i]]);
        for(int mid = 1; mid < lim; mid <<= 1){
            ll Wn = qpow(type == 1 ? G : Gi, (mod - 1) / (mid << 1));
            for(int i = 0; i < lim; i += (mid << 1)){
                ll w = 1;
                for(int j = 0; j < mid; ++j, w = w * Wn % mod){
                    ll x = A[i + j], y = w * A[i + j + mid] % mod;
                    A[i + j] = (x + y) % mod;
                    A[i + j + mid] = (x - y + mod) % mod;
                }
            }
        }
        if(type == 1) return;
        ll inv = qpow(lim, mod - 2);
        for(int i = 0; i <= lim; ++i) A[i] = A[i] * inv % mod;
    }

    inline void Mul(ll n, ll m, ll a[], ll b[], bool flag = 1){
        static ll d[N], e[N];
        for(int i = 0; i < (n << 2); ++i) d[i] = e[i] = 0;
        for(int i = 0; i < n; ++i) d[i] = a[i], e[i] = b[i];
        get_rev(n + m);
        ntt(d, lim, 1), ntt(e, lim, 1);
        for(int i = 0; i < lim; ++i) d[i] = d[i] * e[i] % mod;
        ntt(d, lim, -1);
        for(int i = 0; i < (n << 1); ++i) a[i] = d[i];
        for(int i = (n << 1); i <= lim; ++i) a[i] = 0;
        if(flag) for(int i = n; i < (n << 1); ++i) a[i] = 0;

    }

    inline void Inv(ll n, ll a[], ll b[]){
        if(!n) return b[0] = qpow(a[0], mod - 2), void();
        Inv(n >> 1, a, b);
        get_rev(n << 1);
        for(int i = 0; i <= n; ++i) c[i] = a[i];
        for(int i = n + 1; i <= lim; ++i) c[i] = 0;
        ntt(c, lim, 1), ntt(b, lim, 1);
        for(int i = 0; i <  lim; ++i) b[i] = (2ll - c[i] * b[i] % mod + mod) * b[i] % mod;
        ntt(b, lim, -1);
        for(int i = n + 1; i <= lim; ++i) b[i] = 0;
    }
}
using namespace NTT;

inline void Mod(ll n, ll m, ll f[], ll g[], ll r[]){
    static ll a[N], b[N];
    for(int i = 0; i < (n << 2); ++i) a[i] = b[i] = d[i] = 0;
    for(int i = 0; i < n - m + 1; ++i) a[i] = f[n - i - 1];
    for(int i = 0; i < n - m + 1; ++i) b[i] = g[m - i - 1];

    Inv(n - m + 1, b, d);
    Mul(n - m + 1, n - m + 1, a, d);
    for(int i = 0; i <= n - m; ++i) q[i] = a[n - m - i];

    for(int i = 0; i < (n << 2); ++i) a[i] = b[i] = 0;
    for(int i = 0; i < n; ++i) a[i] = f[i];
    for(int i = 0; i < m; ++i) b[i] = g[i];
    Mul(n, n, b, q);

    for(int i = 0; i < m - 1; ++i) r[i] = (a[i] - b[i] + mod) % mod;
    for(int i = m - 1; i < lim; ++i) r[i] = 0;
}

inline void solve(ll p, ll n){
    t[1] = res[0] = 1;
    while(p){
        if(p & 1) Mul(n, n, res, t, 0), Mod(n << 1, n, res, g, res);// b % g --> b
        Mul(n, n, t, t, 0), Mod(n << 1, n, t, g, t);
        p >>= 1;
    }
}

signed main(){
    // freopen("P4723.in", "r", stdin);
    // freopen("P4723.out", "w", stdout);
    n = read(), m = read();
    g[0] = 1;
    for(int i = 1; i <= m; ++i) g[i] = (mod - (read() % mod + mod) % mod);
    reverse(g, g + 1 + m);
    for(int i = 0; i < m; ++i) f[i] = read();
    solve(n, m + 1);
    ll ans = 0;
    for(int i = 0; i < m; ++i) ans = (ans + res[i] * f[i] % mod + mod) % mod;
    write(ans), puts("");
    return 0;
}

\[\_EOF\_ \]

原文地址:https://www.cnblogs.com/xixike/p/15626626.html