多项式除法及取模

多项式除法及取模

http://blog.miskcoo.com/2015/05/polynomial-division

概述

给出一个 (n) 次多项式 (A(x)) 和一个 (m(m le n)) 次多项式 (B(x)) ,要求求出两个多项式 (D(x), R(x)) , 满足

[A(x) = D(x)B(x) + R(x) ]

其中 (degD le degA - degB = n - m, degR < m)

可以在 (O(n log n)) 的时间求解

原理

首先,我们先想办法消除 (R(x)) 的影响,我们定义

[A^R(x) = x^nA(dfrac 1x) ]

实际上就是将 (A(x)) 的系数翻转, 例如

[A(x) = x^3 + 2x^2 + 4x + 1 \ A^R(x) = x^3(x^{-3} + 2x^{-2} + 4x^{-1} + 1) = 1 + 2x + 4x^2 + x^3 ]

接下来,我们将 ((1)) 中的 (x)(dfrac 1x) 替换,并在两边同乘 (x^n)

[x^nA(dfrac 1x) = x^{n-m}D(dfrac 1x) x^m B(dfrac 1x) + x^{n - m + 1}x^{m - 1} R(dfrac 1x) \ A^R(x) = D^R(x) B^R(x) + x^{n - m +1} R^R(x) ]

观察发现此时 (x^{n - m + 1}R^R(x)) 的非零项都在 (n - m + 1) 上,而 (D^R(x)) 的最高次项为 (n - m) , 那么我们有

[A^R(x) equiv D^R(x)B^R(x) (mod ; x^{n - m + 1}) ]

那么我们就可以用一次求出逆元的复杂度求出 (D(x)) ,再代回原式就可以得到 (R(x))

Code

  1. (B) 翻转,求出其在 (mod ; x^{n - m - 1}) 意义下的逆元 (B'^R)
  2. (A) 翻转, 得到 (D^R(x) = A^R(x) cdot B'^R(x))
  3. 求出 (R(x) = A(x) - D(x)B(x))

不打换行真的会死.......

洛谷 P4512

#include <algorithm>
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
inline char nc() {
	static char buf[100000], *l = buf, *r = buf;
	return l==r&&(r=(l=buf)+fread(buf,1,100000,stdin),l==r)?EOF:*l++;
}
template<class T> void read(T &x) {
	x = 0; int f = 1, ch = nc();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=nc();}
	while(ch>='0'&&ch<='9'){x=x*10-'0'+ch;ch=nc();}
	x *= f;
}
#define inv(a) quick_power(a, mod - 2)
typedef long long ll;
const int g = 3;
const int mod = 998244353;
const int phi = mod - 1;
const int maxn = 400000 + 5;
const int maxlog = 20;
int n, m; ll F[maxn], G[maxn], D[maxn], R[maxn];
inline ll sum(ll x) {
	return x >= mod ? x - mod : x;
}
inline ll dec(ll x) {
	return x < 0 ? x + mod : x;
}
ll quick_power(ll x, ll y) {
	ll re = 1;
	while(y) {
		if(y & 1) re = re * x % mod;
		x = x * x % mod;
		y >>= 1;
	}
	return re;
}
namespace polynomial {
	int rev[maxn];
	ll w[maxlog][maxn][2];
	void init() {
		for(int i = 1, s = 0; i < maxn; i <<= 1, ++s) {
			ll wn0 = quick_power(g, phi / (i * 2));
			ll wn1 = quick_power(g, -phi / (i * 2) + phi);
			w[s][0][0] = w[s][0][1] = 1;
			for(int k = 1; k < i; ++k) {
				w[s][k][0] = w[s][k - 1][0] * wn0 % mod;
				w[s][k][1] = w[s][k - 1][1] * wn1 % mod;
			}
		}
	}
	void init_rev(int n, int L) {
		for(int i = 1; i < n; ++i) {
			rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (L - 1));
		}
	}
	void FFT(ll *A, int n, int f) {
		int d = f == -1;
		for(int i = 0; i < n; ++i) if(i > rev[i]) {
			swap(A[i], A[rev[i]]);
		}
		for(int i = 1, s = 0; i < n; i <<= 1, ++s) {
			for(int j = 0, p = i << 1; j < n; j += p) {
				ll *u = A + j, *v = A + j + i;
				for(int k = 0; k < i; ++k, ++u, ++v) {
					ll x = *u, y = *v * w[s][k][d] % mod;
					*u = sum(x + y);
					*v = dec(x - y);
				}
			}
		}
		if(f == -1) {
			ll r = inv(n);
			for(int i = 0; i < n; ++i) {
				A[i] = A[i] * r % mod;
			}
		}
	}
	void inverse(int step, ll *A, ll *B) {
		static ll T[maxn];
		
		if(step == 1) {
			B[0] = inv(A[0]);
			return;
		}
		
		inverse((step + 1) >> 1, A, B);
		
		int n = 1, L = 0, m = step << 1;
		for(n = 1; n <= m; n <<= 1) ++L;
		init_rev(n, L);
		
		copy(A, A + step, T), fill(T + step, T + n, 0);

		FFT(T, n, 1);
		FFT(B, n, 1);
		for(int i = 0; i < n; ++i) {
			B[i] = dec(2 - T[i] * B[i] % mod) * B[i] % mod;
		}
		FFT(B, n, -1);
		
		fill(B + step, B + n, 0);
	}
	void division(ll *A, ll *B, int n, int m, ll *D, ll *R) {
		static ll A0[maxn], B0[maxn];
		memset(A0, 0, sizeof(A0));
		memset(B0, 0, sizeof(B0));

		int len = n - m + 1;

		reverse_copy(B, B + m + 1, A0);
		inverse(len, A0, B0);

		int p, L = 0, tmp = len << 1;
		for(p = 1; p <= tmp; p <<= 1) ++L;
		init_rev(p, L);
		
		reverse_copy(A, A + n + 1, A0), fill(A0 + len, A0 + p, 0);
		
		FFT(A0, p, 1);
		FFT(B0, p, 1);
		for(int i = 0; i < p; ++i) {
			A0[i] = A0[i] * B0[i] % mod;
		}
		FFT(A0, p, -1);
		
		reverse(A0, A0 + len);
		copy(A0, A0 + len, D);
		
		L = 0, tmp = n;
		for(p = 1; p <= tmp; p <<= 1) ++L;
		init_rev(p, L);
		
		copy(B, B + m + 1, B0), fill(B0 + m + 1, B0 + p, 0);
		fill(A0 + len, A0 + p, 0);
		
		FFT(A0, p, 1);
		FFT(B0, p, 1);
		for(int i = 0; i < p; ++i) {
			A0[i] = A0[i] * B0[i] % mod;
		}
		FFT(A0, p, -1);
		
		for(int i = 0; i < m; ++i) {
			R[i] = dec(A[i] - A0[i]);
		}
	}
}
int main() {
//	freopen("testdata.in", "r", stdin);
	read(n), read(m);
	for(int i = 0; i <= n; ++i) read(F[i]);
	for(int i = 0; i <= m; ++i) read(G[i]);

	polynomial :: init();
	polynomial :: division(F, G, n, m, D, R);
	
	for(int i = 0; i <= n - m; ++i) {
		if(i) printf(" ");
		printf("%lld", D[i]);
	}
	printf("
");
	for(int i = 0; i < m; ++i) {
		if(i) printf(" ");
		printf("%lld", R[i]);
	}
	printf("
");
	
	return 0;
}
原文地址:https://www.cnblogs.com/ljzalc1022/p/12909522.html