Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-41431: [C++][Parquet][Dataset] Fix repeated scan on encrypted dataset #41550

Merged
merged 2 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions cpp/src/arrow/dataset/file_parquet_encryption_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,22 @@ class DatasetEncryptionTestBase : public ::testing::Test {
FileSystemDatasetFactory::Make(file_system_, selector,
file_format, factory_options));

// Read dataset into table
// Create the dataset
ASSERT_OK_AND_ASSIGN(auto dataset, dataset_factory->Finish());
ASSERT_OK_AND_ASSIGN(auto scanner_builder, dataset->NewScan());
ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish());
ASSERT_OK_AND_ASSIGN(auto read_table, scanner->ToTable());

// Verify the data was read correctly
ASSERT_OK_AND_ASSIGN(auto combined_table, read_table->CombineChunks());
// Validate the table
ASSERT_OK(combined_table->ValidateFull());
AssertTablesEqual(*combined_table, *table_);

// Reuse the dataset above to scan it twice to make sure decryption works correctly.
for (size_t i = 0; i < 2; ++i) {
// Read dataset into table
ASSERT_OK_AND_ASSIGN(auto scanner_builder, dataset->NewScan());
ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish());
ASSERT_OK_AND_ASSIGN(auto read_table, scanner->ToTable());

// Verify the data was read correctly
ASSERT_OK_AND_ASSIGN(auto combined_table, read_table->CombineChunks());
// Validate the table
ASSERT_OK(combined_table->ValidateFull());
AssertTablesEqual(*combined_table, *table_);
}
}

