HDu4416 Good Article Good sentence (后缀自动机)

题解:先对每个模式串建立一个后缀自动机,计算出子串的种类个数,再把文本串加上去,再计算一次子串种类个数,最后相减就好了

#include <bits/stdc++.h>
using namespace std;
#define fi first
#define se second
typedef long long LL;
typedef pair<int, int> pii;
const int maxn = 2e5 + 50;
int n, m;
string s, t;
struct state
{
    int link, nex[26];
    int len;
} st[maxn * 2];

int last, sz;
void sam_init(){
    for(int i = 0; i < 26; i++) st[0].nex[i] = 0;
    last = 0;
    st[last].link = -1;
    st[last].len = 0;
    sz = 1;
}
void init_nex(int u){
    for(int i = 0; i < 26; i++) st[u].nex[i] = 0;
}
void sam_extend(int x){
    if(st[last].nex[x]){
    	int q = st[last].nex[x];
    	if(st[q].len == st[last].len + 1){
    		last = q;
    		return ;
    	}
    }
    int cur = sz++;
    init_nex(cur);
    st[cur].len = st[last].len + 1;
    int p = last;
    while(p != -1 && !st[p].nex[x]) {
        st[p].nex[x] = cur;
        p = st[p].link;
    }
    if(p == -1) st[cur].link = 0;
    else{
        int q = st[p].nex[x];
        if(st[q].len == st[p].len + 1){
            st[cur].link = q;
        } else {
            int clone = sz++;
            st[clone].link = st[q].link;
            st[clone].len = st[p].len + 1;
            for(int i = 0; i < 26; i++) st[clone].nex[i] = st[q].nex[i];
            while(p != -1 && st[p].nex[x] == q){
                st[p].nex[x] = clone;
                p = st[p].link;
            }
            st[q].link = st[cur].link = clone;
        }
    }
    last = cur;
}
int main(int argc, char const *argv[])
{
    int tt;
    cin >> tt;
    int ca = 0;
    while(tt--){
        sam_init();
        cin >> n;
        cin >> t;
        for(int i = 1; i <= n; i++){
            cin >> s;
            last = 0;
            int len = s.size();
            for(int j = 0; j < len; j++){
                sam_extend(s[j] - 'a');
            }
        }
        LL ans1 = 0;
        for(int i = 1; i < sz; i++){
        	ans1 += st[i].len - st[st[i].link].len;
        }
        int len = t.size();
        last = 0;
        for(int i = 0; i < len; i++){
            sam_extend(t[i] - 'a');
        }
        LL ans = 0;
        for(int i = 1; i < sz; i++) {
            ans += st[i].len - st[st[i].link].len;
        }

        printf("Case %d: %I64d
", ++ca, ans - ans1);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/PCCCCC/p/13439860.html