子集变换

本来这个东西要放到多项式那一章的,但想想还是算了,毕竟应用场景区别还是蛮大的。
听说还有叫子集反演的黑科技,有时间再补吧。


快速沃尔什变换(FWT)

问题是这样的:给两个长为 (2^n(nle 17)) 的序列 (A,B),设

[C_i=sum_{joperatorname{opt}k=i}A_jB_k ]

(operatorname{opt}) 分别为 ( m or,and,xor) 时求出 (C)
考虑怎么做。
我们仿照 FFT 的思路,先把 (A,B) 变成点值表示,在对应项相乘,然后再变回序列。
我们一个一个运算来看

或运算的本质是枚举子集。
首先有一个显而易见的性质:若 (ioperatorname{or} j=i,ioperatorname{or} k=i),则有 ((joperatorname{or}k)operatorname{or}i=i)
所以我们考虑这么构造:

[{ m DWT}(A_i)=sum_{joperatorname{or}i=i}A_j ]

考虑这么做的正确性:

[{ m DWT}(A_i){ m DWT}(B_i)=sum_{joperatorname{or}i=i}A_jsum_{koperatorname{or}i=i}B_k=sum_{(joperatorname{or}k)operatorname{or}i=i}A_jB_k={ m DWT}(C_i) ]

现在考虑怎么 (O(nlog n)) 做 FWT。借鉴 DFT 的思路进行最高位 01 分类,记 (A_0)(A) 序列下标最高位为 (0) 的序列,(A_1) 为下标最高位为 (1) 的序列,易知两个序列长度为当前序列的一半。得到以下式子:

[{ m DWT}(A)=({ m DWT}(A_0),{ m DWT}(A_0)+{ m DWT}(A_1)) ]

括号表示把两个序列当成字符串拼接起来,(+) 代表序列对应位置的值相加。
很明显可以看出这个过程就是将子集的值不断往上累加的过程。
再考虑还原这个序列。之前是累加的过程,那还原就减掉即可。式子是:

[A=(A_0,A_1-A_0) ]

与运算的本质是枚举超集。
那么我们完全按照或运算相反的方向推即可。

[{ m DWT}(A)=({ m DWT}(A_0)+{ m DWT}(A_1),{ m DWT}(A_0)) ]

[A=(A_1-A_0,A_0) ]

异或

(d(x,y)={ m popcount}(xoperatorname{and}y)mod 2)
则有 (d(i,j)oplus d(i,k)=d(i,joplus k))。构造

[{ m DWT}(A_i)=sum_{d(j,i)=0}A_j-sum_{d(j,i)=1}A_j ]

然后证明 DWT 运算的正确性。我翻遍全网只找到了 xht37 的证明,然而没看懂,其他根本没人证,所以我也不证了。我看懂啦!但我懒得写了……
变换的式子:

[{ m DWT}(A)=({ m DWT}(A_0+A_1),{ m DWT}(A_0-A_1)) ]

[A=(frac{A_0+A_1}{2},frac{A_0-A_1}{2}) ]

#include <bits/stdc++.h>
#define Cpy(f,g,n) memcpy(f,g,sizeof(int[n]))
using namespace std;

const int N=(1<<17)+5,P=998244353;
int n,m,a[N],b[N],c[N],A[N],B[N];

void OR(int f[],int tyof)
{
    for(int len=2,k=1;len<=n;len*=2,k*=2)
        for(int i=0;i<n;i+=len)
            for(int j=0;j<k;++j)
                (f[i+j+k]+=1LL*f[i+j]*tyof%P)%=P;
}

void AND(int f[],int tyof)
{
    for(int len=2,k=1;len<=n;len*=2,k*=2)
        for(int i=0;i<n;i+=len)
            for(int j=0;j<k;++j)
                (f[i+j]+=1LL*f[i+j+k]*tyof%P)%=P;
}

