Skip to content

Commit

Permalink
GH-39444: [C++][Parquet] Fix crash in Modular Encryption (#39623)
Browse files Browse the repository at this point in the history
**Rationale for this change:**

This pull request addresses a critical issue (GH-39444) in the C++/Python components of Parquet, specifically a segmentation fault occurring when processing encrypted datasets over 2^15 rows. The fix involves modifications in `cpp/src/parquet/encryption/internal_file_decryptor.cc`, particularly in `InternalFileDecryptor::GetColumnDecryptor`. The caching of the `Decryptor` object was removed to resolve the multithreading issue causing the segmentation fault and encryption failures.

**What changes are included in this PR?**

- Removal of `Decryptor` object caching in `InternalFileDecryptor::GetColumnDecryptor`.
- Addition of two unit tests: `large_row_parquet_encrypt_test.cc` for C++ and an update to `test_dataset_encryption.py` with `test_large_row_encryption_decryption` for Python.

**Are these changes tested?**

Yes, the unit tests (`large_row_parquet_encrypt_test.cc` and `test_large_row_encryption_decryption` in `test_dataset_encryption.py`) have been added to ensure the reliability and effectiveness of these changes.

**Are there any user-facing changes?**

No significant user-facing changes, but the update significantly improves the backend stability and reliability of Parquet file handling.  Calling DecryptionKeyRetriever::GetKey could be an expensive operation potentially involving network calls to key management servers. 

* Closes: #39444
* GitHub Issue: #39444

Lead-authored-by: Donald Tolley <[email protected]>
Co-authored-by: Antoine Pitrou <[email protected]>
Co-authored-by: Antoine Pitrou <[email protected]>
Co-authored-by: Adam Reeve <[email protected]>
Co-authored-by: Gang Wu <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
  • Loading branch information
5 people authored Mar 13, 2024
1 parent 788200a commit dd6d728
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 119 deletions.
182 changes: 111 additions & 71 deletions cpp/src/arrow/dataset/file_parquet_encryption_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@

#include "gtest/gtest.h"

#include <arrow/dataset/dataset.h>
#include <arrow/dataset/file_base.h>
#include <arrow/dataset/file_parquet.h>
#include "arrow/array.h"
#include "arrow/dataset/dataset.h"
#include "arrow/dataset/file_base.h"
#include "arrow/dataset/file_parquet.h"
#include "arrow/dataset/parquet_encryption_config.h"
#include "arrow/dataset/partition.h"
#include "arrow/filesystem/mockfs.h"
#include "arrow/io/api.h"
#include "arrow/status.h"
#include "arrow/table.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/random.h"
#include "arrow/type.h"
#include "parquet/arrow/reader.h"
#include "parquet/encryption/crypto_factory.h"
#include "parquet/encryption/encryption.h"
#include "parquet/encryption/encryption_internal.h"
#include "parquet/encryption/kms_client.h"
#include "parquet/encryption/test_in_memory_kms.h"
Expand All @@ -51,14 +51,14 @@ using arrow::internal::checked_pointer_cast;
namespace arrow {
namespace dataset {

class DatasetEncryptionTest : public ::testing::Test {
// Base class to test writing and reading encrypted dataset.
class DatasetEncryptionTestBase : public ::testing::Test {
public:
// This function creates a mock file system using the current time point, creates a
// directory with the given base directory path, and writes a dataset to it using
// provided Parquet file write options. The dataset is partitioned using a Hive
// partitioning scheme. The function also checks if the written files exist in the file
// system.
static void SetUpTestSuite() {
// provided Parquet file write options. The function also checks if the written files
// exist in the file system.
void SetUp() override {
#ifdef ARROW_VALGRIND
// Not necessary otherwise, but prevents a Valgrind leak by making sure
// OpenSSL initialization is done from the main thread
Expand All @@ -71,24 +71,8 @@ class DatasetEncryptionTest : public ::testing::Test {
std::chrono::system_clock::now(), {}));
ASSERT_OK(file_system_->CreateDir(std::string(kBaseDir)));

// Prepare table data.
auto table_schema = schema({field("a", int64()), field("c", int64()),
field("e", int64()), field("part", utf8())});
table_ = TableFromJSON(table_schema, {R"([
[ 0, 9, 1, "a" ],
[ 1, 8, 2, "a" ],
[ 2, 7, 1, "c" ],
[ 3, 6, 2, "c" ],
[ 4, 5, 1, "e" ],
[ 5, 4, 2, "e" ],
[ 6, 3, 1, "g" ],
[ 7, 2, 2, "g" ],
[ 8, 1, 1, "i" ],
[ 9, 0, 2, "i" ]
])"});

// Use a Hive-style partitioning scheme.
partitioning_ = std::make_shared<HivePartitioning>(schema({field("part", utf8())}));
// Init dataset and partitioning.
ASSERT_NO_FATAL_FAILURE(PrepareTableAndPartitioning());

// Prepare encryption properties.
std::unordered_map<std::string, std::string> key_map;
Expand Down Expand Up @@ -133,13 +117,81 @@ class DatasetEncryptionTest : public ::testing::Test {
ASSERT_OK(FileSystemDataset::Write(write_options, std::move(scanner)));
}

virtual void PrepareTableAndPartitioning() = 0;

void TestScanDataset() {
// Create decryption properties.
auto decryption_config =
std::make_shared<parquet::encryption::DecryptionConfiguration>();
auto parquet_decryption_config = std::make_shared<ParquetDecryptionConfig>();
parquet_decryption_config->crypto_factory = crypto_factory_;
parquet_decryption_config->kms_connection_config = kms_connection_config_;
parquet_decryption_config->decryption_config = std::move(decryption_config);

// Set scan options.
auto parquet_scan_options = std::make_shared<ParquetFragmentScanOptions>();
parquet_scan_options->parquet_decryption_config =
std::move(parquet_decryption_config);

auto file_format = std::make_shared<ParquetFileFormat>();
file_format->default_fragment_scan_options = std::move(parquet_scan_options);

// Get FileInfo objects for all files under the base directory
fs::FileSelector selector;
selector.base_dir = kBaseDir;
selector.recursive = true;

FileSystemFactoryOptions factory_options;
factory_options.partitioning = partitioning_;
factory_options.partition_base_dir = kBaseDir;
ASSERT_OK_AND_ASSIGN(auto dataset_factory,
FileSystemDatasetFactory::Make(file_system_, selector,
file_format, factory_options));

// Read dataset into table
ASSERT_OK_AND_ASSIGN(auto dataset, dataset_factory->Finish());
ASSERT_OK_AND_ASSIGN(auto scanner_builder, dataset->NewScan());
ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish());
ASSERT_OK_AND_ASSIGN(auto read_table, scanner->ToTable());

// Verify the data was read correctly
ASSERT_OK_AND_ASSIGN(auto combined_table, read_table->CombineChunks());
// Validate the table
ASSERT_OK(combined_table->ValidateFull());
AssertTablesEqual(*combined_table, *table_);
}

protected:
inline static std::shared_ptr<fs::FileSystem> file_system_;
inline static std::shared_ptr<Table> table_;
inline static std::shared_ptr<HivePartitioning> partitioning_;
inline static std::shared_ptr<parquet::encryption::CryptoFactory> crypto_factory_;
inline static std::shared_ptr<parquet::encryption::KmsConnectionConfig>
kms_connection_config_;
std::shared_ptr<fs::FileSystem> file_system_;
std::shared_ptr<Table> table_;
std::shared_ptr<Partitioning> partitioning_;
std::shared_ptr<parquet::encryption::CryptoFactory> crypto_factory_;
std::shared_ptr<parquet::encryption::KmsConnectionConfig> kms_connection_config_;
};

class DatasetEncryptionTest : public DatasetEncryptionTestBase {
public:
// The dataset is partitioned using a Hive partitioning scheme.
void PrepareTableAndPartitioning() override {
// Prepare table data.
auto table_schema = schema({field("a", int64()), field("c", int64()),
field("e", int64()), field("part", utf8())});
table_ = TableFromJSON(table_schema, {R"([
[ 0, 9, 1, "a" ],
[ 1, 8, 2, "a" ],
[ 2, 7, 1, "c" ],
[ 3, 6, 2, "c" ],
[ 4, 5, 1, "e" ],
[ 5, 4, 2, "e" ],
[ 6, 3, 1, "g" ],
[ 7, 2, 2, "g" ],
[ 8, 1, 1, "i" ],
[ 9, 0, 2, "i" ]
])"});

// Use a Hive-style partitioning scheme.
partitioning_ = std::make_shared<HivePartitioning>(schema({field("part", utf8())}));
}
};

// This test demonstrates the process of writing a partitioned Parquet file with the same
Expand All @@ -148,44 +200,7 @@ class DatasetEncryptionTest : public ::testing::Test {
// test reads the data back and verifies that it can be successfully decrypted and
// scanned.
TEST_F(DatasetEncryptionTest, WriteReadDatasetWithEncryption) {
// Create decryption properties.
auto decryption_config =
std::make_shared<parquet::encryption::DecryptionConfiguration>();
auto parquet_decryption_config = std::make_shared<ParquetDecryptionConfig>();
parquet_decryption_config->crypto_factory = crypto_factory_;
parquet_decryption_config->kms_connection_config = kms_connection_config_;
parquet_decryption_config->decryption_config = std::move(decryption_config);

// Set scan options.
auto parquet_scan_options = std::make_shared<ParquetFragmentScanOptions>();
parquet_scan_options->parquet_decryption_config = std::move(parquet_decryption_config);

auto file_format = std::make_shared<ParquetFileFormat>();
file_format->default_fragment_scan_options = std::move(parquet_scan_options);

// Get FileInfo objects for all files under the base directory
fs::FileSelector selector;
selector.base_dir = kBaseDir;
selector.recursive = true;

FileSystemFactoryOptions factory_options;
factory_options.partitioning = partitioning_;
factory_options.partition_base_dir = kBaseDir;
ASSERT_OK_AND_ASSIGN(auto dataset_factory,
FileSystemDatasetFactory::Make(file_system_, selector, file_format,
factory_options));

// Read dataset into table
ASSERT_OK_AND_ASSIGN(auto dataset, dataset_factory->Finish());
ASSERT_OK_AND_ASSIGN(auto scanner_builder, dataset->NewScan());
ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish());
ASSERT_OK_AND_ASSIGN(auto read_table, scanner->ToTable());

// Verify the data was read correctly
ASSERT_OK_AND_ASSIGN(auto combined_table, read_table->CombineChunks());
// Validate the table
ASSERT_OK(combined_table->ValidateFull());
AssertTablesEqual(*combined_table, *table_);
ASSERT_NO_FATAL_FAILURE(TestScanDataset());
}

// Read a single parquet file with and without decryption properties.
Expand Down Expand Up @@ -220,5 +235,30 @@ TEST_F(DatasetEncryptionTest, ReadSingleFile) {
ASSERT_EQ(checked_pointer_cast<Int64Array>(table->column(2)->chunk(0))->GetView(0), 1);
}

// GH-39444: This test covers the case where parquet dataset scanner crashes when
// processing encrypted datasets over 2^15 rows in multi-threaded mode.
class LargeRowEncryptionTest : public DatasetEncryptionTestBase {
public:
// The dataset is partitioned using a Hive partitioning scheme.
void PrepareTableAndPartitioning() override {
// Specifically chosen to be greater than batch size for triggering prefetch.
constexpr int kRowCount = 32769;

// Create a random floating-point array with large number of rows.
arrow::random::RandomArrayGenerator rand_gen(0);
auto array = rand_gen.Float32(kRowCount, 0.0, 1.0, false);
auto table_schema = schema({field("a", float32())});

// Prepare table and partitioning.
table_ = arrow::Table::Make(table_schema, {array});
partitioning_ = std::make_shared<dataset::DirectoryPartitioning>(arrow::schema({}));
}
};

// Test for writing and reading encrypted dataset with large row count.
TEST_F(LargeRowEncryptionTest, ReadEncryptLargeRows) {
ASSERT_NO_FATAL_FAILURE(TestScanDataset());
}

} // namespace dataset
} // namespace arrow
14 changes: 2 additions & 12 deletions cpp/src/parquet/encryption/encryption_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,7 @@ class AesEncryptor::AesEncryptorImpl {
explicit AesEncryptorImpl(ParquetCipher::type alg_id, int key_len, bool metadata,
bool write_length);

~AesEncryptorImpl() {
if (nullptr != ctx_) {
EVP_CIPHER_CTX_free(ctx_);
ctx_ = nullptr;
}
}
~AesEncryptorImpl() { WipeOut(); }

int Encrypt(const uint8_t* plaintext, int plaintext_len, const uint8_t* key,
int key_len, const uint8_t* aad, int aad_len, uint8_t* ciphertext);
Expand Down Expand Up @@ -318,12 +313,7 @@ class AesDecryptor::AesDecryptorImpl {
explicit AesDecryptorImpl(ParquetCipher::type alg_id, int key_len, bool metadata,
bool contains_length);

~AesDecryptorImpl() {
if (nullptr != ctx_) {
EVP_CIPHER_CTX_free(ctx_);
ctx_ = nullptr;
}
}
~AesDecryptorImpl() { WipeOut(); }

int Decrypt(const uint8_t* ciphertext, int ciphertext_len, const uint8_t* key,
int key_len, const uint8_t* aad, int aad_len, uint8_t* plaintext);
Expand Down
50 changes: 17 additions & 33 deletions cpp/src/parquet/encryption/internal_file_decryptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ InternalFileDecryptor::InternalFileDecryptor(FileDecryptionProperties* propertie
}

void InternalFileDecryptor::WipeOutDecryptionKeys() {
std::lock_guard<std::mutex> lock(mutex_);
properties_->WipeOutDecryptionKeys();
for (auto const& i : all_decryptors_) {
if (auto aes_decryptor = i.lock()) {
Expand Down Expand Up @@ -139,10 +140,16 @@ std::shared_ptr<Decryptor> InternalFileDecryptor::GetFooterDecryptor(
// Create both data and metadata decryptors to avoid redundant retrieval of key
// from the key_retriever.
int key_len = static_cast<int>(footer_key.size());
auto aes_metadata_decryptor = encryption::AesDecryptor::Make(
algorithm_, key_len, /*metadata=*/true, &all_decryptors_);
auto aes_data_decryptor = encryption::AesDecryptor::Make(
algorithm_, key_len, /*metadata=*/false, &all_decryptors_);
std::shared_ptr<encryption::AesDecryptor> aes_metadata_decryptor;
std::shared_ptr<encryption::AesDecryptor> aes_data_decryptor;

{
std::lock_guard<std::mutex> lock(mutex_);
aes_metadata_decryptor = encryption::AesDecryptor::Make(
algorithm_, key_len, /*metadata=*/true, &all_decryptors_);
aes_data_decryptor = encryption::AesDecryptor::Make(
algorithm_, key_len, /*metadata=*/false, &all_decryptors_);
}

footer_metadata_decryptor_ = std::make_shared<Decryptor>(
aes_metadata_decryptor, footer_key, file_aad_, aad, pool_);
Expand All @@ -168,21 +175,7 @@ std::shared_ptr<Decryptor> InternalFileDecryptor::GetColumnDataDecryptor(
std::shared_ptr<Decryptor> InternalFileDecryptor::GetColumnDecryptor(
const std::string& column_path, const std::string& column_key_metadata,
const std::string& aad, bool metadata) {
std::string column_key;
// first look if we already got the decryptor from before
if (metadata) {
if (column_metadata_map_.find(column_path) != column_metadata_map_.end()) {
auto res(column_metadata_map_.at(column_path));
res->UpdateAad(aad);
return res;
}
} else {
if (column_data_map_.find(column_path) != column_data_map_.end()) {
auto res(column_data_map_.at(column_path));
res->UpdateAad(aad);
return res;
}
}
std::string column_key = properties_->column_key(column_path);

column_key = properties_->column_key(column_path);
// No explicit column key given via API. Retrieve via key metadata.
Expand All @@ -200,21 +193,12 @@ std::shared_ptr<Decryptor> InternalFileDecryptor::GetColumnDecryptor(
throw HiddenColumnException("HiddenColumnException, path=" + column_path);
}

// Create both data and metadata decryptors to avoid redundant retrieval of key
// using the key_retriever.
int key_len = static_cast<int>(column_key.size());
auto aes_metadata_decryptor = encryption::AesDecryptor::Make(
algorithm_, key_len, /*metadata=*/true, &all_decryptors_);
auto aes_data_decryptor = encryption::AesDecryptor::Make(
algorithm_, key_len, /*metadata=*/false, &all_decryptors_);

column_metadata_map_[column_path] = std::make_shared<Decryptor>(
aes_metadata_decryptor, column_key, file_aad_, aad, pool_);
column_data_map_[column_path] =
std::make_shared<Decryptor>(aes_data_decryptor, column_key, file_aad_, aad, pool_);

if (metadata) return column_metadata_map_[column_path];
return column_data_map_[column_path];
std::lock_guard<std::mutex> lock(mutex_);
auto aes_decryptor =
encryption::AesDecryptor::Make(algorithm_, key_len, metadata, &all_decryptors_);
return std::make_shared<Decryptor>(std::move(aes_decryptor), column_key, file_aad_, aad,
pool_);
}

namespace {
Expand Down
7 changes: 4 additions & 3 deletions cpp/src/parquet/encryption/internal_file_decryptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <vector>

Expand Down Expand Up @@ -91,15 +92,15 @@ class InternalFileDecryptor {
FileDecryptionProperties* properties_;
// Concatenation of aad_prefix (if exists) and aad_file_unique
std::string file_aad_;
std::map<std::string, std::shared_ptr<Decryptor>> column_data_map_;
std::map<std::string, std::shared_ptr<Decryptor>> column_metadata_map_;

std::shared_ptr<Decryptor> footer_metadata_decryptor_;
std::shared_ptr<Decryptor> footer_data_decryptor_;
ParquetCipher::type algorithm_;
std::string footer_key_metadata_;
// Mutex to guard access to all_decryptors_
mutable std::mutex mutex_;
// A weak reference to all decryptors that need to be wiped out when decryption is
// finished
// finished, guarded by mutex_ for thread safety
std::vector<std::weak_ptr<encryption::AesDecryptor>> all_decryptors_;

::arrow::MemoryPool* pool_;
Expand Down
Loading

0 comments on commit dd6d728

Please sign in to comment.