Skip to content

Commit

Permalink
apacheGH-43142: [C++][Parquet] Refactor Encryptor API to use arrow::u…
Browse files Browse the repository at this point in the history
…til::span instead of raw pointers (apache#43195)

### Rationale for this change

See apache#43142. This is a follow up to apache#43071 which refactored the Decryptor API and added extra checks to prevent segfaults. This PR makes similar changes to the Encryptor API for consistency and better maintainability.

### What changes are included in this PR?

* Change `AesEncryptor::Encrypt` and `Encryptor::Encrypt` to use `arrow::util::span` instead of raw pointers
* Replace the `AesEncryptor::CiphertextSizeDelta` method with a `CiphertextLength` method that checks for overflow and abstracts the size difference behaviour away from consumer code for improved readability.

### Are these changes tested?

* This is mostly a refactoring of existing code so is covered by existing tests.

### Are there any user-facing changes?

No
* GitHub Issue: apache#43142

Lead-authored-by: Adam Reeve <[email protected]>
Co-authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
  • Loading branch information
adamreeve and pitrou authored Jul 11, 2024
1 parent c777ac8 commit 6e438e6
Show file tree
Hide file tree
Showing 10 changed files with 218 additions and 164 deletions.
15 changes: 8 additions & 7 deletions cpp/src/parquet/column_writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,10 @@ class SerializedPageWriter : public PageWriter {
if (data_encryptor_.get()) {
UpdateEncryption(encryption::kDictionaryPage);
PARQUET_THROW_NOT_OK(encryption_buffer_->Resize(
data_encryptor_->CiphertextSizeDelta() + output_data_len, false));
output_data_len = data_encryptor_->Encrypt(compressed_data->data(), output_data_len,
encryption_buffer_->mutable_data());
data_encryptor_->CiphertextLength(output_data_len), false));
output_data_len =
data_encryptor_->Encrypt(compressed_data->span_as<uint8_t>(),
encryption_buffer_->mutable_span_as<uint8_t>());
output_data_buffer = encryption_buffer_->data();
}

Expand Down Expand Up @@ -395,11 +396,11 @@ class SerializedPageWriter : public PageWriter {

if (data_encryptor_.get()) {
PARQUET_THROW_NOT_OK(encryption_buffer_->Resize(
data_encryptor_->CiphertextSizeDelta() + output_data_len, false));
data_encryptor_->CiphertextLength(output_data_len), false));
UpdateEncryption(encryption::kDataPage);
output_data_len = data_encryptor_->Encrypt(compressed_data->data(),
static_cast<int32_t>(output_data_len),
encryption_buffer_->mutable_data());
output_data_len =
data_encryptor_->Encrypt(compressed_data->span_as<uint8_t>(),
encryption_buffer_->mutable_span_as<uint8_t>());
output_data_buffer = encryption_buffer_->data();
}

Expand Down
231 changes: 140 additions & 91 deletions cpp/src/parquet/encryption/encryption_internal.cc

Large diffs are not rendered by default.

