高效处理字符串!——AC自动机

AC自动机

这两天进军AC自动机算法,越做越觉得这种算法的灵活与高效,接下来对这阵子的学习做个总结。

AC自动机,当然它最主要的作用自动帮你AC题目多模式串的匹配,也就是字典树trie和kmp的结合,再深入讲就是把kmp中失配时跳转的思想运用到trie上!

1.AC自动机构建

对于构建,基本上都是模板,先建trie,再BFS这颗trie从而构建出最重要的fail指针,即失配跳转指针(口头表达),fail指针指向的是当前状态的最长后继。

而一般我们为了加快速度,每个节点还会构建类似虚节点的东东,说清楚点就是把这个节点不存在的儿子指针指向它最长后继(即fail指向的那个节点)的该儿子,这样做可以在结束该串匹配时快速跳转到另一条模式串上继续匹配,具体自己脑补或者看代码吧(作者比较懒)。

2.AC自动机的扩展点

一般我们在失配时,会跳转fail指针继续匹配,我们把这个叫暴跳!!

暴跳,顾名思义,很费时间,所以一般会被题目卡掉,这样的话,我们一般就考虑fail树,即fail边所连成的树。为什么保证是树?因为显然每个节点只有一个父亲。

这样建好fail树后,每个节点的子树中所有节点表示的状态肯定包含该节点状态所表示的字符串,所以若要统计某一模式串在要求串中共出现的次数,只需要把要求串的所有节点权值加一,然后答案就是模式串的末节点子树大小。

统计子树大小的问题值得深究,因为很多题目中模式串和要求串都不止一个,即树上的权值肯定要变,若用遍历统计子树大小的方法来计算不优秀,这时我们采用DFS序来把树转化为线段区间问题,这样我们就可以用树状数组维护。具体来讲就是记录每个节点的dfn和low值,放在线段上看,dfn和low中间的区间,就是它的整个子树,这样就很好解释为什么可以用树状数组了吧。

3.几道例题

[Noi2011]阿狸的打字机

以 dfs 的方式给定一棵 Trie(操作次数 (le 10^5)(Sigma=26)),(m)(1le m le 10^5))次询问两个节点表示的串 (x)(y) 中的出现次数。

题解

先对dfs序中读入的所有字符串建AC自动机,因为题目支持离线,所以再将每个询问保存后以第二关键字排序,这样就可以一类一类的处理了。

之后就dfs字典树,每到一个点就把该点在树状数组中dfn编号的位置加一,回溯时减回去,这样可以保证我们遍历到某个节点时,都只会是该节点到根节点的这一条链加了一,然后判断此节点有没有在询问中以文本串的方式出现过,出现过的话就去统计该文本串所对应的模式串的末节点的子树大小。

