diff --git a/cpp/src/arrow/acero/asof_join_node.cc b/cpp/src/arrow/acero/asof_join_node.cc index f8dee5aac8815..b168ec438e216 100644 --- a/cpp/src/arrow/acero/asof_join_node.cc +++ b/cpp/src/arrow/acero/asof_join_node.cc @@ -344,7 +344,7 @@ struct MemoStore { // the time of the current entry, defaulting to 0. // when entries with a time less than T are removed, the current time is updated to the // time of the next (by-time) and now-current entry or to T if no such entry exists. - std::atomic current_time_; + OnType current_time_; // current entry per key std::unordered_map entries_; // future entries per key @@ -364,21 +364,16 @@ struct MemoStore { std::swap(index_, memo.index_); #endif std::swap(no_future_, memo.no_future_); - current_time_ = memo.current_time_.exchange(static_cast(current_time_)); + std::swap(current_time_, memo.current_time_); entries_.swap(memo.entries_); future_entries_.swap(memo.future_entries_); times_.swap(memo.times_); } - // Updates the current time to `ts` if it is less. A different thread may win the race - // to update the current time to more than `ts` but not to less. Returns whether the - // current time was changed from its value at the beginning of this invocation. + // Updates the current time to `ts` if it is less. Returns true if updated. bool UpdateTime(OnType ts) { - OnType prev_time = current_time_; - bool update = prev_time < ts; - while (prev_time < ts && !current_time_.compare_exchange_weak(prev_time, ts)) { - // intentionally empty - standard CAS loop - } + bool update = current_time_ < ts; + if (update) current_time_ = ts; return update; } @@ -529,7 +524,7 @@ class KeyHasher { size_t index_; std::vector indices_; std::vector metadata_; - std::atomic batch_; + const RecordBatch* batch_; std::vector hashes_; LightContext ctx_; std::vector column_arrays_; @@ -821,8 +816,11 @@ class InputState { ++batches_processed_; latest_ref_row_ = 0; have_active_batch &= !queue_.TryPop(); - if (have_active_batch) + if (have_active_batch) { DCHECK_GT(queue_.UnsyncFront()->num_rows(), 0); // empty batches disallowed + key_hasher_->Invalidate(); // batch changed - invalidate key hasher's cache + memo_.UpdateTime(GetTime(queue_.UnsyncFront().get(), 0)); // time changed + } } } return have_active_batch; @@ -898,8 +896,6 @@ class InputState { Status Push(const std::shared_ptr& rb) { if (rb->num_rows() > 0) { - key_hasher_->Invalidate(); // batch changed - invalidate key hasher's cache - memo_.UpdateTime(GetTime(rb.get(), 0)); // time changed - update in MemoStore queue_.Push(rb); // only after above updates - push batch for processing } else { ++batches_processed_; // don't enqueue empty batches, just record as processed