protected:
Expand Down
83 changes: 45 additions & 38 deletions cpp/src/parquet/file_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,16 +215,14 @@ class SerializedRowGroup : public RowGroupReader::Contents {
std::shared_ptr<::arrow::io::internal::ReadRangeCache> cached_source,
int64_t source_size, FileMetaData* file_metadata,
int row_group_number, ReaderProperties props,
std::shared_ptr<Buffer> prebuffered_column_chunks_bitmap,
std::shared_ptr<InternalFileDecryptor> file_decryptor = nullptr)
std::shared_ptr<Buffer> prebuffered_column_chunks_bitmap)
: source_(std::move(source)),
cached_source_(std::move(cached_source)),
source_size_(source_size),
file_metadata_(file_metadata),
properties_(std::move(props)),
row_group_ordinal_(row_group_number),
prebuffered_column_chunks_bitmap_(std::move(prebuffered_column_chunks_bitmap)),
file_decryptor_(std::move(file_decryptor)) {
prebuffered_column_chunks_bitmap_(std::move(prebuffered_column_chunks_bitmap)) {
row_group_metadata_ = file_metadata->RowGroup(row_group_number);
}

Expand Down Expand Up @@ -263,10 +261,10 @@ class SerializedRowGroup : public RowGroupReader::Contents {
}

// The column is encrypted
std::shared_ptr<Decryptor> meta_decryptor =
GetColumnMetaDecryptor(crypto_metadata.get(), file_decryptor_.get());
std::shared_ptr<Decryptor> data_decryptor =
GetColumnDataDecryptor(crypto_metadata.get(), file_decryptor_.get());
std::shared_ptr<Decryptor> meta_decryptor = GetColumnMetaDecryptor(
crypto_metadata.get(), file_metadata_->file_decryptor().get());
std::shared_ptr<Decryptor> data_decryptor = GetColumnDataDecryptor(
crypto_metadata.get(), file_metadata_->file_decryptor().get());
ARROW_DCHECK_NE(meta_decryptor, nullptr);
ARROW_DCHECK_NE(data_decryptor, nullptr);

Expand All @@ -291,7 +289,6 @@ class SerializedRowGroup : public RowGroupReader::Contents {
ReaderProperties properties_;
int row_group_ordinal_;
const std::shared_ptr<const Buffer> prebuffered_column_chunks_bitmap_;
std::shared_ptr<InternalFileDecryptor> file_decryptor_;
};

// ----------------------------------------------------------------------
Expand All @@ -316,7 +313,9 @@ class SerializedFile : public ParquetFileReader::Contents {
}

void Close() override {
if (file_decryptor_) file_decryptor_->WipeOutDecryptionKeys();
if (file_metadata_ && file_metadata_->file_decryptor()) {
file_metadata_->file_decryptor()->WipeOutDecryptionKeys();
}
}

std::shared_ptr<RowGroupReader> GetRowGroup(int i) override {
Expand All @@ -330,7 +329,7 @@ class SerializedFile : public ParquetFileReader::Contents {

std::unique_ptr<SerializedRowGroup> contents = std::make_unique<SerializedRowGroup>(
source_, cached_source_, source_size_, file_metadata_.get(), i, properties_,
std::move(prebuffered_column_chunks_bitmap), file_decryptor_);
std::move(prebuffered_column_chunks_bitmap));
return std::make_shared<RowGroupReader>(std::move(contents));
}

Expand All @@ -346,8 +345,9 @@ class SerializedFile : public ParquetFileReader::Contents {
"forget to call ParquetFileReader::Open() first?");
}
if (!page_index_reader_) {
page_index_reader_ = PageIndexReader::Make(source_.get(), file_metadata_,
properties_, file_decryptor_.get());
page_index_reader_ =
PageIndexReader::Make(source_.get(), file_metadata_, properties_,
file_metadata_->file_decryptor().get());
}
return page_index_reader_;
}
Expand All @@ -362,8 +362,8 @@ class SerializedFile : public ParquetFileReader::Contents {
"forget to call ParquetFileReader::Open() first?");
}
if (!bloom_filter_reader_) {
bloom_filter_reader_ =
BloomFilterReader::Make(source_, file_metadata_, properties_, file_decryptor_);
bloom_filter_reader_ = BloomFilterReader::Make(source_, file_metadata_, properties_,
file_metadata_->file_decryptor());
if (bloom_filter_reader_ == nullptr) {
throw ParquetException("Cannot create BloomFilterReader");
}
Expand Down Expand Up @@ -441,10 +441,12 @@ class SerializedFile : public ParquetFileReader::Contents {
// Parse the footer depending on encryption type
const bool is_encrypted_footer =
memcmp(footer_buffer->data() + footer_read_size - 4, kParquetEMagic, 4) == 0;
std::shared_ptr<InternalFileDecryptor> file_decryptor;
if (is_encrypted_footer) {
// Encrypted file with Encrypted footer.
const std::pair<int64_t, uint32_t> read_size =
ParseMetaDataOfEncryptedFileWithEncryptedFooter(metadata_buffer, metadata_len);
ParseMetaDataOfEncryptedFileWithEncryptedFooter(metadata_buffer, metadata_len,
&file_decryptor);
// Read the actual footer
metadata_start = read_size.first;
metadata_len = read_size.second;
Expand All @@ -453,8 +455,8 @@ class SerializedFile : public ParquetFileReader::Contents {
// Fall through
}

const uint32_t read_metadata_len =
ParseUnencryptedFileMetadata(metadata_buffer, metadata_len);
const uint32_t read_metadata_len = ParseUnencryptedFileMetadata(
metadata_buffer, metadata_len, std::move(file_decryptor));
auto file_decryption_properties = properties_.file_decryption_properties().get();
if (is_encrypted_footer) {
// Nothing else to do here.
Expand Down Expand Up @@ -550,34 +552,37 @@ class SerializedFile : public ParquetFileReader::Contents {
// Parse the footer depending on encryption type
const bool is_encrypted_footer =
memcmp(footer_buffer->data() + footer_read_size - 4, kParquetEMagic, 4) == 0;
std::shared_ptr<InternalFileDecryptor> file_decryptor;
if (is_encrypted_footer) {
// Encrypted file with Encrypted footer.
std::pair<int64_t, uint32_t> read_size;
BEGIN_PARQUET_CATCH_EXCEPTIONS
read_size =
ParseMetaDataOfEncryptedFileWithEncryptedFooter(metadata_buffer, metadata_len);
read_size = ParseMetaDataOfEncryptedFileWithEncryptedFooter(
metadata_buffer, metadata_len, &file_decryptor);
END_PARQUET_CATCH_EXCEPTIONS
// Read the actual footer
int64_t metadata_start = read_size.first;
metadata_len = read_size.second;
return source_->ReadAsync(metadata_start, metadata_len)
.Then([this, metadata_len, is_encrypted_footer](
.Then([this, metadata_len, is_encrypted_footer, file_decryptor](
const std::shared_ptr<::arrow::Buffer>& metadata_buffer) {
// Continue and read the file footer
return ParseMetaDataFinal(metadata_buffer, metadata_len, is_encrypted_footer);
return ParseMetaDataFinal(metadata_buffer, metadata_len, is_encrypted_footer,
file_decryptor);
});
}
return ParseMetaDataFinal(std::move(metadata_buffer), metadata_len,
is_encrypted_footer);
is_encrypted_footer, std::move(file_decryptor));
}

// Continuation
::arrow::Status ParseMetaDataFinal(std::shared_ptr<::arrow::Buffer> metadata_buffer,
uint32_t metadata_len,
const bool is_encrypted_footer) {
::arrow::Status ParseMetaDataFinal(
std::shared_ptr<::arrow::Buffer> metadata_buffer, uint32_t metadata_len,
const bool is_encrypted_footer,
std::shared_ptr<InternalFileDecryptor> file_decryptor) {
BEGIN_PARQUET_CATCH_EXCEPTIONS
const uint32_t read_metadata_len =
ParseUnencryptedFileMetadata(metadata_buffer, metadata_len);
const uint32_t read_metadata_len = ParseUnencryptedFileMetadata(
metadata_buffer, metadata_len, std::move(file_decryptor));
auto file_decryption_properties = properties_.file_decryption_properties().get();
if (is_encrypted_footer) {
// Nothing else to do here.
Expand Down Expand Up @@ -608,11 +613,11 @@ class SerializedFile : public ParquetFileReader::Contents {
// Maps row group ordinal and prebuffer status of its column chunks in the form of a
// bitmap buffer.
std::unordered_map<int, std::shared_ptr<Buffer>> prebuffered_column_chunks_;
std::shared_ptr<InternalFileDecryptor> file_decryptor_;

// \return The true length of the metadata in bytes
uint32_t ParseUnencryptedFileMetadata(const std::shared_ptr<Buffer>& footer_buffer,
const uint32_t metadata_len);
uint32_t ParseUnencryptedFileMetadata(
const std::shared_ptr<Buffer>& footer_buffer, const uint32_t metadata_len,
std::shared_ptr<InternalFileDecryptor> file_decryptor);

std::string HandleAadPrefix(FileDecryptionProperties* file_decryption_properties,
EncryptionAlgorithm& algo);
Expand All @@ -624,11 +629,13 @@ class SerializedFile : public ParquetFileReader::Contents {

// \return The position and size of the actual footer
std::pair<int64_t, uint32_t> ParseMetaDataOfEncryptedFileWithEncryptedFooter(
const std::shared_ptr<Buffer>& crypto_metadata_buffer, uint32_t footer_len);
const std::shared_ptr<Buffer>& crypto_metadata_buffer, uint32_t footer_len,
std::shared_ptr<InternalFileDecryptor>* file_decryptor);
};

uint32_t SerializedFile::ParseUnencryptedFileMetadata(
const std::shared_ptr<Buffer>& metadata_buffer, const uint32_t metadata_len) {
const std::shared_ptr<Buffer>& metadata_buffer, const uint32_t metadata_len,
std::shared_ptr<InternalFileDecryptor> file_decryptor) {
if (metadata_buffer->size() != metadata_len) {
throw ParquetException("Failed reading metadata buffer (requested " +
std::to_string(metadata_len) + " bytes but got " +
Expand All @@ -637,15 +644,15 @@ uint32_t SerializedFile::ParseUnencryptedFileMetadata(
uint32_t read_metadata_len = metadata_len;
// The encrypted read path falls through to here, so pass in the decryptor
file_metadata_ = FileMetaData::Make(metadata_buffer->data(), &read_metadata_len,
properties_, file_decryptor_);
properties_, std::move(file_decryptor));
return read_metadata_len;
}

std::pair<int64_t, uint32_t>
SerializedFile::ParseMetaDataOfEncryptedFileWithEncryptedFooter(
const std::shared_ptr<::arrow::Buffer>& crypto_metadata_buffer,
// both metadata & crypto metadata length
const uint32_t footer_len) {
const uint32_t footer_len, std::shared_ptr<InternalFileDecryptor>* file_decryptor) {
// encryption with encrypted footer
// Check if the footer_buffer contains the entire metadata
if (crypto_metadata_buffer->size() != footer_len) {
Expand All @@ -664,7 +671,7 @@ SerializedFile::ParseMetaDataOfEncryptedFileWithEncryptedFooter(
// Handle AAD prefix
EncryptionAlgorithm algo = file_crypto_metadata->encryption_algorithm();
std::string file_aad = HandleAadPrefix(file_decryption_properties, algo);
file_decryptor_ = std::make_shared<InternalFileDecryptor>(
*file_decryptor = std::make_shared<InternalFileDecryptor>(
file_decryption_properties, file_aad, algo.algorithm,
file_crypto_metadata->key_metadata(), properties_.memory_pool());

Expand All @@ -683,12 +690,12 @@ void SerializedFile::ParseMetaDataOfEncryptedFileWithPlaintextFooter(
EncryptionAlgorithm algo = file_metadata_->encryption_algorithm();
// Handle AAD prefix
std::string file_aad = HandleAadPrefix(file_decryption_properties, algo);
file_decryptor_ = std::make_shared<InternalFileDecryptor>(
auto file_decryptor = std::make_shared<InternalFileDecryptor>(
file_decryption_properties, file_aad, algo.algorithm,
file_metadata_->footer_signing_key_metadata(), properties_.memory_pool());
// set the InternalFileDecryptor in the metadata as well, as it's used
// for signature verification and for ColumnChunkMetaData creation.
file_metadata_->set_file_decryptor(file_decryptor_);
file_metadata_->set_file_decryptor(std::move(file_decryptor));

if (file_decryption_properties->check_plaintext_footer_integrity()) {
if (metadata_len - read_metadata_len !=
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/parquet/metadata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,10 @@ class FileMetaData::FileMetaDataImpl {
file_decryptor_ = std::move(file_decryptor);
}

const std::shared_ptr<InternalFileDecryptor>& file_decryptor() const {
return file_decryptor_;
}

private:
friend FileMetaDataBuilder;
uint32_t metadata_len_ = 0;
Expand Down Expand Up @@ -947,6 +951,10 @@ void FileMetaData::set_file_decryptor(
impl_->set_file_decryptor(std::move(file_decryptor));
}

const std::shared_ptr<InternalFileDecryptor>& FileMetaData::file_decryptor() const {
return impl_->file_decryptor();
}

ParquetVersion::type FileMetaData::version() const {
switch (impl_->version()) {
case 1:
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/parquet/metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,12 +399,14 @@ class PARQUET_EXPORT FileMetaData {
private:
friend FileMetaDataBuilder;
friend class SerializedFile;
friend class SerializedRowGroup;

explicit FileMetaData(const void* serialized_metadata, uint32_t* metadata_len,
const ReaderProperties& properties,
std::shared_ptr<InternalFileDecryptor> file_decryptor = NULLPTR);

void set_file_decryptor(std::shared_ptr<InternalFileDecryptor> file_decryptor);
const std::shared_ptr<InternalFileDecryptor>& file_decryptor() const;

// PIMPL Idiom
FileMetaData();
Expand Down
Loading