「题解」[LG7670]【JOI2018】Snake Escaping

我们先来考虑一个很暴力的做法:显然每个位置上只有 01? 三种可能,我们考虑通过递推或记搜预处理出 (3^L) 种状态,然后直接查询即可,这样做的时间复杂度是 (Theta(3^L + q)),可以通过 (L le 13) 的数据。

我们再来考虑一个更暴力的做法:对于每个询问暴力枚举每个 ?0 还是 1,时间复杂度 (Theta(q2^L))

我们发现第一个暴力的瓶颈在于预处理,第二个暴力的瓶颈在于询问,那么结合这两个暴力是不是就可以得到一个复杂度正确的算法了呢?

答案是肯定的,我们考虑预处理 (L-k) 位的答案,对于每一个询问将前 (k) 位单独提出来枚举 ? 的所有情况,然后对于后 (L-k) 位直接在预处理出的数据里查询即可。这样子时间复杂度就是 (Theta(q2^k + 2^k3^{L-k})) 的,当 (n=20) 时在 (k=7) 时取到最小值。

实现起来与上面所讲的略有不同,具体如下所述:

记一个二进制数 (i) 与一个由 01? 组成的字符串 (j) 相似,当且仅当枚举 (j)? 的所有情况,其中有一种情况的二进制形式等于 (i),例如 ((101)_2)?0? 是相似的,((101)_2)?1? 则不是相似的。

我们将询问存储下来离线处理,先枚举 (2^k) 种状态 (i),然后根据 (i) 预处理 (3^{L-k}) 种状态,最后对于每个询问 (j),如果 (j) 的前 (k) 位与 (i) 是相似的,那么这一组询问的答案就加上 (j) 的后 (L-k) 位预处理出的答案。

现在剩下两个问题:

  1. 如何快速判断一个二进制数 (i) 与一个由 01? 组成的字符串 (j) 相似?
  2. 如何预处理后 (3^{L-k}) 种状态?

对于第一个问题,我们考虑先把 (j) 中的问号当作 1,然后将 (j) 转换为二进制形式(记为 (f)),我们再把 (j) 中的问号当作 0,其他当作 1,将 (j) 转换为二进制形式(记为 (g)),如果 ((f oplus i) ext{AND} g) 的值为 (0)(i)(j) 相似(其中 (oplus) 为按位异或运算,( ext{AND}) 为按位与运算)。

对于第二个问题,我们每次将状态中的一个 ? 分别替换成 01 计算答案再加起来,可以通过记搜或递推实现。

代码如下:

/*
    I will never forget it.
*/
 
// 392699
 
#include <bits/stdc++.h>

#define reg register

using namespace std;

#define getchar() (__S == __T && (__T = (__S = __B) + fread(__B, 1, 1 << 23, stdin), __S == __T) ? EOF : *__S++)
#define putchar(c) (_P == _Q && (fwrite(_O, 1, 1 << 23, stdout), _P = _O), *_P++ = c)

char __B[1 << 23], *__S = __B, *__T = __B;
char _O[1 << 23], *_P = _O, *_Q = _O + (1 << 23);
 
template <class T> inline void fr(reg T &a, reg bool f = 0, reg char ch = getchar()) {
    for (a = 0; ch < '0' || ch > '9'; ch = getchar()) ch == '-' ? f = 1 : 0;
    for (; ch >= '0' && ch <= '9'; ch = getchar()) a = a * 10 + ch - '0';
    a = f ? -a : a;
}

template <class T> inline void fw(reg T x) {
    reg int tot = 0;
    reg char TMP[20];
    if (x == 0) return (void)(putchar('0'));
    if (x < 0) putchar('-'), x = -x;
    while (x) TMP[++tot] = x % 10 + '0', x /= 10;
    for (; tot; tot--) putchar(TMP[tot]);
}
 
const int N = 20, M = 1e6, K = 7, L = 20 - K, F = (1 << K) - 1;
 
int n, q, o2, o, a[(1 << N) + 10], ans[M + 10], pow3[N + 10] = {1}, pow32[N + 10] = {2};
int d[M + 10], cc[M + 10], bbb[M + 10], aaaa[1594323 + 10];
int lb[1594323 + 10], ts[1594323 + 10], td[(1 << N) + 10];
char str[N + 10];
 
int dfs(reg int x) {
    reg int t = x - o;
    if (aaaa[t] != -1) return aaaa[t];
    if (lb[t] == -1) return aaaa[t] = a[ts[t] + o2];
    return aaaa[t] = dfs(x - pow3[lb[t]]) + dfs(x - pow32[lb[t]]);
}

inline void getS() {
    reg char ch = getchar();
    reg int tot = 0;
    while (ch == '
' || ch == ' ') ch = getchar();
    while (ch != '
' && ch != ' ' && ch != EOF) str[++tot] = ch, ch = getchar();
}

struct OI {
    int RP, score;
} CSPS2021, NOIP2021;
 
int main() {
    CSPS2021.RP++, CSPS2021.score++, NOIP2021.RP++, NOIP2021.score++, 392699;
    memset(lb, -1, sizeof(lb));
    for (reg int i = 0; i < 1594323; i++) i % 3 == 2 ? lb[i] = 0 : ~lb[i / 3] ? lb[i] = lb[i / 3] + 1 : 0;
    for (reg int i = 0; i < 1594323; i++) ts[i] = ts[i / 3] << 1 | i % 3, td[ts[i]] = i;
    fr(n), fr(q);
    int n2 = 1 << n;
    for (reg int i = 0; i < n2; i++) a[i] = getchar() - '0';
    for (reg int i = 1; i <= N; i++) pow3[i] = pow3[i - 1] * 3, pow32[i] = pow3[i] * 2;
    if (n <= L) {
        memset(aaaa, -1, sizeof(aaaa));
        for (int i = 0; i < pow3[n]; i++) dfs(i);
        for (int i = 1; i <= q; i++) {
            getS();
            int k = 0;
            for (int j = 1; j <= n; j++) {
                if (str[j] == '?') k = k * 3 + 2;
                else k = k * 3 + str[j] - '0';
            }
            fw(aaaa[k]), putchar('
');
        }
        return fwrite(_O, 1, _P - _O, stdout), 0;
    }
    for (reg int i = 1; i <= q; i++) {
        getS(), cc[i] = F;
        for (reg int j = 1; j <= K; j++) {
            int k = str[j] == '0' ? 0 : 1;
            d[i] = d[i] * 2 + k;
            if (str[j] == '?') cc[i] = cc[i] ^ (1 << (K - j));
        }
        for (reg int j = K + 1; j <= n; j++) {
            int k;
            if (str[j] == '?') k = 2;
            else k = str[j] - '0';
            bbb[i] = bbb[i] * 3 + k;
        }
    }
    int awsl = 1 << K;
    for (reg int i = 0; i < awsl; i++) {
        o2 = i << (n - K), o = td[i] * pow3[n - K];
        memset(aaaa, -1, sizeof(aaaa));
        for (reg int j = 0; j < pow3[n - K]; j++) dfs(o + j);
        for (reg int j = 1; j <= q; j++)
            if (((d[j] ^ i) & cc[j]) == 0) ans[j] += aaaa[bbb[j]];
    }
    for (reg int i = 1; i <= q; i++) fw(ans[i]), putchar('
');
    return fwrite(_O, 1, _P - _O, stdout), 0;
}
原文地址:https://www.cnblogs.com/Aestas16/p/LG7670.html