P4199 万径人踪灭 [Manacher + FFT]

万径人踪灭

给出一个由 01 组成的字符串,问该字符串有多少不同的子序列满足:

  1. 子序列是一个回文序列
  2. 子序列不连续,即这个回文串不可以是原字符串上连续的子串

回文序列不仅要求值回文, 且要求位置回文.


color{red}{正解部分}

Ans=numnumAns = num_{位置对称的回文子序列} - num_{回文子串}

设前者为 num1num_1, 后者为 num2num_2, 其中 num2num_2 可以通过 manachermanacher 得出, 没学过的可以看 这里 .

所以现在只需考虑 num1num_1 怎么计算 .

  • 以整数位置 ii 为对称轴, 设满足 Sij=Si+jS_{i-j}=S_{i+j} 的字母对数为 kk, 算上 Si=SiS_i = S_i 总共产生了 2k+112^{k+1}-1 种回文子序列, 那个 1-1 是减去 全部都不选 的情况 .
  • 以小数位置 ii 为对称轴, 设满足 Sij=Si+jS_{i-j}=S_{i+j} 的字母对数为 kk, 则总共产生了 2k12^k-1 种回文子序列 .

构造多项式 Ai=[Si== a]A_i = [S_i== 'a'], 则

A2i2=a+b=2iAaAbA_{2i}^2=sumlimits_{a+b=2i} A_aA_b

于是 A2i2A_{2i}^2 就表示以 ii 位置为对称轴, Sij=Si+j=aS_{i-j}=S_{i+j}='a' 的对数

于是 A2i22lceil frac{A_{2i}^2}{2} ceil 就表示以 ii 位置为对称轴, Sij=Si+j=aS_{i-j}=S_{i+j}='a' 的对数

同理设 Bi=[Si== b]B_i = [S_i == 'b'], B2i22lceil frac{B_{2i}^2}{2} ceil就表示以 ii 为对称轴, Sij=Si+j=bS_{i-j}=S_{i+j}='b' 的对数 .


然后 Ai2+Bi2A_i^2 + B_i^2 就可以得到以 i/2i/2 为对称轴, Si2j=Si2+jS_{frac{i}{2}-j}=S_{frac{i}{2}+j} 的对数 .

纵使 i/2i/2 为小数也成立, 所以可以完美覆盖上方情况 .


color{red}{实现部分}

其中多项式乘法可以使用 FFTFFT 实现, 没学过的可以看 这里 .

#include<bits/stdc++.h>
typedef long long ll;
#define reg register

const int maxn = 200005;
const int mod = 1e9 + 7;
const double Pi = acos(-1);

int N;
int FFT_Len;
int pw[maxn];
int rev[maxn<<2];
int hw[maxn<<1];

char S[maxn<<1];
char t[maxn<<1];

struct complex{
        double x, y;
        complex(double x=0, double y=0):x(x), y(y) {}
} A[maxn<<2], B[maxn<<2];

complex operator + (complex a, complex b){ return complex(a.x+b.x, a.y+b.y); }
complex operator - (complex a, complex b){ return complex(a.x-b.x, a.y-b.y); }
complex operator * (complex a, complex b){ return complex(a.x*b.x-a.y*b.y, a.x*b.y+a.y*b.x); }

int Ksm(int a, ll b){
        int s = 1;
        while(b){
                if(b & 1) s = 1ll*s*a % mod;
                a = 1ll*a*a % mod; b >>= 1;
        }
        return s;
}

int Manacher(){
        int res = 0;

        t[0] = '#';
        for(reg int i = 1; i <= N; i ++) t[i*2-1] = S[i], t[i*2] = '#';
        t[N*2+1] = '#';

        int Max_r = 0, mid = 0;
        for(reg int i = 1; i <= N<<1; i ++){
                if(i <= Max_r) hw[i] = std::min(hw[(mid<<1)-i], Max_r-i+1);
                while(i-hw[i] >= 0 && i+hw[i] <= (N<<1)+1 && t[i-hw[i]] == t[i+hw[i]]) hw[i] ++;
                if(i+hw[i]-1 > Max_r) Max_r = i+hw[i]-1, mid = i;
                res += hw[i]/2;  // !
                if(res >= mod) res -= mod;
        }

        return res;
}

void FFT(complex *F, int opt){
        for(reg int i = 0; i < FFT_Len; i ++)
                if(i < rev[i]) std::swap(F[i], F[rev[i]]);
        for(reg int p = 2; p <= FFT_Len; p <<= 1){
                int half = p >> 1;
                complex t = complex(cos(Pi/half), opt*sin(Pi/half));
                for(reg int i = 0; i < FFT_Len; i += p){
                        complex buf = complex(1, 0);
                        for(reg int k = i; k < i+half; k ++){
                                complex Tmp = buf * F[k + half];
                                F[k + half] = F[k] - Tmp;
                                F[k] = F[k] + Tmp;
                                buf = buf * t;
                        }
                }
        }
}

int Calc(){
        int res = 0;

        for(reg int i = 1; i <= N; i ++) A[i].x = S[i]=='a', B[i].x = S[i]=='b';

        FFT_Len = 1; int bit_n = 0;
        while(FFT_Len <= (N<<1)) bit_n ++, FFT_Len <<= 1;
        for(reg int i = 0; i < FFT_Len; i ++) rev[i] = (rev[i>>1]>>1) | ((i&1) << bit_n-1);
        FFT(A, 1), FFT(B, 1);
        for(reg int i = 0; i < FFT_Len; i ++) A[i] = A[i]*A[i] + B[i]*B[i];
        FFT(A, -1);
        for(reg int i = 0; i < FFT_Len; i ++) A[i].x = (A[i].x + 0.5)/FFT_Len;

        pw[0] = 1;
        for(reg int i = 1; i <= N; i ++) pw[i] = 2ll*pw[i-1] % mod;

        for(reg int i = 1; i <= (N<<1)+1; i ++){
                ll t1 = (A[i].x + 1)/2;
                res += pw[t1] - 1;
                if(res >= mod) res -= mod;
        }

        return res;
}

int main(){
        scanf("%s", S+1);
        N = strlen(S+1);
        int p = Manacher();
        printf("%d
", (1ll*Calc()-p+mod)%mod);
        return 0;
}


原文地址:https://www.cnblogs.com/zbr162/p/11822528.html