校内模拟赛 休闲字符串

题目:

  给出长度为n的字符串,你需要找到一些不相交的长为k的段,这些段的字典序必须非降。

分析:

  如果k等于1,那么就是一个最长不降子序列问题。长度不是1的话,从对于fi],从$1~i-k$转移即可,然后树状数组优化。

  用SA预处理每段子串的大小。

代码:

#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cmath>
#include<cctype>
#include<set>
#include<queue>
#include<vector>
#include<map>
#include<bitset>
using namespace std;
typedef long long LL;

inline int read() {
    int x=0,f=1;char ch=getchar();for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';return x*f;
}

const int N = 200005;
char s[N];

int t1[N], t2[N], c[N], sa[N], rnk[N], ht[N], w[N];

void getsa(int n) {
    memset(t1, 0, sizeof(t1));
    memset(t2, 0, sizeof(t2));
    int m = 130, i, *x = t1, *y = t2;
    for (i = 0; i <= m; ++i) c[i] = 0;
    for (i = 1; i <= n; ++i) x[i] = s[i], c[x[i]] ++;
    for (i = 1; i <= m; ++i) c[i] += c[i - 1];
    for (i = n; i >= 1; --i) sa[c[x[i]] -- ] = i;
    for (int k = 1; k <= n; k <<= 1) {
        int p = 0;
        for (i = n - k + 1; i <= n; ++i) y[++p] = i; 
        for (i = 1; i <= n; ++i) if (sa[i] > k) y[++p] = sa[i] - k;
        for (i = 0; i <= m; ++i) c[i] = 0;
        for (i = 1; i <= n; ++i) c[x[y[i]]] ++;
        for (i = 1; i <= m; ++i) c[i] += c[i - 1];
        for (i = n; i >= 1; --i) sa[c[x[y[i]]]--] = y[i];
        swap(x, y);
        p = 2;
        x[sa[1]] = 1;
        for (i = 2; i <= n; ++i) 
            x[sa[i]] = (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) ? p - 1 : p ++;
        if (p > n) break;
        m = p;
    }
    for (int i = 1; i <= n; ++i) rnk[sa[i]] = i;
    int k = 0;
    ht[1] = 0;
    for (int i = 1; i <= n; ++i) {
        if (rnk[i] == 1) continue;
        if (k) k --;
        int j = sa[rnk[i] - 1];
        while (i + k <= n && j + k <= n && s[i + k] == s[j + k]) k ++;
        ht[rnk[i]] = k;
    }
}
struct Bit{
    int mx[N], n;
    void update(int p,int v) {
        for (; p <= n; p += (p & (-p))) mx[p] = max(mx[p], v);
    }
    int query(int p) {
        int ans = 0;
        for (; p; p -= (p & (-p))) ans = max(ans, mx[p]);
        return ans;
    }
}bit; 
int f[N];
void solve() {
    int n = read(), m = read(), now = 0, k = n - m + 1, Ans = 1;
    scanf("%s", s + 1);
    memset(f, 0, sizeof(f));
    getsa(n);
    w[sa[1]] = ++now;
    for (int i = 2; i <= n; ++i) {
        if (ht[i] < m) now ++;
        w[sa[i]] = now;
    }
    bit.n = now; memset(bit.mx, 0, sizeof(bit.mx));
    for (int i = 1; i <= k; ++i) {
        f[i] = bit.query(w[i]) + 1;
        if (i - m + 1 >= 1) bit.update(w[i - m + 1], f[i - m + 1]);
        Ans = max(Ans, f[i]);
    }
    cout << Ans << "
";
}
int main() {
    for (int T = read(); T --; solve());
    return 0;
}
原文地址:https://www.cnblogs.com/mjtcn/p/10604172.html