回文串计数

传送门

题目描述有点复杂,但其实就是要求给定字符串中不相交的回文串个数。

提到不相交的话……我们好像突然想到以前做最长双回文串的时候拼接的这么个思路……所以我们依然可以先用manacher处理出来每个点的最长回文半径,那么在这个点能拓展出来的回文串之内,每一个在中心点前面的点都可以作为一个回文串的开始,与之相对应的,在中心点后面的点可以作为一个回文串的结束。这样的话其实相当于是一次区间修改,我们可以使用差分维护。之后,对于以一个点为末尾的回文串,与之不相交的回文串个数是开头在其之后的所有回文串个数。(我们是从前向后计算的,所以不用计算前面的),这样的话我们计算一下后缀和,每个点分别 乘起来,和就是结果。(当然倒着求前缀和也可以)

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<iostream>
#include<cmath>
#include<set>
#include<vector>
#include<queue>
#define pb push_back
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define per(i,n,a) for(int i = n;i >= a;i--)
#define enter putchar('
')

using namespace std;
typedef long long ll;
const int M = 40005;
const int N = 2000005;
const ll mod = 51123987;

ll read()
{
    ll ans = 0,op = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9')
    {
    if(ch == '-') op = -1;
    ch = getchar();
    }
    while(ch >= '0' && ch <= '9')
    {
    ans *= 10;
    ans += ch - '0';
    ch = getchar();
    }
    return ans * op;
}

ll p[N<<1],mx,mid,len,ml[N<<1],mr[N<<1],sum,l;
char s[N<<1],c[N];

int change()
{
    l = strlen(c);
    int j = 2;
    s[0] = '!',s[1] = '#';
    rep(i,0,l-1) s[j++] = c[i],s[j++] = '#';
    s[j] = '&';
    return j;
}

void manacher()
{
    len = change(),mx = mid = 1;
    rep(i,1,len-1)
    {
    if(i < mx) p[i] = min(mx-i,p[(mid<<1)-i]);
    else p[i] = 1;
    while(s[i-p[i]] == s[i+p[i]]) p[i]++;
    if(mx < i + p[i]) mid = i,mx = i + p[i];
    mr[(i+1)>>1]++,mr[(i+p[i])>>1]--;
    ml[((i-p[i])>>1)+1]++,ml[(i>>1)+1]--;
    }
    //rep(i,0,l+1) printf("%lld ",mr[i]);enter;
    //rep(i,0,l+1) printf("%lld ",ml[i]);enter;
}

void solve()
{
    rep(i,1,l) ml[i] += ml[i-1],mr[i] += mr[i-1];
    ml[l+1] = 0;
    per(i,l,0) ml[i] += ml[i+1];
    rep(i,1,l) sum += mr[i] * ml[i+1];
}

int main()
{
    scanf("%s",c);
    manacher();
    solve();
    printf("%lld
",sum);
    return 0;
}

还有一道题与它是一样的,CF17E。传送门

这题要求的是相交的回文串个数,我们只要先求出一共有多少回文串(k),之后我们计算出来回文串对数的总数(k×(k-1)/2)减去所有不重合的即可。

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<iostream>
#include<cmath>
#include<set>
#include<vector>
#include<queue>
#define pb push_back
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define per(i,n,a) for(int i = n;i >= a;i--)
#define enter putchar('
')

using namespace std;
typedef long long ll;
const int M = 40005;
const int N = 2000005;
const ll mod = 51123987;

ll read()
{
    ll ans = 0,op = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9')
    {
    if(ch == '-') op = -1;
    ch = getchar();
    }
    while(ch >= '0' && ch <= '9')
    {
    ans *= 10;
    ans += ch - '0';
    ch = getchar();
    }
    return ans * op;
}

ll n,mx,mid,ml[N<<1],mr[N<<1],sum,cur,len;
ll p[N<<1];
char s[N<<1],c[N];

int change()
{
    int l = strlen(c),j = 2;
    s[0] = '!',s[1] = '#';
    rep(i,0,l-1) s[j++] = c[i],s[j++] = '#';
    s[j] = '$';
    return j;
}

void manacher()
{
    len = change(),mx = mid = 1;
    rep(i,1,len-1)
    {
    if(i < mx) p[i] = min(mx-i,p[(mid<<1)-i]);
    else p[i] = 1;
    while(s[i-p[i]] == s[i+p[i]]) p[i]++;
    if(mx < i + p[i]) mid = i,mx = i + p[i];
    if(!(i&1)) sum++;//a lowercase letter needs add 1.
    sum += (p[i]-1) >> 1,sum %= mod;
    }
    sum = sum * (sum-1) >> 1,sum %= mod;
}

void solve()
{
    for(int i = 2;i <= len-1;i += 2) ml[i-p[i]+2]++,ml[i+2]--,mr[i]++,mr[i+p[i]]--;
    for(int i = 1;i <= len-1;i += 2) ml[i-p[i]+2]++,ml[i+1]--,mr[i+1]++,mr[i+p[i]]--;
    for(int i = 2;i <= len-1;i += 2) ml[i] += ml[i-2],ml[i] %= mod,mr[i] += mr[i-2],mr[i] %= mod;
    ml[len] = 0;
    for(int i = len-2;i >= 1;i -= 2) ml[i] += ml[i+2],ml[i] %= mod;
    for(int i = 2;i <= len-1;i += 2) cur += mr[i] * ml[i+2] % mod,cur %= mod;
    sum -= cur,sum = (sum + mod) % mod;
}

int main()
{
    n = read();
    scanf("%s",c);
    manacher();
    solve();
    printf("%I64d
",sum);
    return 0;
}
原文地址:https://www.cnblogs.com/captain1/p/9776075.html