Skip to content

Commit

Permalink
Feature duckdb#1272: Window Parallel Sink
Browse files Browse the repository at this point in the history
Implement thread-safer ValidityArray class and remove lock.
  • Loading branch information
Richard Wesley committed Jul 10, 2024
1 parent c419f89 commit e0de8ee
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/execution/window_segment_tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ class WindowAggregatorGlobalState : public WindowAggregatorState {
}
if (aggregator.aggr.filter) {
// Start with all invalid and set the ones that pass
filter_bits.resize(ValidityMask::ValidityMaskSize(group_count), 0);
filter_mask.Initialize(filter_bits.data());
filter_mask.Initialize(group_count, false);
}
}

Expand All @@ -41,8 +40,7 @@ class WindowAggregatorGlobalState : public WindowAggregatorState {
WindowDataChunk winputs;

//! The filtered rows in inputs.
vector<validity_t> filter_bits;
ValidityMask filter_mask;
ValidityArray filter_mask;

//! Lock for single threading
mutex lock;
Expand Down Expand Up @@ -70,9 +68,6 @@ void WindowAggregator::Sink(WindowAggregatorState &gsink, DataChunk &payload_chu
winputs.Copy(payload_chunk, input_idx);
}
if (filter_sel) {
// Single threaded for now.
// TODO: Check for mask boundaries.
lock_guard<mutex> sink_guard(gasink.lock);
for (idx_t f = 0; f < filtered; ++f) {
filter_mask.SetValid(input_idx + filter_sel->get_index(f));
}
Expand Down Expand Up @@ -389,6 +384,8 @@ class WindowCustomAggregatorGlobalState : public WindowAggregatorGlobalState {
gcstate = make_uniq<WindowCustomAggregatorState>(aggregator.aggr, aggregator.exclude_mode);
}

//! Traditional packed filter mask for API
ValidityMask filter_packed;
//! Data pointer that contains a single local state, used for global custom window execution state
unique_ptr<WindowCustomAggregatorState> gcstate;
//! Partition description for custom window APIs
Expand Down Expand Up @@ -423,8 +420,11 @@ void WindowCustomAggregator::Finalize(WindowAggregatorState &gsink, const FrameS
auto &gcsink = gsink.Cast<WindowCustomAggregatorGlobalState>();
auto &inputs = gcsink.inputs;
auto &filter_mask = gcsink.filter_mask;
auto &filter_packed = gcsink.filter_packed;
filter_mask.Pack(filter_packed, filter_mask.target_count);

gcsink.partition_input =
make_uniq<WindowPartitionInput>(inputs.data.data(), inputs.ColumnCount(), inputs.size(), filter_mask, stats);
make_uniq<WindowPartitionInput>(inputs.data.data(), inputs.ColumnCount(), inputs.size(), filter_packed, stats);

if (aggr.function.window_init) {
auto &gcstate = *gcsink.gcstate;
Expand Down Expand Up @@ -791,7 +791,7 @@ class WindowSegmentTreePart {
enum FramePart : uint8_t { FULL = 0, LEFT = 1, RIGHT = 2 };

WindowSegmentTreePart(ArenaAllocator &allocator, const AggregateObject &aggr, const DataChunk &inputs,
const ValidityMask &filter_mask);
const ValidityArray &filter_mask);
~WindowSegmentTreePart();

unique_ptr<WindowSegmentTreePart> Copy() const {
Expand Down Expand Up @@ -829,7 +829,7 @@ class WindowSegmentTreePart {
//! The partition arguments
const DataChunk &inputs;
//! The filtered rows in inputs
const ValidityMask &filter_mask;
const ValidityArray &filter_mask;
//! The size of a single aggregate state
const idx_t state_size;
//! Data pointer that contains a vector of states, used for intermediate window segment aggregation
Expand Down Expand Up @@ -864,7 +864,7 @@ class WindowSegmentTreeState : public WindowAggregatorState {
};

WindowSegmentTreePart::WindowSegmentTreePart(ArenaAllocator &allocator, const AggregateObject &aggr,
const DataChunk &inputs, const ValidityMask &filter_mask)
const DataChunk &inputs, const ValidityArray &filter_mask)
: allocator(allocator), aggr(aggr),
order_insensitive(aggr.function.order_dependent == AggregateOrderDependent::NOT_ORDER_DEPENDENT), inputs(inputs),
filter_mask(filter_mask), state_size(aggr.function.state_size()), state(state_size * STANDARD_VECTOR_SIZE),
Expand Down
84 changes: 84 additions & 0 deletions src/include/duckdb/common/types/validity_mask.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,4 +351,88 @@ struct ValidityMask : public TemplatedValidityMask<validity_t> {
void Read(ReadStream &reader, idx_t count);
};

//===--------------------------------------------------------------------===//
// ValidityArray
//===--------------------------------------------------------------------===//
struct ValidityArray {
inline ValidityArray() {
}

inline bool AllValid() const {
return !validity_mask;
}

inline void Initialize(idx_t count, bool initial = true) {
target_count = count;
validity_data = make_unsafe_uniq_array<bool>(count);
validity_mask = validity_data.get();
memset(validity_mask, initial, sizeof(bool) * count);
}

//! RowIsValidUnsafe should only be used if AllValid() is false: it achieves the same as RowIsValid but skips a
//! not-null check
inline bool RowIsValidUnsafe(idx_t row_idx) const {
D_ASSERT(validity_mask);
return validity_mask[row_idx];
}

//! Returns true if a row is valid (i.e. not null), false otherwise
inline bool RowIsValid(idx_t row_idx) const {
if (!validity_mask) {
return true;
}
return RowIsValidUnsafe(row_idx);
}

//! Same as SetValid, but skips a null check on validity_mask
inline void SetValidUnsafe(idx_t row_idx) {
D_ASSERT(validity_mask);
validity_mask[row_idx] = true;
}

//! Marks the entry at the specified row index as valid (i.e. not-null)
inline void SetValid(idx_t row_idx) {
if (!validity_mask) {
// if AllValid() we don't need to do anything
// the row is already valid
return;
}

SetValidUnsafe(row_idx);
}

inline void Pack(ValidityMask &mask, const idx_t count) const {
if (AllValid()) {
mask.Reset();
return;
}
mask.Initialize(count);

const auto entire_entries = count / ValidityMask::BITS_PER_VALUE;
const auto ragged = count % ValidityMask::BITS_PER_VALUE;
auto bits = mask.GetData();
idx_t row_idx = 0;
for (idx_t i = 0; i < entire_entries; ++i) {
validity_t entry = 0;
for (idx_t j = 0; j < ValidityMask::BITS_PER_VALUE; ++j) {
if (RowIsValidUnsafe(row_idx++)) {
entry |= validity_t(1) << j;
}
}
*bits++ = entry;
}
validity_t entry = 0;
for (idx_t j = 0; j < ragged; ++j) {
if (RowIsValidUnsafe(row_idx++)) {
entry |= validity_t(1) << j;
}
}
*bits++ = entry;
}

bool *validity_mask = nullptr;
unsafe_unique_array<bool> validity_data;
idx_t target_count = 0;
};

} // namespace duckdb

0 comments on commit e0de8ee

Please sign in to comment.