From 2f92a5b6beabb6669776611b9d6530f8e7ce4615 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Mon, 25 Nov 2024 16:30:46 +0100 Subject: [PATCH] GH-44846: [C++] Fix thread-unsafe access in ConcurrentQueue::UnsyncFront --- cpp/src/arrow/acero/asof_join_node.cc | 9 ++++----- .../arrow/acero/concurrent_queue_internal.h | 20 ++++++++++--------- cpp/src/arrow/acero/sorted_merge_node.cc | 6 +++--- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/cpp/src/arrow/acero/asof_join_node.cc b/cpp/src/arrow/acero/asof_join_node.cc index 92e404b207c89..3ab976e671ccf 100644 --- a/cpp/src/arrow/acero/asof_join_node.cc +++ b/cpp/src/arrow/acero/asof_join_node.cc @@ -567,7 +567,7 @@ class InputState : public util::SerialSequencingQueue::Processor { // Gets latest batch (precondition: must not be empty) const std::shared_ptr& GetLatestBatch() const { - return queue_.UnsyncFront(); + return queue_.Front(); } #define LATEST_VAL_CASE(id, val) \ @@ -634,15 +634,14 @@ class InputState : public util::SerialSequencingQueue::Processor { } latest_time_ = next_time; // If we have an active batch - if (++latest_ref_row_ >= (row_index_t)queue_.UnsyncFront()->num_rows()) { + if (++latest_ref_row_ >= (row_index_t)queue_.Front()->num_rows()) { // hit the end of the batch, need to get the next batch if possible. ++batches_processed_; latest_ref_row_ = 0; have_active_batch &= !queue_.TryPop(); if (have_active_batch) { - DCHECK_GT(queue_.UnsyncFront()->num_rows(), 0); // empty batches disallowed - memo_.UpdateTime(GetTime(queue_.UnsyncFront().get(), time_type_id_, - time_col_index_, + DCHECK_GT(queue_.Front()->num_rows(), 0); // empty batches disallowed + memo_.UpdateTime(GetTime(queue_.Front().get(), time_type_id_, time_col_index_, 0)); // time changed } } diff --git a/cpp/src/arrow/acero/concurrent_queue_internal.h b/cpp/src/arrow/acero/concurrent_queue_internal.h index f530394187299..20ec2089bee41 100644 --- a/cpp/src/arrow/acero/concurrent_queue_internal.h +++ b/cpp/src/arrow/acero/concurrent_queue_internal.h @@ -65,17 +65,19 @@ class ConcurrentQueue { return queue_.empty(); } - // Un-synchronized access to front - // For this to be "safe": - // 1) the caller logically guarantees that queue is not empty - // 2) pop/try_pop cannot be called concurrently with this - const T& UnsyncFront() const { return queue_.front(); } - - size_t UnsyncSize() const { return queue_.size(); } + const T& Front() const { + // Need to lock the queue because `front()` may be implemented in terms + // of `begin()`, which isn't safe with concurrent calls to e.g. `push()`. + // (see GH-44846) + std::unique_lock lock(mutex_); + return queue_.front(); + } protected: std::mutex& GetMutex() { return mutex_; } + size_t SizeUnlocked() const { return queue_.size(); } + T PopUnlocked() { auto item = queue_.front(); queue_.pop(); @@ -111,12 +113,12 @@ class BackpressureConcurrentQueue : public ConcurrentQueue { private: struct DoHandle { explicit DoHandle(BackpressureConcurrentQueue& queue) - : queue_(queue), start_size_(queue_.UnsyncSize()) {} + : queue_(queue), start_size_(queue_.SizeUnlocked()) {} ~DoHandle() { // unsynced access is safe since DoHandle is internally only used when the // lock is held - size_t end_size = queue_.UnsyncSize(); + size_t end_size = queue_.SizeUnlocked(); queue_.handler_.Handle(start_size_, end_size); } diff --git a/cpp/src/arrow/acero/sorted_merge_node.cc b/cpp/src/arrow/acero/sorted_merge_node.cc index 2845383cee982..c49aca17fb20a 100644 --- a/cpp/src/arrow/acero/sorted_merge_node.cc +++ b/cpp/src/arrow/acero/sorted_merge_node.cc @@ -145,7 +145,7 @@ class InputState { // Gets latest batch (precondition: must not be empty) const std::shared_ptr& GetLatestBatch() const { - return queue_.UnsyncFront(); + return queue_.Front(); } #define LATEST_VAL_CASE(id, val) \ @@ -178,7 +178,7 @@ class InputState { row_index_t start = latest_ref_row_; row_index_t end = latest_ref_row_; time_unit_t startTime = GetLatestTime(); - std::shared_ptr batch = queue_.UnsyncFront(); + std::shared_ptr batch = queue_.Front(); auto rows_in_batch = (row_index_t)batch->num_rows(); while (GetLatestTime() == startTime) { @@ -190,7 +190,7 @@ class InputState { latest_ref_row_ = 0; active &= !queue_.TryPop(); if (active) { - DCHECK_GT(queue_.UnsyncFront()->num_rows(), + DCHECK_GT(queue_.Front()->num_rows(), 0); // empty batches disallowed, sanity check } break;