题目链接:https://vjudge.net/problem/HDU-5785
题目大意
给定一个字符串,求有多少对三元组 ((i, j, k)) 满足 (1≤i≤j<k≤|S|),要求 (S[i,...j]) 和 (S[j+1, .. k]) 都为回文串,对 (1e9+7) 取模。
思路
是对回文的端点进行计数,可上回文自动机。
由回文树的性质可知,假设字符串的第 (i) 个点在回文树上的编号为 (p_{i}),那么其在 (fail) 树上的祖先为以第 (i) 个点为右端点的所有回文串。
假设回文树的右端为 (i),其长度为 (len_{p_{i}}),那么左端点就为 (i-len_{p_{i}}+1)。
那么其左端点之和就为 ((i-len_{p_{i}}+1)+(i-len_{fail[p_{i}]}+1)+(i-len_{fail[fail[p_{i}]]}+1)+...)
每次在回文树上增加节点时,维护 (len_{p_{i}}) 之和,记为 (sum_{p_{i}})。
那么则将式子转化成 ((i+i+..+i) + (1+1+...+1) - sum_{p_{i}}),其中 ((i+i+...+i) = num_{p_{i}}) 即以这个点为右端点时回文串个数,要求这个可以做洛谷模板题。
从右向左同理,在转换时需要注意小细节。
但是由于空间特别卡,需要用 (map) 来优化空间。
AC代码
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int MAXN = 1e6 + 5;
const int MAXC = 26;
const int mod = 1000000007;
int nm[MAXN][2];
char str[MAXN];
class PAM {
public:
struct node {
map<int, int> ch;// int ch[MAXC];
int fail, len, num;
ll sum;
} T[MAXN];
int las, tot;
inline int get_fail(int x, int pos) {
while (str[pos - T[x].len - 1] != str[pos]) {
x = T[x].fail;
}
return x;
}
void init() {
T[0].ch.clear(), T[1].ch.clear();
// memset(T[0].ch, 0, sizeof(T[0].ch)), memset(T[1].ch, 0, sizeof(T[1].ch));
T[0].fail = 1, T[1].fail = 0;
T[0].len = 0, T[1].len = -1;
T[0].num = T[1].num = T[0].sum = T[1].sum = 0;
las = 0, tot = 1;
}
void insert1(char s[], int len) {
s[0] = -1;
for (int i = 1; i <= len; i++) {
int p = get_fail(las, i);
if (!T[p].ch[s[i]-'a']) {
T[++tot].len = T[p].len + 2;
T[tot].ch.clear();// memset(T[tot].ch, 0, sizeof(T[tot].ch));
int u = get_fail(T[p].fail, i);
T[tot].fail = T[u].ch[s[i]-'a'];
T[tot].num = T[T[tot].fail].num + 1;
T[tot].sum = (T[T[tot].fail].sum + T[tot].len) % mod;
T[p].ch[s[i]-'a'] = tot;
}
las = T[p].ch[s[i]-'a'];
nm[i][0] = ((ll) T[las].num * (i + 1) % mod - T[las].sum + mod) % mod;
}
}
void insert2(char s[], int len) {
s[0] = 0;
for (int i = 1; i <= len; i++) {
int p = get_fail(las, i);
if (!T[p].ch[s[i]-'a']) {
T[++tot].len = T[p].len + 2;
T[tot].ch.clear();// memset(T[tot].ch, 0, sizeof(T[tot].ch));
int u = get_fail(T[p].fail, i);
T[tot].fail = T[u].ch[s[i]-'a'];
T[tot].num = T[T[tot].fail].num + 1;
T[tot].sum = (T[T[tot].fail].sum + T[tot].len) % mod;
T[p].ch[s[i]-'a'] = tot;
}
las = T[p].ch[s[i]-'a'];
int pos = len - i + 1;
nm[pos][1] = ((ll) T[las].num * (pos - 1) % mod + T[las].sum) % mod;
}
}
} tree;
int main() {
while (~scanf("%s", str + 1)) {
int n = strlen(str + 1);
tree.init();
tree.insert1(str, n);
int n2 = n >> 1;
for (int i = 1; i <= n2; i++) swap(str[i], str[n-i+1]);
tree.init();
tree.insert2(str, n);
ll res = 0;
for (int i = 1; i < n; i++) {
res = (res + (ll) nm[i][0] * nm[i + 1][1] % mod) % mod;
}
printf("%lld
", res);
}
}