残缺的字符串

题目

我终于会用多项式来做字符串匹配了

我们设一个匹配函数(f(x))表示(A)串的第(x)位能否匹配

我们把原字符串上的每一个字母都改成数字,比如('a')变成(1),('b')变成(2)

于是我们定义

[f(x)=sum_{i=1}^m(A_{i+x-1}-B_i)^2 ]

我们考虑一下上面的那个柿子,只有当(A_{i+x-1})始终等于(B_i)的时候,(f(x)=0),只要有一位上是不一样的,(f(x)>0)

也就是这一位匹配上了就是(0),没匹配上技就是大于(0)

我们考虑加入通配符,只需要强行使得这一位是(0)就好了

我们强行定义通配符为(0)

于是现在

[f(x)=sum_{i=1}^m(A_{i+x-1}-B_i)^2A_{i+x-1}B_i ]

拆开就是三个柿子

[sum_{i=1}^mA_{i+x-1}^3B_i ]

[sum_{i=1}^mA_{i+x-1}B_i^3 ]

[-2sum_{i=1}^mA_{i+x-1}^2B_i^2 ]

我们发现我们翻转一下(B)串就能变成卷积的形式了

于是三遍(ntt)之后输出所有为(0)的下标就好了

代码

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
const int maxn=2e6+5;
const int G[2]={3,332748118};
int n,m,len,rev[maxn],ans;
int a[maxn],b[maxn];
int a1[maxn],a2[maxn],a3[maxn];
int b1[maxn],b2[maxn],b3[maxn];
char S[maxn],T[maxn];
const int mod=998244353;
inline int ksm(int a,int b) {
    int S=1;
    while(b) {if(b&1) S=1ll*S*a%mod;b>>=1;a=1ll*a*a%mod;}
    return S;
}
inline void NTT(int *f,int o) {
    for(re int i=0;i<len;i++) if(i<rev[i]) std::swap(f[i],f[rev[i]]);
    for(re int i=2;i<=len;i<<=1) {
        int ln=i>>1,og1=ksm(G[o],(mod-1)/i);
        for(re int l=0;l<len;l+=i) {
            int t,og=1;
            for(re int x=l;x<l+ln;++x) {
                t=1ll*f[x+ln]*og%mod;
                f[x+ln]=(f[x]-t+mod)%mod;
                f[x]=(f[x]+t)%mod;
                og=1ll*og*og1%mod;
            }
        }
    }
    if(!o) return;
    int inv=ksm(len,mod-2);
    for(re int i=0;i<len;i++) f[i]=1ll*inv*f[i]%mod;
}
int main() {
    scanf("%d%d",&m,&n);
    scanf("%s",S+1);scanf("%s",T+1);
    for(re int i=1;i<=m;i++) {
        if(S[i]=='*') {b[m-i+1]=0;continue;}
        b[m-i+1]=S[i]-'a'+1;
    }
    for(re int i=1;i<=n;i++) {
        if(T[i]=='*') {a[i]=0;continue;}
        a[i]=T[i]-'a'+1;
    }
    len=1;while(len<=n+m) len<<=1;
    for(re int i=0;i<len;i++) rev[i]=rev[i>>1]>>1|((i&1)?len>>1:0);
    for(re int i=1;i<=n;i++) a1[i]=a[i]*a[i]*a[i];
    for(re int i=1;i<=n;i++) a2[i]=a[i];
    for(re int i=1;i<=n;i++) a3[i]=a[i]*a[i];
    for(re int i=1;i<=m;i++) b1[i]=b[i];
    for(re int i=1;i<=m;i++) b2[i]=b[i]*b[i]*b[i];
    for(re int i=1;i<=m;i++) b3[i]=2*b[i]*b[i];
    NTT(a1,0),NTT(a2,0),NTT(a3,0);
    NTT(b1,0),NTT(b2,0),NTT(b3,0);
    for(re int i=0;i<len;i++)
        b1[i]=(1ll*b1[i]*a1[i])%mod,
        b2[i]=(1ll*b2[i]*a2[i])%mod,
        b3[i]=(1ll*b3[i]*a3[i])%mod;
    NTT(b1,1),NTT(b2,1),NTT(b3,1);
    for(re int i=m+1;i<=n+m&&i-1<=n;i++)
        b1[i]=(b1[i]+b2[i]-b3[i]+mod)%mod,ans+=(b1[i]==0);
    printf("%d
",ans);
    for(re int i=m+1;i<=n+m&&i-1<=n;i++)
    if(!b1[i]) printf("%d ",i-m);
    puts("");
    return 0;
}
原文地址:https://www.cnblogs.com/asuldb/p/10780618.html