类欧几里得算法 重学笔记

Solution

以前学过,但是太烂,而且很有局限性,今重学一遍。

考虑假设我们要解决的问题为求:

[sum_{x=0}^{n} x^{k1}lfloorfrac{ax+b}{c} floor^{k2} ]

可以发现可以分为几种情况进行讨论:

  1. (a=0) 或者 (lfloorfrac{an+b}{c} floor=0)

可以发现 (lfloorfrac{ax+b}{c} floor) 不变,直接 (k1) 次的前缀和就好了。

  1. (age c)

(q=lfloorfrac{a}{c} floor,r=amod c) ,那么可以得到答案就是:

[sum_{x=0}^{n} x^{k1}(qx+lfloorfrac{xr+b}{c} floor)^{k2} ]

[=sum_{i=0}^{k2} q^iinom{k2}{i}sum_{x=0}^{n} x^{k1+i}lfloorfrac{xr+b}{c} floor^{k2-i} ]

直接递归即可。

  1. (bge c)

(q=lfloorfrac{b}{c} floor,r=bmod c),同理可以得到答案就是:

[sum_{i=0}^{k2} inom{k2}{i}q^isum_{x=0}^{n} x^{k1}lfloorfrac{ax+r}{c} floor^{k2-i} ]

  1. (max(a,b)<c)

(M=lfloorfrac{an+b}{c} floor),可以把 (lfloorfrac{ax+b}{c} floor^{k2}) 拆开,变成:

[sum_{j=0}^{lfloorfrac{ax+b}{c} floor-1}((j+1)^{k2}-j^{k2}) ]

那么答案就是:

[sum_{j=0} ((j+1)^{k2}-j^{k2})sum_{x=0}^{n} x^{k1}[x>lfloorfrac{cj+c-b-1}{a} floor] ]

[sum_{j=0} ((j+1)^{k2}-j^{k2})sum_{i=0}^{n} i^{k1}-sum_{j=0} ((j+1)^{k2}-j^{k2}) imes sum_{i=0}^{lfloorfrac{cj+c-b-1}{a} floor}i^{k1} ]

然后前面这部分可以算 (k2) 次的前缀和,考虑如何算后面那一部分。你发现后面那一个是关于 (lfloorfrac{cj+c-b-1}{a} floor)(k1+1) 次的多项式,假设第 (i) 次系数为 (B_i),那么就可以写成:

[sum_{i=0}^{k2-1}inom{k2}{i}sum_{j=0}^{k1+1} B_jsum_{x=0}^{M-1} x^ilfloorfrac{cx+c-b-1}{a} floor^j ]

也可以递归了。

Code

#include <bits/stdc++.h>
using namespace std;

#define Int register int
#define mod 1000000007
#define int long long
#define MAXN 15

