bzoj3796

后缀数组+kmp+set

前两个条件很好搞,后缀数组求lcp然后看相邻两个后缀是不是分别属于不同的串,是的话所有lcp的max就是答案,但是现在有了第三个限制就很麻烦了。

我们先把第三个串在第一个串上跑kmp,把所有匹配位置的结束点放进set里,然后像之前一样查lcp,每次查的时候在set里查询当前第一个串后缀位置+l3-1的lower_bound,然后和lcp比一比长度就是答案

把空间开小了居然是wa...记住后缀数组要开两倍空间

kmp如果匹配到了终点,那么把j向前跳一下再继续匹配,想想就知道了

#include<bits/stdc++.h>
using namespace std;
const int N = 50010;
int n, l1, l2, l3, pos, ans, k;
int rank[N * 2], sa[N * 2], lcp[N * 2], nxt[N * 2], tmp[N * 2];
char s[N * 2], s1[N], s2[N], s3[N];
set<int> S;
bool cmp(int i, int j)
{
    if(rank[i] != rank[j]) return rank[i] < rank[j];
    int ri = i + k <= n ? rank[i + k] : -1;
    int rj = j + k <= n ? rank[j + k] : -1;
    return ri < rj;
}
void construct(int n)
{
    for(int i = 1; i <= n; ++i) 
    {
        sa[i] = i;
        rank[i] = s[i];
    }
    for(k = 1; k <= n; k <<= 1)
    {
        sort(sa + 1, sa + n + 1, cmp);
        tmp[sa[1]] = 1;
        for(int i = 2; i <= n; ++i) tmp[sa[i]] = tmp[sa[i - 1]] + (cmp(sa[i - 1], sa[i]));
        for(int i = 1; i <= n; ++i) rank[i] = tmp[i];
    }
    int h = 0;
    for(int i = 1; i <= n ;++i) sa[rank[i]] = i;
    for(int i = 1; i <= n; ++i)
    {
        int j = sa[rank[i] - 1];
        if(rank[i] <= 1) continue;
        if(h) --h;
        for(; i + h <= n && j + h <= n; ++h) if(s[i + h] != s[j + h]) break;
        lcp[rank[i] - 1] = h;
    }
}
int main()
{
    scanf("%s%s%s", s1 + 1, s2 + 1, s3 + 1);
    l1 = strlen(s1 + 1);
    l2 = strlen(s2 + 1);
    l3 = strlen(s3 + 1);
    for(int i = 2, j = 0; i <= l3; ++i)
    {
        while(s3[i] != s3[j + 1] && j) j = nxt[j];
        if(s3[i] == s3[j + 1]) ++j;
        nxt[i] = j;
    }
    for(int i = 1, j = 0; i <= l1; ++i)
    {
        while(s1[i] != s3[j + 1] && j) j = nxt[j];
        if(s1[i] == s3[j + 1]) ++j;
        if(j == l3) S.insert(i), j = nxt[j];
    }
    for(int i = 1; i <= l1; ++i) s[++n] = s1[i];
    s[++n] = '#';
    pos = n;
    for(int i = 1; i <= l2; ++i) s[++n] = s2[i];
    construct(n);
    S.insert(n + 1);
    for(int i = 1; i < n; ++i) 
    {
        int p;
        if(sa[i] < pos && sa[i + 1] > pos) p = sa[i];
        else if(sa[i] > pos && sa[i + 1] < pos) p = sa[i + 1]; 
        else continue;
        set<int> :: iterator it = S.lower_bound(p + l3 - 1);
        ans = max(ans, min(*it - p, lcp[i]));
    }
    printf("%d
", ans);
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/19992147orz/p/7524360.html