diff --git a/cpp/src/arrow/dataset/file_parquet_encryption_test.cc b/cpp/src/arrow/dataset/file_parquet_encryption_test.cc index 307017fd67e06..0287d593d12d3 100644 --- a/cpp/src/arrow/dataset/file_parquet_encryption_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_encryption_test.cc @@ -148,17 +148,22 @@ class DatasetEncryptionTestBase : public ::testing::Test { FileSystemDatasetFactory::Make(file_system_, selector, file_format, factory_options)); - // Read dataset into table + // Create the dataset ASSERT_OK_AND_ASSIGN(auto dataset, dataset_factory->Finish()); - ASSERT_OK_AND_ASSIGN(auto scanner_builder, dataset->NewScan()); - ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish()); - ASSERT_OK_AND_ASSIGN(auto read_table, scanner->ToTable()); - - // Verify the data was read correctly - ASSERT_OK_AND_ASSIGN(auto combined_table, read_table->CombineChunks()); - // Validate the table - ASSERT_OK(combined_table->ValidateFull()); - AssertTablesEqual(*combined_table, *table_); + + // Reuse the dataset above to scan it twice to make sure decryption works correctly. + for (size_t i = 0; i < 2; ++i) { + // Read dataset into table + ASSERT_OK_AND_ASSIGN(auto scanner_builder, dataset->NewScan()); + ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish()); + ASSERT_OK_AND_ASSIGN(auto read_table, scanner->ToTable()); + + // Verify the data was read correctly + ASSERT_OK_AND_ASSIGN(auto combined_table, read_table->CombineChunks()); + // Validate the table + ASSERT_OK(combined_table->ValidateFull()); + AssertTablesEqual(*combined_table, *table_); + } } protected: diff --git a/cpp/src/parquet/file_reader.cc b/cpp/src/parquet/file_reader.cc index b3dd1d6054ac8..8fcb0870ce4b6 100644 --- a/cpp/src/parquet/file_reader.cc +++ b/cpp/src/parquet/file_reader.cc @@ -215,16 +215,14 @@ class SerializedRowGroup : public RowGroupReader::Contents { std::shared_ptr<::arrow::io::internal::ReadRangeCache> cached_source, int64_t source_size, FileMetaData* file_metadata, int row_group_number, ReaderProperties props, - std::shared_ptr prebuffered_column_chunks_bitmap, - std::shared_ptr file_decryptor = nullptr) + std::shared_ptr prebuffered_column_chunks_bitmap) : source_(std::move(source)), cached_source_(std::move(cached_source)), source_size_(source_size), file_metadata_(file_metadata), properties_(std::move(props)), row_group_ordinal_(row_group_number), - prebuffered_column_chunks_bitmap_(std::move(prebuffered_column_chunks_bitmap)), - file_decryptor_(std::move(file_decryptor)) { + prebuffered_column_chunks_bitmap_(std::move(prebuffered_column_chunks_bitmap)) { row_group_metadata_ = file_metadata->RowGroup(row_group_number); } @@ -263,10 +261,10 @@ class SerializedRowGroup : public RowGroupReader::Contents { } // The column is encrypted - std::shared_ptr meta_decryptor = - GetColumnMetaDecryptor(crypto_metadata.get(), file_decryptor_.get()); - std::shared_ptr data_decryptor = - GetColumnDataDecryptor(crypto_metadata.get(), file_decryptor_.get()); + std::shared_ptr meta_decryptor = GetColumnMetaDecryptor( + crypto_metadata.get(), file_metadata_->file_decryptor().get()); + std::shared_ptr data_decryptor = GetColumnDataDecryptor( + crypto_metadata.get(), file_metadata_->file_decryptor().get()); ARROW_DCHECK_NE(meta_decryptor, nullptr); ARROW_DCHECK_NE(data_decryptor, nullptr); @@ -291,7 +289,6 @@ class SerializedRowGroup : public RowGroupReader::Contents { ReaderProperties properties_; int row_group_ordinal_; const std::shared_ptr prebuffered_column_chunks_bitmap_; - std::shared_ptr file_decryptor_; }; // ---------------------------------------------------------------------- @@ -316,7 +313,9 @@ class SerializedFile : public ParquetFileReader::Contents { } void Close() override { - if (file_decryptor_) file_decryptor_->WipeOutDecryptionKeys(); + if (file_metadata_ && file_metadata_->file_decryptor()) { + file_metadata_->file_decryptor()->WipeOutDecryptionKeys(); + } } std::shared_ptr GetRowGroup(int i) override { @@ -330,7 +329,7 @@ class SerializedFile : public ParquetFileReader::Contents { std::unique_ptr contents = std::make_unique( source_, cached_source_, source_size_, file_metadata_.get(), i, properties_, - std::move(prebuffered_column_chunks_bitmap), file_decryptor_); + std::move(prebuffered_column_chunks_bitmap)); return std::make_shared(std::move(contents)); } @@ -346,8 +345,9 @@ class SerializedFile : public ParquetFileReader::Contents { "forget to call ParquetFileReader::Open() first?"); } if (!page_index_reader_) { - page_index_reader_ = PageIndexReader::Make(source_.get(), file_metadata_, - properties_, file_decryptor_.get()); + page_index_reader_ = + PageIndexReader::Make(source_.get(), file_metadata_, properties_, + file_metadata_->file_decryptor().get()); } return page_index_reader_; } @@ -362,8 +362,8 @@ class SerializedFile : public ParquetFileReader::Contents { "forget to call ParquetFileReader::Open() first?"); } if (!bloom_filter_reader_) { - bloom_filter_reader_ = - BloomFilterReader::Make(source_, file_metadata_, properties_, file_decryptor_); + bloom_filter_reader_ = BloomFilterReader::Make(source_, file_metadata_, properties_, + file_metadata_->file_decryptor()); if (bloom_filter_reader_ == nullptr) { throw ParquetException("Cannot create BloomFilterReader"); } @@ -441,10 +441,12 @@ class SerializedFile : public ParquetFileReader::Contents { // Parse the footer depending on encryption type const bool is_encrypted_footer = memcmp(footer_buffer->data() + footer_read_size - 4, kParquetEMagic, 4) == 0; + std::shared_ptr file_decryptor; if (is_encrypted_footer) { // Encrypted file with Encrypted footer. const std::pair read_size = - ParseMetaDataOfEncryptedFileWithEncryptedFooter(metadata_buffer, metadata_len); + ParseMetaDataOfEncryptedFileWithEncryptedFooter(metadata_buffer, metadata_len, + &file_decryptor); // Read the actual footer metadata_start = read_size.first; metadata_len = read_size.second; @@ -453,8 +455,8 @@ class SerializedFile : public ParquetFileReader::Contents { // Fall through } - const uint32_t read_metadata_len = - ParseUnencryptedFileMetadata(metadata_buffer, metadata_len); + const uint32_t read_metadata_len = ParseUnencryptedFileMetadata( + metadata_buffer, metadata_len, std::move(file_decryptor)); auto file_decryption_properties = properties_.file_decryption_properties().get(); if (is_encrypted_footer) { // Nothing else to do here. @@ -550,34 +552,37 @@ class SerializedFile : public ParquetFileReader::Contents { // Parse the footer depending on encryption type const bool is_encrypted_footer = memcmp(footer_buffer->data() + footer_read_size - 4, kParquetEMagic, 4) == 0; + std::shared_ptr file_decryptor; if (is_encrypted_footer) { // Encrypted file with Encrypted footer. std::pair read_size; BEGIN_PARQUET_CATCH_EXCEPTIONS - read_size = - ParseMetaDataOfEncryptedFileWithEncryptedFooter(metadata_buffer, metadata_len); + read_size = ParseMetaDataOfEncryptedFileWithEncryptedFooter( + metadata_buffer, metadata_len, &file_decryptor); END_PARQUET_CATCH_EXCEPTIONS // Read the actual footer int64_t metadata_start = read_size.first; metadata_len = read_size.second; return source_->ReadAsync(metadata_start, metadata_len) - .Then([this, metadata_len, is_encrypted_footer]( + .Then([this, metadata_len, is_encrypted_footer, file_decryptor]( const std::shared_ptr<::arrow::Buffer>& metadata_buffer) { // Continue and read the file footer - return ParseMetaDataFinal(metadata_buffer, metadata_len, is_encrypted_footer); + return ParseMetaDataFinal(metadata_buffer, metadata_len, is_encrypted_footer, + file_decryptor); }); } return ParseMetaDataFinal(std::move(metadata_buffer), metadata_len, - is_encrypted_footer); + is_encrypted_footer, std::move(file_decryptor)); } // Continuation - ::arrow::Status ParseMetaDataFinal(std::shared_ptr<::arrow::Buffer> metadata_buffer, - uint32_t metadata_len, - const bool is_encrypted_footer) { + ::arrow::Status ParseMetaDataFinal( + std::shared_ptr<::arrow::Buffer> metadata_buffer, uint32_t metadata_len, + const bool is_encrypted_footer, + std::shared_ptr file_decryptor) { BEGIN_PARQUET_CATCH_EXCEPTIONS - const uint32_t read_metadata_len = - ParseUnencryptedFileMetadata(metadata_buffer, metadata_len); + const uint32_t read_metadata_len = ParseUnencryptedFileMetadata( + metadata_buffer, metadata_len, std::move(file_decryptor)); auto file_decryption_properties = properties_.file_decryption_properties().get(); if (is_encrypted_footer) { // Nothing else to do here. @@ -608,11 +613,11 @@ class SerializedFile : public ParquetFileReader::Contents { // Maps row group ordinal and prebuffer status of its column chunks in the form of a // bitmap buffer. std::unordered_map> prebuffered_column_chunks_; - std::shared_ptr file_decryptor_; // \return The true length of the metadata in bytes - uint32_t ParseUnencryptedFileMetadata(const std::shared_ptr& footer_buffer, - const uint32_t metadata_len); + uint32_t ParseUnencryptedFileMetadata( + const std::shared_ptr& footer_buffer, const uint32_t metadata_len, + std::shared_ptr file_decryptor); std::string HandleAadPrefix(FileDecryptionProperties* file_decryption_properties, EncryptionAlgorithm& algo); @@ -624,11 +629,13 @@ class SerializedFile : public ParquetFileReader::Contents { // \return The position and size of the actual footer std::pair ParseMetaDataOfEncryptedFileWithEncryptedFooter( - const std::shared_ptr& crypto_metadata_buffer, uint32_t footer_len); + const std::shared_ptr& crypto_metadata_buffer, uint32_t footer_len, + std::shared_ptr* file_decryptor); }; uint32_t SerializedFile::ParseUnencryptedFileMetadata( - const std::shared_ptr& metadata_buffer, const uint32_t metadata_len) { + const std::shared_ptr& metadata_buffer, const uint32_t metadata_len, + std::shared_ptr file_decryptor) { if (metadata_buffer->size() != metadata_len) { throw ParquetException("Failed reading metadata buffer (requested " + std::to_string(metadata_len) + " bytes but got " + @@ -637,7 +644,7 @@ uint32_t SerializedFile::ParseUnencryptedFileMetadata( uint32_t read_metadata_len = metadata_len; // The encrypted read path falls through to here, so pass in the decryptor file_metadata_ = FileMetaData::Make(metadata_buffer->data(), &read_metadata_len, - properties_, file_decryptor_); + properties_, std::move(file_decryptor)); return read_metadata_len; } @@ -645,7 +652,7 @@ std::pair SerializedFile::ParseMetaDataOfEncryptedFileWithEncryptedFooter( const std::shared_ptr<::arrow::Buffer>& crypto_metadata_buffer, // both metadata & crypto metadata length - const uint32_t footer_len) { + const uint32_t footer_len, std::shared_ptr* file_decryptor) { // encryption with encrypted footer // Check if the footer_buffer contains the entire metadata if (crypto_metadata_buffer->size() != footer_len) { @@ -664,7 +671,7 @@ SerializedFile::ParseMetaDataOfEncryptedFileWithEncryptedFooter( // Handle AAD prefix EncryptionAlgorithm algo = file_crypto_metadata->encryption_algorithm(); std::string file_aad = HandleAadPrefix(file_decryption_properties, algo); - file_decryptor_ = std::make_shared( + *file_decryptor = std::make_shared( file_decryption_properties, file_aad, algo.algorithm, file_crypto_metadata->key_metadata(), properties_.memory_pool()); @@ -683,12 +690,12 @@ void SerializedFile::ParseMetaDataOfEncryptedFileWithPlaintextFooter( EncryptionAlgorithm algo = file_metadata_->encryption_algorithm(); // Handle AAD prefix std::string file_aad = HandleAadPrefix(file_decryption_properties, algo); - file_decryptor_ = std::make_shared( + auto file_decryptor = std::make_shared( file_decryption_properties, file_aad, algo.algorithm, file_metadata_->footer_signing_key_metadata(), properties_.memory_pool()); // set the InternalFileDecryptor in the metadata as well, as it's used // for signature verification and for ColumnChunkMetaData creation. - file_metadata_->set_file_decryptor(file_decryptor_); + file_metadata_->set_file_decryptor(std::move(file_decryptor)); if (file_decryption_properties->check_plaintext_footer_integrity()) { if (metadata_len - read_metadata_len != diff --git a/cpp/src/parquet/metadata.cc b/cpp/src/parquet/metadata.cc index 3f101b5ae3ac6..b24883cdc160b 100644 --- a/cpp/src/parquet/metadata.cc +++ b/cpp/src/parquet/metadata.cc @@ -826,6 +826,10 @@ class FileMetaData::FileMetaDataImpl { file_decryptor_ = std::move(file_decryptor); } + const std::shared_ptr& file_decryptor() const { + return file_decryptor_; + } + private: friend FileMetaDataBuilder; uint32_t metadata_len_ = 0; @@ -947,6 +951,10 @@ void FileMetaData::set_file_decryptor( impl_->set_file_decryptor(std::move(file_decryptor)); } +const std::shared_ptr& FileMetaData::file_decryptor() const { + return impl_->file_decryptor(); +} + ParquetVersion::type FileMetaData::version() const { switch (impl_->version()) { case 1: diff --git a/cpp/src/parquet/metadata.h b/cpp/src/parquet/metadata.h index 640b898024346..9fc30df58e0d3 100644 --- a/cpp/src/parquet/metadata.h +++ b/cpp/src/parquet/metadata.h @@ -399,12 +399,14 @@ class PARQUET_EXPORT FileMetaData { private: friend FileMetaDataBuilder; friend class SerializedFile; + friend class SerializedRowGroup; explicit FileMetaData(const void* serialized_metadata, uint32_t* metadata_len, const ReaderProperties& properties, std::shared_ptr file_decryptor = NULLPTR); void set_file_decryptor(std::shared_ptr file_decryptor); + const std::shared_ptr& file_decryptor() const; // PIMPL Idiom FileMetaData();