template <typename T> inline void read (T &t){t = 0;char c = getchar();int f = 1;while (c < '0' || c > '9'){if (c != ' ' && c != '
') f = -f;c = getchar();}while (c >= '0' && c <= '9'){t = (t << 3) + (t << 1) + c - '0';c = getchar();} t *= f;}
template <typename T,typename ... Args> inline void read (T &t,Args&... args){read (t);read (args...);}
template <typename T> inline void write (T x){if (x < 0){x = -x;putchar ('-');}if (x > 9) write (x / 10);putchar (x % 10 + '0');}
template <typename T> void chkmax (T &a,T b){a = max (a,b);}
template <typename T> void chkmin (T &a,T b){a = min (a,b);}

int mul (int a,int b){return 1ll * a * b % mod;}
int dec (int a,int b){return a >= b ? a - b : a + mod - b;}
int add (int a,int b){return a + b >= mod ? a + b - mod : a + b;}
int qkpow (int a,int b){
	int res = 1;for (;b;b >>= 1,a = mul (a,a)) if (b & 1) res = mul (res,a);
	return res;
}
int inv (int x){return qkpow (x,mod - 2);}
void Add (int &a,int b){a = add (a,b);}
void Sub (int &a,int b){a = dec (a,b);}

struct node{
	int t[MAXN][MAXN];
	node(){memset (t,0,sizeof (t));}
	int * operator [](const int key){return t[key];}
};

int C[MAXN][MAXN],mat[MAXN][MAXN];
struct Func{//处理每个F(i,x) sum_{k=0}^{x} k^i 的i+1次函数
	int a[MAXN];
	int & operator [](const int key){return a[key];}
	void Gauss (int K){
		for (Int i = 0;i <= K;++ i){
			int tmp = i;
			for (Int j = i + 1;j <= K;++ j) if (mat[j][i]){tmp = j;break;}
			if (tmp ^ i) swap (mat[tmp],mat[i]);
			for (Int j = i + 1,iv = inv (mat[i][i]);j <= K;++ j){
				int del = mul (mat[j][i],iv);
				for (Int k = i;k <= K + 1;++ k) Sub (mat[j][k],mul (del,mat[i][k]));
			}
		}
		for (Int i = K;~i;-- i){
			for (Int j = i + 1;j <= K;++ j) Sub (mat[i][K + 1],mul (a[j],mat[i][j]));
			a[i] = mul (mat[i][K + 1],inv (mat[i][i]));
		}
	}
	void gen(int k){
		for (Int i = 0,res = 0;i <= k + 1;++ i) Add (res,qkpow (i,k)),mat[i][k + 2] = res;
		for (Int i = 0;i <= k + 1;++ i) for (Int j = 0,res = 1;j <= k + 1;++ j,res = mul (res,i)) mat[i][j] = res;
		Gauss (k + 1);
	}
	int getit (int k,int x){
		int res = 0;
		for (Int i = k + 1;i >= 0;-- i) res = add (a[i],mul (res,x));
		return res;
	}
}f[MAXN];

int relans (int n,int a,int b,int c,int k1,int k2){
	int res = 0;
	for (Int x = 0;x <= n;++ x)
		Add (res,mul (qkpow (x,k1),qkpow ((a * x + b) / c % mod,k2)));
	return res;
}

node getit (int n,int a,int b,int c){
	node ans;
	if (a == 0 || a * n + b < c){
		int t = (a * n + b) / c % mod;
		for (Int k1 = 0;k1 <= 10;++ k1)
			for (Int k2 = 0,res = 1;k1 + k2 <= 10;++ k2,res = mul (res,t))
				ans[k1][k2] = mul (res,f[k1].getit (k1,n));
	}
	else if (a >= c){
		int q = a / c,r = a % c;
		node lst = getit (n,r,b,c);
		for (Int k1 = 0;k1 <= 10;++ k1)
			for (Int k2 = 0;k1 + k2 <= 10;++ k2)
				for (Int i = 0,res = 1;i <= k2;++ i,res = mul (res,q))
					Add (ans[k1][k2],mul (mul (res,C[k2][i]),lst[k1 + i][k2 - i]));
	}
	else if (b >= c){
		int q = b / c,r = b % c;
		node lst = getit (n,a,r,c);
		for (Int k1 = 0;k1 <= 10;++ k1)
			for (Int k2 = 0;k1 + k2 <= 10;++ k2)
				for (Int i = 0,res = 1;i <= k2;++ i,res = mul (res,q))
					Add (ans[k1][k2],mul (mul (res,C[k2][i]),lst[k1][k2 - i]));
	}
	else{
		int M = (a * n + b) / c;
		node lst = getit (M - 1,c,c - b - 1,a);
		for (Int k1 = 0;k1 <= 10;++ k1)
			for (Int k2 = 0;k1 + k2 <= 10;++ k2){
				if (k2 == 0) ans[k1][k2] = f[k1].getit (k1,n);
				else{	
					ans[k1][k2] = mul (qkpow (M,k2),f[k1].getit (k1,n));
					for (Int i = 0;i <= k2 - 1;++ i)
						for (Int j = 0;j <= k1 + 1;++ j)	
							Sub (ans[k1][k2],mul (mul (C[k2][i],f[k1][j]),lst[i][j]));
				}
			}
	}
	return ans;
}

signed main(){
	for (Int i = 0;i <= 10;++ i) f[i].gen (i);
	for (Int i = 0;i <= 10;++ i){
		C[i][0] = 1;
		for (Int j = 1;j <= i;++ j) C[i][j] = add (C[i - 1][j],C[i - 1][j - 1]);
	}
	int T;read (T);
	while (T --> 0){
		int n,a,b,c,k1,k2;read (n,a,b,c,k1,k2);
		node ans = getit (n,a,b,c);write (ans[k1][k2]),putchar ('
');
	}
	return 0;
}
原文地址:https://www.cnblogs.com/Dark-Romance/p/15008649.html