poj 2429 GCD & LCM Inverse

http://poj.org/problem?id=2429

  数论的题通常都十分简明清晰,同时,相对其他的算法而言,也是相当的有难度的!

  这题原以为是简单的模板题,结果是搞了我一个晚上,wa了不下20次.......不过最终还是皇天不负有心人,迎来了一个accepted~

  这题题意很明确,给你两个数的gcd和lcm,让你求出原来两个数。如果有多种情况,就选和最小的两个。设所求的两个数分别是a,b,化简所求的式子:gcd(a/gcd(a,b),b/gcd(a,b))=1,同时还有a*b=gcd(a,b)*lcm(a,b)。

  令m=a/gcd(a,b)  n=b/gcd(a,b)  那么就是要解这样的一组式子:m*n=lcm(a,b)/gcd(a,b) 而且 gcd(m,n)=1。很容易就想到一个思路,分解质因数然后搜索到满足题意的答案。搜索并不困难,因为在数据范围里面,m*n最多只包含不超过11种不同的质因数,因此搜索是不二之选。

  然而,现在的问题是,如何在给定的时间里将给出的数分解质因数。注意,数值高达10^18,因此朴素的分解方法是行不通的,所以这题要用到随机算法中的Pollard_Rho快速分解质因数的算法。这题的思路就这么多!

  下面这个是我几经修改才提交通过的代码:

Accepted Code:

View Code
  1 #include <cstdio>
  2 #include <cstring>
  3 #include <cstdlib>
  4 #include <cmath>
  5 #include <algorithm>
  6 
  7 #define debug 0
  8 
  9 typedef __int64 ll;
 10 
 11 ll min2(ll a, ll b) {return a < b ? a : b;}
 12 ll max2(ll a, ll b) {return a > b ? a : b;}
 13 
 14 ll pri[] = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71};
 15 const ll inf = 0x7fffffffffffffff;
 16 ll mina, minb;
 17 ll g_d[30], g_n[30], g_b;
 18 int top;
 19 
 20 void swap(ll &a, ll &b){
 21     ll t = a;
 22 
 23     a = b;
 24     b = t;
 25 }
 26 
 27 ll gcd(ll a, ll b){
 28     while (b){
 29         ll c = a % b;
 30         a = b;
 31         b = c;
 32     }
 33     return a;
 34 }
 35 
 36 ll multi(ll a, ll b, ll n){
 37     ll tmp = 0;
 38 
 39     a %= n;
 40     while (b){
 41         if (b & 1){
 42             tmp += a;
 43             tmp %= n;
 44         }
 45         a <<= 1;
 46         a %= n;
 47         b >>= 1;
 48     }
 49 
 50     return tmp;
 51 }
 52 
 53 ll power(ll a, ll b, ll n){
 54     ll tmp = 1;
 55 
 56     a %= n;
 57     while (b){
 58         if (b & 1) tmp = multi(tmp, a, n);
 59         a = multi(a, a, n);
 60         b >>= 1;
 61     }
 62 
 63     return tmp;
 64 }
 65 
 66 bool mr(ll n){
 67     if (n == 2) return true;
 68     if (n < 2 || !(n & 1)) return false;
 69 
 70     ll k = 0, i, j, m, a;
 71 
 72     m = n - 1;
 73     while (!(m & 1)) m >>= 1, k++;
 74     for (i = 0; i < 10; i++){
 75         if (pri[i] >= n) return true;
 76         a = power(pri[i], m, n);
 77         if (a == 1) continue;
 78         for (j = 0; j < k; j++){
 79             if (a == n - 1) break;
 80             a = multi(a, a, n);
 81         }
 82         if (j == k) return false;
 83     }
 84 
 85     return true;
 86 }
 87 
 88 ll p_rho(ll c, ll n){
 89     ll i, x, y, k, d;
 90 
 91     i = 1;
 92     x = y = rand() % n;
 93     k = 2;
 94     do {
 95         i++;
 96         d = gcd(n + y - x, n);
 97         if (d > 1 && d < n) return d;
 98         if (i == k) y = x, k <<= 1;
 99         x = (multi(x, x, n) + n - 1) % n;
100     }while (y != x);
101 
102     return n;
103 }
104 
105 void rho(ll n){
106     if (mr(n) || n < 2) {
107         if (n >= 2 && g_b != 1 && g_b % n == 0){
108             g_d[top] = n;
109             g_n[top] = 0;
110             while (g_b % g_d[top] == 0){
111                 g_n[top]++;
112                 g_b /= g_d[top];
113             }
114             g_d[top] = power(g_d[top], g_n[top], inf);
115             top++;
116         }
117         return ;
118     }
119     ll t = n;
120 
121     while (t >= n) t = p_rho(rand() % (n - 1) + 1, n);
122     rho(t);
123     rho(n / t);
124 }
125 
126 ll labs(ll a){
127     return max2(a, -a);
128 }
129 
130 void dfs(int pos, int end, ll *a, ll b, ll cur){
131     if (pos >= end){
132         #if debug
133         printf("reach  %I64d\n", cur);
134         #endif
135         if (labs(cur - b / cur) < labs(mina - minb)){
136             #if debug
137             printf("modify  %I64d\n", cur);
138             #endif
139             mina = cur;
140             minb = b / cur;
141         }
142         return ;
143     }
144     ll t = (ll) sqrt((double) b) + 1;
145 
146     dfs(pos + 1, end, a, b, cur);
147     if (a[pos] * cur <= t){
148         dfs(pos + 1, end, a, b, cur * a[pos]);
149     }
150 }
151 
152 int main(){
153     ll a, b;
154     #if debug
155     printf("inf %I64d\n", inf);
156     #endif
157 
158     while (~scanf("%I64d%I64d", &a, &b)){
159         top = 0;
160         b /= a;
161         g_b = b;
162         rho(b);
163         #if debug
164         puts("");
165         for (int i = 0; i < top; i++){
166             printf("%I64d\n", g_d[i]);
167         }
168         puts("");
169         #endif
170         mina = b;
171         minb = 1;
172         dfs(0, top, g_d, b, 1);
173         #if debug
174         printf("pass dfs\n");
175         #endif
176         if (mina > minb) swap(mina, minb);
177         printf("%I64d %I64d\n", mina * a, minb * a);
178     }
179 
180     return 0;
181 }

  刚开始我是如下面的代码一样的做法,每次都不断找到最小的质因数,然后用来除原数。本地测试结果基本没问题,但是交上去却wa了,希望有大牛路过可以指导指导!

