Skip to content

Commit

Permalink
kTLS: implement recvmsg (#4154)
Browse files Browse the repository at this point in the history
Co-authored-by: Lindsay Stewart <[email protected]>
  • Loading branch information
toidiu and lrstewart authored Aug 25, 2023
1 parent 625ff98 commit b70868e
Show file tree
Hide file tree
Showing 3 changed files with 306 additions and 6 deletions.
249 changes: 243 additions & 6 deletions tests/unit/s2n_ktls_io_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,51 @@ S2N_RESULT s2n_ktls_set_control_data(struct msghdr *msg, char *buf, size_t buf_s
S2N_RESULT s2n_ktls_get_control_data(struct msghdr *msg, int cmsg_type, uint8_t *record_type);

/* Mock implementation used for validating failure behavior */
struct s2n_test_ktls_io_fail {
size_t invoked_count;
struct s2n_test_ktls_io_fail_ctx {
size_t errno_code;
size_t invoked_count;
};

static ssize_t s2n_test_ktls_sendmsg_fail(void *io_context, const struct msghdr *msg)
{
struct s2n_test_ktls_io_fail *io_ctx = (struct s2n_test_ktls_io_fail *) io_context;
struct s2n_test_ktls_io_fail_ctx *io_ctx = (struct s2n_test_ktls_io_fail_ctx *) io_context;
POSIX_ENSURE_REF(io_ctx);
io_ctx->invoked_count++;
errno = io_ctx->errno_code;
return -1;
}

static ssize_t s2n_test_ktls_recvmsg_fail(void *io_context, struct msghdr *msg)
{
POSIX_ENSURE_REF(msg);

struct s2n_test_ktls_io_fail_ctx *io_ctx = (struct s2n_test_ktls_io_fail_ctx *) io_context;
POSIX_ENSURE_REF(io_ctx);
io_ctx->invoked_count++;
errno = io_ctx->errno_code;
return -1;
}

static ssize_t s2n_test_ktls_recvmsg_eof(void *io_context, struct msghdr *msg)
{
struct s2n_test_ktls_io_fail_ctx *io_ctx = (struct s2n_test_ktls_io_fail_ctx *) io_context;
POSIX_ENSURE_REF(io_ctx);
io_ctx->invoked_count++;
return 0;
}

ssize_t s2n_test_ktls_recvmsg_io_stuffer_and_ctrunc(void *io_context, struct msghdr *msg)
{
POSIX_ENSURE_REF(msg);

/* The stuffer mock IO is used to ensure `cmsghdr` is otherwise properly constructed
* and that the failure occurs due to the MSG_CTRUNC flag. */
ssize_t ret = s2n_test_ktls_recvmsg_io_stuffer(io_context, msg);
POSIX_GUARD(ret);
msg->msg_flags = MSG_CTRUNC;
return ret;
}

int main(int argc, char **argv)
{
BEGIN_TEST();
Expand Down Expand Up @@ -107,7 +138,6 @@ int main(int argc, char **argv)
{
DEFER_CLEANUP(struct s2n_connection *server = s2n_connection_new(S2N_SERVER),
s2n_connection_ptr_free);

struct iovec msg_iov_valid = { .iov_base = test_data, .iov_len = S2N_TEST_TO_SEND };
s2n_blocked_status blocked = S2N_NOT_BLOCKED;
size_t bytes_written = 0;
Expand Down Expand Up @@ -218,7 +248,7 @@ int main(int argc, char **argv)
{
DEFER_CLEANUP(struct s2n_connection *server = s2n_connection_new(S2N_SERVER),
s2n_connection_ptr_free);
struct s2n_test_ktls_io_fail io_ctx = { 0 };
struct s2n_test_ktls_io_fail_ctx io_ctx = { 0 };
EXPECT_OK(s2n_ktls_set_sendmsg_cb(server, s2n_test_ktls_sendmsg_fail, &io_ctx));

struct iovec msg_iov = { .iov_base = test_data, .iov_len = S2N_TEST_TO_SEND };
Expand All @@ -243,7 +273,7 @@ int main(int argc, char **argv)
{
DEFER_CLEANUP(struct s2n_connection *server = s2n_connection_new(S2N_SERVER),
s2n_connection_ptr_free);
struct s2n_test_ktls_io_fail io_ctx = {
struct s2n_test_ktls_io_fail_ctx io_ctx = {
.errno_code = EINVAL,
};
EXPECT_OK(s2n_ktls_set_sendmsg_cb(server, s2n_test_ktls_sendmsg_fail, &io_ctx));
Expand Down Expand Up @@ -286,5 +316,212 @@ int main(int argc, char **argv)
};
};

/* Test s2n_ktls_recvmsg */
{
/* Safety */
{
DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT),
s2n_connection_ptr_free);
uint8_t recv_buf[S2N_TLS_MAXIMUM_FRAGMENT_LENGTH] = { 0 };
s2n_blocked_status blocked = S2N_NOT_BLOCKED;
uint8_t recv_record_type = 0;
size_t bytes_read = 0;

EXPECT_ERROR_WITH_ERRNO(
s2n_ktls_recvmsg(NULL, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read),
S2N_ERR_NULL);
EXPECT_ERROR_WITH_ERRNO(
s2n_ktls_recvmsg(client, NULL, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read),
S2N_ERR_NULL);
EXPECT_ERROR_WITH_ERRNO(
s2n_ktls_recvmsg(client, &recv_record_type, NULL, S2N_TEST_TO_SEND, &blocked, &bytes_read),
S2N_ERR_NULL);
EXPECT_ERROR_WITH_ERRNO(
s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, NULL, &bytes_read),
S2N_ERR_NULL);
EXPECT_ERROR_WITH_ERRNO(
s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, NULL),
S2N_ERR_NULL);

size_t to_recv_zero = 0;
EXPECT_ERROR_WITH_ERRNO(
s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, to_recv_zero, &blocked, &bytes_read),
S2N_ERR_SAFETY);
};

/* Happy case: send/recv data using sendmsg/recvmsg */
{
DEFER_CLEANUP(struct s2n_connection *server = s2n_connection_new(S2N_SERVER),
s2n_connection_ptr_free);
DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT),
s2n_connection_ptr_free);
DEFER_CLEANUP(struct s2n_test_ktls_io_stuffer_pair io_pair = { 0 },
s2n_ktls_io_stuffer_pair_free);
EXPECT_OK(s2n_test_init_ktls_io_stuffer(server, client, &io_pair));

struct iovec msg_iov = { .iov_base = test_data, .iov_len = S2N_TEST_TO_SEND };
s2n_blocked_status blocked = S2N_NOT_BLOCKED;
size_t bytes_written = 0;
EXPECT_OK(s2n_ktls_sendmsg(server, test_record_type, &msg_iov, 1, &blocked, &bytes_written));
EXPECT_EQUAL(bytes_written, S2N_TEST_TO_SEND);

uint8_t recv_buf[S2N_TLS_MAXIMUM_FRAGMENT_LENGTH] = { 0 };
uint8_t recv_record_type = 0;
size_t bytes_read = 0;
EXPECT_OK(s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read));
EXPECT_BYTEARRAY_EQUAL(test_data, recv_buf, bytes_read);
EXPECT_EQUAL(bytes_read, bytes_written);

