Skip to content

Commit

Permalink
Merge pull request tlsfuzzer#968 from tlsfuzzer/full-fuzz-for-mlkem-s…
Browse files Browse the repository at this point in the history
…hare

tls13-mlkem: verify that the server verifies correctness of the whole key
  • Loading branch information
tomato42 authored Oct 24, 2024
2 parents 40c0768 + 6127c13 commit a0c066c
Showing 1 changed file with 66 additions and 34 deletions.
100 changes: 66 additions & 34 deletions scripts/test-tls13-mlkem.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import traceback
import sys
import getopt
import copy
from itertools import chain
from random import sample

Expand All @@ -30,7 +31,7 @@
from tlslite.utils.compat import ML_KEM_AVAILABLE


version = 3
version = 4


def help_msg():
Expand All @@ -56,25 +57,27 @@ def help_msg():
print(" 'x25519mlkem768,secp256r1mlkem768,secp384r1mlkem1024' by default")
print(" --cookie expect the server to send \"cookie\" extension in")
print(" Hello Retry Request message")
print(" --no-fuzz Do not generate many ciphertexts with malformed PQC shares")
print(" --help this message")


def main():
host = "localhost"
port = 4433
num_limit = None
num_limit = 400
run_exclude = set()
expected_failures = {}
last_exp_tmp = None
ciphers = None
cookie = False
fuzz = True
kems = [GroupName.secp256r1mlkem768,
GroupName.x25519mlkem768,
GroupName.secp384r1mlkem1024]

argv = sys.argv[1:]
opts, args = getopt.getopt(argv, "h:p:e:x:X:n:C:",
["help", "kems=", "cookie"])
["help", "kems=", "cookie", "no-fuzz"])
for opt, arg in opts:
if opt == '-h':
host = arg
Expand Down Expand Up @@ -103,6 +106,8 @@ def main():
kems = [getattr(GroupName, i) for i in arg.split(",")]
elif opt == "--cookie":
cookie = True
elif opt == "--no-fuzz":
fuzz = False
elif opt == '--help':
help_msg()
sys.exit(0)
Expand Down Expand Up @@ -289,41 +294,68 @@ def main():

conversations["{0}: malformed classical part".format(GroupName.toStr(group))] = conversation

conversation = Connect(host, port)
node = conversation
ext = dict(default_ext)
groups = [group]
key_shares = [key_share_gen(group)]
# last 32 bytes of the public key are "rho" in ML-KEM; i.e.
# seed that can have any values; change only values of the t_hat
if group == GroupName.x25519mlkem768:
key_shares[0].key_exchange[1] = 0xff
pqc_start = 0
pqc_length = 384 * 3
elif group == GroupName.secp256r1mlkem768:
key_shares[0].key_exchange[67] = 0xff
# length of the secp256r1 key share
pqc_start = 65
pqc_length = 384 * 3
else:
assert group == GroupName.secp384r1mlkem1024
key_shares[0].key_exchange[98] = 0xff
ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
.create([TLS_1_3_DRAFT, (3, 3)])
ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
.create(groups)
sig_algs = [SignatureScheme.rsa_pss_rsae_sha256,
SignatureScheme.rsa_pss_pss_sha256,
SignatureScheme.ecdsa_secp256r1_sha256,
SignatureScheme.ed25519,
SignatureScheme.ed448]
ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
.create(sig_algs)
ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
.create(SIG_ALL)
node = node.add_child(ClientHelloGenerator(
ciphers + [CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV],
extensions=ext))
node = node.add_child(ExpectAlert(
AlertLevel.fatal,
AlertDescription.illegal_parameter))
node = node.add_child(ExpectClose())

conversations["{0}: malformed pqc part".format(GroupName.toStr(group))] = conversation
# length of the secp384r1 key share
pqc_start = 97
pqc_length = 384 * 4

if not fuzz:
pqc_length = 6

clean_key_share = key_share_gen(group)
for i in range(0, pqc_length):
conversation = Connect(host, port)
node = conversation
ext = dict(default_ext)
groups = [group]
key_share = copy.deepcopy(clean_key_share)
# replace variable with the smallest invalid one (q = 3329)
if i % 3 == 0:
key_share.key_exchange[pqc_start + i] = 0x01
key_share.key_exchange[pqc_start + i + 1] &= 0xf0
key_share.key_exchange[pqc_start + i + 1] |= 0x0d
elif i % 3 == 2:
key_share.key_exchange[pqc_start + i - 1] &= 0x0f
key_share.key_exchange[pqc_start + i - 1] |= 0x10
key_share.key_exchange[pqc_start + i] = 0xd0
else:
# as values are 12 bits long, every three bytes we point
# the variables we already changed
continue
key_shares = [key_share]
ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
.create([TLS_1_3_DRAFT, (3, 3)])
ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
.create(groups)
sig_algs = [SignatureScheme.rsa_pss_rsae_sha256,
SignatureScheme.rsa_pss_pss_sha256,
SignatureScheme.ecdsa_secp256r1_sha256,
SignatureScheme.ed25519,
SignatureScheme.ed448]
ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
.create(sig_algs)
ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
.create(SIG_ALL)
node = node.add_child(ClientHelloGenerator(
ciphers + [CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV],
extensions=ext))
node = node.add_child(ExpectAlert(
AlertLevel.fatal,
AlertDescription.illegal_parameter))
node = node.add_child(ExpectClose())

conversations["{0}: malformed pqc part, variable {1}".format(GroupName.toStr(group), i * 8 // 12)] = conversation

conversation = Connect(host, port)
node = conversation
Expand Down

0 comments on commit a0c066c

Please sign in to comment.