【HNOI2017】礼物(FFT)

显然, y i y_i yi 加上 c c c 可以看成是 x i x_i xi 减去 c c c

所以就变成了 x i x_i xi 加上一个整数(可正可负)。

现将 x x x 环拆成一个长度为 2 n 2n 2n 的序列 a a a(复制一遍),把 y y y 环拆成一个长度为 n n n 的序列 b b b

那么旋转操作就可以看成是 b b b 序列与 a a a 序列中每一个长度为 n n n 的子串匹配求值。

也就是说,求这个东西的最小值: ∑ j = 0 n − 1 ( a i + j − b j + c ) 2 sum_{j=0}^{n-1}(a_{i+j}-b_j+c)^2 j=0n1(ai+jbj+c)2 0 ≤ i < n 0leq i< n 0i<n)。

接下来推式子:(设 x x x 环的和是 s u m a suma suma y y y 环的和是 s u m b sumb sumb x x x 环的平方和是 p o w a powa powa y y y 环的平方和是 p o w a powa powa

min ⁡ i = 0 n − 1 ∑ j = 0 n − 1 ( a i + j − b j + c ) 2 = min ⁡ i = 0 n − 1 ∑ j = 0 n − 1 ( a i + j − b j ) 2 + c 2 + 2 ( a i + j − b j ) c = min ⁡ i = 0 n − 1 [ n c 2 + ∑ j = 0 n − 1 ( a i + j − b j ) 2 + 2 c ∑ j = 0 n − 1 ( a i + j − b j ) ] = min ⁡ i = 0 n − 1 [ n c 2 + ∑ j = 0 n − 1 ( a i + j 2 + b j 2 − 2 a i + j b j ) + 2 c ( ∑ j = 0 n − 1 a i + j − ∑ j = 0 n − 1 b j ) ] = min ⁡ i = 0 n − 1 ( n c 2 + p o w a + p o w b − 2 ∑ j = 0 n − 1 a i + j b j + 2 c ( s u m a − s u m b ) ) egin{aligned} &min_{i=0}^{n-1} sum_{j=0}^{n-1}(a_{i+j}-b_j+c)^2\ =&min_{i=0}^{n-1}sum_{j=0}^{n-1}(a_{i+j}-b_j)^2+c^2+2(a_{i+j}-b_j)c\ =&min_{i=0}^{n-1} [nc^2+sum_{j=0}^{n-1}(a_{i+j}-b_j)^2+2csum_{j=0}^{n-1}(a_{i+j}-b_j)]\ =&min_{i=0}^{n-1} [nc^2+sum_{j=0}^{n-1}(a_{i+j}^2+b_j^2-2a_{i+j}b_j)+2c(sum_{j=0}^{n-1} a_{i+j}-sum_{j=0}^{n-1} b_j)]\ =&min_{i=0}^{n-1} (nc^2+powa+powb-2sum_{j=0}^{n-1}a_{i+j}b_j+2c(suma-sumb)) end{aligned} ====i=0minn1j=0n1(ai+jbj+c)2i=0minn1j=0n1(ai+jbj)2+c2+2(ai+jbj)ci=0minn1[nc2+j=0n1(ai+jbj)2+2cj=0n1(ai+jbj)]i=0minn1[nc2+j=0n1(ai+j2+bj22ai+jbj)+2c(j=0n1ai+jj=0n1bj)]i=0minn1(nc2+powa+powb2j=0n1ai+jbj+2c(sumasumb))

发现 p o w a + p o w b powa+powb powa+powb 是定值,而和 c c c 有关的 n c 2 + 2 c ( s u m a − s u m b ) nc^2+2c(suma-sumb) nc2+2c(sumasumb) 可以用二次函数最值 O ( 1 ) O(1) O(1) 求,也就是说我们只需要求 − 2 ∑ j = 0 n − 1 a i + j b j -2sum_{j=0}^{n-1}a_{i+j}b_j 2j=0n1ai+jbj 的最小值,即 ∑ j = 0 n − 1 a i + j b j sum_{j=0}^{n-1}a_{i+j}b_j j=0n1ai+jbj 的最大值。

我们可以把这个形式改变一下,把 ∑ j = 0 n − 1 a i + j b j sum_{j=0}^{n-1}a_{i+j}b_j j=0n1ai+jbj 改成 ∑ i − j = k a i b j sum_{i-j=k}a_ib_j ij=kaibj k k k 是旋转角度)。

S ( k ) = ∑ i − j = k a i b j S(k)=sum_{i-j=k}a_ib_j S(k)=ij=kaibj,设 a i ^ = a 2 n − i widehat{a_i}=a_{2n-i} ai =a2ni,则:

