luogu P4173 残缺的字符串 FFT

温馨提示:倘若下角标看不清的话您可以尝试放大。

倘若没有通配符的话可以用KMP搞一搞。

听巨佬说通配符可以用FFT搞一搞。

我们先考虑一下没有通配符的怎么搞。我们设a=1,b=2,...,然后我们构造一个这样的函数(displaystyle P_x=sum_{i=0}^{m-1}(A_i-B_{x-m+1+i})^2),但且仅当A和B在x的位置上匹配完成的时候(P_x) 为0.。至于为什么是平方,主要是为了防止正数和负数相互抵消。

至于通配符,我们设它为0,我们尝试重新构造一下(displaystyle P_x=sum_{i=0}^{m-1}(A_i-B_{x-m+1+i})^2A_iB_{x-m+1+i}),这样我们就能满足"通配"这一条件了。

那我们怎么快速求解呢?我们将式子先展开一下,

(displaystyle P_x=sum_{i=0}^{m-1}(A_i^3B_{x-m+1+i}-2A_i^2B_{x-m+1+i}^2+A_iB_{x-m+1+i}^3))

(displaystyle=sum_{i=0}^{m-1}A_i^3B_{x-m+1+i}-2sum_{i=0}^{m-1}A_i^2B_{x-m+1+i}^2+sum_{i=0}^{m-1}A_iB_{x-m+1+i}^3)

还是老方法,我们尝试将A翻转一下设(A_i=C_{m-i-1}),带进原来的式子。

(displaystyle=sum_{i=0}^{m-1}C_{m-i-1}^3B_{x-m+1+i}-2sum_{i=0}^{m-1}C_{m-i-1}^2B_{x-m+1+i}^2+sum_{i=0}^{m-1}C_{m-i-1}B_{x-m+1+i}^3)

发现了什么吗?C和B的下角标之和等于x,所以我们换一种写法。

(displaystyle=sum_{i+j=x}C_{i}^3B_{j}-2sum_{i+j=x}C_{i}^2B_{j}^2+sum_{i+j=x}C_{i}B_{j}^3)

是不是好看了很多?这很明显是一个卷积...

剩下的FFT一波带走就可以了。

#include<bits/stdc++.h>
#define LL long long
#define DB double
using namespace std;
int n, m, lim;
const int N = 1200010;
const DB PI = acos(-1);
int r[N], cnt[N];
DB a[N], b[N];
char s1[N], s2[N];
struct xu 
{
	DB x, y;
	xu(DB X = 0, DB Y = 0) {x = X, y = Y;}
	friend xu operator +(const xu &a, const xu &b)
	{return (xu) {a.x + b.x, a.y + b.y};}
	friend xu operator -(const xu &a, const xu &b)
	{return (xu) {a.x - b.x, a.y - b.y};}
	friend xu operator *(const xu &a, const xu &b)
	{return (xu) {a.x*b.x - a.y*b.y, a.x*b.y + a.y*b.x};}
} A[N], B[N], ans[N];
void FFT(xu *A, int lim, int opt) 
{
	for (int i = 0; i < lim; ++i)
		r[i] = (r[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0);
	for (int i = 0; i < lim; ++i)
		if (i < r[i])swap(A[i], A[r[i]]);
	int len;
	xu wn, w, x, y;
	for (int mid = 1; mid < lim; mid <<= 1) 
	{
		len = mid << 1;
		wn = (xu) {cos(PI / mid), opt*sin(PI / mid)};
		for (int j = 0; j < lim; j += len) 
		{
			w = (xu) {1, 0};
			for (int k = j; k < j + mid; ++k, w = w * wn) 
			{
				x = A[k]; y = A[k + mid] * w;
				A[k] = x + y; A[k + mid] = x - y;
			}
		}
	}
	if (opt == 1)return;
	for (int i = 0; i < lim; ++i)A[i].x /= lim;
}
int main() 
{
	cin >> n >> m;
	scanf("%s", s1); scanf("%s", s2);
	lim = 1;
	while (lim <= (n + m))lim <<= 1;
	reverse(s1, s1 + n);
	for (int i = 0; i < n; ++i)a[i] = (s1[i] == '*') ? 0 : s1[i] - 'a' + 1;
	for (int i = 0; i < m; ++i)b[i] = (s2[i] == '*') ? 0 : s2[i] - 'a' + 1;

	for (int i = 0; i < lim; ++i)A[i] = (xu) {a[i]*a[i]*a[i], 0}, B[i] = (xu) {b[i], 0};
	FFT(A, lim, 1); FFT(B, lim, 1);
	for (int i = 0; i < lim; ++i)ans[i] = ans[i] + A[i] * B[i];

	for (int i = 0; i < lim; ++i)A[i] = (xu) {a[i], 0}, B[i] = (xu) {b[i]*b[i]*b[i], 0};
	FFT(A, lim, 1); FFT(B, lim, 1);
	for (int i = 0; i < lim; ++i)ans[i] = ans[i] + A[i] * B[i];

	for (int i = 0; i < lim; ++i)A[i] = (xu) {a[i]*a[i], 0}, B[i] = (xu) {b[i]*b[i], 0};
	FFT(A, lim, 1); FFT(B, lim, 1);
	for (int i = 0; i < lim; ++i)ans[i] = ans[i] - A[i] * B[i] * (xu) {2, 0};
	FFT(ans, lim, -1);

	for (int i = n - 1; i < m; ++i)
		if (!(int)(ans[i].x + 0.5))cnt[++cnt[0]] = i - n + 2;
	cout << cnt[0] << endl;
	for (int i = 1; i <= cnt[0]; ++i)printf("%d ", cnt[i]);
	return 0;
}
原文地址:https://www.cnblogs.com/wljss/p/12020196.html