diff --git a/Quotient/e2ee/cryptoutils.cpp b/Quotient/e2ee/cryptoutils.cpp index 5a855af69..718cf9e02 100644 --- a/Quotient/e2ee/cryptoutils.cpp +++ b/Quotient/e2ee/cryptoutils.cpp @@ -22,6 +22,23 @@ using namespace Quotient; static_assert(std::is_same_v); static_assert(SslErrorUserOffset == ERR_LIB_USER); +//! \brief A wrapper for `std::unique_ptr` for use with OpenSSL context functions +//! +//! This class and the deduction guide for it are merely to remove +//! the boilerplate necessary to pass custom deleter to `std::unique_ptr`. +//! Usage: `const ContextHolder ctx(CTX_new(), &CTX_free);`, where `CTX_new` and +//! `CTX_free` are the matching allocation and deallocation functions from +//! OpenSSL API. You can pass additional parameters to the allocation function +//! as needed; the deallocation function is assumed to take exactly one +//! parameter of the same type that is returned by the allocation function. +template +class ContextHolder : public std::unique_ptr { +public: + using std::unique_ptr::unique_ptr; +}; +template +ContextHolder(CryptoContext*, Deleter) -> ContextHolder; + SslExpected Quotient::pbkdf2HmacSha512(const QByteArray& password, const QByteArray& salt, int iterations, int keyLength) @@ -39,49 +56,42 @@ SslExpected Quotient::aesCtr256Encrypt(const QByteArray& plaintext, const QByteArray& key, const QByteArray& iv) { - EVP_CIPHER_CTX* ctx = nullptr; int length = 0; int ciphertextLength = 0; auto encrypted = QByteArray(plaintext.size() + AES_BLOCK_SIZE, u'\0'); - int status = RAND_bytes(reinterpret_cast(encrypted.data()), encrypted.size()); - if (status != 1) { + if (RAND_bytes(reinterpret_cast(encrypted.data()), encrypted.size()) != 1) { qWarning() << ERR_error_string(ERR_get_error(), nullptr); return ERR_get_error(); } auto data = encrypted.data(); constexpr auto mask = static_cast(~(1U << (63 / 8))); data[15 - 63 % 8] &= mask; - if (ctx = EVP_CIPHER_CTX_new(); !ctx) { + + const ContextHolder ctx(EVP_CIPHER_CTX_new(), &EVP_CIPHER_CTX_free); + if (!ctx) { qWarning() << ERR_error_string(ERR_get_error(), nullptr); return ERR_get_error(); } - status = EVP_EncryptInit_ex(ctx, EVP_aes_256_ctr(), nullptr, reinterpret_cast(key.data()), reinterpret_cast(iv.data())); - if (status != 1) { + if (EVP_EncryptInit_ex(ctx.get(), EVP_aes_256_ctr(), nullptr, reinterpret_cast(key.data()), reinterpret_cast(iv.data())) != 1) { qWarning() << ERR_error_string(ERR_get_error(), nullptr); - EVP_CIPHER_CTX_free(ctx); return ERR_get_error(); } - status = EVP_EncryptUpdate(ctx, reinterpret_cast(encrypted.data()), &length, reinterpret_cast(&plaintext.data()[0]), (int) plaintext.size()); - if (status != 1) { + if (EVP_EncryptUpdate(ctx.get(), reinterpret_cast(encrypted.data()), &length, reinterpret_cast(&plaintext.data()[0]), (int) plaintext.size()) != 1) { qWarning() << ERR_error_string(ERR_get_error(), nullptr); - EVP_CIPHER_CTX_free(ctx); return ERR_get_error(); } ciphertextLength = length; - status = EVP_EncryptFinal_ex(ctx, reinterpret_cast(encrypted.data()) + length, &length); - if (status != 1) { + if (EVP_EncryptFinal_ex(ctx.get(), reinterpret_cast(encrypted.data()) + length, &length) != 1) { qWarning() << ERR_error_string(ERR_get_error(), nullptr); - EVP_CIPHER_CTX_free(ctx); return ERR_get_error(); } ciphertextLength += length; encrypted.resize(ciphertextLength); - EVP_CIPHER_CTX_free(ctx); return encrypted; } @@ -89,7 +99,6 @@ SslExpected Quotient::aesCtr256Encrypt(const QByteArray& plaintext, do { \ if (Call_ != 1) { \ qWarning() << ERR_error_string(ERR_get_error(), nullptr); \ - EVP_PKEY_CTX_free(context); \ return ERR_get_error(); \ } \ } while (false) // End of macro @@ -99,21 +108,24 @@ SslExpected Quotient::hkdfSha256(const QByteArray& key, const QByteArray& info) { QByteArray result(64, u'\0'); - auto context = EVP_PKEY_CTX_new_id(EVP_PKEY_HKDF, nullptr); + const ContextHolder context(EVP_PKEY_CTX_new_id(EVP_PKEY_HKDF, nullptr), + &EVP_PKEY_CTX_free); - CALL_OPENSSL(EVP_PKEY_derive_init(context)); - CALL_OPENSSL(EVP_PKEY_CTX_set_hkdf_md(context, EVP_sha256())); + CALL_OPENSSL(EVP_PKEY_derive_init(context.get())); + CALL_OPENSSL(EVP_PKEY_CTX_set_hkdf_md(context.get(), EVP_sha256())); CALL_OPENSSL(EVP_PKEY_CTX_set1_hkdf_salt( - context, reinterpret_cast(salt.data()), + context.get(), reinterpret_cast(salt.data()), salt.size())); CALL_OPENSSL(EVP_PKEY_CTX_set1_hkdf_key( - context, reinterpret_cast(key.data()), + context.get(), reinterpret_cast(key.data()), key.size())); CALL_OPENSSL(EVP_PKEY_CTX_add1_hkdf_info( - context, reinterpret_cast(info.data()), + context.get(), reinterpret_cast(info.data()), info.size())); auto outputLength = unsignedSize(result); - CALL_OPENSSL(EVP_PKEY_derive(context, reinterpret_cast(result.data()), &outputLength)); + CALL_OPENSSL(EVP_PKEY_derive(context.get(), + reinterpret_cast(result.data()), + &outputLength)); if (outputLength != 64) { qCCritical(E2EE) << "hkdfSha256: the derived key is" << outputLength << "bytes instead of 64"; @@ -121,8 +133,6 @@ SslExpected Quotient::hkdfSha256(const QByteArray& key, return WrongDerivedKeyLength; } - EVP_PKEY_CTX_free(context); - auto macKey = result.mid(32); result.resize(32); return HkdfKeys{std::move(result), std::move(macKey)}; @@ -133,8 +143,7 @@ SslExpected Quotient::hmacSha256(const QByteArray& hmacKey, { uint32_t len = SHA256_DIGEST_LENGTH; QByteArray output(SHA256_DIGEST_LENGTH, u'\0'); - auto status = HMAC(EVP_sha256(), hmacKey.data(), hmacKey.size(), reinterpret_cast(data.data()), data.size(), reinterpret_cast(output.data()), &len); - if (!status) { + if (!HMAC(EVP_sha256(), hmacKey.data(), hmacKey.size(), reinterpret_cast(data.data()), data.size(), reinterpret_cast(output.data()), &len)) { qWarning() << ERR_error_string(ERR_get_error(), nullptr); return ERR_get_error(); } @@ -145,35 +154,31 @@ SslExpected Quotient::aesCtr256Decrypt(const QByteArray& ciphertext, const QByteArray& aes256Key, const QByteArray& iv) { - auto context = EVP_CIPHER_CTX_new(); + const ContextHolder context(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free); Q_ASSERT(context); int length = 0; int plaintextLength = 0; QByteArray decrypted(ciphertext.size(), u'\0'); - int status = EVP_DecryptInit_ex(context, EVP_aes_256_ctr(), nullptr, reinterpret_cast(aes256Key.data()), reinterpret_cast(iv.data())); - if (status != 1) { + if (EVP_DecryptInit_ex(context.get(), EVP_aes_256_ctr(), nullptr, reinterpret_cast(aes256Key.data()), reinterpret_cast(iv.data())) != 1) { qWarning() << ERR_error_string(ERR_get_error(), nullptr); return ERR_get_error(); } - status = EVP_DecryptUpdate(context, reinterpret_cast(decrypted.data()), &length, reinterpret_cast(&ciphertext.data()[0]), (int) ciphertext.size()); - if (status != 1) { + if (EVP_DecryptUpdate(context.get(), reinterpret_cast(decrypted.data()), &length, reinterpret_cast(&ciphertext.data()[0]), (int) ciphertext.size()) != 1) { qWarning() << ERR_error_string(ERR_get_error(), nullptr); return ERR_get_error(); } plaintextLength = length; - status = EVP_DecryptFinal_ex(context, reinterpret_cast(decrypted.data()) + length, &length); - if (status != 1) { + if (EVP_DecryptFinal_ex(context.get(), reinterpret_cast(decrypted.data()) + length, &length) != 1) { qWarning() << ERR_error_string(ERR_get_error(), nullptr); return ERR_get_error(); } plaintextLength += length; decrypted.resize(plaintextLength); - EVP_CIPHER_CTX_free(context); return decrypted; } @@ -185,17 +190,14 @@ QOlmExpected Quotient::curve25519AesSha2Decrypt( Q_ASSERT(context); QByteArray publicKey(olm_pk_key_length(), u'\0'); - auto status = olm_pk_key_from_private(context.get(), publicKey.data(), publicKey.size(), privateKey.data(), privateKey.size()); - if (status == olm_error()) { + if (olm_pk_key_from_private(context.get(), publicKey.data(), publicKey.size(), privateKey.data(), privateKey.size()) == olm_error()) return olm_pk_decryption_last_error_code(context.get()); - } QByteArray plaintext(olm_pk_max_plaintext_length(context.get(), ciphertext.size()), u'\0'); auto result = olm_pk_decrypt(context.get(), ephemeral.data(), ephemeral.size(), mac.data(), mac.size(), ciphertext.data(), ciphertext.size(), plaintext.data(), plaintext.size()); - - if (result == olm_error()) { + if (result == olm_error()) return olm_pk_decryption_last_error_code(context.get()); - } + plaintext.resize(result); return plaintext; } @@ -205,24 +207,20 @@ QOlmExpected Quotient::curve25519AesSha2Encrypt( { auto context = makeCStruct(olm_pk_encryption, olm_pk_encryption_size, olm_clear_pk_encryption); - auto status = olm_pk_encryption_set_recipient_key(context.get(), publicKey.data(), publicKey.size()); - if (status == olm_error()) { + if (olm_pk_encryption_set_recipient_key(context.get(), publicKey.data(), publicKey.size()) == olm_error()) return olm_pk_encryption_last_error_code(context.get()); - } QByteArray ephemeral(olm_pk_key_length(), 0); QByteArray mac(olm_pk_mac_length(context.get()), 0); QByteArray ciphertext(olm_pk_ciphertext_length(context.get(), plaintext.size()), 0); const auto random = getRandom(olm_pk_encrypt_random_length(context.get())); - auto result = olm_pk_encrypt(context.get(), plaintext.data(), - plaintext.size(), ciphertext.data(), - ciphertext.size(), mac.data(), mac.size(), - ephemeral.data(), ephemeral.size(), - random.data(), random.size()); - if (result == olm_error()) { + if (olm_pk_encrypt(context.get(), plaintext.data(), plaintext.size(), + ciphertext.data(), ciphertext.size(), mac.data(), + mac.size(), ephemeral.data(), ephemeral.size(), + random.data(), random.size()) + == olm_error()) return olm_pk_encryption_last_error_code(context.get()); - } return Curve25519Encrypted { .ciphertext = ciphertext,