AH/HNOI 2017 礼物

题目链接

描述

两个序列 (x, y),可以将一个序列每个值同时加非负整数 (c),其中一个序列可以循环移位,要求最小化:

[sum_{i = 1}^{n}(x_i - y_i) ^ 2 ]

题解

循环移位 (Leftrightarrow) 断环成链。显然那个序列循环移位不影响,而且强制加值在 (x) 上, (c) 可以为负整数(可以理解为如果是负数,则把这个的绝对值加到 (y) 上,差保持不变),不妨让 y 移位,将 y 数组复制一倍到末尾,由于循环移位,所以 (sum_{i=1}^{n} y_{j + i} = sum_{i=1}^{n} y_i)

[ans = min{ sum_{i = 1}^{n} (x_i + c - y_{j + i}) ^ 2 } ]

把里面这个东西拿出来:

[sum_{i = 1}^{n} (x_i + c - y_{j + i}) ^ 2 = sum_{i=1}^{n}x_i ^ 2 + sum_{i=1}^{n}y_i^2 + nc^2 + 2csum(x_i - y_i) - 2 sum x_i y_{i + j} ]

要让这个式子尽量小:

  • 前两项是定值
  • (3, 4) 项是一个关于 (c) 的开口向上二次函数,由于要求取整数,所以算对称轴,算一下最近的两个整数取最优值即可。

比较棘手的是最后一项 (sum x_i y_{i + j}) (可以忽略系数),感觉可以转化成卷积的形式,用套路性的反转序列试试看:

新建一个数组 (z),令 (z_{2n - i + 1} = y_i)

(sum x_i y_{i + j} = sum x_i z_{2n + 1 - i - j}),很显然的一个卷积,即从 (j) 位开始的答案记录在了 (2n - j + 1) 位的系数上。

Tips

  • C++ 如果是负数整除会上取整,注意特判

时间复杂度

(O(Nlog_2N))

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
using namespace std;
typedef long long LL;

const int N = 2e5 + 5;
const double PI = acos(-1);

int n, m, c, lim = 1, len, rev[N], x[N], y[N];
LL ans = 0;

struct CP{
	double x, y;
	CP operator + (const CP &b) const { return (CP){ x + b.x, y + b.y }; }
	CP operator - (const CP &b) const { return (CP){ x - b.x, y - b.y }; }
	CP operator * (const CP &b) const { return (CP){ x * b.x - y * b.y, x * b.y + y * b.x }; }
} F[N], G[N];

void FFT(CP a[], int opt) {
	for (int i = 0; i < lim; i++) if (i < rev[i]) swap(a[i], a[rev[i]]);
	for (int m = 1; m <= lim; m <<= 1) {
		CP wn = (CP){ cos(2 * PI / m), opt * sin(2 * PI / m) };
		for (int i = 0; i < lim; i += m) {
			CP w = (CP){ 1, 0 };
			for (int j = 0; j < (m >> 1); j++) {
				CP u = a[i + j], t = w * a[i + j + (m >> 1)];
				a[i + j] = u + t, a[i + j + (m >> 1)] = u - t;
				w = w * wn;
			}
		}
	}
}

int main() {
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n; i++) scanf("%d", x + i), ans += x[i] * x[i], c += x[i];
	for (int i = 1; i <= n; i++) scanf("%d", y + i), y[i + n] = y[i], ans += y[i] * y[i], c -= y[i];
	// x = l 是对称轴
	int l = -c / n;
	if (c > 0) l--;
	ans += min(l * l * n + 2 * l * c, (l + 1) * (l + 1) * n + 2 * (l + 1) * c);
	LL v = 0;
	for (int i = 1; i <= n; i++) F[i].x = x[i];
	for (int i = 1; i <= 2 * n; i++) G[i].x = y[2 * n - i + 1];
 	while (lim <= 2 * n) lim <<= 1, ++len;
 	for (int i = 0; i < lim; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (len - 1));
 	FFT(F, 1); FFT(G, 1);
 	for (int i = 0; i < lim; i++) F[i] = F[i] * G[i];
 	FFT(F, -1);
	for (int i = n + 1; i <= 2 * n; i++) v = max(v, (LL)(F[i].x / lim + 0.5));
	printf("%lld
", ans - 2 * v);
	return 0;
}
原文地址:https://www.cnblogs.com/dmoransky/p/12463436.html