Skip to content

Commit

Permalink
cryptoutils: ContextHolder; drop extra status variables
Browse files Browse the repository at this point in the history
Using ContextHolder plugs quite a few memory leaks on error conditions
occurring in cryptoutils. Most of those conditions are grave but that
doesn't mean we should go lax on resource management.
  • Loading branch information
Alexey Rusakov committed Dec 17, 2023
1 parent 5da2f6d commit 3c33dc8
Showing 1 changed file with 48 additions and 50 deletions.
98 changes: 48 additions & 50 deletions Quotient/e2ee/cryptoutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@ using namespace Quotient;
static_assert(std::is_same_v<SslErrorCode, decltype(ERR_get_error())>);
static_assert(SslErrorUserOffset == ERR_LIB_USER);

//! \brief A wrapper for `std::unique_ptr` for use with OpenSSL context functions
//!
//! This class and the deduction guide for it are merely to remove
//! the boilerplate necessary to pass custom deleter to `std::unique_ptr`.
//! Usage: `const ContextHolder ctx(CTX_new(), &CTX_free);`, where `CTX_new` and
//! `CTX_free` are the matching allocation and deallocation functions from
//! OpenSSL API. You can pass additional parameters to the allocation function
//! as needed; the deallocation function is assumed to take exactly one
//! parameter of the same type that is returned by the allocation function.
template <class Context>
class ContextHolder : public std::unique_ptr<Context, void (*)(Context*)> {
public:
using std::unique_ptr<Context, void (*)(Context*)>::unique_ptr;
};
template <class CryptoContext, typename Deleter>
ContextHolder(CryptoContext*, Deleter) -> ContextHolder<CryptoContext>;