好了,这样就做完了!

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<set>
#include<map>
#include<vector>
#include<queue>
using namespace std;
#define MAX 200000
inline int read()
{
    int x=0,t=1;char ch=getchar();
    while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
    if(ch=='-')t=-1,ch=getchar();
    while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
    return x*t;
}
char ss[MAX];
int nd[MAX],n,tot;
int ans[MAX];
int c[MAX];
int dfn[MAX],low[MAX],tim;
int ql[MAX],qr[MAX];
inline int lowbit(int x){return x&(-x);}
void Modify(int x,int w){while(x<=tim)c[x]+=w,x+=lowbit(x);}
int getsum(int x){int ret=0;while(x)ret+=c[x],x-=lowbit(x);return ret;}
struct Node
{
    int vis[26];
    int Vis[26];
    int fail,fa;
    int lt;
}t[MAX];
struct Question{int x,y,id,ans;}q[MAX];
bool operator<(Question a,Question b){return a.y<b.y;}
void GetFail()
{
    queue<int> Q;
    for(int i=0;i<26;++i)
        if(t[0].vis[i])Q.push(t[0].vis[i]);
    while(!Q.empty())
    {
        int u=Q.front();Q.pop();
        for(int i=0;i<26;++i)
            if(t[u].vis[i])
                t[t[u].vis[i]].fail=t[t[u].fail].vis[i],Q.push(t[u].vis[i]);
            else t[u].vis[i]=t[t[u].fail].vis[i];
    }
}
struct Line{int v,next;}e[MAX<<1];
int h[MAX],cnt=1;
inline void Add(int u,int v){e[cnt]=(Line){v,h[u]};h[u]=cnt++;}
void dfs(int u)
{
    dfn[u]=++tim;
    for(int i=h[u];i;i=e[i].next)dfs(e[i].v);
    low[u]=tim;
}
void DFS(int u)
{
    Modify(dfn[u],1);
    if(t[u].lt)
        for(int i=ql[t[u].lt];i<=qr[t[u].lt];++i)
            q[i].ans=getsum(low[nd[q[i].x]])-getsum(dfn[nd[q[i].x]]-1);
    for(int i=0;i<26;++i)
        if(t[u].Vis[i])
            DFS(t[u].Vis[i]);
    Modify(dfn[u],-1);
}
int main()
{
    scanf("%s",ss+1);
    int now=0;
    for(int i=1,l=strlen(ss+1);i<=l;++i)
    {
        if(ss[i]>='a'&&ss[i]<='z')
        {
            if(!t[now].vis[ss[i]-'a'])t[now].vis[ss[i]-'a']=++tot,t[tot].fa=now;
            now=t[now].vis[ss[i]-'a'];
        }
        if(ss[i]=='B')now=t[now].fa;
        if(ss[i]=='P'){nd[++n]=now;t[now].lt=n;}
    }
    for(int i=0;i<=tot;++i)
        for(int j=0;j<26;++j)
            t[i].Vis[j]=t[i].vis[j];
    int Q=read();
    GetFail();
    for(int i=1;i<=tot;++i)Add(t[i].fail,i);
    dfs(0);
    for(int i=1;i<=Q;++i)
    {
        q[i].x=read(),q[i].y=read();
        q[i].id=i;
    }
    sort(&q[1],&q[Q+1]);
    for(int i=1,pos=1;i<=Q;i=pos)
    {
        ql[q[i].y]=i;
        while(q[pos].y==q[i].y)pos++;
        qr[q[i].y]=pos-1;
    }
    DFS(0);
    for(int i=1;i<=Q;++i)ans[q[i].id]=q[i].ans;
    for(int i=1;i<=Q;++i)
        printf("%d
",ans[i]);
    return 0;
}

【JSOI 2007】 文本生成器

给定字符串集合 (|S|)(|S| le 60),$ |s| le 100 $, (|Σ| = 26)),求长度为$ m $ ( $ 1 le m le 100 $)的包含至少一个子串在 $ S $ 中的字符串数量模
10007。

题解

正难则反,这道题我们正着做会很麻烦,于是我们考虑计算出 (m) 的所有字符情况减去不包含至少一个子串的情况,结果就是包含至少一个子串在 (S) 中的字符串数量。
怎么样才算不包含至少一个子串的情况呢?

这样来做,把字符串集合建一个AC自动机,注意如果当前字符跳完fail指针后发现那个节点是某个模式串的末尾,这样的话就把这个节点end也标记一下,因为它的后缀包含了一个完整的模式串,然后就dp一下就好了,字符串处理上的dp听dalao说通常把第一维设为当前匹配字符串的长度,第二维表示当前字符。

然后此题就解决了

