快速沃尔什变换 FWT 学习笔记【多项式】

〇、前言

之前看到异或就担心是 FWT,然后才开始想别的。

这次学了 FWT 以后,以后判断应该就很快了吧?

参考资料

一、FWT 是什么

FWT 是快速沃尔什变换。它和快速傅里叶变换一样,原本都用于物理中的频谱分析。

但是由于它可分治的特点,在算法竞赛中常被用来计算位运算卷积。

二、FWT 能干什么

它可以在 (O(nlog n)) 的时间复杂度内由数组 (a,b) 得到数组 (c),满足

[ ewcommand{and}{ mathrm{and} } ewcommand{or}{ mathrm{or} } ewcommand{xor}{ mathrm{xor} } forall iin[0,n) c_i=sum_{joplus k=i}a_j imes b_k ]

其中 (oplus) 可以代表“与”,“或”,“异或”中的任意一种运算。

这叫做位运算卷积。

三、与、或卷积

我们需要把 (a,b) 数组分别转化为 (a',b') 来通过一次乘法解决多个乘法问题。

对于或,我们有:若 (jor i=i,kor i=i)((jor k)or i=i)

虽然这样看上去和题目要求还差了一点,但是我们如果这样想呢:

构造数组 (a',b')

[a'_i=sum_{jor i=i}a_j\ b'_i=sum_{kor i=i}b_k ]

即通过正变换(a) 转化为 (a'),由 (b) 转化为 (b')

那么

