LOJ #3219. 「PA 2019」Iloczyny Fibonacciego (斐波拉契表示性质+FFT)

https://loj.ac/problem/3219

题解:

  • (F[n+m]=F[n]*F[m]+F[n-1]*F[m-1])
  • (F[n]*F[m]=F[n+m]-(F[n-1]*F[m-1]))
  • (...)
  • (F[n]*F[m]=F[n+m]-F[n+m-2]+F[n+m-4]…+(-1)^{min(n,m)}*F[|n-m|])

用FFT求(A(x)*B(x)),再求(A(-x)*B(frac{1}{x})),即可得到答案是两个斐波拉契表示(非01最简表示)的差。

考虑把一个斐波拉契表示(非01最简)化简成01表示,即化简(sum_{i=1}^n a[i]*F[i]),可以用分治:

(sum_{i=1}^n a[i]*F[i]=sum_{i=1}^n a[i]~mod~2*F[i]+2(sum_{i=1}^n lfloor frac{a[i]}{2} floor*F[i]))

问题在于两个01最简表示做怎么(O(n))做加法?

  • 考虑倒着做,每次在(i)这个位置加(a[i]+b[i])次1,不难发现复杂度挺对的。

那么最后的问题:做差

  • 还是倒着做,如果(i)位置要(-1),找到最小的(j)满足(a[j]=1)(jge i),发现是把它-1,然后给(i-1..j-2~)加1,不难发现复杂度还是对的。

时间复杂度:(O(n~log~n))

Code:

#include<bits/stdc++.h>
#define fo(i, x, y) for(int i = x, _b = y; i <= _b; i ++)
#define ff(i, x, y) for(int i = x, _b = y; i <  _b; i ++)
#define fd(i, x, y) for(int i = x, _b = y; i >= _b; i --)
#define ll long long
#define pp printf
#define hh pp("
")
using namespace std;

const int nm = 1 << 21;

#define db double
namespace fft {
	const db pi = acos(-1);
	struct P {
		db x, y;
		P(db _x = 0, db _y = 0) {
			x = _x, y = _y;
		}
	};
	P operator + (P a, P b) { return P(a.x + b.x, a.y + b.y);}
	P operator - (P a, P b) { return P(a.x - b.x, a.y - b.y);}
	P operator * (P a, P b) { return P(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);}
	P w[nm]; int r[nm];
	void build() {
		for(int i = 1; i < nm; i *= 2) {
			ff(j, 0, i) {
				w[i + j] = P(cos(pi * j / i), sin(pi * j / i));
			}
		}
	}
	void dft(P *a, int n, int f) {
		ff(i, 0, n) {
			r[i] = r[i / 2] / 2 + (i & 1) * (n / 2);
			if(i < r[i]) swap(a[i], a[r[i]]);
		} P b;
		for(int i = 1; i < n; i *= 2) for(int j = 0; j < n; j += 2 * i) ff(k, 0, i)
			b = a[i + j + k] * w[i + k], a[i + j + k] = a[j + k] - b, a[j + k] = a[j + k] + b;
		if(f == -1) {
			reverse(a + 1, a + n);
			ff(i, 0, n) a[i].x /= n, a[i].y /= n;
		}
	}
	P c[nm], d[nm], p[nm];
	P conj(P a) { return P(a.x, -a.y);} 
	void mtp(ll *a, ll *b, int n) {
		ff(i, 0, n) p[i] = P(a[i], b[i]);
		dft(p, n, 1);
		ff(i, 0, n) {
			P k = conj(p[(n - i) % n]);
			c[i] = (p[i] + k) * P(0.5, 0);
			d[i] = (p[i] - k) * P(0, -0.5);
			c[i] = c[i] * d[i];
		}
		dft(c, n, -1);
		ff(i, 0, n) a[i] = round(c[i].x);
	}
}
using fft :: mtp;

int n, m;
ll a[nm], b[nm], a0[nm], b0[nm], c0[nm];
ll c[nm], d[nm];
ll p[nm], q[nm];

