BZOJ 4503 两个串

思路

我们定义一个函数,(F(x,y)=(a_x-b_y)^2 b_y),使得有通配符的时候和相等的时候都为0

[egin{align}&sum_{i=0}^{n-m}sum_{j=0}^{m-1} (a_{i+j}-b_j)^2b_j\=&sum_{i=0}^{n-m}sum_{j=0}^{m-1} (a_{i+j}^2-2a_{i+j}b_j+b_j^2)b_j\=&sum_{i=0}^{n-m}sum_{j=0}^{m-1} a_{i+j}^2b_j-2a_{i+j}b_j^2+b_j^3\=&sum_{i=0}^{n-m}sum_{j=0}^{m-1} a_{i+j}^2b_j-sum_{j=0}^{m-1}2a_{i+j}b_j^2+sum_{j=0}^{m-1}b_j^3end{align} ]

尝试把b翻转一下

[egin{align}&sum_{i=0}^{n-m}sum_{j=0}^{m-1} (a_{i+j}-b_{m-j-1})^2b_{m-j-1}\=&sum_{i=0}^{n-m}sum_{j=0}^{m-1} (a_{i+j}^2-2a_{i+j}b_{m-j-1}+b_{m-j-1}^2)b_{m-j-1}\=&sum_{i=0}^{n-m}sum_{j=0}^{m-1} a_{i+j}^2b_{m-j-1}-2a_{i+j}b_{m-j-1}^2+b_{m-j-1}^3\=&sum_{i=0}^{n-m}sum_{j=0}^{m-1} a_{i+j}^2b_{m-j-1}-sum_{i=0}^{n-m}sum_{j=0}^{m-1}2a_{i+j}b_{m-j-1}^2+sum_{i=0}^{n-m}sum_{j=0}^{m-1}b_{m-j-1}^31end{align} ]

枚举(m+i-1),就变成卷积了
把三个式子分别卷起来

代码

#include <cstdio>
#include <algorithm>
#include <cstring>
#define int long long
using namespace std;
const int MAXN = 300000;
const int MOD = 998244353;
const int G = 3;
const int invG = 332748118;
int rev[MAXN];
void cal_rev(int n,int lim){
    for(int i=0;i<n;i++)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(lim-1));
}
int pow(int a,int b){
    int ans=1;
    while(b){
        if(b&1)
            ans=(1LL*ans*a)%MOD;
        a=(1LL*a*a)%MOD;
        b>>=1;
    }
    return ans;
}
void NTT(int *a,int opt,int n,int lim){
    for(int i=0;i<n;i++)    
        if(i<rev[i])
            swap(a[i],a[rev[i]]);
    for(int i=2;i<=n;i<<=1){
        int len=i/2,tmp=pow((opt)?G:invG,(MOD-1)/i);
        for(int j=0;j<n;j+=i){
            int arr=1;
            for(int k=j;k<j+len;k++){
                int t=(1LL*a[k+len]*arr)%MOD;
                a[k+len]=(a[k]-t+MOD)%MOD;
                a[k]=(a[k]+t)%MOD;
                arr=(1LL*arr*tmp)%MOD;
            }
        }
    }
    if(!opt){
        int invN=pow(n,MOD-2);
        for(int i=0;i<n;i++)
            a[i]=(1LL*a[i]*invN)%MOD;
    }
}
int s[MAXN],t[MAXN],x[MAXN],a[MAXN],b[MAXN],c[MAXN],n,m,ans[MAXN],cnt;
char S[MAXN];
signed main(){
    scanf("%s",S);
    n=strlen(S);
    for(int i=0;i<n;i++){
        s[i]=(S[i]=='?')?0:S[i]-'a'+1;
        x[i]=1;
    }
    scanf("%s",S);
    m=strlen(S);
    for(int i=0;i<m;i++)
        t[i]=(S[i]=='?')?0:S[i]-'a'+1;
    reverse(t,t+m);
    int midlen=1,midlim=0;
    while(midlen<(n+m))
        midlen<<=1,midlim++;
    cal_rev(midlen,midlim);

    for(int i=0;i<midlen;i++)
        b[i]=1;
    for(int i=0;i<midlen;i++)
        a[i]=(t[i]*t[i]*t[i])%MOD;
    NTT(b,1,midlen,midlim);
    NTT(a,1,midlen,midlim);
    for(int i=0;i<midlen;i++)
        c[i]=(a[i]*b[i])%MOD;

    for(int i=0;i<midlen;i++)
        b[i]=(2*s[i])%MOD;
    for(int i=0;i<midlen;i++)
        a[i]=(t[i]*t[i])%MOD;
    NTT(b,1,midlen,midlim);
    NTT(a,1,midlen,midlim);
    for(int i=0;i<midlen;i++)
        c[i]=(c[i]-a[i]*b[i]+MOD)%MOD;

    for(int i=0;i<midlen;i++)
        b[i]=(s[i]*s[i])%MOD;
    for(int i=0;i<midlen;i++)
        a[i]=(t[i])%MOD;
    NTT(b,1,midlen,midlim);
    NTT(a,1,midlen,midlim);
    for(int i=0;i<midlen;i++)
        c[i]=(c[i]+a[i]*b[i])%MOD;
    
    NTT(c,0,midlen,midlim);

    // for(int i=0;i<midlen;i++)
    //     printf("!%lld
",c[i]);
    for(int i=m-1;i<n;i++)
        if(!c[i])
            ans[++cnt]=i-m+1;
    printf("%lld
",cnt);
    for(int i=1;i<=cnt;i++)
        printf("%lld
",ans[i]);
    return 0;
}
原文地址:https://www.cnblogs.com/dreagonm/p/10757618.html