File Coverage

mulmod.h
Criterion Covered Total %
statement 27 27 100.0
branch 18 20 90.0
condition n/a
subroutine n/a
pod n/a
total 45 47 95.7


line stmt bran cond sub pod time code
1             #ifndef MPU_MULMOD_H
2             #define MPU_MULMOD_H
3              
4             #include "ptypes.h"
5              
6             /* if n is smaller than this, you can multiply without overflow */
7             #define HALF_WORD (UVCONST(1) << (BITS_PER_WORD/2))
8             /* This will be true if we think mulmods are fast */
9             #define MULMODS_ARE_FAST 1
10              
11             /* x86-64 ARM RISC-V
12             * umul 64->128 mul -> rdx:rax umulh/mul mulhu/mulu
13             * smul 64->128 imul -> rdx:rax smulh/mul mulsu/muls
14             * udiv 128->64 div -> q:rax r:rdx divu (RV128I)
15             * sdiv 128->64 idiv -> q:rax r:rdx divs (RV128I)
16             * clmul 64->128 pclmulqdq -> xmm pmull/pmull2 clmul/clmulh
17             *
18             * __int128 (GCC, clang, CUDA 11.5+)
19             * MSVC std::_Unsigned128 in <__msvc_int128.hpp>
20             * C23 _BitInt(128) (clang 14+, gcc 14+)
21             */
22              
23              
24             #if (BITS_PER_WORD == 32) && HAVE_UINT64
25              
26             /* We have 64-bit available, but UV is 32-bit. Do the math in 64-bit.
27             * Even if it is emulated, it should be as fast or faster than us doing it.
28             */
29             #define addmod(a,b,n) (UV)( ((uint64_t)(a) + (b)) % (n) )
30             #define mulmod(a,b,n) (UV)( ((uint64_t)(a) * (b)) % (n) )
31             #define sqrmod(a,n) (UV)( ((uint64_t)(a) * (a)) % (n) )
32              
33             #elif defined(__GNUC__) && defined(__x86_64__)
34              
35             /* GCC on a 64-bit Intel x86, help from WraithX and Wojciech Izykowski */
36             /* Beware: if (a*b)/c > 2^64, there will be an FP exception */
37 36590           static INLINE UV _mulmod(UV a, UV b, UV n) {
38             UV d, dummy; /* d will get a*b mod c */
39 36590           __asm__ ("mulq %3\n\t" /* mul a*b -> rdx:rax */
40             "divq %4\n\t" /* (a*b)/c -> quot in rax remainder in rdx */
41             :"=a"(dummy), "=&d"(d) /* output */
42             :"a"(a), "r"(b), "r"(n) /* input */
43             :"cc" /* mulq and divq can set conditions */
44             );
45 36590           return d;
46             }
47             #define mulmod(a,b,n) _mulmod(a,b,n)
48             #define sqrmod(a,n) _mulmod(a,a,n)
49             /* A version for _MSC_VER:
50             * __asm { mov rax, qword ptr a
51             * mul qword ptr b
52             * div qword ptr c
53             * mov qword ptr d, rdx } */
54              
55             /* addmod from Kruppa 2010 page 67 */
56 17511           static INLINE UV _addmod(UV a, UV b, UV n) {
57 17511           UV t = a-n;
58 17511           a += b;
59 17511           __asm__ ("add %2, %1\n\t" /* t := t + b */
60             "cmovc %1, %0\n\t" /* if (carry) a := t */
61             :"+r" (a), "+&r" (t)
62             :"r" (b)
63             :"cc"
64             );
65 17511           return a;
66             }
67             #define addmod(a,b,n) _addmod(a,b,n)
68              
69             #elif BITS_PER_WORD == 64 && HAVE_UINT128
70              
71             /* We're 64-bit, using a modern gcc, and the target has some 128-bit type.
72             * The actual number of targets that have this implemented are limited.
73             * However, the late 2020 Apple M1 Macs use this. */
74              
75             #define mulmod(a,b,n) (UV)( ((uint128_t)(a) * (b)) % (n) )
76             #define sqrmod(a,n) (UV)( ((uint128_t)(a) * (a)) % (n) )
77              
78             #else
79              
80             /* UV is the largest integral type available (that we know of). */
81             #undef MULMODS_ARE_FAST
82             #define MULMODS_ARE_FAST 0
83              
84             /* Do it by hand */
85             static INLINE UV _mulmod(UV a, UV b, UV n) {
86             UV r = 0;
87             if (a >= n) a %= n; /* Careful attention from the caller should make */
88             if (b >= n) b %= n; /* these unnecessary. */
89             if ((a|b) < HALF_WORD) return (a*b) % n;
90             if (a < b) { UV t = a; a = b; b = t; }
91             if (n <= (UV_MAX>>1)) {
92             while (b > 0) {
93             if (b & 1) { r += a; if (r >= n) r -= n; }
94             b >>= 1;
95             if (b) { a += a; if (a >= n) a -= n; }
96             }
97             } else {
98             while (b > 0) {
99             if (b & 1) r = ((n-r) > a) ? r+a : r+a-n; /* r = (r + a) % n */
100             b >>= 1;
101             if (b) a = ((n-a) > a) ? a+a : a+a-n; /* a = (a + a) % n */
102             }
103             }
104             return r;
105             }
106              
107             #define mulmod(a,b,n) _mulmod(a,b,n)
108             #define sqrmod(a,n) _mulmod(a,a,n)
109              
110             #endif
111              
112             #ifndef addmod
113             static INLINE UV addmod(UV a, UV b, UV n) {
114             return ((n-a) > b) ? a+b : a+b-n;
115             }
116             #endif
117              
118 3844           static INLINE UV submod(UV a, UV b, UV n) {
119 3844           UV t = n-b; /* Evaluate as UV, then hand to addmod */
120 3844           return addmod(a, t, n);
121             }
122              
123             /* a^2 + c mod n */
124             #define sqraddmod(a, c, n) addmod(sqrmod(a,n), c, n)
125             /* a*b + c mod n */
126             #define muladdmod(a, b, c, n) addmod(mulmod(a,b,n), c, n)
127             /* a*b - c mod n */
128             #define mulsubmod(a, b, c, n) submod(mulmod(a,b,n), c, n)
129              
130             /* a^k mod n */
131             #ifndef HALF_WORD
132             static INLINE UV powmod(UV a, UV k, UV n) {
133             UV t = 1;
134             if (a >= n) a %= n;
135             while (k) {
136             if (k & 1) t = mulmod(t, a, n);
137             k >>= 1;
138             if (k) a = sqrmod(a, n);
139             }
140             return t;
141             }
142             #else
143 313           static INLINE UV powmod(UV a, UV k, UV n) {
144 313           UV t = 1;
145 313 50         if (a >= n) a %= n;
146 313 100         if (n < HALF_WORD) {
147 1260 100         while (k) {
148 1038 100         if (k & 1) t = (t*a)%n;
149 1038           k >>= 1;
150 1038 100         if (k) a = (a*a)%n;
151             }
152             } else {
153 4010 100         while (k) {
154 3919 100         if (k & 1) t = mulmod(t, a, n);
155 3919           k >>= 1;
156 3919 100         if (k) a = sqrmod(a, n);
157             }
158             }
159 313           return t;
160             }
161             #endif
162              
163             /* a^k + c mod n */
164             #define powaddmod(a, k, c, n) addmod(powmod(a,k,n),c,n)
165              
166 15           static INLINE UV negmod(UV a, UV n) {
167 15 50         if (a >= n) a %= n;
168 15 100         return (a) ? (n-a) : 0;
169             }
170              
171             #endif