#include<stdio.h>
#include<algorithm>
#include<string.h>
#include<string>
#include<queue>
#include<iostream>
using namespace std;
int n,m,cnt,ans;
int Mod=10007;
string s;
struct note
{
	int fail;
	int vis[26];
	int end;
}AC[100010];
void insert(string x)
{
	int i,j,now=0;
	int len=x.length();
	for(i=0;i<len;i++)
	{
		if(AC[now].vis[x[i]-'A']==0) AC[now].vis[x[i]-'A']=++cnt;
		now=AC[now].vis[x[i]-'A'];
	}
	AC[now].end=1;
}
void get_fail()
{
	int i,j;
	queue<int> q;
	for(i=0;i<26;i++)
    {
    	if(AC[0].vis[i]!=0) q.push(AC[0].vis[i]),AC[0].fail=0;
    }
    while(!q.empty())
    {
    	int u=q.front();
    	q.pop();
    	for(i=0;i<26;i++)
    	{
    		if(AC[u].vis[i]!=0)
    		{
    			AC[AC[u].vis[i]].fail=AC[AC[u].fail].vis[i];
    			AC[AC[u].vis[i]].end|=AC[AC[AC[u].fail].vis[i]].end;
    			q.push(AC[u].vis[i]);
    		}
    		else AC[u].vis[i]=AC[AC[u].fail].vis[i];
    	}
    }
}
int f[105][10005];
int main()
{
	int i,j,k;
	scanf("%d%d",&n,&m);
	for(i=1;i<=n;i++)
	{
		cin>>s; insert(s);
	}
	f[0][0]=1;
	get_fail();
	for(i=1;i<=m;i++)
	{
		for(j=0;j<=cnt;j++)
		{
			for(k=0;k<26;k++)
			{
				if(AC[AC[j].vis[k]].end==0)
				{
					(f[i][AC[j].vis[k]]+=f[i-1][j])%=Mod;
				}
			}
		}
	}
	for(i=0;i<=cnt;i++)
	{
		(ans+=f[m][i])%=Mod;
	}
	int sum=1;
	for(i=1;i<=m;i++)
	{
		sum=sum*26%Mod;
	}
	printf("%d
",(sum-ans+Mod)%Mod);
	return 0;
}

回忆树

给定一棵 (n)(n le 10^5))个节点的无根树,每条边上有一个小写字符,(m)($ m le 10^5 $) 次询问某个串 (t)($ Sigma |t| le 3 ⋅ 10^5$)在树上从节点 (u) 至节点 (v) 的串中出现的次数。

题解

这题好题,但码起来很恶心....

思路很简单:

经过 lca 的串可以提取出来 KMP。

再差分一下,问题转化为求一个串在根到一个结点的路径上的出现次数。

离线对询问串及其反串建 AC 自动机, dfs 原树,用 BIT 维护子树和即可。

#include<stdio.h>
#include<algorithm>
#include<string.h>
#include<string>
#include<queue>
#include<vector>
using namespace std;
#define N 820005
#define M 800015
vector<int> ID[N],V[N];
char ch[2],s[M],zz[M],xx[M];
int n,m,dep[N],w[N],nx[M];
int head[N],cnt=1,f[N][18],headb[N];
int pos[N][2],ant,ans[N],bnt=1;
int dfn[N],low[N],dnt;
int tr[M];
struct A
{
	int to;
	int nxt;
	int d;
}a[N],b[N];

struct B
{
	int fail;
	int vis[26];
	int end;
}AC[M];

void add2(int x,int y)
{
	b[bnt].to=y;
	b[bnt].nxt=headb[x];
	headb[x]=bnt++;
}

int lowbit(int x)
{
	return x&(-x);
}

void increase(int x,int y)
{
	while(x<=dnt)
    {
    	tr[x]+=y;
    	x+=lowbit(x);
    }
}

int query(int x)
{
	int sum=0;
	while(x)
	{
		sum+=tr[x];
		x-=lowbit(x);
	}
	return sum;
}

int insert(int opt)
{
	int i,j,now=0;
	int len=strlen(s);
	for(j=0,i;j<len;j++)
    {
    	if(opt==1) i=j; else i=len-j-1;
    	if(AC[now].vis[s[i]-'a']==0)
    	 AC[now].vis[s[i]-'a']=++ant;
    	now=AC[now].vis[s[i]-'a'];
	}
	return now;
}

void get_fail()
{
	int i,j;
	queue<int> q;
	for(i=0;i<26;i++)
	{
		if(AC[0].vis[i]!=0) q.push(AC[0].vis[i]),add2(0,AC[0].vis[i]);
	}
	while(!q.empty())
	{
		int u=q.front();
		q.pop();
		for(i=0;i<26;i++)
		{
			if(AC[u].vis[i])
			{
				AC[AC[u].vis[i]].fail=AC[AC[u].fail].vis[i];
				add2(AC[AC[u].vis[i]].fail,AC[u].vis[i]);
				q.push(AC[u].vis[i]);
			}
			else AC[u].vis[i]=AC[AC[u].fail].vis[i];
		}
	}
}

void dfs2(int x)
{
	dfn[x]=++dnt;
	for(int i=headb[x];i;i=b[i].nxt)
	{
		int to=b[i].to;
		dfs2(to);
	}
	low[x]=dnt;
}

