bzoj 5496: [2019省队联测]字符串问题【SAM+拓扑】

有一个想法就是暴力建图,把每个A向有和他相连的B前缀的A,然后拓扑一下,这样的图是n^2的;
考虑优化建图,因为大部分数据结构都是处理后缀的,所以把串反过来,题目中要求的前缀B就变成了后缀B
建立SAM,发现在parent树中每个B能走到的A都在子树中,所以保留这个树结构,连边权为0的边;
然后在parent树上倍增找到每个AB串对应的点,因为SAM上每个对应不止一个串,所以找完之后把对应多个AB串的点拆成一条链
然后对于一对(x,y)的AB串关系,Ax对应的点向By对应的点连边权为A长度的边
然后拓扑,找环/dp,dp的时候注意每个点最后的dp值都要加上自己的长度

#include<iostream>
#include<cstdio>
#include<cstring>
#include<set>
#include<vector>
#include<algorithm>
#include<queue>
using namespace std;
const int N=400005;
int T,n,m,na,nb,h[N<<1],cnt,dis[N],fa[N],f[20][N],la,cur,tot,ch[N][26],l[N],r[N],p[N],rl[N],wsu[N],sa[N],de[N],du[N<<1],co[N],top,a[N<<1],q[N<<1];
long long g[N<<1],ans;
char s[N];
vector<int>v[N];
vector<pair<int,int> >d[N];
set<pair<int,int> >st;
struct qwe
{
	int ne,to,va;
}e[N<<1];
int read()
{
	int r=0,f=1;
	char p=getchar();
	while(p>'9'||p<'0')
	{
		if(p=='-')
			f=-1;
		p=getchar();
	}
	while(p>='0'&&p<='9')
	{
		r=r*10+p-48;
		p=getchar();
	}
	return r*f;
}
void ins(int c,int id)
{
	la=cur,dis[cur=++tot]=dis[la]+1;
	rl[id]=cur;
	int p=la;
	for(;!ch[p][c]&&p;p=fa[p])
		ch[p][c]=cur;
	if(!p)
		fa[cur]=1;
	else
	{
		int q=ch[p][c];
		if(dis[q]==dis[p]+1)
			fa[cur]=q;
		else
		{
			int nq=++tot;
			memcpy(ch[nq],ch[q],sizeof(ch[q]));
			dis[nq]=dis[p]+1;
			fa[nq]=fa[q];
			fa[q]=fa[cur]=nq;
			for(;ch[p][c]==q;p=fa[p])
				ch[p][c]=nq;
		}
	}
}
int clc(int x,int len)
{
	for(int i=18;i>=0;i--)
		if(dis[f[i][x]]>=len)
			x=f[i][x];
	pair<int,int> nw=make_pair(x,len);
	if(st.find(nw)==st.end())
		st.insert(nw),co[x]++;
	return x;
}
void add(int u,int v,int w)
{//cout<<u<<" "<<v<<" "<<w<<endl;
	cnt++;
	e[cnt].ne=h[u];
	e[cnt].to=v;
	e[cnt].va=w;
	du[v]++;
	h[u]=cnt;
}
void dfs(int u,int fat)
{//cerr<<u<<" "<<fat<<endl;
	if(fat)
		add(fat,u,0);
	int w=0,nw=u,len=d[u].size();
	sort(d[u].begin(),d[u].end());
	if(len)
	{
		p[d[u][0].second]=u;
		for(int i=1;i<len;i++)
		{
			if(d[u][i].first!=d[u][i-1].first)
			{
				w=i;
				break;
			}
			p[d[u][i].second]=u;
		}
	}
	while(co[u]>1)
	{
		co[u]--;
		p[d[u][w].second]=++tot;
		int lw=w;
		for(int i=w+1;i<len;i++)
		{
			if(d[u][i].first!=d[u][i-1].first)
			{
				w=i;
				break;
			}
			p[d[u][i].second]=tot;
		}
		add(nw,tot,0);
		w=max(w,lw+1);
		nw=tot;
	}
	for(int i=0,len=v[u].size();i<len;i++)
		dfs(v[u][i],nw);
}
int main()
{
	T=read();
	while(T--)
	{
		memset(h,0,sizeof(h));
		memset(dis,0,sizeof(dis));
		memset(fa,0,sizeof(fa));
		memset(f,0,sizeof(f));
		memset(ch,0,sizeof(ch));
		memset(co,0,sizeof(co));
		memset(p,0,sizeof(p));
		memset(rl,0,sizeof(rl));
		memset(wsu,0,sizeof(wsu));
		memset(du,0,sizeof(du));
		memset(g,0,sizeof(g));
		memset(a,0,sizeof(a));
		for(int i=0;i<N;i++)
			v[i].clear(),d[i].clear();
		st.clear();
		cur=tot=1,cnt=top=ans=0;
		scanf("%s",s+1);
		n=strlen(s+1);
		for(int i=n;i>=1;i--)
			ins(s[i]-'a',i);
		for(int i=1;i<=tot;i++)
			wsu[dis[i]]++;
		for(int i=1;i<=n;i++)
			wsu[i]+=wsu[i-1];
		for(int i=tot;i>=1;i--)
			sa[wsu[dis[i]]--]=i;
		for(int i=1;i<=tot;i++)
			de[sa[i]]=de[fa[sa[i]]]+1,v[fa[sa[i]]].push_back(sa[i]);//,cout<<sa[i]<<endl;
		for(int i=2;i<=tot;i++)
			f[0][i]=fa[i];
		for(int i=1;i<=18;i++)
			for(int j=2;j<=tot;j++)
				f[i][j]=f[i-1][f[i-1][j]];
		na=read();
		for(int i=1;i<=na;i++)
			l[i]=read(),r[i]=read(),p[i]=clc(rl[l[i]],r[i]-l[i]+1);
		nb=read();
		for(int i=na+1;i<=na+nb;i++)
			l[i]=read(),r[i]=read(),p[i]=clc(rl[l[i]],r[i]-l[i]+1);
		for(int i=1;i<=na+nb;i++)
			d[p[i]].push_back(make_pair(r[i]-l[i]+1,i));//,cout<<p[i]<<endl;
		dfs(1,0);
		m=read();
		for(int i=1;i<=m;i++)
		{
			int x=read(),y=read();
			add(p[x],p[y+na],r[x]-l[x]+1);
		}
		for(int i=1;i<=na;i++)
			a[p[i]]=r[i]-l[i]+1;
		// queue<int>q;
		// q.push(1);
		// while(!q.empty())
		// {
			// int u=q.front();//cerr<<u<<endl;
			// q.pop();
			// top++;
			// ans=max(ans,g[u]+a[u]);
			// for(int i=h[u];i;i=e[i].ne)
			// {
				// g[e[i].to]=max(g[e[i].to],g[u]+e[i].va);
				// if(!(--du[e[i].to]))
					// q.push(e[i].to);
			// }
		// }//cerr<<"OK"<<endl;
		top=0;
		q[++top]=1;
		for(int i=1;i<=top;i++)
		{
			int u=q[i];
			ans=max(ans,g[u]+a[u]);
			for(int j=h[u];j;j=e[j].ne)
			{
				g[e[j].to]=max(g[e[j].to],g[u]+e[j].va);
				if(!(--du[e[j].to])) 
					q[++top]=e[j].to;
			}
		}
		if(top<tot)
			puts("-1");
		else
			printf("%lld
",ans);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/lokiii/p/10712338.html