Skip to content

Commit

Permalink
apacheGH-41431: [C++][Parquet][Dataset] Fix repeated scan on encrypte…
Browse files Browse the repository at this point in the history
…d dataset (apache#41550)

### Rationale for this change

When parquet dataset is reused to create multiple scanners, `FileMetaData` objects are cached to avoid parsing them again. However, these caused issues on encrypted files since internal file decryptors were no longer created by cached `FileMetaData` objects.

### What changes are included in this PR?

Expose file_decryptor from FileMetaData and set it properly.

### Are these changes tested?

Yes, modify the test to reproduce the issue and assure fixed.

### Are there any user-facing changes?

No.
* GitHub Issue: apache#41431

Authored-by: Gang Wu <[email protected]>
Signed-off-by: Gang Wu <[email protected]>
  • Loading branch information
wgtmac authored May 8, 2024
1 parent 51689a0 commit 5385926
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 48 deletions.
25 changes: 15 additions & 10 deletions cpp/src/arrow/dataset/file_parquet_encryption_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,22 @@ class DatasetEncryptionTestBase : public ::testing::Test {
FileSystemDatasetFactory::Make(file_system_, selector,
file_format, factory_options));

// Read dataset into table
// Create the dataset
ASSERT_OK_AND_ASSIGN(auto dataset, dataset_factory->Finish());
ASSERT_OK_AND_ASSIGN(auto scanner_builder, dataset->NewScan());
ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish());
ASSERT_OK_AND_ASSIGN(auto read_table, scanner->ToTable());

// Verify the data was read correctly
ASSERT_OK_AND_ASSIGN(auto combined_table, read_table->CombineChunks());
// Validate the table
ASSERT_OK(combined_table->ValidateFull());
AssertTablesEqual(*combined_table, *table_);

// Reuse the dataset above to scan it twice to make sure decryption works correctly.
for (size_t i = 0; i < 2; ++i) {
// Read dataset into table
ASSERT_OK_AND_ASSIGN(auto scanner_builder, dataset->NewScan());
ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish());
ASSERT_OK_AND_ASSIGN(auto read_table, scanner->ToTable());

// Verify the data was read correctly
ASSERT_OK_AND_ASSIGN(auto combined_table, read_table->CombineChunks());
// Validate the table
ASSERT_OK(combined_table->ValidateFull());
AssertTablesEqual(*combined_table, *table_);
}
}