[egin{aligned} c'_i&=a'_i imes b'_i\ &=sum_{jor i=i}a_jsum_{kor i=i}b_k\ &=sum_{jor i=i}sum_{kor i=i}a_jb_k\ &=sum_{(jor k)or i=i}a_jb_k end{aligned} ]

((jor k)or i=i) 就又是变换完的形式了。

再通过逆变换(c') 转化回 (c),那么 (c) 就是满足 (c_i=sum_{jor k=i}a_jb_k) 的结果了。

同理,由于与运算满足:若 (jand i=i)(kand i=i),则 ((jand k)and i=i)

因此和上面的变换是一样的。

现在我们需要找出 (a o a') 是怎么实现的。

正变换

针对或变换的举例:

[forall iin [0,n) a'_i=sum_{jor i=i}a_i ]

我们可以按位分治。从下到上转移,第 (i) 层的状态 (j)(f[i,j]) 表示所有比 (i) 高的位与 (j) 相同的状态 (k) 的和。即

[f[i,j]=sum_{leftlfloorfrac{k}{2^i} ight floor=leftlfloorfrac{j}{2^i} ight floor,kor j=j}a_k ]

其中 (leftlfloorfrac{k}{2^i} ight floor) 表示将 (k) 在二进制下右移 (i) 位。

如果还不好理解,那么对于 (f[5,1011001110_{(2)}]),满足条件的 (k)

[i=5\ egin{aligned} j&=1011001110\ k&=1011000000or x end{aligned} ]

其中的 (x) 满足 (xor 001110=001110)(k) 必须满足在第 (5sim 9) 位与 (j) 相同。

分析方程,会发现我们是可以利用 (f[i-1]) 的信息的。在 (f[i-1,j]) 中的每一个状态所存的 (sum a_k)(j)(k) 从第 (i) 位到最高位都是相等的,现在我们用到了第 (i) 位,那么就考虑第 (i) 位的取值。

就有了简洁的状态转移,令 (j) 的第 (i) 位是 (0)

[egin{aligned} f[i,j]&=f[i-1,j],\ f[i,j+2^i]&=f[i-1,j]+f[i-1,j+2^i] end{aligned} ]

所以 (a'=f[leftlceillog n ight ceil]),答案就是最上面一层。

同理,与的正变换的方程恰好反过来了

[egin{aligned} f[i,j]&=f[i-1,j]+f[i-1,j+2^i],\ f[i,j+2^i]&=f[i-1,j+2^i] end{aligned} ]

逆变换

逆变换是由 (f[i])(f[i-1]) 的过程。

直接由上面的式子倒过来就可以了。

或:

[egin{aligned} f[i,j]&=f[i+1,j],\ f[i,j+2^i]&=f[i+1,j+2^i]-f[i+1,j] end{aligned} ]

与:

[egin{aligned} f[i,j]&=f[i+1,j]-f[i+1,j+2^i],\ f[i,j+2^i]&=f[i+1,j+2^i] end{aligned} ]

因此卷积的答案最后就存在 (f[1,i]=sum_{joplus k=i}a_jb_k) 里了。

四、异或卷积

这个东西有点麻烦,仍然需要构造。

定义运算 (xotimes y=operatorname{popcount}(xand y)mod 2),称之为 (x)(y) 的奇偶性。

它是一个满足 ((iotimes j)xor (iotimes k)=iotimes(jxor k)) 的运算,所以可以用来做异或卷积。

构造

[a'_i=sum_{iotimes j=0}a_j-sum_{iotimes j=1}a_j\ b'_i=sum_{iotimes k=0}b_k-sum_{iotimes k=1}b_k\ ]

[egin{aligned} c'_i&=sum_{iotimes j=0}a_jsum_{iotimes k=0}b_k-sum_{iotimes j=0}a_jsum_{iotimes k=1}b_k-sum_{iotimes j=1}a_jsum_{iotimes k=0}b_k+sum_{iotimes j=1}a_jsum_{iotimes k=1}b_k\ &=sum_{iotimes(jxor k)=0}a_jb_k-sum_{iotimes(jxor k)=1}a_jb_k end{aligned} ]

解释:式子中的第一行,第一项和第四项构成了 (iotimes(jxor k)=0) 的全部可能性:(00)(11);第二项和第三项构成了 (iotimes(jxor k)=1) 的全部可能性:(01)(10)。所以可以写 (sum),而且由于每项不相交,所以不能乘 (2)

可以发现 (c') 也是一个变换完了的式子,把它逆变换回去就可以了。

正变换

仍然按位分治,同样考虑上面那样逐位转移。

在枚举第 (i) 位的不同时,状态 (j) 和状态 (j+2^i) 都可以从第 (i-1) 层的 (j)(j+2^i) 转移过来。其中 (j) 的第 (i) 位为 (0)

这样的话有四种情况:

  • ([i,j]leftarrow[i-1,j])((0and 0)) 是不变的;
  • ([i,j]leftarrow[i-1,j+2^i])((0and 1)) 是不变的;
  • ([i,j+2^i]leftarrow[i-1,j])((1and 0)) 是不变的;
  • ([i,j+2^i]leftarrow[i-1,j+2^i])((1and 1)) 会改变。

这个图中蓝色(无色)的箭头表示正转移,其他颜色的箭头表示负转移。

也就是说,转移之后,这个状态内部的全部元素进行 (otimes) 的结果都从 (0) 变成了 (1) 或从 (1) 变成了 (0)

那么在最终结果方面就会产生影响,因此那些转移我们把它定为负的。

还有一种理解方法。因为最上面一行是我们正变换的结果,可以通过这个图从上到下来看出它的贡献来源。

(a'_i) 出发,遇到有颜色的边,就要把子数内的贡献取反(( imes -1)),它的意义也是 (kotimes i= eg(kotimes i))

这样对每一个位置就可以满足

[f[i,j]=sum_{kotimes j=0}a_k-sum_{kotimes j=1}a_k ]

了。其中 (k) 只枚举了有效位。

观察图可以发现,状态转移方程是

[egin{aligned} f[i,j]&=f[i-1,j]+f[i-1,j+2^i],\ f[i,j+2^i]&=f[i-1,j]-f[i-1,j+2^i] end{aligned} ]

逆变换

把正变换上下相减,除以 (2) 即可

[egin{aligned} f[i,j]&=frac{f[i+1,j]+f[i+1,j+2^i]}{2},\ f[i,j+2^i]&=frac{f[i+1,j]-f[i+1,f[j+2^i]]}{2} end{aligned} ]

五、代码

#include<cstdio>
#include<cstring>
#define p 998244353
#define inv 499122177ll
#define gc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1<<21],*p1=buf,*p2=buf;
int read()
{
    int x=0;
    char ch=gc();
    while(ch<'0'||ch>'9')
        ch=gc();
    while(ch>='0'&&ch<='9')
    {
        x=x*10+(ch&15);
        ch=gc();
    }
    return x;
}
int A[1<<17],B[1<<17],a[1<<17],b[1<<17],n,tot;
void init()
{
    for(int i=0;i<(1<<n);++i)
    {
        a[i]=A[i];
        b[i]=B[i];
    }
}
void Or(int *f)
{
    for(int bs=2;bs<=tot;bs<<=1)
    {
        int g=(bs>>1);
        for(int i=0;i<tot;i+=bs)
            for(int j=0;j<g;++j)
                f[i+j+g]=(f[i+j]+f[i+j+g])%p;
    }
}
void iOr(int *f)
{
    for(int bs=tot;bs>=2;bs>>=1)
    {
        int g=(bs>>1);
        for(int i=0;i<tot;i+=bs)
            for(int j=0;j<g;++j)
                f[i+j+g]=(f[i+j+g]+p-f[i+j])%p;
    }
}
void And(int *f)
{
    for(int bs=2;bs<=tot;bs<<=1)
    {
        int g=(bs>>1);
        for(int i=0;i<tot;i+=bs)
            for(int j=0;j<g;++j)
                f[i+j]=(f[i+j]+f[i+j+g])%p;
    }
}
void iAnd(int *f)
{
    for(int bs=tot;bs>=2;bs>>=1)
    {
        int g=(bs>>1);
        for(int i=0;i<tot;i+=bs)
            for(int j=0;j<g;++j)
                f[i+j]=(f[i+j]+p-f[i+j+g])%p;
    }
}
void Xor(int *f)
{
    for(int bs=2;bs<=tot;bs<<=1)
    {
        int g=(bs>>1);
        for(int i=0;i<tot;i+=bs)
            for(int j=0;j<g;++j)
            {
                int t0=(f[i+j]+f[i+j+g])%p,t1=(f[i+j]+p-f[i+j+g])%p;
                f[i+j]=t0;
                f[i+j+g]=t1;
            }
    }
}
void iXor(int *f)
{
    for(int bs=tot;bs>=2;bs>>=1)
    {
        int g=(bs>>1);
        for(int i=0;i<tot;i+=bs)
            for(int j=0;j<g;++j)
            {
                int t0=inv*(f[i+j]+f[i+j+g])%p,t1=inv*(f[i+j]+p-f[i+j+g])%p;
                f[i+j]=t0;
                f[i+j+g]=t1;
            }
    }
}
int main()
{
    #ifdef wjyyy
        freopen("a.in","r",stdin);
    #endif
    n=read();
    tot=(1<<n);
    for(int i=0;i<tot;++i)
        A[i]=read();
    for(int i=0;i<tot;++i)
        B[i]=read();
    init();
    Or(a);
    Or(b);
    for(int i=0;i<tot;++i)
        a[i]=(long long)a[i]*b[i]%p;
    iOr(a);
    for(int i=0;i<tot;++i)
        printf("%d ",a[i]);
    puts("");
    init();
    And(a);
    And(b);
    for(int i=0;i<tot;++i)
        a[i]=(long long)a[i]*b[i]%p;
    iAnd(a);
    for(int i=0;i<tot;++i)
        printf("%d ",a[i]);
    puts("");
    init();
    Xor(a);
    Xor(b);
    for(int i=0;i<tot;++i)
        a[i]=(long long)a[i]*b[i]%p;
    iXor(a);
    for(int i=0;i<tot;++i)
        printf("%d ",a[i]);
    return 0;
}
原文地址:https://www.cnblogs.com/wjyyy/p/FWT.html