Skip to content

Commit

Permalink
crypto/internal/mlkem768: move to crypto/internal/fips/mlkem
Browse files Browse the repository at this point in the history
In the process, replace out-of-module imports with their FIPS versions.

For #69536

Change-Id: I83e900b7c38ecf760382e5dca7fd0b1eaa5a5589
Reviewed-on: https://go-review.googlesource.com/c/go/+/626879
LUCI-TryBot-Result: Go LUCI <[email protected]>
Reviewed-by: Russ Cox <[email protected]>
Auto-Submit: Filippo Valsorda <[email protected]>
Reviewed-by: Daniel McCarney <[email protected]>
Reviewed-by: Michael Knyszek <[email protected]>
  • Loading branch information
FiloSottile authored and gopherbot committed Nov 19, 2024
1 parent 9854fc3 commit 7e6b38e
Show file tree
Hide file tree
Showing 23 changed files with 43 additions and 6,835 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package mlkem768
package mlkem

import (
"crypto/internal/fips/sha3"
"errors"
"internal/byteorder"

"golang.org/x/crypto/sha3"
)

// fieldElement is an integer modulo q, an element of ℤ_q. It is always reduced.
Expand Down Expand Up @@ -164,18 +163,18 @@ func polyByteEncode[T ~[n]fieldElement](b []byte, f T) []byte {
// It implements ByteDecode₁₂, according to FIPS 203, Algorithm 6.
func polyByteDecode[T ~[n]fieldElement](b []byte) (T, error) {
if len(b) != encodingSize12 {
return T{}, errors.New("mlkem768: invalid encoding length")
return T{}, errors.New("mlkem: invalid encoding length")
}
var f T
for i := 0; i < n; i += 2 {
d := uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16
const mask12 = 0b1111_1111_1111
var err error
if f[i], err = fieldCheckReduced(uint16(d & mask12)); err != nil {
return T{}, errors.New("mlkem768: invalid polynomial encoding")
return T{}, errors.New("mlkem: invalid polynomial encoding")
}
if f[i+1], err = fieldCheckReduced(uint16(d >> 12)); err != nil {
return T{}, errors.New("mlkem768: invalid polynomial encoding")
return T{}, errors.New("mlkem: invalid polynomial encoding")
}
b = b[3:]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package mlkem768
package mlkem

import (
"math/big"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package mlkem768 implements the quantum-resistant key encapsulation method
// Package mlkem implements the quantum-resistant key encapsulation method
// ML-KEM (formerly known as Kyber), as specified in [NIST FIPS 203].
//
// Only the recommended ML-KEM-768 parameter set is provided.
//
// [NIST FIPS 203]: https://doi.org/10.6028/NIST.FIPS.203
package mlkem768
package mlkem

// This package targets security, correctness, simplicity, readability, and
// reviewability as its primary goals. All critical operations are performed in
Expand All @@ -21,11 +21,10 @@ package mlkem768
// background at https://words.filippo.io/kyber-math/ useful.

import (
"crypto/rand"
"crypto/subtle"
"crypto/internal/fips/drbg"
"crypto/internal/fips/sha3"
"crypto/internal/fips/subtle"
"errors"

"golang.org/x/crypto/sha3"
)

const (
Expand Down Expand Up @@ -125,7 +124,7 @@ type decryptionKey struct {
}

// GenerateKey768 generates a new decapsulation key, drawing random bytes from
// crypto/rand. The decapsulation key must be kept secret.
// a DRBG. The decapsulation key must be kept secret.
func GenerateKey768() (*DecapsulationKey768, error) {
// The actual logic is in a separate function to outline this allocation.
dk := &DecapsulationKey768{}
Expand All @@ -134,9 +133,9 @@ func GenerateKey768() (*DecapsulationKey768, error) {

func generateKey(dk *DecapsulationKey768) *DecapsulationKey768 {
var d [32]byte
rand.Read(d[:])
drbg.Read(d[:])
var z [32]byte
rand.Read(z[:])
drbg.Read(z[:])
return kemKeyGen(dk, &d, &z)
}

Expand All @@ -150,7 +149,7 @@ func NewDecapsulationKey768(seed []byte) (*DecapsulationKey768, error) {

func newKeyFromSeed(dk *DecapsulationKey768, seed []byte) (*DecapsulationKey768, error) {
if len(seed) != SeedSize {
return nil, errors.New("mlkem768: invalid seed length")
return nil, errors.New("mlkem: invalid seed length")
}
d := (*[32]byte)(seed[:32])
z := (*[32]byte)(seed[32:])
Expand Down Expand Up @@ -212,7 +211,7 @@ func kemKeyGen(dk *DecapsulationKey768, d, z *[32]byte) *DecapsulationKey768 {
}

// Encapsulate generates a shared key and an associated ciphertext from an
// encapsulation key, drawing random bytes from crypto/rand.
// encapsulation key, drawing random bytes from a DRBG.
//
// The shared key must be kept secret.
func (ek *EncapsulationKey768) Encapsulate() (ciphertext, sharedKey []byte) {
Expand All @@ -223,7 +222,7 @@ func (ek *EncapsulationKey768) Encapsulate() (ciphertext, sharedKey []byte) {

func (ek *EncapsulationKey768) encapsulate(cc *[CiphertextSize768]byte) (ciphertext, sharedKey []byte) {
var m [messageSize]byte
rand.Read(m[:])
drbg.Read(m[:])
// Note that the modulus check (step 2 of the encapsulation key check from
// FIPS 203, Section 7.2) is performed by polyByteDecode in parseEK.
return kemEncaps(cc, ek, &m)
Expand Down Expand Up @@ -260,10 +259,12 @@ func NewEncapsulationKey768(encapsulationKey []byte) (*EncapsulationKey768, erro
// Algorithm 14.
func parseEK(ek *EncapsulationKey768, ekPKE []byte) (*EncapsulationKey768, error) {
if len(ekPKE) != encryptionKeySize {
return nil, errors.New("mlkem768: invalid encapsulation key length")
return nil, errors.New("mlkem: invalid encapsulation key length")
}

ek.h = sha3.Sum256(ekPKE[:])
h := sha3.New256()
h.Write(ekPKE)
h.Sum(ek.h[:0])

for i := range ek.t {
var err error
Expand Down Expand Up @@ -333,7 +334,7 @@ func pkeEncrypt(cc *[CiphertextSize768]byte, ex *encryptionKey, m *[messageSize]
// The shared key must be kept secret.
func (dk *DecapsulationKey768) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
if len(ciphertext) != CiphertextSize768 {
return nil, errors.New("mlkem768: invalid ciphertext length")
return nil, errors.New("mlkem: invalid ciphertext length")
}
c := (*[CiphertextSize768]byte)(ciphertext)
// Note that the hash check (step 3 of the decapsulation input check from
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package mlkem768
package mlkem

import (
"bytes"
"crypto/internal/fips/sha3"
"crypto/rand"
_ "embed"
"encoding/hex"
"flag"
"testing"

"golang.org/x/crypto/sha3"
)

func TestRoundTrip(t *testing.T) {
Expand Down
6 changes: 3 additions & 3 deletions src/crypto/tls/handshake_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/internal/fips/mlkem"
"crypto/internal/fips/tls13"
"crypto/internal/hpke"
"crypto/internal/mlkem768"
"crypto/rsa"
"crypto/subtle"
"crypto/x509"
Expand Down Expand Up @@ -160,11 +160,11 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echCon
if err != nil {
return nil, nil, nil, err
}
seed := make([]byte, mlkem768.SeedSize)
seed := make([]byte, mlkem.SeedSize)
if _, err := io.ReadFull(config.rand(), seed); err != nil {
return nil, nil, nil, err
}
keyShareKeys.kyber, err = mlkem768.NewDecapsulationKey768(seed)
keyShareKeys.kyber, err = mlkem.NewDecapsulationKey768(seed)
if err != nil {
return nil, nil, nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions src/crypto/tls/handshake_client_tls13.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (
"crypto"
"crypto/hmac"
"crypto/internal/fips/hkdf"
"crypto/internal/fips/mlkem"
"crypto/internal/fips/tls13"
"crypto/internal/mlkem768"
"crypto/rsa"
"crypto/subtle"
"errors"
Expand Down Expand Up @@ -481,7 +481,7 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {

ecdhePeerData := hs.serverHello.serverShare.data
if hs.serverHello.serverShare.group == x25519Kyber768Draft00 {
if len(ecdhePeerData) != x25519PublicKeySize+mlkem768.CiphertextSize768 {
if len(ecdhePeerData) != x25519PublicKeySize+mlkem.CiphertextSize768 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid server key share")
}
Expand Down
4 changes: 2 additions & 2 deletions src/crypto/tls/handshake_server_tls13.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"context"
"crypto"
"crypto/hmac"
"crypto/internal/fips/mlkem"
"crypto/internal/fips/tls13"
"crypto/internal/mlkem768"
"crypto/rsa"
"errors"
"hash"
Expand Down Expand Up @@ -223,7 +223,7 @@ func (hs *serverHandshakeStateTLS13) processClientHello() error {
ecdhData := clientKeyShare.data
if selectedGroup == x25519Kyber768Draft00 {
ecdhGroup = X25519
if len(ecdhData) != x25519PublicKeySize+mlkem768.EncapsulationKeySize768 {
if len(ecdhData) != x25519PublicKeySize+mlkem.EncapsulationKeySize768 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid Kyber client key share")
}
Expand Down
18 changes: 9 additions & 9 deletions src/crypto/tls/key_schedule.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@ package tls
import (
"crypto/ecdh"
"crypto/hmac"
"crypto/internal/fips/mlkem"
"crypto/internal/fips/sha3"
"crypto/internal/fips/tls13"
"crypto/internal/mlkem768"
"errors"
"hash"
"io"

"golang.org/x/crypto/sha3"
)

// This file contains the functions necessary to compute the TLS 1.3 key
Expand Down Expand Up @@ -54,11 +53,11 @@ func (c *cipherSuiteTLS13) exportKeyingMaterial(s *tls13.MasterSecret, transcrip
type keySharePrivateKeys struct {
curveID CurveID
ecdhe *ecdh.PrivateKey
kyber *mlkem768.DecapsulationKey768
kyber *mlkem.DecapsulationKey768
}

// kyberDecapsulate implements decapsulation according to Kyber Round 3.
func kyberDecapsulate(dk *mlkem768.DecapsulationKey768, c []byte) ([]byte, error) {
func kyberDecapsulate(dk *mlkem.DecapsulationKey768, c []byte) ([]byte, error) {
K, err := dk.Decapsulate(c)
if err != nil {
return nil, err
Expand All @@ -68,7 +67,7 @@ func kyberDecapsulate(dk *mlkem768.DecapsulationKey768, c []byte) ([]byte, error

// kyberEncapsulate implements encapsulation according to Kyber Round 3.
func kyberEncapsulate(ek []byte) (c, ss []byte, err error) {
k, err := mlkem768.NewEncapsulationKey768(ek)
k, err := mlkem.NewEncapsulationKey768(ek)
if err != nil {
return nil, nil, err
}
Expand All @@ -77,13 +76,14 @@ func kyberEncapsulate(ek []byte) (c, ss []byte, err error) {
}

func kyberSharedSecret(c, K []byte) []byte {
// Package mlkem768 implements ML-KEM, which compared to Kyber removed a
// Package mlkem implements ML-KEM, which compared to Kyber removed a
// final hashing step. Compute SHAKE-256(K || SHA3-256(c), 32) to match Kyber.
// See https://words.filippo.io/mlkem768/#bonus-track-using-a-ml-kem-implementation-as-kyber-v3.
h := sha3.NewShake256()
h.Write(K)
ch := sha3.Sum256(c)
h.Write(ch[:])
ch := sha3.New256()
ch.Write(c)
h.Write(ch.Sum(nil))
out := make([]byte, 32)
h.Read(out)
return out
Expand Down
4 changes: 2 additions & 2 deletions src/crypto/tls/key_schedule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ package tls

import (
"bytes"
"crypto/internal/fips/mlkem"
"crypto/internal/fips/tls13"
"crypto/internal/mlkem768"
"crypto/sha256"
"encoding/hex"
"strings"
Expand Down Expand Up @@ -120,7 +120,7 @@ func TestTrafficKey(t *testing.T) {
}

func TestKyberEncapsulate(t *testing.T) {
dk, err := mlkem768.GenerateKey768()
dk, err := mlkem.GenerateKey768()
if err != nil {
t.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion src/go/build/deps_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ var depsRules = `
< crypto/internal/fips/hmac
< crypto/internal/fips/check
< crypto/internal/fips/hkdf
< crypto/internal/fips/mlkem
< crypto/internal/fips/ssh
< crypto/internal/fips/tls12
< crypto/internal/fips/tls13
Expand Down Expand Up @@ -525,7 +526,6 @@ var depsRules = `
CRYPTO, FMT, math/big
< crypto/internal/boring/bbig
< crypto/rand
< crypto/internal/mlkem768
< crypto/ed25519
< encoding/asn1
< golang.org/x/crypto/cryptobyte/asn1
Expand Down
62 changes: 0 additions & 62 deletions src/vendor/golang.org/x/crypto/sha3/doc.go

This file was deleted.

Loading

0 comments on commit 7e6b38e

Please sign in to comment.