Skip to content

Commit

Permalink
[ML-KEM] Add experimental support for ML-KEM-512-IPD (#1516)
Browse files Browse the repository at this point in the history
Add support and testing for ML-KEM-512-IPD,
as specified in FIPS 203 Initial Public Draft.
This is an intermediate step to support the final
standardized ML-KEM once FIPS 203 is finalized.

We do not plan to support the IPD version long-term,
it will be surpassed by the final FIPS 203 (ML-KEM) definition.
  • Loading branch information
dkostic authored Apr 22, 2024
1 parent c295aef commit 56f3569
Show file tree
Hide file tree
Showing 26 changed files with 1,638 additions and 1,674 deletions.
2 changes: 2 additions & 0 deletions crypto/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,8 @@ add_library(
kyber/kem_kyber.c
lhash/lhash.c
mem.c
ml_kem/ml_kem_512_ipd.c
ml_kem/ml_kem.c
obj/obj.c
obj/obj_xref.c
ocsp/ocsp_asn.c
Expand Down
1 change: 1 addition & 0 deletions crypto/evp_extra/evp_extra_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2070,6 +2070,7 @@ static const struct KnownKEM kKEMs[] = {
{"Kyber512r3", NID_KYBER512_R3, 800, 1632, 768, 32, "kyber/kat/kyber512r3.txt"},
{"Kyber768r3", NID_KYBER768_R3, 1184, 2400, 1088, 32, "kyber/kat/kyber768r3.txt"},
{"Kyber1024r3", NID_KYBER1024_R3, 1568, 3168, 1568, 32, "kyber/kat/kyber1024r3.txt"},
{"MLKEM512IPD", NID_MLKEM512IPD, 800, 1632, 768, 32, "ml_kem/kat/mlkem512ipd.txt"},
};

class PerKEMTest : public testing::TestWithParam<KnownKEM> {};
Expand Down
1 change: 1 addition & 0 deletions crypto/kem/internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ typedef struct {
extern const KEM_METHOD kem_kyber512r3_method;
extern const KEM_METHOD kem_kyber768r3_method;
extern const KEM_METHOD kem_kyber1024r3_method;
extern const KEM_METHOD kem_ml_kem_512_ipd_method;

// KEM structure and helper functions.
typedef struct {
Expand Down
15 changes: 14 additions & 1 deletion crypto/kem/kem.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,21 @@
#include "../internal.h"
#include "internal.h"
#include "../kyber/kem_kyber.h"
#include "../ml_kem/ml_kem.h"

// The KEM parameters listed below are taken from corresponding specifications.
//
// Kyber: - https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
// - Kyber is not standardized yet, so we use the latest specification
// from Round 3 of NIST PQC project.

#define AWSLC_NUM_BUILT_IN_KEMS 3
#define AWSLC_NUM_BUILT_IN_KEMS 4

// TODO(awslc): placeholder OIDs, replace with the real ones when available.
static const uint8_t kOIDKyber512r3[] = {0xff, 0xff, 0xff, 0xff};
static const uint8_t kOIDKyber768r3[] = {0xff, 0xff, 0xff, 0xff};
static const uint8_t kOIDKyber1024r3[] = {0xff, 0xff, 0xff, 0xff};
static const uint8_t kOIDMLKEM512IPD[] = {0xff, 0xff, 0xff, 0xff};

static const KEM built_in_kems[AWSLC_NUM_BUILT_IN_KEMS] = {
{
Expand Down Expand Up @@ -60,6 +62,17 @@ static const KEM built_in_kems[AWSLC_NUM_BUILT_IN_KEMS] = {
KYBER_R3_SHARED_SECRET_LEN, // kem.shared_secret_len
&kem_kyber1024r3_method, // kem.method
},
{
NID_MLKEM512IPD, // kem.nid
kOIDMLKEM512IPD, // kem.oid
sizeof(kOIDMLKEM512IPD), // kem.oid_len
"MLKEM512 IPD", // kem.comment
MLKEM512IPD_PUBLIC_KEY_BYTES, // kem.public_key_len
MLKEM512IPD_SECRET_KEY_BYTES, // kem.secret_key_len
MLKEM512IPD_CIPHERTEXT_BYTES, // kem.ciphertext_len
MLKEM512IPD_SHARED_SECRET_LEN, // kem.shared_secret_len
&kem_ml_kem_512_ipd_method, // kem.method
},
};

const KEM *KEM_find_kem_by_nid(int nid) {
Expand Down
24 changes: 24 additions & 0 deletions crypto/kem/kem_methods.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "internal.h"

#include "../kyber/kem_kyber.h"
#include "../ml_kem/ml_kem.h"

static int kyber512r3_keygen(uint8_t *public_key,
uint8_t *secret_key) {
Expand Down Expand Up @@ -77,3 +78,26 @@ const KEM_METHOD kem_kyber1024r3_method = {
kyber1024r3_encaps,
kyber1024r3_decaps,
};

static int ml_kem_512_ipd_keygen(uint8_t *public_key,
uint8_t *secret_key) {
return ml_kem_512_ipd_keypair(public_key, secret_key) == 0;
}

static int ml_kem_512_ipd_encaps(uint8_t *ciphertext,
uint8_t *shared_secret,
const uint8_t *public_key) {
return ml_kem_512_ipd_encapsulate(ciphertext, shared_secret, public_key) == 0;
}

static int ml_kem_512_ipd_decaps(uint8_t *shared_secret,
const uint8_t *ciphertext,
const uint8_t *secret_key) {
return ml_kem_512_ipd_decapsulate(shared_secret, ciphertext, secret_key) == 0;
}

const KEM_METHOD kem_ml_kem_512_ipd_method = {
ml_kem_512_ipd_keygen,
ml_kem_512_ipd_encaps,
ml_kem_512_ipd_decaps,
};
8 changes: 8 additions & 0 deletions crypto/ml_kem/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,11 @@ NOTE: THIS IS AN IMPLEMENTATION OF THE DRAFT VERSION OF FIPS 203, NOT THE FINAL

**Source code origin and modifications.** The source code was imported from a branch of the official repository of the Crystals-Kyber team that follows the standard draft: https://github.com/pq-crystals/kyber/tree/standard. The code was taken at [commit](https://github.com/pq-crystals/kyber/commit/11d00ff1f20cfca1f72d819e5a45165c1e0a2816) as of 03/26/2024. At the moment, only the reference C implementation is imported.

The following changes were made to the source code in `ml_kem_ipd_ref_common` directory:
- `randombytes.{h|c}` are deleted because we are using the randomness generation functions provided by AWS-LC.
- `kem.c`: call to randombytes function is replaced with a call to pq_custom_randombytes and the appropriate header file is included (crypto/rand_extra/pq_custom_randombytes.h).
- `fips202.{h|c}` are deleted and the ones from `crypto/kyber/pqcrystals_kyber_ref_common` directory are used.
- `symmetric-shake.c`: unnecessary include of fips202.h is removed.
- `api.h`: `pqcrystals` prefix substituted with `ml_kem` (to be able build alongside `crypto/kyber`).

The KATs were generated by compiling and running the KAT generator tests from the official repository. Specifically, running `make` in the `ref` folder produces `nistkat/PQCgenKAT_kem512` binary that can generates the test vectors.
702 changes: 702 additions & 0 deletions crypto/ml_kem/kat/mlkem512ipd.txt

Large diffs are not rendered by default.

28 changes: 28 additions & 0 deletions crypto/ml_kem/ml_kem.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 OR ISC

#include "ml_kem.h"
#include "ml_kem_ipd_ref_common/api.h"

// Note: These methods currently default to using the reference code for ML_KEM.
// In a future where AWS-LC has optimized options available, those can be
// conditionally (or based on compile-time flags) called here, depending on
// platform support.

int ml_kem_512_ipd_keypair(uint8_t *public_key /* OUT */,
uint8_t *secret_key /* OUT */) {
return ml_kem_512_ref_keypair(public_key, secret_key);
}

int ml_kem_512_ipd_encapsulate(uint8_t *ciphertext /* OUT */,
uint8_t *shared_secret /* OUT */,
const uint8_t *public_key /* IN */) {
return ml_kem_512_ref_enc(ciphertext, shared_secret, public_key);
}

int ml_kem_512_ipd_decapsulate(uint8_t *shared_secret /* OUT */,
const uint8_t *ciphertext /* IN */,
const uint8_t *secret_key /* IN */) {
return ml_kem_512_ref_dec(shared_secret, ciphertext, secret_key);
}

25 changes: 25 additions & 0 deletions crypto/ml_kem/ml_kem.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 OR ISC

#ifndef ML_KEM_H
#define ML_KEM_H

#include <stdint.h>
#include <openssl/base.h>

#define MLKEM512IPD_SHARED_SECRET_LEN (32)
#define MLKEM512IPD_PUBLIC_KEY_BYTES (800)
#define MLKEM512IPD_SECRET_KEY_BYTES (1632)
#define MLKEM512IPD_CIPHERTEXT_BYTES (768)

int ml_kem_512_ipd_keypair(uint8_t *public_key /* OUT */,
uint8_t *secret_key /* OUT */);

int ml_kem_512_ipd_encapsulate(uint8_t *ciphertext /* OUT */,
uint8_t *shared_secret /* OUT */,
const uint8_t *public_key /* IN */);

int ml_kem_512_ipd_decapsulate(uint8_t *shared_secret /* OUT */,
const uint8_t *ciphertext /* IN */,
const uint8_t *secret_key /* IN */);
#endif // ML_KEM_H
19 changes: 19 additions & 0 deletions crypto/ml_kem/ml_kem_512_ipd.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 OR ISC

// The following two lines have to be in that order, first the definition of
// KYBER_K, and then the inclusion of params.h so that the correct version
// of Kyber would be selected. KYBER_K equal to 2 corresponds to ML-KEM-512.
// Both lines also have to come before all the source files.
#define KYBER_K 2
#include "./ml_kem_ipd_ref_common/params.h"

#include "./ml_kem_ipd_ref_common/cbd.c"
#include "./ml_kem_ipd_ref_common/indcpa.c"
#include "./ml_kem_ipd_ref_common/kem.c"
#include "./ml_kem_ipd_ref_common/ntt.c"
#include "./ml_kem_ipd_ref_common/poly.c"
#include "./ml_kem_ipd_ref_common/polyvec.c"
#include "./ml_kem_ipd_ref_common/reduce.c"
#include "./ml_kem_ipd_ref_common/symmetric-shake.c"
#include "./ml_kem_ipd_ref_common/verify.c"
79 changes: 19 additions & 60 deletions crypto/ml_kem/ml_kem_ipd_ref_common/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,65 +2,24 @@
#define API_H

#include <stdint.h>

#define pqcrystals_kyber512_SECRETKEYBYTES 1632
#define pqcrystals_kyber512_PUBLICKEYBYTES 800
#define pqcrystals_kyber512_CIPHERTEXTBYTES 768
#define pqcrystals_kyber512_KEYPAIRCOINBYTES 64
#define pqcrystals_kyber512_ENCCOINBYTES 32
#define pqcrystals_kyber512_BYTES 32

#define pqcrystals_kyber512_ref_SECRETKEYBYTES pqcrystals_kyber512_SECRETKEYBYTES
#define pqcrystals_kyber512_ref_PUBLICKEYBYTES pqcrystals_kyber512_PUBLICKEYBYTES
#define pqcrystals_kyber512_ref_CIPHERTEXTBYTES pqcrystals_kyber512_CIPHERTEXTBYTES
#define pqcrystals_kyber512_ref_KEYPAIRCOINBYTES pqcrystals_kyber512_KEYPAIRCOINBYTES
#define pqcrystals_kyber512_ref_ENCCOINBYTES pqcrystals_kyber512_ENCCOINBYTES
#define pqcrystals_kyber512_ref_BYTES pqcrystals_kyber512_BYTES

int pqcrystals_kyber512_ref_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins);
int pqcrystals_kyber512_ref_keypair(uint8_t *pk, uint8_t *sk);
int pqcrystals_kyber512_ref_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins);
int pqcrystals_kyber512_ref_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk);
int pqcrystals_kyber512_ref_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk);

#define pqcrystals_kyber768_SECRETKEYBYTES 2400
#define pqcrystals_kyber768_PUBLICKEYBYTES 1184
#define pqcrystals_kyber768_CIPHERTEXTBYTES 1088
#define pqcrystals_kyber768_KEYPAIRCOINBYTES 64
#define pqcrystals_kyber768_ENCCOINBYTES 32
#define pqcrystals_kyber768_BYTES 32

#define pqcrystals_kyber768_ref_SECRETKEYBYTES pqcrystals_kyber768_SECRETKEYBYTES
#define pqcrystals_kyber768_ref_PUBLICKEYBYTES pqcrystals_kyber768_PUBLICKEYBYTES
#define pqcrystals_kyber768_ref_CIPHERTEXTBYTES pqcrystals_kyber768_CIPHERTEXTBYTES
#define pqcrystals_kyber768_ref_KEYPAIRCOINBYTES pqcrystals_kyber768_KEYPAIRCOINBYTES
#define pqcrystals_kyber768_ref_ENCCOINBYTES pqcrystals_kyber768_ENCCOINBYTES
#define pqcrystals_kyber768_ref_BYTES pqcrystals_kyber768_BYTES

int pqcrystals_kyber768_ref_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins);
int pqcrystals_kyber768_ref_keypair(uint8_t *pk, uint8_t *sk);
int pqcrystals_kyber768_ref_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins);
int pqcrystals_kyber768_ref_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk);
int pqcrystals_kyber768_ref_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk);

