File Coverage

rootmod.c
Criterion Covered Total %
statement 450 480 93.7
branch 342 444 77.0
condition n/a
subroutine n/a
pod n/a
total 792 924 85.7


line stmt bran cond sub pod time code
1             /******************************************************************************/
2             /* MODULAR ROOTS */
3             /******************************************************************************/
4              
5             #include
6             #include
7             #include
8             #include "ptypes.h"
9             #define FUNC_isqrt 1
10             #define FUNC_is_perfect_square 1
11             #define FUNC_gcd_ui 1
12             #define FUNC_ipow 1
13             #include "util.h"
14             #include "sort.h"
15             #include "mulmod.h"
16             #include "factor.h"
17             #include "rootmod.h"
18              
19             /* Pick one or both */
20             #define USE_ROOTMOD_SPLITK 1 /* enables rootmod_composite1 */
21             #define USE_ROOTMOD_SPLITN 1 /* enables rootmod_composite2 */
22              
23              
24             /******************************************************************************/
25             /* SQRT(N) MOD M */
26             /******************************************************************************/
27              
28             /* _sqrtmod_prime assumes 1 < a < p, n > 1, p > 2, p prime.
29             * _sqrtmod_prime_power assumes 1 < a < p, n > 1, p > 2, p prime.
30             * If any of these are not true, the result is undefined.
31             *
32             * _sqrtmod_composite takes care of the edge conditions and factors n.
33             *
34             * _sqrtmod_composite and _sqrtmod_prime_power always return UV_MAX
35             * if no root exists, while any other return value will be a valid root.
36             *
37             * The exported functions sqrtmod(a,n) and rootmod(a,2,n) further:
38             * - verify the result and return success / fail in a separate int.
39             * - always returns the smaller of the two roots.
40             *
41             * sqrtmodp / rootmodp does the same except n is assumed prime.
42             */
43              
44             #if !USE_MONTMATH
45 119           static UV _sqrtmod_prime(UV a, UV p) {
46 119 100         if ((p % 4) == 3) {
47 35           return powmod(a, (p+1)>>2, p);
48             }
49 84 100         if ((p % 8) == 5) { /* Atkin's algorithm. Faster than Legendre. */
50             UV a2, alpha, beta, b;
51 29           a2 = addmod(a,a,p);
52 29           alpha = powmod(a2,(p-5)>>3,p);
53 29           beta = mulmod(a2,sqrmod(alpha,p),p);
54 29 50         b = mulmod(alpha, mulmod(a, (beta ? beta-1 : p-1), p), p);
55 29           return b;
56             }
57 55 100         if ((p % 16) == 9) { /* Müller's algorithm extending Atkin */
58 13           UV a2, alpha, beta, b, d = 1;
59 13           a2 = addmod(a,a,p);
60 13           alpha = powmod(a2, (p-9)>>4, p);
61 13           beta = mulmod(a2, sqrmod(alpha,p), p);
62 13 100         if (sqrmod(beta,p) != p-1) {
63 20 100         do { d += 2; } while (kronecker_uu(d,p) != -1 && d < p);
    50          
64 10           alpha = mulmod(alpha, powmod(d,(p-9)>>3,p), p);
65 10           beta = mulmod(a2, mulmod(sqrmod(d,p),sqrmod(alpha,p),p), p);
66             }
67 13 50         b = mulmod(alpha, mulmod(a, mulmod(d,(beta ? beta-1 : p-1),p),p),p);
68 13           return b;
69             }
70              
71             /* Verify Euler condition for odd p */
72 42 50         if ((p & 1) && powmod(a,(p-1)>>1,p) != 1) return 0;
    100          
73              
74             /* Algorithm 1.5.1 from Cohen. Tonelli/Shanks. */
75             {
76             UV x, q, e, t, z, r, m, b;
77 26           q = p-1;
78 26           e = valuation_remainder(q, 2, &q);
79 26           t = 3;
80 86 100         while (kronecker_uu(t, p) != -1) {
81 60           t += 2;
82 60 50         if (t == 201) { /* exit if p looks like a composite */
83 0 0         if ((p % 2) == 0 || powmod(2, p-1, p) != 1 || powmod(3, p-1, p) != 1)
    0          
    0          
84 0           return 0;
85 60 50         } else if (t >= 20000) { /* should never happen */
86 0           return 0;
87             }
88             }
89 26           z = powmod(t, q, p);
90 26           b = powmod(a, q, p);
91 26           r = e;
92 26           q = (q+1) >> 1;
93 26           x = powmod(a, q, p);
94 71 100         while (b != 1) {
95 45           t = b;
96 153 50         for (m = 0; m < r && t != 1; m++)
    100          
97 108           t = sqrmod(t, p);
98 45 50         if (m >= r) break;
99 45           t = powmod(z, UVCONST(1) << (r-m-1), p);
100 45           x = mulmod(x, t, p);
101 45           z = mulmod(t, t, p);
102 45           b = mulmod(b, z, p);
103 45           r = m;
104             }
105 26           return x;
106             }
107             }
108             #else
109             static UV _sqrtmod_prime(UV a, UV p) {
110             const uint64_t npi = mont_inverse(p), mont1 = mont_get1(p);
111             a = mont_geta(a,p);
112              
113             if ((p % 4) == 3) {
114             UV b = mont_powmod(a, (p+1)>>2, p);
115             return mont_recover(b, p);
116             }
117              
118             if ((p % 8) == 5) { /* Atkin's algorithm. Faster than Legendre. */
119             UV a2, alpha, beta, b;
120             a2 = addmod(a,a,p);
121             alpha = mont_powmod(a2,(p-5)>>3,p);
122             beta = mont_mulmod(a2,mont_sqrmod(alpha,p),p);
123             beta = submod(beta, mont1, p);
124             b = mont_mulmod(alpha, mont_mulmod(a, beta, p), p);
125             return mont_recover(b, p);
126             }
127             if ((p % 16) == 9) { /* Müller's algorithm extending Atkin */
128             UV a2, alpha, beta, b, d = 1;
129             a2 = addmod(a,a,p);
130             alpha = mont_powmod(a2, (p-9)>>4, p);
131             beta = mont_mulmod(a2, mont_sqrmod(alpha,p), p);
132             if (mont_sqrmod(beta,p) != submod(0,mont1,p)) {
133             do { d += 2; } while (kronecker_uu(d,p) != -1 && d < p);
134             d = mont_geta(d,p);
135             alpha = mont_mulmod(alpha, mont_powmod(d,(p-9)>>3,p), p);
136             beta = mont_mulmod(a2, mont_mulmod(mont_sqrmod(d,p),mont_sqrmod(alpha,p),p), p);
137             beta = mont_mulmod(submod(beta,mont1,p), d, p);
138             } else {
139             beta = submod(beta, mont1, p);
140             }
141             b = mont_mulmod(alpha, mont_mulmod(a, beta, p), p);
142             return mont_recover(b, p);
143             }
144              
145             /* Verify Euler condition for odd p */
146             if ((p & 1) && mont_powmod(a,(p-1)>>1,p) != mont1) return 0;
147              
148             /* Algorithm 1.5.1 from Cohen. Tonelli/Shanks. */
149             {
150             UV x, q, e, t, z, r, m, b;
151             q = p-1;
152             e = valuation_remainder(q, 2, &q);
153             t = 3;
154             while (kronecker_uu(t, p) != -1) {
155             t += 2;
156             if (t == 201) { /* exit if p looks like a composite */
157             if ((p % 2) == 0 || powmod(2, p-1, p) != 1 || powmod(3, p-1, p) != 1)
158             return 0;
159             } else if (t >= 20000) { /* should never happen */
160             return 0;
161             }
162             }
163             t = mont_geta(t, p);
164             z = mont_powmod(t, q, p);
165             b = mont_powmod(a, q, p);
166             r = e;
167             q = (q+1) >> 1;
168             x = mont_powmod(a, q, p);
169             while (b != mont1) {
170             t = b;
171             for (m = 0; m < r && t != mont1; m++)
172             t = mont_sqrmod(t, p);
173             if (m >= r) break;
174             t = mont_powmod(z, UVCONST(1) << (r-m-1), p);
175             x = mont_mulmod(x, t, p);
176             z = mont_mulmod(t, t, p);
177             b = mont_mulmod(b, z, p);
178             r = m;
179             }
180             return mont_recover(x, p);
181             }
182             return 0;
183             }
184             #endif
185              
186 167           static UV _sqrtmod_prime_power(UV a, UV p, UV e) {
187             UV r, s, n, pk, apk, ered, np;
188              
189 167 100         if (e == 1) {
190 121 100         if (a >= p) a %= p;
191 121 100         if (p == 2 || a == 0) return a;
    100          
192 84           r = _sqrtmod_prime(a,p);
193 84 100         if (p-r < r) r = p-r;
194 84 100         return (sqrmod(r,p) == a) ? r : UV_MAX;
195             }
196              
197 46           n = ipow(p,e);
198 46           pk = p*p;
199              
200 46 100         if ((a % n) == 0)
201 1           return 0;
202              
203 45 100         if ((a % pk) == 0) {
204 2           apk = a / pk;
205 2           s = _sqrtmod_prime_power(apk, p, e-2);
206 2 100         if (s == UV_MAX) return UV_MAX;
207 1           return s * p;
208             }
209              
210 43 50         if ((a % p) == 0)
211 0           return UV_MAX;
212              
213 43 100         ered = (p > 2 || e < 5) ? (e+1)>>1 : (e+3)>>1;
    100          
214 43           s = _sqrtmod_prime_power(a, p, ered);
215 43 100         if (s == UV_MAX) return UV_MAX;
216              
217 39 100         np = (p != 2 || (n > (UV_MAX/p))) ? n : n * p;
    50          
218 39           r = addmod(s, gcddivmod(submod(a,sqrmod(s,np),np), addmod(s,s,np), n), n);
219 39 100         if (n-r < r) r = n-r;
220 39 100         if (sqrmod(r,n) != (a % n)) return UV_MAX;
221 29           return r;
222             }
223              
224 40           static UV _sqrtmod_composite(UV a, UV n) {
225             factored_t nf;
226             UV r, s, t, fe, N, inv;
227             uint32_t i, root;
228              
229 40 50         if (n == 0) return UV_MAX;
230 40 50         if (a >= n) a %= n;
231 40 50         if (n <= 2 || a <= 1) return a;
    50          
232 40 100         if (is_perfect_square_ret(a,&root)) return root;
233              
234 33           nf = factorint(n);
235 33           N = ipow(nf.f[0], nf.e[0]);
236 33           r = _sqrtmod_prime_power(a, nf.f[0], nf.e[0]);
237 33 100         if (r == UV_MAX) return UV_MAX;
238 31 100         for (i = 1; i < nf.nfactors; i++) {
239 12           fe = ipow(nf.f[i], nf.e[i]);
240 12           s = _sqrtmod_prime_power(a, nf.f[i], nf.e[i]);
241 12 100         if (s == UV_MAX) return UV_MAX;
242 11           inv = modinverse(N, fe);
243 11           t = mulmod(inv, submod(s % fe,r % fe,fe), fe);
244 11           r = addmod(r, mulmod(N,t,n), n);
245 11           N *= fe;
246             }
247 19           return r;
248             }
249              
250             /* Micro-optimization for fast returns with small values */
251             #define NSMALL 16
252             static char _small[NSMALL-3+1][NSMALL-2+1] = {
253             {0},
254             {0,0},
255             {0,0,2},
256             {0,3,2,0},
257             {3,0,2,0,0},
258             {0,0,2,0,0,0},
259             {0,0,2,0,0,4,0},
260             {0,0,2,5,4,0,0,3},
261             {0,5,2,4,0,0,0,3,0},
262             {0,0,2,0,0,0,0,3,0,0},
263             {0,4,2,0,0,0,0,3,6,0,5},
264             {4,0,2,0,0,7,6,3,0,5,0,0},
265             {0,0,2,0,6,0,0,3,5,0,0,0,0},
266             {0,0,2,0,0,0,0,3,0,0,0,0,0,0},
267             };
268              
269 21           static bool _sqrtmod_small_return(UV *s, UV a, UV n) {
270 21 50         if (n == 0) return 0;
271 21 50         if (a >= n) a %= n;
272 21 50         if (n > 2 && a > 1) {
    100          
273 12           a = _small[n-3][a-2];
274 12 100         if (a == 0) return 0;
275             }
276 11 50         if (s != 0) *s = a;
277 11           return 1;
278             }
279 60           static bool _sqrtmod_return(UV r, UV *s, UV a, UV p) {
280 60 100         if (p-r < r) r = p-r;
281 60 100         if (mulmod(r, r, p) != (a % p)) return 0;
282 46 50         if (s != 0) *s = r;
283 46           return 1;
284             }
285 20           bool sqrtmodp(UV *s, UV a, UV p) {
286 20 50         if (p == 0) return 0;
287 20 50         if (a >= p) a %= p;
288 20 50         if (p <= NSMALL || a <= 1) return _sqrtmod_small_return(s,a,p);
    50          
289 20           return _sqrtmod_return(_sqrtmod_prime(a,p), s, a, p);
290             }
291              
292 61           bool sqrtmod(UV *s, UV a, UV n) {
293             /* return rootmod(s, a, 2, n); */
294 61 50         if (n == 0) return 0;
295 61 50         if (a >= n) a %= n;
296 61 100         if (n <= NSMALL || a <= 1) return _sqrtmod_small_return(s,a,n);
    100          
297 40           return _sqrtmod_return(_sqrtmod_composite(a,n), s, a, n);
298             }
299              
300              
301              
302              
303             /******************************************************************************/
304             /* K-TH ROOT OF N MOD M */
305             /******************************************************************************/
306              
307 108           static bool _rootmod_return(UV r, UV *s, UV a, UV k, UV p) {
308 108 100         if (k == 2 && p-r < r) r = p-r;
    100          
309 108 100         if (powmod(r, k, p) != (a % p)) return 0;
310 70 50         if (s != 0) *s = r;
311 70           return 1;
312             }
313              
314              
315             /* Generalized Tonelli-Shanks for k-th root mod a prime, with k prime */
316 14           static UV _ts_prime(UV a, UV k, UV p, UV *z) {
317             UV A, B, y, x, r, T, ke, t;
318              
319             /* Assume: k > 1, 1 < a < p, p > 2, k prime, p prime */
320              
321 36 100         for (r = p-1; !(r % k); r /= k) ;
322             /* p-1 = r * k^e => ke = ipow(k,e) = (p-1)/r */
323 14           ke = (p-1)/r;
324              
325 14           x = powmod(a, modinverse(k % r, r), p);
326 14           B = mulmod(powmod(x, k, p), modinverse(a, p), p);
327              
328 33 100         for (T = 2, y = 1; y == 1; T++) {
329 19           t = powmod(T, r, p);
330 19           y = powmod(t, ke/k, p);
331             }
332              
333 22 100         while (ke != k) {
334 8           ke = ke/k;
335 8           T = t;
336 8           t = powmod(t, k, p);
337 8           A = powmod(B, ke/k, p);
338 13 100         while (A != 1) {
339 5           x = mulmod(x, T, p);
340 5           B = mulmod(B, t, p);
341 5           A = mulmod(A, y, p);
342             }
343             }
344 14 50         if (z) *z = t;
345 14           return x;
346             }
347              
348             #if USE_ROOTMOD_SPLITK
349             /* Alternate, taking prime p but composite k. */
350             /* k-th root using Tonelli-Shanks for prime k and p */
351             /* This works much better for me than AMM (Holt 2003 or Cao/Sha/Fan 2011). */
352             /* See Algorithm 3.3 of van de Woestijne (2006). */
353             /* https://www.opt.math.tugraz.at/~cvdwoest/maths/dissertatie.pdf */
354             /* Also see Pari's Tonelli-Shanks by Bill Allombert, 2014,2017, which seems */
355             /* to be the same algorithm. */
356              
357             /* Algorithm 3.3, step 2 "Find generator" */
358 18           static void _find_ts_generator(UV *py, UV *pm, /* a not used */ UV k, UV p) {
359             UV e, r, y, m, x, ke1;
360             /* Assume: k > 2, 1 < a < p, p > 2, k prime, p prime */
361             /* e = valuation_remainder(p-1,k,&r); */
362 43 100         for (e = 0, r = p-1; !(r % k); r /= k) e++;
363 18           ke1 = ipow(k, e-1);
364 44 100         for (x = 2, m = 1; m == 1; x++) {
365 26           y = powmod(x, r, p);
366 26 100         if (y != 1)
367 23           m = powmod(y, ke1, p);
368 26 50         MPUassert(x < p, "bad Tonelli-Shanks input\n");
369             }
370 18           *py = y;
371 18           *pm = m;
372 18           }
373              
374 20           static UV _ts_rootmod(UV a, UV k, UV p, UV y, UV m) {
375             UV e, r, A, x, l, T, z, kz;
376              
377             /* Assume: k > 2, 1 < a < p, p > 2, k prime, p prime */
378             /* It is not expected to work with prime powers. */
379              
380             /* e = valuation_remainder(p-1,k,&r); */
381 50 100         for (e = 0, r = p-1; !(r % k); r /= k) e++;
382             /* p-1 = r * k^e */
383 20           x = powmod(a, modinverse(k % r, r), p);
384 20 100         A = (a == 0) ? 0 : mulmod(powmod(x, k, p), modinverse(a, p), p);
385              
386 20 50         if (y == 0 && A != 1)
    0          
387 0           _find_ts_generator(&y, &m /* ,a */, k, p);
388              
389 23 100         while (A != 1) {
390 14 100         for (l = 1, T = A; T != 1; l++) {
391 11 100         if (l >= e) return 0;
392 5           z = T;
393 5           T = powmod(T, k, p);
394             }
395 3           kz = negmod( znlog_solve(z, m, p, k), k); /* k = znorder(m,p) */
396 3           m = powmod(m, kz, p);
397 3           T = powmod(y, kz * ipow(k, e-l), p);
398             /* In the loop we always end with l < e, so e always gets smaller */
399 3           e = l-1;
400 3           x = mulmod(x, T, p);
401 3           y = powmod(T, k, p);
402 3 50         if (y <= 1) return 0; /* In theory this will never be hit. */
403 3           A = mulmod(A, y, p);
404             }
405 14           return x;
406             }
407              
408 0           static UV _compute_generator(UV l, UV e, UV r, UV p) {
409 0           UV x, y, m = 1;
410 0           UV lem1 = ipow(l, e-1);
411 0 0         for (x = 2; m == 1; x++) {
412 0           y = powmod(x, r, p);
413 0 0         if (y == 1) continue;
414 0           m = powmod(y, lem1, p);
415             }
416 0           return y; /* We might want to also return m */
417             }
418              
419             /* Following Pari, we calculate a root of unity to allow finding other roots */
420 65           static UV _rootmod_prime_splitk(UV a, UV k, UV p, UV *zeta) {
421             UV g;
422              
423 65 50         if (zeta) *zeta = 1;
424 65 50         if (a >= p) a %= p;
425 65 100         if (a == 0 || (a == 1 && !zeta)) return a;
    100          
    50          
426              
427             /* Assume: k >= 2, 1 < a < p, p > 2, p prime */
428              
429 44 100         if (k == 2) {
430 15 50         if (zeta) *zeta = p-1;
431 15           return _sqrtmod_prime(a, p);
432             }
433              
434             /* See Algorithm 2.1 of van de Woestijne (2006), or Lindhurst (1997) */
435             /* The latter's proposition 7 generalizes to composite p */
436              
437 29           g = gcd_ui(k, p-1);
438              
439 29 100         if (g != 1) {
440             uint32_t i;
441 17           factored_t nf = factorint(g);
442 35 100         for (i = 0; a != 0 && i < nf.nfactors; i++) {
    100          
443 18           UV y, m, F = nf.f[i], E = nf.e[i];
444 18 50         if (zeta) {
445             UV REM, V, Y;
446 0           V = valuation_remainder(p-1, F, &REM);
447 0           Y = _compute_generator(F, V, REM, p);
448 0           *zeta = mulmod(*zeta, powmod(Y, ipow(F, V-E), p), p);
449             }
450 18           _find_ts_generator(&y, &m /* ,a */, F, p);
451 38 100         while (E-- > 0)
452 20           a = _ts_rootmod(a, F, p, y, m);
453             }
454             }
455 29 100         if (g != k) {
456 17           UV kg = k/g, pg = (p-1)/g;
457 17           a = powmod(a, modinverse(kg % pg, pg), p);
458             }
459 29           return a;
460             }
461             #endif
462              
463              
464             #if 0 /* For testing purposes only. */
465             static UV _trial_rootmod(UV a, UV k, UV n) {
466             UV r;
467             if (n == 0) return 0;
468             if (a >= n) a %= n;
469             if (a <= 1) return a;
470             for (r = 2; r < n; r++)
471             if (powmod(r, k, n) == a)
472             return r;
473             return 0;
474             }
475             static UV* _trial_allsqrtmod(UV* nroots, UV a, UV n) {
476             UV i, *roots, numr = 0, allocr = 16;
477              
478             if (n == 0) return 0;
479             if (a >= n) a %= n;
480              
481             New(0, roots, allocr, UV);
482             for (i = 0; i <= n/2; i++) {
483             if (mulmod(i,i,n) == a) {
484             if (numr >= allocr-1) Renew(roots, allocr += 256, UV);
485             roots[numr++] = i;
486             if (i != 0 && 2*i != n)
487             roots[numr++] = n-i;
488             }
489             }
490             sort_uv_array(roots, numr);
491             *nroots = numr;
492             return roots;
493             }
494             static UV* _trial_allrootmod(UV* nroots, UV a, UV g, UV n) {
495             UV i, *roots, numr = 0, allocr = 16;
496              
497             if (n == 0) return 0;
498             if (a >= n) a %= n;
499              
500             New(0, roots, allocr, UV);
501             for (i = 0; i < n; i++) {
502             if (powmod(i,g,n) == a) {
503             if (numr >= allocr-1) Renew(roots, allocr += 256, UV);
504             roots[numr++] = i;
505             }
506             }
507             *nroots = numr;
508             return roots;
509             }
510             #endif
511              
512              
513             /******************************************************************************/
514             /* K-TH ROOT OF N MOD M (splitk) */
515             /******************************************************************************/
516              
517             #if USE_ROOTMOD_SPLITK
518             /* Given a solution to r^k = a mod p^(e-1), return r^k = a mod p^e */
519 90           static bool _hensel_lift(UV *re, UV r, UV a, UV k, UV pe) {
520             UV f, fp, d;
521              
522             /* UV pe = ipow(p, e); */
523 90 100         if (a >= pe) a %= pe;
524 90           f = submod(powmod(r, k, pe), a, pe);
525 90 100         if (f == 0) { *re = r; return 1; }
526 48           fp = mulmod(k, powmod(r, k-1, pe), pe);
527 48           d = divmod(f, fp, pe);
528 48 100         if (d == 0) return 0; /* We need a different base root */
529 2           *re = submod(r, d, pe);
530 2           return 1;
531             }
532              
533 20           static UV _rootmod_composite1(UV a, UV k, UV n) {
534             factored_t nf;
535             UV fac[MPU_MAX_DFACTORS], exp[MPU_MAX_DFACTORS];
536             UV f, g, e, r;
537             uint32_t i;
538              
539             /* Assume: k >= 2, 1 < a < n, n > 2, n composite */
540              
541             #if 0
542             /* For square roots of p^k with gcd(a,p)==1, this is straightforward. */
543             if (k == 2 && (i = primepower(n, &f)) && (a % f) > 1) {
544             UV x = _sqrtmod_prime(a % f, f);
545             UV r = n/f;
546             UV j = powmod(x, r, n);
547             UV k = powmod(a, (n - r - r + 1) >> 1, n);
548             return mulmod(j, k, n);
549             }
550             #endif
551              
552 20           nf = factorint(n);
553 47 100         for (i = 0; i < nf.nfactors; i++) {
554 34           f = fac[i] = nf.f[i];
555             /* Find a root mod f. If none exists, there is no root for n. */
556 34           r = _rootmod_prime_splitk(a%f, k, f, 0);
557 34 100         if (powmod(r, k, f) != (a%f)) return 0;
558             /* If we have a prime power, use Hensel lifting to solve for p^e */
559 32 100         if (nf.e[i] > 1) {
560 16           UV fe = f;
561 68 100         for (e = 2; e <= nf.e[i]; e++) {
562 57           fe *= f;
563             /* We aren't guaranteed a solution, though we usually get one. */
564 57 100         if (!_hensel_lift(&r, r, a, k, fe)) {
565             /* Search for a different base root */
566 26           UV t, m = fe / (f*f);
567 46 100         for (t = 1; t < f; t++) {
568 33 100         if (_hensel_lift(&r, r + t*m, a, k, fe))
569 13           break;
570             }
571             /* That didn't work, do a stronger but time consuming search. */
572 26 100         if (t >= f) {
573 13           UV afe = a % fe;
574 92 100         for (r = (a % f); r < fe; r += f)
575 87 100         if (powmod(r, k, fe) == afe)
576 8           break;
577 13 100         if (r >= fe) return 0;
578             }
579             }
580             }
581 11           fac[i] = fe;
582             }
583 27           exp[i] = r;
584             }
585 13 50         if (chinese(&g, 0, exp, fac, nf.nfactors) != 1) return 0;
586 13           return g;
587             }
588             #endif
589              
590             /******************************************************************************/
591             /* K-TH ROOT OF N MOD M (splitn) */
592             /******************************************************************************/
593              
594             /* _rootmod_composite2 factors k and combines:
595             * _rootmod_kprime takes prime k along with factored n:
596             * _rootmod_prime_power splits p^e into primes (prime k):
597             * _rootmod_prime finds a root (prime p and prime k)
598             * _sqrtmod_prime (if k==2)
599             * _ts_prime
600             */
601              
602             #if USE_ROOTMOD_SPLITN && !USE_ROOTMOD_SPLITK
603             static UV _rootmod_prime(UV a, UV k, UV p) {
604             UV r, g;
605              
606             /* Assume: p is prime, k is prime */
607              
608             if (a >= p) a %= p;
609             if (p == 2 || a == 0) return a;
610             if (k == 2) {
611             r = _sqrtmod_prime(a,p);
612             return (sqrmod(r,p) == a) ? r : UV_MAX;
613             }
614              
615             /* If co-prime, we have one root */
616             g = gcd_ui(k, p-1);
617             if (g == 1)
618             return powmod(a, modinverse(k % (p-1), p-1), p);
619              
620             /* Check generalized Euler's criterion */
621             if (powmod(a, (p-1)/g, p) != 1)
622             return UV_MAX;
623              
624             return _ts_prime(a, k, p, 0);
625             }
626              
627             static UV _rootmod_prime_power(UV a, UV k, UV p, UV e) {
628             UV r, s, t, n, np, pk, apk, ered;
629              
630             /* Assume: p is prime, k is prime, e >= 1 */
631              
632             if (k == 2) return _sqrtmod_prime_power(a, p, e);
633             if (e == 1) return _rootmod_prime(a, k, p);
634              
635             n = ipow(p,e);
636             pk = ipow(p,k);
637             /* Note: a is not modded */
638              
639             if ((a % n) == 0)
640             return 0;
641              
642             if ((a % pk) == 0) {
643             apk = a / pk;
644             s = _rootmod_prime_power(apk, k, p, e-k);
645             if (s == UV_MAX) return UV_MAX;
646             return s * p;
647             }
648              
649             if ((a % p) == 0)
650             return UV_MAX;
651              
652             ered = (p > 2 || e < 5) ? (e+1)>>1 : (e+3)>>1;
653             s = _rootmod_prime_power(a, k, p, ered);
654             if (s == UV_MAX) return UV_MAX;
655              
656             np = (p != k || (n > (UV_MAX/p))) ? n : n * p;
657             t = powmod(s, k-1, np);
658             r = addmod(s, gcddivmod(submod(a,mulmod(t,s,np),np), mulmod(k,t,np), n), n);
659             if (powmod(r, k, n) != (a % n)) return UV_MAX;
660             return r;
661             }
662              
663             static UV _rootmod_kprime(UV a, UV k, factored_t nf) {
664             UV N, fe, r, s, t, inv;
665             uint32_t i;
666              
667             /* Assume: k is prime */
668              
669             N = ipow(nf.f[0], nf.e[0]);
670             r = _rootmod_prime_power(a, k, nf.f[0], nf.e[0]);
671             if (r == UV_MAX) return UV_MAX;
672             for (i = 1; i < nf.nfactors; i++) {
673             fe = ipow(nf.f[i], nf.e[i]);
674             s = _rootmod_prime_power(a, k, nf.f[i], nf.e[i]);
675             if (s == UV_MAX) return UV_MAX;
676             inv = modinverse(N, fe);
677             t = mulmod(inv, submod(s % fe,r % fe,fe), fe);
678             r = addmod(r, mulmod(N,t,nf.n), nf.n);
679             N *= fe;
680             }
681             return r;
682             }
683              
684             static UV _rootmod_composite2(UV a, UV k, UV n) {
685             factored_t nf;
686             UV r, kfac[MPU_MAX_FACTORS];
687             uint32_t i, kfactors;
688              
689             if (n == 0) return 0;
690             if (a >= n) a %= n;
691              
692             if (n <= 2 || a <= 1) return a;
693             if (k <= 1) return (k == 0) ? 1 : a;
694              
695             /* Factor n */
696             nf = factorint(n);
697              
698             if (is_prime(k))
699             return _rootmod_kprime(a, k, nf);
700              
701             kfactors = factor(k, kfac);
702             r = a;
703             for (i = 0; i < kfactors; i++) { /* for each prime k */
704             r = _rootmod_kprime(r, kfac[i], nf);
705             if (r == UV_MAX) { /* Bad path. We have to use a fallback method. */
706             #if USE_ROOTMOD_SPLITK
707             r = _rootmod_composite1(a,k,n);
708             #else
709             UV *roots, numr;
710             roots = allrootmod(&numr,a,k,n);
711             r = (numr > 0) ? roots[0] : UV_MAX;
712             Safefree(roots);
713             #endif
714             break;
715             }
716             }
717             return r;
718             }
719             #endif
720              
721 0           bool rootmodp(UV *s, UV a, UV k, UV p) {
722             UV r;
723             uint32_t R;
724 0 0         if (p == 0) return 0;
725 0 0         if (a >= p) a %= p;
726              
727             /* return _rootmod_return(_trial_rootmod(a,k,n), s, a, k, p); */
728              
729 0 0         if (p <= 2 || a <= 1) r = a;
    0          
730 0 0         else if (k <= 1) r = (k == 0) ? 1 : a;
    0          
731 0 0         else if (k < BITS_PER_WORD && is_power_ret(a,k,&R)) r = R;
    0          
732             #if USE_ROOTMOD_SPLITK
733 0           else r = _rootmod_prime_splitk(a,k,p,0);
734             #else
735             else r = _rootmod_composite2(a,k,p);
736             #endif
737 0           return _rootmod_return(r, s, a, k, p);
738             }
739              
740 108           bool rootmod(UV *s, UV a, UV k, UV n) {
741             UV r;
742             uint32_t R;
743 108 50         if (n == 0) return 0;
744 108 50         if (a >= n) a %= n;
745              
746             /* return _rootmod_return(_trial_rootmod(a,k,n), s, a, k, n); */
747              
748 108 100         if (n <= 2 || a <= 1) r = a;
    100          
749 93 100         else if (k <= 1) r = (k == 0) ? 1 : a;
    100          
750 60 50         else if (k < BITS_PER_WORD && is_power_ret(a,k,&R)) r = R;
    100          
751             #if USE_ROOTMOD_SPLITK
752 51 100         else if (is_prime(n)) r = _rootmod_prime_splitk(a,k,n,0);
753 20           else r = _rootmod_composite1(a,k,n);
754             #else
755             else r = _rootmod_composite2(a,k,n);
756             #endif
757 108           return _rootmod_return(r, s, a, k, n);
758             }
759              
760              
761              
762              
763             /******************************************************************************/
764             /* SQRTMOD AND ROOTMOD RETURNING ALL RESULTS */
765             /******************************************************************************/
766              
767              
768             /* We could alternately just let the allocation fail */
769             #define MAX_ROOTS_RETURNED 600000000
770              
771             /* Combine roots using Cartesian product CRT */
772 39           static UV* _rootmod_cprod(UV* nroots,
773             UV nr1, UV *roots1, UV p1,
774             UV nr2, UV *roots2, UV p2) {
775             UV i, j, nr, *roots, inv;
776              
777 39           nr = nr1 * nr2;
778 39 50         if (nr > MAX_ROOTS_RETURNED) croak("Maximum returned roots exceeded");
779 39 50         New(0, roots, nr, UV);
780              
781 39           inv = modinverse(p1, p2);
782 169 100         for (i = 0; i < nr1; i++) {
783 130           UV r1 = roots1[i];
784 368 100         for (j = 0; j < nr2; j++) {
785 238           UV r2 = roots2[j];
786             #if 0
787             UV ca[2], cn[2];
788             ca[0] = r1; cn[0] = p1;
789             ca[1] = r2; cn[1] = p2;
790             if (chinese(roots + i * nr2 + j, 0, ca, cn, 2) != 1)
791             croak("chinese fail in allrootmod");
792             #else
793 238           UV t = mulmod(inv, submod(r2 % p2,r1 % p2,p2), p2);
794 238           roots[i * nr2 + j] = addmod(r1, mulmod(p1,t,p1*p2), p1*p2);
795             #endif
796             }
797             }
798 39           Safefree(roots1);
799 39           Safefree(roots2);
800 39           *nroots = nr;
801 39           return roots;
802             }
803              
804 36           static UV* _one_root(UV* nroots, UV r) {
805             UV *roots;
806 36           New(0, roots, 1, UV);
807 36           roots[0] = r;
808 36           *nroots = 1;
809 36           return roots;
810             }
811 0           static UV* _two_roots(UV* nroots, UV r, UV s) {
812             UV *roots;
813 0           New(0, roots, 2, UV);
814 0           roots[0] = r; roots[1] = s;
815 0           *nroots = 2;
816 0           return roots;
817             }
818              
819              
820             /* allsqrtmod algorithm from Hugo van der Sanden, 2021 */
821              
822 114           static UV* _allsqrtmodpk(UV *nroots, UV a, UV p, UV k) {
823 114           UV *roots, *roots2, nr2 = 0;
824             UV i, j, pk, pj, q, q2, a2;
825              
826 114           pk = ipow(p,k);
827 114           *nroots = 0;
828              
829 114 100         if ((a % p) == 0) {
830 37 100         if ((a % pk) == 0) {
831 20           UV low = ipow(p, k >> 1);
832 20 100         UV high = (k & 1) ? low * p : low;
833 20 50         if (low > MAX_ROOTS_RETURNED) croak("Maximum returned roots exceeded");
834 20 50         New(0, roots, low, UV);
835 59 100         for (i = 0; i < low; i++)
836 39           roots[i] = high * i;
837 20           *nroots = low;
838 20           return roots;
839             }
840 17           a2 = a / p;
841 17 100         if ((a2 % p) != 0)
842 3           return 0;
843 14           pj = pk / p;
844 14           roots2 = _allsqrtmodpk(&nr2, a2/p, p, k-2);
845 14 100         if (roots2 == 0) return 0;
846 9           *nroots = nr2 * p;
847 9 50         if (*nroots > MAX_ROOTS_RETURNED) croak("Maximum returned roots exceeded");
848 9 50         New(0, roots, *nroots, UV);
849 28 100         for (i = 0; i < nr2; i++)
850 59 100         for (j = 0; j < p; j++)
851 40           roots[i*p+j] = roots2[i] * p + j * pj;
852 9           Safefree(roots2);
853 9           return roots;
854             }
855              
856 77           q = _sqrtmod_prime_power(a,p,k);
857 77 100         if (q == UV_MAX) return 0;
858              
859 59           New(0, roots, 4, UV);
860 59           roots[0] = q; roots[1] = pk - q;
861 59 100         if (p != 2) { *nroots = 2; }
862 17 100         else if (k == 1) { *nroots = 1; }
863 12 100         else if (k == 2) { *nroots = 2; }
864             else {
865 4           pj = pk / p;
866 4           q2 = mulmod(q, pj-1, pk);
867 4           roots[2] = q2; roots[3] = pk - q2;
868 4           *nroots = 4;
869             }
870 59           return roots;
871             }
872              
873 100           static UV* _allsqrtmodfact(UV *nroots, UV a, factored_t nf) {
874             factored_t rf;
875             UV *roots, *roots1, *roots2, nr, nr1, nr2, p, k, pk;
876             uint32_t i;
877              
878 100           p = nf.f[0], k = nf.e[0];
879 100           *nroots = 0;
880              
881             /* nr1,roots1 are roots of p^k for the first prime power */
882 100           roots1 = _allsqrtmodpk(&nr1, a, p, k);
883 100 100         if (roots1 == 0) return 0;
884 79 100         if (nf.nfactors == 1) {
885 48           *nroots = nr1;
886 48           return roots1;
887             }
888              
889             /* rf = nf with the first factor removed */
890 31           pk = ipow(p, k);
891 31           rf.n = nf.n/pk;
892 31           rf.nfactors = nf.nfactors-1;
893 70 100         for (i = 0; i < rf.nfactors; i++) {
894 39           rf.f[i] = nf.f[i+1];
895 39           rf.e[i] = nf.e[i+1];
896             }
897              
898             /* nr2,roots2 are roots of all the rest, found recursively */
899 31           roots2 = _allsqrtmodfact(&nr2, a, rf);
900 31 100         if (roots2 == 0) { Safefree(roots1); return 0; }
901              
902 27           roots = _rootmod_cprod(&nr, nr1, roots1, pk, nr2, roots2, rf.n);
903              
904 27           *nroots = nr;
905 27           return roots;
906             }
907              
908 34           UV* allsqrtmod(UV* nroots, UV a, UV n) {
909 34           UV *roots, numr = 0;
910              
911 34 50         if (n == 0) return 0;
912 34 50         if (a >= n) a %= n;
913              
914             /* return _trial_allsqrtmod(nroots, a, n); */
915              
916 34 100         if (n <= 2) return _one_root(nroots, a);
917              
918 32           roots = _allsqrtmodfact(&numr, a, factorint(n));
919 32 100         if (numr > 0) sort_uv_array(roots, numr);
920 32           *nroots = numr;
921 32           return roots;
922             }
923              
924              
925             /* allrootmod factors k and combines:
926             * _allrootmod_kprime takes prime k and factored n:
927             * _allrootmod_prime_power splits p^e into primes:
928             * _allrootmod_prime finds all the roots for prime p and prime k
929             * _ts_prime (could alternately call _rootmod_prime_splitk)
930             */
931              
932 49           static UV* _allrootmod_prime(UV* nroots, UV a, UV k, UV p) {
933 49           UV r, g, z, r2, *roots, numr = 0;
934              
935 49           *nroots = 0;
936 49 100         if (a >= p) a %= p;
937              
938             /* Assume: p is prime, k is prime */
939              
940             /* simple case */
941 49 100         if (p == 2 || a == 0) return _one_root(nroots, a);
    100          
942              
943             /* If co-prime, we have one root */
944 44           g = gcd_ui(k, p-1);
945 44 100         if (g == 1) {
946 28           r = powmod(a, modinverse(k % (p-1), p-1), p);
947 28           return _one_root(nroots, r);
948             }
949             /* At this point k < p. (k is a prime so if k>=p, g=1) */
950              
951             /* Check generalized Euler's criterion:
952             * r^k = a mod p has a solution iff a^((p-1)/gcd(p-1,k)) = 1 mod p */
953 16 100         if (powmod(a, (p-1)/g, p) != 1)
954 2           return 0;
955              
956             /* Special case p=3 for performance */
957 14 50         if (p == 3) return (k == 2 && a == 1) ? _two_roots(nroots, 1, 2) : 0;
    0          
    0          
958              
959             /* functionally identical: r = _rootmod_prime_splitk(a, k, p, &z); */
960 14           r = _ts_prime(a, k, p, &z);
961 14 50         if (powmod(r,k,p) != a || z == 0) croak("allrootmod: failed to find root");
    50          
962              
963 14 50         New(0, roots, k, UV);
964 14           roots[numr++] = r;
965 44 100         for (r2 = mulmod(r, z, p); r2 != r && numr < k; r2 = mulmod(r2, z, p) )
    50          
966 30           roots[numr++] = r2;
967 14 50         if (r2 != r) croak("allrootmod: excess roots found");
968              
969 14           *nroots = numr;
970 14           return roots;
971             }
972              
973              
974 87           static UV* _allrootmod_prime_power(UV* nroots, UV a, UV k, UV p, UV e) {
975 87           UV n, i, j, pk, s, t, r, numr = 0, *roots = 0, nr2 = 0, *roots2 = 0;
976              
977             #if 0
978             MPUassert(p >= 2, "_allrootmod_prime_power must be given a prime modulus");
979             MPUassert(e >= 1, "_allrootmod_prime_power must be given exponent >= 1");
980             MPUassert(k >= 2, "_allrootmod_prime_power must be given k >= 2");
981             MPUassert(is_prime(k), "_allrootmod_prime_power must be given prime k");
982             MPUassert(is_prime(p), "_allrootmod_prime_power must be given prime p");
983             #endif
984              
985 87 100         if (e == 1) return _allrootmod_prime(nroots, a, k, p);
986              
987 38           n = ipow(p,e);
988 38           pk = ipow(p, k);
989             /* Note: a is not modded */
990              
991 38 100         if ((a % n) == 0) {
992              
993 2           t = ((e-1) / k) + 1;
994 2           s = ipow(p,t);
995 2           numr = ipow(p,e-t);
996 2 50         New(0, roots, numr, UV);
997 10 100         for (i = 0; i < numr; i++)
998 8           roots[i] = mulmod(i, s, n);
999              
1000 36 100         } else if ((a % pk) == 0) {
1001              
1002 5           UV apk = a / pk;
1003 5           UV pe1 = ipow(p, k-1);
1004 5           UV pek = ipow(p, e-k+1);
1005 5           roots2 = _allrootmod_prime_power(&nr2, apk, k, p, e-k);
1006 5           numr = pe1 * nr2;
1007 5 50         New(0, roots, numr, UV);
1008 9 100         for (i = 0; i < nr2; i++)
1009 44 100         for (j = 0; j < pe1; j++)
1010 40           roots[i*pe1+j] = addmod(mulmod(roots2[i],p,n), mulmod(j,pek,n), n);
1011 5           Safefree(roots2);
1012              
1013 31 100         } else if ((a % p) != 0) {
1014              
1015 16 50         UV np = (n > (UV_MAX/p)) ? n : n*p;
1016 16 100         UV ered = (p > 2 || e < 5) ? (e+1)>>1 : (e+3)>>1;
    50          
1017 16           roots2 = _allrootmod_prime_power(&nr2, a, k, p, ered);
1018              
1019 16 100         if (k != p) {
1020              
1021 4 100         for (j = 0; j < nr2; j++) {
1022 2           s = roots2[j];
1023 2           t = powmod(s, k-1, n);
1024 2           r = addmod(s,gcddivmod(submod(a,mulmod(t,s,n),n),mulmod(k,t,n),n),n);
1025 2           roots2[j] = r;
1026             }
1027 2           roots = roots2;
1028 2           numr = nr2;
1029              
1030             } else {
1031              
1032             /* Step 1, transform roots, eliding any that aren't valid */
1033 44 100         for (j = 0; j < nr2; j++) {
1034 30           s = roots2[j];
1035 30           t = powmod(s, k-1, np);
1036 30           r = addmod(s,gcddivmod(submod(a,mulmod(t,s,np),np),mulmod(k,t,np),n),n);
1037 30 100         if (powmod(r, k, n) == (a % n))
1038 17           roots2[numr++] = r;
1039             }
1040 14           nr2 = numr;
1041              
1042             /* Step 2, Expand out by k */
1043 14 100         if (nr2 > 0) {
1044 12           numr = nr2 * k;
1045 12 50         New(0, roots, numr, UV);
1046 29 100         for (j = 0; j < nr2; j++) {
1047 17           r = roots2[j];
1048 76 100         for (i = 0; i < k; i++)
1049 59           roots[j*k+i] = mulmod(r, addmod( mulmod(i,n/p,n), 1, n), n);
1050             }
1051             }
1052 14           Safefree(roots2);
1053              
1054             /* Step 3, Remove any duplicates */
1055 14 50         if (numr == 2 && roots[0] == roots[1])
    0          
1056 0           numr = 1;
1057 14 100         if (numr > 2) {
1058 12           sort_uv_array(roots, numr);
1059 59 100         for (j = 0, i = 1; i < numr; i++)
1060 47 100         if (roots[j] != roots[i])
1061 32           roots[++j] = roots[i];
1062 12           numr = j+1;
1063             }
1064              
1065             }
1066             }
1067 38           *nroots = numr;
1068 38           return roots;
1069             }
1070              
1071 90           static UV* _allrootmod_kprime(UV* nroots, UV a, UV k, factored_t nf) {
1072 90           UV fe, N, *roots = 0, *roots2, numr = 0, nr2;
1073             uint32_t i;
1074              
1075 90 100         if (k == 2) return _allsqrtmodfact(nroots, a, nf);
1076              
1077 53           *nroots = 0;
1078 53           N = ipow(nf.f[0], nf.e[0]);
1079 53           roots = _allrootmod_prime_power(&numr, a, k, nf.f[0], nf.e[0]);
1080 53 100         if (numr == 0) { Safefree(roots); return 0; }
1081 47 100         for (i = 1; i < nf.nfactors; i++) {
1082 13           fe = ipow(nf.f[i], nf.e[i]);
1083 13           roots2 = _allrootmod_prime_power(&nr2, a, k, nf.f[i], nf.e[i]);
1084 13 100         if (nr2 == 0) { Safefree(roots); Safefree(roots2); return 0; }
1085             /* Cartesian product using CRT. roots and roots2 are freed. */
1086 12           roots = _rootmod_cprod(&numr, numr, roots, N, nr2, roots2, fe);
1087 12           N *= fe;
1088             }
1089 34 50         MPUassert(N == nf.n, "allrootmod: Incorrect factoring");
1090              
1091 34           *nroots = numr;
1092 34           return roots;
1093             }
1094              
1095 47           UV* allrootmod(UV* nroots, UV a, UV k, UV n) {
1096             factored_t nf;
1097 47           UV numr = 0, *roots = 0;
1098             UV kfac[MPU_MAX_FACTORS+1];
1099             uint32_t i, kfactors;
1100              
1101             /* return _trial_allrootmod(nroots, a, k, n); */
1102              
1103 47           *nroots = 0;
1104 47 50         if (n == 0) return 0;
1105 47 50         if (a >= n) a %= n;
1106              
1107 47 100         if (n <= 2 || k == 1)
    50          
1108 1           return _one_root(nroots, a); /* n=1 => [0], n=2 => [0] or [1] */
1109              
1110 46 100         if (k == 0) {
1111 2 100         if (a != 1) return 0;
1112 1 50         if (n > MAX_ROOTS_RETURNED) croak("Maximum returned roots exceeded");
1113 1 50         New(0, roots, n, UV);
1114 14 100         for (i = 0; i < n; i++)
1115 13           roots[i] = i;
1116 1           *nroots = n;
1117 1           return roots;
1118             }
1119              
1120             /* Factor n */
1121 44           nf = factorint(n);
1122              
1123 44 100         if (is_prime(k)) {
1124              
1125 31           roots = _allrootmod_kprime(&numr, a, k, nf);
1126              
1127             } else { /* Split k into primes */
1128              
1129 13           kfactors = factor(k, kfac);
1130 13           roots = _allrootmod_kprime(&numr, a, kfac[0], nf);
1131              
1132 24 100         for (i = 1; numr > 0 && i < kfactors; i++) { /* for each prime k */
    100          
1133 11           UV j, t, allocr = numr, primek = kfac[i];
1134 11           UV *roots2 = 0, nr2 = 0, *roots3 = 0, nr3 = 0;
1135 11 50         New(0, roots3, allocr, UV);
1136 57 100         for (j = 0; j < numr; j++) { /* get a list from each root */
1137 46           roots2 = _allrootmod_kprime(&nr2, roots[j], primek, nf);
1138 46 100         if (nr2 == 0) continue;
1139             /* Append to roots3 */
1140 23 50         if (nr3 + nr2 > MAX_ROOTS_RETURNED) croak("Maximum returned roots exceeded");
1141 23 100         if (nr3 + nr2 >= allocr) Renew(roots3, allocr += nr2, UV);
    50          
1142 89 100         for (t = 0; t < nr2; t++)
1143 66           roots3[nr3++] = roots2[t];
1144 23           Safefree(roots2);
1145             }
1146             /* We've walked through all the roots combining to roots3 */
1147 11           Safefree(roots);
1148 11           roots = roots3;
1149 11           numr = nr3;
1150             }
1151              
1152             }
1153 44 100         if (numr > 1)
1154 30           sort_uv_array(roots, numr);
1155 44           *nroots = numr;
1156 44           return roots;
1157             }