diff --git a/api/s2n.h b/api/s2n.h index 464c7a13151..4bc5edeb3fc 100644 --- a/api/s2n.h +++ b/api/s2n.h @@ -1521,6 +1521,32 @@ S2N_API extern int s2n_client_hello_get_session_id_length(struct s2n_client_hell */ S2N_API extern int s2n_client_hello_get_session_id(struct s2n_client_hello *ch, uint8_t *out, uint32_t *out_length, uint32_t max_length); +/** + * Retrieves the supported groups received in the supported groups extension. + * + * IANA values for each of the received supported groups are written to the provided + * `supported_groups` array, and `supported_groups_count` is set to the number of received + * supported groups. + * + * `max_count` should be set to the maximum capacity of the `supported_groups` array. If + * `max_count` is less than the number of received supported groups, this function will error. To + * determine how large `supported_groups` should be in advance, use + * `s2n_client_hello_get_extension_length()` with the S2N_EXTENSION_SUPPORTED_GROUPS extension + * type. + * + * If no supported groups extension was received from the peer, or the received supported groups + * extension is malformed, this function will error. + * + * @param ch A pointer to the ClientHello. Can be retrieved from a connection via + * `s2n_connection_get_client_hello()`. + * @param supported_groups An array that will be filled with the received supported groups. + * @param supported_groups_count Set to the number of received supported groups. + * @param max_count The maximum number of supported groups that can fit in `supported_groups`. + * @returns S2N_SUCCESS on success. S2N_FAILURE on failure. + */ +S2N_API extern int s2n_client_hello_get_supported_groups(struct s2n_client_hello *ch, uint16_t *supported_groups, + uint16_t *supported_groups_count, uint16_t max_count); + /** * Sets the file descriptor for a s2n connection. * diff --git a/error/s2n_errno.c b/error/s2n_errno.c index 6ba2fcadfdb..3315b32a2b4 100644 --- a/error/s2n_errno.c +++ b/error/s2n_errno.c @@ -216,6 +216,7 @@ static const char *no_such_error = "Internal s2n error"; ERR_ENTRY(S2N_ERR_SEND_SIZE, "Retried s2n_send() size is invalid") \ ERR_ENTRY(S2N_ERR_CORK_SET_ON_UNMANAGED, "Attempt to set connection cork management on unmanaged IO") \ ERR_ENTRY(S2N_ERR_UNRECOGNIZED_EXTENSION, "TLS extension not recognized") \ + ERR_ENTRY(S2N_ERR_EXTENSION_NOT_RECEIVED, "The TLS extension was not received") \ ERR_ENTRY(S2N_ERR_INVALID_SCT_LIST, "SCT list is invalid") \ ERR_ENTRY(S2N_ERR_INVALID_OCSP_RESPONSE, "OCSP response is invalid") \ ERR_ENTRY(S2N_ERR_UPDATING_EXTENSION, "Updating extension data failed") \ diff --git a/error/s2n_errno.h b/error/s2n_errno.h index 2b09ba13244..5311e41f9dc 100644 --- a/error/s2n_errno.h +++ b/error/s2n_errno.h @@ -258,6 +258,7 @@ typedef enum { S2N_ERR_SEND_SIZE, S2N_ERR_CORK_SET_ON_UNMANAGED, S2N_ERR_UNRECOGNIZED_EXTENSION, + S2N_ERR_EXTENSION_NOT_RECEIVED, S2N_ERR_INVALID_SCT_LIST, S2N_ERR_INVALID_OCSP_RESPONSE, S2N_ERR_UPDATING_EXTENSION, diff --git a/tests/unit/s2n_client_hello_get_supported_groups_test.c b/tests/unit/s2n_client_hello_get_supported_groups_test.c new file mode 100644 index 00000000000..30151d1a3c3 --- /dev/null +++ b/tests/unit/s2n_client_hello_get_supported_groups_test.c @@ -0,0 +1,409 @@ +/* +* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"). +* You may not use this file except in compliance with the License. +* A copy of the License is located at +* +* http://aws.amazon.com/apache2.0 +* +* or in the "license" file accompanying this file. This file is distributed +* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +* express or implied. See the License for the specific language governing +* permissions and limitations under the License. +*/ + +#include "pq-crypto/s2n_pq.h" +#include "s2n_test.h" +#include "testlib/s2n_testlib.h" +#include "tls/extensions/s2n_client_supported_groups.h" +#include "tls/s2n_client_hello.h" +#include "tls/s2n_tls.h" +#include "utils/s2n_random.h" + +#define TEST_SUPPORTED_GROUPS_LIST_COUNT S2N_RECEIVED_SUPPORTED_GROUPS_MAX + +/* Each supported group is 2 bytes. */ +#define TEST_SUPPORTED_GROUPS_LIST_SIZE (TEST_SUPPORTED_GROUPS_LIST_COUNT * 2) + +/* 2 length bytes + space for the maximum number of supported groups. */ +#define TEST_SUPPORTED_GROUPS_EXTENSION_SIZE (2 + TEST_SUPPORTED_GROUPS_LIST_SIZE) + +struct s2n_client_hello_context { + const struct s2n_security_policy *client_security_policy; + unsigned client_supports_pq : 1; + int invoked_count; +}; + +int s2n_check_received_supported_groups_cb(struct s2n_connection *conn, void *ctx) +{ + EXPECT_NOT_NULL(ctx); + + struct s2n_client_hello_context *context = (struct s2n_client_hello_context *) ctx; + context->invoked_count += 1; + + struct s2n_client_hello *client_hello = s2n_connection_get_client_hello(conn); + EXPECT_NOT_NULL(client_hello); + + bool supported_groups_received = false; + EXPECT_SUCCESS(s2n_client_hello_has_extension(client_hello, S2N_EXTENSION_SUPPORTED_GROUPS, + &supported_groups_received)); + if (!supported_groups_received) { + return S2N_SUCCESS; + } + + uint16_t supported_groups[TEST_SUPPORTED_GROUPS_LIST_COUNT] = { 0 }; + uint16_t supported_groups_count = 0; + EXPECT_SUCCESS(s2n_client_hello_get_supported_groups(client_hello, supported_groups, + &supported_groups_count, s2n_array_len(supported_groups))); + + const struct s2n_security_policy *security_policy = context->client_security_policy; + uint16_t expected_groups_count = security_policy->ecc_preferences->count; + if (context->client_supports_pq) { + expected_groups_count += security_policy->kem_preferences->tls13_kem_group_count; + } + EXPECT_EQUAL(supported_groups_count, expected_groups_count); + + size_t offset = 0; + for (size_t i = 0; i < security_policy->kem_preferences->tls13_kem_group_count; i++) { + if (!context->client_supports_pq) { + break; + } + + const struct s2n_kem_group *group = security_policy->kem_preferences->tls13_kem_groups[i]; + EXPECT_EQUAL(supported_groups[i], group->iana_id); + offset += 1; + } + + for (size_t i = 0; i < security_policy->ecc_preferences->count; i++) { + const struct s2n_ecc_named_curve *curve = security_policy->ecc_preferences->ecc_curves[i]; + EXPECT_EQUAL(supported_groups[offset + i], curve->iana_id); + } + + return S2N_SUCCESS; +} + +int main(int argc, char **argv) +{ + BEGIN_TEST(); + + DEFER_CLEANUP(struct s2n_cert_chain_and_key *chain_and_key = NULL, s2n_cert_chain_and_key_ptr_free); + EXPECT_SUCCESS(s2n_test_cert_chain_and_key_new(&chain_and_key, + S2N_DEFAULT_TEST_CERT_CHAIN, S2N_DEFAULT_TEST_PRIVATE_KEY)); + + DEFER_CLEANUP(struct s2n_cert_chain_and_key *ecdsa_chain_and_key = NULL, s2n_cert_chain_and_key_ptr_free); + EXPECT_SUCCESS(s2n_test_cert_chain_and_key_new(&ecdsa_chain_and_key, + S2N_DEFAULT_ECDSA_TEST_CERT_CHAIN, S2N_DEFAULT_ECDSA_TEST_PRIVATE_KEY)); + + /* Safety */ + { + struct s2n_client_hello client_hello = { 0 }; + uint16_t supported_groups[TEST_SUPPORTED_GROUPS_LIST_COUNT] = { 0 }; + uint16_t supported_groups_count = 0; + + EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_supported_groups(NULL, supported_groups, &supported_groups_count, + S2N_RECEIVED_SUPPORTED_GROUPS_MAX), + S2N_ERR_NULL); + EXPECT_EQUAL(supported_groups_count, 0); + + EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_supported_groups(&client_hello, NULL, &supported_groups_count, + S2N_RECEIVED_SUPPORTED_GROUPS_MAX), + S2N_ERR_NULL); + EXPECT_EQUAL(supported_groups_count, 0); + + EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_supported_groups(&client_hello, supported_groups, NULL, + S2N_RECEIVED_SUPPORTED_GROUPS_MAX), + S2N_ERR_NULL); + EXPECT_EQUAL(supported_groups_count, 0); + } + + /* Ensure that the maximum size of the provided supported groups list is respected. */ + { + struct s2n_client_hello client_hello = { 0 }; + + s2n_extension_type_id supported_groups_id = 0; + EXPECT_SUCCESS(s2n_extension_supported_iana_value_to_id(S2N_EXTENSION_SUPPORTED_GROUPS, &supported_groups_id)); + + uint8_t extension_data[TEST_SUPPORTED_GROUPS_EXTENSION_SIZE] = { 0 }; + struct s2n_blob extension_blob = { 0 }; + EXPECT_SUCCESS(s2n_blob_init(&extension_blob, extension_data, sizeof(extension_data))); + + s2n_parsed_extension *supported_groups_extension = &client_hello.extensions.parsed_extensions[supported_groups_id]; + supported_groups_extension->extension_type = S2N_EXTENSION_SUPPORTED_GROUPS; + supported_groups_extension->extension = extension_blob; + + struct s2n_stuffer extension_stuffer = { 0 }; + EXPECT_SUCCESS(s2n_stuffer_init(&extension_stuffer, &extension_blob)); + + EXPECT_SUCCESS(s2n_stuffer_write_uint16(&extension_stuffer, TEST_SUPPORTED_GROUPS_LIST_SIZE)); + + uint16_t supported_groups[TEST_SUPPORTED_GROUPS_LIST_COUNT] = { 0 }; + uint16_t supported_groups_count = 0; + + /* s2n_client_hello_get_supported_groups should fail if the provided buffer is too small. */ + EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_supported_groups(&client_hello, supported_groups, + &supported_groups_count, TEST_SUPPORTED_GROUPS_LIST_COUNT - 1), + S2N_ERR_SAFETY); + EXPECT_EQUAL(supported_groups_count, 0); + + EXPECT_SUCCESS(s2n_stuffer_reread(&extension_stuffer)); + + /* s2n_client_hello_get_supported_groups should succeed with a correctly sized buffer. */ + EXPECT_SUCCESS(s2n_client_hello_get_supported_groups(&client_hello, supported_groups, &supported_groups_count, + S2N_RECEIVED_SUPPORTED_GROUPS_MAX)); + EXPECT_EQUAL(supported_groups_count, s2n_array_len(supported_groups)); + } + + /* Ensure that s2n_client_hello_get_supported_groups fails before the client hello is parsed. */ + { + struct s2n_client_hello client_hello = { 0 }; + + uint16_t supported_groups[TEST_SUPPORTED_GROUPS_LIST_COUNT] = { 0 }; + uint16_t supported_groups_count = 0; + EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_supported_groups(&client_hello, supported_groups, &supported_groups_count, + s2n_array_len(supported_groups)), + S2N_ERR_EXTENSION_NOT_RECEIVED); + } + + /* Ensure that s2n_client_hello_get_supported_groups fails if a supported groups extension + * wasn't received. + */ + for (int disable_ecc = 0; disable_ecc <= 1; disable_ecc++) { + DEFER_CLEANUP(struct s2n_config *config = s2n_config_new(), s2n_config_ptr_free); + EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(config, chain_and_key)); + if (disable_ecc) { + /* The 20150202 security policy doesn't contain any ECDHE cipher suites, so the + * supported groups extension won't be sent. + */ + EXPECT_SUCCESS(s2n_config_set_cipher_preferences(config, "20150202")); + } else { + /* The 20170210 security policy contains ECDHE cipher suites, so the supported groups + * extension will be sent. + */ + EXPECT_SUCCESS(s2n_config_set_cipher_preferences(config, "20170210")); + } + + DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT), + s2n_connection_ptr_free); + EXPECT_SUCCESS(s2n_connection_set_config(client, config)); + + DEFER_CLEANUP(struct s2n_connection *server = s2n_connection_new(S2N_SERVER), + s2n_connection_ptr_free); + EXPECT_SUCCESS(s2n_connection_set_config(server, config)); + + EXPECT_SUCCESS(s2n_client_hello_send(client)); + EXPECT_SUCCESS(s2n_stuffer_copy(&client->handshake.io, &server->handshake.io, + s2n_stuffer_data_available(&client->handshake.io))); + EXPECT_SUCCESS(s2n_client_hello_recv(server)); + + struct s2n_client_hello *client_hello = s2n_connection_get_client_hello(server); + EXPECT_NOT_NULL(client_hello); + + bool supported_groups_extension_exists = false; + EXPECT_SUCCESS(s2n_client_hello_has_extension(client_hello, S2N_EXTENSION_SUPPORTED_GROUPS, + &supported_groups_extension_exists)); + EXPECT_EQUAL(supported_groups_extension_exists, !disable_ecc); + + uint16_t supported_groups[TEST_SUPPORTED_GROUPS_LIST_COUNT] = { 0 }; + uint16_t supported_groups_count = 0; + int ret = s2n_client_hello_get_supported_groups(client_hello, supported_groups, &supported_groups_count, + s2n_array_len(supported_groups)); + + if (disable_ecc) { + EXPECT_FAILURE_WITH_ERRNO(ret, S2N_ERR_EXTENSION_NOT_RECEIVED); + EXPECT_EQUAL(supported_groups_count, 0); + } else { + EXPECT_SUCCESS(ret); + + /* The 20170210 security policy contains 2 ECC curves. */ + EXPECT_EQUAL(supported_groups_count, 2); + } + } + + /* Test parsing a supported groups extension with a malformed groups list length. */ + { + struct s2n_client_hello client_hello = { 0 }; + + s2n_extension_type_id supported_groups_id = 0; + EXPECT_SUCCESS(s2n_extension_supported_iana_value_to_id(S2N_EXTENSION_SUPPORTED_GROUPS, &supported_groups_id)); + + s2n_parsed_extension *supported_groups_extension = &client_hello.extensions.parsed_extensions[supported_groups_id]; + supported_groups_extension->extension_type = S2N_EXTENSION_SUPPORTED_GROUPS; + + /* Test parsing a correct groups list length */ + { + uint8_t extension_data[TEST_SUPPORTED_GROUPS_EXTENSION_SIZE] = { 0 }; + struct s2n_blob extension_blob = { 0 }; + EXPECT_SUCCESS(s2n_blob_init(&extension_blob, extension_data, sizeof(extension_data))); + supported_groups_extension->extension = extension_blob; + + struct s2n_stuffer extension_stuffer = { 0 }; + EXPECT_SUCCESS(s2n_stuffer_init(&extension_stuffer, &extension_blob)); + + EXPECT_SUCCESS(s2n_stuffer_write_uint16(&extension_stuffer, TEST_SUPPORTED_GROUPS_LIST_SIZE)); + + uint16_t supported_groups[TEST_SUPPORTED_GROUPS_LIST_COUNT] = { 0 }; + uint16_t supported_groups_count = 0; + EXPECT_SUCCESS(s2n_client_hello_get_supported_groups(&client_hello, supported_groups, + &supported_groups_count, s2n_array_len(supported_groups))); + + EXPECT_EQUAL(supported_groups_count, TEST_SUPPORTED_GROUPS_LIST_COUNT); + } + + /* Test parsing a groups list length that is larger than the extension length */ + { + uint8_t extension_data[TEST_SUPPORTED_GROUPS_EXTENSION_SIZE] = { 0 }; + struct s2n_blob extension_blob = { 0 }; + EXPECT_SUCCESS(s2n_blob_init(&extension_blob, extension_data, sizeof(extension_data))); + supported_groups_extension->extension = extension_blob; + + struct s2n_stuffer extension_stuffer = { 0 }; + EXPECT_SUCCESS(s2n_stuffer_init(&extension_stuffer, &extension_blob)); + + uint16_t too_many_groups_size = TEST_SUPPORTED_GROUPS_LIST_SIZE + 2; + EXPECT_SUCCESS(s2n_stuffer_write_uint16(&extension_stuffer, too_many_groups_size)); + + uint16_t supported_groups[TEST_SUPPORTED_GROUPS_LIST_COUNT] = { 0 }; + uint16_t supported_groups_count = 0; + EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_supported_groups(&client_hello, supported_groups, + &supported_groups_count, s2n_array_len(supported_groups)), + S2N_ERR_SAFETY); + + EXPECT_EQUAL(supported_groups_count, 0); + } + + /* Test parsing a groups list that contains a partial supported group */ + { + uint8_t extension_data[TEST_SUPPORTED_GROUPS_EXTENSION_SIZE] = { 0 }; + struct s2n_blob extension_blob = { 0 }; + EXPECT_SUCCESS(s2n_blob_init(&extension_blob, extension_data, sizeof(extension_data))); + supported_groups_extension->extension = extension_blob; + + struct s2n_stuffer extension_stuffer = { 0 }; + EXPECT_SUCCESS(s2n_stuffer_init(&extension_stuffer, &extension_blob)); + + uint16_t one_and_a_half_groups_size = 3; + EXPECT_SUCCESS(s2n_stuffer_write_uint16(&extension_stuffer, one_and_a_half_groups_size)); + + uint16_t supported_groups[TEST_SUPPORTED_GROUPS_LIST_COUNT] = { 0 }; + uint16_t supported_groups_count = 0; + EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_supported_groups(&client_hello, supported_groups, + &supported_groups_count, s2n_array_len(supported_groups)), + S2N_ERR_SAFETY); + + EXPECT_EQUAL(supported_groups_count, 0); + } + } + + /* Ensure that s2n_client_hello_get_supported_groups writes what is contained in the parsed + * supported groups extension in the client hello. + */ + { + struct s2n_client_hello client_hello = { 0 }; + + s2n_extension_type_id supported_groups_id = 0; + EXPECT_SUCCESS(s2n_extension_supported_iana_value_to_id(S2N_EXTENSION_SUPPORTED_GROUPS, &supported_groups_id)); + + s2n_parsed_extension *supported_groups_extension = &client_hello.extensions.parsed_extensions[supported_groups_id]; + supported_groups_extension->extension_type = S2N_EXTENSION_SUPPORTED_GROUPS; + + for (uint16_t test_groups_count = 0; test_groups_count < S2N_RECEIVED_SUPPORTED_GROUPS_MAX; test_groups_count++) { + uint16_t test_groups_list_size = test_groups_count * 2; + + uint8_t test_groups_list_data[TEST_SUPPORTED_GROUPS_LIST_SIZE] = { 0 }; + struct s2n_blob test_groups_list_blob = { 0 }; + EXPECT_SUCCESS(s2n_blob_init(&test_groups_list_blob, test_groups_list_data, test_groups_list_size)); + EXPECT_OK(s2n_get_public_random_data(&test_groups_list_blob)); + + uint8_t extension_data[TEST_SUPPORTED_GROUPS_EXTENSION_SIZE] = { 0 }; + struct s2n_blob extension_blob = { 0 }; + EXPECT_SUCCESS(s2n_blob_init(&extension_blob, extension_data, sizeof(extension_data))); + supported_groups_extension->extension = extension_blob; + + struct s2n_stuffer extension_stuffer = { 0 }; + EXPECT_SUCCESS(s2n_stuffer_init(&extension_stuffer, &extension_blob)); + EXPECT_SUCCESS(s2n_stuffer_write_uint16(&extension_stuffer, test_groups_list_size)); + EXPECT_SUCCESS(s2n_stuffer_write(&extension_stuffer, &test_groups_list_blob)); + + uint16_t supported_groups[TEST_SUPPORTED_GROUPS_LIST_COUNT] = { 0 }; + uint16_t supported_groups_count = 0; + EXPECT_SUCCESS(s2n_client_hello_get_supported_groups(&client_hello, supported_groups, + &supported_groups_count, s2n_array_len(supported_groups))); + EXPECT_EQUAL(supported_groups_count, test_groups_count); + + struct s2n_stuffer test_groups_list_stuffer = { 0 }; + EXPECT_SUCCESS(s2n_stuffer_init_written(&test_groups_list_stuffer, &test_groups_list_blob)); + + for (size_t i = 0; i < supported_groups_count; i++) { + uint16_t test_group = 0; + EXPECT_SUCCESS(s2n_stuffer_read_uint16(&test_groups_list_stuffer, &test_group)); + uint16_t written_group = supported_groups[i]; + + EXPECT_EQUAL(test_group, written_group); + } + } + } + + /* Self-talk: Ensure that the retrieved supported groups match what was sent by the client. + * + * This test also ensures that s2n_client_hello_get_supported_groups is usable from within the + * client hello callback. + */ + for (size_t policy_index = 0; security_policy_selection[policy_index].version != NULL; policy_index++) { + /* Skip the null policy */ + if (security_policy_selection[policy_index].security_policy->cipher_preferences == &cipher_preferences_null) { + continue; + } + + const char* version = security_policy_selection[policy_index].version; + + DEFER_CLEANUP(struct s2n_config *client_config = s2n_config_new(), s2n_config_ptr_free); + EXPECT_NOT_NULL(client_config); + EXPECT_SUCCESS(s2n_config_set_unsafe_for_testing(client_config)); + EXPECT_SUCCESS(s2n_config_set_cipher_preferences(client_config, version)); + + DEFER_CLEANUP(struct s2n_config *server_config = s2n_config_new(), s2n_config_ptr_free); + EXPECT_NOT_NULL(server_config); + EXPECT_SUCCESS(s2n_config_set_unsafe_for_testing(server_config)); + EXPECT_SUCCESS(s2n_config_set_cipher_preferences(server_config, version)); + EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(server_config, chain_and_key)); + EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(server_config, ecdsa_chain_and_key)); + + struct s2n_client_hello_context context = { + .client_security_policy = client_config->security_policy, + .client_supports_pq = false, + .invoked_count = 0, + }; + EXPECT_SUCCESS(s2n_config_set_client_hello_cb(server_config, s2n_check_received_supported_groups_cb, + &context)); + + DEFER_CLEANUP(struct s2n_connection *client_conn = s2n_connection_new(S2N_CLIENT), s2n_connection_ptr_free); + EXPECT_NOT_NULL(client_conn); + EXPECT_SUCCESS(s2n_connection_set_config(client_conn, client_config)); + + DEFER_CLEANUP(struct s2n_connection *server_conn = s2n_connection_new(S2N_SERVER), s2n_connection_ptr_free); + EXPECT_NOT_NULL(server_conn); + EXPECT_SUCCESS(s2n_connection_set_config(server_conn, server_config)); + + DEFER_CLEANUP(struct s2n_test_io_pair io_pair = { 0 }, s2n_io_pair_close); + EXPECT_SUCCESS(s2n_io_pair_init_non_blocking(&io_pair)); + EXPECT_SUCCESS(s2n_connection_set_io_pair(client_conn, &io_pair)); + EXPECT_SUCCESS(s2n_connection_set_io_pair(server_conn, &io_pair)); + + s2n_blocked_status blocked = S2N_NOT_BLOCKED; + EXPECT_OK(s2n_negotiate_until_message(client_conn, &blocked, SERVER_HELLO)); + + /* PQ kem groups are only sent in the supported groups extension if the client supports + * TLS 1.3 and PQ is enabled. + */ + if (s2n_connection_get_protocol_version(client_conn) >= S2N_TLS13 && s2n_pq_is_enabled()) { + context.client_supports_pq = true; + } + + EXPECT_SUCCESS(s2n_negotiate_test_server_and_client(server_conn, client_conn)); + + EXPECT_EQUAL(context.invoked_count, 1); + } + + END_TEST(); +} \ No newline at end of file diff --git a/tests/unit/s2n_client_hello_test.c b/tests/unit/s2n_client_hello_test.c index 4d150223e81..815144a9821 100644 --- a/tests/unit/s2n_client_hello_test.c +++ b/tests/unit/s2n_client_hello_test.c @@ -1129,7 +1129,7 @@ int main(int argc, char **argv) EXPECT_EQUAL(s2n_client_hello_get_extension_length(client_hello, S2N_EXTENSION_CERTIFICATE_TRANSPARENCY), 0); EXPECT_NOT_NULL(ext_data = malloc(server_name_extension_len)); EXPECT_EQUAL(s2n_client_hello_get_extension_by_id(client_hello, S2N_EXTENSION_CERTIFICATE_TRANSPARENCY, ext_data, server_name_extension_len), 0); - EXPECT_EQUAL(s2n_errno, S2N_ERR_NULL); + EXPECT_EQUAL(s2n_errno, S2N_ERR_EXTENSION_NOT_RECEIVED); free(ext_data); ext_data = NULL; diff --git a/tls/extensions/s2n_client_supported_groups.c b/tls/extensions/s2n_client_supported_groups.c index 4f7e3e12c4e..94a86d8b1fc 100644 --- a/tls/extensions/s2n_client_supported_groups.c +++ b/tls/extensions/s2n_client_supported_groups.c @@ -78,6 +78,37 @@ static int s2n_client_supported_groups_send(struct s2n_connection *conn, struct return S2N_SUCCESS; } +S2N_RESULT s2n_client_supported_groups_parse_groups_count(struct s2n_stuffer *extension, uint16_t *count) +{ + RESULT_ENSURE_REF(count); + *count = 0; + RESULT_ENSURE_REF(extension); + + uint16_t supported_groups_list_size = 0; + RESULT_GUARD_POSIX(s2n_stuffer_read_uint16(extension, &supported_groups_list_size)); + + RESULT_ENSURE_LTE(supported_groups_list_size, s2n_stuffer_data_available(extension)); + RESULT_ENSURE_EQ(supported_groups_list_size % sizeof(uint16_t), 0); + + *count = supported_groups_list_size / 2; + + return S2N_RESULT_OK; +} + +S2N_RESULT s2n_client_supported_groups_parse_groups(struct s2n_stuffer *extension, uint16_t supported_groups_count, + uint16_t *groups_list, uint16_t groups_list_capacity) +{ + RESULT_ENSURE_REF(extension); + RESULT_ENSURE_REF(groups_list); + RESULT_ENSURE_LTE(supported_groups_count, groups_list_capacity); + + for (size_t i = 0; i < supported_groups_count; i++) { + RESULT_GUARD_POSIX(s2n_stuffer_read_uint16(extension, &groups_list[i])); + } + + return S2N_RESULT_OK; +} + /* Populates the appropriate index of either the mutually_supported_curves or * mutually_supported_kem_groups array based on the received IANA ID. Will * ignore unrecognized IANA IDs (and return success). */ @@ -165,17 +196,18 @@ static int s2n_client_supported_groups_recv(struct s2n_connection *conn, struct POSIX_ENSURE_REF(conn); POSIX_ENSURE_REF(extension); - uint16_t size_of_all; - POSIX_GUARD(s2n_stuffer_read_uint16(extension, &size_of_all)); - if (size_of_all > s2n_stuffer_data_available(extension) || (size_of_all % sizeof(uint16_t))) { + uint16_t supported_groups_count = 0; + if (s2n_result_is_error(s2n_client_supported_groups_parse_groups_count(extension, &supported_groups_count))) { /* Malformed length, ignore the extension */ return S2N_SUCCESS; } - for (size_t i = 0; i < (size_of_all / sizeof(uint16_t)); i++) { - uint16_t iana_id; - POSIX_GUARD(s2n_stuffer_read_uint16(extension, &iana_id)); - POSIX_GUARD(s2n_client_supported_groups_recv_iana_id(conn, iana_id)); + uint16_t supported_groups[S2N_RECEIVED_SUPPORTED_GROUPS_MAX] = { 0 }; + POSIX_GUARD_RESULT(s2n_client_supported_groups_parse_groups(extension, supported_groups_count, supported_groups, + S2N_RECEIVED_SUPPORTED_GROUPS_MAX)); + + for (size_t i = 0; i < supported_groups_count; i++) { + POSIX_GUARD(s2n_client_supported_groups_recv_iana_id(conn, supported_groups[i])); } POSIX_GUARD(s2n_choose_supported_group(conn)); diff --git a/tls/extensions/s2n_client_supported_groups.h b/tls/extensions/s2n_client_supported_groups.h index 322ac14813e..85b8a85965b 100644 --- a/tls/extensions/s2n_client_supported_groups.h +++ b/tls/extensions/s2n_client_supported_groups.h @@ -19,8 +19,14 @@ #include "tls/extensions/s2n_extension_type.h" #include "tls/s2n_connection.h" +#define S2N_RECEIVED_SUPPORTED_GROUPS_MAX 64 + extern const s2n_extension_type s2n_client_supported_groups_extension; bool s2n_extension_should_send_if_ecc_enabled(struct s2n_connection *conn); +S2N_RESULT s2n_client_supported_groups_parse_groups_count(struct s2n_stuffer *extension, uint16_t *count); +S2N_RESULT s2n_client_supported_groups_parse_groups(struct s2n_stuffer *extension, uint16_t supported_groups_count, + uint16_t *groups_list, uint16_t groups_list_capacity); + /* Old-style extension functions -- remove after extensions refactor is complete */ int s2n_recv_client_supported_groups(struct s2n_connection *conn, struct s2n_stuffer *extension); diff --git a/tls/s2n_client_hello.c b/tls/s2n_client_hello.c index dacfc0b9d88..5654689ce02 100644 --- a/tls/s2n_client_hello.c +++ b/tls/s2n_client_hello.c @@ -40,6 +40,7 @@ #include "utils/s2n_bitmap.h" #include "utils/s2n_random.h" #include "utils/s2n_safety.h" +#include "tls/extensions/s2n_client_supported_groups.h" struct s2n_client_hello *s2n_connection_get_client_hello(struct s2n_connection *conn) { @@ -863,7 +864,7 @@ int s2n_client_hello_get_parsed_extension(s2n_tls_extension_type extension_type, POSIX_GUARD(s2n_extension_supported_iana_value_to_id(extension_type, &extension_type_id)); s2n_parsed_extension *found_parsed_extension = &parsed_extension_list->parsed_extensions[extension_type_id]; - POSIX_ENSURE_REF(found_parsed_extension->extension.data); + POSIX_ENSURE(found_parsed_extension->extension.data, S2N_ERR_EXTENSION_NOT_RECEIVED); POSIX_ENSURE(found_parsed_extension->extension_type == extension_type, S2N_ERR_INVALID_PARSED_EXTENSIONS); *parsed_extension = found_parsed_extension; @@ -971,3 +972,28 @@ int s2n_client_hello_has_extension(struct s2n_client_hello *ch, uint16_t extensi } return S2N_SUCCESS; } + +int s2n_client_hello_get_supported_groups(struct s2n_client_hello *ch, uint16_t *supported_groups, + uint16_t *supported_groups_count_out, uint16_t max_count) +{ + POSIX_ENSURE_REF(supported_groups_count_out); + *supported_groups_count_out = 0; + + POSIX_ENSURE_REF(ch); + POSIX_ENSURE_REF(supported_groups); + + s2n_parsed_extension *supported_groups_extension = NULL; + POSIX_GUARD(s2n_client_hello_get_parsed_extension(S2N_EXTENSION_SUPPORTED_GROUPS, &ch->extensions, &supported_groups_extension)); + POSIX_ENSURE_REF(supported_groups_extension); + + struct s2n_stuffer extension_stuffer = { 0 }; + POSIX_GUARD(s2n_stuffer_init_written(&extension_stuffer, &supported_groups_extension->extension)); + + uint16_t supported_groups_count = 0; + POSIX_GUARD_RESULT(s2n_client_supported_groups_parse_groups_count(&extension_stuffer, &supported_groups_count)); + POSIX_GUARD_RESULT(s2n_client_supported_groups_parse_groups(&extension_stuffer, supported_groups_count, supported_groups, max_count)); + + *supported_groups_count_out = supported_groups_count; + + return S2N_SUCCESS; +}