SGU 140 扩展欧几里得

 题目大意:

给定序列a[] , p , b

希望找到一个序列 x[] , 使a1*x1 + a2*x2 + ... + an*xn = b (mod p)

这里很容易写成 a1*x1 + a2*x2 + ... + an*xn + yp = b

-> a1*x1 + a2*x2 + ... + an*xn + y1*p + y2*p + .... + yn*p = b

->(a1*x1+y1*p) + (a2*x2+y2*p) + ... + (an*xn+yn*p) = b

 y[]是必然有解的 , 这里每一个值都可以看做一个二元方程,用扩展欧几里得求解得到的就是

f1*gcd(a1,p) + f2*gcd(a2,p) + ... + fn*gcd(an,p) = b  (1.1)

这里f[]是未知的 ,只要求扩展欧几里得的过程中记录当答案为ai*xi + yi*p = gcd(ai,p) 是xi的值

那么求出合法的fi , 那么正确的解就是 xi = xi*fi

而式子1.1又可以逐个求扩展欧几里得,然后再逆向求回来

ll cur = b;
bool flag=true;
for(int i=n ; i>=1 ; i--){
  ll tmp;
  ll d = ex_gcd(t[i-1] , tmp , g[i] , f[i]);
  if(cur%d!=0){flag=false;break;}
  f[i] = cur/d*f[i];
  cur -= f[i]*g[i];
}

 1 #include <cstdio>
 2 #include <cstring>
 3 #include <cmath>
 4 #include <ctime>
 5 #include <cstdlib>
 6 #include <set>
 7 #include <iostream>
 8 using namespace std;
 9 
10 #define ll long long
11 int n , p , b;
12 int  a[105];
13 ll g[105] , x[105] , y[105];
14 ll t[105] , f[105];
15 
16 ll ex_gcd(ll a , ll &x , ll b , ll &y)
17 {
18     if(b==0){
19         x = 1 , b = 0;
20         return a;
21     }
22     ll ans = ex_gcd(b , x , a%b , y);
23     ll t=x ;
24     x=y , y=t-(a/b)*y;
25     return ans;
26 }
27 
28 ll gcd(ll a , ll b){return b?gcd(b,a%b):a;}
29 
30 int main()
31 {
32     scanf("%d%d%d" , &n , &p , &b);
33     for(int i=1 ; i<=n ; i++){
34         scanf("%d" , &a[i]);
35         g[i] = ex_gcd(a[i] , x[i] , p , y[i]);
36     }
37 
38     t[0] = 0 , t[1] = g[1];
39     for(int i=2 ; i<=n ; i++){
40         t[i] = gcd(t[i-1], g[i]);
41     }
42     if(b%t[n]!=0){
43         puts("NO");
44         return 0;
45     }
46     ll cur = b;
47     bool flag=true;
48     for(int i=n ; i>=1 ; i--){
49         ll tmp;
50         ll d = ex_gcd(t[i-1] , tmp , g[i] , f[i]);
51         if(cur%d!=0){flag=false;break;}
52         f[i] = cur/d*f[i];
53         cur -= f[i]*g[i];
54     }
55     if(!flag) {
56         puts("NO");
57         return 0;
58     }
59     puts("YES");
60     for(int i=1 ; i<=n ; i++){
61         ll v = x[i]*f[i];
62         v = (v%p+p)%p;
63         if(i<n) printf("%I64d " , v);
64         else printf("%I64d
" , v);
65     }
66     return 0;
67 }
原文地址:https://www.cnblogs.com/CSU3901130321/p/4812363.html