Skip to content

Commit c58aa12

Browse files
committed
crypto/internal/nistec: reduce P-256 scalar
Unlike the rest of nistec, the P-256 assembly doesn't use complete addition formulas, meaning that p256PointAdd[Affine]Asm won't return the correct value if the two inputs are equal. This was (undocumentedly) ignored in the scalar multiplication loops because as long as the input point is not the identity and the scalar is lower than the order of the group, the addition inputs can't be the same. As part of the math/big rewrite, we went however from always reducing the scalar to only checking its length, under the incorrect assumption that the scalar multiplication loop didn't require reduction. Added a reduction, and while at it added it in P256OrdInverse, too, to enforce a universal reduction invariant on p256OrdElement values. Note that if the input point is the infinity, the code currently still relies on undefined behavior, but that's easily tested to behave acceptably, and will be addressed in a future CL. Fixes #58647 Fixes CVE-2023-24532 (Filed with the "safe APIs like complete addition formulas are good" dept.) Change-Id: I7b2c75238440e6852be2710fad66ff1fdc4e2b24
1 parent badd748 commit c58aa12

File tree

3 files changed

+99
-0
lines changed

3 files changed

+99
-0
lines changed

nistec_test.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package nistec_test
77
import (
88
"bytes"
99
"crypto/elliptic"
10+
"fmt"
1011
"math/big"
1112
"math/rand"
1213
"testing"
@@ -163,6 +164,86 @@ func testEquivalents[P nistPoint[P]](t *testing.T, newPoint func() P, c elliptic
163164
}
164165
}
165166

