Skip to content

Commit

Permalink
pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
goatgoose committed Sep 21, 2023
1 parent c403c5b commit 2ad8699
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 40 deletions.
21 changes: 10 additions & 11 deletions api/s2n.h
Original file line number Diff line number Diff line change
Expand Up @@ -1524,13 +1524,12 @@ S2N_API extern int s2n_client_hello_get_session_id(struct s2n_client_hello *ch,
/**
* Retrieves the supported groups received from the peer in the supported groups extension.
*
* IANA values for each of the received supported groups are written to the provided
* `supported_groups` array, and `supported_groups_count` is set to the number of received
* supported groups.
* IANA values for each of the received supported groups are written to the provided `groups`
* array, and `groups_count` is set to the number of received supported groups.
*
* `max_count` should be set to the maximum capacity of the `supported_groups` array. If
* `max_count` is less than the number of received supported groups, this function will error. To
* determine how large `supported_groups` should be in advance, use
* `groups_count_max` should be set to the maximum capacity of the `groups` array. If
* `groups_count_max` is less than the number of received supported groups, this function will
* error. To determine how large `groups` should be in advance, use
* `s2n_client_hello_get_extension_length()` with the S2N_EXTENSION_SUPPORTED_GROUPS extension
* type, and divide the value by 2.
*
Expand All @@ -1539,13 +1538,13 @@ S2N_API extern int s2n_client_hello_get_session_id(struct s2n_client_hello *ch,
*
* @param ch A pointer to the ClientHello. Can be retrieved from a connection via
* `s2n_connection_get_client_hello()`.
* @param supported_groups The array to populate with the received supported groups.
* @param supported_groups_count Returns the number of received supported groups.
* @param max_count The maximum number of supported groups that can fit in the `supported_groups` array.
* @param groups The array to populate with the received supported groups.
* @param groups_count_max The maximum number of supported groups that can fit in the `groups` array.
* @param groups_count Returns the number of received supported groups.
* @returns S2N_SUCCESS on success. S2N_FAILURE on failure.
*/
S2N_API extern int s2n_client_hello_get_supported_groups(struct s2n_client_hello *ch, uint16_t *supported_groups,
uint16_t *supported_groups_count, uint16_t max_count);
S2N_API extern int s2n_client_hello_get_supported_groups(struct s2n_client_hello *ch, uint16_t *groups,
uint16_t groups_count_max, uint16_t *groups_count);

/**
* Sets the file descriptor for a s2n connection.
Expand Down
38 changes: 19 additions & 19 deletions tests/unit/s2n_client_hello_get_supported_groups_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#define S2N_TEST_SUPPORTED_GROUPS_LIST_COUNT 16

/* Each supported group is 2 bytes. */
#define S2N_TEST_SUPPORTED_GROUPS_LIST_SIZE (S2N_TEST_SUPPORTED_GROUPS_LIST_COUNT * 2)
#define S2N_TEST_SUPPORTED_GROUPS_LIST_SIZE (S2N_TEST_SUPPORTED_GROUPS_LIST_COUNT * S2N_SUPPORTED_GROUP_SIZE)

/* 2 length bytes + space for the list of supported groups. */
#define S2N_TEST_SUPPORTED_GROUPS_EXTENSION_SIZE (2 + S2N_TEST_SUPPORTED_GROUPS_LIST_SIZE)
Expand All @@ -48,7 +48,7 @@ int s2n_check_received_supported_groups_cb(struct s2n_connection *conn, void *ct
uint16_t supported_groups[S2N_TEST_SUPPORTED_GROUPS_LIST_COUNT] = { 0 };
uint16_t supported_groups_count = 0;
EXPECT_SUCCESS(s2n_client_hello_get_supported_groups(client_hello, supported_groups,
&supported_groups_count, s2n_array_len(supported_groups)));
s2n_array_len(supported_groups), &supported_groups_count));

const struct s2n_security_policy *security_policy = context->client_security_policy;
uint16_t expected_groups_count = security_policy->ecc_preferences->count;
Expand Down Expand Up @@ -90,18 +90,18 @@ int main(int argc, char **argv)
uint16_t supported_groups[S2N_TEST_SUPPORTED_GROUPS_LIST_COUNT] = { 0 };
uint16_t supported_groups_count = 0;

EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_supported_groups(NULL, supported_groups, &supported_groups_count,
s2n_array_len(supported_groups)),
EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_supported_groups(NULL, supported_groups,
s2n_array_len(supported_groups), &supported_groups_count),
S2N_ERR_NULL);
EXPECT_EQUAL(supported_groups_count, 0);

EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_supported_groups(&client_hello, NULL, &supported_groups_count,
s2n_array_len(supported_groups)),
EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_supported_groups(&client_hello, NULL,
s2n_array_len(supported_groups), &supported_groups_count),
S2N_ERR_NULL);
EXPECT_EQUAL(supported_groups_count, 0);

EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_supported_groups(&client_hello, supported_groups, NULL,
s2n_array_len(supported_groups)),
EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_supported_groups(&client_hello, supported_groups,
s2n_array_len(supported_groups), NULL),
S2N_ERR_NULL);
EXPECT_EQUAL(supported_groups_count, 0);
}
Expand Down Expand Up @@ -130,15 +130,15 @@ int main(int argc, char **argv)

/* Fail if the provided buffer is too small. */
EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_supported_groups(&client_hello, supported_groups,
&supported_groups_count, S2N_TEST_SUPPORTED_GROUPS_LIST_COUNT - 1),
S2N_TEST_SUPPORTED_GROUPS_LIST_COUNT - 1, &supported_groups_count),
S2N_ERR_SAFETY);
EXPECT_EQUAL(supported_groups_count, 0);

EXPECT_SUCCESS(s2n_stuffer_reread(&extension_stuffer));

/* Succeed with a correctly sized buffer. */
EXPECT_SUCCESS(s2n_client_hello_get_supported_groups(&client_hello, supported_groups, &supported_groups_count,
S2N_TEST_SUPPORTED_GROUPS_LIST_COUNT));
EXPECT_SUCCESS(s2n_client_hello_get_supported_groups(&client_hello, supported_groups,
S2N_TEST_SUPPORTED_GROUPS_LIST_COUNT, &supported_groups_count));
EXPECT_EQUAL(supported_groups_count, S2N_TEST_SUPPORTED_GROUPS_LIST_COUNT);
}

Expand All @@ -148,8 +148,8 @@ int main(int argc, char **argv)

uint16_t supported_groups[S2N_TEST_SUPPORTED_GROUPS_LIST_COUNT] = { 0 };
uint16_t supported_groups_count = 0;
EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_supported_groups(&client_hello, supported_groups, &supported_groups_count,
s2n_array_len(supported_groups)),
EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_supported_groups(&client_hello, supported_groups,
s2n_array_len(supported_groups), &supported_groups_count),
S2N_ERR_EXTENSION_NOT_RECEIVED);
}

Expand Down Expand Up @@ -192,8 +192,8 @@ int main(int argc, char **argv)

uint16_t supported_groups[S2N_TEST_SUPPORTED_GROUPS_LIST_COUNT] = { 0 };
uint16_t supported_groups_count = 0;
int ret = s2n_client_hello_get_supported_groups(client_hello, supported_groups, &supported_groups_count,
s2n_array_len(supported_groups));
int ret = s2n_client_hello_get_supported_groups(client_hello, supported_groups, s2n_array_len(supported_groups),
&supported_groups_count);

if (disable_ecc) {
EXPECT_FAILURE_WITH_ERRNO(ret, S2N_ERR_EXTENSION_NOT_RECEIVED);
Expand Down Expand Up @@ -231,7 +231,7 @@ int main(int argc, char **argv)
uint16_t supported_groups[S2N_TEST_SUPPORTED_GROUPS_LIST_COUNT] = { 0 };
uint16_t supported_groups_count = 0;
EXPECT_SUCCESS(s2n_client_hello_get_supported_groups(&client_hello, supported_groups,
&supported_groups_count, s2n_array_len(supported_groups)));
s2n_array_len(supported_groups), &supported_groups_count));

EXPECT_EQUAL(supported_groups_count, S2N_TEST_SUPPORTED_GROUPS_LIST_COUNT);
}
Expand All @@ -251,7 +251,7 @@ int main(int argc, char **argv)
uint16_t supported_groups[S2N_TEST_SUPPORTED_GROUPS_LIST_COUNT] = { 0 };
uint16_t supported_groups_count = 0;
EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_supported_groups(&client_hello, supported_groups,
&supported_groups_count, s2n_array_len(supported_groups)),
s2n_array_len(supported_groups), &supported_groups_count),
S2N_ERR_SAFETY);

EXPECT_EQUAL(supported_groups_count, 0);
Expand All @@ -273,7 +273,7 @@ int main(int argc, char **argv)
uint16_t supported_groups[S2N_TEST_SUPPORTED_GROUPS_LIST_COUNT] = { 0 };
uint16_t supported_groups_count = 0;
EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_supported_groups(&client_hello, supported_groups,
&supported_groups_count, s2n_array_len(supported_groups)),
s2n_array_len(supported_groups), &supported_groups_count),
S2N_ERR_SAFETY);

EXPECT_EQUAL(supported_groups_count, 0);
Expand Down Expand Up @@ -311,7 +311,7 @@ int main(int argc, char **argv)
uint16_t supported_groups[S2N_TEST_SUPPORTED_GROUPS_LIST_COUNT] = { 0 };
uint16_t supported_groups_count = 0;
EXPECT_SUCCESS(s2n_client_hello_get_supported_groups(&client_hello, supported_groups,
&supported_groups_count, s2n_array_len(supported_groups)));
s2n_array_len(supported_groups), &supported_groups_count));
EXPECT_EQUAL(supported_groups_count, test_groups_count);

struct s2n_stuffer test_groups_list_stuffer = { 0 };
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/s2n_client_hello_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,38 @@ int main(int argc, char **argv)
EXPECT_SUCCESS(s2n_connection_free(conn));
};

