Skip to content

Commit

Permalink
Create Rsa_client class, in effort to eliminate global vars
Browse files Browse the repository at this point in the history
  • Loading branch information
AGWA committed May 8, 2015
1 parent aed241d commit 3cf7600
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 66 deletions.
2 changes: 1 addition & 1 deletion child.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ try {
throw System_error("connect", keyserver_sockaddr.sun_path, errno);
}

rsa_client_set_socket(std::move(keyserver_client_sock));
rsa_client.set_socket(std::move(keyserver_client_sock));

// Create the backend socket. Since setting transparency requires privilege,
// we do it while we're still root.
Expand Down
1 change: 1 addition & 0 deletions common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@ int children_pipe[2];
struct sockaddr_un keyserver_sockaddr;
socklen_t keyserver_sockaddr_len;
Vhost* active_vhost;
Rsa_client rsa_client;
2 changes: 2 additions & 0 deletions common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#define COMMON_HHP

#include "util.hpp"
#include "rsa_client.hpp"
#include <string>
#include <vector>
#include <sys/types.h>
Expand Down Expand Up @@ -77,5 +78,6 @@ extern int children_pipe[2]; // Used by children to tell us when they accept
extern struct sockaddr_un keyserver_sockaddr;
extern socklen_t keyserver_sockaddr_len;
extern Vhost* active_vhost;
extern Rsa_client rsa_client;

#endif
142 changes: 80 additions & 62 deletions rsa_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,78 +36,76 @@


namespace {
filedesc sock;
struct Rsa_client_data {
Rsa_client* client = nullptr;
uintptr_t key_id = 0;
};
}