#define pqcrystals_kyber1024_SECRETKEYBYTES 3168
#define pqcrystals_kyber1024_PUBLICKEYBYTES 1568
#define pqcrystals_kyber1024_CIPHERTEXTBYTES 1568
#define pqcrystals_kyber1024_KEYPAIRCOINBYTES 64
#define pqcrystals_kyber1024_ENCCOINBYTES 32
#define pqcrystals_kyber1024_BYTES 32

#define pqcrystals_kyber1024_ref_SECRETKEYBYTES pqcrystals_kyber1024_SECRETKEYBYTES
#define pqcrystals_kyber1024_ref_PUBLICKEYBYTES pqcrystals_kyber1024_PUBLICKEYBYTES
#define pqcrystals_kyber1024_ref_CIPHERTEXTBYTES pqcrystals_kyber1024_CIPHERTEXTBYTES
#define pqcrystals_kyber1024_ref_KEYPAIRCOINBYTES pqcrystals_kyber1024_KEYPAIRCOINBYTES
#define pqcrystals_kyber1024_ref_ENCCOINBYTES pqcrystals_kyber1024_ENCCOINBYTES
#define pqcrystals_kyber1024_ref_BYTES pqcrystals_kyber1024_BYTES

int pqcrystals_kyber1024_ref_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins);
int pqcrystals_kyber1024_ref_keypair(uint8_t *pk, uint8_t *sk);
int pqcrystals_kyber1024_ref_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins);
int pqcrystals_kyber1024_ref_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk);
int pqcrystals_kyber1024_ref_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk);
#include <openssl/base.h>

int ml_kem_512_ref_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins);
int ml_kem_512_ref_keypair(uint8_t *pk, uint8_t *sk);
int ml_kem_512_ref_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins);
int ml_kem_512_ref_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk);
int ml_kem_512_ref_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk);

int ml_kem_768_ref_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins);
int ml_kem_768_ref_keypair(uint8_t *pk, uint8_t *sk);
int ml_kem_768_ref_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins);
int ml_kem_768_ref_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk);
int ml_kem_768_ref_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk);

int ml_kem_1024_ref_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins);
int ml_kem_1024_ref_keypair(uint8_t *pk, uint8_t *sk);
int ml_kem_1024_ref_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins);
int ml_kem_1024_ref_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk);
int ml_kem_1024_ref_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk);

#endif
Loading

0 comments on commit 56f3569

Please sign in to comment.