FWT快速沃尔什变换

FWT

[C(i) = sum_{j @ k=i}A(j)B(k) ]

DWT

设第 (i) 个点值 (x^j) 带入的是 (f(i,j)),由于位运算乘法需满足 (x^ix^j=x^{i@j})

所以 (f(i,j)) 需满足:

[f(i,j)f(i,k)=f(i,j@k) ]

以此为依据构造 (f(i,j))

对于 and, (f(i, j) = [i &j=i])([isubseteq j])

对于 or, (f(i,j)=[i|j=i])([jsubseteq i])

对于 xor, (f(i,j)=(-1)^{popcount(i&j)})

(f(i,j)=prod f(i_p,j_p)) (i_p)(i) ,2进制下的每一位。

[egin{aligned} DWT(C)_i &= sum_{j=0}^{n-1}f(i,j)C(j)\ &=sum_{j=0}^{n/2-1}f(i,j)C(j)+sum_{j=n/2}^{n-1}f(i,j)C(j)\ &=sum_{j=0}^{n/2-1}f(i_{len},j_{len})f(i_{[1,len-1]},j_{[1,len-1]})C(j)+sum_{j=n/2}^{n-1}f(i_{len},j_{len})f(i_{[1,len-1]},j_{[1,len-1]})C(j)\ &=f(i_{len},0)sum_{j=0}^{n/2-1}f(i_{[1,len-1]},j_{[1,len-1]})C(j)+f(i_{len},1)sum_{j=n/2}^{n-1}f(i_{[1,len-1]},j_{[1,len-1]})C(j) end{aligned} ]

对于 (iin[0, n/2-1]) ,

[DWT(C)_i = f(0,0)DWT(C_L)_i + f(0,1)DWT(C_R)_i\ DWT(C)_{i+n/2}=f(1, 0)DWT(C_L)_{i}+f(1,1)DWT(C_R)_i ]

到最底层时 (DWT(C)_0=sum_{j=0}^0f(0,j)C(j)=f(0,0)C(j))(f(0,0)) 对于 xor, and, or 都是 1。所以递归到最底层 的点值就是自己本身的系数。

然后就可以左右分治了,值得注意的是DFT是奇偶分治,而DWT是左右。

and:

[DWT(C)_i = DWT(C_L)_i+DWT(C_R)_i\ DWT(C)_{i+n/2}=DWT(C_R)_i ]

or:

[DWT(C)_i = DWT(C_L)_i\ DWT(C)_{i+n/2}=DWT(C_L)_i+DWT(C_R)_i ]

xor:

[DWT(C)_i = DWT(C_L)_i + DWT(C_R)_i\ DWT(C)_{i+n/2}=DWT(C_L)_{i}-DWT(C_R)_i ]

IDWT

由DWT的式子可以解得:

and:

[DWT(C_L)_i=DWT(C)_i-DWT(C)_{i+n/2}\ DWT(C_R)_i=DWT(C)_{i+n/2} ]

or:

[DWT(C_L)_i = DWT(C)_i\ DWT(C_R)_i=DWT(C)_{i+n/2}-DWT(C)_i ]

xor:

[DWT(C_L)_i=frac{DWT(C)_i+DWT(C)_{i+n/2}}{2}\ DWT(C_R)_i=frac{DWT(C)_i-DWT(C)_{i+n/2}}{2} ]

#include <vector>
#include <cmath>
#include <cstdio>
#include <cassert>
#include <cstring>
#include <iostream>
#include <algorithm>

typedef long long LL;
typedef unsigned long long uLL;

#define fir first
#define sec second
#define SZ(x) (int)x.size()
#define MP(x, y) std::make_pair(x, y)
#define PB(x) push_back(x)
#define debug(...) fprintf(stderr, __VA_ARGS__)
#define GO debug("GO
")
#define rep(i, a, b) for (register int i = (a), i##end = (b); (i) <= i##end; ++ (i))
#define drep(i, a, b) for (register int i = (a), i##end = (b); (i) >= i##end; -- (i))
#define REP(i, a, b) for (register int i = (a), i##end = (b); (i) < i##end; ++ (i))

