poj3415(后缀数组)

poj3415

题意

给定两个字符串,给出长度 (m) ,问这两个字符串有多少对长度大于等于 (m) 且完全相同的子串。

分析

首先连接两个字符串 A B,中间用一个特殊符号分割开。
按照 (sa) 的顺序(即枚举 (height) 值),进行分组,那么有公共前缀长大于等于 (m) 的都分到了一组,对于某一组,后缀串可能来自于 A 也可能来自于 B,那么对于 A 找前面的 B 串,对于 B 找前面的 A 串,如果某两个后缀串的公共前缀长为 (l(l geqslant m)),那么显然会有 (l - m + 1) 对子串。
注意到这个性质: 对于两个后缀串 j 和 k,设 (rnk[j] < rnk[k]) ,LCP长度为 (height[rnk[j]+1], height[rnk[j]+2], ... , height[rnk[k]]) 中的最小值。
维护一个单调递增的栈(保证栈顶最大)可以用一个二维数组表示((q[][2])),一个是栈,一个是某个数的个数。
举个例子,如果连续的 (height) 值为 (2 3 4)(m = 2),前三个为 A 串,那么 (2 3 4) 全部入栈,且计算对答案的贡献 (sum)(不是直接加到答案上),即 ((2-2+1) + (3-2+1) + (4-2+1)) ,到 B 串时,答案就加上了这个值,但是如果后面还有一个 B 串且 (height)(3),那么就要弹栈,且减去 (sum) 值多的那部分(前面多算了),栈里 (4) 的数量为 (1),所以 (sum = sum - (4 - 3) * 1) ,且栈里 (3) 的数量变为了 (2)(4) 对应的 A 串对于后面串提供的贡献减小了(注意前面的性质),所以(4) 变为了 (3) ),答案加上 (sum)

code

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int MAXN = 2e5 + 10;
const int INF = 1e9;
char s[MAXN];
int sa[MAXN], t[MAXN], t2[MAXN], c[MAXN], n; // n 为 字符串长度 + 1,即最后一位为数字 0
int rnk[MAXN], height[MAXN];
// 构造字符串 s 的后缀数组。每个字符值必须为 0 ~ m-1
void build_sa(int m) {
    int i, *x = t, *y = t2;
    for(i = 0; i < m; i++) c[i] = 0;
    for(i = 0; i < n; i++) c[x[i] = s[i]]++;
    for(i = 1; i < m; i++) c[i] += c[i - 1];
    for(i = n - 1; i >= 0; i--) sa[--c[x[i]]] = i;
    for(int k = 1; k <= n; k <<= 1) {
        int p = 0;
        for(i = n - k; i < n; i++) y[p++] = i;
        for(i = 0; i < n; i++) if(sa[i] >= k) y[p++] = sa[i] - k;
        for(i = 0; i < m; i++) c[i] = 0;
        for(i = 0; i < n; i++) c[x[y[i]]]++;
        for(i = 0; i < m; i++) c[i] += c[i - 1];
        for(i = n - 1; i >= 0; i--) sa[--c[x[y[i]]]] = y[i];
        swap(x, y);
        p = 1;
        x[sa[0]] = 0;
        for(i = 1; i < n; i++)
            x[sa[i]] = y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + k] == y[sa[i] + k] ? p - 1 : p++;
        if(p >= n) break;
        m = p;
    }
}
void getHeight() {
    int i, j, k = 0;
    for(i = 0; i < n; i++) rnk[sa[i]] = i;
    for(i = 0; i < n - 1; i++) {
        if(k) k--;
        j = sa[rnk[i] - 1];
        while(s[i + k] == s[j + k]) k++;
        height[rnk[i]] = k;
    }
}
char s2[MAXN];
int q[MAXN][2];
int main() {
    int m;
    while(~scanf("%d", &m) && m) {
        scanf("%s%s", s, s2); // A 、B串
        int l = strlen(s), l2 = strlen(s2);
        s[l++] = '#';
        for(int i = 0; i < l2; i++) s[i + l] = s2[i];
        s[l + l2] = 0;
        n = l + l2 + 1;
        build_sa(128);
        getHeight();
        ll ans = 0, sum = 0;
        int top = 0;
        // 在 B 串前找 A
        for(int i = 2; i < n; i++) {
            int cnt = 0;
            if(height[i] < m) {
                top = 0; sum = 0;
                continue;
            }
            if(sa[i - 1] < l) {
                cnt++;
                sum += height[i] - m + 1;
            }
            while(top && q[top - 1][0] >= height[i]) {
                top--;
                sum -= (q[top][0] - height[i]) * q[top][1];
                cnt += q[top][1];
            }
            q[top][0] = height[i]; q[top++][1] = cnt;
            if(sa[i] >= l) ans += sum;
        }
        // 在 A 串前找 B
        sum = 0; top = 0;
        for(int i = 2; i < n; i++) {
            int cnt = 0;
            if(height[i] < m) {
                top = 0; sum = 0;
                continue;
            }
            if(sa[i - 1] >= l) {
                cnt++;
                sum += height[i] - m + 1;
            }
            while(top && q[top - 1][0] >= height[i]) {
                top--;
                sum -= (q[top][0] - height[i]) * q[top][1];
                cnt += q[top][1];
            }
            q[top][0] = height[i]; q[top++][1] = cnt;
            if(sa[i] < l) ans += sum;
        }
        printf("%lld
", ans);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/ftae/p/7222999.html