diff --git a/tests/unit/s2n_ktls_io_test.c b/tests/unit/s2n_ktls_io_test.c index a7c0e68d63d..7468aac6c06 100644 --- a/tests/unit/s2n_ktls_io_test.c +++ b/tests/unit/s2n_ktls_io_test.c @@ -27,20 +27,51 @@ S2N_RESULT s2n_ktls_set_control_data(struct msghdr *msg, char *buf, size_t buf_s S2N_RESULT s2n_ktls_get_control_data(struct msghdr *msg, int cmsg_type, uint8_t *record_type); /* Mock implementation used for validating failure behavior */ -struct s2n_test_ktls_io_fail { - size_t invoked_count; +struct s2n_test_ktls_io_fail_ctx { size_t errno_code; + size_t invoked_count; }; static ssize_t s2n_test_ktls_sendmsg_fail(void *io_context, const struct msghdr *msg) { - struct s2n_test_ktls_io_fail *io_ctx = (struct s2n_test_ktls_io_fail *) io_context; + struct s2n_test_ktls_io_fail_ctx *io_ctx = (struct s2n_test_ktls_io_fail_ctx *) io_context; + POSIX_ENSURE_REF(io_ctx); + io_ctx->invoked_count++; + errno = io_ctx->errno_code; + return -1; +} + +static ssize_t s2n_test_ktls_recvmsg_fail(void *io_context, struct msghdr *msg) +{ + POSIX_ENSURE_REF(msg); + + struct s2n_test_ktls_io_fail_ctx *io_ctx = (struct s2n_test_ktls_io_fail_ctx *) io_context; POSIX_ENSURE_REF(io_ctx); io_ctx->invoked_count++; errno = io_ctx->errno_code; return -1; } +static ssize_t s2n_test_ktls_recvmsg_eof(void *io_context, struct msghdr *msg) +{ + struct s2n_test_ktls_io_fail_ctx *io_ctx = (struct s2n_test_ktls_io_fail_ctx *) io_context; + POSIX_ENSURE_REF(io_ctx); + io_ctx->invoked_count++; + return 0; +} + +ssize_t s2n_test_ktls_recvmsg_io_stuffer_and_ctrunc(void *io_context, struct msghdr *msg) +{ + POSIX_ENSURE_REF(msg); + + /* The stuffer mock IO is used to ensure `cmsghdr` is otherwise properly constructed + * and that the failure occurs due to the MSG_CTRUNC flag. */ + ssize_t ret = s2n_test_ktls_recvmsg_io_stuffer(io_context, msg); + POSIX_GUARD(ret); + msg->msg_flags = MSG_CTRUNC; + return ret; +} + int main(int argc, char **argv) { BEGIN_TEST(); @@ -107,7 +138,6 @@ int main(int argc, char **argv) { DEFER_CLEANUP(struct s2n_connection *server = s2n_connection_new(S2N_SERVER), s2n_connection_ptr_free); - struct iovec msg_iov_valid = { .iov_base = test_data, .iov_len = S2N_TEST_TO_SEND }; s2n_blocked_status blocked = S2N_NOT_BLOCKED; size_t bytes_written = 0; @@ -218,7 +248,7 @@ int main(int argc, char **argv) { DEFER_CLEANUP(struct s2n_connection *server = s2n_connection_new(S2N_SERVER), s2n_connection_ptr_free); - struct s2n_test_ktls_io_fail io_ctx = { 0 }; + struct s2n_test_ktls_io_fail_ctx io_ctx = { 0 }; EXPECT_OK(s2n_ktls_set_sendmsg_cb(server, s2n_test_ktls_sendmsg_fail, &io_ctx)); struct iovec msg_iov = { .iov_base = test_data, .iov_len = S2N_TEST_TO_SEND }; @@ -243,7 +273,7 @@ int main(int argc, char **argv) { DEFER_CLEANUP(struct s2n_connection *server = s2n_connection_new(S2N_SERVER), s2n_connection_ptr_free); - struct s2n_test_ktls_io_fail io_ctx = { + struct s2n_test_ktls_io_fail_ctx io_ctx = { .errno_code = EINVAL, }; EXPECT_OK(s2n_ktls_set_sendmsg_cb(server, s2n_test_ktls_sendmsg_fail, &io_ctx)); @@ -286,5 +316,212 @@ int main(int argc, char **argv) }; }; + /* Test s2n_ktls_recvmsg */ + { + /* Safety */ + { + DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT), + s2n_connection_ptr_free); + uint8_t recv_buf[S2N_TLS_MAXIMUM_FRAGMENT_LENGTH] = { 0 }; + s2n_blocked_status blocked = S2N_NOT_BLOCKED; + uint8_t recv_record_type = 0; + size_t bytes_read = 0; + + EXPECT_ERROR_WITH_ERRNO( + s2n_ktls_recvmsg(NULL, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), + S2N_ERR_NULL); + EXPECT_ERROR_WITH_ERRNO( + s2n_ktls_recvmsg(client, NULL, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), + S2N_ERR_NULL); + EXPECT_ERROR_WITH_ERRNO( + s2n_ktls_recvmsg(client, &recv_record_type, NULL, S2N_TEST_TO_SEND, &blocked, &bytes_read), + S2N_ERR_NULL); + EXPECT_ERROR_WITH_ERRNO( + s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, NULL, &bytes_read), + S2N_ERR_NULL); + EXPECT_ERROR_WITH_ERRNO( + s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, NULL), + S2N_ERR_NULL); + + size_t to_recv_zero = 0; + EXPECT_ERROR_WITH_ERRNO( + s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, to_recv_zero, &blocked, &bytes_read), + S2N_ERR_SAFETY); + }; + + /* Happy case: send/recv data using sendmsg/recvmsg */ + { + DEFER_CLEANUP(struct s2n_connection *server = s2n_connection_new(S2N_SERVER), + s2n_connection_ptr_free); + DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT), + s2n_connection_ptr_free); + DEFER_CLEANUP(struct s2n_test_ktls_io_stuffer_pair io_pair = { 0 }, + s2n_ktls_io_stuffer_pair_free); + EXPECT_OK(s2n_test_init_ktls_io_stuffer(server, client, &io_pair)); + + struct iovec msg_iov = { .iov_base = test_data, .iov_len = S2N_TEST_TO_SEND }; + s2n_blocked_status blocked = S2N_NOT_BLOCKED; + size_t bytes_written = 0; + EXPECT_OK(s2n_ktls_sendmsg(server, test_record_type, &msg_iov, 1, &blocked, &bytes_written)); + EXPECT_EQUAL(bytes_written, S2N_TEST_TO_SEND); + + uint8_t recv_buf[S2N_TLS_MAXIMUM_FRAGMENT_LENGTH] = { 0 }; + uint8_t recv_record_type = 0; + size_t bytes_read = 0; + EXPECT_OK(s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read)); + EXPECT_BYTEARRAY_EQUAL(test_data, recv_buf, bytes_read); + EXPECT_EQUAL(bytes_read, bytes_written); + + EXPECT_EQUAL(io_pair.client_in.sendmsg_invoked_count, 1); + EXPECT_EQUAL(io_pair.client_in.recvmsg_invoked_count, 1); + }; + + /* Simulate blocked and handle a S2N_ERR_IO_BLOCKED error */ + { + DEFER_CLEANUP(struct s2n_connection *server = s2n_connection_new(S2N_SERVER), + s2n_connection_ptr_free); + DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT), + s2n_connection_ptr_free); + DEFER_CLEANUP(struct s2n_test_ktls_io_stuffer_pair io_pair = { 0 }, + s2n_ktls_io_stuffer_pair_free); + EXPECT_OK(s2n_test_init_ktls_io_stuffer(server, client, &io_pair)); + + uint8_t recv_buf[S2N_TLS_MAXIMUM_FRAGMENT_LENGTH] = { 0 }; + s2n_blocked_status blocked = S2N_NOT_BLOCKED; + size_t blocked_invoked_count = 5; + uint8_t recv_record_type = 0; + size_t bytes_read = 0; + /* recv should block since there is no data */ + for (size_t i = 0; i < blocked_invoked_count; i++) { + EXPECT_ERROR_WITH_ERRNO( + s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), + S2N_ERR_IO_BLOCKED); + EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_READ); + } + + /* send data to unblock */ + struct iovec msg_iov = { .iov_base = test_data, .iov_len = S2N_TEST_TO_SEND }; + size_t bytes_written = 0; + EXPECT_OK(s2n_ktls_sendmsg(server, test_record_type, &msg_iov, 1, &blocked, &bytes_written)); + EXPECT_EQUAL(bytes_written, S2N_TEST_TO_SEND); + + EXPECT_OK(s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read)); + EXPECT_BYTEARRAY_EQUAL(test_data, recv_buf, bytes_read); + EXPECT_EQUAL(bytes_read, bytes_written); + + /* recv should block again since we have read all the data */ + for (size_t i = 0; i < blocked_invoked_count; i++) { + EXPECT_ERROR_WITH_ERRNO( + s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), + S2N_ERR_IO_BLOCKED); + EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_READ); + } + + EXPECT_EQUAL(io_pair.client_in.recvmsg_invoked_count, (blocked_invoked_count * 2) + 1); + EXPECT_EQUAL(io_pair.client_in.sendmsg_invoked_count, 1); + }; + + /* Both EWOULDBLOCK and EAGAIN should return a S2N_ERR_IO_BLOCKED error */ + { + DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT), + s2n_connection_ptr_free); + struct s2n_test_ktls_io_fail_ctx io_ctx = { 0 }; + EXPECT_OK(s2n_ktls_set_recvmsg_cb(client, s2n_test_ktls_recvmsg_fail, &io_ctx)); + + uint8_t recv_buf[S2N_TLS_MAXIMUM_FRAGMENT_LENGTH] = { 0 }; + s2n_blocked_status blocked = S2N_NOT_BLOCKED; + uint8_t recv_record_type = 0; + size_t bytes_read = 0; + + io_ctx.errno_code = EWOULDBLOCK; + EXPECT_ERROR_WITH_ERRNO( + s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), + S2N_ERR_IO_BLOCKED); + EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_READ); + + /* cppcheck-suppress redundantAssignment */ + io_ctx.errno_code = EAGAIN; + EXPECT_ERROR_WITH_ERRNO( + s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), + S2N_ERR_IO_BLOCKED); + EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_READ); + + EXPECT_EQUAL(io_ctx.invoked_count, 2); + }; + + /* Handle a S2N_ERR_IO error */ + { + DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT), + s2n_connection_ptr_free); + struct s2n_test_ktls_io_fail_ctx io_ctx = { + .errno_code = EINVAL, + }; + EXPECT_OK(s2n_ktls_set_recvmsg_cb(client, s2n_test_ktls_recvmsg_fail, &io_ctx)); + + uint8_t recv_buf[S2N_TLS_MAXIMUM_FRAGMENT_LENGTH] = { 0 }; + s2n_blocked_status blocked = S2N_NOT_BLOCKED; + uint8_t recv_record_type = 0; + size_t bytes_read = 0; + EXPECT_ERROR_WITH_ERRNO( + s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), + S2N_ERR_IO); + /* Blocked status intentionally not reset to preserve legacy s2n_send behavior */ + EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_READ); + + EXPECT_EQUAL(io_ctx.invoked_count, 1); + }; + + /* Simulate EOF and handle a S2N_ERR_CLOSED error */ + { + DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT), + s2n_connection_ptr_free); + struct s2n_test_ktls_io_fail_ctx io_ctx = { 0 }; + EXPECT_OK(s2n_ktls_set_recvmsg_cb(client, s2n_test_ktls_recvmsg_eof, &io_ctx)); + + uint8_t recv_buf[S2N_TLS_MAXIMUM_FRAGMENT_LENGTH] = { 0 }; + s2n_blocked_status blocked = S2N_NOT_BLOCKED; + uint8_t recv_record_type = 0; + size_t bytes_read = 0; + EXPECT_ERROR_WITH_ERRNO( + s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), + S2N_ERR_CLOSED); + /* Blocked status intentionally not reset to preserve legacy s2n_send behavior */ + EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_READ); + + EXPECT_EQUAL(io_ctx.invoked_count, 1); + }; + + /* Simulate control message truncated via MSG_CTRUNC flag and handle a S2N_ERR_KTLS_BAD_CMSG error */ + { + DEFER_CLEANUP(struct s2n_connection *server = s2n_connection_new(S2N_SERVER), + s2n_connection_ptr_free); + DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT), + s2n_connection_ptr_free); + DEFER_CLEANUP(struct s2n_test_ktls_io_stuffer_pair io_pair = { 0 }, + s2n_ktls_io_stuffer_pair_free); + EXPECT_OK(s2n_test_init_ktls_io_stuffer(server, client, &io_pair)); + /* override the client recvmsg callback to add a MSG_CTRUNC flag to msghdr before returning */ + EXPECT_OK(s2n_ktls_set_recvmsg_cb(client, s2n_test_ktls_recvmsg_io_stuffer_and_ctrunc, &io_pair.client_in)); + + struct iovec msg_iov = { .iov_base = test_data, .iov_len = S2N_TEST_TO_SEND }; + s2n_blocked_status blocked = S2N_NOT_BLOCKED; + size_t bytes_written = 0; + EXPECT_OK(s2n_ktls_sendmsg(server, test_record_type, &msg_iov, 1, &blocked, &bytes_written)); + EXPECT_EQUAL(bytes_written, S2N_TEST_TO_SEND); + + uint8_t recv_buf[S2N_TLS_MAXIMUM_FRAGMENT_LENGTH] = { 0 }; + uint8_t recv_record_type = 0; + size_t bytes_read = 0; + EXPECT_ERROR_WITH_ERRNO( + s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), + S2N_ERR_KTLS_BAD_CMSG); + /* Blocked status intentionally not reset to preserve legacy s2n_send behavior */ + EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_READ); + + EXPECT_EQUAL(io_pair.client_in.sendmsg_invoked_count, 1); + EXPECT_EQUAL(io_pair.client_in.recvmsg_invoked_count, 1); + }; + }; + END_TEST(); } diff --git a/tls/s2n_ktls.h b/tls/s2n_ktls.h index abaac0f4c58..3fc588e2340 100644 --- a/tls/s2n_ktls.h +++ b/tls/s2n_ktls.h @@ -41,6 +41,8 @@ S2N_RESULT s2n_ktls_get_file_descriptor(struct s2n_connection *conn, s2n_ktls_mo S2N_RESULT s2n_ktls_sendmsg(struct s2n_connection *conn, uint8_t record_type, const struct iovec *msg_iov, size_t msg_iovlen, s2n_blocked_status *blocked, size_t *bytes_written); +S2N_RESULT s2n_ktls_recvmsg(struct s2n_connection *conn, uint8_t *record_type, uint8_t *buf, + size_t buf_len, s2n_blocked_status *blocked, size_t *bytes_read); /* These functions will be part of the public API. */ int s2n_connection_ktls_enable_send(struct s2n_connection *conn); diff --git a/tls/s2n_ktls_io.c b/tls/s2n_ktls_io.c index 5718a7a5437..b3d2d25cd71 100644 --- a/tls/s2n_ktls_io.c +++ b/tls/s2n_ktls_io.c @@ -146,6 +146,13 @@ S2N_RESULT s2n_ktls_get_control_data(struct msghdr *msg, int cmsg_type, uint8_t RESULT_ENSURE_REF(msg); RESULT_ENSURE_REF(record_type); + /* https://man7.org/linux/man-pages/man3/recvmsg.3p.html + * MSG_CTRUNC Control data was truncated. + */ + if (msg->msg_flags & MSG_CTRUNC) { + RESULT_BAIL(S2N_ERR_KTLS_BAD_CMSG); + } + /* * https://man7.org/linux/man-pages/man3/cmsg.3.html * To create ancillary data, first initialize the msg_controllen @@ -185,6 +192,7 @@ S2N_RESULT s2n_ktls_sendmsg(struct s2n_connection *conn, uint8_t record_type, co RESULT_ENSURE_REF(conn); *blocked = S2N_BLOCKED_ON_WRITE; + *bytes_written = 0; struct msghdr msg = { /* msghdr requires a non-const iovec. This is safe because s2n-tls does @@ -210,3 +218,56 @@ S2N_RESULT s2n_ktls_sendmsg(struct s2n_connection *conn, uint8_t record_type, co *bytes_written = result; return S2N_RESULT_OK; } + +S2N_RESULT s2n_ktls_recvmsg(struct s2n_connection *conn, uint8_t *record_type, uint8_t *buf, + size_t buf_len, s2n_blocked_status *blocked, size_t *bytes_read) +{ + RESULT_ENSURE_REF(record_type); + RESULT_ENSURE_REF(bytes_read); + RESULT_ENSURE_REF(blocked); + RESULT_ENSURE_REF(conn); + RESULT_ENSURE_REF(buf); + /* Ensure that buf_len is > 0 since trying to receive 0 bytes does not + * make sense and a return value of `0` from recvmsg is treated as EOF. + */ + RESULT_ENSURE_GT(buf_len, 0); + + *blocked = S2N_BLOCKED_ON_READ; + *record_type = 0; + *bytes_read = 0; + struct iovec msg_iov = { + .iov_base = buf, + .iov_len = buf_len + }; + struct msghdr msg = { + .msg_iov = &msg_iov, + .msg_iovlen = 1, + }; + + /* + * https://man7.org/linux/man-pages/man3/cmsg.3.html + * To create ancillary data, first initialize the msg_controllen + * member of the msghdr with the length of the control message + * buffer. + */ + char control_data[S2N_KTLS_CONTROL_BUFFER_SIZE] = { 0 }; + msg.msg_controllen = sizeof(control_data); + msg.msg_control = control_data; + + ssize_t result = s2n_recvmsg_fn(conn->recv_io_context, &msg); + if (result < 0) { + if (errno == EWOULDBLOCK || errno == EAGAIN) { + RESULT_BAIL(S2N_ERR_IO_BLOCKED); + } + RESULT_BAIL(S2N_ERR_IO); + } else if (result == 0) { + /* The return value will be 0 when the socket reads EOF. */ + RESULT_BAIL(S2N_ERR_CLOSED); + } + + RESULT_GUARD(s2n_ktls_get_control_data(&msg, S2N_TLS_GET_RECORD_TYPE, record_type)); + + *blocked = S2N_NOT_BLOCKED; + *bytes_read = result; + return S2N_RESULT_OK; +}