18 changes: 11 additions & 7 deletions cpp/src/parquet/encryption/encryption_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,22 @@ class PARQUET_EXPORT AesEncryptor {

~AesEncryptor();

/// Size difference between plaintext and ciphertext, for this cipher.
int CiphertextSizeDelta();
/// The size of the ciphertext, for this cipher and the specified plaintext length.
[[nodiscard]] int32_t CiphertextLength(int64_t plaintext_len) const;

/// Encrypts plaintext with the key and aad. Key length is passed only for validation.
/// If different from value in constructor, exception will be thrown.
int Encrypt(const uint8_t* plaintext, int plaintext_len, const uint8_t* key,
int key_len, const uint8_t* aad, int aad_len, uint8_t* ciphertext);
int Encrypt(::arrow::util::span<const uint8_t> plaintext,
::arrow::util::span<const uint8_t> key,
::arrow::util::span<const uint8_t> aad,
::arrow::util::span<uint8_t> ciphertext);

/// Encrypts plaintext footer, in order to compute footer signature (tag).
int SignedFooterEncrypt(const uint8_t* footer, int footer_len, const uint8_t* key,
int key_len, const uint8_t* aad, int aad_len,
const uint8_t* nonce, uint8_t* encrypted_footer);
int SignedFooterEncrypt(::arrow::util::span<const uint8_t> footer,
::arrow::util::span<const uint8_t> key,
::arrow::util::span<const uint8_t> aad,
::arrow::util::span<const uint8_t> nonce,
::arrow::util::span<uint8_t> encrypted_footer);

void WipeOut();

Expand Down
18 changes: 10 additions & 8 deletions cpp/src/parquet/encryption/encryption_internal_nossl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,26 @@ class AesEncryptor::AesEncryptorImpl {};

AesEncryptor::~AesEncryptor() {}

int AesEncryptor::SignedFooterEncrypt(const uint8_t* footer, int footer_len,
const uint8_t* key, int key_len, const uint8_t* aad,
int aad_len, const uint8_t* nonce,
uint8_t* encrypted_footer) {
int AesEncryptor::SignedFooterEncrypt(::arrow::util::span<const uint8_t> footer,
::arrow::util::span<const uint8_t> key,
::arrow::util::span<const uint8_t> aad,
::arrow::util::span<const uint8_t> nonce,
::arrow::util::span<uint8_t> encrypted_footer) {
ThrowOpenSSLRequiredException();
return -1;
}

void AesEncryptor::WipeOut() { ThrowOpenSSLRequiredException(); }

int AesEncryptor::CiphertextSizeDelta() {
int32_t AesEncryptor::CiphertextLength(int64_t plaintext_len) const {
ThrowOpenSSLRequiredException();
return -1;
}

int AesEncryptor::Encrypt(const uint8_t* plaintext, int plaintext_len, const uint8_t* key,
int key_len, const uint8_t* aad, int aad_len,
uint8_t* ciphertext) {
int AesEncryptor::Encrypt(::arrow::util::span<const uint8_t> plaintext,
::arrow::util::span<const uint8_t> key,
::arrow::util::span<const uint8_t> aad,
::arrow::util::span<uint8_t> ciphertext) {
ThrowOpenSSLRequiredException();
return -1;
}
Expand Down
20 changes: 8 additions & 12 deletions cpp/src/parquet/encryption/encryption_internal_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,12 @@ class TestAesEncryption : public ::testing::Test {

AesEncryptor encryptor(cipher_type, key_length_, metadata, write_length);

int expected_ciphertext_len =
static_cast<int>(plain_text_.size()) + encryptor.CiphertextSizeDelta();
int32_t expected_ciphertext_len =
encryptor.CiphertextLength(static_cast<int64_t>(plain_text_.size()));
std::vector<uint8_t> ciphertext(expected_ciphertext_len, '\0');

int ciphertext_length =
encryptor.Encrypt(str2bytes(plain_text_), static_cast<int>(plain_text_.size()),
str2bytes(key_), static_cast<int>(key_.size()), str2bytes(aad_),
static_cast<int>(aad_.size()), ciphertext.data());
int ciphertext_length = encryptor.Encrypt(str2span(plain_text_), str2span(key_),
str2span(aad_), ciphertext);

ASSERT_EQ(ciphertext_length, expected_ciphertext_len);

Expand Down Expand Up @@ -87,14 +85,12 @@ class TestAesEncryption : public ::testing::Test {

AesEncryptor encryptor(cipher_type, key_length_, metadata, write_length);

int expected_ciphertext_len =
static_cast<int>(plain_text_.size()) + encryptor.CiphertextSizeDelta();
int32_t expected_ciphertext_len =
encryptor.CiphertextLength(static_cast<int64_t>(plain_text_.size()));
std::vector<uint8_t> ciphertext(expected_ciphertext_len, '\0');

int ciphertext_length =
encryptor.Encrypt(str2bytes(plain_text_), static_cast<int>(plain_text_.size()),
str2bytes(key_), static_cast<int>(key_.size()), str2bytes(aad_),
static_cast<int>(aad_.size()), ciphertext.data());
int ciphertext_length = encryptor.Encrypt(str2span(plain_text_), str2span(key_),
str2span(aad_), ciphertext);

AesDecryptor decryptor(cipher_type, key_length_, metadata, write_length);

Expand Down
11 changes: 6 additions & 5 deletions cpp/src/parquet/encryption/internal_file_encryptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ Encryptor::Encryptor(encryption::AesEncryptor* aes_encryptor, const std::string&
aad_(aad),
pool_(pool) {}

int Encryptor::CiphertextSizeDelta() { return aes_encryptor_->CiphertextSizeDelta(); }
int32_t Encryptor::CiphertextLength(int64_t plaintext_len) const {
return aes_encryptor_->CiphertextLength(plaintext_len);
}

int Encryptor::Encrypt(const uint8_t* plaintext, int plaintext_len, uint8_t* ciphertext) {
return aes_encryptor_->Encrypt(plaintext, plaintext_len, str2bytes(key_),
static_cast<int>(key_.size()), str2bytes(aad_),
static_cast<int>(aad_.size()), ciphertext);
int Encryptor::Encrypt(::arrow::util::span<const uint8_t> plaintext,
::arrow::util::span<uint8_t> ciphertext) {
return aes_encryptor_->Encrypt(plaintext, str2span(key_), str2span(aad_), ciphertext);
}

// InternalFileEncryptor
Expand Down
6 changes: 4 additions & 2 deletions cpp/src/parquet/encryption/internal_file_encryptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ class PARQUET_EXPORT Encryptor {
void UpdateAad(const std::string& aad) { aad_ = aad; }
::arrow::MemoryPool* pool() { return pool_; }

int CiphertextSizeDelta();
int Encrypt(const uint8_t* plaintext, int plaintext_len, uint8_t* ciphertext);
[[nodiscard]] int32_t CiphertextLength(int64_t plaintext_len) const;

int Encrypt(::arrow::util::span<const uint8_t> plaintext,
::arrow::util::span<uint8_t> ciphertext);

bool EncryptColumnMetaData(
bool encrypted_footer,
Expand Down
15 changes: 7 additions & 8 deletions cpp/src/parquet/encryption/key_toolkit_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,14 @@ std::string EncryptKeyLocally(const std::string& key_bytes, const std::string& m
static_cast<int>(master_key.size()), false,
false /*write_length*/);

int encrypted_key_len =
static_cast<int>(key_bytes.size()) + key_encryptor.CiphertextSizeDelta();
int32_t encrypted_key_len =
key_encryptor.CiphertextLength(static_cast<int64_t>(key_bytes.size()));
std::string encrypted_key(encrypted_key_len, '\0');
encrypted_key_len = key_encryptor.Encrypt(
reinterpret_cast<const uint8_t*>(key_bytes.data()),
static_cast<int>(key_bytes.size()),
reinterpret_cast<const uint8_t*>(master_key.data()),
static_cast<int>(master_key.size()), reinterpret_cast<const uint8_t*>(aad.data()),
static_cast<int>(aad.size()), reinterpret_cast<uint8_t*>(&encrypted_key[0]));
::arrow::util::span<uint8_t> encrypted_key_span(
reinterpret_cast<uint8_t*>(&encrypted_key[0]), encrypted_key_len);

encrypted_key_len = key_encryptor.Encrypt(str2span(key_bytes), str2span(master_key),
str2span(aad), encrypted_key_span);

return ::arrow::util::base64_encode(
::std::string_view(encrypted_key.data(), encrypted_key_len));
Expand Down
34 changes: 17 additions & 17 deletions cpp/src/parquet/metadata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -640,11 +640,13 @@ class FileMetaData::FileMetaDataImpl {
uint32_t serialized_len = metadata_len_;
ThriftSerializer serializer;
serializer.SerializeToBuffer(metadata_.get(), &serialized_len, &serialized_data);
::arrow::util::span<const uint8_t> serialized_data_span(serialized_data,
serialized_len);

// encrypt with nonce
auto nonce = const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(signature));
auto tag = const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(signature)) +
encryption::kNonceLength;
::arrow::util::span<const uint8_t> nonce(reinterpret_cast<const uint8_t*>(signature),
encryption::kNonceLength);
auto tag = reinterpret_cast<const uint8_t*>(signature) + encryption::kNonceLength;

std::string key = file_decryptor_->GetFooterKey();
std::string aad = encryption::CreateFooterAad(file_decryptor_->file_aad());
Expand All @@ -653,13 +655,11 @@ class FileMetaData::FileMetaDataImpl {
file_decryptor_->algorithm(), static_cast<int>(key.size()), true,
false /*write_length*/, nullptr);

std::shared_ptr<Buffer> encrypted_buffer = std::static_pointer_cast<ResizableBuffer>(
AllocateBuffer(file_decryptor_->pool(),
aes_encryptor->CiphertextSizeDelta() + serialized_len));
std::shared_ptr<Buffer> encrypted_buffer = AllocateBuffer(
file_decryptor_->pool(), aes_encryptor->CiphertextLength(serialized_len));
uint32_t encrypted_len = aes_encryptor->SignedFooterEncrypt(
serialized_data, serialized_len, str2bytes(key), static_cast<int>(key.size()),
str2bytes(aad), static_cast<int>(aad.size()), nonce,
encrypted_buffer->mutable_data());
serialized_data_span, str2span(key), str2span(aad), nonce,
encrypted_buffer->mutable_span_as<uint8_t>());
// Delete AES encryptor object. It was created only to verify the footer signature.
aes_encryptor->WipeOut();
delete aes_encryptor;
Expand Down Expand Up @@ -701,12 +701,12 @@ class FileMetaData::FileMetaDataImpl {
uint8_t* serialized_data;
uint32_t serialized_len;
serializer.SerializeToBuffer(metadata_.get(), &serialized_len, &serialized_data);
::arrow::util::span<const uint8_t> serialized_data_span(serialized_data,
serialized_len);

// encrypt the footer key
std::vector<uint8_t> encrypted_data(encryptor->CiphertextSizeDelta() +
serialized_len);
unsigned encrypted_len =
encryptor->Encrypt(serialized_data, serialized_len, encrypted_data.data());
std::vector<uint8_t> encrypted_data(encryptor->CiphertextLength(serialized_len));
int encrypted_len = encryptor->Encrypt(serialized_data_span, encrypted_data);

// write unencrypted footer
PARQUET_THROW_NOT_OK(dst->Write(serialized_data, serialized_len));
Expand Down Expand Up @@ -1559,11 +1559,11 @@ class ColumnChunkMetaDataBuilder::ColumnChunkMetaDataBuilderImpl {

serializer.SerializeToBuffer(&column_chunk_->meta_data, &serialized_len,
&serialized_data);
::arrow::util::span<const uint8_t> serialized_data_span(serialized_data,
serialized_len);

std::vector<uint8_t> encrypted_data(encryptor->CiphertextSizeDelta() +
serialized_len);
unsigned encrypted_len =
encryptor->Encrypt(serialized_data, serialized_len, encrypted_data.data());
std::vector<uint8_t> encrypted_data(encryptor->CiphertextLength(serialized_len));
int encrypted_len = encryptor->Encrypt(serialized_data_span, encrypted_data);

const char* temp =
const_cast<const char*>(reinterpret_cast<char*>(encrypted_data.data()));
Expand Down
14 changes: 7 additions & 7 deletions cpp/src/parquet/thrift_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,8 @@ class ThriftDeserializer {
throw ParquetException(ss.str());
}
// decrypt
auto decrypted_buffer = std::static_pointer_cast<ResizableBuffer>(AllocateBuffer(
decryptor->pool(), decryptor->PlaintextLength(static_cast<int32_t>(clen))));
auto decrypted_buffer = AllocateBuffer(
decryptor->pool(), decryptor->PlaintextLength(static_cast<int32_t>(clen)));
::arrow::util::span<const uint8_t> cipher_buf(buf, clen);
uint32_t decrypted_buffer_len =
decryptor->Decrypt(cipher_buf, decrypted_buffer->mutable_span_as<uint8_t>());
Expand Down Expand Up @@ -525,13 +525,13 @@ class ThriftSerializer {
}
}

int64_t SerializeEncryptedObj(ArrowOutputStream* out, uint8_t* out_buffer,
int64_t SerializeEncryptedObj(ArrowOutputStream* out, const uint8_t* out_buffer,
uint32_t out_length, Encryptor* encryptor) {
auto cipher_buffer = std::static_pointer_cast<ResizableBuffer>(AllocateBuffer(
encryptor->pool(),
static_cast<int64_t>(encryptor->CiphertextSizeDelta() + out_length)));
auto cipher_buffer =
AllocateBuffer(encryptor->pool(), encryptor->CiphertextLength(out_length));
::arrow::util::span<const uint8_t> out_span(out_buffer, out_length);
int cipher_buffer_len =
encryptor->Encrypt(out_buffer, out_length, cipher_buffer->mutable_data());
encryptor->Encrypt(out_span, cipher_buffer->mutable_span_as<uint8_t>());

PARQUET_THROW_NOT_OK(out->Write(cipher_buffer->data(), cipher_buffer_len));
return static_cast<int64_t>(cipher_buffer_len);
Expand Down

0 comments on commit 6e438e6

Please sign in to comment.