后缀自动机的应用

求不同子串个数

LUOGU P2408

解题思路

  这个其实就是在后缀自动机上统计不同的路径条数,可以(dp)解决,(f[u]=sumlimits_{v=son[u]} f[v]+1),时间复杂度(O(n))。突然发现这种太麻烦了。。直接枚举每个点,(ans=sum l[i]-l[fa[i]])就行了。

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>

using namespace std;
const int N=200005;
typedef long long LL;

inline int rd(){
	int x=0,f=1; char ch=getchar();
	while(!isdigit(ch)) f=ch=='-'?0:1,ch=getchar();
	while(isdigit(ch)) x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
	return f?x:-x;
}

int n,c[N],a[N],l[N],r[N],len;
LL f[N];
char s[N];

struct SAM{
	int ch[N][28],fa[N],cnt,lst,l[N];
	void insert(int c){
		int p=lst,np=++cnt; lst=np; l[np]=l[p]+1;
		for(;!ch[p][c] && p;p=fa[p]) ch[p][c]=np;
		if(!p) fa[np]=1;
		else {
			int q=ch[p][c]; 
			if(l[q]==l[p]+1) fa[np]=q;
			else {
				int nq=++cnt; l[nq]=l[p]+1;
				memcpy(ch[nq],ch[q],sizeof(ch[nq]));
				fa[nq]=fa[q]; fa[np]=fa[q]=nq;
				for(;ch[p][c]==q;p=fa[p]) ch[p][c]=nq;
			} 
		}
	}
	void build(){
		for(int i=1;i<=n;i++) insert(s[i]-'a'+1);
		for(int i=1;i<=cnt;i++) c[l[i]]++;
		for(int i=1;i<=cnt;i++) c[i]+=c[i-1];
		for(int i=1;i<=cnt;i++) a[c[l[i]]--]=i;
	}
	void solve(){
		for(int i=cnt;i;i--)
			for(int j=1;j<=26;j++)
				if(ch[a[i]][j]) f[a[i]]+=f[ch[a[i]][j]]+1;
		printf("%lld
",f[1]);
	}
}sam;

int main(){
	sam.lst=sam.cnt=1; n=rd();
	scanf("%s",s+1); sam.build();
	sam.solve();
	return 0;
}

求模式串在文本串中出现次数

  要处理(parent)树上的倍增数组和(right)集合大小,然后扫描模式串,记录当前匹配了多少位,最后先判断匹配位数是否为(m)(m为模式串的长度)。是的话就倍增往上跳到满足(len>=m)的最高位置,把这个位置的(right)集合大小统计到答案中。

例题1:CF 235C
题解

(a)子串在(b)串中出现(k)次的出现次数

bzoj 3277

解题思路

  首先构造广义(sam),然后对于每一个状态记录它属于哪个串,这个需要用(set)。如果一个状态属于串(a),那么其父节点也一定属于串(a),那么可以(dfs)启发式合并,把每个节点的(set)插到父节点。然后计算时要枚举每个串,同时在后缀自动机上走相应节点,然后判断(set)集合大小是否(>k),是的话统计答案,否则跳父节点,时间复杂度(O(nlogn))

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<string>
#include<set>
 
using namespace std;
const int N=100005;
const int M=200005;
typedef long long LL;
 
int n,k,head[M],tot,to[M<<1],nxt[M<<1],siz[M];
string s[N];
LL ans;
set<int> S[M];
 
struct SAM{
    int fa[M],ch[M][28],len[M],cnt,lst;
    void Insert(int c,int id){
        int p=lst,np=++cnt; len[np]=len[p]+1;
        lst=cnt; S[np].insert(id);
        for(;p && !ch[p][c];p=fa[p]) ch[p][c]=np;
        if(!p) fa[np]=1;
        else {
            int q=ch[p][c];
            if(len[q]==len[p]+1) fa[np]=q;
            else {
                int nq=++cnt; len[nq]=len[p]+1;
                memcpy(ch[nq],ch[q],sizeof(ch[nq]));
                fa[nq]=fa[q]; fa[q]=fa[np]=nq;
                for(;ch[p][c]==q;p=fa[p]) ch[p][c]=nq; 
            }
        }
    }
}sam;
 
inline void add(int bg,int ed){
    to[++tot]=ed,nxt[tot]=head[bg],head[bg]=tot;
}
 
void dfs(int x){
    for(int i=head[x];i;i=nxt[i]){
        int u=to[i]; dfs(u); if(x==1) continue;
        for(set<int>::iterator it=S[u].begin();it!=S[u].end();it++)
            S[x].insert(*it);
    }
    siz[x]=S[x].size();
}
 
int main(){
    scanf("%d%d",&n,&k); int len; sam.cnt=sam.lst=1;
    for(int i=1;i<=n;i++){
        cin>>s[i]; len=s[i].length(); sam.lst=1;
        for(int j=0;j<len;j++) sam.Insert(s[i][j]-'a'+1,i);
    }
    if(k>n) {for(int i=1;i<=n;i++) puts("0"); return 0;}
    for(int i=2;i<=sam.cnt;i++) 
        if(sam.fa[i]) add(sam.fa[i],i);
    dfs(1); int p;
    for(int i=1;i<=n;i++){
        ans=0; len=s[i].length(); p=1;
        for(int j=0;j<len;j++){
            p=sam.ch[p][s[i][j]-'a'+1];
            while(siz[p]<k) p=sam.fa[p];
            ans+=sam.len[p];
        }
        printf("%lld
",ans);
    }
    return 0;
}

原文地址:https://www.cnblogs.com/sdfzsyq/p/10446178.html