AC自动机--简单版、加强版、二次加强版

AC自动机算法流程

若我们没有构建出失败指针,则AC自动机就是一个普通的(Trie),而用(Trie)完成上面的单词查询问题,我们需要对于文章的每一个位置(i)开始,将(S_{i ightarrow m})视作一个字符串,在(Trie)中进行查询,将所有所到达的节点进行标记,最后统计那些代表模式串的节点的标记总个数。

这与(KMP)中的暴力算法非常相似,我们回忆(KMP)是通过寻找模式串(B)中最大的(j),使得在当前位置(i)(B_{1 ightarrow j}=B_{i-j+1 ightarrow i}),使下次尝试匹配的位置变为(j),减少明显无用的尝试来优化复杂度的。AC 自动机也可以借鉴这个思想。

现在匹配是在(Trie)上进行的,因此当我们匹配到(Trie)中某个节点(x)时,它代表的是模式串中某个(些)串的前缀(T_{1 ightarrow j}),同时也是主串的一段子串(S_{i-j+1 ightarrow i}),而(j)其实就是(x)(Trie)中的深度。当在(x)匹配失败(即失配)时,我们需要找到另一个串(也可能是自己本身,在(KMP)中由于只有一个模式串,所以每次它只能找自己),使得它有尽量长的前缀与(x)所代表的串的后缀相等,即(T_{1 ightarrow k}=T_{j-k+1 ightarrow j}),且(k<j),而(T_{1 ightarrow k})也肯定对应着(Trie)中的某个节点(to) ,因此我们将(x)的失败指针指向(to),就像(KMP)中的(fail)数组一样。

例题:

简单版:暴力跳(fail),记录一下以当前字符结尾的有多少个单词,为了让每个单词只算一遍,走过它之后标记一下就可以(A)

