FFT/NTT字符串模糊匹配

因为FFT精度问题太离谱了,所以墙裂推荐用NTT
首先考虑精确匹配:https://www.acwing.com/problem/content/833/
假设我们有短串(s1)(长度为(n)),长串(s2)(长度为(m)
我们定义字符差

[c(x,y) = s1(x) - s2(y) ]

(c(x,y) = 0),表明(s1)的第(x)个字符与(s2)的第(y)个字符匹配,再定义

[F(x) = sum_{i = 0}^{n - 1}c(i,x-n+i+1) ]

(s2)子串的字符差之和,这个子串长为(n)并且以下标(x)为结尾,若(F(x) = 0),则表明这个子串与(s1)完全匹配,但这样可能会将(ab)(ba)算为完全匹配,因此我们考虑将(F(x))换个表达式

[F(x) = sum_{i = 0}^{n - 1}[s1(i)-s2(x-n+i+1)]^{2} ]

这样若(F(x) = 0),则表明这个子串与之完全匹配,将其暴力拆解

[F(x) =sum_{i = 0}^{n - 1}s1(i)^2+sum_{i = 0}^{n - 1}s2(x-n+i+1)^2-sum_{i = 0}^{n - 1}2s1(i)s2(x-n+i+1) ]

其中(sum_{i = 0}^{n - 1}s1(i)^2)(sum_{i = 0}^{n - 1}s2(x-n+i+1)^2)都可以用前缀和解决,关键是(sum_{i = 0}^{n - 1}2s1(i)s2(x-n+i+1)),我们将(s1)翻转,可得(s1'(x-n+i+1)=s1(i)),即

[sum_{i = 0}^{n - 1}2s1(i)s2(x-n+i+1)=sum_{i = 0}^{n - 1}2s1'(n-i-1)s2(x-n+i+1)=sum_{i+j=x}^{}s1'(i)s2(j) ]

可以发现能用NTT啦!因此

[F(x) = sum - S(x) + S(x-n) - 2sum_{i+j=x}^{}s1'(i)s2(j) ]

(F(x)=0)时,表明完全匹配

AC代码:
不开O2会T

#include <unordered_map>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <string>
#include <stack>
#include <deque>
#include <queue>
#include <cmath>
#include <map>
#include <set>
using namespace std;
#pragma GCC optimize(2)
typedef pair<int, int> PII;
typedef pair<double, int> PDI;
//typedef __int128 int128;
typedef long long ll;
typedef unsigned long long ull;
const int INF = 0x3f3f3f3f;
const int N = 1e7 + 10, M = 4e7 + 10;
const int base = 1e9;
const int P = 131;
const int mod = 998244353;
const double eps = 1e-9;
const double PI = acos(-1.0);
int n, m, tot, bit;
char s1[N], s2[N];
ll S[N], a[N], b[N];
int R[N];
ll ksm(ll a, ll b)
{
    ll res = 1 % mod;
    while (b)
    {
        if (b & 1)
            res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}
void inif(int n)
{
    tot = 1, bit = 0;
    while (tot <= n)
        tot *= 2, ++bit;
    for (int i = 0; i <= tot; ++i)
        R[i] = (R[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
void NTT(ll f[], int total, int type)
{
    for (int i = 0; i < total; ++i)
        if (i < R[i])
            swap(f[i], f[R[i]]);
    for (int tot = 2; tot <= total; tot *= 2)
    {
        ll w1 = ksm(type == 1 ? 3 : 332748118, (mod - 1) / tot);
        //332748118为 3 在模 998244353 的逆元
        for (int pos = 0; pos < total; pos += tot)
        {
            ll w = 1;
            for (int i = pos; i < pos + tot / 2; ++i, w = w * w1 % mod)
            {
                int x = f[i];
                int y = w * f[i + tot / 2] % mod;
                f[i] = (x + y) % mod;
                f[i + tot / 2] = (x - y + mod) % mod;
            }
        }
    }
    if (type == -1)
    {
        int inv = ksm(tot, mod - 2);
        for (int i = 0; i <= n + m; ++i)
            a[i] = a[i] * inv % mod;
    }
}

int main()
{
    scanf("%d%s%d%s", &n, &s1, &m, &s2);
    for (int i = 0; i < n; ++i)
        a[i] = s1[i] - 'a' + 1;
    for (int i = 0; i < m; ++i)
        b[i] = s2[i] - 'a' + 1;
    reverse(a, a + n);
    ll sum = 0;
    for (int i = 0; i < n; ++i)
        sum = (sum + a[i] * a[i] % mod) % mod;
    S[0] = b[0] * b[0];
    for (int i = 1; i < m; ++i)
        S[i] = (S[i - 1] + b[i] * b[i] % mod) % mod;
    inif(n + m);
    NTT(a, tot, 1), NTT(b, tot, 1);
    for (int i = 0; i < tot; ++i)
        a[i] = a[i] * b[i] % mod;
    NTT(a, tot, -1);
    for (int x = n - 1; x < m; ++x)
    {
        double P = (sum + S[x] - S[x - n] - 2 * a[x]) % mod;
        if (P == 0)
            printf("%d ", x - n + 1);
    }
    return 0;
}

接着我们考虑模糊匹配,即有通配符的情况:https://www.luogu.com.cn/problem/P4173
设通配符的值为0,重新定义字符差

[c(x,y) = [s1(x) - s2(y)]^2s1(x)s2(y) ]

发现会完美解决问题,依然暴力拆解

[F(x) = sum_{i = 0}^{n - 1}[s1(i)-s2(x-n+i+1)]^{2}s1(i)s2(x-n+i+1)\ =[sum_{i = 0}^{n - 1}s1(i)^2+sum_{i = 0}^{n - 1}s2(x-n+i+1)^2-sum_{i = 0}^{n - 1}2s1(i)s2(x-n+i+1)]s1(i)s2(x-n+i+1)\ =sum_{i = 0}^{n - 1}s1(i)^3s2(x-n+i+1)+sum_{i = 0}^{n - 1}s1(i)s2(x-n+i+1)^3-sum_{i = 0}^{n - 1}2s1(i)^2s2(x-n+i+1)^2\ =sum_{i+j=x}^{}s1'(i)^3s2(j)+sum_{i+j=x}^{}s1'(i)s2(j)^3+sum_{i+j=x}^{}s1'(i)^2s2(j)^2 ]

(F(x)=0)时,表明完全匹配

AC代码:

#include <unordered_map>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <string>
#include <stack>
#include <deque>
#include <queue>
#include <cmath>
#include <map>
#include <set>
using namespace std;
typedef pair<int, int> PII;
typedef pair<double, int> PDI;
//typedef __int128 int128;
typedef long long ll;
typedef unsigned long long ull;
const int INF = 0x3f3f3f3f;
const int N = 1e7 + 10, M = 4e7 + 10;
const int base = 1e9;
const int P = 131;
const int mod = 998244353;
const double eps = 1e-9;
const double PI = acos(-1.0);
int n, m;
int A[N], B[N];
char s1[N], s2[N];
int R[N], ans[N];
int tot, bit, pos;
ll a[N], b[N], p[N];
ll ksm(ll a, ll b)
{
	ll res = 1 % mod;
	while (b)
	{
		if (b & 1)
			res = res * a % mod;
		a = a * a % mod;
		b >>= 1;
	}
	return res;
}
void inif(int n)
{
	tot = 1, bit = 0;
	while (tot <= n)
		tot *= 2, ++bit;
	for (int i = 0; i <= tot; ++i)
		R[i] = (R[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
void NTT(ll f[], int total, int type)
{
	for (int i = 0; i < total; ++i)
		if (i < R[i])
			swap(f[i], f[R[i]]);
	for (int tot = 2; tot <= total; tot *= 2)
	{
		ll w1 = ksm(type == 1 ? 3 : 332748118, (mod - 1) / tot);
		//332748118为 3 在模 998244353 的逆元
		for (int pos = 0; pos < total; pos += tot)
		{
			ll w = 1;
			for (int i = pos; i < pos + tot / 2; ++i, w = w * w1 % mod)
			{
				int x = f[i];
				int y = w * f[i + tot / 2] % mod;
				f[i] = (x + y) % mod;
				f[i + tot / 2] = (x - y + mod) % mod;
			}
		}
	}
	if (type == -1)
	{
		int inv = ksm(tot, mod - 2);
		for (int i = 0; i <= n + m; ++i)
			a[i] = a[i] * inv % mod;
	}
}
int main()
{
	scanf("%d%d%s%s", &n, &m, &s1, &s2);
	reverse(s1, s1 + n);
	for (int i = 0; i < n; ++i)
		A[i] = s1[i] == '*' ? 0 : s1[i] - 'a' + 1;
	for (int i = 0; i < m; ++i)
		B[i] = s2[i] == '*' ? 0 : s2[i] - 'a' + 1;
	inif(n + m);
	//A[i]^3 B[i]
	for (int i = 0; i < tot; ++i)
		a[i] = A[i] * A[i] * A[i];
	for (int i = 0; i < tot; ++i)
		b[i] = B[i];
	NTT(a, tot, 1), NTT(b, tot, 1);
	for (int i = 0; i < tot; ++i)
		p[i] = (p[i] + a[i] * b[i]) % mod;
	//A[i] B[i]^3
	for (int i = 0; i < tot; ++i)
		a[i] = A[i];
	for (int i = 0; i < tot; ++i)
		b[i] = B[i] * B[i] * B[i];
	NTT(a, tot, 1), NTT(b, tot, 1);
	for (int i = 0; i < tot; ++i)
		p[i] = (p[i] + a[i] * b[i]) % mod;
	//A[i]^2 B[i]^2
	for (int i = 0; i < tot; ++i)
		a[i] = A[i] * A[i];
	for (int i = 0; i < tot; ++i)
		b[i] = B[i] * B[i];
	NTT(a, tot, 1), NTT(b, tot, 1);
	for (int i = 0; i < tot; ++i)
		p[i] = (p[i] - 2 * a[i] * b[i] + mod) % mod;

	NTT(p, tot, -1);
	for (int i = n - 1; i < m; ++i)
		if (p[i] == 0)
			ans[++pos] = i - n + 2;

	printf("%d
", pos);
	for (int i = 1; i <= pos; ++i)
		printf("%d ", ans[i]);
	return 0;
}

然后是杭电多校让我知道了这个知识点
HDU6975:https://acm.hdu.edu.cn/showproblem.php?pid=6975
因为字符只包含0-9和,首先不考虑通配符,我们可以枚举0-9,将每个子串在0-9情况下的匹配数算出来,以8为例,将所有为8的地方值设为1,其他地方值设为0,则对单个字符的匹配数有

[F(x)=sum_{i=0}^{n-1}s1(i)s2(x-n+1+i)=sum_{i=0}^{n-1}s1(n-i-1)s2(x-n+i+1)=sum_{i+j=x}s1(i)s2(j) ]

求出每个子串的匹配数后就可以考虑通配符了,其实通配符匹配数=(s1)通配符数+(s2)子串通配符数-(s1)(s2)子串相同位置的通配符数,前缀和加卷积即可求出

AC代码:

#include <unordered_map>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <string>
#include <stack>
#include <deque>
#include <queue>
#include <cmath>
#include <map>
#include <set>
using namespace std;
typedef pair<int, int> PII;
typedef pair<double, int> PDI;
//typedef __int128 int128;
typedef long long ll;
typedef unsigned long long ull;
const int INF = 0x3f3f3f3f;
const int N = 1e6 + 10, M = 4e7 + 10;
const int base = 1e9;
const int P = 131;
const int mod = 998244353;
const double eps = 1e-9;
const double PI = acos(-1.0);
FILE *fp;
int n, m, tot, bit;
char s1[N], s2[N];
int R[N], ans[N];
ll a[N], b[N], f[N], S[N];
ll ksm(ll a, ll b)
{
	ll res = 1 % mod;
	while (b)
	{
		if (b & 1)
			res = res * a % mod;
		a = a * a % mod;
		b >>= 1;
	}
	return res;
}
void inif(int n)
{
	memset(s1, 0, sizeof(s1));
	memset(s2, 0, sizeof(s2));
	memset(ans, 0, sizeof(ans));
	memset(f, 0, sizeof(f));
	tot = 1, bit = 0;
	while (tot <= n)
		tot *= 2, ++bit;
	for (int i = 0; i <= tot; ++i)
		R[i] = (R[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
void NTT(ll f[], int total, int type)
{
	for (int i = 0; i < total; ++i)
		if (i < R[i])
			swap(f[i], f[R[i]]);
	for (int tot = 2; tot <= total; tot *= 2)
	{
		ll w1 = ksm(type == 1 ? 3 : 332748118, (mod - 1) / tot);
		//332748118? 3 ?? 998244353 ???
		for (int pos = 0; pos < total; pos += tot)
		{
			ll w = 1;
			for (int i = pos; i < pos + tot / 2; ++i, w = w * w1 % mod)
			{
				int x = f[i];
				int y = w * f[i + tot / 2] % mod;
				f[i] = (x + y) % mod;
				f[i + tot / 2] = (x - y + mod) % mod;
			}
		}
	}
	if (type == -1)
	{
		int inv = ksm(tot, mod - 2);
		for (int i = 0; i <= n + m; ++i)
			f[i] = f[i] * inv % mod;
	}
}
void get(char c, int type)
{
	for (int i = 0; i < tot; ++i)
		a[i] = s1[i] == c;
	for (int i = 0; i < tot; ++i)
		b[i] = s2[i] == c;
	NTT(a, tot, 1), NTT(b, tot, 1);
	for (int i = 0; i < tot; ++i)
	{
		if (type == 1)
			f[i] = (f[i] + a[i] * b[i] % mod) % mod;
		else
			f[i] = (f[i] - a[i] * b[i] % mod + mod) % mod;
	}
}
int main()
{
	int T;
	scanf("%d", &T);
	while (T--)
	{
		scanf("%d%d", &m, &n);
		inif(n + m);
		scanf("%s%s", s2, s1);
		reverse(s1, s1 + n);

		for (char c = '0'; c <= '9'; ++c)
			get(c, 1);
		get('*', -1);
		NTT(f, tot, -1);
		ll sum = 0;
		for (int i = 0; i < n; ++i)
			sum += s1[i] == '*';
		S[0] = s2[0] == '*';
		for (int i = 1; i < m; ++i)
			S[i] = (S[i - 1] + (s2[i] == '*')) % mod;
		for (int i = 0; i < tot; ++i)
		{
			if (i >= n)
				f[i] = (f[i] + sum + S[i] - S[i - n] + mod) % mod;
			else
				f[i] = (f[i] + sum + S[i]) % mod;
		}
		for (int i = n - 1; i < m; ++i)
			++ans[n - f[i]];
		for (int i = 0; i <= n; ++i)
		{
			if (i)
				ans[i] += ans[i - 1];
			printf("%d
", ans[i]);
		}
	}
	return 0;
}
原文地址:https://www.cnblogs.com/xiaopangpangdehome/p/15080759.html