HDU-6599 I Love Palindrome String(回文自动机+字符串hash)

题目链接

题意:给定一个字符串(|S|le 3 imes 10^5)
对于每个 (iin [1,|S|]) 求有多少子串(s_ls_{l+1}cdots s_r)满足下面条件

  • (r-l+1 = i)
  • (s_ls_{l+1}cdots s_r)是一个回文串
  • (s_ls_{l+1}cdots s_{lfloor(l+r)/2 floor})也是一个回文串

回文树学习:https://blog.csdn.net/Clove_unique/article/details/53750322

根据条件可知(s_ls_{l+1}cdots s_{lfloor(l+r)/2 floor-1} == s_{lfloor(l+r)/2 floor}cdots s_r)
先用回文自动机求出本质不同的回文串

  • 第一种方法是回文树中创建新节点(意味着新的回文串)时判断该回文串是否满足上面条件,再建好树之后从后向前计算个数时累加即可
  • 第二种是利用回文树中fail指针特性,每次循环找fail所指的回文串是否满足条件,直到根节点或找到为止,此次寻找结果可以更新到下一层fail指针所指的节点中(避免多次重复寻找,不然会T)

第一种方法代码:

#include <bits/stdc++.h>
using namespace std;
const int N = 3e5+10;
typedef unsigned long long ull;
typedef long long ll;
ull has[N],pn[N],base = 131;
ull getVal(int l,int r){
    return has[r] - has[l-1] * pn[r-l+1];
}
ll res[N];
char s[N];
namespace PAT{
    const int SZ = 6e5+10;
    int ch[SZ][26],fail[SZ],cnt[SZ],len[SZ],tot,last,pos[SZ],satisfy[SZ];
    void init(int n){
        for(int i=0;i<=n+10;i++){
            fail[i] = cnt[i] = len[i] = 0;
            for(int j=0;j<26;j++)ch[i][j] = 0;
        }
        s[0] = -1;fail[0] = 1;last = 0;
        len[0] = 0;len[1] = -1,tot = 1;
    }
    inline int newnode(int x){
        len[++tot] = x;return tot;
    }
    inline int getfail(int x,int n){
        while(s[n-len[x]-1] != s[n])x = fail[x];
        return x;
    }
    void create(char *s,int n){
        s[0] = -1;
        for(int i=1;i<=n;++i){
            int t = s[i]- 'a';
            int p = getfail(last,i);
            if(!ch[p][t]){
                int q = newnode(len[p]+2);
                fail[q] = ch[getfail(fail[p],i)][t];
                ch[p][t] = q;
                int need = (len[q] + 1) /2;
                if(len[q] == 1 || getVal(i-len[q]+1,i-len[q] + need) == getVal(i-need +1,i))satisfy[q] = 1;
                else satisfy[q] = 0;
            }
            ++cnt[last = ch[p][t]];
        }
    }
    void solve(){
        for(int i=tot;i>=2;i--){
            cnt[fail[i]] += cnt[i];
            if(satisfy[i])res[len[i]] += cnt[i];
        }
    }
}
int main(){
    pn[0] = 1;
    for(int i=1;i<N;i++)pn[i] = pn[i-1] * base;

    while(~scanf("%s",s+1)){
        int n = strlen(s+1);
        has[0] = 1;
        for(int i=1;i<=n;i++){
            has[i] = has[i-1] * base + s[i];
        }
        for(int i=1;i<=n;i++)res[i] = 0;
        PAT::init(n);
        PAT::create(s,n);
        PAT::solve();
        for(int i=1;i<n;i++)printf("%lld ",res[i]);
        printf("%lld
",res[n]);
    }
    return 0;
}

第二种方法code

#include <bits/stdc++.h>
using namespace std;
const int N = 3e5+10;
typedef long long ll;
char s[N];
ll res[N];
namespace PAT{
    const int SZ = 6e5+10;
    int ch[SZ][26],fail[SZ],cnt[SZ],len[SZ],tot,last;
    int be[SZ],ok[SZ];
    void init(int n){
        for(int i=0;i<=n+10;i++){
            fail[i] = cnt[i] = len[i] = 0;
            for(int j=0;j<26;j++)ch[i][j] = 0;
            be[i] = ok[i] = 0;
        }
        s[0] = -1;fail[0] = 1;last = 0;
        len[0] = 0;len[1] = -1;tot = 1;
    }
    inline int newnode(int x){
        len[++tot] = x;return tot;
    }
    inline int getfail(int x,int n){
        while(s[n-len[x]-1] != s[n])x = fail[x];
        return x;
    }
    void create(char *s){
        s[0] = -1;
        for(int i=1;s[i];++i){
            int t = s[i]- 'a';
            int p = getfail(last,i);
            if(!ch[p][t]){
                int q = newnode(len[p]+2);
                fail[q] = ch[getfail(fail[p],i)][t];
                ch[p][t] = q;
            }
            ++cnt[last = ch[p][t]];
        }
    }
    void solve(){
        for(int i=tot;i>=2;i--){
            if(be[i] == 0)be[i] = i;
            while(be[i] >= 2 && len[be[i]] > (len[i] + 1)/2)be[i] = fail[be[i]];
            if(len[be[i]] == (len[i]+1)/2)ok[i] = 1;
            be[fail[i]] = be[i];
        }
        for(int i=tot;i>=2;i--){
            cnt[fail[i]] += cnt[i];
            if(ok[i]) res[len[i]] += cnt[i];
        }
    }
}
int main(){
    while(~scanf("%s",s+1)){
        int n = strlen(s+1);
        for(int i=1;i<=n;i++)res[i] = 0;
        PAT::init(n);
        PAT::create(s);
        PAT::solve();
        for(int i=1;i<n;i++)printf("%lld ",res[i]);
        printf("%lld
",res[n]);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/1625--H/p/11288662.html