diff --git a/src/crypto/internal/nistec/nistec_test.go b/src/crypto/internal/nistec/nistec_test.go index 309f68be16a9fc..9103608c18a0f9 100644 --- a/src/crypto/internal/nistec/nistec_test.go +++ b/src/crypto/internal/nistec/nistec_test.go @@ -8,6 +8,7 @@ import ( "bytes" "crypto/elliptic" "crypto/internal/nistec" + "fmt" "internal/testenv" "math/big" "math/rand" @@ -165,6 +166,86 @@ func testEquivalents[P nistPoint[P]](t *testing.T, newPoint func() P, c elliptic } } +func TestScalarMult(t *testing.T) { + t.Run("P224", func(t *testing.T) { + testScalarMult(t, nistec.NewP224Point, elliptic.P224()) + }) + t.Run("P256", func(t *testing.T) { + testScalarMult(t, nistec.NewP256Point, elliptic.P256()) + }) + t.Run("P384", func(t *testing.T) { + testScalarMult(t, nistec.NewP384Point, elliptic.P384()) + }) + t.Run("P521", func(t *testing.T) { + testScalarMult(t, nistec.NewP521Point, elliptic.P521()) + }) +} + +func testScalarMult[P nistPoint[P]](t *testing.T, newPoint func() P, c elliptic.Curve) { + G := newPoint().SetGenerator() + checkScalar := func(t *testing.T, scalar []byte) { + p1, err := newPoint().ScalarBaseMult(scalar) + fatalIfErr(t, err) + p2, err := newPoint().ScalarMult(G, scalar) + fatalIfErr(t, err) + if !bytes.Equal(p1.Bytes(), p2.Bytes()) { + t.Error("[k]G != ScalarBaseMult(k)") + } + + d := new(big.Int).SetBytes(scalar) + d.Sub(c.Params().N, d) + d.Mod(d, c.Params().N) + g1, err := newPoint().ScalarBaseMult(d.FillBytes(make([]byte, len(scalar)))) + fatalIfErr(t, err) + g1.Add(g1, p1) + if !bytes.Equal(g1.Bytes(), newPoint().Bytes()) { + t.Error("[N - k]G + [k]G != ∞") + } + } + + byteLen := len(c.Params().N.Bytes()) + bitLen := c.Params().N.BitLen() + t.Run("0", func(t *testing.T) { checkScalar(t, make([]byte, byteLen)) }) + t.Run("1", func(t *testing.T) { + checkScalar(t, big.NewInt(1).FillBytes(make([]byte, byteLen))) + }) + t.Run("N-1", func(t *testing.T) { + checkScalar(t, new(big.Int).Sub(c.Params().N, big.NewInt(1)).Bytes()) + }) + t.Run("N", func(t *testing.T) { checkScalar(t, c.Params().N.Bytes()) }) + t.Run("N+1", func(t *testing.T) { + checkScalar(t, new(big.Int).Add(c.Params().N, big.NewInt(1)).Bytes()) + }) + t.Run("all1s", func(t *testing.T) { + s := new(big.Int).Lsh(big.NewInt(1), uint(bitLen)) + s.Sub(s, big.NewInt(1)) + checkScalar(t, s.Bytes()) + }) + if testing.Short() { + return + } + for i := 0; i < bitLen; i++ { + t.Run(fmt.Sprintf("1<<%d", i), func(t *testing.T) { + s := new(big.Int).Lsh(big.NewInt(1), uint(i)) + checkScalar(t, s.FillBytes(make([]byte, byteLen))) + }) + } + // Test N+1...N+32 since they risk overlapping with precomputed table values + // in the final additions. + for i := int64(2); i <= 32; i++ { + t.Run(fmt.Sprintf("N+%d", i), func(t *testing.T) { + checkScalar(t, new(big.Int).Add(c.Params().N, big.NewInt(i)).Bytes()) + }) + } +} + +func fatalIfErr(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } +} + func BenchmarkScalarMult(b *testing.B) { b.Run("P224", func(b *testing.B) { benchmarkScalarMult(b, nistec.NewP224Point().SetGenerator(), 28) diff --git a/src/crypto/internal/nistec/p256_asm.go b/src/crypto/internal/nistec/p256_asm.go index 6ea161eb499532..99a22b833f028c 100644 --- a/src/crypto/internal/nistec/p256_asm.go +++ b/src/crypto/internal/nistec/p256_asm.go @@ -364,6 +364,21 @@ func p256PointDoubleAsm(res, in *P256Point) // Montgomery domain (with R 2²⁵⁶) as four uint64 limbs in little-endian order. type p256OrdElement [4]uint64 +// p256OrdReduce ensures s is in the range [0, ord(G)-1]. +func p256OrdReduce(s *p256OrdElement) { + // Since 2 * ord(G) > 2²⁵⁶, we can just conditionally subtract ord(G), + // keeping the result if it doesn't underflow. + t0, b := bits.Sub64(s[0], 0xf3b9cac2fc632551, 0) + t1, b := bits.Sub64(s[1], 0xbce6faada7179e84, b) + t2, b := bits.Sub64(s[2], 0xffffffffffffffff, b) + t3, b := bits.Sub64(s[3], 0xffffffff00000000, b) + tMask := b - 1 // zero if subtraction underflowed + s[0] ^= (t0 ^ s[0]) & tMask + s[1] ^= (t1 ^ s[1]) & tMask + s[2] ^= (t2 ^ s[2]) & tMask + s[3] ^= (t3 ^ s[3]) & tMask +} + // Add sets q = p1 + p2, and returns q. The points may overlap. func (q *P256Point) Add(r1, r2 *P256Point) *P256Point { var sum, double P256Point @@ -393,6 +408,7 @@ func (r *P256Point) ScalarBaseMult(scalar []byte) (*P256Point, error) { } scalarReversed := new(p256OrdElement) p256OrdBigToLittle(scalarReversed, (*[32]byte)(scalar)) + p256OrdReduce(scalarReversed) r.p256BaseMult(scalarReversed) return r, nil @@ -407,6 +423,7 @@ func (r *P256Point) ScalarMult(q *P256Point, scalar []byte) (*P256Point, error) } scalarReversed := new(p256OrdElement) p256OrdBigToLittle(scalarReversed, (*[32]byte)(scalar)) + p256OrdReduce(scalarReversed) r.Set(q).p256ScalarMult(scalarReversed) return r, nil diff --git a/src/crypto/internal/nistec/p256_ordinv.go b/src/crypto/internal/nistec/p256_ordinv.go index 86a7a230bdce8e..1274fb7fd3f5cc 100644 --- a/src/crypto/internal/nistec/p256_ordinv.go +++ b/src/crypto/internal/nistec/p256_ordinv.go @@ -25,6 +25,7 @@ func P256OrdInverse(k []byte) ([]byte, error) { x := new(p256OrdElement) p256OrdBigToLittle(x, (*[32]byte)(k)) + p256OrdReduce(x) // Inversion is implemented as exponentiation by n - 2, per Fermat's little theorem. //