P4717 快速沃尔什变换FWT 模板题

#include <bits/stdc++.h>
using namespace std;
#define rep(i,a,n) for (int i=a;i<n;i++)
#define per(i,a,n) for (int i=n-1;i>=a;i--)
#define pb push_back
#define mp make_pair
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef long long ll;
typedef pair<int,int> Pii;
const ll mod=998244353;
const int maxn = 3e6+10;
ll powmod(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
// head

int a[maxn],b[maxn],c[maxn];
void FWT_or(int *a,int N,int opt)
{
    for(int i=1;i<N;i<<=1)
        for(int p=i<<1,j=0;j<N;j+=p)
            for(int k=0;k<i;++k)
                if(opt==1)a[i+j+k]=(a[j+k]+a[i+j+k])%mod;
                else a[i+j+k]=(a[i+j+k]+mod-a[j+k])%mod;
}
void FWT_and(int *a,int N,int opt)
{
    for(int i=1;i<N;i<<=1)
        for(int p=i<<1,j=0;j<N;j+=p)
            for(int k=0;k<i;++k)
                if(opt==1)a[j+k]=(a[j+k]+a[i+j+k])%mod;
                else a[j+k]=(a[j+k]+mod-a[i+j+k])%mod;
}
void FWT_xor(int *a,int N,int opt) //opt=1 正变换 opt=-1 逆变换
{
    ll inv2=powmod(2,mod-2);
    for(int i=1;i<N;i<<=1)
        for(int p=i<<1,j=0;j<N;j+=p)
            for(int k=0;k<i;++k)
            {
                int X=a[j+k],Y=a[i+j+k];
                a[j+k]=(X+Y)%mod;a[i+j+k]=(X+mod-Y)%mod;
                if(opt==-1)a[j+k]=1ll*a[j+k]*inv2%mod,a[i+j+k]=1ll*a[i+j+k]*inv2%mod;
            }
}
int main()
{
    int n;
    scanf("%d",&n);n=1<<n;
    for(int i=0;i<n;i++) scanf("%d",&a[i]);
    for(int i=0;i<n;i++) scanf("%d",&b[i]);
    
    FWT_or(a,n,1);FWT_or(b,n,1);
    for(int i=0;i<n;i++) c[i]=1ll*a[i]*b[i]%mod;
    FWT_or(c,n,-1);
    for(int i=0;i<n;i++) printf("%d ",c[i]); printf("
");
    
    FWT_or(a,n,-1),FWT_or(b,n,-1);
    FWT_and(a,n,1),FWT_and(b,n,1);
    for(int i=0;i<n;i++) c[i]=1ll*a[i]*b[i]%mod;
    FWT_and(c,n,-1);
    for(int i=0;i<n;i++) printf("%d ",c[i]); printf("
");
    
    FWT_and(a,n,-1),FWT_and(b,n,-1);
    FWT_xor(a,n,1),FWT_xor(b,n,1);
    for(int i=0;i<n;i++) c[i]=1ll*a[i]*b[i]%mod;
    FWT_xor(c,n,-1);
    for(int i=0;i<n;i++) printf("%d ",c[i]); printf("
");
    return 0;
}
原文地址:https://www.cnblogs.com/stranger-/p/11232773.html