CF 528D. Fuzzy Search NTT

CF 528D. Fuzzy Search NTT

题目大意

给出文本串S和模式串T和k,S,T为DNA序列(只含ATGC)。对于S中的每个位置(i),只要中[i-k,i+k]有一个位置匹配了字符(i),那么就认为(i)可以匹配。求S中有多少位置匹配了T。

思路

一共有四个字母,我们分别计算每个字母是否可行,其他不管。
最后四个都满足的位置就是一个合法位置(指的是初始位置)。
设g[i]表示S_i位置是否是枚举的字母,f[i]表示M_i是否是是枚举的字母。
他们满足条件只需要右斜对角线==len
发现每个点又是右斜对角线,反转ntt

错误

有点zz,4写成了m,还忘记删掉调试了,wrong了两发,ntt真好调试(不用调试)。

代码

#include <bits/stdc++.h>
using namespace std;
const int N=1e6+7,mod=998244353;
int read() {
	int x=0,f=1;char s=getchar();
	for(;s>'9'||s<'0';s=getchar()) if(s=='-') f=-1;
	for(;s>='0'&&s<='9';s=getchar()) x=x*10+s-'0';
	return x*f;
}
int n,m,k,limit=1,l,r[N];
char S[N],T[N];
int q_pow(int a,int b) {
	int ans=1;
	while(b) {
		if(b&1) ans=1LL*ans*a%mod;
		a=1LL*a*a%mod;
		b>>=1;
	}
	return ans;
}
void ntt(int *a,int type) {
	for(int i=0;i<=limit;++i)
		if(i<r[i]) swap(a[i],a[r[i]]);
	for(int mid=1;mid<limit;mid<<=1) {
		int Wn=q_pow(3,(mod-1)/(mid<<1));
		for(int i=0;i<limit;i+=(mid<<1)) {
			for(int j=0,w=1;j<mid;j++,w=1LL*w*Wn%mod) {
				int x=a[i+j],y=1LL*w*a[i+j+mid]%mod;
				a[i+j]=(x+y)%mod;
				a[i+j+mid]=(x+mod-y)%mod;
			}
		}
	}
	if(type==-1) {
		reverse(&a[1],&a[limit]);
		int inv=q_pow(limit,mod-2);
		for(int i=0;i<=limit;++i) a[i]=1LL*a[i]*inv%mod;
	}
}
int AAA[N],f[N],g[N],sum[N];
void solve(char x) {
	memset(f,0,sizeof(f));
	memset(g,0,sizeof(g));
	int tong,h=0,d=0;
	sum[0]=(S[0]==x);
	for(int i=1;i<n;++i) sum[i]=sum[i-1]+(S[i]==x);
	for(int i=0;i<n;++i) {
		int y=i+k>=n ? sum[n-1] : sum[i+k];
		int x=i-k-1 < 0 ? 0 : sum[i-k-1];
		g[i]=(bool)(y-x);
	}
	for(int i=0;i<m;++i) f[m-i-1]=(T[i]==x);
	// for(int i=0;i<n;++i) cout<<g[i]<<" ";cout<<"
";
	// for(int i=0;i<m;++i) cout<<f[i]<<" ";cout<<"
";

	ntt(g,1),ntt(f,1);
	for(int i=0;i<=limit;++i) f[i]=1LL*f[i]*g[i]%mod;
	ntt(f,-1);
	int gs=0;
	for(int i=0;i<m;++i) gs+=(T[i]==x);
	for(int i=m-1,js=0;i<=n-1;++i,++js) AAA[js]+=(f[i]==gs);
}
int main() {
	n=read(),m=read(),k=read();
	scanf("%s%s",S,T);
	while(limit<=n+m-2) limit<<=1,l++;
	for(int i=0;i<=limit;++i)
		r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
	solve('A'),solve('T'),solve('G'),solve('C');
	int tot=0;
	for(int i=0;i<=n-m+1;++i) tot+=(AAA[i]==4);
	printf("%d
",tot);
	return 0;
}
原文地址:https://www.cnblogs.com/dsrdsr/p/10704441.html