SPOJ(后缀数组求不同子串个数)

DISUBSTR - Distinct Substrings

Given a string, we need to find the total number of its distinct substrings.

Input

T- number of test cases. T<=20;
Each test case consists of one string, whose length is <= 1000

Output

For each test case output one number saying the number of distinct substrings.

Example

Sample Input:
2
CCCCC
ABABA

Sample Output:
5
9

Explanation for the testcase with string ABABA: 
len=1 : A,B
len=2 : AB,BA
len=3 : ABA,BAB
len=4 : ABAB,BABA
len=5 : ABABA
Thus, total number of distinct substrings is 9.

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int MAXN=200005;
char buf[MAXN];
int sa[MAXN];
int rnk[MAXN];
int tmp[MAXN];
int lcp[MAXN];
int len,k;

bool comp(int i,int j)
{
    if(rnk[i]!=rnk[j])    return rnk[i]<rnk[j];
    else{
        int ri=(i+k<=len)?rnk[i+k]:-1;
        int rj=(j+k<=len)?rnk[j+k]:-1;
        return ri<rj;
    }
}

void getsa()
{
    memset(rnk,0,sizeof(rnk));
    memset(sa,0,sizeof(sa));
    memset(tmp,0,sizeof(tmp));
    
    len=strlen(buf);
    for(int i=0;i<len;i++)
    {
        sa[i]=i;
        rnk[i]=buf[i]-'A';
    }
    sa[len]=len;
    rnk[len]=-1;
    
    for(k=1;k<=len;k*=2)
    {
        sort(sa,sa+len+1,comp);
        
        tmp[sa[0]]=0;
        for(int i=1;i<=len;i++)
        {
            tmp[sa[i]]=tmp[sa[i-1]]+(comp(sa[i-1],sa[i])?1:0);
        }
        
        for(int i=0;i<=len;i++)
        {
            rnk[i]=tmp[i];
        }
    }
}

void getlcp()
{
    memset(rnk,0,sizeof(rnk));
    memset(lcp,0,sizeof(lcp));
    getsa();
    for(int i=0;i<=len;i++)
    {
        rnk[sa[i]]=i;
    }
    
    int h=0;
    lcp[0]=h;
    for(int i=0;i<len;i++)
    {
        int j=sa[rnk[i]-1];
        if(h>0)    h--;
        for(;h+i<len&&h+j<len;h++)
            if(buf[i+h]!=buf[j+h])    break;
        lcp[rnk[i]-1]=h;
    }    
}

void debug()
{
    for(int i=0;i<=len;i++)
    {
        int l=sa[i];
        if(l==len)
        {
            printf("%d %d
",l,lcp[i]);
        }
        else
        {
            for(int j=l;j<len;j++)
                printf("%c ",buf[j]);
            printf("%d
",lcp[i]);
        }
    }
    
}
int T;
int main()
{
    scanf("%d",&T);
    while(T--)
    {
        scanf("%s",buf);
        int res=0;
        getlcp();
        res+=(len+1)*len/2;
        for(int i=0;i<=len;i++)
            res-=lcp[i];
        printf("%d
",res);
    }
    
    return 0;
}
原文地址:https://www.cnblogs.com/program-ccc/p/5236662.html