O(nlog^2/loglogn)的cdq FFT

论文鸽在群里说了一下这个东西,我也实现了一下,发现效果还不错。

由于这个 exp 的 (O(nlog n)) 算法非常的慢,所以我们一般采用 (O(nlog^2 n)) 的分治 FFT 来求解。

普通的分治 FFT 已经可以与论文鸽的 (O(nlog n)) exp 五五开了,但是有没有更快的方法呢?

注意,这个优化只能在 cdq FFT 的时候采用,也就是说不能优化 n 个一次多项式的卷积之类的问题。

(O(nlog^2n))

我们先来回忆一下普通的分治 FFT 是如何做的。

我们假设 (F(x) = e^{G(X)}),这里我们知道 (G(x)),我们要求解 (F(x))

两侧求导,得 (F^{'}(x) = e^{G(X)} imes G^{'}(X) = F(X) imes G^{'}(X))

也就是说我们是对这个式子进行求解:(F^{'}(x) = F(X) imes G^{'}(X))

我们采用 (solve(l, r)) 表示求解 (F(X)) 的第 (l) 项到第 (r) 项。

取区间中点 (mid)

先调用 (solve(l, mid)) 来求解出前半部分。

再计算左侧对右侧的贡献。

再调用 (solve(mid + 1, r)) 来求解出后半部分。

这就是普通的分治 FFT。

(O(frac{nlog^2n}{log log n}))

首先分治 FFT 是一个树状结构,我们往往可以尝试增加一层往下的分支数来优化深度。

我们设分支数为 (B)

如果直接分治 FFT,那么需要计算每个儿子对后面儿子的贡献,每次计算需要一个长度为 (O(n / B)) 的卷积((n)为目前分治区间长度)。

也就是说时间复杂度 (T(n) = B imes T(frac{n}{B}) + B imes n imes log {frac{n}{b}}),大力求解得 (B=2) 时最优。

我是不是在玩你。

好的我们继续。

我们真的是每一对儿子都要用一次卷积来计算贡献吗?

我们可以先求出这个儿子的点值,考虑它对后面儿子的贡献,这个儿子的点值就不用重复计算了。

存储前面儿子对它的贡献的时候,你也可以直接存储点值,最后一次 FFT 转换即可,也不用多次计算了。

而卷上的是 (G^{'}(x)) 的一个区间,这个区间的点值也可以提前计算。

也就是说我们只需要 (O(B)) 次长度为 (O(n / B)) 的 FFT 了!

因此计算贡献的部分复杂度变为了 (O(B^2 imes frac{n}{B} + B imes frac{n}{b} imes log {frac{n}{b}}))

(O(Bn + n imes log {frac{n}{b}}))

时间复杂度 (T(n) = B imes T(frac{n}{B}) + Bn + n imes log {frac{n}{b}})。不错,求解一下。

发现 (B = O(log n)) 的时候最优秀,时间复杂度为(O(frac{nlog^2n}{log log n}))

事实上由于计算贡献的时候非常的****,可以使用 avx2 进行优化,亲测一定的常数优化之后进行 (4 imes 10^6) 的 exp 只需要 1.5s。

当然其他类似的 cdq FFT 也可以这样进行优化,祝大家早日吊打 (O(nlog n))

贴代码:

#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,avx2,tune=native")
#include<bits/stdc++.h>
#define rep(i, l, r) for(int i = (l), i##end = (r);i <= i##end;++i)
const int maxn = 1 << 19 | 1;
typedef long long ll;
typedef unsigned long long u64, ull;
const int mod = 998244353;
struct istream {
	static const int size = 1 << 21;
	char buf[size], *vin;
	inline istream() {
		fread(buf,1,size,stdin);
		vin = buf - 1;
	}
	inline istream& operator >> (int & x) {
		for(x = *++vin & 15;isdigit(*++vin);) x = x * 10 + (*vin & 15);
		return * this;
	}
} cin;
struct ostream {
	static const int size = 1 << 21;
	char buf[size], *vout;
	unsigned map[10000];
	inline ostream() {
		for(int i = 0;i < 10000;++i) {
			int p = i;
			map[i] = p % 10 + 48, p /= 10;
			map[i] = map[i] << 8 | p % 10 + 48, p /= 10;
			map[i] = map[i] << 8 | p % 10 + 48, p /= 10;
			map[i] = map[i] << 8 | p % 10 + 48, p /= 10;
		}
		vout = buf + size;
	}
	inline ~ ostream()
	{ fwrite(vout,1,buf + size - vout,stdout); }
	inline ostream& operator << (int x) {
		for(;x > 10000;x /= 10000) *--(unsigned*&)vout = map[x % 10000];
		do *--vout = x % 10 + 48; while(x /= 10);
		return * this;
	}
	inline ostream& operator << (char x) {
		*--vout = x;
		return * this;
	}
} cout;
inline ll pow(ll a,int b,ll ans = 1){ for(;b;b >>= 1, a = a * a % mod) if(b & 1) ans = ans * a % mod; return ans; }
inline ll inverse(ll x){ return pow(x, mod - 2); }
int wn[1 << 13], rev[1 << 14], inv[maxn], lim, invlim;
inline void init_(int n) {
	int N = 1; for(;N < n;) N <<= 1;
	for(int i = 1;i < N;i <<= 1) {
		const int w = pow(3, mod / i / 2); wn[i] = 1;
		for(int j = 1;j < i;++j) wn[i + j] = (ll) wn[i + j - 1] * w % mod;
	}
	for(int i = 1;i <= N;i <<= 1) {
		for(int j = 1;j < i;++j) rev[i + j] = rev[i + (j >> 1)] >> 1 | j % 2 * i / 2;
	}
}
inline void init(int len) {
	lim = len; invlim = mod - (mod - 1) / lim;
}
inline void reduce(int & x) {
	x += x >> 31 & mod;
}
static u64 t[1 << 13];
inline void fft(int * a,int type) {
	for(int i = 0;i < lim;++i) t[i] = a[rev[i + lim]];
#define trans(i, j, k) 
	{ 
		const u64 x = wn[i + k] * t[i + j + k] % mod; 
		t[i + j + k] = t[j + k] + mod - x, t[j + k] += x; 
	}
	for(int i = 1;i < lim;i <<= 1) {
		if(i == 1) {
			for(int j = 0;j < lim;j += 8) {
				trans(1, j, 0);
				trans(1, j + 2, 0);
				trans(1, j + 4, 0);
				trans(1, j + 6, 0);
			}
		} else if(i == 2) {
			for(int j = 0;j < lim;j += 8) {
				trans(2, j, 0);
				trans(2, j, 1);
				trans(2, j + 4, 0);
				trans(2, j + 4, 1);
			}
		} else {
			for(int j = 0;j < lim;j += i + i) for(int k = 0;k < i;k += 4) {
				trans(i, j, k + 0);
				trans(i, j, k + 1);
				trans(i, j, k + 2);
				trans(i, j, k + 3);
			}
		}
	}
	if(type == 1) {
		for(int i = 0;i < lim;++i) a[i] = t[i] % mod;
	}
	if(type == 0) {
		a[0] = t[0] * invlim % mod;
		for(int i = 1;i < lim;++i) a[i] = t[lim - i] * invlim % mod;
	}
}
inline void fill(int * a, const int * b, int len) {
	memcpy(a, b, len << 2), memset(a + len, 0, lim - len << 2);
}
typedef std::function<int(int, int*)> fc;
struct solver {
	static const int C = 128;
	static const int B = 64;
	int n, N;
	int rem[maxn], g[maxn], * MM;

	int M[B][(maxn << 1) / B];
	u64 g0[maxn << 2];
	inline void Init(int len, int * multi) {
		MM = multi;
		for(n = len, N = 1;N < len;N <<= 1);
		for(int mid = (N + N) / B;mid > 1;mid /= B) {
			init(mid); 
			for(int j = 0;j + 1 < B;++j) {
				if(j * mid / 2 < n) {
					for(int i = 0;i < mid;++i) M[j][mid + i] = MM[i + j * mid / 2];
					fft(M[j] + mid, 1);
				}
			}
		}
	}
	inline void solve(int l, int r, u64 * g0, const fc & xxx) {
		if(r - l < C) {
			for(int i = l;i < r;++i) {
				int j = l;
				u64 x = rem[i];
#define T(o) (u64) g[j + o] * MM[i - j - o]
				for(;j + 15 < i;j += 16) {
					x = (x + T(0) + T(1) + T(2) + T(3) + T(4) + T(5) + T(6) + T(7) + 
						 	 T(8) + T(9) + T(10) + T(11) + T(12) + T(13) + T(14) + T(15)) % mod;
				}
				if(j + 7 < i) x += T(0) + T(1) + T(2) + T(3) + T(4) + T(5) + T(6) + T(7), j += 8;
				if(j + 3 < i) x += T(0) + T(1) + T(2) + T(3), j += 4;
				if(j + 1 < i) x += T(0) + T(1), j += 2;
				if(j < i) x += T(0);
#undef T
				rem[i] = x % mod;
				g[i] = xxx(i, rem + i);
			}
			return ;
		}
		const int DT = (r - l) / B;
		if(l) memset(g0, 0, r - l << 4);
		int end = 0;
		for(;end < B && l + end * DT < n;++end);
		for(int i = 0;i < end;++i) {
			int L = l + i * DT, R = L + DT;
			if(i) {
				static int T[maxn];
				init(DT + DT);
				for(int j = 0;j < lim;++j) T[j] = g0[2 * i * DT + j] % mod;
				fft(T, 2);
				for(int j = L;j < R;++j) rem[j] = (rem[j] + (ll) invlim * t[lim - j + L - DT]) % mod;
			}
			solve(L, R, g0 + (r - l << 1), xxx);
			if(i != end - 1) {
				init(DT + DT);
				static int b[maxn];
				fill(b, g + L, R - L), fft(b, 1);
				for(int j = i + 1;j < end;++j) {
					ull * g1 = g0 + lim * j;
					if(i == B / 2) {
						for(int k = 0;k < lim;++k) {
							g1[k] = (g1[k] + (ll) b[k] * M[j - i - 1][lim + k]) % mod;
						}
					} else {
						for(int k = 0;k < lim;++k) {
							g1[k] += (ll) b[k] * M[j - i - 1][lim + k];
						}
					}
				}
			}
		}
	}
	inline void solve(fc x) { solve(0, N, g0, x); }
};
int n, a[maxn], b[maxn];
int main() {
	static solver ln, exp;
	cin >> n;
	for(int i = 0;i < n;++i) {
		cin >> a[i]; if(i) a[i] = mod - a[i];
		b[i] = (ll) a[i] * i % mod;
	}
	inv[1] = 1;
	for(int i = 2;i < n;++i) {
		inv[i] = ll(mod - mod / i) * inv[mod % i] % mod;
	}
	init_((n + n) / solver::B + 1);
	ln.Init(n, a);
	ln.solve([](int pos, int * now) { return reduce(*now -= b[pos + 1]), *now; });
	for(int i = 1;i < n;++i) {
		b[i] = (ll) ln.g[i - 1] * inv[2] % mod;
	}
	exp.Init(n, b);
	exp.solve([](int pos, int * now) { return int(pos == 0 ? 1 : (ll) *now * inv[pos] % mod); });
	for(int i = n - 1;i >= 0;--i) {
		cout << ' ' << exp.g[i];
	}
}

原文地址:https://www.cnblogs.com/skip1978/p/12408384.html