4566: [Haoi2016]找相同字符——后缀自动机

题目大意:

给定两个串,求有多少种方式从两个串中各提取出一个子串并且两个子串相等。

思路:

涉及两个串的子串问题考虑对第一个串建立SAM。
然后用第个二串在SAM上匹配,每到一个点,贡献是(目前的长度-这个状态的父亲的长度)x这个状态RIGHT集合的大小,同时对这个状态的每个祖先也像这样计算贡献即可。

/*=======================================
 * Author : ylsoi
 * Time : 2019.2.14
 * Problem : bzoj4566
 * E-mail : ylsoi@foxmail.com
 * ====================================*/
#include<bits/stdc++.h>

#define REP(i,a,b) for(int i=a,i##_end_=b;i<=i##_end_;++i)
#define DREP(i,a,b) for(int i=a,i##_end_=b;i>=i##_end_;--i)
#define debug(x) cout<<#x<<"="<<x<<" "
#define fi first
#define se second
#define mk make_pair
#define pb push_back
typedef long long ll;

using namespace std;

void File(){
	freopen("bzoj4566.in","r",stdin);
	freopen("bzoj4566.out","w",stdout);
}

template<typename T>void read(T &_){
	_=0; T f=1; char c=getchar();
	for(;!isdigit(c);c=getchar())if(c=='-')f=-1;
	for(;isdigit(c);c=getchar())_=(_<<1)+(_<<3)+(c^'0');
	_*=f;
}

const int maxn=2e5+10;
int n1,n2;
char s1[maxn],s2[maxn];

int len[maxn<<1],fa[maxn<<1],ch[maxn<<1][26];
int cnt=1,last=1,sz[maxn<<1];
ll sum[maxn<<1],ans;

void insert(int x){
	int p=last,np=last=++cnt;
	len[np]=len[p]+1;
	sz[np]=1;
	while(p && !ch[p][x])ch[p][x]=np,p=fa[p];
	if(!p)fa[np]=1;
	else{
		int q=ch[p][x];
		if(len[q]==len[p]+1)fa[np]=q;
		else{
			int nq=++cnt;
			memcpy(ch[nq],ch[q],sizeof(ch[nq]));
			len[nq]=len[p]+1,fa[nq]=fa[q];
			fa[q]=fa[np]=nq;
			while(p && ch[p][x]==q)ch[p][x]=nq,p=fa[p];
		}
	}
}

void get_sz(){
	int tax[maxn<<1]={0},lis[maxn<<1]={0};
	REP(i,1,cnt)++tax[len[i]];
	REP(i,1,n1)tax[i]+=tax[i-1];
	REP(i,1,cnt)lis[tax[len[i]]--]=i;
	DREP(i,cnt,1)sz[fa[lis[i]]]+=sz[lis[i]];

	REP(i,1,cnt)sum[i]=1ll*(len[i]-len[fa[i]])*sz[i];
	REP(i,1,cnt)sum[lis[i]]+=sum[fa[lis[i]]];
}

void compare(){
	int o=1,now=0;
	REP(i,1,n2){
		int x=s2[i]-'a';
		while(o!=1 && !ch[o][x])o=fa[o],now=len[o];
		if(ch[o][x]){
			o=ch[o][x];
			++now;
			ans+=1ll*(now-len[fa[o]])*sz[o];
			ans+=sum[fa[o]];
		}
	}
}

int st[maxn],tp;

void dfs(int o){
	REP(i,1,tp)printf("%c",st[i]+'a');
	printf("
");
	REP(i,0,25)if(ch[o][i]){
		st[++tp]=i;
		dfs(ch[o][i]);
		--tp;
	}
}

int main(){
	File();
	scanf("%s%s",s1+1,s2+1);
	n1=strlen(s1+1),n2=strlen(s2+1);
	REP(i,1,n1){
		insert(s1[i]-'a');
	}
	get_sz();
	compare();
	printf("%lld
",ans);
	return 0;
}

原文地址:https://www.cnblogs.com/ylsoi/p/10376042.html