题解 luoguP5466 【[PKUSC2018]神仙的游戏】

显然要是没有限制那就全都是(1)对吧,

所以考虑这些(0, 1)到底限制了啥。

首先看到(border), 容易想到当年写(kmp)求最短回文子串的那题。因为(kmp)(next)数组实际上就是最长(border)。我们也在那题积累一个结论, 就是对于原来的串一个长度为(k)(border), 原串一定是由长度为(n -k)的子串循环而成的,但是最后一部分并不一定是一个完整的循环节。(实际上你也可以看成第一个不一定是一个完整的循环节, 这一部分画图很容易看出来)。那么如果有一组(0,1)的位置距离恰好是(n-k),那显然就是不行的, 因为按照要求, 它们应该相等。显然, 如果这个距离是(n-k)的因数也是不行的。

然后一个简单的暴力就出来了, 暴力枚举所有(0, 1), 求出它们的距离, 然后取并集, 对于这个集合里的长度暴力枚举倍数来判断是否合法, 复杂度(O(n^2+nln n))

然后我们考虑优化这个做法, 优化这个我只知道有数据结构和多项式两条路,感觉数据结构不太行那就搞多项式。我们要求的是距离, 就设(F(x) = sum_{i=0}^{n-1}x^if_i), 其中(f_i)表示距离为(i)(0,1)点对有多少对。 然后设(A(x)=sum_{i=0}^{n-1}[s[i] =='0']x^i, B(x)=sum_{i=0}^{n-1}[s[i]=='1']x^{-i}),考虑怎么解决这个负数幂次, 让(B(x)=x^{n}B(frac{1}{x})),最后求答案的时候直接位置加上(n)就好了。

#include <cmath>
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
#include <iostream>

using namespace std;

#define R register
#define LL long long
const int N = 2e6 + 10;
const int P = 998244353;
const int G = 3;
const int Gi = 332748118;

inline int read() {
	int x = 0, f = 1; char a = getchar();
	for(; a > '9' || a < '0'; a = getchar()) if(a == '-') f = -1;
	for(; a >= '0' && a <= '9'; a = getchar()) x = x * 10 + a - '0';
	return x * f;
}

inline int qpow(int x, int k) { 
	int res = 1;
	while(k) {
		if(k & 1) res = 1LL * res * x % P;
		x = 1LL * x * x % P; k >>= 1;
	} return res;
}

int bit, rev[N], lim = 1;
inline void NTT(int *A, int type = 1) {
	for(R int i = 0; i < lim; i ++) if(i > rev[i]) swap(A[i], A[rev[i]]);
	for(R int dep = 1; dep < lim; dep <<= 1) {
		int Wn = qpow(type == 1 ? G : Gi, (P - 1) / (dep << 1));
		for(R int j = 0, len = dep << 1; j < lim; j += len) {
			int w = 1;
			for(R int k = 0; k < dep; k ++, w = 1LL * w * Wn % P) {
				int x = A[j + k], y = 1LL * A[j + k + dep] * w % P;
				A[j + k] = (x + y) % P;
				A[j + k + dep] = (x - y + P) % P;
			}
		}
	}
	if(type == -1) {
		int inv = qpow(lim, P - 2);
		for(R int i = 0; i < lim; i ++) A[i] = 1LL * A[i] * inv % P;
	}
}

int n; 
char s[N];
int a[N], b[N];
int main() {
	//freopen(".in", "r", stdin);
	//freopen(".out", "w", stdout);
	scanf("%s", s); n = strlen(s);
	while(lim < (n << 1)) lim <<= 1, bit ++;
	for(R int i = 0; i < lim; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
	for(R int i = 0; i < n; i ++) a[i] = (s[i] == '0');
	for(R int i = 0; i < n; i ++) b[i] = (s[n - i - 1] == '1');
	NTT(a); NTT(b);
	for(R int i = 0; i < lim; i ++) a[i] = 1LL * a[i] * b[i] % P;
	NTT(a, -1);
	LL ans = 1LL * n * n;
	for(R int i = 1; i < n; i ++) {
		int f = 1;
		for(R int j = i; j < n; j += i) if(a[n - 1 - j] | a[n - 1 + j]) {f = 0; break; }
		if(f) ans ^= 1LL * (n - i) * (n - i);
	}
	printf("%lld
", ans);
	return 0;
}
原文地址:https://www.cnblogs.com/clover4/p/13525964.html