[Codeforces 452E] Three Strings

[题目链接]

         https://codeforces.com/contest/452/problem/E

[算法]

         构建后缀数组

         用并查集合并答案即可

         时间复杂度 : O(NlogN)

[代码]

        

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
const int N = 3e5 + 10;
const int P = 1e9 + 7;

#define rint register int

struct info
{
        int id , ht;
} a[N];

int n;
int height[N] , sa[N] , rk[N] , sz[N] , cnta[N] , cntb[N] , cntc[N] , bel[N] , fa[N];
ll ans[N];
char s[N] , s1[N] , s2[N] , s3[N];

template <typename T> inline void chkmax(T &x,T y) { x = max(x,y); }
template <typename T> inline void chkmin(T &x,T y) { x = min(x,y); }
template <typename T> inline void add(T &x , T y)
{
        x += y;
        while (x >= P) x -= P;
}
template <typename T> inline void read(T &x)
{
    T f = 1; x = 0;
    char c = getchar();
    for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
    for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0';
    x *= f;
}
inline void build_sa()
{
        static int x[N] , y[N] , cnt[N];
        for (rint i = 1; i <= n; ++i) ++cnt[s[i]];
        for (rint i = 1; i <= 256; ++i) cnt[i] += cnt[i - 1];
        for (rint i = n; i >= 1; --i) sa[cnt[s[i]]--] = i;
        rk[sa[1]] = 1;
        for (rint i = 2; i <= n; ++i) rk[sa[i]] = rk[sa[i - 1]] + (s[sa[i]] != s[sa[i - 1]]);
        for (rint k = 1; rk[sa[n]] != n; k <<= 1)
        {
                for (rint i = 1; i <= n; ++i)
                        x[i] = rk[i] , y[i] = (i + k <= n) ? rk[i + k] : 0;
                for (rint i = 1; i <= n; ++i) cnt[i] = 0;
                for (rint i = 1; i <= n; ++i) ++cnt[y[i]];
                for (rint i = 1; i <= n; ++i) cnt[i] += cnt[i - 1];
                for (rint i = n; i >= 1; i--) rk[cnt[y[i]]--] = i;
                for (rint i = 1; i <= n; ++i) cnt[i] = 0;
                for (rint i = 1; i <= n; ++i) ++cnt[x[i]];
                for (rint i = 1; i <= n; ++i) cnt[i] += cnt[i - 1];
                for (rint i = n; i >= 1; i--) sa[cnt[x[rk[i]]]--] = rk[i];
                rk[sa[1]] = 1;
                for (rint i = 2; i <= n; ++i) rk[sa[i]] = rk[sa[i - 1]] + (x[sa[i]] != x[sa[i - 1]] || y[sa[i]] != y[sa[i - 1]]);
        }        
}
inline void get_height()
{
        int k = 0;
        for (rint i = 1; i <= n; ++i)
        {
                if (k) --k;
                int j = sa[rk[i] + 1];
                while (s[i + k] == s[j + k]) ++k;
                height[rk[i]] = k;        
        }        
}
inline bool cmp(info a , info b)
{
        return a.ht > b.ht;
}
inline int get_root(int x)
{
        if (fa[x] == x) return x;
        else return fa[x] = get_root(fa[x]);
}
inline void merge(int x , int y)
{
        if (sz[x] > sz[y]) swap(x , y);
        fa[x] = y;
        cnta[y] += cnta[x];
        cntb[y] += cntb[x];
        cntc[y] += cntc[x];
        sz[y] += sz[x];
        return;
}
inline int _min(int x , int y , int z)
{
        return min(min(x , y) , z);
}

int main()
{
        
        scanf("%s%s%s" , s1 + 1 , s2 + 1 , s3 + 1);
        int l1 = strlen(s1 + 1) , l2 = strlen(s2 + 1)  , l3 = strlen(s3 + 1);
        for (rint i = 1; i <= l1; ++i) 
        {
                s[++n] = s1[i];
                bel[n] = 1;
        }
        s[++n] = '#';
        for (rint i = 1; i <= l2; ++i) 
        {
                s[++n] = s2[i];
                bel[n] = 2;
        }
        s[++n] = '@';
        for (rint i = 1; i <= l3; ++i) 
        {
                s[++n] = s3[i];
                bel[n] = 3;
        }
        build_sa();
        get_height();
        for (rint i = 1; i < n; ++i)
        {
                a[i].id = i;
                a[i].ht = height[i];        
        }
        sort(a + 1 , a + n , cmp);
        for (rint i = 1; i <= n; ++i)
        {
                fa[i] = i;
                sz[i] = 1;
                if (bel[sa[i]] == 1) cnta[i] = 1;
                if (bel[sa[i]] == 2) cntb[i] = 1;
                if (bel[sa[i]] == 3) cntc[i] = 1;    
        }
        for (rint i = 1; i < n; ++i)
        {
                int rk = a[i].id , ht = a[i].ht;
                if (!ht) break;
                int fx = get_root(rk) , fy = get_root(rk + 1);
                add(ans[ht] , 1LL * cnta[fx] * cntb[fx] % P * cntc[fy] % P);
                add(ans[ht]    , 1LL * cnta[fx] * cntc[fx] % P * cntb[fy] % P);
                add(ans[ht] , 1LL * cntb[fx] * cntc[fx] % P * cnta[fy] % P);
                add(ans[ht] , 1LL * cnta[fy] * cntb[fy] % P * cntc[fx] % P);
                add(ans[ht] , 1LL * cnta[fy] * cntc[fy] % P * cntb[fx] % P);
                add(ans[ht] , 1LL * cntb[fy] * cntc[fy] % P * cnta[fx] % P);
                merge(fx , fy);
        }
         for (rint i = n; i >= 1; --i) add(ans[i] , ans[i + 1]);
         for (rint i = 1; i <= _min(l1 , l2 , l3); ++i) printf("%lld " , ans[i]);
         printf("
");
         
        return 0;
    
}
原文地址:https://www.cnblogs.com/evenbao/p/10660041.html