MTT学习小记

这是个毒瘤题才有的毒瘤东西……奶一口NOI不考

拆系数FFT:


考虑做NTT时模数不是NTT模数((2^a*b+1))怎么办?

很容易想到拆次数FFT。

比如说现在求(a*b),设(m=sqrt mo(2^{15}))

那么把(a[i])拆成(a0[i]+a1[i]*m),b[i]拆成(b0[i]+b1[i]*m)

那么(a[i]*b[j]=a0[i]*b0[j]+(a0[i]*b1[j]+a1[i]*b0[j])*m+(a1[i]*b1[j])*m^2)

由于(a0,b1,b0,b1)的大小都不到,所以FFT不会爆精度。

那么这个最好也需要4(正)+3(逆)=7,复杂度不能接受。

DFT:


一开始要做4次DFT,我们两两一起做。

假设现在有两个序列A、B,要求DFT(A)和DFT(B)

(P=A+B*i,Q=A-B*i)

只用做P的DFT,便可得到Q的DFT。

(DFT(Q)[n-i]=conj(DFT(P)[i]))(conj)为共轭,就是把虚部系数取反。

证明:

(conj(DFT(P)[i]))

(=conj(sum_{j=0}^{n-1}w_n^{ji}*(A[i]+B[i]*sqrt{-1}))

(=sum_{j=0}^{n-1}conj(w_n^{ji})*conj(A[i]+B[i]*sqrt{-1}))

(=sum_{j=0}^{n-1}w_n^{-ij}*(A[i]-B[i]*sqrt{-1}))

(=DFT(Q)[n-i])

这样求出了(DFT(P),DFT(Q))

那么(DFT(A)=(DFT(P)+DFT(Q))*({1over2},0),\DFT(B)=(DFT(P)-DFT(Q))*(-1)*(0, {1over2}))

这个东西显然有一个条件是(A、B)只能实部有值,不然会混乱了无法提出来的。

IDFT:


接下来是IDFT,同样的可以两个 一起做。

如果有(A、B),都只有实部有值,设(C=DFT(A)+i*DFT(B))

显然(IDFT(C))的实部就是A,虚部就是B

这样我们就用四次DFT完成啦!

Code:


#include<bits/stdc++.h>
#define fo(i, x, y) for(int i = x, B = y; i <= B; i ++)
#define ff(i, x, y) for(int i = x, B = y; i <  B; i ++)
#define fd(i, x, y) for(int i = x, B = y; i >= B; i --)
#define ll long long
#define pp printf
#define hh pp("
")
#define db double
using namespace std;

const db pi = acos(-1);

const int mo = 1e9 + 7;

struct P {
	db x, y;
	P(db _x = 0, db _y = 0) { x = _x, y = _y;}
	P operator + (P b) { return P(x + b.x, y + b.y);}
	P operator - (P b){  return P(x - b.x, y - b.y);}
	P operator * (P b) { return P(x * b.x - y * b.y, x * b.y + y * b.x);}
};

const int nm = 1 << 18;
P w[nm]; int r[nm];
P c0[nm], c1[nm], c2[nm], c3[nm];

void dft(P *a, int n) {
	ff(i, 0, n) {
		r[i] = r[i / 2] / 2 + (i & 1) * (n / 2);
		if(i < r[i]) swap(a[i], a[r[i]]);
	} P b;
	for(int i = 1; i < n; i *= 2) for(int j = 0; j < n; j += 2 * i)
		ff(k, 0, i) b = a[i + j + k] * w[i + k], a[i + j + k] = a[j + k] - b, a[j + k] = a[j + k] + b;
}
void rev(P *a, int n) {
	reverse(a + 1, a + n);
	ff(i, 0, n) a[i].x /= n, a[i].y /= n;
}
P conj(P a) { return P(a.x, -a.y);}
void fft(ll *a, ll *b, int n) {
	#define qz(x) ((ll) round(x))
//	ff(i, 0, n) c0[i] = P(a[i], 0), c1[i] = P(b[i], 0);
//	dft(c0, n); dft(c1, n);
//	ff(i, 0, n) c0[i] = c0[i] * c1[i];
//	dft(c0, n); rev(c0, n);
//	ff(i, 0, n) a[i] = qz(c0[i].x);
	ff(i, 0, n) c0[i] = P(a[i] & 32767, a[i] >> 15), c1[i] = P(b[i] & 32767, b[i] >> 15);
	dft(c0, n); dft(c1, n);
	ff(i, 0, n) {
		P k, d0, d1, d2, d3;
		int j = (n - i) & (n - 1);
		k = conj(c0[j]);
		d0 = (k + c0[i]) * P(0.5, 0);
		d1 = (k - c0[i]) * P(0, 0.5);
		k = conj(c1[j]);
		d2 = (k + c1[i]) * P(0.5, 0);
		d3 = (k - c1[i]) * P(0, 0.5);
		c2[i] = d0 * d2 + d1 * d3 * P(0, 1);
		c3[i] = d0 * d3 + d1 * d2;
	}
	dft(c2, n); dft(c3, n); rev(c2, n); rev(c3, n);
	ff(i, 0, n) {
		a[i] = qz(c2[i].x) + (qz(c2[i].y) % mo << 30) + (qz(c3[i].x) % mo << 15);
		a[i] %= mo;
	}
}

ll a[nm], b[nm]; 

int main() {
	for(int i = 1; i < nm; i *= 2) ff(j, 0, i)
		w[i + j] = P(cos(pi * j / i), sin(pi * j / i));
	fo(i, 0, 15) a[i] = b[i] = mo - 1;
	fft(a, b, 32);
	ff(i, 0, 32) pp("%lld ", a[i]);
}
原文地址:https://www.cnblogs.com/coldchair/p/11129417.html