File Coverage

aks.c
Criterion Covered Total %
statement 156 163 95.7
branch 115 156 73.7
condition n/a
subroutine n/a
pod n/a
total 271 319 84.9


line stmt bran cond sub pod time code
1             #include
2             #include
3             #include
4             #include
5             #include
6              
7             /* The AKS primality algorithm for native integers.
8             *
9             * There are three versions here:
10             * V6 The v6 algorithm from the latest AKS paper.
11             * https://www.cse.iitk.ac.in/users/manindra/algebra/primality_v6.pdf
12             * BORNEMANN Improvements from Bernstein, Voloch, and a clever r/s
13             * selection from Folkmar Bornemann. Similar to Bornemann's
14             * 2003 Pari/GP implementation:
15             * https://homepage.univie.ac.at/Dietrich.Burde/pari/aks.gp
16             * BERN41 My implementation of theorem 4.1 from Bernstein's 2003 paper.
17             * https://cr.yp.to/papers/aks.pdf
18             *
19             * Each one is orders of magnitude faster than the previous, and by default
20             * we use Bernstein 4.1 as it is by far the fastest.
21             *
22             * Note that AKS is very, very slow compared to other methods. It is, however,
23             * polynomial in log(N), and log-log performance graphs show nice straight
24             * lines for both implementations. However APR-CL and ECPP both start out
25             * much faster and the slope will be less for any sizes of N that we're
26             * interested in.
27             *
28             * For native 64-bit integers this is purely a coding exercise, as BPSW is
29             * a million times faster and gives proven results.
30             *
31             *
32             * When n < 2^(wordbits/2)-1, we can do a straightforward intermediate:
33             * r = (r + a * b) % n
34             * If n is larger, then these are replaced with:
35             * r = addmod( r, mulmod(a, b, n), n)
36             * which is a lot more work, but keeps us correct.
37             *
38             * Software that does polynomial convolutions followed by a modulo can be
39             * very fast, but will fail when n >= (2^wordbits)/r.
40             *
41             * This is all much easier in GMP.
42             *
43             * Copyright 2012-2016, Dana Jacobsen.
44             */
45              
46             #define SQRTN_SHORTCUT 1
47              
48             #define IMPL_V6 0 /* From the primality_v6 paper */
49             #define IMPL_BORNEMANN 0 /* From Bornemann's 2002 implementation */
50             #define IMPL_BERN41 1 /* From Bernstein's early 2003 paper */
51              
52             #include "ptypes.h"
53             #include "aks.h"
54             #define FUNC_isqrt 1
55             #define FUNC_gcd_ui 1
56             #include "util.h"
57             #include "cache.h"
58             #include "mulmod.h"
59             #include "factor.h"
60              
61             #if IMPL_BORNEMANN || IMPL_BERN41
62             /* We could use lgamma, but it isn't in MSVC and not in pre-C99. The only
63             * sure way to find if it is available is test compilation (ala autoconf).
64             * Instead, we'll just use our own implementation.
65             * See http://mrob.com/pub/ries/lanczos-gamma.html for alternates. */
66 1116           static double log_gamma(double x)
67             {
68             static const double log_sqrt_two_pi = 0.91893853320467274178;
69             static const double lanczos_coef[8+1] =
70             { 0.99999999999980993, 676.5203681218851, -1259.1392167224028,
71             771.32342877765313, -176.61502916214059, 12.507343278686905,
72             -0.13857109526572012, 9.9843695780195716e-6, 1.5056327351493116e-7 };
73 1116           double base = x + 7.5, sum = 0;
74             int i;
75 10044 100         for (i = 8; i >= 1; i--)
76 8928           sum += lanczos_coef[i] / (x + (double)i);
77 1116           sum += lanczos_coef[0];
78 1116           sum = log_sqrt_two_pi + log(sum/x) + ( (x+0.5)*log(base) - base );
79 1116           return sum;
80             }
81              
82             /* Note: For lgammal we need logl in the above.
83             * Max error drops from 2.688466e-09 to 1.818989e-12. */
84             #undef lgamma
85             #define lgamma(x) log_gamma(x)
86             #endif
87              
88             #if IMPL_BERN41
89 372           static double log_binomial(UV n, UV k)
90             {
91 372 50         if (n < k) return 0;
92 372           return log_gamma(n+1) - log_gamma(k+1) - log_gamma(n-k+1);
93             }
94 93           static double log_bern41_binomial(UV r, UV d, UV i, UV j, UV s)
95             {
96 93           return log_binomial( 2*s, i)
97 93           + log_binomial( d, i)
98 93           + log_binomial( 2*s-i, j)
99 93           + log_binomial( r-2-d, j);
100             }
101 93           static int bern41_acceptable(UV n, UV r, UV s)
102             {
103 93           double scmp = ceil(sqrt( (r-1)/3.0 )) * log(n);
104 93           UV d = (UV) (0.5 * (r-1));
105 93           UV i = (UV) (0.475 * (r-1));
106 93           UV j = i;
107 93 50         if (d > r-2) d = r-2;
108 93 50         if (i > d) i = d;
109 93 50         if (j > (r-2-d)) j = r-2-d;
110 93           return (log_bern41_binomial(r,d,i,j,s) >= scmp);
111             }
112             #endif
113              
114             #if 0
115             /* Naive znorder. Works well if limit is small. Note arguments. */
116             static UV order(UV r, UV n, UV limit) {
117             UV j;
118             UV t = 1;
119             for (j = 1; j <= limit; j++) {
120             t = mulmod(t, n, r);
121             if (t == 1)
122             break;
123             }
124             return j;
125             }
126             static void poly_print(UV* poly, UV r)
127             {
128             int i;
129             for (i = r-1; i >= 1; i--) {
130             if (poly[i] != 0)
131             printf("%lux^%d + ", poly[i], i);
132             }
133             if (poly[0] != 0) printf("%lu", poly[0]);
134             printf("\n");
135             }
136             #endif
137              
138 540           static void poly_mod_mul(UV* px, UV* py, UV* res, UV r, UV mod)
139             {
140             UV degpx, degpy;
141             UV i, j, pxi, pyj, rindex;
142              
143             /* Determine max degree of px and py */
144 13440 100         for (degpx = r-1; degpx > 0 && !px[degpx]; degpx--) ; /* */
    100          
145 11705 50         for (degpy = r-1; degpy > 0 && !py[degpy]; degpy--) ; /* */
    100          
146             /* We can sum at least j values at once */
147 540 50         j = (mod >= HALF_WORD) ? 0 : (UV_MAX / ((mod-1)*(mod-1)));
148              
149 540 100         if (j >= degpx || j >= degpy) {
    50          
150             /* res will be written completely, so no need to set */
151 17910 100         for (rindex = 0; rindex < r; rindex++) {
152 17595           UV sum = 0;
153 17595           j = rindex;
154 308130 100         for (i = 0; i <= degpx; i++) {
155 290535 100         if (j <= degpy)
156 206550           sum += px[i] * py[j];
157 290535 100         j = (j == 0) ? r-1 : j-1;
158             }
159 17595           res[rindex] = sum % mod;
160             }
161             } else {
162 225           memset(res, 0, r * sizeof(UV)); /* Zero result accumulator */
163 16650 100         for (i = 0; i <= degpx; i++) {
164 16425           pxi = px[i];
165 16425 50         if (pxi == 0) continue;
166 16425 50         if (mod < HALF_WORD) {
167 1215450 100         for (j = 0; j <= degpy; j++) {
168 1199025           pyj = py[j];
169 1199025 100         rindex = i+j; if (rindex >= r) rindex -= r;
170 1199025           res[rindex] = (res[rindex] + (pxi*pyj) ) % mod;
171             }
172             } else {
173 0 0         for (j = 0; j <= degpy; j++) {
174 0           pyj = py[j];
175 0 0         rindex = i+j; if (rindex >= r) rindex -= r;
176 0           res[rindex] = muladdmod(pxi, pyj, res[rindex], mod);
177             }
178             }
179             }
180             }
181 540           memcpy(px, res, r * sizeof(UV)); /* put result in px */
182 540           }
183 1435           static void poly_mod_sqr(UV* px, UV* res, UV r, UV mod)
184             {
185             UV c, d, s, sum, rindex, maxpx;
186 1435           UV degree = r-1;
187 1435           int native_sqr = (mod > isqrt(UV_MAX/(2*r))) ? 0 : 1;
188              
189 1435           memset(res, 0, r * sizeof(UV)); /* zero out sums */
190             /* Discover index of last non-zero value in px */
191 18730 50         for (s = degree; s > 0; s--)
192 18730 100         if (px[s] != 0)
193 1435           break;
194 1435           maxpx = s;
195             /* 1D convolution */
196 193910 100         for (d = 0; d <= 2*degree; d++) {
197             UV *pp1, *pp2, *ppend;
198 192475 100         UV s_beg = (d <= degree) ? 0 : d-degree;
199 192475           UV s_end = ((d/2) <= maxpx) ? d/2 : maxpx;
200 192475 100         if (s_end < s_beg) continue;
201 175180           sum = 0;
202 175180           pp1 = px + s_beg;
203 175180           pp2 = px + d - s_beg;
204 175180           ppend = px + s_end;
205 175180 100         if (native_sqr) {
206 10160 100         while (pp1 < ppend)
207 7240           sum += 2 * *pp1++ * *pp2--;
208             /* Special treatment for last point */
209 2920           c = px[s_end];
210 2920 100         sum += (s_end*2 == d) ? c*c : 2*c*px[d-s_end];
211 2920 100         rindex = (d < r) ? d : d-r; /* d % r */
212 2920           res[rindex] = (res[rindex] + sum) % mod;
213             #if HAVE_UINT128
214             } else {
215 172260           uint128_t max = ((uint128_t)1 << 127) - 1;
216 172260           uint128_t c128, sum128 = 0;
217              
218 2988450 100         while (pp1 < ppend) {
219 2816190           c128 = ((uint128_t)*pp1++) * ((uint128_t)*pp2--);
220 2816190 50         if (c128 > max) c128 %= mod;
221 2816190           c128 <<= 1;
222 2816190 50         if (c128 > max) c128 %= mod;
223 2816190           sum128 += c128;
224 2816190 50         if (sum128 > max) sum128 %= mod;
225             }
226 172260           c128 = px[s_end];
227 172260 100         if (s_end*2 == d) {
228 78300           c128 *= c128;
229             } else {
230 93960           c128 *= px[d-s_end];
231 93960 50         if (c128 > max) c128 %= mod;
232 93960           c128 <<= 1;
233             }
234 172260 50         if (c128 > max) c128 %= mod;
235 172260           sum128 += c128;
236 172260 50         if (sum128 > max) sum128 %= mod;
237 172260 100         rindex = (d < r) ? d : d-r; /* d % r */
238 172260           res[rindex] = ((uint128_t)res[rindex] + sum128) % mod;
239             #else
240             } else {
241             while (pp1 < ppend) {
242             UV p1 = *pp1++;
243             UV p2 = *pp2--;
244             sum = addmod(sum, mulmod(2, mulmod(p1, p2, mod), mod), mod);
245             }
246             c = px[s_end];
247             if (s_end*2 == d)
248             sum = addmod(sum, sqrmod(c, mod), mod);
249             else
250             sum = addmod(sum, mulmod(2, mulmod(c, px[d-s_end], mod), mod), mod);
251             rindex = (d < r) ? d : d-r; /* d % r */
252             res[rindex] = addmod(res[rindex], sum, mod);
253             #endif
254             }
255             }
256 1435           memcpy(px, res, r * sizeof(UV)); /* put result in px */
257 1435           }
258              
259 55           static UV* poly_mod_pow(UV* pn, UV power, UV r, UV mod)
260             {
261             UV *res, *temp;
262              
263 55 50         Newz(0, res, r, UV);
264 55 50         New(0, temp, r, UV);
265 55           res[0] = 1;
266              
267 1545 100         while (power) {
268 1490 100         if (power & 1) poly_mod_mul(res, pn, temp, r, mod);
269 1490           power >>= 1;
270 1490 100         if (power) poly_mod_sqr(pn, temp, r, mod);
271             }
272 55           Safefree(temp);
273 55           return res;
274             }
275              
276 55           static int test_anr(UV a, UV n, UV r)
277             {
278             UV* pn;
279             UV* res;
280             UV i;
281 55           int retval = 1;
282              
283 55 50         Newz(0, pn, r, UV);
284 55 50         if (a >= n) a %= n;
285 55           pn[0] = a;
286 55           pn[1] = 1;
287 55           res = poly_mod_pow(pn, n, r, n);
288 55           res[n % r] = addmod(res[n % r], n - 1, n);
289 55           res[0] = addmod(res[0], n - a, n);
290              
291 3470 100         for (i = 0; i < r; i++)
292 3415 50         if (res[i] != 0)
293 0           retval = 0;
294 55           Safefree(res);
295 55           Safefree(pn);
296 55           return retval;
297             }
298              
299             /*
300             * Avanzi and Mihǎilescu, 2007
301             * http://www.uni-math.gwdg.de/preda/mihailescu-papers/ouraks3.pdf
302             * "As a consequence, one cannot expect the present variants of AKS to
303             * compete with the earlier primality proving methods like ECPP and
304             * cyclotomy." - conclusion regarding memory consumption
305             */
306 11           bool is_aks_prime(UV n)
307             {
308 11           UV r, s, a, starta = 1;
309              
310 11 100         if (n < 2)
311 2           return 0;
312 9 100         if (n == 2)
313 1           return 1;
314              
315 8 50         if (powerof(n) > 1)
316 0           return 0;
317              
318 8 50         if (n > 11 && ( !(n%2) || !(n%3) || !(n%5) || !(n%7) || !(n%11) )) return 0;
    50          
    50          
    50          
    50          
    50          
319             /* if (!is_prob_prime(n)) return 0; */
320              
321             #if IMPL_V6
322             {
323             UV sqrtn = isqrt(n);
324             double log2n = log(n) / log(2); /* C99 has a log2() function */
325             UV limit = (UV) floor(log2n * log2n);
326              
327             MPUverbose(1, "# aks limit is %lu\n", (unsigned long) limit);
328              
329             for (r = 2; r < n; r++) {
330             if ((n % r) == 0)
331             return 0;
332             #if SQRTN_SHORTCUT
333             if (r > sqrtn)
334             return 1;
335             #endif
336             if (znorder(n, r) > limit)
337             break;
338             }
339              
340             if (r >= n)
341             return 1;
342              
343             s = (UV) floor(sqrt(r-1) * log2n);
344             }
345             #endif
346             #if IMPL_BORNEMANN
347             {
348             UV fac[MPU_MAX_FACTORS+1];
349             UV slim;
350             double c1, c2, x;
351             double const t = 48;
352             double const t1 = (1.0/((t+1)*log(t+1)-t*log(t)));
353             double const dlogn = log(n);
354             r = next_prime( (UV) (t1*t1 * dlogn*dlogn) );
355             while (!is_primitive_root(n,r,1))
356             r = next_prime(r);
357              
358             slim = (UV) (2*t*(r-1));
359             c1 = lgamma(r-1);
360             c2 = dlogn * floor(sqrt(r));
361             { /* Binary search for first s in [1,slim] where x >= 0 */
362             UV i = 1;
363             UV j = slim;
364             while (i < j) {
365             s = i + (j-i)/2;
366             x = (lgamma(r-1+s) - c1 - lgamma(s+1)) / c2 - 1.0;
367             if (x < 0) i = s+1;
368             else j = s;
369             }
370             s = i-1;
371             }
372             s = (s+3) >> 1;
373             /* Bornemann checks factors up to (s-1)^2, we check to max(r,s) */
374             /* slim = (s-1)*(s-1); */
375             slim = (r > s) ? r : s;
376             MPUverbose(2, "# aks trial to %lu\n", slim);
377             if (trial_factor(n, fac, 2, slim) > 1)
378             return 0;
379             if (slim >= HALF_WORD || (slim*slim) >= n)
380             return 1;
381             }
382             #endif
383             #if IMPL_BERN41
384             {
385             UV slim, fac[MPU_MAX_FACTORS+1];
386 8           double const log2n = log(n) / log(2);
387             /* Tuning: Initial 'r' selection. Search limit for 's'. */
388 8 50         double const r0 = ((log2n > 32) ? 0.010 : 0.003) * log2n * log2n;
389 8 50         UV const rmult = (log2n > 32) ? 6 : 30;
390              
391 8 100         r = next_prime(r0 < 2 ? 2 : (UV)r0); /* r must be at least 3 */
392 53 100         while ( !is_primitive_root(n,r,1) || !bern41_acceptable(n,r,rmult*(r-1)) )
    100          
393 45           r = next_prime(r);
394              
395             { /* Binary search for first s in [1,slim] where conditions met */
396 8           UV bi = 1;
397 8           UV bj = rmult * (r-1);
398 79 100         while (bi < bj) {
399 71           s = bi + (bj-bi)/2;
400 71 100         if (!bern41_acceptable(n, r, s)) bi = s+1;
401 52           else bj = s;
402             }
403 8           s = bj;
404 8 50         if (!bern41_acceptable(n, r, s)) croak("AKS: bad s selected");
405             /* S goes from 2 to s+1 */
406 8           starta = 2;
407 8           s = s+1;
408             }
409             /* Check divisibility to s * (s-1) to cover both gcd conditions */
410 8           slim = s * (s-1);
411 8 50         MPUverbose(2, "# aks trial to %lu\n", (unsigned long)slim);
412 8 100         if (trial_factor(n, fac, 2, slim) > 1)
413 6           return 0;
414 6 50         if (slim >= HALF_WORD || (slim*slim) >= n)
    100          
415 3           return 1;
416             /* Check b^(n-1) = 1 mod n for b in [2..s] */
417 58 100         for (a = 2; a <= s; a++) {
418 56 100         if (powmod(a, n-1, n) != 1)
419 1           return 0;
420             }
421             }
422             #endif
423              
424 2 50         MPUverbose(1, "# aks r = %lu s = %lu\n", (unsigned long) r, (unsigned long) s);
425              
426             /* Almost every composite will get recognized by the first test.
427             * However, we need to run 's' tests to have the result proven for all n
428             * based on the theorems we have available at this time. */
429 57 100         for (a = starta; a <= s; a++) {
430 55 50         if (! test_anr(a, n, r) )
431 0           return 0;
432 55 50         MPUverbose(2, ".");
433             }
434 2 50         MPUverbose(2, "\n");
435 2           return 1;
436             }