protected:
Expand Down
83 changes: 45 additions & 38 deletions cpp/src/parquet/file_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,16 +215,14 @@ class SerializedRowGroup : public RowGroupReader::Contents {
std::shared_ptr<::arrow::io::internal::ReadRangeCache> cached_source,
int64_t source_size, FileMetaData* file_metadata,
int row_group_number, ReaderProperties props,
std::shared_ptr<Buffer> prebuffered_column_chunks_bitmap,
std::shared_ptr<InternalFileDecryptor> file_decryptor = nullptr)
std::shared_ptr<Buffer> prebuffered_column_chunks_bitmap)
: source_(std::move(source)),
cached_source_(std::move(cached_source)),
source_size_(source_size),
file_metadata_(file_metadata),
properties_(std::move(props)),
row_group_ordinal_(row_group_number),
prebuffered_column_chunks_bitmap_(std::move(prebuffered_column_chunks_bitmap)),
file_decryptor_(std::move(file_decryptor)) {
prebuffered_column_chunks_bitmap_(std::move(prebuffered_column_chunks_bitmap)) {
row_group_metadata_ = file_metadata->RowGroup(row_group_number);
}

Expand Down Expand Up @@ -263,10 +261,10 @@ class SerializedRowGroup : public RowGroupReader::Contents {
}

// The column is encrypted
std::shared_ptr<Decryptor> meta_decryptor =
GetColumnMetaDecryptor(crypto_metadata.get(), file_decryptor_.get());
std::shared_ptr<Decryptor> data_decryptor =
GetColumnDataDecryptor(crypto_metadata.get(), file_decryptor_.get());
std::shared_ptr<Decryptor> meta_decryptor = GetColumnMetaDecryptor(
crypto_metadata.get(), file_metadata_->file_decryptor().get());
std::shared_ptr<Decryptor> data_decryptor = GetColumnDataDecryptor(
crypto_metadata.get(), file_metadata_->file_decryptor().get());
ARROW_DCHECK_NE(meta_decryptor, nullptr);
ARROW_DCHECK_NE(data_decryptor, nullptr);

Expand All @@ -291,7 +289,6 @@ class SerializedRowGroup : public RowGroupReader::Contents {
ReaderProperties properties_;
int row_group_ordinal_;
const std::shared_ptr<const Buffer> prebuffered_column_chunks_bitmap_;
std::shared_ptr<InternalFileDecryptor> file_decryptor_;
};

// ----------------------------------------------------------------------
Expand All @@ -316,7 +313,9 @@ class SerializedFile : public ParquetFileReader::Contents {
}

void Close() override {
if (file_decryptor_) file_decryptor_->WipeOutDecryptionKeys();
if (file_metadata_ && file_metadata_->file_decryptor()) {
file_metadata_->file_decryptor()->WipeOutDecryptionKeys();
}
}

std::shared_ptr<RowGroupReader> GetRowGroup(int i) override {
Expand All @@ -330,7 +329,7 @@ class SerializedFile : public ParquetFileReader::Contents {

std::unique_ptr<SerializedRowGroup> contents = std::make_unique<SerializedRowGroup>(
source_, cached_source_, source_size_, file_metadata_.get(), i, properties_,
std::move(prebuffered_column_chunks_bitmap), file_decryptor_);
std::move(prebuffered_column_chunks_bitmap));
return std::make_shared<RowGroupReader>(std::move(contents));
}

Expand All @@ -346,8 +345,9 @@ class SerializedFile : public ParquetFileReader::Contents {
"forget to call ParquetFileReader::Open() first?");
}
if (!page_index_reader_) {
page_index_reader_ = PageIndexReader::Make(source_.get(), file_metadata_,
properties_, file_decryptor_.get());
page_index_reader_ =
PageIndexReader::Make(source_.get(), file_metadata_, properties_,
file_metadata_->file_decryptor().get());
}
return page_index_reader_;
}
Expand All @@ -362,8 +362,8 @@ class SerializedFile : public ParquetFileReader::Contents {
"forget to call ParquetFileReader::Open() first?");
}
if (!bloom_filter_reader_) {
bloom_filter_reader_ =
BloomFilterReader::Make(source_, file_metadata_, properties_, file_decryptor_);
bloom_filter_reader_ = BloomFilterReader::Make(source_, file_metadata_, properties_,
file_metadata_->file_decryptor());
if (bloom_filter_reader_ == nullptr) {
throw ParquetException("Cannot create BloomFilterReader");
}
Expand Down Expand Up @@ -441,10 +441,12 @@ class SerializedFile : public ParquetFileReader::Contents {
// Parse the footer depending on encryption type
const bool is_encrypted_footer =
memcmp(footer_buffer->data() + footer_read_size - 4, kParquetEMagic, 4) == 0;
std::shared_ptr<InternalFileDecryptor> file_decryptor;
if (is_encrypted_footer) {
// Encrypted file with Encrypted footer.
const std::pair<int64_t, uint32_t> read_size =
ParseMetaDataOfEncryptedFileWithEncryptedFooter(metadata_buffer, metadata_len);
ParseMetaDataOfEncryptedFileWithEncryptedFooter(metadata_buffer, metadata_len,
&file_decryptor);
// Read the actual footer
metadata_start = read_size.first;
metadata_len = read_size.second;
Expand All @@ -453,8 +455,8 @@ class SerializedFile : public ParquetFileReader::Contents {
// Fall through
}

const uint32_t read_metadata_len =
ParseUnencryptedFileMetadata(metadata_buffer, metadata_len);
const uint32_t read_metadata_len = ParseUnencryptedFileMetadata(
metadata_buffer, metadata_len, std::move(file_decryptor));
auto file_decryption_properties = properties_.file_decryption_properties().get();
if (is_encrypted_footer) {
// Nothing else to do here.
Expand Down Expand Up @@ -550,34 +552,37 @@ class SerializedFile : public ParquetFileReader::Contents {
// Parse the footer depending on encryption type
const bool is_encrypted_footer =
memcmp(footer_buffer->data() + footer_read_size - 4, kParquetEMagic, 4) == 0;
std::shared_ptr<InternalFileDecryptor> file_decryptor;
if (is_encrypted_footer) {
// Encrypted file with Encrypted footer.
std::pair<int64_t, uint32_t> read_size;
BEGIN_PARQUET_CATCH_EXCEPTIONS
read_size =
ParseMetaDataOfEncryptedFileWithEncryptedFooter(metadata_buffer, metadata_len);
read_size = ParseMetaDataOfEncryptedFileWithEncryptedFooter(
metadata_buffer, metadata_len, &file_decryptor);
END_PARQUET_CATCH_EXCEPTIONS
// Read the actual footer
int64_t metadata_start = read_size.first;
metadata_len = read_size.second;
return source_->ReadAsync(metadata_start, metadata_len)
.Then([this, metadata_len, is_encrypted_footer](
.Then([this, metadata_len, is_encrypted_footer, file_decryptor](
const std::shared_ptr<::arrow::Buffer>& metadata_buffer) {
// Continue and read the file footer
return ParseMetaDataFinal(metadata_buffer, metadata_len, is_encrypted_footer);
return ParseMetaDataFinal(metadata_buffer, metadata_len, is_encrypted_footer,
file_decryptor);
});
}
return ParseMetaDataFinal(std::move(metadata_buffer), metadata_len,
is_encrypted_footer);
is_encrypted_footer, std::move(file_decryptor));
}

// Continuation
::arrow::Status ParseMetaDataFinal(std::shared_ptr<::arrow::Buffer> metadata_buffer,
uint32_t metadata_len,
const bool is_encrypted_footer) {
::arrow::Status ParseMetaDataFinal(
std::shared_ptr<::arrow::Buffer> metadata_buffer, uint32_t metadata_len,
const bool is_encrypted_footer,
std::shared_ptr<InternalFileDecryptor> file_decryptor) {
BEGIN_PARQUET_CATCH_EXCEPTIONS
const uint32_t read_metadata_len =
ParseUnencryptedFileMetadata(metadata_buffer, metadata_len);
const uint32_t read_metadata_len = ParseUnencryptedFileMetadata(
metadata_buffer, metadata_len, std::move(file_decryptor));
auto file_decryption_properties = properties_.file_decryption_properties().get();
if (is_encrypted_footer) {
// Nothing else to do here.
Expand Down Expand Up @@ -608,11 +613,11 @@ class SerializedFile : public ParquetFileReader::Contents {
// Maps row group ordinal and prebuffer status of its column chunks in the form of a
// bitmap buffer.
std::unordered_map<int, std::shared_ptr<Buffer>> prebuffered_column_chunks_;
std::shared_ptr<InternalFileDecryptor> file_decryptor_;

// \return The true length of the metadata in bytes
uint32_t ParseUnencryptedFileMetadata(const std::shared_ptr<Buffer>& footer_buffer,
const uint32_t metadata_len);
uint32_t ParseUnencryptedFileMetadata(
const std::shared_ptr<Buffer>& footer_buffer, const uint32_t metadata_len,
std::shared_ptr<InternalFileDecryptor> file_decryptor);

std::string HandleAadPrefix(FileDecryptionProperties* file_decryption_properties,
EncryptionAlgorithm& algo);
Expand All @@ -624,11 +629,13 @@ class SerializedFile : public ParquetFileReader::Contents {

// \return The position and size of the actual footer
std::pair<int64_t, uint32_t> ParseMetaDataOfEncryptedFileWithEncryptedFooter(
const std::shared_ptr<Buffer>& crypto_metadata_buffer, uint32_t footer_len);
const std::shared_ptr<Buffer>& crypto_metadata_buffer, uint32_t footer_len,
std::shared_ptr<InternalFileDecryptor>* file_decryptor);
};

uint32_t SerializedFile::ParseUnencryptedFileMetadata(
const std::shared_ptr<Buffer>& metadata_buffer, const uint32_t metadata_len) {
const std::shared_ptr<Buffer>& metadata_buffer, const uint32_t metadata_len,
std::shared_ptr<InternalFileDecryptor> file_decryptor) {
if (metadata_buffer->size() != metadata_len) {
throw ParquetException("Failed reading metadata buffer (requested " +
std::to_string(metadata_len) + " bytes but got " +
Expand All @@ -637,15 +644,15 @@ uint32_t SerializedFile::ParseUnencryptedFileMetadata(
uint32_t read_metadata_len = metadata_len;
// The encrypted read path falls through to here, so pass in the decryptor
file_metadata_ = FileMetaData::Make(metadata_buffer->data(), &read_metadata_len,
properties_, file_decryptor_);
properties_, std::move(file_decryptor));
return read_metadata_len;
}

std::pair<int64_t, uint32_t>
SerializedFile::ParseMetaDataOfEncryptedFileWithEncryptedFooter(
const std::shared_ptr<::arrow::Buffer>& crypto_metadata_buffer,
// both metadata & crypto metadata length
const uint32_t footer_len) {
const uint32_t footer_len, std::shared_ptr<InternalFileDecryptor>* file_decryptor) {
// encryption with encrypted footer
// Check if the footer_buffer contains the entire metadata
if (crypto_metadata_buffer->size() != footer_len) {
Expand All @@ -664,7 +671,7 @@ SerializedFile::ParseMetaDataOfEncryptedFileWithEncryptedFooter(
// Handle AAD prefix
EncryptionAlgorithm algo = file_crypto_metadata->encryption_algorithm();
std::string file_aad = HandleAadPrefix(file_decryption_properties, algo);
file_decryptor_ = std::make_shared<InternalFileDecryptor>(
*file_decryptor = std::make_shared<InternalFileDecryptor>(
file_decryption_properties, file_aad, algo.algorithm,
file_crypto_metadata->key_metadata(), properties_.memory_pool());

Expand All @@ -683,12 +690,12 @@ void SerializedFile::ParseMetaDataOfEncryptedFileWithPlaintextFooter(
EncryptionAlgorithm algo = file_metadata_->encryption_algorithm();
// Handle AAD prefix
std::string file_aad = HandleAadPrefix(file_decryption_properties, algo);
file_decryptor_ = std::make_shared<InternalFileDecryptor>(
auto file_decryptor = std::make_shared<InternalFileDecryptor>(
file_decryption_properties, file_aad, algo.algorithm,
file_metadata_->footer_signing_key_metadata(), properties_.memory_pool());
// set the InternalFileDecryptor in the metadata as well, as it's used
// for signature verification and for ColumnChunkMetaData creation.
file_metadata_->set_file_decryptor(file_decryptor_);
file_metadata_->set_file_decryptor(std::move(file_decryptor));

if (file_decryption_properties->check_plaintext_footer_integrity()) {
if (metadata_len - read_metadata_len !=
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/parquet/metadata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,10 @@ class FileMetaData::FileMetaDataImpl {
file_decryptor_ = std::move(file_decryptor);
}

const std::shared_ptr<InternalFileDecryptor>& file_decryptor() const {
return file_decryptor_;
}

private:
friend FileMetaDataBuilder;
uint32_t metadata_len_ = 0;
Expand Down Expand Up @@ -947,6 +951,10 @@ void FileMetaData::set_file_decryptor(
impl_->set_file_decryptor(std::move(file_decryptor));
}

const std::shared_ptr<InternalFileDecryptor>& FileMetaData::file_decryptor() const {
return impl_->file_decryptor();
}

ParquetVersion::type FileMetaData::version() const {
switch (impl_->version()) {
case 1:
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/parquet/metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,12 +399,14 @@ class PARQUET_EXPORT FileMetaData {
private:
friend FileMetaDataBuilder;
friend class SerializedFile;
friend class SerializedRowGroup;

explicit FileMetaData(const void* serialized_metadata, uint32_t* metadata_len,
const ReaderProperties& properties,
std::shared_ptr<InternalFileDecryptor> file_decryptor = NULLPTR);

void set_file_decryptor(std::shared_ptr<InternalFileDecryptor> file_decryptor);
const std::shared_ptr<InternalFileDecryptor>& file_decryptor() const;

// PIMPL Idiom
FileMetaData();
Expand Down

0 comments on commit 5385926

Please sign in to comment.