From de411b0ad07f31314a57efcf3fd962c16f12baa4 Mon Sep 17 00:00:00 2001 From: ekoby <7406535+ekoby@users.noreply.github.com> Date: Tue, 30 Jan 2024 11:27:11 -0500 Subject: [PATCH] Read start: check ssl engine data (#200) * tlsuv: tlsuv_read_start check for data buffered in TLS engine application may stop reading when SSL packet(s) was read from the wire but not delivered to application. * test: ignore empty reads --- sample/common.c | 2 +- sample/common.h | 2 +- src/mbedtls/engine.c | 4 ++++ src/tlsuv.c | 37 ++++++++++++++++++++++++++++++++----- src/util.h | 4 ++++ tests/stream_tests.cpp | 19 +++++++++++++------ 6 files changed, 55 insertions(+), 13 deletions(-) diff --git a/sample/common.c b/sample/common.c index 2d928417..81d4203e 100644 --- a/sample/common.c +++ b/sample/common.c @@ -42,7 +42,7 @@ void resp_cb(tlsuv_http_resp_t *resp, void *data) { printf("\n"); } -void body_cb(tlsuv_http_req_t *req, const char *body, ssize_t len) { +void body_cb(tlsuv_http_req_t *req, char *body, ssize_t len) { if (len == UV_EOF) { printf("\n\n====================\nRequest completed\n"); } diff --git a/sample/common.h b/sample/common.h index e00cc2d0..ddfd9974 100644 --- a/sample/common.h +++ b/sample/common.h @@ -19,6 +19,6 @@ #include void resp_cb(tlsuv_http_resp_t *resp, void *data); -void body_cb(tlsuv_http_req_t *req, const char *body, ssize_t len); +void body_cb(tlsuv_http_req_t *req, char *body, ssize_t len); void logger(int level, const char *file, unsigned int line, const char *msg); #endif //UV_MBED_COMMON_H diff --git a/src/mbedtls/engine.c b/src/mbedtls/engine.c index ec3bb2bc..bd6568e8 100644 --- a/src/mbedtls/engine.c +++ b/src/mbedtls/engine.c @@ -765,6 +765,10 @@ static int mbedtls_read(tlsuv_engine_t engine, char *out, size_t *out_bytes, siz if (rc < 0) { if (rc == MBEDTLS_ERR_SSL_WANT_READ || rc == MBEDTLS_ERR_SSL_WANT_WRITE) { err = TLS_AGAIN; + } else if (rc == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { + UM_LOG(DEBG, "mbedTLS: peer close notify"); + eng->error = rc; + err = TLS_EOF; } else { UM_LOG(ERR, "mbedTLS: %0x(%s)", rc, mbedtls_error(rc)); eng->error = rc; diff --git a/src/tlsuv.c b/src/tlsuv.c index f1ee18f2..963e24c2 100644 --- a/src/tlsuv.c +++ b/src/tlsuv.c @@ -14,6 +14,7 @@ #include "tlsuv/tlsuv.h" #include "um_debug.h" +#include "util.h" #include "tlsuv/queue.h" #include #include @@ -42,6 +43,7 @@ static uv_os_sock_t new_socket(const struct addrinfo *addr); static void on_clt_io(uv_poll_t *, int, int); static void fail_pending_reqs(tlsuv_stream_t *clt, int err); +static void check_read(uv_idle_t *idle); static tls_context *DEFAULT_TLS = NULL; @@ -81,8 +83,6 @@ int tlsuv_stream_init(uv_loop_t *l, tlsuv_stream_t *clt, tls_context *tls) { clt->queue_len = 0; TAILQ_INIT(&clt->queue); - clt->watcher.data = clt; - return 0; } @@ -110,13 +110,19 @@ static int start_io(tlsuv_stream_t *clt) { } static void on_internal_close(uv_handle_t *h) { - tlsuv_stream_t *clt = h->data; + tlsuv_stream_t *clt = container_of(h, tlsuv_stream_t, watcher); if (clt->conn_req) { uv_connect_t *req = clt->conn_req; clt->conn_req = NULL; req->cb(req, UV_ECANCELED); } + if (h->data) { + uv_idle_t *idle = h->data; + assert(idle->type == UV_IDLE); + uv_close((uv_handle_t *) idle, (uv_close_cb) free); + } + // error handling // fail all pending requests fail_pending_reqs(clt, UV_ECANCELED); @@ -337,6 +343,15 @@ static void process_inbound(tlsuv_stream_t *clt) { int attempts = 16; + // got IO or idle check, can clear the handle + if (clt->watcher.data) { + uv_idle_t *idler = clt->watcher.data; + assert(idler->type == UV_IDLE); + + clt->watcher.data = NULL; + uv_close((uv_handle_t *) idler, (uv_close_cb) free); + } + while(clt->read_cb && (attempts-- > 0)) { assert(clt->alloc_cb != NULL); @@ -380,7 +395,7 @@ static void process_inbound(tlsuv_stream_t *clt) { } static void on_clt_io(uv_poll_t *p, int status, int events) { - tlsuv_stream_t *clt = p->data; + tlsuv_stream_t *clt = container_of(p, tlsuv_stream_t, watcher); if (clt->conn_req) { UM_LOG(VERB, "processing connect: events=%d status=%d", events, status); process_connect(clt, status); @@ -423,7 +438,6 @@ int tlsuv_stream_open(uv_connect_t *req, tlsuv_stream_t *clt, uv_os_fd_t fd, uv_ clt->sock = fd; uv_poll_init_socket(clt->loop, &clt->watcher, clt->sock); - clt->watcher.data = clt; return uv_poll_start(&clt->watcher, UV_READABLE | UV_WRITABLE | UV_DISCONNECT, on_clt_io); } @@ -523,6 +537,14 @@ int tlsuv_stream_read_start(tlsuv_stream_t *clt, uv_alloc_cb alloc_cb, uv_read_c if (rc != 0) { clt->alloc_cb = NULL; clt->read_cb = NULL; + } else { + // schedule idle read (if nothing on the wire) + // in case reading was stopped with data buffered in TLS engine + uv_idle_t *idle = calloc(1, sizeof(*idle)); + clt->watcher.data = idle; + uv_idle_init(clt->loop, idle); + idle->data = clt; + uv_idle_start(idle, check_read); } return rc; } @@ -629,3 +651,8 @@ uv_os_sock_t new_socket(const struct addrinfo *addr) { return sock; } +void check_read(uv_idle_t *idler) { + tlsuv_stream_t *clt = idler->data; + // this will clean up idle handle + process_inbound(clt); +} diff --git a/src/util.h b/src/util.h index f2fb6c77..956e6e79 100644 --- a/src/util.h +++ b/src/util.h @@ -7,6 +7,10 @@ #include + +#define container_of(ptr, type, member) \ + ((type *) ((char *) (ptr) - offsetof(type, member))) + /** * wrap-around buffer */ diff --git a/tests/stream_tests.cpp b/tests/stream_tests.cpp index f54a5923..b06d4e2a 100644 --- a/tests/stream_tests.cpp +++ b/tests/stream_tests.cpp @@ -153,10 +153,13 @@ TEST_CASE("read/write","[stream]") { auto ctx = (struct test_ctx *) c->data; if (status == UV_EOF) { tlsuv_stream_close(c, nullptr); + } else if (status >= 0) { + if (status > 0) { + REQUIRE_THAT(b->base, Catch::Matchers::StartsWith("HTTP/1.1 200 OK")); + fprintf(stderr, "%.*s\n", (int) status, b->base); + } } else { - REQUIRE(status > 0); - REQUIRE_THAT(b->base, Catch::Matchers::StartsWith("HTTP/1.1 200 OK")); - fprintf(stderr, "%.*s\n", (int) status, b->base); + FAIL("status: " << status << " " << uv_strerror(status)); } free(b->base); }); @@ -322,9 +325,13 @@ static void read_cb(uv_stream_t *stream, ssize_t nread, const uv_buf_t *buf) { tlsuv_stream_t *clt = reinterpret_cast(stream); test_result *result = static_cast(clt->data); - REQUIRE(nread > 0); - result->read_count++; - result->read_data.append(buf->base, nread); + + REQUIRE(nread >= 0); + + if (nread > 0) { + result->read_count++; + result->read_data.append(buf->base, nread); + } free(buf->base); }