From dd6d7288e41d5d5c8bc73dbcf96ddc601db009cc Mon Sep 17 00:00:00 2001 From: Donald Tolley Date: Wed, 13 Mar 2024 09:54:02 -0400 Subject: [PATCH] GH-39444: [C++][Parquet] Fix crash in Modular Encryption (#39623) **Rationale for this change:** This pull request addresses a critical issue (GH-39444) in the C++/Python components of Parquet, specifically a segmentation fault occurring when processing encrypted datasets over 2^15 rows. The fix involves modifications in `cpp/src/parquet/encryption/internal_file_decryptor.cc`, particularly in `InternalFileDecryptor::GetColumnDecryptor`. The caching of the `Decryptor` object was removed to resolve the multithreading issue causing the segmentation fault and encryption failures. **What changes are included in this PR?** - Removal of `Decryptor` object caching in `InternalFileDecryptor::GetColumnDecryptor`. - Addition of two unit tests: `large_row_parquet_encrypt_test.cc` for C++ and an update to `test_dataset_encryption.py` with `test_large_row_encryption_decryption` for Python. **Are these changes tested?** Yes, the unit tests (`large_row_parquet_encrypt_test.cc` and `test_large_row_encryption_decryption` in `test_dataset_encryption.py`) have been added to ensure the reliability and effectiveness of these changes. **Are there any user-facing changes?** No significant user-facing changes, but the update significantly improves the backend stability and reliability of Parquet file handling. Calling DecryptionKeyRetriever::GetKey could be an expensive operation potentially involving network calls to key management servers. * Closes: #39444 * GitHub Issue: #39444 Lead-authored-by: Donald Tolley Co-authored-by: Antoine Pitrou Co-authored-by: Antoine Pitrou Co-authored-by: Adam Reeve Co-authored-by: Gang Wu Signed-off-by: Antoine Pitrou --- .../dataset/file_parquet_encryption_test.cc | 182 +++++++++++------- .../parquet/encryption/encryption_internal.cc | 14 +- .../encryption/internal_file_decryptor.cc | 50 ++--- .../encryption/internal_file_decryptor.h | 7 +- .../pyarrow/tests/test_dataset_encryption.py | 62 ++++++ 5 files changed, 196 insertions(+), 119 deletions(-) diff --git a/cpp/src/arrow/dataset/file_parquet_encryption_test.cc b/cpp/src/arrow/dataset/file_parquet_encryption_test.cc index 87028eb6e2fac..307017fd67e06 100644 --- a/cpp/src/arrow/dataset/file_parquet_encryption_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_encryption_test.cc @@ -19,10 +19,10 @@ #include "gtest/gtest.h" -#include -#include -#include #include "arrow/array.h" +#include "arrow/dataset/dataset.h" +#include "arrow/dataset/file_base.h" +#include "arrow/dataset/file_parquet.h" #include "arrow/dataset/parquet_encryption_config.h" #include "arrow/dataset/partition.h" #include "arrow/filesystem/mockfs.h" @@ -30,10 +30,10 @@ #include "arrow/status.h" #include "arrow/table.h" #include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" #include "arrow/type.h" #include "parquet/arrow/reader.h" #include "parquet/encryption/crypto_factory.h" -#include "parquet/encryption/encryption.h" #include "parquet/encryption/encryption_internal.h" #include "parquet/encryption/kms_client.h" #include "parquet/encryption/test_in_memory_kms.h" @@ -51,14 +51,14 @@ using arrow::internal::checked_pointer_cast; namespace arrow { namespace dataset { -class DatasetEncryptionTest : public ::testing::Test { +// Base class to test writing and reading encrypted dataset. +class DatasetEncryptionTestBase : public ::testing::Test { public: // This function creates a mock file system using the current time point, creates a // directory with the given base directory path, and writes a dataset to it using - // provided Parquet file write options. The dataset is partitioned using a Hive - // partitioning scheme. The function also checks if the written files exist in the file - // system. - static void SetUpTestSuite() { + // provided Parquet file write options. The function also checks if the written files + // exist in the file system. + void SetUp() override { #ifdef ARROW_VALGRIND // Not necessary otherwise, but prevents a Valgrind leak by making sure // OpenSSL initialization is done from the main thread @@ -71,24 +71,8 @@ class DatasetEncryptionTest : public ::testing::Test { std::chrono::system_clock::now(), {})); ASSERT_OK(file_system_->CreateDir(std::string(kBaseDir))); - // Prepare table data. - auto table_schema = schema({field("a", int64()), field("c", int64()), - field("e", int64()), field("part", utf8())}); - table_ = TableFromJSON(table_schema, {R"([ - [ 0, 9, 1, "a" ], - [ 1, 8, 2, "a" ], - [ 2, 7, 1, "c" ], - [ 3, 6, 2, "c" ], - [ 4, 5, 1, "e" ], - [ 5, 4, 2, "e" ], - [ 6, 3, 1, "g" ], - [ 7, 2, 2, "g" ], - [ 8, 1, 1, "i" ], - [ 9, 0, 2, "i" ] - ])"}); - - // Use a Hive-style partitioning scheme. - partitioning_ = std::make_shared(schema({field("part", utf8())})); + // Init dataset and partitioning. + ASSERT_NO_FATAL_FAILURE(PrepareTableAndPartitioning()); // Prepare encryption properties. std::unordered_map key_map; @@ -133,13 +117,81 @@ class DatasetEncryptionTest : public ::testing::Test { ASSERT_OK(FileSystemDataset::Write(write_options, std::move(scanner))); } + virtual void PrepareTableAndPartitioning() = 0; + + void TestScanDataset() { + // Create decryption properties. + auto decryption_config = + std::make_shared(); + auto parquet_decryption_config = std::make_shared(); + parquet_decryption_config->crypto_factory = crypto_factory_; + parquet_decryption_config->kms_connection_config = kms_connection_config_; + parquet_decryption_config->decryption_config = std::move(decryption_config); + + // Set scan options. + auto parquet_scan_options = std::make_shared(); + parquet_scan_options->parquet_decryption_config = + std::move(parquet_decryption_config); + + auto file_format = std::make_shared(); + file_format->default_fragment_scan_options = std::move(parquet_scan_options); + + // Get FileInfo objects for all files under the base directory + fs::FileSelector selector; + selector.base_dir = kBaseDir; + selector.recursive = true; + + FileSystemFactoryOptions factory_options; + factory_options.partitioning = partitioning_; + factory_options.partition_base_dir = kBaseDir; + ASSERT_OK_AND_ASSIGN(auto dataset_factory, + FileSystemDatasetFactory::Make(file_system_, selector, + file_format, factory_options)); + + // Read dataset into table + ASSERT_OK_AND_ASSIGN(auto dataset, dataset_factory->Finish()); + ASSERT_OK_AND_ASSIGN(auto scanner_builder, dataset->NewScan()); + ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish()); + ASSERT_OK_AND_ASSIGN(auto read_table, scanner->ToTable()); + + // Verify the data was read correctly + ASSERT_OK_AND_ASSIGN(auto combined_table, read_table->CombineChunks()); + // Validate the table + ASSERT_OK(combined_table->ValidateFull()); + AssertTablesEqual(*combined_table, *table_); + } + protected: - inline static std::shared_ptr file_system_; - inline static std::shared_ptr table_; - inline static std::shared_ptr partitioning_; - inline static std::shared_ptr crypto_factory_; - inline static std::shared_ptr - kms_connection_config_; + std::shared_ptr file_system_; + std::shared_ptr
table_; + std::shared_ptr partitioning_; + std::shared_ptr crypto_factory_; + std::shared_ptr kms_connection_config_; +}; + +class DatasetEncryptionTest : public DatasetEncryptionTestBase { + public: + // The dataset is partitioned using a Hive partitioning scheme. + void PrepareTableAndPartitioning() override { + // Prepare table data. + auto table_schema = schema({field("a", int64()), field("c", int64()), + field("e", int64()), field("part", utf8())}); + table_ = TableFromJSON(table_schema, {R"([ + [ 0, 9, 1, "a" ], + [ 1, 8, 2, "a" ], + [ 2, 7, 1, "c" ], + [ 3, 6, 2, "c" ], + [ 4, 5, 1, "e" ], + [ 5, 4, 2, "e" ], + [ 6, 3, 1, "g" ], + [ 7, 2, 2, "g" ], + [ 8, 1, 1, "i" ], + [ 9, 0, 2, "i" ] + ])"}); + + // Use a Hive-style partitioning scheme. + partitioning_ = std::make_shared(schema({field("part", utf8())})); + } }; // This test demonstrates the process of writing a partitioned Parquet file with the same @@ -148,44 +200,7 @@ class DatasetEncryptionTest : public ::testing::Test { // test reads the data back and verifies that it can be successfully decrypted and // scanned. TEST_F(DatasetEncryptionTest, WriteReadDatasetWithEncryption) { - // Create decryption properties. - auto decryption_config = - std::make_shared(); - auto parquet_decryption_config = std::make_shared(); - parquet_decryption_config->crypto_factory = crypto_factory_; - parquet_decryption_config->kms_connection_config = kms_connection_config_; - parquet_decryption_config->decryption_config = std::move(decryption_config); - - // Set scan options. - auto parquet_scan_options = std::make_shared(); - parquet_scan_options->parquet_decryption_config = std::move(parquet_decryption_config); - - auto file_format = std::make_shared(); - file_format->default_fragment_scan_options = std::move(parquet_scan_options); - - // Get FileInfo objects for all files under the base directory - fs::FileSelector selector; - selector.base_dir = kBaseDir; - selector.recursive = true; - - FileSystemFactoryOptions factory_options; - factory_options.partitioning = partitioning_; - factory_options.partition_base_dir = kBaseDir; - ASSERT_OK_AND_ASSIGN(auto dataset_factory, - FileSystemDatasetFactory::Make(file_system_, selector, file_format, - factory_options)); - - // Read dataset into table - ASSERT_OK_AND_ASSIGN(auto dataset, dataset_factory->Finish()); - ASSERT_OK_AND_ASSIGN(auto scanner_builder, dataset->NewScan()); - ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish()); - ASSERT_OK_AND_ASSIGN(auto read_table, scanner->ToTable()); - - // Verify the data was read correctly - ASSERT_OK_AND_ASSIGN(auto combined_table, read_table->CombineChunks()); - // Validate the table - ASSERT_OK(combined_table->ValidateFull()); - AssertTablesEqual(*combined_table, *table_); + ASSERT_NO_FATAL_FAILURE(TestScanDataset()); } // Read a single parquet file with and without decryption properties. @@ -220,5 +235,30 @@ TEST_F(DatasetEncryptionTest, ReadSingleFile) { ASSERT_EQ(checked_pointer_cast(table->column(2)->chunk(0))->GetView(0), 1); } +// GH-39444: This test covers the case where parquet dataset scanner crashes when +// processing encrypted datasets over 2^15 rows in multi-threaded mode. +class LargeRowEncryptionTest : public DatasetEncryptionTestBase { + public: + // The dataset is partitioned using a Hive partitioning scheme. + void PrepareTableAndPartitioning() override { + // Specifically chosen to be greater than batch size for triggering prefetch. + constexpr int kRowCount = 32769; + + // Create a random floating-point array with large number of rows. + arrow::random::RandomArrayGenerator rand_gen(0); + auto array = rand_gen.Float32(kRowCount, 0.0, 1.0, false); + auto table_schema = schema({field("a", float32())}); + + // Prepare table and partitioning. + table_ = arrow::Table::Make(table_schema, {array}); + partitioning_ = std::make_shared(arrow::schema({})); + } +}; + +// Test for writing and reading encrypted dataset with large row count. +TEST_F(LargeRowEncryptionTest, ReadEncryptLargeRows) { + ASSERT_NO_FATAL_FAILURE(TestScanDataset()); +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/parquet/encryption/encryption_internal.cc b/cpp/src/parquet/encryption/encryption_internal.cc index b1770be53358d..465b14793219f 100644 --- a/cpp/src/parquet/encryption/encryption_internal.cc +++ b/cpp/src/parquet/encryption/encryption_internal.cc @@ -55,12 +55,7 @@ class AesEncryptor::AesEncryptorImpl { explicit AesEncryptorImpl(ParquetCipher::type alg_id, int key_len, bool metadata, bool write_length); - ~AesEncryptorImpl() { - if (nullptr != ctx_) { - EVP_CIPHER_CTX_free(ctx_); - ctx_ = nullptr; - } - } + ~AesEncryptorImpl() { WipeOut(); } int Encrypt(const uint8_t* plaintext, int plaintext_len, const uint8_t* key, int key_len, const uint8_t* aad, int aad_len, uint8_t* ciphertext); @@ -318,12 +313,7 @@ class AesDecryptor::AesDecryptorImpl { explicit AesDecryptorImpl(ParquetCipher::type alg_id, int key_len, bool metadata, bool contains_length); - ~AesDecryptorImpl() { - if (nullptr != ctx_) { - EVP_CIPHER_CTX_free(ctx_); - ctx_ = nullptr; - } - } + ~AesDecryptorImpl() { WipeOut(); } int Decrypt(const uint8_t* ciphertext, int ciphertext_len, const uint8_t* key, int key_len, const uint8_t* aad, int aad_len, uint8_t* plaintext); diff --git a/cpp/src/parquet/encryption/internal_file_decryptor.cc b/cpp/src/parquet/encryption/internal_file_decryptor.cc index 19e4845c8732d..c4416df90b121 100644 --- a/cpp/src/parquet/encryption/internal_file_decryptor.cc +++ b/cpp/src/parquet/encryption/internal_file_decryptor.cc @@ -61,6 +61,7 @@ InternalFileDecryptor::InternalFileDecryptor(FileDecryptionProperties* propertie } void InternalFileDecryptor::WipeOutDecryptionKeys() { + std::lock_guard lock(mutex_); properties_->WipeOutDecryptionKeys(); for (auto const& i : all_decryptors_) { if (auto aes_decryptor = i.lock()) { @@ -139,10 +140,16 @@ std::shared_ptr InternalFileDecryptor::GetFooterDecryptor( // Create both data and metadata decryptors to avoid redundant retrieval of key // from the key_retriever. int key_len = static_cast(footer_key.size()); - auto aes_metadata_decryptor = encryption::AesDecryptor::Make( - algorithm_, key_len, /*metadata=*/true, &all_decryptors_); - auto aes_data_decryptor = encryption::AesDecryptor::Make( - algorithm_, key_len, /*metadata=*/false, &all_decryptors_); + std::shared_ptr aes_metadata_decryptor; + std::shared_ptr aes_data_decryptor; + + { + std::lock_guard lock(mutex_); + aes_metadata_decryptor = encryption::AesDecryptor::Make( + algorithm_, key_len, /*metadata=*/true, &all_decryptors_); + aes_data_decryptor = encryption::AesDecryptor::Make( + algorithm_, key_len, /*metadata=*/false, &all_decryptors_); + } footer_metadata_decryptor_ = std::make_shared( aes_metadata_decryptor, footer_key, file_aad_, aad, pool_); @@ -168,21 +175,7 @@ std::shared_ptr InternalFileDecryptor::GetColumnDataDecryptor( std::shared_ptr InternalFileDecryptor::GetColumnDecryptor( const std::string& column_path, const std::string& column_key_metadata, const std::string& aad, bool metadata) { - std::string column_key; - // first look if we already got the decryptor from before - if (metadata) { - if (column_metadata_map_.find(column_path) != column_metadata_map_.end()) { - auto res(column_metadata_map_.at(column_path)); - res->UpdateAad(aad); - return res; - } - } else { - if (column_data_map_.find(column_path) != column_data_map_.end()) { - auto res(column_data_map_.at(column_path)); - res->UpdateAad(aad); - return res; - } - } + std::string column_key = properties_->column_key(column_path); column_key = properties_->column_key(column_path); // No explicit column key given via API. Retrieve via key metadata. @@ -200,21 +193,12 @@ std::shared_ptr InternalFileDecryptor::GetColumnDecryptor( throw HiddenColumnException("HiddenColumnException, path=" + column_path); } - // Create both data and metadata decryptors to avoid redundant retrieval of key - // using the key_retriever. int key_len = static_cast(column_key.size()); - auto aes_metadata_decryptor = encryption::AesDecryptor::Make( - algorithm_, key_len, /*metadata=*/true, &all_decryptors_); - auto aes_data_decryptor = encryption::AesDecryptor::Make( - algorithm_, key_len, /*metadata=*/false, &all_decryptors_); - - column_metadata_map_[column_path] = std::make_shared( - aes_metadata_decryptor, column_key, file_aad_, aad, pool_); - column_data_map_[column_path] = - std::make_shared(aes_data_decryptor, column_key, file_aad_, aad, pool_); - - if (metadata) return column_metadata_map_[column_path]; - return column_data_map_[column_path]; + std::lock_guard lock(mutex_); + auto aes_decryptor = + encryption::AesDecryptor::Make(algorithm_, key_len, metadata, &all_decryptors_); + return std::make_shared(std::move(aes_decryptor), column_key, file_aad_, aad, + pool_); } namespace { diff --git a/cpp/src/parquet/encryption/internal_file_decryptor.h b/cpp/src/parquet/encryption/internal_file_decryptor.h index 0b27effda8822..f12cdefbe67a7 100644 --- a/cpp/src/parquet/encryption/internal_file_decryptor.h +++ b/cpp/src/parquet/encryption/internal_file_decryptor.h @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -91,15 +92,15 @@ class InternalFileDecryptor { FileDecryptionProperties* properties_; // Concatenation of aad_prefix (if exists) and aad_file_unique std::string file_aad_; - std::map> column_data_map_; - std::map> column_metadata_map_; std::shared_ptr footer_metadata_decryptor_; std::shared_ptr footer_data_decryptor_; ParquetCipher::type algorithm_; std::string footer_key_metadata_; + // Mutex to guard access to all_decryptors_ + mutable std::mutex mutex_; // A weak reference to all decryptors that need to be wiped out when decryption is - // finished + // finished, guarded by mutex_ for thread safety std::vector> all_decryptors_; ::arrow::MemoryPool* pool_; diff --git a/python/pyarrow/tests/test_dataset_encryption.py b/python/pyarrow/tests/test_dataset_encryption.py index d25b22990abfb..fadbb6108d440 100644 --- a/python/pyarrow/tests/test_dataset_encryption.py +++ b/python/pyarrow/tests/test_dataset_encryption.py @@ -15,9 +15,12 @@ # specific language governing permissions and limitations # under the License. +import base64 from datetime import timedelta +import numpy as np import pyarrow.fs as fs import pyarrow as pa +import pyarrow.parquet as pq import pytest encryption_unavailable = False @@ -151,3 +154,62 @@ def test_write_dataset_parquet_without_encryption(): with pytest.raises(NotImplementedError): _ = pformat.make_write_options(encryption_config="some value") + + +@pytest.mark.skipif( + encryption_unavailable, reason="Parquet Encryption is not currently enabled" +) +def test_large_row_encryption_decryption(): + """Test encryption and decryption of a large number of rows.""" + + class NoOpKmsClient(pe.KmsClient): + def wrap_key(self, key_bytes: bytes, _: str) -> bytes: + b = base64.b64encode(key_bytes) + return b + + def unwrap_key(self, wrapped_key: bytes, _: str) -> bytes: + b = base64.b64decode(wrapped_key) + return b + + row_count = 2**15 + 1 + table = pa.Table.from_arrays( + [pa.array(np.random.rand(row_count), type=pa.float32())], names=["foo"] + ) + + kms_config = pe.KmsConnectionConfig() + crypto_factory = pe.CryptoFactory(lambda _: NoOpKmsClient()) + encryption_config = pe.EncryptionConfiguration( + footer_key="UNIMPORTANT_KEY", + column_keys={"UNIMPORTANT_KEY": ["foo"]}, + double_wrapping=True, + plaintext_footer=False, + data_key_length_bits=128, + ) + pqe_config = ds.ParquetEncryptionConfig( + crypto_factory, kms_config, encryption_config + ) + pqd_config = ds.ParquetDecryptionConfig( + crypto_factory, kms_config, pe.DecryptionConfiguration() + ) + scan_options = ds.ParquetFragmentScanOptions(decryption_config=pqd_config) + file_format = ds.ParquetFileFormat(default_fragment_scan_options=scan_options) + write_options = file_format.make_write_options(encryption_config=pqe_config) + file_decryption_properties = crypto_factory.file_decryption_properties(kms_config) + + mockfs = fs._MockFileSystem() + mockfs.create_dir("/") + + path = "large-row-test-dataset" + ds.write_dataset(table, path, format=file_format, + file_options=write_options, filesystem=mockfs) + + file_path = path + "/part-0.parquet" + new_table = pq.ParquetFile( + file_path, decryption_properties=file_decryption_properties, + filesystem=mockfs + ).read() + assert table == new_table + + dataset = ds.dataset(path, format=file_format, filesystem=mockfs) + new_table = dataset.to_table() + assert table == new_table