hdu6086(AC 自动机)

hdu6086

题意

字符串只由 (01) 组成,求长度为 (2L) 且包含给定的 (n) 个子串的字符串的个数(且要求字符串满足 (s[i] eq s[|s| - i + 1]))。

分析

没有想到可以暴力预处理中间那些字符。

官方题解:

如果没有反对称串的限制,直接求一个长度为 (L)(01) 串满足所有给定串都出现过,那么是一个经典的 AC 自动机的问题,状态 (f[i][j][S]) 表示长度为 (i),目前在 AC 自动机的节点 (j) 上,已经出现的字符串集合为 (S) 的方案数,然后直接转移即可,时间复杂度 (O(2^nLsum |s|))

然后如果不考虑有串跨越中轴线,那么可以预处理所有正串的 AC 自动机和所有反串(即原串左右翻转)的 AC 自动机,然后从中间向两边 DP,每一次枚举右侧下一个字符是 (0) 还是 (1),那么另一侧一定是另外一个字符。状态 (f[i][j][k][S]) 表示长度为 (2i),目前右半边在正串 AC 自动机的节点 (j) 上,左半边的反串在反串 AC 自动机的节点 (k) 上,已经出现的字符串集合为 (S) 的方案数,然后直接转移,时间复杂度 (O(2^nL(sum |s|)^2))

现在考虑有串跨越中轴线,可以先爆枚从中间开始左右各 (max|s|-1) 个字符,统计出哪些串以及出现了。对于之后左右扩展出去的字符来说,肯定没有经过的它们的字符串跨越中轴线,因此可以以爆枚的结果为 DP 的初始值,从第 (max|s|) 个字符开始 DP。

时间复杂度 (O(2^nL(sum |s|)^2+max|s|2^{max|s|}))

数组要开成滚动数组,然后爆搜的时候自动机上的状态也要跟着转移。

时限还是很宽松的。

code

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#include<iostream>
using namespace std;
typedef long long ll;
const int MAXN = 121;
const int MOD = 998244353;
struct Trie {
    int root, L, nxt[MAXN][2], fail[MAXN], val[MAXN];
    int newnode() {
        memset(nxt[L], -1, sizeof nxt[L]);
        return L++;
    }
    void init() {
        L = 0;
        root = newnode();
        memset(val, 0, sizeof val);
        memset(fail, 0, sizeof fail);
    }
    void insert(int id, char S[]) {
        int len = strlen(S);
        int now = root;
        for(int i = 0; i < len; i++) {
            int d = S[i] - '0';
            if(nxt[now][d] == -1) nxt[now][d] = newnode();
            now = nxt[now][d];
        }
        val[now] |= (1 << id);
    }
    void build() {
        queue<int> Q;
        for(int i = 0; i < 2; i++) {
            if(nxt[root][i] == -1) nxt[root][i] = 0;
            else { fail[nxt[root][i]] = root; Q.push(nxt[root][i]); }
        }
        while(!Q.empty()) {
            int now = Q.front(); Q.pop();
            val[now] |= val[fail[now]];
            for(int i = 0; i < 2; i++) {
                if(nxt[now][i] == -1) nxt[now][i] = nxt[fail[now]][i];
                else { fail[nxt[now][i]] = nxt[fail[now]][i]; Q.push(nxt[now][i]); }
            }
        }
    }
    int query(char S[], int l, int r) {
        int now = root;
        int res = 0;
        int flg = 0;
        int mid = (r - l) / 2 + l;
        for(int i = l; i <= r; i++) {
            int d = S[i] - '0';
            now = nxt[now][d];
            res |= val[now];
        }
        return res;
    }
}trie1, trie2;
int n, L, mx;
int dp[2][MAXN][MAXN][64];
void dfs(char s[], int l, int r, int nl, int nr) {
    int len = r - l + 1;
    if(len / 2  >= mx) {
        int tmp = trie2.query(s, l, r);
        dp[1][nl][nr][tmp]++;
        return;
    }
    s[l - 1] = '0'; s[r + 1] = '1';
    dfs(s, l - 1, r + 1, trie1.nxt[nl][0], trie2.nxt[nr][1]);
    s[l - 1] = '1'; s[r + 1] = '0';
    dfs(s, l - 1, r + 1, trie1.nxt[nl][1], trie2.nxt[nr][0]);
}
int cnt[64];
int main() {
    cnt[0] = 0;
    for(int i = 1; i < 64; i++) {
        int j = 0;
        while(!((i >> j) & 1)) j++;
        cnt[i] = cnt[i - (1 << j)] + 1;
    }
    int T;
    scanf("%d", &T);
    while(T--) {
        scanf("%d%d", &n, &L);
        trie1.init();
        trie2.init();
        mx = 0;
        for(int i = 0; i < n; i++) {
            char s[22];
            scanf("%s", s);
            trie2.insert(i, s);
            int len = strlen(s);
            mx = max(mx, len);
            reverse(s, s + len);
            trie1.insert(i, s);
        }
        mx--;
        trie1.build();
        trie2.build();
        memset(dp, 0, sizeof dp);
        char s[65];
        dfs(s, 23, 22, 0, 0);
        int z = 1;
        for(int i = mx; i < L; i++, z = !z) {
            memset(dp[!z], 0, sizeof dp[!z]);
            for(int j = 0; j < trie1.L; j++) {
                for(int k = 0; k < trie2.L; k++) {
                    for(int p = 0; p < (1 << n); p++) {
                        if(!dp[z][j][k][p]) continue;
                        for(int q = 0; q < 2; q++) {
                            int tmp1 = trie1.nxt[j][q], tmp2 = trie2.nxt[k][!q];
                            (dp[!z][tmp1][tmp2][p | trie1.val[tmp1] | trie2.val[tmp2]] += dp[z][j][k][p]) %= MOD;
                        }
                    }
                }
            }
        }
        int sum = 0;
        for(int i = 0; i < trie1.L; i++) {
            for(int j = 0; j < trie2.L; j++) {
                sum = (sum + dp[z][i][j][(1 << n) - 1]) % MOD;
            }
        }
        printf("%d
", sum);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/ftae/p/7329315.html