后缀自动机学习笔记

hihocoder系列

这是一个系列讲解,好康的,有六集……

构建SAM并求不同子串数

对应第二个讲解

#include <iostream>
#include <cstring>
#include <cstdio>
using namespace std;
typedef long long ll;
int tot, n, maxLen[2000005], minLen[2000005], trans[2000005][26];
int slink[2000005];
char ss[1000005];
int newSamNode(int _maxLen_, int _minLen_, int *_trans_, int _slink_){
	maxLen[++tot] = _maxLen_;
	minLen[tot] = _minLen_;
	slink[tot] = _slink_;
	if(_trans_!=NULL)
		for(int i=0; i<26; i++)
			trans[tot][i] = _trans_[i];
	return tot;
}
int addChar(int u, char ch){
	int c=ch-'a', v=u;
	int z=newSamNode(maxLen[u]+1, -1, NULL, 0);
	while(v && !trans[v][c]){
		trans[v][c] = z;
		v = slink[v];
	}
	if(!v){
		minLen[z] = slink[z] = 1;
		return z;
	}
	int x=trans[v][c];
	if(maxLen[v]+1==maxLen[x]){
		slink[z] = x;
		minLen[z] = maxLen[x] + 1;
		return z;
	}
	int y=newSamNode(maxLen[v]+1, -1, trans[x], slink[x]);
	minLen[y] = maxLen[slink[y]] + 1;
	slink[z] = slink[x] = y;
	minLen[x] = minLen[z] = maxLen[y] + 1;
	while(v && trans[v][c]==x){
		trans[v][c] = y;
		v = slink[v];
	}
	return z;
}
int main(){
	scanf("%s", ss);
	n = strlen(ss);
	int pre=1;
	tot = 1;
	for(int i=0; i<n; i++)
		pre = addChar(pre, ss[i]);
	ll ans=0;
	for(int i=2; i<=tot; i++)
		ans += maxLen[i] - minLen[i] + 1;
	cout<<ans<<endl;
	return 0;
}

长度为 (k) 的子串最多出现几次

第三个讲解

#include <iostream>
#include <cstring>
#include <cstdio>
#include <queue>
using namespace std;
int n, tot=1, maxlen[2000005], minlen[2000005], trans[2000005][26];
int slink[2000005], cntendpos[2000005], cnt[2000005], ans[2000005];
char ss[1000005];
queue<int> d;
int newSamNode(int Maxlen, int Minlen, int *Trans, int Slink){
	maxlen[++tot] = Maxlen;
	minlen[tot] = Minlen;
	slink[tot] = Slink;
	if(Trans!=NULL)
		for(int i=0; i<26; i++)
			trans[tot][i] = Trans[i];
	return tot;
}
int addChar(int u, char ch){
	int c=ch-'a', v=u;
	int z=newSamNode(maxlen[v]+1, -1, NULL, 0);
	cntendpos[z] = 1;
	while(v && !trans[v][c]){
		trans[v][c] = z;
		v = slink[v];
	}
	if(!v){
		minlen[z] = slink[z] = cntendpos[z] = 1;
		return z;
	}
	int x=trans[v][c];
	if(maxlen[v]+1==maxlen[x]){
		slink[z] = x;
		minlen[z] = maxlen[x] + 1;
		return z;
	}
	int y=newSamNode(maxlen[v]+1, -1, trans[x], slink[x]);
	minlen[y] = maxlen[slink[y]] + 1;
	slink[x] = slink[z] = y;
	minlen[x] = minlen[z] = maxlen[y] + 1;
	while(v && trans[v][c]==x){
		trans[v][c] = y;
		v = slink[v];
	}
	return z;
}
int main(){
	scanf("%s", ss);
	n = strlen(ss);
	int pre=1;
	for(int i=0; i<n; i++)
		pre = addChar(pre, ss[i]);
	for(int i=1; i<=tot; i++)
		cnt[slink[i]]++;
	for(int i=1; i<=tot; i++)
		if(!cnt[i])
			d.push(i);
	while(!d.empty()){
		int x=d.front();
		d.pop();
		cnt[slink[x]]--;
		cntendpos[slink[x]] += cntendpos[x];
		if(!cnt[slink[x]])	d.push(slink[x]);
	}
	for(int i=2; i<=tot; i++)
		ans[maxlen[i]] = max(ans[maxlen[i]], cntendpos[i]);
	for(int i=1; i<n; i++)
		ans[i] = max(ans[i], ans[i+1]);
	for(int i=1; i<=n; i++)
		printf("%d
", ans[i]);
	return 0;
}

多个串搞搞

对应讲解四

#include <iostream>
#include <cstring>
#include <cstdio>
#include <queue>
using namespace std;
typedef long long ll;
int n, len, tot=1, minlen[2000005], maxlen[2000005], trans[2000005][11];
int slink[2000005], ru[2000005], hefa[2000005], sum[2000005];
const int mod=1e9+7;
char ss[1000005];
queue<int> d;
int newSamNode(int Maxlen, int Minlen, int *Trans, int Slink){
	maxlen[++tot] = Maxlen;
	minlen[tot] = Minlen;
	slink[tot] = Slink;
	if(Trans!=NULL)
		for(int i=0; i<=10; i++)
			trans[tot][i] = Trans[i];
	return tot;
}
int addChar(int u, char ch){
	int c=ch-'0', v=u;
	int z=newSamNode(maxlen[u]+1, -1, NULL, 0);
    while(v && !trans[v][c]){
        trans[v][c] = z;
        v = slink[v];
    }
    if(!v){
        minlen[z] = slink[z] = 1;
        return z;
    }
    int x=trans[v][c];
    if(maxlen[v]+1==maxlen[x]){
        slink[z] = x;
        minlen[z] = maxlen[x] + 1;
        return z;
    }
    int y=newSamNode(maxlen[v]+1, -1, trans[x], slink[x]);
    minlen[y] = maxlen[slink[y]] + 1;
    slink[z] = slink[x] = y;
    minlen[x] = minlen[z] = maxlen[y] + 1;
    while(v && trans[v][c]==x){
        trans[v][c] = y;
        v = slink[v];
    }
    return z;
}
int main(){
	cin>>n;
	int pre=1;
	for(int i=1; i<=n; i++){
		scanf("%s", ss);
		len = strlen(ss);
		if(i!=1)	pre = addChar(pre, ':');
		for(int i=0; i<len; i++)
			pre = addChar(pre, ss[i]);
	}
	for(int i=1; i<=tot; i++)
		for(int j=0; j<=10; j++)
			if(trans[i][j])
				ru[trans[i][j]]++;
	for(int i=1; i<=tot; i++)
		if(!ru[i]){
			d.push(i);
			hefa[i] = 1;
		}
	while(!d.empty()){
		int x=d.front();
		d.pop();
		for(int i=0; i<=10; i++)
			if(trans[x][i]){
				int t=trans[x][i];
				if(i<10){
					hefa[t] = (hefa[t] + hefa[x]) % mod;
					sum[t] = (sum[t] + (ll)sum[x] * 10) % mod;
					sum[t] = (sum[t] + (ll)i * hefa[x]) % mod;
				}
				ru[t]--;
				if(!ru[t])	d.push(t);
			}
	}
	int ans=0;
	for(int i=2; i<=tot; i++)
		ans = (ans + sum[i]) % mod;
	cout<<ans<<endl;
	return 0;
}
原文地址:https://www.cnblogs.com/poorpool/p/9042653.html