inline int read() {
    register int x = 0; register int f = 1; register char c;
    while (!isdigit(c = getchar())) if (c == '-') f = -1;
    while (x = (x << 1) + (x << 3) + (c xor 48), isdigit(c = getchar()));
    return x * f;
}
template<class T> inline void write(T x) {
    static char stk[30]; static int top = 0;
    if (x < 0) { x = -x, putchar('-'); }
    while (stk[++top] = x % 10 xor 48, x /= 10, x);
    while (putchar(stk[top--]), top);
}
template<typename T> inline bool chkmin(T &a, T b) { return a > b ? a = b, 1 : 0; }
template<typename T> inline bool chkmax(T &a, T b) { return a < b ? a = b, 1 : 0; }

using namespace std;

const int N = 131082;
const int P = 998244353;
const int inv2 = (P + 1) >> 1;

LL ADD(LL a, int b) {
    return a + b >= P ? a + b - P : a + b;
}

void DWT(int a[], int n, int t) {
    if (t > 0) {
        for (register int len = 2; len <= n; len <<= 1) {
            int m = len >> 1;
            for (register int* p = a; p != a + n; p += len) 
                for (register int i = 0; i < m; ++i) {
                    register int x = p[i], y = p[i + m];
                    if (t == 1) { //or 
                        p[i] = x;
                        p[i + m] = ADD(x, y);
                    } else if (t == 2) { //and
                        p[i] = ADD(x,  y);
                        p[i + m] = y;
                    } else { //xor
                        p[i] = ADD(x, y);
                        p[i + m] = ADD(1ll * x - y, P);
                    }
                }
        }
    } else {
        for (register int len = n; len >= 2; len >>= 1) {
            register int m = len >> 1;
            for (register int *p = a; p != a + n; p += len) 
                for (register int i = 0; i < m; ++i) {
                    register int x = p[i], y = p[i + m];
                    if (t == -1) {
                        p[i] = x;
                        p[i + m] = ADD(1ll * y - x, P);
                    } else if (t == -2) {
                        p[i] = ADD(1ll * x - y, P);
                        p[i + m] = y;
                    } else {
                        p[i] = ADD(x, y) * inv2 % P;
                        p[i + m] = ADD(1ll * x - y, P) * inv2 % P;
                    }
                }
        }
    }
}

int main() 
{
#ifndef ONLINE_JUDGE
    freopen("xhc.in", "r", stdin);
    freopen("xhc.out", "w", stdout);
#endif
    int lg2 = read();
    int n = 1 << lg2;
    static int A[N], B[N], C[N];
    rep (i, 0, n - 1) A[i] = read();
    rep (i, 0, n - 1) B[i] = read();

    DWT(A, n, 1), DWT(B, n, 1);
    rep (i, 0, n - 1) C[i] = 1ll * A[i] * B[i] % P;
    DWT(C, n, -1), DWT(A, n, -1), DWT(B, n, -1);
    rep (i, 0, n - 1) write(C[i]), putchar(' ');
    putchar('
');

    DWT(A, n, 2), DWT(B, n, 2);
    rep (i, 0, n - 1) C[i] = 1ll * A[i] * B[i] % P;
    DWT(C, n, -2), DWT(A, n, -2), DWT(B, n, -2);
    rep (i, 0, n - 1) write(C[i]), putchar(' ');
    putchar('
');

    DWT(A, n, 3), DWT(B, n, 3);
    rep (i, 0, n - 1) C[i] = 1ll * A[i] * B[i] % P;
    DWT(C, n, -3), DWT(A, n, -3), DWT(B, n, -3);
    rep (i, 0, n - 1) write(C[i]), putchar(' ');
    putchar('
');

    return 0;
}

原文地址:https://www.cnblogs.com/cnyali-Tea/p/11262632.html