S ( k ) = ∑ i − j = k 0 ≤ j < n a i b j = ∑ ( 2 n − i ) − j = k 0 ≤ j < n a i ^ b j = ∑ i + j = 2 n − k 0 ≤ j < n a i ^ b j S(k)=sum_{i-j=k}^{0leq j< n}a_ib_j=sum_{(2n-i)-j=k}^{0leq j< n}widehat{a_i}b_j=sum_{i+j=2n-k}^{0leq j< n}widehat{a_i}b_j S(k)=ij=k0j<naibj=(2ni)j=k0j<nai bj=i+j=2nk0j<nai bj

n − k n-k nk 代入:

S ( 2 n − k ) = ∑ i + j = 2 n − ( 2 n − k ) 0 ≤ j < n a i ^ b j = ∑ i + j = k 0 ≤ j < n a i ^ b j S(2n-k)=sum_{i+j=2n-(2n-k)}^{0leq j< n}widehat{a_i}b_j=sum_{i+j=k}^{0leq j< n}widehat{a_i}b_j S(2nk)=i+j=2n(2nk)0j<nai bj=i+j=k0j<nai bj

有木有觉得很熟悉?

我们可以把 a i ^ widehat{a_i} ai 看成一个多项式 A A A 的系数, b j b_j bj 看成另一个多项式 B B B 的系数,那么 S ( 2 n − k ) S(2n-k) S(2nk) 就是 A × B A imes B A×B k k k 项的系数。

看会我们原来的问题:求 ∑ i − j = k a i b j sum_{i-j=k}a_ib_j ij=kaibj 的最大值,也就是求 S ( k ) S(k) S(k) 的最大值( 0 ≤ k < n 0leq k<n 0k<n),也就是 A × B A imes B A×B ( n + 1 ) ∼ 2 n (n+1)sim 2n (n+1)2n 项系数的最大值。

那么现在就好处理了,我们先用 FFT 把 A × B A imes B A×B 算出来,再取最大值。

代码如下:

#include<bits/stdc++.h>

#define N 50010
#define PN 262200
#define INF 0x7fffffff

using namespace std;

struct Complex
{
	double x,y;
	Complex(){};
	Complex(double xx,double yy){x=xx,y=yy;}
}a[PN],b[PN],c[PN];

Complex operator + (Complex a,Complex b){return Complex(a.x+b.x,a.y+b.y);}
Complex operator - (Complex a,Complex b){return Complex(a.x-b.x,a.y-b.y);}
Complex operator * (Complex a,Complex b){return Complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}

const double pi=acos(-1);

int n,m,suma,sumb,powa,powb;
int limit=1,bit,rev[PN];

void FFT(Complex *a,int opt)
{
	for(int i=0;i<limit;i++)
		if(i<rev[i])
			swap(a[i],a[rev[i]]);
	for(int mid=1;mid<limit;mid<<=1)
	{
		Complex wn=Complex(cos(pi/mid),opt*sin(pi/mid));
		for(int i=0,len=(mid<<1);i<limit;i+=len)
		{
			Complex now=Complex(1,0);
			for(int j=0;j<mid;j++,now=now*wn)
			{
				Complex x=a[i+j],y=now*a[i+mid+j];
				a[i+j]=x+y,a[i+mid+j]=x-y;
			}
		}
	}
	if(opt==-1)
		for(int i=0;i<limit;i++)
			a[i].x/=limit;
}

int main()
{
	scanf("%d%d",&n,&m);
	for(int i=0;i<n;i++)
	{
		scanf("%lf",&a[i].x);
		a[n+i].x=a[i].x;
		suma+=a[i].x;
		powa+=a[i].x*a[i].x;
	}
	for(int i=0;i<n;i++)
	{
		scanf("%lf",&b[i].x);
		sumb+=b[i].x;
		powb+=b[i].x*b[i].x;
	}
	int d=round(-1.0*(suma-sumb)/n);
	reverse(a,a+2*n+1);//求a^
	while(limit<=3*n)
		limit<<=1,bit++;
	for(int i=0;i<limit;i++)
		rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
	FFT(a,1),FFT(b,1);
	for(int i=0;i<limit;i++)
		c[i]=a[i]*b[i];
	FFT(c,-1);
	int ans=-INF;
	for(int i=n+1;i<=2*n;i++)
		ans=max(ans,(int)round(c[i].x));//这里用round是怕被卡精度
	printf("%d
",n*d*d+2*d*(suma-sumb)+powa+powb-2*ans);
	return 0;
}
原文地址:https://www.cnblogs.com/ez-lcw/p/14448658.html