EXPECT_EQUAL(io_pair.client_in.sendmsg_invoked_count, 1);
EXPECT_EQUAL(io_pair.client_in.recvmsg_invoked_count, 1);
};

/* Simulate blocked and handle a S2N_ERR_IO_BLOCKED error */
{
DEFER_CLEANUP(struct s2n_connection *server = s2n_connection_new(S2N_SERVER),
s2n_connection_ptr_free);
DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT),
s2n_connection_ptr_free);
DEFER_CLEANUP(struct s2n_test_ktls_io_stuffer_pair io_pair = { 0 },
s2n_ktls_io_stuffer_pair_free);
EXPECT_OK(s2n_test_init_ktls_io_stuffer(server, client, &io_pair));

uint8_t recv_buf[S2N_TLS_MAXIMUM_FRAGMENT_LENGTH] = { 0 };
s2n_blocked_status blocked = S2N_NOT_BLOCKED;
size_t blocked_invoked_count = 5;
uint8_t recv_record_type = 0;
size_t bytes_read = 0;
/* recv should block since there is no data */
for (size_t i = 0; i < blocked_invoked_count; i++) {
EXPECT_ERROR_WITH_ERRNO(
s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read),
S2N_ERR_IO_BLOCKED);
EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_READ);
}

/* send data to unblock */
struct iovec msg_iov = { .iov_base = test_data, .iov_len = S2N_TEST_TO_SEND };
size_t bytes_written = 0;
EXPECT_OK(s2n_ktls_sendmsg(server, test_record_type, &msg_iov, 1, &blocked, &bytes_written));
EXPECT_EQUAL(bytes_written, S2N_TEST_TO_SEND);

EXPECT_OK(s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read));
EXPECT_BYTEARRAY_EQUAL(test_data, recv_buf, bytes_read);
EXPECT_EQUAL(bytes_read, bytes_written);

/* recv should block again since we have read all the data */
for (size_t i = 0; i < blocked_invoked_count; i++) {
EXPECT_ERROR_WITH_ERRNO(
s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read),
S2N_ERR_IO_BLOCKED);
EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_READ);
}