void add(ll *f, int x) {
	if(x == 0) {
		add(f, x + 1);
		return;
	}
	if(f[x + 1]) {
		f[x + 1] = 0;
		add(f, x + 2);
		return;
	}
	if(f[x - 1]) {
		f[x - 1] = 0;
		add(f, x + 1);
		return;
	}
	if(f[x]) {
		f[x] = 0;
		if(x == 1) {
			add(f, x + 1);
			return;
		}
		add(f, x - 2); add(f, x + 1);
		return;
	}
	f[x] = 1;
}

void plu(ll *p, ll *q, int n) {
	static ll f[nm];
	fo(i, 0, n) f[i] = 0;
	fd(i, n, 1) {
		fo(j, 1, p[i] + q[i]) add(f, i);
	}
	fo(i, 0, n) p[i] = f[i];
}

bool bz[35][nm];
void solve(int D, ll *c, ll *p, int n) {
	ll mx = 0;
	fo(i, 1, n) mx = max(mx, c[i]);
	if(mx == 0) {
		fo(i, 1, n) p[i] = 0;
		return;
	}
	fo(i, 1, n) {
		bz[D][i] = c[i] & 1;
		c[i] /= 2;
	}
	solve(D + 1, c, p, n);
	plu(p, p, n);
	static ll g[nm];
	fo(i, 1, n) g[i] = bz[D][i];
	plu(p, g, n);
}

void dec(ll *p, ll *q, int n) {
	fd(i, n, 1) if(q[i]) {
		int l = i;
		while(!p[l]) l ++;
		p[l] --;
		fd(j, l - 2, i - 1) add(p, j);
	}
}

void work() {
	scanf("%d", &n);
	fo(i, 1, n) scanf("%lld", &a[i]);
	scanf("%d", &m);
	fo(i, 1, m) scanf("%lld", &b[i]);
	
	int tp = 1;
	while(1 << ++ tp <= n + m);
	ff(i, n + 1, 1 << tp) a[i] = 0;
	ff(i, m + 1, 1 << tp) b[i] = 0;
	ff(i, 0, 1 << tp) a0[i] = a[i], b0[i] = b[i];
	mtp(a, b, 1 << tp);
	
//	fo(i, 1, n) if(a0[i]) fo(j, 1, m) if(b0[j]) {
//		int w = abs(i - j);
//		c0[w] += min(i, j) % 2 ? -1 : 1;
//	}
	fo(i, 1, n) a0[i] = a0[i] * (i % 2 ? -1 : 1);
	reverse(b0 + 0, b0 + m + 1);
	mtp(a0, b0, 1 << tp);
	ff(i, 0, (1 << tp) + 100) c0[i] = 0;
	ff(i, 0, 1 << tp) if(a0[i]) {
		int j = abs(i - m);
		if(i - m <= 0) {
			c0[j] += a0[i];
		} else {
			c0[j] += a0[i] * (j % 2 ? -1 : 1);
		}
	}
	
	n = (1 << tp) + 100;
	fo(i, 0, (1 << tp) + 100) c[i] = d[i] = 0;
	fd(i, (1 << tp) - 1, 0) {
		c[i] = a[i] + d[i + 2];
		d[i] = c[i + 2];
		if(c0[i + 2] > 0) c[i] += c0[i + 2]; else
			d[i] -= c0[i + 2];
	}
	c[1] += c[0]; c[0] = 0;
	d[1] += d[0]; d[0] = 0;
	fo(i, 0, n) p[i] = q[i] = 0;
	solve(0, c, p, n);
	solve(0, d, q, n);
	m = 0;
	dec(p, q, n);
	fo(i, 1, n) if(p[i]) m = i;
	pp("%d ", m);
	fo(i, 1, m) pp("%lld ", p[i]); hh;
}

int main() {
	fft :: build();
	int T; scanf("%d", &T);
	fo(ii, 1, T) {
		work();
	}
}
原文地址:https://www.cnblogs.com/coldchair/p/13045587.html