Skip to content

Commit

Permalink
Improve: Drop iovecs from posix engine.
Browse files Browse the repository at this point in the history
The result is decrased memory usage, and better code.
  • Loading branch information
ishkhan42 committed Apr 11, 2023
1 parent 5e6a7b8 commit 4ccb182
Showing 1 changed file with 106 additions and 140 deletions.
246 changes: 106 additions & 140 deletions src/engine_posix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ using namespace unum::ujrpc;
using time_clock_t = std::chrono::steady_clock;
using time_point_t = std::chrono::time_point<time_clock_t>;

static constexpr std::size_t initial_buffer_size_k = ram_page_size_k * 4;

struct ujrpc_ssl_context_t {
ujrpc_ssl_context_t() noexcept = default;

Expand Down Expand Up @@ -109,7 +111,6 @@ struct engine_t {
~engine_t() noexcept { delete ssl_ctx; }

descriptor_t socket;
std::size_t max_batch_size;

/// @brief Establishes an SSL connection if SSL is enabled, otherwise the `ssl_ctx` is unused and uninitialized.
ujrpc_ssl_context_t* ssl_ctx = nullptr;
Expand All @@ -124,10 +125,9 @@ struct engine_t {
scratch_space_t scratch;
/// @brief For batch-requests in synchronous connections we need a place to
struct batch_response_t {
buffer_gt<struct iovec> iovecs;
buffer_gt<char*> copies;
std::size_t iovecs_count;
std::size_t copies_count;
array_gt<char> buffer;
// buffer_gt<char*> copies;
// std::size_t copies_count;
} batch_response;

stats_t stats;
Expand All @@ -149,15 +149,13 @@ sj::simdjson_result<sjd::element> param_at(ujrpc_call_t call, size_t position) n
return scratch.point_to_param(position);
}

void send_message(engine_t& engine, struct msghdr& message) noexcept {
void send_message(engine_t& engine, array_gt<char> const& message) noexcept {
long bytes_sent = 0;
if (engine.ssl_ctx) {
size_t sz = 0;
for (size_t i = 0; i < message.msg_iovlen; ++i)
sz += message.msg_iov[i].iov_len;
bytes_sent = mbedtls_ssl_write(&engine.ssl_ctx->ssl, (uint8_t*)message.msg_iov->iov_base, sz);
} else
bytes_sent = sendmsg(engine.connection, &message, 0);
if (engine.ssl_ctx)
bytes_sent =
mbedtls_ssl_write(&engine.ssl_ctx->ssl, reinterpret_cast<uint8_t const*>(message.data()), message.size());
else
bytes_sent = send(engine.connection, message.data(), message.size(), 0);

if (bytes_sent < 0) {
if (errno == EMSGSIZE)
Expand All @@ -168,14 +166,11 @@ void send_message(engine_t& engine, struct msghdr& message) noexcept {
engine.stats.packets_sent++;
}

void send_reply(engine_t& engine) noexcept {
if (!engine.batch_response.iovecs_count)
void send_reply(engine_t& engine) noexcept { // TODO Is this required?
if (!engine.batch_response.buffer.size())
return;

struct msghdr message {};
message.msg_iov = engine.batch_response.iovecs.data();
message.msg_iovlen = engine.batch_response.iovecs_count;
send_message(engine, message);
send_message(engine, engine.batch_response.buffer);
}

void forward_call(engine_t& engine) noexcept {
Expand All @@ -199,8 +194,6 @@ void forward_call_or_calls(engine_t& engine) noexcept {
if (one_or_many.error() != sj::SUCCESS)
return ujrpc_call_reply_error(&engine, -32700, "Invalid JSON was received by the server.", 40);

engine.batch_response.iovecs_count = 0;
engine.batch_response.copies_count = 0;
// The major difference between batch and single-request paths is that
// in the first case we need to keep a copy of the data somewhere,
// until answers to all requests are accumulated and we can submit them
Expand All @@ -210,53 +203,40 @@ void forward_call_or_calls(engine_t& engine) noexcept {
if (one_or_many.is_array()) {
sjd::array many = one_or_many.get_array().value_unsafe();
scratch.is_batch = false;
if (many.size() > engine.max_batch_size)
return ujrpc_call_reply_error(&engine, -32603, "Too many requests in the batch.", 31);

// Start a JSON array.
scratch.is_batch = true;
engine.batch_response.iovecs[scratch.is_http].iov_base = const_cast<char*>("[");
engine.batch_response.iovecs[scratch.is_http].iov_len = 1;
engine.batch_response.iovecs_count += scratch.is_http + 1;
bool res = true;
if (scratch.is_http)
res &= engine.batch_response.buffer.append_n(http_header_k, http_header_size_k);

res &= engine.batch_response.buffer.append_n("[", 1);

for (sjd::element const one : many) {
scratch.tree = one;
forward_call(engine);
}

// Drop the last comma. Yeah, it's ugly.
auto last_bucket = (char*)engine.batch_response.iovecs[engine.batch_response.iovecs_count - 1].iov_base;
if (last_bucket[engine.batch_response.iovecs[engine.batch_response.iovecs_count - 1].iov_len - 1] == ',')
engine.batch_response.iovecs[engine.batch_response.iovecs_count - 1].iov_len--;

// Close the last bracket of the JSON array.
engine.batch_response.iovecs[engine.batch_response.iovecs_count].iov_base = (void*)"]";
engine.batch_response.iovecs[engine.batch_response.iovecs_count].iov_len = 1;
engine.batch_response.iovecs_count++;
if (engine.batch_response.buffer[engine.batch_response.buffer.size() - 1] == ',')
engine.batch_response.buffer.pop_back();

if (scratch.is_http) {
size_t body_len = 0;
for (size_t i = 1; i < engine.batch_response.iovecs_count; ++i)
body_len += engine.batch_response.iovecs[i].iov_len;
res &= engine.batch_response.buffer.append_n("]", 1);

char headers[http_header_size_k] = {};
std::memcpy(headers, http_header_k, http_header_size_k);
set_http_content_length(headers, body_len);
if (!res)
return ujrpc_call_reply_error_out_of_memory(&engine);

engine.batch_response.iovecs[0].iov_base = headers;
engine.batch_response.iovecs[0].iov_len = http_header_size_k;
}
if (scratch.is_http)
set_http_content_length(engine.batch_response.buffer.data(),
engine.batch_response.buffer.size() - http_header_size_k);

send_reply(engine);

// Deallocate copies of received responses:
for (std::size_t response_idx = 0; response_idx != engine.batch_response.copies_count; ++response_idx)
std::free(std::exchange(engine.batch_response.copies[response_idx], nullptr));
engine.batch_response.buffer.reset();
} else {
scratch.is_batch = false;
scratch.tree = one_or_many.value_unsafe();
forward_call(engine);
send_reply(engine);
engine.batch_response.buffer.reset();
}
}

Expand Down Expand Up @@ -355,11 +335,13 @@ void ujrpc_take_call(ujrpc_server_t server, uint16_t) {
// or allocate dynamically, if the message is too long.
if (bytes_expected <= ram_page_size_k) {
bytes_expected -= bytes_received;
if (engine.ssl_ctx)
bytes_received += mbedtls_ssl_read(&engine.ssl_ctx->ssl,
reinterpret_cast<uint8_t*>(buffer_ptr + bytes_received), bytes_expected);
else
bytes_received += recv(engine.connection, buffer_ptr + bytes_received, bytes_expected, MSG_WAITALL);
if (bytes_expected > 0) {
if (engine.ssl_ctx)
bytes_received += mbedtls_ssl_read(
&engine.ssl_ctx->ssl, reinterpret_cast<uint8_t*>(buffer_ptr + bytes_received), bytes_expected);
else
bytes_received += recv(engine.connection, buffer_ptr + bytes_received, bytes_expected, MSG_WAITALL);
}
scratch.dynamic_parser = &scratch.parser;
scratch.dynamic_packet = std::string_view(buffer_ptr, bytes_received);
engine.stats.bytes_received += bytes_received;
Expand Down Expand Up @@ -419,8 +401,6 @@ void ujrpc_init(ujrpc_config_t* config_inout, ujrpc_server_t* server_out) {
config.queue_depth = 128u;
if (!config.max_callbacks)
config.max_callbacks = 128u;
if (!config.max_batch_size)
config.max_batch_size = 1024u;
if (!config.interface)
config.interface = "0.0.0.0";
if (config.use_ssl && !(config.ssl_pk_path || config.ssl_crts_cnt))
Expand All @@ -436,8 +416,7 @@ void ujrpc_init(ujrpc_config_t* config_inout, ujrpc_server_t* server_out) {
int socket_options{1};
int socket_descriptor{-1};
engine_t* server_ptr = nullptr;
buffer_gt<struct iovec> embedded_iovecs;
buffer_gt<char*> embedded_copies;
array_gt<char> buffer;
array_gt<named_callback_t> embedded_callbacks;
ujrpc_ssl_context_t* ssl_context = nullptr;
sjd::parser parser;
Expand All @@ -452,13 +431,7 @@ void ujrpc_init(ujrpc_config_t* config_inout, ujrpc_server_t* server_out) {
server_ptr = (engine_t*)std::malloc(sizeof(engine_t));
if (!server_ptr)
goto cleanup;
// In the worst case we may have `max_batch_size` requests, where each will
// need `iovecs_for_content_k` or `iovecs_for_error_k` of `iovec` structures,
// plus two for the opening and closing bracket of JSON.
if (!embedded_iovecs.resize(config.max_batch_size * std::max(iovecs_for_content_k, iovecs_for_error_k) + 2 +
iovecs_for_http_response_k))
goto cleanup;
if (!embedded_copies.resize(config.max_batch_size))
if (!buffer.reserve(initial_buffer_size_k))
goto cleanup;
if (!embedded_callbacks.reserve(config.max_callbacks))
goto cleanup;
Expand All @@ -484,11 +457,9 @@ void ujrpc_init(ujrpc_config_t* config_inout, ujrpc_server_t* server_out) {
// Initialize all the members.
new (server_ptr) engine_t();
server_ptr->socket = descriptor_t{socket_descriptor};
server_ptr->max_batch_size = config.max_batch_size;
server_ptr->callbacks = std::move(embedded_callbacks);
server_ptr->scratch.parser = std::move(parser);
server_ptr->batch_response.copies = std::move(embedded_copies);
server_ptr->batch_response.iovecs = std::move(embedded_iovecs);
server_ptr->batch_response.buffer = std::move(buffer);
server_ptr->logs_file_descriptor = config.logs_file_descriptor;
server_ptr->logs_format = config.logs_format ? std::string_view(config.logs_format) : std::string_view();
server_ptr->log_last_time = time_clock_t::now();
Expand Down Expand Up @@ -526,11 +497,54 @@ void ujrpc_free(ujrpc_server_t server) {
delete engine;
}

void prepend_http_headers(iovec* buffers, size_t content_len, char* http_buffer) {
std::memcpy(http_buffer, http_header_k, http_header_size_k);
set_http_content_length(http_buffer, content_len);
buffers[0].iov_base = const_cast<char*>(http_buffer);
buffers[0].iov_len = http_header_size_k;
bool fill_with_content(array_gt<char>& buffer, std::string_view request_id, std::string_view body,
bool add_http = false, bool append_comma = false) {

// Communication example would be:
// --> {"jsonrpc": "2.0", "method": "subtract", "params": [42, 23], "id": 1}
// <-- {"jsonrpc": "2.0", "id": 1, "result": 19}
bool res = true;
if (add_http)
res &= buffer.append_n(http_header_k, http_header_size_k);

size_t initial_sz = buffer.size();
res &= buffer.append_n(R"({"jsonrpc":"2.0","id":)", 22);
res &= buffer.append_n(request_id.data(), request_id.size());
res &= buffer.append_n(R"(,"result":)", 10);
res &= buffer.append_n(body.data(), body.size());
res &= buffer.append_n(R"(},)", 1 + append_comma);
size_t body_len = buffer.size() - initial_sz;

if (add_http)
set_http_content_length(buffer.end() - (body_len + http_header_size_k), body_len);

return res;
}

bool fill_with_error(array_gt<char>& buffer, std::string_view request_id, std::string_view error_code,
std::string_view error_message, bool add_http = false, bool append_comma = false) {

// Communication example would be:
// --> {"jsonrpc": "2.0", "method": "foobar", "id": "1"}
// <-- {"jsonrpc": "2.0", "id": "1", "error": {"code": -32601, "message": "Method not found"}}
bool res = true;
if (add_http)
res &= buffer.append_n(http_header_k, http_header_size_k);

size_t initial_sz = buffer.size();
res &= buffer.append_n(R"({"jsonrpc":"2.0","id":)", 22);
res &= buffer.append_n(request_id.data(), request_id.size());
res &= buffer.append_n(R"(,"error":{"code":)", 17);
res &= buffer.append_n(error_code.data(), error_code.size());
res &= buffer.append_n(R"(,"message":")", 12);
res &= buffer.append_n(error_message.data(), error_message.size());
res &= buffer.append_n(R"("}},)", 3 + append_comma);
size_t body_len = buffer.size() - initial_sz;

if (add_http)
set_http_content_length(buffer.end() - (body_len + http_header_size_k), body_len);

return res;
}

void ujrpc_call_reply_content(ujrpc_call_t call, ujrpc_str_t body, size_t body_len) {
Expand All @@ -543,36 +557,16 @@ void ujrpc_call_reply_content(ujrpc_call_t call, ujrpc_str_t body, size_t body_l
body_len = std::strlen(body);

// In case of a single request - immediately push into the socket.
if (!scratch.is_batch) {
struct msghdr message {};
if (scratch.is_http) {
struct iovec iovecs[iovecs_for_content_k + 1]{};
size_t content_len = fill_with_content(iovecs + 1, scratch.dynamic_id, std::string_view(body, body_len));
message.msg_iov = iovecs;
message.msg_iovlen = iovecs_for_content_k + 1;
char headers[http_header_size_k];
prepend_http_headers(iovecs, content_len, headers);
send_message(engine, message);
} else {
struct iovec iovecs[iovecs_for_content_k] {};
fill_with_content(iovecs, scratch.dynamic_id, std::string_view(body, body_len));
message.msg_iov = iovecs;
message.msg_iovlen = iovecs_for_content_k;
send_message(engine, message);
}
}

// In case of a batch or async request, preserve a copy of data on the heap.
else {
auto body_copy = (char*)std::malloc(body_len);
if (!body_copy)
if (!scratch.is_batch)
if (fill_with_content(engine.batch_response.buffer, scratch.dynamic_id, //
std::string_view(body, body_len), scratch.is_http))
send_message(engine, engine.batch_response.buffer);
else
return ujrpc_call_reply_error_out_of_memory(call);
std::memcpy(body_copy, body, body_len);
engine.batch_response.copies[engine.batch_response.copies_count++] = body_copy;
fill_with_content(engine.batch_response.iovecs.data() + engine.batch_response.iovecs_count, scratch.dynamic_id,
std::string_view(body_copy, body_len), true);
engine.batch_response.iovecs_count += iovecs_for_content_k;
}

else if (!fill_with_content(engine.batch_response.buffer, scratch.dynamic_id, //
std::string_view(body, body_len), false, true))
return ujrpc_call_reply_error_out_of_memory(call);
}

void ujrpc_call_reply_error(ujrpc_call_t call, int code_int, ujrpc_str_t note, size_t note_len) {
Expand All @@ -591,45 +585,17 @@ void ujrpc_call_reply_error(ujrpc_call_t call, int code_int, ujrpc_str_t note, s
return ujrpc_call_reply_error_unknown(call);

// In case of a single request - immediately push into the socket.
if (!scratch.is_batch) {
struct msghdr message {};
if (scratch.is_http) {
struct iovec iovecs[iovecs_for_error_k + 1]{};
size_t content_len = fill_with_error(iovecs + 1, scratch.dynamic_id, //
std::string_view(code, code_len), //
std::string_view(note, note_len));
message.msg_iov = iovecs;
message.msg_iovlen = iovecs_for_error_k + 1;
char headers[http_header_size_k];
prepend_http_headers(iovecs, content_len, headers);
send_message(engine, message);
} else {
struct iovec iovecs[iovecs_for_error_k] {};
fill_with_error(iovecs, scratch.dynamic_id, //
std::string_view(code, code_len), //
std::string_view(note, note_len));

message.msg_iov = iovecs;
message.msg_iovlen = iovecs_for_error_k;
send_message(engine, message);
}

}

// In case of a batch or async request, preserve a copy of data on the heap.
else {
auto code_and_node = (char*)std::malloc(code_len + note_len);
if (!code_and_node)
if (!scratch.is_batch)
if (fill_with_error(engine.batch_response.buffer, scratch.dynamic_id, //
std::string_view(code, code_len), std::string_view(note, note_len), scratch.is_http))
send_message(engine, engine.batch_response.buffer);
else
return ujrpc_call_reply_error_out_of_memory(call);
std::memcpy(code_and_node, code, code_len);
std::memcpy(code_and_node + code_len, note, note_len);
engine.batch_response.copies[engine.batch_response.copies_count++] = code_and_node;
fill_with_error(engine.batch_response.iovecs.data() + engine.batch_response.iovecs_count,
scratch.dynamic_id, //
std::string_view(code_and_node, code_len), //
std::string_view(code_and_node + code_len, note_len), true);
engine.batch_response.iovecs_count += iovecs_for_error_k;
}

else if (!fill_with_error(engine.batch_response.buffer, scratch.dynamic_id, //
std::string_view(code, code_len), //
std::string_view(note, note_len), false, true))
return ujrpc_call_reply_error_out_of_memory(call);
}

void ujrpc_call_reply_error_invalid_params(ujrpc_call_t call) {
Expand Down

0 comments on commit 4ccb182

Please sign in to comment.