EXPECT_EQUAL(io_pair.client_in.recvmsg_invoked_count, (blocked_invoked_count * 2) + 1);
EXPECT_EQUAL(io_pair.client_in.sendmsg_invoked_count, 1);
};

/* Both EWOULDBLOCK and EAGAIN should return a S2N_ERR_IO_BLOCKED error */
{
DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT),
s2n_connection_ptr_free);
struct s2n_test_ktls_io_fail_ctx io_ctx = { 0 };
EXPECT_OK(s2n_ktls_set_recvmsg_cb(client, s2n_test_ktls_recvmsg_fail, &io_ctx));

uint8_t recv_buf[S2N_TLS_MAXIMUM_FRAGMENT_LENGTH] = { 0 };
s2n_blocked_status blocked = S2N_NOT_BLOCKED;
uint8_t recv_record_type = 0;
size_t bytes_read = 0;

io_ctx.errno_code = EWOULDBLOCK;
EXPECT_ERROR_WITH_ERRNO(
s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read),
S2N_ERR_IO_BLOCKED);
EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_READ);

/* cppcheck-suppress redundantAssignment */
io_ctx.errno_code = EAGAIN;
EXPECT_ERROR_WITH_ERRNO(
s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read),
S2N_ERR_IO_BLOCKED);
EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_READ);

EXPECT_EQUAL(io_ctx.invoked_count, 2);
};

/* Handle a S2N_ERR_IO error */
{
DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT),
s2n_connection_ptr_free);
struct s2n_test_ktls_io_fail_ctx io_ctx = {
.errno_code = EINVAL,
};
EXPECT_OK(s2n_ktls_set_recvmsg_cb(client, s2n_test_ktls_recvmsg_fail, &io_ctx));

uint8_t recv_buf[S2N_TLS_MAXIMUM_FRAGMENT_LENGTH] = { 0 };
s2n_blocked_status blocked = S2N_NOT_BLOCKED;
uint8_t recv_record_type = 0;
size_t bytes_read = 0;
EXPECT_ERROR_WITH_ERRNO(
s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read),
S2N_ERR_IO);
/* Blocked status intentionally not reset to preserve legacy s2n_send behavior */
EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_READ);

EXPECT_EQUAL(io_ctx.invoked_count, 1);
};

/* Simulate EOF and handle a S2N_ERR_CLOSED error */
{
DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT),
s2n_connection_ptr_free);
struct s2n_test_ktls_io_fail_ctx io_ctx = { 0 };
EXPECT_OK(s2n_ktls_set_recvmsg_cb(client, s2n_test_ktls_recvmsg_eof, &io_ctx));

uint8_t recv_buf[S2N_TLS_MAXIMUM_FRAGMENT_LENGTH] = { 0 };
s2n_blocked_status blocked = S2N_NOT_BLOCKED;
uint8_t recv_record_type = 0;
size_t bytes_read = 0;
EXPECT_ERROR_WITH_ERRNO(
s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read),
S2N_ERR_CLOSED);
/* Blocked status intentionally not reset to preserve legacy s2n_send behavior */
EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_READ);

EXPECT_EQUAL(io_ctx.invoked_count, 1);
};

/* Simulate control message truncated via MSG_CTRUNC flag and handle a S2N_ERR_KTLS_BAD_CMSG error */
{
DEFER_CLEANUP(struct s2n_connection *server = s2n_connection_new(S2N_SERVER),
s2n_connection_ptr_free);
DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT),
s2n_connection_ptr_free);
DEFER_CLEANUP(struct s2n_test_ktls_io_stuffer_pair io_pair = { 0 },
s2n_ktls_io_stuffer_pair_free);
EXPECT_OK(s2n_test_init_ktls_io_stuffer(server, client, &io_pair));
/* override the client recvmsg callback to add a MSG_CTRUNC flag to msghdr before returning */
EXPECT_OK(s2n_ktls_set_recvmsg_cb(client, s2n_test_ktls_recvmsg_io_stuffer_and_ctrunc, &io_pair.client_in));

struct iovec msg_iov = { .iov_base = test_data, .iov_len = S2N_TEST_TO_SEND };
s2n_blocked_status blocked = S2N_NOT_BLOCKED;
size_t bytes_written = 0;
EXPECT_OK(s2n_ktls_sendmsg(server, test_record_type, &msg_iov, 1, &blocked, &bytes_written));
EXPECT_EQUAL(bytes_written, S2N_TEST_TO_SEND);

