File Coverage

src/int/i31_moddiv.c
Criterion Covered Total %
statement 150 150 100.0
branch 22 22 100.0
condition n/a
subroutine n/a
pod n/a
total 172 172 100.0


line stmt bran cond sub pod time code
1             /*
2             * Copyright (c) 2018 Thomas Pornin
3             *
4             * Permission is hereby granted, free of charge, to any person obtaining
5             * a copy of this software and associated documentation files (the
6             * "Software"), to deal in the Software without restriction, including
7             * without limitation the rights to use, copy, modify, merge, publish,
8             * distribute, sublicense, and/or sell copies of the Software, and to
9             * permit persons to whom the Software is furnished to do so, subject to
10             * the following conditions:
11             *
12             * The above copyright notice and this permission notice shall be
13             * included in all copies or substantial portions of the Software.
14             *
15             * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
16             * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
17             * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
18             * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
19             * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
20             * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
21             * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22             * SOFTWARE.
23             */
24              
25             #include "inner.h"
26              
27             /*
28             * In this file, we handle big integers with a custom format, i.e.
29             * without the usual one-word header. Value is split into 31-bit words,
30             * each stored in a 32-bit slot (top bit is zero) in little-endian
31             * order. The length (in words) is provided explicitly. In some cases,
32             * the value can be negative (using two's complement representation). In
33             * some cases, the top word is allowed to have a 32th bit.
34             */
35              
36             /*
37             * Negate big integer conditionally. The value consists of 'len' words,
38             * with 31 bits in each word (the top bit of each word should be 0,
39             * except possibly for the last word). If 'ctl' is 1, the negation is
40             * computed; otherwise, if 'ctl' is 0, then the value is unchanged.
41             */
42             static void
43 8260           cond_negate(uint32_t *a, size_t len, uint32_t ctl)
44             {
45             size_t k;
46             uint32_t cc, xm;
47              
48 8260           cc = ctl;
49 8260           xm = -ctl >> 1;
50 148680 100         for (k = 0; k < len; k ++) {
51             uint32_t aw;
52              
53 140420           aw = a[k];
54 140420           aw = (aw ^ xm) + cc;
55 140420           a[k] = aw & 0x7FFFFFFF;
56 140420           cc = aw >> 31;
57             }
58 8260           }
59              
60             /*
61             * Finish modular reduction. Rules on input parameters:
62             *
63             * if neg = 1, then -m <= a < 0
64             * if neg = 0, then 0 <= a < 2*m
65             *
66             * If neg = 0, then the top word of a[] may use 32 bits.
67             *
68             * Also, modulus m must be odd.
69             */
70             static void
71 8260           finish_mod(uint32_t *a, size_t len, const uint32_t *m, uint32_t neg)
72             {
73             size_t k;
74             uint32_t cc, xm, ym;
75              
76             /*
77             * First pass: compare a (assumed nonnegative) with m.
78             * Note that if the final word uses the top extra bit, then
79             * subtracting m must yield a value less than 2^31, since we
80             * assumed that a < 2*m.
81             */
82 8260           cc = 0;
83 148680 100         for (k = 0; k < len; k ++) {
84             uint32_t aw, mw;
85              
86 140420           aw = a[k];
87 140420           mw = m[k];
88 140420           cc = (aw - mw - cc) >> 31;
89             }
90              
91             /*
92             * At this point:
93             * if neg = 1, then we must add m (regardless of cc)
94             * if neg = 0 and cc = 0, then we must subtract m
95             * if neg = 0 and cc = 1, then we must do nothing
96             */
97 8260           xm = -neg >> 1;
98 8260           ym = -(neg | (1 - cc));
99 8260           cc = neg;
100 148680 100         for (k = 0; k < len; k ++) {
101             uint32_t aw, mw;
102              
103 140420           aw = a[k];
104 140420           mw = (m[k] ^ xm) & ym;
105 140420           aw = aw - mw - cc;
106 140420           a[k] = aw & 0x7FFFFFFF;
107 140420           cc = aw >> 31;
108             }
109 8260           }
110              
111             /*
112             * Compute:
113             * a <- (a*pa+b*pb)/(2^31)
114             * b <- (a*qa+b*qb)/(2^31)
115             * The division is assumed to be exact (i.e. the low word is dropped).
116             * If the final a is negative, then it is negated. Similarly for b.
117             * Returned value is the combination of two bits:
118             * bit 0: 1 if a had to be negated, 0 otherwise
119             * bit 1: 1 if b had to be negated, 0 otherwise
120             *
121             * Factors pa, pb, qa and qb must be at most 2^31 in absolute value.
122             * Source integers a and b must be nonnegative; top word is not allowed
123             * to contain an extra 32th bit.
124             */
125             static uint32_t
126 4130           co_reduce(uint32_t *a, uint32_t *b, size_t len,
127             int64_t pa, int64_t pb, int64_t qa, int64_t qb)
128             {
129             size_t k;
130             int64_t cca, ccb;
131             uint32_t nega, negb;
132              
133 4130           cca = 0;
134 4130           ccb = 0;
135 74340 100         for (k = 0; k < len; k ++) {
136             uint32_t wa, wb;
137             uint64_t za, zb;
138             uint64_t tta, ttb;
139              
140             /*
141             * Since:
142             * |pa| <= 2^31
143             * |pb| <= 2^31
144             * 0 <= wa <= 2^31 - 1
145             * 0 <= wb <= 2^31 - 1
146             * |cca| <= 2^32 - 1
147             * Then:
148             * |za| <= (2^31-1)*(2^32) + (2^32-1) = 2^63 - 1
149             *
150             * Thus, the new value of cca is such that |cca| <= 2^32 - 1.
151             * The same applies to ccb.
152             */
153 70210           wa = a[k];
154 70210           wb = b[k];
155 70210           za = wa * (uint64_t)pa + wb * (uint64_t)pb + (uint64_t)cca;
156 70210           zb = wa * (uint64_t)qa + wb * (uint64_t)qb + (uint64_t)ccb;
157 70210 100         if (k > 0) {
158 66080           a[k - 1] = za & 0x7FFFFFFF;
159 66080           b[k - 1] = zb & 0x7FFFFFFF;
160             }
161              
162             /*
163             * For the new values of cca and ccb, we need a signed
164             * right-shift; since, in C, right-shifting a signed
165             * negative value is implementation-defined, we use a
166             * custom portable sign extension expression.
167             */
168             #define M ((uint64_t)1 << 32)
169 70210           tta = za >> 31;
170 70210           ttb = zb >> 31;
171 70210           tta = (tta ^ M) - M;
172 70210           ttb = (ttb ^ M) - M;
173 70210           cca = *(int64_t *)&tta;
174 70210           ccb = *(int64_t *)&ttb;
175             #undef M
176             }
177 4130           a[len - 1] = (uint32_t)cca;
178 4130           b[len - 1] = (uint32_t)ccb;
179              
180 4130           nega = (uint32_t)((uint64_t)cca >> 63);
181 4130           negb = (uint32_t)((uint64_t)ccb >> 63);
182 4130           cond_negate(a, len, nega);
183 4130           cond_negate(b, len, negb);
184 4130           return nega | (negb << 1);
185             }
186              
187             /*
188             * Compute:
189             * a <- (a*pa+b*pb)/(2^31) mod m
190             * b <- (a*qa+b*qb)/(2^31) mod m
191             *
192             * m0i is equal to -1/m[0] mod 2^31.
193             *
194             * Factors pa, pb, qa and qb must be at most 2^31 in absolute value.
195             * Source integers a and b must be nonnegative; top word is not allowed
196             * to contain an extra 32th bit.
197             */
198             static void
199 4130           co_reduce_mod(uint32_t *a, uint32_t *b, size_t len,
200             int64_t pa, int64_t pb, int64_t qa, int64_t qb,
201             const uint32_t *m, uint32_t m0i)
202             {
203             size_t k;
204             int64_t cca, ccb;
205             uint32_t fa, fb;
206              
207 4130           cca = 0;
208 4130           ccb = 0;
209 4130           fa = ((a[0] * (uint32_t)pa + b[0] * (uint32_t)pb) * m0i) & 0x7FFFFFFF;
210 4130           fb = ((a[0] * (uint32_t)qa + b[0] * (uint32_t)qb) * m0i) & 0x7FFFFFFF;
211 74340 100         for (k = 0; k < len; k ++) {
212             uint32_t wa, wb;
213             uint64_t za, zb;
214             uint64_t tta, ttb;
215              
216             /*
217             * In this loop, carries 'cca' and 'ccb' always fit on
218             * 33 bits (in absolute value).
219             */
220 70210           wa = a[k];
221 70210           wb = b[k];
222 70210           za = wa * (uint64_t)pa + wb * (uint64_t)pb
223 70210           + m[k] * (uint64_t)fa + (uint64_t)cca;
224 70210           zb = wa * (uint64_t)qa + wb * (uint64_t)qb
225 70210           + m[k] * (uint64_t)fb + (uint64_t)ccb;
226 70210 100         if (k > 0) {
227 66080           a[k - 1] = (uint32_t)za & 0x7FFFFFFF;
228 66080           b[k - 1] = (uint32_t)zb & 0x7FFFFFFF;
229             }
230              
231             #define M ((uint64_t)1 << 32)
232 70210           tta = za >> 31;
233 70210           ttb = zb >> 31;
234 70210           tta = (tta ^ M) - M;
235 70210           ttb = (ttb ^ M) - M;
236 70210           cca = *(int64_t *)&tta;
237 70210           ccb = *(int64_t *)&ttb;
238             #undef M
239             }
240 4130           a[len - 1] = (uint32_t)cca;
241 4130           b[len - 1] = (uint32_t)ccb;
242              
243             /*
244             * At this point:
245             * -m <= a < 2*m
246             * -m <= b < 2*m
247             * (this is a case of Montgomery reduction)
248             * The top word of 'a' and 'b' may have a 32-th bit set.
249             * We may have to add or subtract the modulus.
250             */
251 4130           finish_mod(a, len, m, (uint32_t)((uint64_t)cca >> 63));
252 4130           finish_mod(b, len, m, (uint32_t)((uint64_t)ccb >> 63));
253 4130           }
254              
255             /* see inner.h */
256             uint32_t
257 118           br_i31_moddiv(uint32_t *x, const uint32_t *y, const uint32_t *m, uint32_t m0i,
258             uint32_t *t)
259             {
260             /*
261             * Algorithm is an extended binary GCD. We maintain four values
262             * a, b, u and v, with the following invariants:
263             *
264             * a * x = y * u mod m
265             * b * x = y * v mod m
266             *
267             * Starting values are:
268             *
269             * a = y
270             * b = m
271             * u = x
272             * v = 0
273             *
274             * The formal definition of the algorithm is a sequence of steps:
275             *
276             * - If a is even, then a <- a/2 and u <- u/2 mod m.
277             * - Otherwise, if b is even, then b <- b/2 and v <- v/2 mod m.
278             * - Otherwise, if a > b, then a <- (a-b)/2 and u <- (u-v)/2 mod m.
279             * - Otherwise, b <- (b-a)/2 and v <- (v-u)/2 mod m.
280             *
281             * Algorithm stops when a = b. At that point, they both are equal
282             * to GCD(y,m); the modular division succeeds if that value is 1.
283             * The result of the modular division is then u (or v: both are
284             * equal at that point).
285             *
286             * Each step makes either a or b shrink by at least one bit; hence,
287             * if m has bit length k bits, then 2k-2 steps are sufficient.
288             *
289             *
290             * Though complexity is quadratic in the size of m, the bit-by-bit
291             * processing is not very efficient. We can speed up processing by
292             * remarking that the decisions are taken based only on observation
293             * of the top and low bits of a and b.
294             *
295             * In the loop below, at each iteration, we use the two top words
296             * of a and b, and the low words of a and b, to compute reduction
297             * parameters pa, pb, qa and qb such that the new values for a
298             * and b are:
299             *
300             * a' = (a*pa + b*pb) / (2^31)
301             * b' = (a*qa + b*qb) / (2^31)
302             *
303             * the division being exact.
304             *
305             * Since the choices are based on the top words, they may be slightly
306             * off, requiring an optional correction: if a' < 0, then we replace
307             * pa with -pa, and pb with -pb. The total length of a and b is
308             * thus reduced by at least 30 bits at each iteration.
309             *
310             * The stopping conditions are still the same, though: when a
311             * and b become equal, they must be both odd (since m is odd,
312             * the GCD cannot be even), therefore the next operation is a
313             * subtraction, and one of the values becomes 0. At that point,
314             * nothing else happens, i.e. one value is stuck at 0, and the
315             * other one is the GCD.
316             */
317             size_t len, k;
318             uint32_t *a, *b, *u, *v;
319             uint32_t num, r;
320              
321 118           len = (m[0] + 31) >> 5;
322 118           a = t;
323 118           b = a + len;
324 118           u = x + 1;
325 118           v = b + len;
326 118           memcpy(a, y + 1, len * sizeof *y);
327 118           memcpy(b, m + 1, len * sizeof *m);
328 118           memset(v, 0, len * sizeof *v);
329              
330             /*
331             * Loop below ensures that a and b are reduced by some bits each,
332             * for a total of at least 30 bits.
333             */
334 4248 100         for (num = ((m[0] - (m[0] >> 5)) << 1) + 30; num >= 30; num -= 30) {
335             size_t j;
336             uint32_t c0, c1;
337             uint32_t a0, a1, b0, b1;
338             uint64_t a_hi, b_hi;
339             uint32_t a_lo, b_lo;
340             int64_t pa, pb, qa, qb;
341             int i;
342              
343             /*
344             * Extract top words of a and b. If j is the highest
345             * index >= 1 such that a[j] != 0 or b[j] != 0, then we want
346             * (a[j] << 31) + a[j - 1], and (b[j] << 31) + b[j - 1].
347             * If a and b are down to one word each, then we use a[0]
348             * and b[0].
349             */
350 4130           c0 = (uint32_t)-1;
351 4130           c1 = (uint32_t)-1;
352 4130           a0 = 0;
353 4130           a1 = 0;
354 4130           b0 = 0;
355 4130           b1 = 0;
356 4130           j = len;
357 74340 100         while (j -- > 0) {
358             uint32_t aw, bw;
359              
360 70210           aw = a[j];
361 70210           bw = b[j];
362 70210           a0 ^= (a0 ^ aw) & c0;
363 70210           a1 ^= (a1 ^ aw) & c1;
364 70210           b0 ^= (b0 ^ bw) & c0;
365 70210           b1 ^= (b1 ^ bw) & c1;
366 70210           c1 = c0;
367 70210           c0 &= (((aw | bw) + 0x7FFFFFFF) >> 31) - (uint32_t)1;
368             }
369              
370             /*
371             * If c1 = 0, then we grabbed two words for a and b.
372             * If c1 != 0 but c0 = 0, then we grabbed one word. It
373             * is not possible that c1 != 0 and c0 != 0, because that
374             * would mean that both integers are zero.
375             */
376 4130           a1 |= a0 & c1;
377 4130           a0 &= ~c1;
378 4130           b1 |= b0 & c1;
379 4130           b0 &= ~c1;
380 4130           a_hi = ((uint64_t)a0 << 31) + a1;
381 4130           b_hi = ((uint64_t)b0 << 31) + b1;
382 4130           a_lo = a[0];
383 4130           b_lo = b[0];
384              
385             /*
386             * Compute reduction factors:
387             *
388             * a' = a*pa + b*pb
389             * b' = a*qa + b*qb
390             *
391             * such that a' and b' are both multiple of 2^31, but are
392             * only marginally larger than a and b.
393             */
394 4130           pa = 1;
395 4130           pb = 0;
396 4130           qa = 0;
397 4130           qb = 1;
398 132160 100         for (i = 0; i < 31; i ++) {
399             /*
400             * At each iteration:
401             *
402             * a <- (a-b)/2 if: a is odd, b is odd, a_hi > b_hi
403             * b <- (b-a)/2 if: a is odd, b is odd, a_hi <= b_hi
404             * a <- a/2 if: a is even
405             * b <- b/2 if: a is odd, b is even
406             *
407             * We multiply a_lo and b_lo by 2 at each
408             * iteration, thus a division by 2 really is a
409             * non-multiplication by 2.
410             */
411             uint32_t r, oa, ob, cAB, cBA, cA;
412             uint64_t rz;
413              
414             /*
415             * r = GT(a_hi, b_hi)
416             * But the GT() function works on uint32_t operands,
417             * so we inline a 64-bit version here.
418             */
419 128030           rz = b_hi - a_hi;
420 128030           r = (uint32_t)((rz ^ ((a_hi ^ b_hi)
421 128030           & (a_hi ^ rz))) >> 63);
422              
423             /*
424             * cAB = 1 if b must be subtracted from a
425             * cBA = 1 if a must be subtracted from b
426             * cA = 1 if a is divided by 2, 0 otherwise
427             *
428             * Rules:
429             *
430             * cAB and cBA cannot be both 1.
431             * if a is not divided by 2, b is.
432             */
433 128030           oa = (a_lo >> i) & 1;
434 128030           ob = (b_lo >> i) & 1;
435 128030           cAB = oa & ob & r;
436 128030           cBA = oa & ob & NOT(r);
437 128030           cA = cAB | NOT(oa);
438              
439             /*
440             * Conditional subtractions.
441             */
442 128030           a_lo -= b_lo & -cAB;
443 128030           a_hi -= b_hi & -(uint64_t)cAB;
444 128030           pa -= qa & -(int64_t)cAB;
445 128030           pb -= qb & -(int64_t)cAB;
446 128030           b_lo -= a_lo & -cBA;
447 128030           b_hi -= a_hi & -(uint64_t)cBA;
448 128030           qa -= pa & -(int64_t)cBA;
449 128030           qb -= pb & -(int64_t)cBA;
450              
451             /*
452             * Shifting.
453             */
454 128030           a_lo += a_lo & (cA - 1);
455 128030           pa += pa & ((int64_t)cA - 1);
456 128030           pb += pb & ((int64_t)cA - 1);
457 128030           a_hi ^= (a_hi ^ (a_hi >> 1)) & -(uint64_t)cA;
458 128030           b_lo += b_lo & -cA;
459 128030           qa += qa & -(int64_t)cA;
460 128030           qb += qb & -(int64_t)cA;
461 128030           b_hi ^= (b_hi ^ (b_hi >> 1)) & ((uint64_t)cA - 1);
462             }
463              
464             /*
465             * Replace a and b with new values a' and b'.
466             */
467 4130           r = co_reduce(a, b, len, pa, pb, qa, qb);
468 4130           pa -= pa * ((r & 1) << 1);
469 4130           pb -= pb * ((r & 1) << 1);
470 4130           qa -= qa * (r & 2);
471 4130           qb -= qb * (r & 2);
472 4130           co_reduce_mod(u, v, len, pa, pb, qa, qb, m + 1, m0i);
473             }
474              
475             /*
476             * Now one of the arrays should be 0, and the other contains
477             * the GCD. If a is 0, then u is 0 as well, and v contains
478             * the division result.
479             * Result is correct if and only if GCD is 1.
480             */
481 118           r = (a[0] | b[0]) ^ 1;
482 118           u[0] |= v[0];
483 2006 100         for (k = 1; k < len; k ++) {
484 1888           r |= a[k] | b[k];
485 1888           u[k] |= v[k];
486             }
487 118           return EQ0(r);
488             }