Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Threading - TLS connections on multiple threads #773

Merged
merged 20 commits into from
Feb 3, 2020
51 changes: 17 additions & 34 deletions src/ds/thread_messaging.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,50 +18,32 @@ extern std::map<std::thread::id, uint16_t> thread_ids;

namespace enclave
{
struct ThreadMsg
const uint64_t magic_const = 0xba5eball;
struct alignas(8) ThreadMsg
{
void (*cb)(std::unique_ptr<ThreadMsg>);
std::atomic<ThreadMsg*> next = nullptr;
uint64_t padding[14];
uint64_t magic = magic_const;

ThreadMsg(void (*_cb)(std::unique_ptr<ThreadMsg>)) : cb(_cb) {}

virtual ~ThreadMsg()
{
assert(magic == magic_const);
}
};

template <typename Payload>
struct Tmsg
struct alignas(8) Tmsg : public ThreadMsg
{
Payload data;

Tmsg(void (*_cb)(std::unique_ptr<Tmsg<Payload>>)) :
cb(reinterpret_cast<void (*)(std::unique_ptr<ThreadMsg>)>(_cb)),
next(nullptr)
{
check_invariants();
}
ThreadMsg(reinterpret_cast<void (*)(std::unique_ptr<ThreadMsg>)>(_cb))

void (*cb)(std::unique_ptr<ThreadMsg>);
std::atomic<ThreadMsg*> next;
union
{
Payload data;
uint64_t padding[14];
};
{}

static void check_invariants()
{
static_assert(
sizeof(ThreadMsg) == sizeof(Tmsg<Payload>), "message is too large");
static_assert(
sizeof(Payload) <= sizeof(ThreadMsg::padding),
"message payload is too large");
static_assert(std::is_pod<Payload>::value, "data should be a pod");

static_assert(
offsetof(Tmsg, cb) == offsetof(ThreadMsg, cb),
"Expected cb at start of struct");
static_assert(
offsetof(Tmsg, next) == offsetof(ThreadMsg, next),
"Expected next after cb in struct");
static_assert(
offsetof(Tmsg, data) == offsetof(ThreadMsg, padding),
"Expected payload after next in struct");
}
virtual ~Tmsg() = default;
};

static void init_cb(std::unique_ptr<ThreadMsg> m)
Expand Down Expand Up @@ -168,6 +150,7 @@ namespace enclave
public:
static ThreadMessaging thread_messaging;
static std::atomic<uint16_t> thread_count;
static const uint16_t main_thread = 0;

static const uint16_t max_num_threads = 64;

Expand Down
2 changes: 1 addition & 1 deletion src/enclave/endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

namespace enclave
{
class Endpoint
class Endpoint : public std::enable_shared_from_this<Endpoint>
{
public:
virtual ~Endpoint() {}
Expand Down
123 changes: 109 additions & 14 deletions src/enclave/httpendpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,23 @@ namespace enclave
p(parser_type, *this)
{}

static void recv_cb(std::unique_ptr<enclave::Tmsg<SendRecvMsg>> msg)
{
reinterpret_cast<HTTPEndpoint*>(msg->data.self.get())
->recv_(msg->data.data.data(), msg->data.data.size());
}

void recv(const uint8_t* data, size_t size) override
{
auto msg = std::make_unique<enclave::Tmsg<SendRecvMsg>>(&recv_cb);
msg->data.self = this->shared_from_this();
msg->data.data.assign(data, data + size);

enclave::ThreadMessaging::thread_messaging.add_task<SendRecvMsg>(
execution_thread, std::move(msg));
}

void recv_(const uint8_t* data, size_t size)
{
recv_buffered(data, size);

Expand Down Expand Up @@ -95,7 +111,23 @@ namespace enclave
session_id(session_id)
{}

static void send_cb(std::unique_ptr<enclave::Tmsg<SendRecvMsg>> msg)
{
reinterpret_cast<HTTPServerEndpoint*>(msg->data.self.get())
->send_(msg->data.data);
}

void send(const std::vector<uint8_t>& data) override
{
auto msg = std::make_unique<enclave::Tmsg<SendRecvMsg>>(&send_cb);
msg->data.self = this->shared_from_this();
msg->data.data = data;

enclave::ThreadMessaging::thread_messaging.add_task<SendRecvMsg>(
execution_thread, std::move(msg));
}

void send_(const std::vector<uint8_t>& data)
{
// This should be called with raw body of response - we will wrap it with
// header then transmit
Expand Down Expand Up @@ -133,6 +165,76 @@ namespace enclave
flush();
}

struct HandleProcessCbMsg
{
std::shared_ptr<Endpoint> self;
std::shared_ptr<JsonRpcContext> rpc_ctx;
std::shared_ptr<RpcHandler> frontend;
};

static void handle_process(
std::unique_ptr<enclave::Tmsg<HandleProcessCbMsg>> msg)
{
reinterpret_cast<HTTPServerEndpoint*>(msg->data.self.get())
->handle_message_main_thread(msg->data.rpc_ctx, msg->data.frontend);
}

struct SendResponseVectCbMsg
{
std::vector<uint8_t> d = {};
http_status status;
std::shared_ptr<Endpoint> self;
};

static void send_response_vect(
std::unique_ptr<enclave::Tmsg<SendResponseVectCbMsg>> msg)
{
reinterpret_cast<HTTPServerEndpoint*>(msg->data.self.get())
->send_response(msg->data.d);
}

void handle_message_main_thread(
std::shared_ptr<JsonRpcContext>& rpc_ctx,
std::shared_ptr<RpcHandler>& search)
{
try
{
auto response = search->process(rpc_ctx);

if (!response.has_value())
{
// If the RPC is pending, hold the connection.
LOG_TRACE_FMT("Pending");
return;
}
else
{
// Otherwise, reply to the client synchronously.
LOG_TRACE_FMT("Responding");
auto msg = std::make_unique<enclave::Tmsg<SendResponseVectCbMsg>>(
&send_response_vect);
msg->data.status = HTTP_STATUS_OK;
msg->data.self = this->shared_from_this();
msg->data.d = response.value();
enclave::ThreadMessaging::thread_messaging
.add_task<SendResponseVectCbMsg>(execution_thread, std::move(msg));
}
}
catch (const std::exception& e)
{
auto msg = std::make_unique<enclave::Tmsg<SendResponseVectCbMsg>>(
&send_response_vect);
msg->data.status = HTTP_STATUS_INTERNAL_SERVER_ERROR;
msg->data.self = this->shared_from_this();

std::string err_msg = fmt::format("Exception:\n{}\n", e.what());
msg->data.d.assign(err_msg.begin(), err_msg.end());

enclave::ThreadMessaging::thread_messaging
.add_task<SendResponseVectCbMsg>(execution_thread, std::move(msg));
}
}

void handle_message(
http_method verb,
const std::string& path,
Expand Down Expand Up @@ -246,20 +348,13 @@ namespace enclave
rpc_ctx->method = method_s;
rpc_ctx->actor = actor;

auto response = search.value()->process(rpc_ctx);

if (!response.has_value())
{
// If the RPC is pending, hold the connection.
LOG_TRACE_FMT("Pending");
return;
}
else
{
// Otherwise, reply to the client synchronously.
LOG_TRACE_FMT("Responding");
send_response(response.value());
}
auto msg =
std::make_unique<enclave::Tmsg<HandleProcessCbMsg>>(&handle_process);
msg->data.self = this->shared_from_this();
msg->data.rpc_ctx = rpc_ctx;
msg->data.frontend = search.value();
enclave::ThreadMessaging::thread_messaging.add_task<HandleProcessCbMsg>(
enclave::ThreadMessaging::main_thread, std::move(msg));
}
catch (const std::exception& e)
{
Expand Down
56 changes: 56 additions & 0 deletions src/enclave/tlsendpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
#include "tls/context.h"
#include "tls/msg_types.h"

#include <exception>

namespace enclave
{
class TLSEndpoint : public Endpoint
{
protected:
ringbuffer::WriterPtr to_host;
size_t session_id;
size_t execution_thread;

enum Status
{
Expand Down Expand Up @@ -62,6 +65,15 @@ namespace enclave
ctx(move(ctx_)),
status(handshake)
{
if (enclave::ThreadMessaging::thread_count > 1)
{
execution_thread =
(session_id_ % (enclave::ThreadMessaging::thread_count - 1)) + 1;
}
else
{
execution_thread = 0;
}
ctx->set_bio(this, send_callback, recv_callback, dbg_callback);
}

Expand Down Expand Up @@ -195,12 +207,42 @@ namespace enclave

void recv_buffered(const uint8_t* data, size_t size)
{
if (thread_ids[std::this_thread::get_id()] != execution_thread)
{
throw std::exception();
}
pending_read.insert(pending_read.end(), data, data + size);
do_handshake();
}

struct SendRecvMsg
{
std::vector<uint8_t> data;
std::shared_ptr<Endpoint> self;
};

static void send_raw_cb(std::unique_ptr<enclave::Tmsg<SendRecvMsg>> msg)
{
reinterpret_cast<TLSEndpoint*>(msg->data.self.get())
->send_raw_thread(msg->data.data);
}

void send_raw(const std::vector<uint8_t>& data)
{
auto msg = std::make_unique<enclave::Tmsg<SendRecvMsg>>(&send_raw_cb);
msg->data.self = this->shared_from_this();
msg->data.data = data;

enclave::ThreadMessaging::thread_messaging.add_task<SendRecvMsg>(
execution_thread, std::move(msg));
}

void send_raw_thread(std::vector<uint8_t>& data)
{
if (thread_ids[std::this_thread::get_id()] != execution_thread)
{
throw std::runtime_error("running from incorrect thread");
}
// Writes as much of the data as possible. If the data cannot all
// be written now, we store the remainder. We
// will try to send pending writes again whenever write() is called.
Expand All @@ -222,11 +264,21 @@ namespace enclave

void send_buffered(const std::vector<uint8_t>& data)
{
if (thread_ids[std::this_thread::get_id()] != execution_thread)
{
throw std::runtime_error("running from incorrect thread");
}

pending_write.insert(pending_write.end(), data.begin(), data.end());
}

void flush()
{
if (thread_ids[std::this_thread::get_id()] != execution_thread)
{
throw std::runtime_error("running from incorrect thread");
}

do_handshake();

if (status != ready)
Expand Down Expand Up @@ -444,6 +496,10 @@ namespace enclave

int handle_recv(uint8_t* buf, size_t len)
{
if (thread_ids[std::this_thread::get_id()] != execution_thread)
{
throw std::runtime_error("running from incorrect thread");
}
if (pending_read.size() > 0)
{
// Use the pending data vector. This is populated when the host
Expand Down