From 3cf7600a1b5067cea9d7607ded7415ba18c4f13e Mon Sep 17 00:00:00 2001 From: Andrew Ayer Date: Fri, 8 May 2015 16:12:54 -0700 Subject: [PATCH] Create Rsa_client class, in effort to eliminate global vars --- child.cpp | 2 +- common.cpp | 1 + common.hpp | 2 + rsa_client.cpp | 142 ++++++++++++++++++++++++++++--------------------- rsa_client.hpp | 25 ++++++++- titus.cpp | 2 +- 6 files changed, 108 insertions(+), 66 deletions(-) diff --git a/child.cpp b/child.cpp index 3bea3b9..df00c5a 100644 --- a/child.cpp +++ b/child.cpp @@ -310,7 +310,7 @@ try { throw System_error("connect", keyserver_sockaddr.sun_path, errno); } - rsa_client_set_socket(std::move(keyserver_client_sock)); + rsa_client.set_socket(std::move(keyserver_client_sock)); // Create the backend socket. Since setting transparency requires privilege, // we do it while we're still root. diff --git a/common.cpp b/common.cpp index 72358e1..a35bd06 100644 --- a/common.cpp +++ b/common.cpp @@ -43,3 +43,4 @@ int children_pipe[2]; struct sockaddr_un keyserver_sockaddr; socklen_t keyserver_sockaddr_len; Vhost* active_vhost; +Rsa_client rsa_client; diff --git a/common.hpp b/common.hpp index e95d501..4321662 100644 --- a/common.hpp +++ b/common.hpp @@ -29,6 +29,7 @@ #define COMMON_HHP #include "util.hpp" +#include "rsa_client.hpp" #include #include #include @@ -77,5 +78,6 @@ extern int children_pipe[2]; // Used by children to tell us when they accept extern struct sockaddr_un keyserver_sockaddr; extern socklen_t keyserver_sockaddr_len; extern Vhost* active_vhost; +extern Rsa_client rsa_client; #endif diff --git a/rsa_client.cpp b/rsa_client.cpp index 1e901a8..4b6f460 100644 --- a/rsa_client.cpp +++ b/rsa_client.cpp @@ -36,78 +36,76 @@ namespace { - filedesc sock; + struct Rsa_client_data { + Rsa_client* client = nullptr; + uintptr_t key_id = 0; + }; +} - void send_to_server (const void* data, size_t len) - { - write_all(sock, data, len); - } - void recv_from_server (void* data, size_t len) - { - if (!read_all(sock, data, len)) { - throw Key_protocol_error("Server ended connection prematurely"); - } +int Rsa_client::rsa_private_decrypt (int flen, const unsigned char* from, unsigned char* to, RSA* rsa, int padding) +{ + const uint8_t command = 1; + Rsa_client_data* data = reinterpret_cast(RSA_get_app_data(rsa)); + + data->client->send_to_server(&command, sizeof(command)); + data->client->send_to_server(&data->key_id, sizeof(data->key_id)); + data->client->send_to_server(&padding, sizeof(padding)); + data->client->send_to_server(&flen, sizeof(flen)); + data->client->send_to_server(from, flen); + + int plain_len; + data->client->recv_from_server(&plain_len, sizeof(plain_len)); + if (plain_len > 0) { + data->client->recv_from_server(to, plain_len); } - int rsa_client_private_decrypt (int flen, const unsigned char* from, unsigned char* to, RSA* rsa, int padding) - { - uint8_t command = 1; - uintptr_t key_id = reinterpret_cast(RSA_get_app_data(rsa)); - - send_to_server(&command, sizeof(command)); - send_to_server(&key_id, sizeof(key_id)); - send_to_server(&padding, sizeof(padding)); - send_to_server(&flen, sizeof(flen)); - send_to_server(from, flen); - - int plain_len; - recv_from_server(&plain_len, sizeof(plain_len)); - if (plain_len > 0) { - recv_from_server(to, plain_len); - } + return plain_len; +} - return plain_len; +int Rsa_client::rsa_private_encrypt (int flen, const unsigned char* from, unsigned char* to, RSA* rsa, int padding) +{ + const uint8_t command = 2; + Rsa_client_data* data = reinterpret_cast(RSA_get_app_data(rsa)); + + data->client->send_to_server(&command, sizeof(command)); + data->client->send_to_server(&data->key_id, sizeof(data->key_id)); + data->client->send_to_server(&padding, sizeof(padding)); + data->client->send_to_server(&flen, sizeof(flen)); + data->client->send_to_server(from, flen); + + int sig_len; + data->client->recv_from_server(&sig_len, sizeof(sig_len)); + if (sig_len > 0) { + data->client->recv_from_server(to, sig_len); } - int rsa_client_private_encrypt (int flen, const unsigned char* from, unsigned char* to, RSA* rsa, int padding) - { - uint8_t command = 2; - uintptr_t key_id = reinterpret_cast(RSA_get_app_data(rsa)); - - send_to_server(&command, sizeof(command)); - send_to_server(&key_id, sizeof(key_id)); - send_to_server(&padding, sizeof(padding)); - send_to_server(&flen, sizeof(flen)); - send_to_server(from, flen); - - int sig_len; - recv_from_server(&sig_len, sizeof(sig_len)); - if (sig_len > 0) { - recv_from_server(to, sig_len); - } - - return sig_len; - } + return sig_len; +} - const RSA_METHOD* get_rsa_client_method () - { - static RSA_METHOD ops; - if (!ops.rsa_priv_enc) { - ops = *RSA_get_default_method(); - ops.rsa_priv_enc = rsa_client_private_encrypt; - ops.rsa_priv_dec = rsa_client_private_decrypt; - } - return &ops; +int Rsa_client::rsa_finish (RSA* rsa) +{ + delete reinterpret_cast(RSA_get_app_data(rsa)); + if (const auto default_finish = RSA_get_default_method()->finish) { + return (*default_finish)(rsa); + } else { + return 1; } } -openssl_unique_ptr rsa_client_load_private_key (uintptr_t key_id, RSA* public_rsa) +const RSA_METHOD* Rsa_client::get_rsa_method () { - openssl_unique_ptr private_key(EVP_PKEY_new()); - if (!private_key) { - throw Openssl_error(ERR_get_error()); + static RSA_METHOD ops; + if (!ops.rsa_priv_enc) { + ops = *RSA_get_default_method(); + ops.rsa_priv_enc = rsa_private_encrypt; + ops.rsa_priv_dec = rsa_private_decrypt; + ops.finish = rsa_finish; } + return &ops; +} +openssl_unique_ptr Rsa_client::load_private_key (uintptr_t key_id, RSA* public_rsa) +{ openssl_unique_ptr rsa(RSA_new()); if (!rsa) { throw Openssl_error(ERR_get_error()); @@ -122,8 +120,17 @@ openssl_unique_ptr rsa_client_load_private_key (uintptr_t key_id, RSA* throw Openssl_error(ERR_get_error()); } - RSA_set_method(rsa.get(), get_rsa_client_method()); - if (!RSA_set_app_data(rsa.get(), reinterpret_cast(key_id))) { + std::unique_ptr client_data(new Rsa_client_data); + client_data->client = this; + client_data->key_id = key_id; + if (!RSA_set_app_data(rsa.get(), client_data.get())) { + throw Openssl_error(ERR_get_error()); + } + RSA_set_method(rsa.get(), get_rsa_method()); + client_data.release(); // After calling RSA_set_method, client_data is owned by rsa. + + openssl_unique_ptr private_key(EVP_PKEY_new()); + if (!private_key) { throw Openssl_error(ERR_get_error()); } @@ -136,8 +143,19 @@ openssl_unique_ptr rsa_client_load_private_key (uintptr_t key_id, RSA* return private_key; } -void rsa_client_set_socket (filedesc arg_sock) +void Rsa_client::set_socket (filedesc arg_sock) { sock = std::move(arg_sock); } +void Rsa_client::send_to_server (const void* data, size_t len) const +{ + write_all(sock, data, len); +} + +void Rsa_client::recv_from_server (void* data, size_t len) const +{ + if (!read_all(sock, data, len)) { + throw Key_protocol_error("Server ended connection prematurely"); + } +} diff --git a/rsa_client.hpp b/rsa_client.hpp index c8d4eb8..c7c88f3 100644 --- a/rsa_client.hpp +++ b/rsa_client.hpp @@ -32,8 +32,29 @@ #include #include #include "util.hpp" +#include "filedesc.hpp" -openssl_unique_ptr rsa_client_load_private_key (uintptr_t key_id, RSA* public_rsa); -void rsa_client_set_socket (filedesc fd); +class Rsa_client { + filedesc sock; + + void send_to_server (const void* data, size_t len) const; + void recv_from_server (void* data, size_t len) const; + + static int rsa_private_decrypt (int flen, const unsigned char* from, unsigned char* to, RSA* rsa, int padding); + static int rsa_private_encrypt (int flen, const unsigned char* from, unsigned char* to, RSA* rsa, int padding); + static int rsa_finish (RSA* rsa); + static const RSA_METHOD* get_rsa_method (); +public: + // Note: you can't move Rsa_clients because doing so would leave + // dangling pointers in all the private keys created from it. + Rsa_client () = default; + Rsa_client (const Rsa_client&) = delete; + Rsa_client (Rsa_client&&) = delete; + Rsa_client& operator= (const Rsa_client&) = delete; + Rsa_client& operator= (Rsa_client&&) = delete; + + openssl_unique_ptr load_private_key (uintptr_t key_id, RSA* public_rsa); + void set_socket (filedesc); +}; #endif diff --git a/titus.cpp b/titus.cpp index 43c2439..dc6260d 100644 --- a/titus.cpp +++ b/titus.cpp @@ -322,7 +322,7 @@ namespace { } // Create a RSA private key "client" - openssl_unique_ptr privkey(rsa_client_load_private_key(vhost.id, public_rsa.get())); + openssl_unique_ptr privkey(rsa_client.load_private_key(vhost.id, public_rsa.get())); public_rsa.reset(); // Use this private key for SSL: