【HDU-5785】Interesting(回文串的性质+回文自动机+map空间优化)

题目链接:https://vjudge.net/problem/HDU-5785

题目大意

给定一个字符串,求有多少对三元组 ((i, j, k)) 满足 (1≤i≤j<k≤|S|),要求 (S[i,...j])(S[j+1, .. k]) 都为回文串,对 (1e9+7) 取模。

思路

是对回文的端点进行计数,可上回文自动机。

由回文树的性质可知,假设字符串的第 (i) 个点在回文树上的编号为 (p_{i}),那么其在 (fail) 树上的祖先为以第 (i) 个点为右端点的所有回文串。

假设回文树的右端为 (i),其长度为 (len_{p_{i}}),那么左端点就为 (i-len_{p_{i}}+1)

那么其左端点之和就为 ((i-len_{p_{i}}+1)+(i-len_{fail[p_{i}]}+1)+(i-len_{fail[fail[p_{i}]]}+1)+...)

每次在回文树上增加节点时,维护 (len_{p_{i}}) 之和,记为 (sum_{p_{i}})

那么则将式子转化成 ((i+i+..+i) + (1+1+...+1) - sum_{p_{i}}),其中 ((i+i+...+i) = num_{p_{i}}) 即以这个点为右端点时回文串个数,要求这个可以做洛谷模板题

从右向左同理,在转换时需要注意小细节。

但是由于空间特别卡,需要用 (map) 来优化空间。

AC代码

#include <bits/stdc++.h>

typedef long long ll;
using namespace std;
const int MAXN = 1e6 + 5;
const int MAXC = 26;
const int mod = 1000000007;

int nm[MAXN][2];

char str[MAXN];

class PAM {
public:
    struct node {
        map<int, int> ch;// int ch[MAXC];
        int fail, len, num;
        ll sum;
    } T[MAXN];
    int las, tot;

    inline int get_fail(int x, int pos) {
        while (str[pos - T[x].len - 1] != str[pos]) {
            x = T[x].fail;
        }
        return x;
    }

    void init() {
        T[0].ch.clear(), T[1].ch.clear();
        // memset(T[0].ch, 0, sizeof(T[0].ch)), memset(T[1].ch, 0, sizeof(T[1].ch));
        T[0].fail = 1, T[1].fail = 0;
        T[0].len = 0, T[1].len = -1;
        T[0].num = T[1].num = T[0].sum = T[1].sum = 0;
        las = 0, tot = 1;
    }

    void insert1(char s[], int len) {
        s[0] = -1;
        for (int i = 1; i <= len; i++) {
            int p = get_fail(las, i);
            if (!T[p].ch[s[i]-'a']) {
                T[++tot].len = T[p].len + 2;
                T[tot].ch.clear();// memset(T[tot].ch, 0, sizeof(T[tot].ch));
                int u = get_fail(T[p].fail, i);
                T[tot].fail = T[u].ch[s[i]-'a'];
                T[tot].num = T[T[tot].fail].num + 1;
                T[tot].sum = (T[T[tot].fail].sum + T[tot].len) % mod;
                T[p].ch[s[i]-'a'] = tot;
            }
            las = T[p].ch[s[i]-'a'];
            nm[i][0] = ((ll) T[las].num * (i + 1) % mod - T[las].sum + mod) % mod;
        }
    }

    void insert2(char s[], int len) {
        s[0] = 0;
        for (int i = 1; i <= len; i++) {
            int p = get_fail(las, i);
            if (!T[p].ch[s[i]-'a']) {
                T[++tot].len = T[p].len + 2;
                T[tot].ch.clear();// memset(T[tot].ch, 0, sizeof(T[tot].ch));
                int u = get_fail(T[p].fail, i);
                T[tot].fail = T[u].ch[s[i]-'a'];
                T[tot].num = T[T[tot].fail].num + 1;
                T[tot].sum = (T[T[tot].fail].sum + T[tot].len) % mod;
                T[p].ch[s[i]-'a'] = tot;
            }
            las = T[p].ch[s[i]-'a'];
            int pos = len - i + 1;
            nm[pos][1] = ((ll) T[las].num * (pos - 1) % mod + T[las].sum) % mod;
        }
    }

} tree;


int main() {
    while (~scanf("%s", str + 1)) {
        int n = strlen(str + 1);
        tree.init();
        tree.insert1(str, n);

        int n2 = n >> 1;
        for (int i = 1; i <= n2; i++) swap(str[i], str[n-i+1]);

        tree.init();
        tree.insert2(str, n);

        ll res = 0;
        for (int i = 1; i < n; i++) {
            res = (res + (ll) nm[i][0] * nm[i + 1][1] % mod) % mod;
        }
        printf("%lld
", res);
    }
}
原文地址:https://www.cnblogs.com/tudouuuuu/p/14076213.html