code:

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <queue>
using namespace std;
int read(){
	int x = 1,a = 0;char ch = getchar();
	while (ch < '0'||ch > '9'){if (ch == '-') x = -1;ch = getchar();}
	while (ch >= '0'&&ch <= '9'){a = a*10+ch-'0';ch = getchar();}
	return x*a;
}
const int maxn = 1e6+10;
int n,vis[maxn],cnt = 1;
char s[maxn];
int trie[maxn][30],kmp[maxn];
void build(char *a){
	int root = 1,len = strlen(a);
	for (int i = 0;i < len;i++){
		int to = a[i]-'a';
		if (!trie[root][to]) trie[root][to] = ++cnt;
		root = trie[root][to];
	}
	vis[root]++;
}
void getfail(){
	queue<int> q;q.push(1);
	for (int i = 0;i < 26;i++) trie[0][i] = 1;
	while (!q.empty()){
		int root = q.front();q.pop();
		for (int i = 0;i < 26;i++){
			if (!trie[root][i]) trie[root][i] = trie[kmp[root]][i];
			else{
				kmp[trie[root][i]] = trie[kmp[root]][i];
				q.push(trie[root][i]);
			}		
		}
	}
}
int query(char *a){
	int root = 1,len = strlen(a),ans = 0;
	for (int i = 0;i < len;i++){
		int to = a[i]-'a';
		root = trie[root][to];
		for (int j = root;j&&vis[j]!=-1;j = kmp[j]){
			ans+=vis[j];
			vis[j] = -1;
		}
	}
	return ans;
}
int main(){
	scanf ("%d",&n);
	for (int i = 1;i <= n;i++){
		scanf ("%s",s);
		build(s);
	}
	scanf ("%s",s);
	getfail();
	printf("%d
",query(s));
	return 0;
}

加强版:把简单版中记录以当前字符结尾的单词个数改为编号,且跳过后不用标记,多组数据记得初始化。

code:

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>
#include <cstring>
using namespace std;
int read(){
	int x = 1,a = 0;char ch = getchar();
	while (ch < '0'||ch > '9'){if (ch == '-') x = -1;ch = getchar();}
	while (ch >= '0'&&ch <= '9'){a = a*10+ch-'0';ch = getchar();}
	return x*a;
}
const int maxn = 1e6+10;
int n,num;
char s[1000][100],c[maxn];
int cnt,trie[maxn][30],vis[maxn];
int ans[maxn],kmp[maxn];
void build(char *a,int id){
	int root = 1,len = strlen(a);
	for (int i = 0;i < len;i++){
		int to = a[i]-'a';
		if (!trie[root][to]) trie[root][to] = ++cnt;
		root = trie[root][to];
	}
	vis[root] = id;
}
void getfail(){
	queue<int> q;q.push(1);
	for (int i = 0;i < 26;i++) trie[0][i] = 1;
	while (!q.empty()){
		int root = q.front();q.pop();
		for (int i = 0;i < 26;i++){
			if (!trie[root][i]) trie[root][i] = trie[kmp[root]][i];
			else{
				kmp[trie[root][i]] = trie[kmp[root]][i];
				q.push(trie[root][i]);
			}
		}
	}
}
void query(char *a){
	int root = 1,len = strlen(a);
	for (int i = 0;i < len;i++){
		root = trie[root][a[i]-'a'];
		for (int j = root;j;j = kmp[j]){
			ans[vis[j]]++;
		}
	}
}
void init(){
	memset(trie,0,sizeof(trie));
	memset(vis,0,sizeof(vis));
	memset(kmp,0,sizeof(kmp));
	memset(ans,0,sizeof(ans)); 
	cnt =  1,num = 0;
}
int main(){
	while (scanf ("%d",&n)&&n!=0){
		init();
		for (int i = 1;i <= n;i++){
			scanf ("%s",s[i]);
			build(s[i],i);
		}
		getfail();
		scanf("%s",c);
		query(c);
		for (int i = 1;i <= n;i++) num = max(num,ans[i]);
		printf("%d
",num);
		for (int i = 1;i <= n;i++){
			if (ans[i] == num) printf("%s
",s[i]);
		}
	}
	return 0;
}

二次加强版:拓扑优化

code:

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <queue>
using namespace std;
const int maxn = 2e5+10,maxm = 2e6+10;
int n,cnt = 1;
char s[maxm];
int vis[maxm],trie[maxn][30],dp[maxm],kmp[maxm];
void build(char *a,int x){
	int len = strlen(a),root = 1;
	for (int i = 0;i < len;i++){
		int to = a[i]-'a';
		if (!trie[root][to]) trie[root][to] = ++cnt;
		root = trie[root][to];
	}
	vis[x] = root;
}
void getfail(){
	queue<int> q;
	for (int i = 0;i < 26;i++) trie[0][i] = 1;
	q.push(1);
	while (!q.empty()){
		int x = q.front();q.pop();
		for (int i = 0;i < 26;i++){
			if (!trie[x][i])trie[x][i] = trie[kmp[x]][i];
			else{
				kmp[trie[x][i]] = trie[kmp[x]][i];
				q.push(trie[x][i]);
			}
		}
	}
}
int du[maxm];
void toopo(){
	queue<int> q;
	for (int i = 1;i <= cnt;i++) du[kmp[i]]++;
	for (int i = 1;i <= cnt;i++){
		if (!du[i]) q.push(i);
	}
	while (!q.empty()){
		int x = q.front();q.pop();
		dp[kmp[x]] += dp[x];
		if (!--du[kmp[x]]) q.push(kmp[x]);
	}
}
int main(){
	scanf ("%d",&n);
	for (int i = 1;i <= n;i++){
		scanf ("%s",s);
		build(s,i);
	}
	getfail();
	scanf ("%s",s);
	int len = strlen(s);
	for (int i = 0,root = 1;i < len;i++){
		root = trie[root][s[i]-'a'];
		dp[root]++;
	}	
	toopo();
	for (int i = 1;i <= n;i++){
		printf("%d
",dp[vis[i]]);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/little-uu/p/13964113.html