void send_to_server (const void* data, size_t len)
{
write_all(sock, data, len);
}
void recv_from_server (void* data, size_t len)
{
if (!read_all(sock, data, len)) {
throw Key_protocol_error("Server ended connection prematurely");
}
int Rsa_client::rsa_private_decrypt (int flen, const unsigned char* from, unsigned char* to, RSA* rsa, int padding)
{
const uint8_t command = 1;
Rsa_client_data* data = reinterpret_cast<Rsa_client_data*>(RSA_get_app_data(rsa));

data->client->send_to_server(&command, sizeof(command));
data->client->send_to_server(&data->key_id, sizeof(data->key_id));
data->client->send_to_server(&padding, sizeof(padding));
data->client->send_to_server(&flen, sizeof(flen));
data->client->send_to_server(from, flen);

int plain_len;
data->client->recv_from_server(&plain_len, sizeof(plain_len));
if (plain_len > 0) {
data->client->recv_from_server(to, plain_len);
}

int rsa_client_private_decrypt (int flen, const unsigned char* from, unsigned char* to, RSA* rsa, int padding)
{
uint8_t command = 1;
uintptr_t key_id = reinterpret_cast<uintptr_t>(RSA_get_app_data(rsa));

send_to_server(&command, sizeof(command));
send_to_server(&key_id, sizeof(key_id));
send_to_server(&padding, sizeof(padding));
send_to_server(&flen, sizeof(flen));
send_to_server(from, flen);

int plain_len;
recv_from_server(&plain_len, sizeof(plain_len));
if (plain_len > 0) {
recv_from_server(to, plain_len);
}
return plain_len;
}

return plain_len;
int Rsa_client::rsa_private_encrypt (int flen, const unsigned char* from, unsigned char* to, RSA* rsa, int padding)
{
const uint8_t command = 2;
Rsa_client_data* data = reinterpret_cast<Rsa_client_data*>(RSA_get_app_data(rsa));

data->client->send_to_server(&command, sizeof(command));
data->client->send_to_server(&data->key_id, sizeof(data->key_id));
data->client->send_to_server(&padding, sizeof(padding));
data->client->send_to_server(&flen, sizeof(flen));
data->client->send_to_server(from, flen);

int sig_len;
data->client->recv_from_server(&sig_len, sizeof(sig_len));
if (sig_len > 0) {
data->client->recv_from_server(to, sig_len);
}

int rsa_client_private_encrypt (int flen, const unsigned char* from, unsigned char* to, RSA* rsa, int padding)
{
uint8_t command = 2;
uintptr_t key_id = reinterpret_cast<uintptr_t>(RSA_get_app_data(rsa));

send_to_server(&command, sizeof(command));
send_to_server(&key_id, sizeof(key_id));
send_to_server(&padding, sizeof(padding));
send_to_server(&flen, sizeof(flen));
send_to_server(from, flen);

int sig_len;
recv_from_server(&sig_len, sizeof(sig_len));
if (sig_len > 0) {
recv_from_server(to, sig_len);
}

return sig_len;
}
return sig_len;
}

const RSA_METHOD* get_rsa_client_method ()
{
static RSA_METHOD ops;
if (!ops.rsa_priv_enc) {
ops = *RSA_get_default_method();
ops.rsa_priv_enc = rsa_client_private_encrypt;
ops.rsa_priv_dec = rsa_client_private_decrypt;
}
return &ops;
int Rsa_client::rsa_finish (RSA* rsa)
{
delete reinterpret_cast<Rsa_client_data*>(RSA_get_app_data(rsa));
if (const auto default_finish = RSA_get_default_method()->finish) {
return (*default_finish)(rsa);
} else {
return 1;
}
}

openssl_unique_ptr<EVP_PKEY> rsa_client_load_private_key (uintptr_t key_id, RSA* public_rsa)
const RSA_METHOD* Rsa_client::get_rsa_method ()
{
openssl_unique_ptr<EVP_PKEY> private_key(EVP_PKEY_new());
if (!private_key) {
throw Openssl_error(ERR_get_error());
static RSA_METHOD ops;
if (!ops.rsa_priv_enc) {
ops = *RSA_get_default_method();
ops.rsa_priv_enc = rsa_private_encrypt;
ops.rsa_priv_dec = rsa_private_decrypt;
ops.finish = rsa_finish;
}
return &ops;
}

openssl_unique_ptr<EVP_PKEY> Rsa_client::load_private_key (uintptr_t key_id, RSA* public_rsa)
{
openssl_unique_ptr<RSA> rsa(RSA_new());
if (!rsa) {
throw Openssl_error(ERR_get_error());
Expand All @@ -122,8 +120,17 @@ openssl_unique_ptr<EVP_PKEY> rsa_client_load_private_key (uintptr_t key_id, RSA*
throw Openssl_error(ERR_get_error());
}

RSA_set_method(rsa.get(), get_rsa_client_method());
if (!RSA_set_app_data(rsa.get(), reinterpret_cast<void*>(key_id))) {
std::unique_ptr<Rsa_client_data> client_data(new Rsa_client_data);
client_data->client = this;
client_data->key_id = key_id;
if (!RSA_set_app_data(rsa.get(), client_data.get())) {
throw Openssl_error(ERR_get_error());
}
RSA_set_method(rsa.get(), get_rsa_method());
client_data.release(); // After calling RSA_set_method, client_data is owned by rsa.

openssl_unique_ptr<EVP_PKEY> private_key(EVP_PKEY_new());
if (!private_key) {
throw Openssl_error(ERR_get_error());
}

Expand All @@ -136,8 +143,19 @@ openssl_unique_ptr<EVP_PKEY> rsa_client_load_private_key (uintptr_t key_id, RSA*
return private_key;
}

void rsa_client_set_socket (filedesc arg_sock)
void Rsa_client::set_socket (filedesc arg_sock)
{
sock = std::move(arg_sock);
}

void Rsa_client::send_to_server (const void* data, size_t len) const
{
write_all(sock, data, len);
}

void Rsa_client::recv_from_server (void* data, size_t len) const
{
if (!read_all(sock, data, len)) {
throw Key_protocol_error("Server ended connection prematurely");
}
}
25 changes: 23 additions & 2 deletions rsa_client.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,29 @@
#include <openssl/evp.h>
#include <stdint.h>
#include "util.hpp"
#include "filedesc.hpp"

openssl_unique_ptr<EVP_PKEY> rsa_client_load_private_key (uintptr_t key_id, RSA* public_rsa);
void rsa_client_set_socket (filedesc fd);
class Rsa_client {
filedesc sock;

void send_to_server (const void* data, size_t len) const;
void recv_from_server (void* data, size_t len) const;

static int rsa_private_decrypt (int flen, const unsigned char* from, unsigned char* to, RSA* rsa, int padding);
static int rsa_private_encrypt (int flen, const unsigned char* from, unsigned char* to, RSA* rsa, int padding);
static int rsa_finish (RSA* rsa);
static const RSA_METHOD* get_rsa_method ();
public:
// Note: you can't move Rsa_clients because doing so would leave
// dangling pointers in all the private keys created from it.
Rsa_client () = default;
Rsa_client (const Rsa_client&) = delete;
Rsa_client (Rsa_client&&) = delete;
Rsa_client& operator= (const Rsa_client&) = delete;
Rsa_client& operator= (Rsa_client&&) = delete;

openssl_unique_ptr<EVP_PKEY> load_private_key (uintptr_t key_id, RSA* public_rsa);
void set_socket (filedesc);
};

#endif
2 changes: 1 addition & 1 deletion titus.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ namespace {
}

// Create a RSA private key "client"
openssl_unique_ptr<EVP_PKEY> privkey(rsa_client_load_private_key(vhost.id, public_rsa.get()));
openssl_unique_ptr<EVP_PKEY> privkey(rsa_client.load_private_key(vhost.id, public_rsa.get()));
public_rsa.reset();

// Use this private key for SSL:
Expand Down

0 comments on commit 3cf7600

Please sign in to comment.