bzoj 3160: 万径人踪灭【FFT+manacher】

考虑正难则反,我们计算所有对称子序列个数,再减去连续的
这里减去连续的很简单,manacher即可
然后考虑总的,注意到关于一个中心对称的两点下标和相同(这样也能包含以空位为对称中心的方案),所以设f[i]为下标和为i的对称中心一共有多少对相同字符,这样总答案就是( sum_{i=0}{2*n-2}2{f[i]}-1 )(减掉的1是减掉空集)
然后考虑f怎么求,( f[i]=((sum_{j=0}^{i-1}s[j]s[i-j])+1)/2 ),除2是因为每一对都被算了两遍
暴力是不行的,但是这个(j,i-j)看着非常卷积,所以考虑怎么用FFT优化一下,把a字符和b字符分开算,以算a字符为例,做一个数组a[i]=(s[i]
'a')?1:0,然后多项式是( sum_{j=0}^{i-1}a[j]*a[i-j]) ),这样就只算了对称且两点都是a的方案数,b同理
然后减掉manacher结果即可

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
using namespace std;
const int N=500005,mod=1e9+7;
int n,f[N],bt,lm,re[N];
long long ans;
char c[N],s[N];
struct cd
{
	double a,b;
	cd(double A=0,double B=0)
	{
		a=A,b=B;
	}
	cd operator + (const cd &x) const
	{
		return cd(a+x.a,b+x.b);
	}
	cd operator - (const cd &x) const
	{
		return cd(a-x.a,b-x.b);
	}
	cd operator * (const cd &x) const
	{
		return cd(a*x.a-b*x.b,a*x.b+b*x.a);
	}
}a[N],b[N];
void dft(cd a[],int f)
{
	for(int i=0;i<lm;i++)
		if(i<re[i])
			swap(a[i],a[re[i]]);
	for(int i=1;i<lm;i<<=1)
	{
		cd wi=cd(cos(M_PI/i),f*sin(M_PI/i));
		for(int k=0;k<lm;k+=(i<<1))
		{
			cd w=cd(1,0),x,y;
			for(int j=0;j<i;j++)
			{
				x=a[j+k],y=w*a[i+j+k];
				a[j+k]=x+y,a[i+j+k]=x-y;
				w=w*wi;
			}
		}
	}
	if(f==-1)
		for(int i=0;i<lm;i++)
			a[i].a/=lm;
}
long long ksm(long long a,long long b)
{
	long long r=1;
	while(b)
	{
		if(b&1)
			r=r*a%mod;
		a=a*a%mod;
		b>>=1;
	}
	return r;
}
int main()
{
	scanf("%s",c);
	n=strlen(c);
	for(bt=0;(1<<bt)<=2*n;bt++);
	lm=(1<<bt);
	for(int i=0;i<lm;i++)
		re[i]=(re[i>>1]>>1)|((i&1)<<(bt-1));
	for(int i=0;i<lm;i++)
		a[i].a=(c[i]=='a')?1:0,a[i].b=0;
	dft(a,1);
	for(int i=0;i<lm;i++)
		b[i]=a[i]*a[i];
	for(int i=0;i<lm;i++)
		a[i].a=(c[i]=='b')?1:0,a[i].b=0;
	dft(a,1);
	for(int i=0;i<lm;i++)
		b[i]=b[i]+a[i]*a[i];
	dft(b,-1);
	for(int i=0;i<=2*(n-1);i++)
		ans=(ans+ksm(2,((int)(b[i].a+0.5)+1)/2)-1)%mod;//cerr<<ans<<endl;
	for(int i=0;i<n;i++)
		s[(i+1)*2]=c[i],s[(i+1)*2+1]='#';
	s[0]='$',s[1]='#',s[2*n+2]='&';
	for(int i=1,mx=0,w;i<2*n+1;i++)
	{
		if(i<mx)
			f[i]=min(f[2*w-i],mx-i);
		else
			f[i]=1;
		for(;s[i-f[i]]==s[i+f[i]];f[i]++);
		if(i+f[i]>mx)
			mx=i+f[i],w=i;
		ans=(ans-f[i]/2+mod)%mod;
	}
	printf("%lld
",(ans%mod+mod)%mod);
	return 0;
}
原文地址:https://www.cnblogs.com/lokiii/p/10036804.html