void XOR(int f[],int tyof)
{
    for(int len=2,k=1;len<=n;len*=2,k*=2)
        for(int i=0;i<n;i+=len)
            for(int j=0;j<k;++j)
            {
                (f[i+j]+=f[i+j+k])%=P;
                (f[i+j+k]=(f[i+j]-f[i+j+k]+P)%P-f[i+j+k]+P)%=P;
                f[i+j]=1LL*f[i+j]*tyof%P,f[i+j+k]=1LL*f[i+j+k]*tyof%P;
            }
}

int main()
{
    scanf("%d",&m); n=1<<m;
    for(int i=0;i<n;++i) scanf("%d",a+i);
    for(int i=0;i<n;++i) scanf("%d",b+i);
    Cpy(A,a,n); Cpy(B,b,n);
    OR(a,1); OR(b,1);
    for(int i=0;i<n;++i) c[i]=1LL*a[i]*b[i]%P;
    OR(c,P-1);
    for(int i=0;i<n;++i) printf("%d ",c[i]);
    puts(""); Cpy(a,A,n); Cpy(b,B,n);
    AND(a,1); AND(b,1);
    for(int i=0;i<n;++i) c[i]=1LL*a[i]*b[i]%P;
    AND(c,P-1);
    for(int i=0;i<n;++i) printf("%d ",c[i]);
    puts(""); Cpy(a,A,n); Cpy(b,B,n);
    XOR(a,1); XOR(b,1);
    for(int i=0;i<n;++i) c[i]=1LL*a[i]*b[i]%P;
    XOR(c,P+1>>1);
    for(int i=0;i<n;++i) printf("%d ",c[i]);
    return 0;
}

子集卷积

快速子集变换(也叫 FST),就是解决这样的问题:

[c_k=sum_{ioperatorname{and}j=0\ioperatorname{or}j=k}a_ib_j ]

用集合表示就是:

[c_S=sum_{Tsubseteq S}a_Tb_{S-T} ]

注意到这个和或卷积唯一的区别就是需要两个集合交集为空。我们将这个条件转化为 (|i|+|j|=|icup j|),所以我们可以多记一维个数:

[a_{i,S}=sum_{Tsubseteq S,|T|=i}a_T ]

刚开始时,每个 (a_{|i|,i}) 是初始读入的系数,然后我们把所有的 (a_{|i|}) 都做一遍 FWT,在转移时要保证同一个 (S) 互相转移且交集为空

[c_{i,S}=sum_{j=0}^ia_{j,S}b_{i-j,S} ]

即可。复杂度 (O(n^22^n))

#include <bits/stdc++.h>
#define lb(x) (x&(-x))
#define gchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
using namespace std;
namespace IO{char buf[1<<21],*p1=buf,*p2=buf;int read()
{
    int x=0; char ch=gchar();
    for(;ch<'0'||ch>'9';ch=gchar());
    for(;ch>='0'&&ch<='9';ch=gchar()) x=x*10+ch-48;
    return x;
}}

const int N=21,M=(1<<20)+5,P=1e9+9;
int n,U,c[N][M],a[N][M],b[N][M],popct[M];

void OR(int f[],int tyof)
{
    for(int len=2,k=1;len<=U;len*=2,k*=2)
        for(int i=0;i<U;i+=len)
            for(int j=0;j<k;++j)
                (f[i+j+k]+=1LL*f[i+j]*tyof%P)%=P;
}

int main()
{
    n=IO::read(); U=1<<n;
    for(int i=1;i<U;++i) popct[i]=popct[i-lb(i)]+1;
    for(int i=0;i<U;++i) a[popct[i]][i]=IO::read();
    for(int i=0;i<U;++i) b[popct[i]][i]=IO::read();
    for(int i=0;i<=n;++i) OR(a[i],1),OR(b[i],1);
    for(int i=0;i<=n;++i)
        for(int j=0;j<=i;++j)
            for(int s=0;s<U;++s)
                (c[i][s]+=1LL*a[j][s]*b[i-j][s]%P)%=P;
    for(int i=0;i<=n;++i) OR(c[i],P-1);
    for(int i=0;i<U;++i) printf("%d ",c[popct[i]][i]);
    return 0;
}
原文地址:https://www.cnblogs.com/wzzyr24/p/13291856.html