diff --git a/src/ds/thread_messaging.h b/src/ds/thread_messaging.h index 9f404c9e8e3c..84983cb18e10 100644 --- a/src/ds/thread_messaging.h +++ b/src/ds/thread_messaging.h @@ -18,50 +18,32 @@ extern std::map thread_ids; namespace enclave { - struct ThreadMsg + const uint64_t magic_const = 0xba5eball; + struct alignas(8) ThreadMsg { void (*cb)(std::unique_ptr); std::atomic next = nullptr; - uint64_t padding[14]; + uint64_t magic = magic_const; + + ThreadMsg(void (*_cb)(std::unique_ptr)) : cb(_cb) {} + + virtual ~ThreadMsg() + { + assert(magic == magic_const); + } }; template - struct Tmsg + struct alignas(8) Tmsg : public ThreadMsg { + Payload data; + Tmsg(void (*_cb)(std::unique_ptr>)) : - cb(reinterpret_cast)>(_cb)), - next(nullptr) - { - check_invariants(); - } + ThreadMsg(reinterpret_cast)>(_cb)) - void (*cb)(std::unique_ptr); - std::atomic next; - union - { - Payload data; - uint64_t padding[14]; - }; + {} - static void check_invariants() - { - static_assert( - sizeof(ThreadMsg) == sizeof(Tmsg), "message is too large"); - static_assert( - sizeof(Payload) <= sizeof(ThreadMsg::padding), - "message payload is too large"); - static_assert(std::is_pod::value, "data should be a pod"); - - static_assert( - offsetof(Tmsg, cb) == offsetof(ThreadMsg, cb), - "Expected cb at start of struct"); - static_assert( - offsetof(Tmsg, next) == offsetof(ThreadMsg, next), - "Expected next after cb in struct"); - static_assert( - offsetof(Tmsg, data) == offsetof(ThreadMsg, padding), - "Expected payload after next in struct"); - } + virtual ~Tmsg() = default; }; static void init_cb(std::unique_ptr m) @@ -168,6 +150,7 @@ namespace enclave public: static ThreadMessaging thread_messaging; static std::atomic thread_count; + static const uint16_t main_thread = 0; static const uint16_t max_num_threads = 64; diff --git a/src/enclave/endpoint.h b/src/enclave/endpoint.h index 0b5140d2c54c..4c630e56c9c0 100644 --- a/src/enclave/endpoint.h +++ b/src/enclave/endpoint.h @@ -6,7 +6,7 @@ namespace enclave { - class Endpoint + class Endpoint : public std::enable_shared_from_this { public: virtual ~Endpoint() {} diff --git a/src/enclave/httpendpoint.h b/src/enclave/httpendpoint.h index 162c38d8837f..00f420757a5c 100644 --- a/src/enclave/httpendpoint.h +++ b/src/enclave/httpendpoint.h @@ -27,7 +27,23 @@ namespace enclave p(parser_type, *this) {} + static void recv_cb(std::unique_ptr> msg) + { + reinterpret_cast(msg->data.self.get()) + ->recv_(msg->data.data.data(), msg->data.data.size()); + } + void recv(const uint8_t* data, size_t size) override + { + auto msg = std::make_unique>(&recv_cb); + msg->data.self = this->shared_from_this(); + msg->data.data.assign(data, data + size); + + enclave::ThreadMessaging::thread_messaging.add_task( + execution_thread, std::move(msg)); + } + + void recv_(const uint8_t* data, size_t size) { recv_buffered(data, size); @@ -95,7 +111,23 @@ namespace enclave session_id(session_id) {} + static void send_cb(std::unique_ptr> msg) + { + reinterpret_cast(msg->data.self.get()) + ->send_(msg->data.data); + } + void send(const std::vector& data) override + { + auto msg = std::make_unique>(&send_cb); + msg->data.self = this->shared_from_this(); + msg->data.data = data; + + enclave::ThreadMessaging::thread_messaging.add_task( + execution_thread, std::move(msg)); + } + + void send_(const std::vector& data) { // This should be called with raw body of response - we will wrap it with // header then transmit @@ -133,6 +165,76 @@ namespace enclave flush(); } + struct HandleProcessCbMsg + { + std::shared_ptr self; + std::shared_ptr rpc_ctx; + std::shared_ptr frontend; + }; + + static void handle_process( + std::unique_ptr> msg) + { + reinterpret_cast(msg->data.self.get()) + ->handle_message_main_thread(msg->data.rpc_ctx, msg->data.frontend); + } + + struct SendResponseVectCbMsg + { + std::vector d = {}; + http_status status; + std::shared_ptr self; + }; + + static void send_response_vect( + std::unique_ptr> msg) + { + reinterpret_cast(msg->data.self.get()) + ->send_response(msg->data.d); + } + + void handle_message_main_thread( + std::shared_ptr& rpc_ctx, + std::shared_ptr& search) + { + try + { + auto response = search->process(rpc_ctx); + + if (!response.has_value()) + { + // If the RPC is pending, hold the connection. + LOG_TRACE_FMT("Pending"); + return; + } + else + { + // Otherwise, reply to the client synchronously. + LOG_TRACE_FMT("Responding"); + auto msg = std::make_unique>( + &send_response_vect); + msg->data.status = HTTP_STATUS_OK; + msg->data.self = this->shared_from_this(); + msg->data.d = response.value(); + enclave::ThreadMessaging::thread_messaging + .add_task(execution_thread, std::move(msg)); + } + } + catch (const std::exception& e) + { + auto msg = std::make_unique>( + &send_response_vect); + msg->data.status = HTTP_STATUS_INTERNAL_SERVER_ERROR; + msg->data.self = this->shared_from_this(); + + std::string err_msg = fmt::format("Exception:\n{}\n", e.what()); + msg->data.d.assign(err_msg.begin(), err_msg.end()); + + enclave::ThreadMessaging::thread_messaging + .add_task(execution_thread, std::move(msg)); + } + } + void handle_message( http_method verb, const std::string& path, @@ -246,20 +348,13 @@ namespace enclave rpc_ctx->method = method_s; rpc_ctx->actor = actor; - auto response = search.value()->process(rpc_ctx); - - if (!response.has_value()) - { - // If the RPC is pending, hold the connection. - LOG_TRACE_FMT("Pending"); - return; - } - else - { - // Otherwise, reply to the client synchronously. - LOG_TRACE_FMT("Responding"); - send_response(response.value()); - } + auto msg = + std::make_unique>(&handle_process); + msg->data.self = this->shared_from_this(); + msg->data.rpc_ctx = rpc_ctx; + msg->data.frontend = search.value(); + enclave::ThreadMessaging::thread_messaging.add_task( + enclave::ThreadMessaging::main_thread, std::move(msg)); } catch (const std::exception& e) { diff --git a/src/enclave/tlsendpoint.h b/src/enclave/tlsendpoint.h index fdc5f4997ad9..a8c0b37d17aa 100644 --- a/src/enclave/tlsendpoint.h +++ b/src/enclave/tlsendpoint.h @@ -9,6 +9,8 @@ #include "tls/context.h" #include "tls/msg_types.h" +#include + namespace enclave { class TLSEndpoint : public Endpoint @@ -16,6 +18,7 @@ namespace enclave protected: ringbuffer::WriterPtr to_host; size_t session_id; + size_t execution_thread; enum Status { @@ -62,6 +65,15 @@ namespace enclave ctx(move(ctx_)), status(handshake) { + if (enclave::ThreadMessaging::thread_count > 1) + { + execution_thread = + (session_id_ % (enclave::ThreadMessaging::thread_count - 1)) + 1; + } + else + { + execution_thread = 0; + } ctx->set_bio(this, send_callback, recv_callback, dbg_callback); } @@ -195,12 +207,42 @@ namespace enclave void recv_buffered(const uint8_t* data, size_t size) { + if (thread_ids[std::this_thread::get_id()] != execution_thread) + { + throw std::exception(); + } pending_read.insert(pending_read.end(), data, data + size); do_handshake(); } + struct SendRecvMsg + { + std::vector data; + std::shared_ptr self; + }; + + static void send_raw_cb(std::unique_ptr> msg) + { + reinterpret_cast(msg->data.self.get()) + ->send_raw_thread(msg->data.data); + } + void send_raw(const std::vector& data) { + auto msg = std::make_unique>(&send_raw_cb); + msg->data.self = this->shared_from_this(); + msg->data.data = data; + + enclave::ThreadMessaging::thread_messaging.add_task( + execution_thread, std::move(msg)); + } + + void send_raw_thread(std::vector& data) + { + if (thread_ids[std::this_thread::get_id()] != execution_thread) + { + throw std::runtime_error("running from incorrect thread"); + } // Writes as much of the data as possible. If the data cannot all // be written now, we store the remainder. We // will try to send pending writes again whenever write() is called. @@ -222,11 +264,21 @@ namespace enclave void send_buffered(const std::vector& data) { + if (thread_ids[std::this_thread::get_id()] != execution_thread) + { + throw std::runtime_error("running from incorrect thread"); + } + pending_write.insert(pending_write.end(), data.begin(), data.end()); } void flush() { + if (thread_ids[std::this_thread::get_id()] != execution_thread) + { + throw std::runtime_error("running from incorrect thread"); + } + do_handshake(); if (status != ready) @@ -444,6 +496,10 @@ namespace enclave int handle_recv(uint8_t* buf, size_t len) { + if (thread_ids[std::this_thread::get_id()] != execution_thread) + { + throw std::runtime_error("running from incorrect thread"); + } if (pending_read.size() > 0) { // Use the pending data vector. This is populated when the host