Wrong Answer Code:

View Code
  1 #include <cstdio>
  2 #include <cstring>
  3 #include <cstdlib>
  4 #include <cmath>
  5 
  6 #define debug 0
  7 
  8 typedef __int64 ll;
  9 
 10 ll min2(ll a, ll b) {return a < b ? a : b;}
 11 ll max2(ll a, ll b) {return a > b ? a : b;}
 12 
 13 int pri[] = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29};
 14 const ll inf = 0x7fffffffffffffff;
 15 ll mina, minb;
 16 
 17 ll gcd(ll a, ll b){
 18     while (b){
 19         ll c = a % b;
 20         a = b;
 21         b = c;
 22     }
 23     return a;
 24 }
 25 
 26 ll multi(ll a, ll b, ll n){
 27     ll tmp = 0;
 28 
 29     a %= n;
 30     while (b){
 31         if (b & 1){
 32             tmp += a;
 33             tmp %= n;
 34         }
 35         a <<= 1;
 36         a %= n;
 37         b >>= 1;
 38     }
 39 
 40     return tmp;
 41 }
 42 
 43 ll power(ll a, ll b, ll n){
 44     ll tmp = 1;
 45 
 46     a %= n;
 47     while (b){
 48         if (b & 1) tmp = multi(tmp, a, n);
 49         a = multi(a, a, n);
 50         b >>= 1;
 51     }
 52 
 53     return tmp;
 54 }
 55 
 56 bool mr(ll n){
 57     if (n == 2) return true;
 58     if (n < 2 || ~n & 1) return false;
 59 
 60     ll k = 0, i, j, m, a;
 61 
 62     m = n - 1;
 63     while (!(m & 1)) m >>= 1, k++;
 64     for (i = 0; i < 10; i++){
 65         if (pri[i] >= n) return true;
 66         a = power(pri[i], m, n);
 67         if (a == 1) continue;
 68         for (j = 0; j < k; j++){
 69             if (a == n - 1) break;
 70             a = multi(a, a, n);
 71         }
 72         if (j == k) return false;
 73     }
 74 
 75     return true;
 76 }
 77 
 78 ll p_rho(ll c, ll n){
 79     ll i, x, y, k, d;
 80 
 81     i = 1;
 82     x = y = rand() % n;
 83     k = 2;
 84     do {
 85         i++;
 86         d = gcd(n + y - x, n);
 87         if (d > 1 && d < n) return d;
 88         if (i == k) y = x, k <<= 1;
 89         x = (multi(x, x, n) + n - c) % n;
 90     }while (y != x);
 91 
 92     return n;
 93 }
 94 
 95 ll rho(ll n){
 96     if (mr(n)) return n;
 97 
 98     ll t = n;
 99     ll a = n, b = n;
100     #if debug
101     printf("pass\n");
102     #endif
103     while (t >= n) t = p_rho(rand() % (n - 1) + 1, n);
104     if (t >= 2) a = rho(t);
105     if ((n / t) >= 2) b = rho(n / t);
106     return min2(a, b);
107 }
108 
109 ll labs(ll a){
110     return max2(a, -a);
111 }
112 
113 void dfs(int pos, int end, ll *a, ll b, ll cur){
114     if (pos >= end){
115         #if debug
116         printf("reach  %I64d\n", cur);
117         #endif
118         if (labs(cur - b / cur) < labs(mina - minb)){
119             #if debug
120             printf("modify  %I64d\n", cur);
121             #endif
122             mina = cur;
123             minb = b / cur;
124         }
125         return ;
126     }
127     ll t = (ll) sqrt((double) b) + 1;
128 
129     if (a[pos] * cur <= t){
130         dfs(pos + 1, end, a, b, cur * a[pos]);
131     }
132     dfs(pos + 1, end, a, b, cur);
133 }
134 
135 void swap(ll &a, ll &b){
136     ll t = a;
137     a = b;
138     b = t;
139 }
140 
141 int main(){
142     ll a, b, t;
143     ll d[12], n[12];
144     int top;
145     #if debug
146     printf("inf %I64d\n", inf);
147     #endif
148 
149     while (~scanf("%I64d%I64d", &a, &b)){
150         top = 0;
151         b /= a;
152         t = b;
153         while (b != 1){
154             d[top] = rho(b);
155             n[top] = 0;
156             while (b % d[top] == 0){
157                 n[top]++;
158                 b /= d[top];
159             }
160             d[top] = power(d[top], n[top], inf);
161             top++;
162         }
163         #if debug
164         puts("");
165         for (int i = 0; i < top; i++){
166             printf("%I64d\n", d[i]);
167         }
168         puts("");
169         #endif
170         mina = t;
171         minb = 1;
172         dfs(0, top, d, t, 1);
173         #if debug
174         printf("pass dfs\n");
175         #endif
176         if (mina > minb) swap(mina, minb);
177         printf("%I64d %I64d\n", mina * a, minb * a);
178     }
179 
180     return 0;
181 }

  用一个相对稳定的Pollard_Rho模板做这题:

View Code
  1 #include <cstdio>
  2 #include <cstring>
  3 #include <cstdlib>
  4 #include <cmath>
  5 #include <algorithm>
  6 #include <iostream>
  7 
  8 #define debug 0
  9 
 10 using namespace std;
 11 
 12 typedef __int64 ll;
 13 
 14 ll rd(){
 15     return rand() * rand();
 16 }
 17 
 18 ll gcd(ll a, ll b){
 19     if (!b) return a;
 20     return gcd(b, a % b);
 21 }
 22 
 23 ll multi(ll a, ll b, ll n){
 24     ll tmp = 0;
 25 
 26     while (b){
 27         if (b & 1){
 28             tmp += a;
 29             tmp %= n;
 30         }
 31         a <<= 1;
 32         a %= n;
 33         b >>= 1;
 34     }
 35 
 36     return tmp;
 37 }
 38 
 39 ll power(ll a, ll m, ll n){
 40     ll tmp = 1;
 41 
 42     a %= n;
 43     while (m){
 44         if (m & 1) tmp = multi(tmp, a, n);
 45         a = multi(a, a, n);
 46         m >>= 1;
 47     }
 48 
 49     return tmp;
 50 }
 51 
 52 bool mr(ll n){
 53     if (n == 2) return true;
 54     if (n < 2 || ~n & 1) return false;
 55 
 56     int t = 0;
 57     ll a, x, y, u = n - 1;
 58 
 59     while (~u & 1) t++, u >>= 1;
 60     for (int i = 0; i < 8; i++){
 61         a = rd() % (n - 1) + 1;
 62         x = power(a, u, n);
 63         for (int j = 0; j < t; j++){
 64             y = multi(x, x, n);
 65             if (y == 1 && x != 1 && x != n - 1) return false;
 66             x = y;
 67         }
 68         if (x != 1) return false;
 69     }
 70 
 71     return true;
 72 }
 73 
 74 ll rho(ll n, ll c){
 75     ll x, y, d, i = 1, k = 2;
 76 
 77     x = rd() % (n - 1) + 1;
 78     y = x;
 79     while (true){
 80         i++;
 81         x = (multi(x, x, n) + c) % n;
 82         d = gcd(y - x, n);
 83         if (1 < d && d < n) return d;
 84         if (x == y) return n;
 85         if (i == k) y = x, k <<= 1;
 86     }
 87 }
 88 
 89 void fac(ll n, int k, ll *p, int &cnt){
 90     if (n == 1) return ;
 91     if (mr(n)){
 92         p[cnt++] = n;
 93         return ;
 94     }
 95 
 96     ll t = n;
 97 
 98     while (t >= n) t = rho(t, k--);
 99     fac(t, k, p, cnt);
100     fac(n / t, k, p, cnt);
101 }
102 
103 ll mina, minb;
104 
105 ll max2(ll a, ll b){
106     return a > b ? a : b;
107 }
108 
109 ll absi64(ll a){
110     return max2(a, -a);
111 }
112 
113  void dfs(int pos, int end, ll *a, ll b, ll cur){
114      if (pos >= end){
115          #if debug
116          printf("reach  %I64d\n", cur);
117          #endif
118          if (absi64(cur - b / cur) < absi64(mina - minb)){
119              #if debug
120              printf("modify  %I64d\n", cur);
121              #endif
122              mina = cur;
123              minb = b / cur;
124          }
125          return ;
126      }
127      ll t = (ll) sqrt((double) b) + 1;
128 
129      dfs(pos + 1, end, a, b, cur);
130      if (a[pos] * cur <= t){
131          dfs(pos + 1, end, a, b, cur * a[pos]);
132      }
133  }
134 
135 void swap(ll &a, ll &b){
136     ll t = a;
137     a = b;
138     b = t;
139 }
140 
141 int main(){
142     ll a, b, t;
143     int cnt, m;
144     ll f[80];
145 
146     while (cin >> a >> b){
147         b /= a;
148         cnt = 0;
149         fac(b, 107, f, cnt);
150         sort(f, f + cnt);
151         #if debug
152         for (int i = 0; i < cnt; i++){
153             cout << f[i] << endl;
154         }
155         #endif
156         m = f[cnt] = 0;
157         t = f[0];
158         for (int i = 1; i <= cnt; i++){
159             if (t == f[i]) f[m] *= t;
160             else {
161                 t = f[i];
162                 m++;
163                 f[m] = t;
164             }
165         }
166         #if debug
167         cout << endl;
168         for (int i = 0; i < m; i++){
169             cout << "multi  " << f[i] << endl;
170         }
171         #endif
172         mina = 1;
173         minb = b;
174         if (b != 1) dfs(0, m, f, b, 1);
175         if (mina > minb) swap(mina, minb);
176         cout << mina * a << ' ' << minb * a << endl;
177     }
178 
179     return 0;
180 }

  ——written by Lyon

原文地址:https://www.cnblogs.com/LyonLys/p/poj_2429_Lyon.html