bzoj3160 万径人踪灭

题目描述

题解:

认为大爷讲的最好。

代码:

#include<cmath>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define ll long long
#define N 200050
#define MOD 1000000007
const double Pi = acos(-1.0);
ll fastpow(ll x,int y)
{
    ll ret = 1;
    while(y)
    {
        if(y&1)ret=ret*x%MOD;
        x=x*x%MOD;
        y>>=1;
    }
    return ret;
}
struct cp
{
    double x,y;
    cp(){}
    cp(double x,double y):x(x),y(y){}
};
cp operator + (cp &a,cp &b)
{
    return cp(a.x+b.x,a.y+b.y);
}
cp operator - (cp &a,cp &b)
{
    return cp(a.x-b.x,a.y-b.y);
}
cp operator * (cp &a,cp &b)
{
    return cp(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
}
char ch[N],s[4*N];
int to[2*N];
void fft(cp *a,int len,int k)
{
    for(int i=0;i<len;i++)
        if(i<to[i])swap(a[i],a[to[i]]);
    for(int i=1;i<len;i<<=1)
    {
        cp w0(cos(Pi/i),k*sin(Pi/i));
        for(int j=0;j<len;j+=(i<<1))
        {
            cp w(1,0);
            for(int o=0;o<i;o++,w=w*w0)
            {
                cp w1 = a[j+o],w2 = a[j+o+i]*w;
                a[j+o] = w1+w2;
                a[j+o+i] = w1-w2;
            }
        }
    }
}
cp a[2*N],c[2*N];
int len,lim=1,l;
ll ans[2*N],sum;
void sol()
{
    fft(a,lim,1);
    for(int i=0;i<lim;i++)c[i]=a[i]*a[i];
    fft(c,lim,-1);
    for(int i=0;i<lim;i++)ans[i]+=(ll)(c[i].x/lim+0.5);
}
int sl;
void get_s()
{
    s[0]='!',s[sl=1]='#';
    for(int i=0;i<len;i++)
    {
        s[++sl]=ch[i];
        s[++sl]='#';
    }
    s[sl+1] = '@';
}
int rp[2*N];
void manacher()
{
    int mid = 0,mx = 0;
    for(int i=1;i<=sl;i++)
    {
        if(i<=mx)rp[i] = min(rp[2*mid-i],mx-i+1);
        else rp[i] = 1;
        while(s[i+rp[i]]==s[i-rp[i]])rp[i]++;
        if(i+rp[i]-1>mx)mx=i+rp[i]-1,mid=i;
        sum-=rp[i]/2;
    }
}
int main()
{
    scanf("%s",ch);
    len = strlen(ch);
    while(lim<2*len)lim<<=1,l++;
    for(int i=1;i<lim;i++)to[i]=((to[i>>1]>>1)|((i&1)<<(l-1)));
    for(int i=0;i<len;i++)a[i].x=(ch[i]=='a');
    sol();
    for(int i=0;i<lim;i++)a[i].x=a[i].y=0;
    for(int i=0;i<len;i++)a[i].x=(ch[i]=='b');
    sol();
    for(int i=0;i<lim;i++)ans[i]>>=1;
    for(int i=0;i<2*len;i+=2)ans[i]++;
    get_s();
    manacher();
    for(int i=0;i<lim;i++)(sum+=fastpow(2,ans[i])-1ll)%=MOD;
    printf("%lld
",(sum+MOD)%MOD);
    return 0;
}
原文地址:https://www.cnblogs.com/LiGuanlin1124/p/10258681.html