BSGS学习笔记

BSGS算法

BSGS算法使用来求解(y)的方程

[x ^ y equiv zpmod p ]

其中(gcd(x, p) = 1), 我们将(y)写做一个(am - b)的形式, 其中(a in (1, m + 1]), (b in [0, m))

那么这样, 原式就变成了

[egin{aligned} x^{am - b} &equiv z pmod p\ x^{am} &equiv z * x ^ b pmod p end{aligned} ]

我们枚举每一个(b), 将(z * x ^ b)存进(hash)或者(map)之类的东西

之后对于左边枚举每一个(a), 如果(x ^ {am} \% p)(hash)或者(map)中, 答案就为(a * m - mp[x ^ {am} \% p])

复杂度为(O(max(m, p / m))), 易证得(m)(sqrt{p})的时候最优

Code

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#include <vector>
#include <cmath>
#include <map>
#define itn int
#define reaD read
using namespace std;

int p, x, y, s, m;
map<int, int> mp; 

inline int read()
{
	int x = 0, w = 1; char c = getchar();
	while(c < '0' || c > '9') { if (c == '-') w = -1; c = getchar(); }
	while(c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar(); }
	return x * w;
}

int fpow(int x, int k)
{
	int res = 1;
	while(k)
	{
		if(k & 1) res = 1ll * res * x % p;
		x = 1ll * x * x % p;
		k >>= 1; 
	}
	return res; 
}

int main()
{
	p = reaD(); x = read(); y = read(); m = sqrt(p) + 1; s = y; 
	for(int i = 0; i < m; i++) mp[s] = i, s = 1ll * s * x % p; 
	s = 1; int t = fpow(x, m); 
	for(int i = 1; i <= m + 1; i++)
	{
		s = 1ll * s * t % p; 
		if(mp.count(s))
		{
			printf("%d
", i * m - mp[s]);
			return 0; 
		}
	}
	puts("no solution"); 
	return 0;
}

ExBSGS算法

好像所有有(Ex)的算法似乎都是在不互质的情况下诶

这里, (ExBSGS)是用来处理(gcd(x, p) != 1)的情况的, 我们可以有这样一个式子, 设(gcd(x, p) = d)

注意, 此时若(d)不整除(z), 方程无解, 若(d)整除(z), 则有

[egin{aligned} frac{x ^ y}{d} &equiv frac{z}{d} pmod{frac{p}{d}}\ frac{x}{d} * x^{y - 1} &equiv frac{z}{d} pmod{frac{p}{d}}\ end{aligned} ]

注意到(frac{x}{d})变成了一个系数, 当(gcd(x, p / d))不等于1时不断地除以(gcd(x, p / d)), 我们最终可以得到

[frac{x ^ k}{d}*x^{y - k} equiv frac{z}{d} pmod{frac{p}{d}} ]

注意到此时式子中的(d)不是第一次算(gcd)(d)了, 他是所有不为1的(gcd)的积

带上个系数跑BSGS, 最后答案加上(k)即可

Code

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#include <vector>
#include <cmath>
#include <map>
#define itn int
#define reaD read
#define LL long long
#define MOD 233333
using namespace std;

int x, y, p, d, m, cnt, sum; 
struct MAP {
	LL ha[MOD+5]; int id[MOD+5];
	void clear() {for (int i = 0; i < MOD; i++) ha[i] = id[i] = -1; }
	int count(LL x) {
		LL pos = x%MOD;
		while (true) {
			if (ha[pos] == -1) return 0;
			if (ha[pos] == x) return 1;
			++pos; if (pos >= MOD) pos -= MOD;
		}
	}
	void insert(LL x, int idex) {
		LL pos = x%MOD; 
		while (true) {
			if (ha[pos] == -1 || ha[pos] == x) { ha[pos] = x, id[pos] = idex; return; }
			++pos; if (pos >= MOD) pos -= MOD; 
		}
	}
	int query(LL x) {
		LL pos = x%MOD;
		while (true) {
			if (ha[pos] == x) return id[pos];
			++pos; if (pos >= MOD) pos -= MOD;
		}
	}
}mp;

int gcd(int n, int m) { return m ? gcd(m, n % m) : n; }

int fpow(int x, int y)
{
	int res = 1;
	while(y)
	{
		if(y & 1) res = 1ll * res * x % p;
		x = 1ll * x * x % p;
		y >>= 1; 
	}
	return res; 
}

int exbsgs(int x, int y, int p)
{
	if(y == 1) return 0;
	cnt = 0; sum = 1; mp.clear();
	while((d = gcd(x, p)) != 1)
	{
		if(y % d) return -1;
		cnt++; p /= d; y /= d; sum = 1ll * sum * (x / d) % p;
		if(y == sum) return cnt; 
	}
	m = sqrt(p) + 1;
	for(int i = 0; i < m; i++) mp.insert(y, i), y = 1ll * y * x % p;
	y = sum; x = fpow(x, m);
	for(int i = 1; i <= m + 1; i++)
	{
		y = 1ll * y * x % p;
		if(mp.count(y)) return i * m - mp.query(y) + cnt; 
	}
	return -1; 
}

int main()
{
	while(scanf("%d%d%d", &x, &p, &y) != EOF)
	{
		if(!x && !p && !y) break; 
		int ans = exbsgs(x, y, p);
		ans == -1 ? puts("No Solution") : printf("%d
", ans); 
	}
	return 0;
}
原文地址:https://www.cnblogs.com/ztlztl/p/11024724.html