模拟赛 sutoringu

sutoringu

题意:

  询问有多少一个字符串内有多少个个子区间,满足可以分成k个相同的串。

分析:

  首先可以枚举一个长度len,表示分成的k个长为len的串。然后从1开始,每len的长度分成一块,分成(n-1)/k+1块,首先可以求出连续的k块的是否是合法。

  此时只求了起点是1+len*i的串,还有些起点在块内的没有求。

  枚举k-1个相同的块,设这些块为i...j,j-i+1=k。然后与求一下第i块和第i-1块最长后缀,设为a,求一下第j块和第j+1块的最长前缀,设为b。说明如果起点在第i-1块的串,必须是后面a个字符,这些串的终点必须是第j+1块的前b个字符。于是计算一下。

  如何求连续的k块是否是一样的?可以求出这连续k块在的rank,然后取一个最大的rank和一个最小的rank,然后求之间的height最小值即可。

  复杂度$nlog^2n$。

代码:

#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cmath>
#include<cctype>
#include<set>
#include<queue>
#include<vector>
#include<map>
#include<bitset>
using namespace std;
typedef long long LL;

inline int read() {
    int x=0,f=1;char ch=getchar();for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';return x*f;
}

const int N = 600005;
char s[N];
int t1[N], t2[N], c[N], sa[N], rnk[N], ht[N], f[N][21], Log[N];
void getsa(int n) {
    int m = 130, i, *x = t1, *y = t2;
    for (i = 1; i <= m; ++i) c[i] = 0;
    for (i = 1; i <= n; ++i) x[i] = s[i], c[x[i]] ++;
    for (i = 1; i <= m; ++i) c[i] += c[i - 1];
    for (i = n; i >= 1; --i) sa[c[x[i]]--] = i;
    for (int k = 1; k <= n; k <<= 1) {
        int p = 0;
        for (i = n - k + 1; i <= n; ++i) y[++p] = i;
        for (i = 1; i <= n; ++i) if (sa[i] > k) y[++p] = sa[i] - k;
        for (i = 1; i <= m; ++i) c[i] = 0;
        for (i = 1; i <= n; ++i) c[x[y[i]]] ++;
        for (i = 1; i <= m; ++i) c[i] += c[i - 1];
        for (i = n; i >= 1; --i) sa[c[x[y[i]]]--] = y[i];
        swap(x, y);
        p = 2;
        x[sa[1]] = 1;
        for (i = 2; i <= n; ++i) 
            x[sa[i]] = (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) ? p - 1 : p ++;
        if (p > n) break;
        m = p;
    }
    for (i = 1; i <= n; ++i) rnk[sa[i]] = i;
    ht[1] = 0;
    int k = 0;
    for (i = 1; i <= n; ++i) {
        if (rnk[i] == 1) continue;
        if (k) k --;
        int j = sa[rnk[i] - 1];
        while (i + k <= n && j + k <= n && s[i + k] == s[j + k]) k ++;
        ht[rnk[i]] = k;
    }
    for (i = 1; i <= n; ++i) f[i][0] = ht[i];
    for (i = 2; i <= n; ++i) Log[i] = Log[i >> 1] + 1;
    for (int j = 1; j <= Log[n]; ++j) 
        for (i = 1; i + (1 << j) - 1 <= n; ++i) 
            f[i][j] = min(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
}
int LCP(int i,int j) {
    i = rnk[i], j = rnk[j];
    if (i > j) swap(i, j);
    i ++;
    int k = Log[j - i + 1];
    return min(f[i][k], f[j - (1 << k) + 1][k]);
}
int LCP2(int i,int j) {
    i ++;
    int k = Log[j - i + 1];
    return min(f[i][k], f[j - (1 << k) + 1][k]);
}

set<int> sk;
int n, k, rev[N];

bool check(int len) {
    int l = *sk.begin();
    set<int>::iterator it = sk.end(); it --;
    int r = *it;
    return LCP2(l, r) >= len;
}
int check2(int i,int j,int len) {
    if (sk.size() >= 2 && !check(len)) return 0;
    int a = min(len - 1, LCP(rev[i - 1], rev[i - 1 + len]));
    if (j + len > n) return 0;
    int b = min(len - 1, LCP(j, j + len));
    return max(0, b - (len - a) + 1);
}
int main() {
    freopen("sutoringu.in", "r", stdin);
    freopen("sutoringu.out", "w", stdout);
    n = read(), k = read();
    scanf("%s", s + 1);
    s[n + 1] = '#';
    for (int i = 1; i <= n; ++i) 
        s[i + n + 1] = s[n - i + 1], rev[n - i + 1] = i + n + 1;
    getsa(n + n + 1);
    LL ans = 0;
    for (int len = 1; len <= n; ++len) {
        sk.clear();
        for (int i = 1; i <= n; i += len) {
            sk.insert(rnk[i]);
            if (sk.size() > k) sk.erase(rnk[i - len * k]);
            if (sk.size() == k) ans += check(len);
        }
        if (len == 1) continue;
        sk.clear();
        for (int i = len + 1; i <= n; i += len) {
            sk.insert(rnk[i]);
            if (sk.size() > k - 1) sk.erase(rnk[i - len * (k - 1)]);
            if (sk.size() == k - 1) ans += check2(i - (k - 2) * len, i, len);
        }
    }
    cout << ans;
    return 0;
}
原文地址:https://www.cnblogs.com/mjtcn/p/10610070.html