FWT 学习笔记

FWT用来干什么

快速处理

[c[k] = sum_{i or|and|xor j = k} a[i] * b[j] ]

  • 记号:(a + b) 表示 (a,b) 逐位相加 ((a[i] + b[i]))

  • 记号:(a * b) 表示 (a)(b)

  • 这种卷积具有乘法分配律 ((a + b) * c = a * c + b * c)

根据最高位为 (1 or 0) 将多项式 (a, b, c) 分为 (a_0, a_1, b_0, b_1, c_0, c_1)

这样就可以不考虑最高位

or

[c_0 = a_0 * b_0 ]

因为一旦最高位为 1 权值就统计到 (c_1) 上了

[c_1 = a_0 * b_1 + a_1 * b_0 + a_1 * b_1 ]

[= (a_0 + a_1) * (b_0 + b_1) - a_0 *b_0 ]

[= (a_0 + a_1) * (b_0 + b_1) - c_0 ]

这样问题就缩小一倍了

边界:(c[0] = a[0] * b[0])

and

[c_1 = a1 * b_1 ]

一旦最高位不为 1 权值就统计到 (c_0) 上了

[c_0 = a0 * b_1 + b_0 * a_1 + a_0 * b_0 ]

[= (a_0 + a_1) * (b_0 + b_1) - c_1 ]

xor

[c_0 = a_0 * b_0 + a_1 * b_1 ]

[c_1 = a_1 * b_0 + a_0 * b_1 ]

(x_0 = (a_0 + a_1) * (b_0 + b_1), x_1 = (a_0 - a_1) * (b_0 - b_1))

这样 (c_0 = frac {x_0 + x_1} 2, c_1 = frac {x_0 - x_1} 2)

然后就就没了。。

(code)

#include <bits/stdc++.h>
using namespace std;
#define rg register
inline int read(){
    #define gc getchar
    rg char ch = gc();
    rg int x = 0, f = 0;
    while(!isdigit(ch)) f |= (ch == '-'), ch = gc();
    while(isdigit(ch)) x = (x << 1) + (x << 3) + (ch ^ 48), ch =gc();
    return f ? -x : x;
}
const int N = 1 << 17;
int a[N], b[N], c[N], 
d[N], e[N], f[N], 
g[N], h[N], i[N];
int n;
const int mod = 998244353, inv2 = (mod + 1) >> 1;
inline void Mod(int &x){
    x += (x >> 31 & mod);
}
inline void mulor(int *a, int *b, int *c, int lim){
    if(!(lim >>= 1)) return (void) (*c = 1ll * (*a) * (*b) % mod);
    for(int i = 0; i < lim; ++i) Mod(a[i + lim] += a[i] - mod), Mod(b[i + lim] += b[i] - mod);
    mulor(a, b, c, lim); mulor(a + lim, b + lim, c + lim, lim);
    for(int i = 0; i < lim; ++i) Mod(c[i + lim] -= c[i]);
}
inline void muland(int *a, int *b, int *c, int lim){
    if(!(lim >>= 1)) return (void) (*c = 1ll * (*a) * (*b) % mod);
    for(int i = 0; i < lim; ++i) Mod(a[i] += a[i + lim] - mod), Mod(b[i] += b[i + lim] - mod);
    muland(a, b, c, lim); muland(a + lim, b + lim, c + lim, lim);
    for(int i = 0; i < lim; ++i) Mod(c[i] -= c[i + lim]);
}
inline void mulxor(int *a, int *b, int *c, int lim){
    if(!(lim >>= 1)) return (void) (*c = 1ll * (*a) * (*b) % mod);
    for(int i = 0; i < lim; ++i){
//      a[i] += a[i + lim];
//      a[i + lim] = a[i] - (a[i + lim] << 1);
//      b[i] += b[i + lim];
//      b[i + lim] = b[i] - (b[i + lim] << 1);
//		上下等价 
        tie(a[i], a[i + lim]) = make_tuple(a[i] + a[i + lim], a[i] - a[i + lim]);
        tie(b[i], b[i + lim]) = make_tuple(b[i] + b[i + lim], b[i] - b[i - lim]);
        Mod(a[i + lim]); Mod(b[i + lim]); Mod(a[i] -= mod); Mod(b[i] -= mod);
    }
    mulxor(a, b, c, lim); mulxor(a + lim, b + lim, c + lim, lim);
    for(int i = 0; i < lim; ++i) tie(c[i], c[i + lim]) 
        = make_tuple(1ll * (c[i] + c[i + lim]) * inv2 % mod, 1ll * (c[i] - c[i + lim] + mod) * inv2 % mod);
}
signed main(){
    n = read();
    int lim = 1 << n;
    for(int i = 0; i < lim; ++i) a[i] = d[i] = g[i] = read();
	for(int i = 0; i < lim; ++i) b[i] = e[i] = h[i] = read();
    mulor(a, b, c, lim);
    muland(d, e, f, lim);
    mulxor(g, h, i, lim);
    for(int i = 0; i < lim; ++i) printf("%d ", c[i]); puts("");
    for(int i = 0; i < lim; ++i) printf("%d ", f[i]); puts("");
    for(int j = 0; j < lim; ++j) printf("%d ", i[j]); puts("");
    gc(), gc();
    return 0;
}
原文地址:https://www.cnblogs.com/XiaoVsun/p/13054137.html