void add(int x,int y,int z)
{
	a[cnt].to=y;
	a[cnt].d=z;
	a[cnt].nxt=head[x];
	head[x]=cnt++;
}

void dfs(int x,int fa)
{
	int i,j;
	dep[x]=dep[fa]+1;
	f[x][0]=fa;
	for(i=1;i<=17;i++) f[x][i]=f[f[x][i-1]][i-1];
	for(i=head[x];i;i=a[i].nxt)
	{
		int to=a[i].to;
		if(to==fa) continue;
		w[to]=a[i].d;
		dfs(to,x);
	}
}

int getlca(int x,int y)
{
	int i,j;
	if(dep[x]>dep[y]) swap(x,y);
	for(i=17;i>=0;i--) 
	{
		if(dep[f[y][i]]>=dep[x]) y=f[y][i];
	}
	if(x==y) return x;
	for(i=17;i>=0;i--)
	{
		if(f[x][i]!=f[y][i])
		x=f[x][i],y=f[y][i];
	}
	return f[x][0];
}

int getpoint(int u,int d)
{
	for(int i=0;d;d>>=1,i++)
	{
		if(d&1) u=f[u][i];
	}
	return u;
}

void kmp(int u,int v,int lca,int id)
{
    int lent=strlen(s),lens=0;
    int x=getpoint(u,dep[u]-min(dep[u],dep[lca]+lent-1));
    int y=getpoint(v,dep[v]-min(dep[v],dep[lca]+lent-1));
    lens=dep[x]-dep[lca]+dep[y]-dep[lca];
    int tmp=x,i=0,j; while(tmp!=lca) zz[i++]=w[tmp],tmp=f[tmp][0];
    tmp=y,i=1;while(tmp!=lca) zz[lens-i]=w[tmp],tmp=f[tmp][0],i++;
    for(i=0;i<lent;i++) xx[i]=s[i]-'a';
    nx[0]=-1,i=0,j=-1;
    while(i<lent)
    {
    	if(j==-1||xx[i]==xx[j]) i++,j++,nx[i]=j;
    	else j=nx[j];
    }
    i=0,j=0;
    int ret=0;
    while(i<lens)
    {
    	if(j==-1||xx[j]==zz[i])
		{
			i++,j++;
			if(j==lent) ret++,j=nx[j];
		}
		else j=nx[j];
    }
    pos[id][0]=insert(1),pos[id][1]=insert(-1);
    ans[id]=ret;
    if(u!=x)
    {
    	ID[x].push_back(-id); ID[u].push_back(id);
    	V[x].push_back(pos[id][1]); V[u].push_back(pos[id][1]);
    }
    if(v!=y)
    {
    	ID[y].push_back(-id); ID[v].push_back(id);
    	V[y].push_back(pos[id][0]); V[v].push_back(pos[id][0]);
    }
}

void f_dfs(int u,int fa,int x)
{
	increase(dfn[x],1);
	int sz=V[u].size();
	for(int i=0;i<sz;i++)
	{
		int ret=query(low[V[u][i]])-query(dfn[V[u][i]]-1);
		if(ID[u][i]>0) ans[ID[u][i]]+=ret;
		else ans[-ID[u][i]]-=ret;
	}
	for(int i=head[u];i;i=a[i].nxt)
	{
		int to=a[i].to;
		if(to==fa) continue;
		f_dfs(to,u,AC[x].vis[a[i].d]);
	}
	increase(dfn[x],-1);
}

int main()
{
	int i,j,u,v;
	scanf("%d%d",&n,&m);
	for(i=1;i<n;i++)
	{
		scanf("%d%d%s",&u,&v,ch);
		add(u,v,ch[0]-'a'),add(v,u,ch[0]-'a');
	}
	dfs(1,0);
	for(i=1;i<=m;i++)
	{
		scanf("%d%d",&u,&v);
		scanf("%s",s);
		if(u==v) continue;
		int lca=getlca(u,v);
		kmp(u,v,lca,i);
	}
	get_fail();
	dfs2(0);
	f_dfs(1,0,0);
	for(i=1;i<=m;i++) printf("%d
",ans[i]);
	return 0;
}

更多题目请点赞后查询

原文地址:https://www.cnblogs.com/yzxx/p/11258140.html