/* Test s2n_client_hello_has_extension with a zero-length extension */
for (int send_sct = 0; send_sct <= 1; send_sct++) {
DEFER_CLEANUP(struct s2n_config *config = s2n_config_new(), s2n_config_ptr_free);
EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(config, chain_and_key));

/* The SCT extension is zero-length. */
if (send_sct) {
EXPECT_SUCCESS(s2n_config_set_ct_support_level(config, S2N_CT_SUPPORT_REQUEST));
}

DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT),
s2n_connection_ptr_free);
EXPECT_SUCCESS(s2n_connection_set_config(client, config));

DEFER_CLEANUP(struct s2n_connection *server = s2n_connection_new(S2N_SERVER),
s2n_connection_ptr_free);
EXPECT_SUCCESS(s2n_connection_set_config(server, config));

EXPECT_SUCCESS(s2n_client_hello_send(client));
EXPECT_SUCCESS(s2n_stuffer_copy(&client->handshake.io, &server->handshake.io,
s2n_stuffer_data_available(&client->handshake.io)));
EXPECT_SUCCESS(s2n_client_hello_recv(server));

struct s2n_client_hello *client_hello = s2n_connection_get_client_hello(server);
EXPECT_NOT_NULL(client_hello);

/* Ensure that s2n_client_hello_has_extension knows that the SCT extension was received. */
bool exists = false;
EXPECT_SUCCESS(s2n_client_hello_has_extension(client_hello, S2N_EXTENSION_CERTIFICATE_TRANSPARENCY, &exists));
EXPECT_EQUAL(exists, send_sct);
}

/* Test s2n_client_hello_get_raw_extension */
{
uint8_t data[] = {
Expand Down
4 changes: 2 additions & 2 deletions tls/extensions/s2n_client_supported_groups.c
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ S2N_RESULT s2n_client_supported_groups_parse_groups_count(struct s2n_stuffer *ex
RESULT_GUARD_POSIX(s2n_stuffer_read_uint16(extension, &supported_groups_list_size));

RESULT_ENSURE_LTE(supported_groups_list_size, s2n_stuffer_data_available(extension));
RESULT_ENSURE_EQ(supported_groups_list_size % sizeof(uint16_t), 0);
RESULT_ENSURE_EQ(supported_groups_list_size % S2N_SUPPORTED_GROUP_SIZE, 0);

*count = supported_groups_list_size / 2;
*count = supported_groups_list_size / S2N_SUPPORTED_GROUP_SIZE;

return S2N_RESULT_OK;
}
Expand Down
2 changes: 2 additions & 0 deletions tls/extensions/s2n_client_supported_groups.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include "tls/extensions/s2n_extension_type.h"
#include "tls/s2n_connection.h"

#define S2N_SUPPORTED_GROUP_SIZE 2

extern const s2n_extension_type s2n_client_supported_groups_extension;
bool s2n_extension_should_send_if_ecc_enabled(struct s2n_connection *conn);

Expand Down
16 changes: 8 additions & 8 deletions tls/s2n_client_hello.c
Original file line number Diff line number Diff line change
Expand Up @@ -973,13 +973,13 @@ int s2n_client_hello_has_extension(struct s2n_client_hello *ch, uint16_t extensi
return S2N_SUCCESS;
}

int s2n_client_hello_get_supported_groups(struct s2n_client_hello *ch, uint16_t *supported_groups,
uint16_t *supported_groups_count_out, uint16_t max_count)
int s2n_client_hello_get_supported_groups(struct s2n_client_hello *ch, uint16_t *groups,
uint16_t groups_count_max, uint16_t *groups_count_out)
{
POSIX_ENSURE_REF(supported_groups_count_out);
*supported_groups_count_out = 0;
POSIX_ENSURE_REF(groups_count_out);
*groups_count_out = 0;
POSIX_ENSURE_REF(ch);
POSIX_ENSURE_REF(supported_groups);
POSIX_ENSURE_REF(groups);

s2n_parsed_extension *supported_groups_extension = NULL;
POSIX_GUARD(s2n_client_hello_get_parsed_extension(S2N_EXTENSION_SUPPORTED_GROUPS, &ch->extensions, &supported_groups_extension));
Expand All @@ -989,13 +989,13 @@ int s2n_client_hello_get_supported_groups(struct s2n_client_hello *ch, uint16_t

uint16_t supported_groups_count = 0;
POSIX_GUARD_RESULT(s2n_client_supported_groups_parse_groups_count(&extension_stuffer, &supported_groups_count));
POSIX_ENSURE_LTE(supported_groups_count, max_count);
POSIX_ENSURE_LTE(supported_groups_count, groups_count_max);

for (size_t i = 0; i < supported_groups_count; i++) {
POSIX_GUARD(s2n_stuffer_read_uint16(&extension_stuffer, &supported_groups[i]));
POSIX_GUARD(s2n_stuffer_read_uint16(&extension_stuffer, &groups[i]));
}

*supported_groups_count_out = supported_groups_count;
*groups_count_out = supported_groups_count;

return S2N_SUCCESS;
}

0 comments on commit 2ad8699

Please sign in to comment.