Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-36092: [C++] Simplify concurrency in as-of-join node #36094

Merged
merged 1 commit into from
Jun 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 10 additions & 14 deletions cpp/src/arrow/acero/asof_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ struct MemoStore {
// the time of the current entry, defaulting to 0.
// when entries with a time less than T are removed, the current time is updated to the
// time of the next (by-time) and now-current entry or to T if no such entry exists.
std::atomic<OnType> current_time_;
OnType current_time_;
// current entry per key
std::unordered_map<ByType, Entry> entries_;
// future entries per key
Expand All @@ -364,21 +364,16 @@ struct MemoStore {
std::swap(index_, memo.index_);
#endif
std::swap(no_future_, memo.no_future_);
current_time_ = memo.current_time_.exchange(static_cast<OnType>(current_time_));
std::swap(current_time_, memo.current_time_);
entries_.swap(memo.entries_);
future_entries_.swap(memo.future_entries_);
times_.swap(memo.times_);
}

// Updates the current time to `ts` if it is less. A different thread may win the race
// to update the current time to more than `ts` but not to less. Returns whether the
// current time was changed from its value at the beginning of this invocation.
// Updates the current time to `ts` if it is less. Returns true if updated.
bool UpdateTime(OnType ts) {
OnType prev_time = current_time_;
bool update = prev_time < ts;
while (prev_time < ts && !current_time_.compare_exchange_weak(prev_time, ts)) {
// intentionally empty - standard CAS loop
}
bool update = current_time_ < ts;
if (update) current_time_ = ts;
Comment on lines +375 to +376
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
bool update = current_time_ < ts;
if (update) current_time_ = ts;
current_time_ = std::min(current_time_, ts);

Minor nit: Maybe simpler?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method still needs to return a Boolean, so we can't remove the bool update line, but we can replace the if-statement.

return update;
}

Expand Down Expand Up @@ -529,7 +524,7 @@ class KeyHasher {
size_t index_;
std::vector<col_index_t> indices_;
std::vector<KeyColumnMetadata> metadata_;
std::atomic<const RecordBatch*> batch_;
const RecordBatch* batch_;
std::vector<HashType> hashes_;
LightContext ctx_;
std::vector<KeyColumnArray> column_arrays_;
Expand Down Expand Up @@ -821,8 +816,11 @@ class InputState {
++batches_processed_;
latest_ref_row_ = 0;
have_active_batch &= !queue_.TryPop();
if (have_active_batch)
if (have_active_batch) {
DCHECK_GT(queue_.UnsyncFront()->num_rows(), 0); // empty batches disallowed
key_hasher_->Invalidate(); // batch changed - invalidate key hasher's cache
memo_.UpdateTime(GetTime(queue_.UnsyncFront().get(), 0)); // time changed
}
}
}
return have_active_batch;
Expand Down Expand Up @@ -898,8 +896,6 @@ class InputState {

Status Push(const std::shared_ptr<arrow::RecordBatch>& rb) {
if (rb->num_rows() > 0) {
key_hasher_->Invalidate(); // batch changed - invalidate key hasher's cache
memo_.UpdateTime(GetTime(rb.get(), 0)); // time changed - update in MemoStore
queue_.Push(rb); // only after above updates - push batch for processing
} else {
++batches_processed_; // don't enqueue empty batches, just record as processed
Expand Down