[十二省联考]字符串问题

SAM上定位子串然后通过parent树优化建图就可以了

由于一个节点可能会有很多串所以拆出来一些点就行了

//Love and Freedom.
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<vector>
#define ll long long
#define inf 20021225
#define N 1600010
using namespace std;
vector<int> st[N]; int na,nb; char ch[N];
int poi,lt,rt,id[N]; int f[N][20]; int sz;
int ma(char c){return c-'a';}
int a[N],b[N],lst[N],m;
struct node{int fa,len,ch[26]; bool isa;}t[N];
struct edge{int to,lt;}e[N]; int cnt,in[N];
void insert(int c)
{
    int p=lt,np=lt=++poi; t[np].len=t[p].len+1;
    for(;p&&!t[p].ch[c];p=t[p].fa)    t[p].ch[c]=np;
    if(!p){t[np].fa=rt; return;}
    int q=t[p].ch[c];
    if(t[q].len==t[p].len+1){t[np].fa=q; return;}
    int nq=++poi; t[nq].fa=t[q].fa; t[q].fa=t[np].fa=nq;
    memcpy(t[nq].ch,t[q].ch,sizeof(t[q].ch)); t[nq].len=t[p].len+1;
    for(;p&&t[p].ch[c]==q;p=t[p].fa)    t[p].ch[c]=nq;
}
void build()
{
    for(int i=1;i<=poi;i++)    f[i][0]=t[i].fa;
    for(int j=1;j<=19;j++)
        for(int i=1;i<=poi;i++)    f[i][j]=f[f[i][j-1]][j-1];
}
void add(int x,int y){e[++cnt].to=y; e[cnt].lt=in[x]; in[x]=cnt;}
void locate(int l,int r,int flag)
{
    int ln=r-l+1; int x=id[l];
    for(int i=19;~i;i--)
        if(f[x][i] && t[f[x][i]].len>=ln)    x=f[x][i];
    t[++sz].isa=flag; t[sz].len=ln; st[x].push_back(sz);
}
bool cmp(int x,int y)
{
    return t[x].len>t[y].len || (t[x].len==t[y].len && t[x].isa>t[y].isa);
}
ll dp[N]; bool vis[N];
ll work(int x)
{
    if(vis[x])    return -1;
    //printf("MMP
");
    if(~dp[x])    return dp[x];
    vis[x]=1; dp[x]=0;
    for(int i=in[x];i;i=e[i].lt)
    {
        ll to=work(e[i].to);
        if(to==-1)    return -1;
        dp[x]=max(dp[x],to);
    }
    vis[x]=0; dp[x]+=t[x].len; return dp[x];
}
void solve()
{
    scanf("%s",ch+1); int n=strlen(ch+1); lt=rt=++poi;
    for(int i=n;i;i--)    insert(ma(ch[i])),id[i]=lt;
    build(); scanf("%d",&na); sz=poi; int l,r;
    for(int i=1;i<=na;i++)    scanf("%d%d",&l,&r),locate(l,r,1),a[i]=sz;
    scanf("%d",&nb);
    for(int i=1;i<=nb;i++)    scanf("%d%d",&l,&r),locate(l,r,0),b[i]=sz;
    for(int i=1;i<=poi;i++)    sort(st[i].begin(),st[i].end(),cmp);
    for(int i=1;i<=poi;i++)
    {
        int tmp=i;
        for(int j=(int)st[i].size()-1;~j;j--)
        {
            int cur=st[i][j]; add(tmp,cur);
            if(!t[cur].isa)    tmp=cur;
        }
        lst[i]=tmp;// printf("!%d
",tmp); 
    }
    for(int i=2;i<=poi;i++)    add(lst[t[i].fa],i);
    for(int i=1;i<=sz;i++)    if(!t[i].isa)
        t[i].len=0;
    scanf("%d",&m);
    int x,y; ll ans=0;
    for(int i=1;i<=m;i++)
        scanf("%d%d",&x,&y),add(a[x],b[y]);
    for(int i=1;i<=sz;i++)
    {
        ll tmp=work(i); if(tmp==-1){ans=-1;break;}
        ans=max(ans,tmp);
    }
    printf("%lld
",ans);// poi=sz;
}
void clear()
{
    for(int i=1;i<=sz;i++)    st[i].clear(),memset(t[i].ch,0,sizeof(t[i].ch)),t[i].isa=0,t[i].len=0,t[i].fa=0;
    memset(dp,-1,sizeof(dp)); memset(vis,0,sizeof(vis)); memset(in,0,sizeof(in));
    //memset(id,0,sizeof(id)); memset(f,0,sizeof(f)); memset(lst,0,sizeof(lst));
    poi=cnt=lt=rt=sz=na=nb=0;
}
int main()
{
    int T;// memset(dp,-1,sizeof(dp));
    scanf("%d",&T); while(T--)
    {
        clear(); solve();
    }
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/hanyuweining/p/10801730.html