uint8_t recv_buf[S2N_TLS_MAXIMUM_FRAGMENT_LENGTH] = { 0 };
uint8_t recv_record_type = 0;
size_t bytes_read = 0;
EXPECT_ERROR_WITH_ERRNO(
s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read),
S2N_ERR_KTLS_BAD_CMSG);
/* Blocked status intentionally not reset to preserve legacy s2n_send behavior */
EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_READ);

EXPECT_EQUAL(io_pair.client_in.sendmsg_invoked_count, 1);
EXPECT_EQUAL(io_pair.client_in.recvmsg_invoked_count, 1);
};
};

END_TEST();
}
2 changes: 2 additions & 0 deletions tls/s2n_ktls.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ S2N_RESULT s2n_ktls_get_file_descriptor(struct s2n_connection *conn, s2n_ktls_mo

S2N_RESULT s2n_ktls_sendmsg(struct s2n_connection *conn, uint8_t record_type, const struct iovec *msg_iov,
size_t msg_iovlen, s2n_blocked_status *blocked, size_t *bytes_written);
S2N_RESULT s2n_ktls_recvmsg(struct s2n_connection *conn, uint8_t *record_type, uint8_t *buf,
size_t buf_len, s2n_blocked_status *blocked, size_t *bytes_read);

/* These functions will be part of the public API. */
int s2n_connection_ktls_enable_send(struct s2n_connection *conn);
Expand Down
61 changes: 61 additions & 0 deletions tls/s2n_ktls_io.c
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ S2N_RESULT s2n_ktls_get_control_data(struct msghdr *msg, int cmsg_type, uint8_t
RESULT_ENSURE_REF(msg);
RESULT_ENSURE_REF(record_type);

/* https://man7.org/linux/man-pages/man3/recvmsg.3p.html
* MSG_CTRUNC Control data was truncated.
*/
if (msg->msg_flags & MSG_CTRUNC) {
RESULT_BAIL(S2N_ERR_KTLS_BAD_CMSG);
}

/*
* https://man7.org/linux/man-pages/man3/cmsg.3.html
* To create ancillary data, first initialize the msg_controllen
Expand Down Expand Up @@ -185,6 +192,7 @@ S2N_RESULT s2n_ktls_sendmsg(struct s2n_connection *conn, uint8_t record_type, co
RESULT_ENSURE_REF(conn);

*blocked = S2N_BLOCKED_ON_WRITE;
*bytes_written = 0;

struct msghdr msg = {
/* msghdr requires a non-const iovec. This is safe because s2n-tls does
Expand All @@ -210,3 +218,56 @@ S2N_RESULT s2n_ktls_sendmsg(struct s2n_connection *conn, uint8_t record_type, co
*bytes_written = result;
return S2N_RESULT_OK;
}

S2N_RESULT s2n_ktls_recvmsg(struct s2n_connection *conn, uint8_t *record_type, uint8_t *buf,
size_t buf_len, s2n_blocked_status *blocked, size_t *bytes_read)
{
RESULT_ENSURE_REF(record_type);
RESULT_ENSURE_REF(bytes_read);
RESULT_ENSURE_REF(blocked);
RESULT_ENSURE_REF(conn);
RESULT_ENSURE_REF(buf);
/* Ensure that buf_len is > 0 since trying to receive 0 bytes does not
* make sense and a return value of `0` from recvmsg is treated as EOF.
*/
RESULT_ENSURE_GT(buf_len, 0);

*blocked = S2N_BLOCKED_ON_READ;
*record_type = 0;
*bytes_read = 0;
struct iovec msg_iov = {
.iov_base = buf,
.iov_len = buf_len
};
struct msghdr msg = {
.msg_iov = &msg_iov,
.msg_iovlen = 1,
};

/*
* https://man7.org/linux/man-pages/man3/cmsg.3.html
* To create ancillary data, first initialize the msg_controllen
* member of the msghdr with the length of the control message
* buffer.
*/
char control_data[S2N_KTLS_CONTROL_BUFFER_SIZE] = { 0 };
msg.msg_controllen = sizeof(control_data);
msg.msg_control = control_data;

ssize_t result = s2n_recvmsg_fn(conn->recv_io_context, &msg);
if (result < 0) {
if (errno == EWOULDBLOCK || errno == EAGAIN) {
RESULT_BAIL(S2N_ERR_IO_BLOCKED);
}
RESULT_BAIL(S2N_ERR_IO);
} else if (result == 0) {
/* The return value will be 0 when the socket reads EOF. */
RESULT_BAIL(S2N_ERR_CLOSED);
}

RESULT_GUARD(s2n_ktls_get_control_data(&msg, S2N_TLS_GET_RECORD_TYPE, record_type));

*blocked = S2N_NOT_BLOCKED;
*bytes_read = result;
return S2N_RESULT_OK;
}

0 comments on commit b70868e

Please sign in to comment.