Separate String(Ac自动机+dp)

题意

给出大小为 (n) 的字符串集合,给定字符串 (t) ,求拆分 (t) 的方案数,要求串 (t) 拆分后每一个串都要是集合中的某个串。

答案取模1e9+7

思路

对于一个串 (t) 的第i个位置,如果他是某个串的结尾, 并且这个串之前的串也是个合法串,那么可进行dp转移,可用ac自动机的fail指针来维护第 (i) 个位置是否为某个串的结尾并遍历集合中所有合法的串。

Code

#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+10;
const int mod = 1e9+7;
int dp[maxn], pos[maxn];
struct Ac {
    int tr[maxn][26], fail[maxn*26], e[maxn*26], len[maxn*26];
    int dep[maxn*26];
    int tot;
    void insert(char *t) {
        int p=0;
        for (int c, i=0; t[i]; ++i) {
            c = t[i]-'a';
            if(!tr[p][c]) {
                tr[p][c] = ++tot;
                dep[tot] = dep[p]+1;
            }
            p = tr[p][c];
        }
        e[p]=1;
        len[p] = dep[p];
    }
    void build() {
        queue<int>q;
        for (int i=0; i<26; ++i) if(tr[0][i]) q.push(tr[0][i]);
        while(!q.empty()) {
            int u=q.front();
            q.pop();
            for (int i=0; i<26; ++i) {
                if(tr[u][i]) {
                    fail[tr[u][i]] = tr[fail[u]][i], q.push(tr[u][i]);
                    e[tr[u][i]] |= e[fail[tr[u][i]]];
//                    len[tr[u][i]] =len[fail[tr[u][i]]];
//                    cout << tr[u][i] << " " << fail[tr[u][i]] << endl;
                } else tr[u][i] = tr[fail[u]][i];
            }
        }
//        for (int i=1; i<=tot; ++i) len[i] = dep[i]-dep[fail[i]];
//        for (int i=1; i<=tot; ++i) printf("%d ", len[i]); puts("");
    }
    void query(char *t) {
        int p = 0;
        dp[0] = 1, e[0] = 1;
        for (int c, i=1; t[i]; ++i) {
            c = t[i]-'a';
            p=tr[p][c], pos[i] = p;
            for (int j=p; j && e[j]; j=fail[j]) {
                if(e[pos[i-len[j]]] && len[j])
                    dp[i] = (dp[i] + dp[i-len[j]])%mod;
            }
        }
    }
} ac;
int n;
char str[maxn];
 
int main() {
//    freopen("input.txt", "r", stdin);
    scanf("%d", &n);
    for (int i=1; i<=n; ++i) {
        scanf("%s", str);
        ac.insert(str);
    }
    ac.build();
     scanf("%s", str+1);
    ac.query(str);
   // n=100000;
    n = strlen(str+1);
//    for (int i=1; i<=n; ++i) printf("%d
", dp[i]);
    printf("%d
", dp[n]);
    return 0;
}
 
/*
3
aaaaaa
aaaaa
a
aaaaa
 
3
a
aaaaaa
aaaaaaaa
aaa
 
4
a
b
ab
aba
ababa
 
5
a
b
ab
ba
aba
abab
 
5
a
b
ab
ba
aba
ababa
 
*/
原文地址:https://www.cnblogs.com/acerkoo/p/11386486.html