SslExpected<QByteArray> Quotient::pbkdf2HmacSha512(const QByteArray& password,
const QByteArray& salt,
int iterations, int keyLength)
Expand All @@ -39,57 +56,49 @@ SslExpected<QByteArray> Quotient::aesCtr256Encrypt(const QByteArray& plaintext,
const QByteArray& key,
const QByteArray& iv)
{
EVP_CIPHER_CTX* ctx = nullptr;
int length = 0;
int ciphertextLength = 0;

auto encrypted = QByteArray(plaintext.size() + AES_BLOCK_SIZE, u'\0');
int status = RAND_bytes(reinterpret_cast<unsigned char*>(encrypted.data()), encrypted.size());
if (status != 1) {
if (RAND_bytes(reinterpret_cast<unsigned char*>(encrypted.data()), encrypted.size()) != 1) {
qWarning() << ERR_error_string(ERR_get_error(), nullptr);
return ERR_get_error();
}
auto data = encrypted.data();
constexpr auto mask = static_cast<std::uint8_t>(~(1U << (63 / 8)));
data[15 - 63 % 8] &= mask;
if (ctx = EVP_CIPHER_CTX_new(); !ctx) {

const ContextHolder ctx(EVP_CIPHER_CTX_new(), &EVP_CIPHER_CTX_free);
if (!ctx) {
qWarning() << ERR_error_string(ERR_get_error(), nullptr);
return ERR_get_error();
}

status = EVP_EncryptInit_ex(ctx, EVP_aes_256_ctr(), nullptr, reinterpret_cast<const unsigned char*>(key.data()), reinterpret_cast<const unsigned char*>(iv.data()));
if (status != 1) {
if (EVP_EncryptInit_ex(ctx.get(), EVP_aes_256_ctr(), nullptr, reinterpret_cast<const unsigned char*>(key.data()), reinterpret_cast<const unsigned char*>(iv.data())) != 1) {
qWarning() << ERR_error_string(ERR_get_error(), nullptr);
EVP_CIPHER_CTX_free(ctx);
return ERR_get_error();
}

status = EVP_EncryptUpdate(ctx, reinterpret_cast<unsigned char*>(encrypted.data()), &length, reinterpret_cast<const unsigned char *>(&plaintext.data()[0]), (int) plaintext.size());
if (status != 1) {
if (EVP_EncryptUpdate(ctx.get(), reinterpret_cast<unsigned char*>(encrypted.data()), &length, reinterpret_cast<const unsigned char *>(&plaintext.data()[0]), (int) plaintext.size()) != 1) {
qWarning() << ERR_error_string(ERR_get_error(), nullptr);
EVP_CIPHER_CTX_free(ctx);
return ERR_get_error();
}

ciphertextLength = length;
status = EVP_EncryptFinal_ex(ctx, reinterpret_cast<unsigned char*>(encrypted.data()) + length, &length);
if (status != 1) {
if (EVP_EncryptFinal_ex(ctx.get(), reinterpret_cast<unsigned char*>(encrypted.data()) + length, &length) != 1) {
qWarning() << ERR_error_string(ERR_get_error(), nullptr);
EVP_CIPHER_CTX_free(ctx);
return ERR_get_error();
}

ciphertextLength += length;
encrypted.resize(ciphertextLength);
EVP_CIPHER_CTX_free(ctx);
return encrypted;
}

#define CALL_OPENSSL(Call_) \
do { \
if (Call_ != 1) { \
qWarning() << ERR_error_string(ERR_get_error(), nullptr); \
EVP_PKEY_CTX_free(context); \
return ERR_get_error(); \
} \
} while (false) // End of macro
Expand All @@ -99,30 +108,31 @@ SslExpected<HkdfKeys> Quotient::hkdfSha256(const QByteArray& key,
const QByteArray& info)
{
QByteArray result(64, u'\0');
auto context = EVP_PKEY_CTX_new_id(EVP_PKEY_HKDF, nullptr);
const ContextHolder context(EVP_PKEY_CTX_new_id(EVP_PKEY_HKDF, nullptr),
&EVP_PKEY_CTX_free);

CALL_OPENSSL(EVP_PKEY_derive_init(context));
CALL_OPENSSL(EVP_PKEY_CTX_set_hkdf_md(context, EVP_sha256()));
CALL_OPENSSL(EVP_PKEY_derive_init(context.get()));
CALL_OPENSSL(EVP_PKEY_CTX_set_hkdf_md(context.get(), EVP_sha256()));
CALL_OPENSSL(EVP_PKEY_CTX_set1_hkdf_salt(
context, reinterpret_cast<const unsigned char*>(salt.data()),
context.get(), reinterpret_cast<const unsigned char*>(salt.data()),
salt.size()));
CALL_OPENSSL(EVP_PKEY_CTX_set1_hkdf_key(
context, reinterpret_cast<const unsigned char*>(key.data()),
context.get(), reinterpret_cast<const unsigned char*>(key.data()),
key.size()));
CALL_OPENSSL(EVP_PKEY_CTX_add1_hkdf_info(
context, reinterpret_cast<const unsigned char*>(info.data()),
context.get(), reinterpret_cast<const unsigned char*>(info.data()),
info.size()));
auto outputLength = unsignedSize(result);
CALL_OPENSSL(EVP_PKEY_derive(context, reinterpret_cast<unsigned char *>(result.data()), &outputLength));
CALL_OPENSSL(EVP_PKEY_derive(context.get(),
reinterpret_cast<unsigned char*>(result.data()),
&outputLength));
if (outputLength != 64) {
qCCritical(E2EE) << "hkdfSha256: the derived key is" << outputLength
<< "bytes instead of 64";
Q_ASSERT(false);
return WrongDerivedKeyLength;
}

EVP_PKEY_CTX_free(context);

auto macKey = result.mid(32);
result.resize(32);
return HkdfKeys{std::move(result), std::move(macKey)};
Expand All @@ -133,8 +143,7 @@ SslExpected<QByteArray> Quotient::hmacSha256(const QByteArray& hmacKey,
{
uint32_t len = SHA256_DIGEST_LENGTH;
QByteArray output(SHA256_DIGEST_LENGTH, u'\0');
auto status = HMAC(EVP_sha256(), hmacKey.data(), hmacKey.size(), reinterpret_cast<const unsigned char *>(data.data()), data.size(), reinterpret_cast<unsigned char *>(output.data()), &len);
if (!status) {
if (!HMAC(EVP_sha256(), hmacKey.data(), hmacKey.size(), reinterpret_cast<const unsigned char *>(data.data()), data.size(), reinterpret_cast<unsigned char *>(output.data()), &len)) {
qWarning() << ERR_error_string(ERR_get_error(), nullptr);
return ERR_get_error();
}
Expand All @@ -145,35 +154,31 @@ SslExpected<QByteArray> Quotient::aesCtr256Decrypt(const QByteArray& ciphertext,
const QByteArray& aes256Key,
const QByteArray& iv)
{
auto context = EVP_CIPHER_CTX_new();
const ContextHolder context(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free);
Q_ASSERT(context);

int length = 0;
int plaintextLength = 0;
QByteArray decrypted(ciphertext.size(), u'\0');

int status = EVP_DecryptInit_ex(context, EVP_aes_256_ctr(), nullptr, reinterpret_cast<const unsigned char *>(aes256Key.data()), reinterpret_cast<const unsigned char *>(iv.data()));
if (status != 1) {
if (EVP_DecryptInit_ex(context.get(), EVP_aes_256_ctr(), nullptr, reinterpret_cast<const unsigned char *>(aes256Key.data()), reinterpret_cast<const unsigned char *>(iv.data())) != 1) {
qWarning() << ERR_error_string(ERR_get_error(), nullptr);
return ERR_get_error();
}

status = EVP_DecryptUpdate(context, reinterpret_cast<unsigned char*>(decrypted.data()), &length, reinterpret_cast<const unsigned char*>(&ciphertext.data()[0]), (int) ciphertext.size());
if (status != 1) {
if (EVP_DecryptUpdate(context.get(), reinterpret_cast<unsigned char*>(decrypted.data()), &length, reinterpret_cast<const unsigned char*>(&ciphertext.data()[0]), (int) ciphertext.size()) != 1) {
qWarning() << ERR_error_string(ERR_get_error(), nullptr);
return ERR_get_error();
}

plaintextLength = length;
status = EVP_DecryptFinal_ex(context, reinterpret_cast<unsigned char*>(decrypted.data()) + length, &length);
if (status != 1) {
if (EVP_DecryptFinal_ex(context.get(), reinterpret_cast<unsigned char*>(decrypted.data()) + length, &length) != 1) {
qWarning() << ERR_error_string(ERR_get_error(), nullptr);
return ERR_get_error();
}

plaintextLength += length;
decrypted.resize(plaintextLength);
EVP_CIPHER_CTX_free(context);
return decrypted;
}

Expand All @@ -185,17 +190,14 @@ QOlmExpected<QByteArray> Quotient::curve25519AesSha2Decrypt(
Q_ASSERT(context);

QByteArray publicKey(olm_pk_key_length(), u'\0');
auto status = olm_pk_key_from_private(context.get(), publicKey.data(), publicKey.size(), privateKey.data(), privateKey.size());
if (status == olm_error()) {
if (olm_pk_key_from_private(context.get(), publicKey.data(), publicKey.size(), privateKey.data(), privateKey.size()) == olm_error())
return olm_pk_decryption_last_error_code(context.get());
}

QByteArray plaintext(olm_pk_max_plaintext_length(context.get(), ciphertext.size()), u'\0');
auto result = olm_pk_decrypt(context.get(), ephemeral.data(), ephemeral.size(), mac.data(), mac.size(), ciphertext.data(), ciphertext.size(), plaintext.data(), plaintext.size());

if (result == olm_error()) {
if (result == olm_error())
return olm_pk_decryption_last_error_code(context.get());
}

plaintext.resize(result);
return plaintext;
}
Expand All @@ -205,24 +207,20 @@ QOlmExpected<Curve25519Encrypted> Quotient::curve25519AesSha2Encrypt(
{
auto context = makeCStruct(olm_pk_encryption, olm_pk_encryption_size, olm_clear_pk_encryption);

auto status = olm_pk_encryption_set_recipient_key(context.get(), publicKey.data(), publicKey.size());
if (status == olm_error()) {
if (olm_pk_encryption_set_recipient_key(context.get(), publicKey.data(), publicKey.size()) == olm_error())
return olm_pk_encryption_last_error_code(context.get());
}

QByteArray ephemeral(olm_pk_key_length(), 0);
QByteArray mac(olm_pk_mac_length(context.get()), 0);
QByteArray ciphertext(olm_pk_ciphertext_length(context.get(), plaintext.size()), 0);
const auto random = getRandom(olm_pk_encrypt_random_length(context.get()));

auto result = olm_pk_encrypt(context.get(), plaintext.data(),
plaintext.size(), ciphertext.data(),
ciphertext.size(), mac.data(), mac.size(),
ephemeral.data(), ephemeral.size(),
random.data(), random.size());
if (result == olm_error()) {
if (olm_pk_encrypt(context.get(), plaintext.data(), plaintext.size(),
ciphertext.data(), ciphertext.size(), mac.data(),
mac.size(), ephemeral.data(), ephemeral.size(),
random.data(), random.size())
== olm_error())
return olm_pk_encryption_last_error_code(context.get());
}

return Curve25519Encrypted {
.ciphertext = ciphertext,
Expand Down

0 comments on commit 3c33dc8

Please sign in to comment.