@@ -220,7 +220,7 @@ library Math {
220
220
}
221
221
222
222
/**
223
- * @dev Calculate the modular multiplicative inverse of a number in Z/nZ.
223
+ * @dev Calculate the modular multiplicative inverse of a number in Z/nZ.
224
224
*
225
225
* If n is a prime, then Z/nZ is a field. In that case all elements are inversible, expect 0.
226
226
* If n is not a prime, then Z/nZ is not a field, and some elements might not be inversible.
@@ -230,16 +230,48 @@ library Math {
230
230
function invMod (uint256 a , uint256 n ) internal pure returns (uint256 ) {
231
231
unchecked {
232
232
if (n == 0 ) return 0 ;
233
- uint256 r1 = a % n;
234
- uint256 r2 = n;
235
- int256 t1 = 0 ;
236
- int256 t2 = 1 ;
237
- while (r1 != 0 ) {
238
- uint256 q = r2 / r1;
239
- (t1, t2, r2, r1) = (t2, t1 - t2 * int256 (q), r1, r2 - r1 * q);
233
+
234
+ // The inverse modulo is calculated using the Extended Euclidean Algorithm (iterative version)
235
+ // Used to compute integers x and y such that: ax + ny = gcd(a, n).
236
+ // When the gcd is 1, then the inverse of a modulo n exists and it's x.
237
+ // ax + ny = 1
238
+ // ax = 1 + (-y)n
239
+ // ax ≡ 1 (mod n) # x is the inverse of a modulo n
240
+
241
+ // If the remainder is 0 the gcd is n right away.
242
+ uint256 remainder = a % n;
243
+ uint256 gcd = n;
244
+
245
+ // Therefore the initial coefficients are:
246
+ // ax + ny = gcd(a, n) = n
247
+ // 0a + 1n = n
248
+ int256 x = 0 ;
249
+ int256 y = 1 ;
250
+
251
+ while (remainder != 0 ) {
252
+ uint256 quotient = gcd / remainder;
253
+
254
+ (gcd, remainder) = (
255
+ // The old remainder is the next gcd to try.
256
+ remainder,
257
+ // Compute the next remainder.
258
+ // Can't overflow given that (a % gcd) * (gcd // (a % gcd)) <= gcd
259
+ // where gcd is at most n (capped to type(uint256).max)
260
+ gcd - remainder * quotient
261
+ );
262
+
263
+ (x, y) = (
264
+ // Increment the coefficient of a.
265
+ y,
266
+ // Decrement the coefficient of n.
267
+ // Can overflow, but the result is casted to uint256 so that the
268
+ // next value of y is "wrapped around" to a value between 0 and n - 1.
269
+ x - y * int256 (quotient)
270
+ );
240
271
}
241
- if (r2 != 1 ) return 0 ;
242
- return t1 < 0 ? (n - uint256 (- t1)) : uint256 (t1);
272
+
273
+ if (gcd != 1 ) return 0 ; // No inverse exists.
274
+ return x < 0 ? (n - uint256 (- x)) : uint256 (x); // Wrap the result if it's negative.
243
275
}
244
276
}
245
277
0 commit comments