167+
func TestScalarMult(t *testing.T) {
168+
t.Run("P224", func(t *testing.T) {
169+
testScalarMult(t, nistec.NewP224Point, elliptic.P224())
170+
})
171+
t.Run("P256", func(t *testing.T) {
172+
testScalarMult(t, nistec.NewP256Point, elliptic.P256())
173+
})
174+
t.Run("P384", func(t *testing.T) {
175+
testScalarMult(t, nistec.NewP384Point, elliptic.P384())
176+
})
177+
t.Run("P521", func(t *testing.T) {
178+
testScalarMult(t, nistec.NewP521Point, elliptic.P521())
179+
})
180+
}
181+
182+
func testScalarMult[P nistPoint[P]](t *testing.T, newPoint func() P, c elliptic.Curve) {
183+
G := newPoint().SetGenerator()
184+
checkScalar := func(t *testing.T, scalar []byte) {
185+
p1, err := newPoint().ScalarBaseMult(scalar)
186+
fatalIfErr(t, err)
187+
p2, err := newPoint().ScalarMult(G, scalar)
188+
fatalIfErr(t, err)
189+
if !bytes.Equal(p1.Bytes(), p2.Bytes()) {
190+
t.Error("[k]G != ScalarBaseMult(k)")
191+
}
192+
193+
d := new(big.Int).SetBytes(scalar)
194+
d.Sub(c.Params().N, d)
195+
d.Mod(d, c.Params().N)
196+
g1, err := newPoint().ScalarBaseMult(d.FillBytes(make([]byte, len(scalar))))
197+
fatalIfErr(t, err)
198+
g1.Add(g1, p1)
199+
if !bytes.Equal(g1.Bytes(), newPoint().Bytes()) {
200+
t.Error("[N - k]G + [k]G != ∞")
201+
}
202+
}
203+
204+
byteLen := len(c.Params().N.Bytes())
205+
bitLen := c.Params().N.BitLen()
206+
t.Run("0", func(t *testing.T) { checkScalar(t, make([]byte, byteLen)) })
207+
t.Run("1", func(t *testing.T) {
208+
checkScalar(t, big.NewInt(1).FillBytes(make([]byte, byteLen)))
209+
})
210+
t.Run("N-1", func(t *testing.T) {
211+
checkScalar(t, new(big.Int).Sub(c.Params().N, big.NewInt(1)).Bytes())
212+
})
213+
t.Run("N", func(t *testing.T) { checkScalar(t, c.Params().N.Bytes()) })
214+
t.Run("N+1", func(t *testing.T) {
215+
checkScalar(t, new(big.Int).Add(c.Params().N, big.NewInt(1)).Bytes())
216+
})
217+
t.Run("all1s", func(t *testing.T) {
218+
s := new(big.Int).Lsh(big.NewInt(1), uint(bitLen))
219+
s.Sub(s, big.NewInt(1))
220+
checkScalar(t, s.Bytes())
221+
})
222+
if testing.Short() {
223+
return
224+
}
225+
for i := 0; i < bitLen; i++ {
226+
t.Run(fmt.Sprintf("1<<%d", i), func(t *testing.T) {
227+
s := new(big.Int).Lsh(big.NewInt(1), uint(i))
228+
checkScalar(t, s.FillBytes(make([]byte, byteLen)))
229+
})
230+
}
231+
// Test N+1...N+32 since they risk overlapping with precomputed table values
232+
// in the final additions.
233+
for i := int64(2); i <= 32; i++ {
234+
t.Run(fmt.Sprintf("N+%d", i), func(t *testing.T) {
235+
checkScalar(t, new(big.Int).Add(c.Params().N, big.NewInt(i)).Bytes())
236+
})
237+
}
238+
}
239+
240+
func fatalIfErr(t *testing.T, err error) {
241+
t.Helper()
242+
if err != nil {
243+
t.Fatal(err)
244+
}
245+
}
246+
166247
func BenchmarkScalarMult(b *testing.B) {
167248
b.Run("P224", func(b *testing.B) {
168249
benchmarkScalarMult(b, nistec.NewP224Point().SetGenerator(), 28)

p256_asm.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,21 @@ func p256PointDoubleAsm(res, in *P256Point)
364364
// Montgomery domain (with R 2²⁵⁶) as four uint64 limbs in little-endian order.
365365
type p256OrdElement [4]uint64
366366

367+
// p256OrdReduce ensures s is in the range [0, ord(G)-1].
368+
func p256OrdReduce(s *p256OrdElement) {
369+
// Since 2 * ord(G) > 2²⁵⁶, we can just conditionally subtract ord(G),
370+
// keeping the result if it doesn't underflow.
371+
t0, b := bits.Sub64(s[0], 0xf3b9cac2fc632551, 0)
372+
t1, b := bits.Sub64(s[1], 0xbce6faada7179e84, b)
373+
t2, b := bits.Sub64(s[2], 0xffffffffffffffff, b)
374+
t3, b := bits.Sub64(s[3], 0xffffffff00000000, b)
375+
tMask := b - 1 // zero if subtraction underflowed
376+
s[0] ^= (t0 ^ s[0]) & tMask
377+
s[1] ^= (t1 ^ s[1]) & tMask
378+
s[2] ^= (t2 ^ s[2]) & tMask
379+
s[3] ^= (t3 ^ s[3]) & tMask
380+
}
381+
367382
// Add sets q = p1 + p2, and returns q. The points may overlap.
368383
func (q *P256Point) Add(r1, r2 *P256Point) *P256Point {
369384
var sum, double P256Point
@@ -393,6 +408,7 @@ func (r *P256Point) ScalarBaseMult(scalar []byte) (*P256Point, error) {
393408
}
394409
scalarReversed := new(p256OrdElement)
395410
p256OrdBigToLittle(scalarReversed, (*[32]byte)(scalar))
411+
p256OrdReduce(scalarReversed)
396412

397413
r.p256BaseMult(scalarReversed)
398414
return r, nil
@@ -407,6 +423,7 @@ func (r *P256Point) ScalarMult(q *P256Point, scalar []byte) (*P256Point, error)
407423
}
408424
scalarReversed := new(p256OrdElement)
409425
p256OrdBigToLittle(scalarReversed, (*[32]byte)(scalar))
426+
p256OrdReduce(scalarReversed)
410427

411428
r.Set(q).p256ScalarMult(scalarReversed)
412429
return r, nil

p256_ordinv.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ func p256OrdInverse(k []byte) ([]byte, error) {
2525

2626
x := new(p256OrdElement)
2727
p256OrdBigToLittle(x, (*[32]byte)(k))
28+
p256OrdReduce(x)
2829

2930
// Inversion is implemented as exponentiation by n - 2, per Fermat's little theorem.
3031
//

0 commit comments

Comments
 (0)