Skip to content

Commit

Permalink
Issue duckdb#12171: Streaming Window FILTER
Browse files Browse the repository at this point in the history
Move filtering up to handle COUNT(*) FILTER(...)
Also refactor to share data structures between calls.
  • Loading branch information
hawkfish committed May 26, 2024
1 parent 3dc237e commit c226f37
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 62 deletions.
131 changes: 71 additions & 60 deletions src/execution/operator/aggregate/physical_streaming_window.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,41 +50,60 @@ class StreamingWindowState : public OperatorState {
public:
using StateBuffer = vector<data_t>;

StreamingWindowState()
: initialized(false), allocator(Allocator::DefaultAllocator()),
statev(LogicalType::POINTER, data_ptr_cast(&state_ptr)) {
}
struct AggregateState {
explicit AggregateState(BoundWindowExpression &wexpr)
: arena_allocator(Allocator::DefaultAllocator()), statev(LogicalType::POINTER, data_ptr_cast(&state_ptr)) {
D_ASSERT(wexpr.GetExpressionType() == ExpressionType::WINDOW_AGGREGATE);
auto &aggregate = *wexpr.aggregate;
bind_data = wexpr.bind_info.get();
dtor = aggregate.destructor;
state.resize(aggregate.state_size());
state_ptr = state.data();
aggregate.initialize(state.data());
for (auto &child : wexpr.children) {
arg_types.push_back(child->return_type);
}
if (wexpr.filter_expr) {
filter_sel.Initialize();
}
}

~StreamingWindowState() override {
for (size_t i = 0; i < aggregate_dtors.size(); ++i) {
auto dtor = aggregate_dtors[i];
~AggregateState() {
if (dtor) {
AggregateInputData aggr_input_data(aggregate_bind_data[i], allocator);
state_ptr = aggregate_states[i].data();
AggregateInputData aggr_input_data(bind_data, arena_allocator);
state_ptr = state.data();
dtor(statev, aggr_input_data, 1);
}
}

ArenaAllocator arena_allocator;
StateBuffer state;
data_ptr_t state_ptr = nullptr;
Vector statev;
FunctionData *bind_data = nullptr;
aggregate_destructor_t dtor = nullptr;
SelectionVector filter_sel;
int64_t unfiltered = 0;
vector<LogicalType> arg_types;
};

explicit StreamingWindowState(ClientContext &client) : initialized(false), allocator(Allocator::Get(client)) {
}

~StreamingWindowState() override {
}

void Initialize(ClientContext &context, DataChunk &input, const vector<unique_ptr<Expression>> &expressions) {
const_vectors.resize(expressions.size());
aggregate_states.resize(expressions.size());
aggregate_bind_data.resize(expressions.size(), nullptr);
aggregate_dtors.resize(expressions.size(), nullptr);

for (idx_t expr_idx = 0; expr_idx < expressions.size(); expr_idx++) {
auto &expr = *expressions[expr_idx];
auto &wexpr = expr.Cast<BoundWindowExpression>();
switch (expr.GetExpressionType()) {
case ExpressionType::WINDOW_AGGREGATE: {
auto &aggregate = *wexpr.aggregate;
auto &state = aggregate_states[expr_idx];
aggregate_bind_data[expr_idx] = wexpr.bind_info.get();
aggregate_dtors[expr_idx] = aggregate.destructor;
state.resize(aggregate.state_size());
aggregate.initialize(state.data());
case ExpressionType::WINDOW_AGGREGATE:
aggregate_states[expr_idx] = make_uniq<AggregateState>(wexpr);
break;
}
case ExpressionType::WINDOW_FIRST_VALUE: {
// Just execute the expression once
ExpressionExecutor executor(context);
Expand Down Expand Up @@ -115,22 +134,18 @@ class StreamingWindowState : public OperatorState {
public:
bool initialized;
vector<unique_ptr<Vector>> const_vectors;
ArenaAllocator allocator;

// Aggregation
vector<StateBuffer> aggregate_states;
vector<FunctionData *> aggregate_bind_data;
vector<aggregate_destructor_t> aggregate_dtors;
data_ptr_t state_ptr;
Vector statev;
vector<unique_ptr<AggregateState>> aggregate_states;
Allocator &allocator;
};

unique_ptr<GlobalOperatorState> PhysicalStreamingWindow::GetGlobalOperatorState(ClientContext &context) const {
return make_uniq<StreamingWindowGlobalState>();
}

unique_ptr<OperatorState> PhysicalStreamingWindow::GetOperatorState(ExecutionContext &context) const {
return make_uniq<StreamingWindowState>();
return make_uniq<StreamingWindowState>(context.client);
}

OperatorResultType PhysicalStreamingWindow::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk,
Expand All @@ -156,64 +171,60 @@ OperatorResultType PhysicalStreamingWindow::Execute(ExecutionContext &context, D
// Establish the aggregation environment
auto &wexpr = expr.Cast<BoundWindowExpression>();
auto &aggregate = *wexpr.aggregate;
auto &statev = state.statev;
state.state_ptr = state.aggregate_states[expr_idx].data();
AggregateInputData aggr_input_data(wexpr.bind_info.get(), state.allocator);
auto &aggr_state = *state.aggregate_states[expr_idx];
auto &statev = aggr_state.statev;
AggregateInputData aggr_input_data(wexpr.bind_info.get(), aggr_state.arena_allocator);

// Compute the FILTER mask (if any)
ValidityMask filter_mask;
if (wexpr.filter_expr) {
auto &filter_sel = aggr_state.filter_sel;
ExpressionExecutor filter_executor(context.client, *wexpr.filter_expr);
const auto filtered = filter_executor.SelectExpression(input, filter_sel);
if (filtered < count) {
filter_mask.Initialize(count);
filter_mask.SetAllInvalid(count);
for (idx_t f = 0; f < filtered; ++f) {
filter_mask.SetValid(filter_sel.get_index(f));
}
}
}

// Check for COUNT(*)
if (wexpr.children.empty()) {
D_ASSERT(GetTypeIdSize(result.GetType().InternalType()) == sizeof(int64_t));
auto data = FlatVector::GetData<int64_t>(result);
int64_t start_row = gstate.row_number;
auto &unfiltered = aggr_state.unfiltered;
for (idx_t i = 0; i < count; ++i) {
data[i] = NumericCast<int64_t>(start_row + NumericCast<int64_t>(i));
unfiltered += int64_t(filter_mask.RowIsValid(i));
data[i] = unfiltered;
}
break;
}

// Compute the arguments
auto &allocator = Allocator::Get(context.client);
ExpressionExecutor executor(context.client);
vector<LogicalType> payload_types;
for (auto &child : wexpr.children) {
payload_types.push_back(child->return_type);
executor.AddExpression(*child);
}

DataChunk payload;
payload.Initialize(allocator, payload_types);
executor.Execute(input, payload);

// Compute the FILTER mask (if any)
vector<validity_t> filter_bits;
ValidityMask filter_mask;
if (wexpr.filter_expr) {
ExpressionExecutor filter_executor(context.client);
filter_executor.AddExpression(*wexpr.filter_expr);
SelectionVector filter_sel(count);
const auto filtered = filter_executor.SelectExpression(input, filter_sel);
if (filtered < count) {
filter_bits.resize(ValidityMask::ValidityMaskSize(count), 0);
filter_mask.Initialize(filter_bits.data());
for (idx_t f = 0; f < filtered; ++f) {
filter_mask.SetValid(filter_sel.get_index(f));
}
}
}
DataChunk arg_chunk;
arg_chunk.Initialize(state.allocator, aggr_state.arg_types);
executor.Execute(input, arg_chunk);

// Iterate through them using a single SV
payload.Flatten();
DataChunk row;
row.Initialize(allocator, payload_types);
arg_chunk.Flatten();
sel_t s = 0;
SelectionVector sel(&s);
DataChunk row;
row.Initialize(state.allocator, aggr_state.arg_types);
row.Slice(sel, 1);
// This doesn't work for STRUCTs because the SV
// is not copied to the children when you slice
vector<column_t> structs;
for (column_t col_idx = 0; col_idx < payload.ColumnCount(); ++col_idx) {
for (column_t col_idx = 0; col_idx < arg_chunk.ColumnCount(); ++col_idx) {
auto &col_vec = row.data[col_idx];
DictionaryVector::Child(col_vec).Reference(payload.data[col_idx]);
DictionaryVector::Child(col_vec).Reference(arg_chunk.data[col_idx]);
if (col_vec.GetType().InternalType() == PhysicalType::STRUCT) {
structs.emplace_back(col_idx);
}
Expand All @@ -223,7 +234,7 @@ OperatorResultType PhysicalStreamingWindow::Execute(ExecutionContext &context, D
for (idx_t i = 0; i < count; ++i) {
sel.set_index(0, i);
for (const auto struct_idx : structs) {
row.data[struct_idx].Slice(payload.data[struct_idx], sel, 1);
row.data[struct_idx].Slice(arg_chunk.data[struct_idx], sel, 1);
}
if (filter_mask.RowIsValid(i)) {
aggregate.update(row.data.data(), aggr_input_data, row.ColumnCount(), statev, 1);
Expand Down
23 changes: 21 additions & 2 deletions test/sql/window/test_streaming_window.test
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,19 @@ SELECT j, COUNT(j) FILTER(WHERE i = 2) OVER(ROWS BETWEEN UNBOUNDED PRECEDING AND
----
physical_plan <REGEX>:.*STREAMING_WINDOW.*

query TT
EXPLAIN
SELECT j, COUNT(*) FILTER(WHERE i = 2) OVER(ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM integers;
----
physical_plan <REGEX>:.*STREAMING_WINDOW.*

query TT
EXPLAIN
SELECT j, SUM(j) FILTER(WHERE i = 2) OVER(ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM integers;
----
physical_plan <REGEX>:.*STREAMING_WINDOW.*

# DISTINCT is not supported for streaming windows
# DISTINCT is not supported for streaming windows
query TT
EXPLAIN
SELECT
Expand All @@ -53,7 +59,6 @@ FROM (VALUES ({'key': 'A'}), ({'key': 'B'}), ({'key': 'A'}))
----
physical_plan <!REGEX>:.*STREAMING_WINDOW.*

# DISTINCT is not supported for streaming windows
query TT
EXPLAIN
SELECT
Expand Down Expand Up @@ -130,6 +135,20 @@ SELECT i, COUNT(*) OVER(ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM i
1 3
1 4

query TT
EXPLAIN
SELECT i, COUNT(*) FILTER(WHERE i = 2) OVER(ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM integers;
----
physical_plan <REGEX>:.*STREAMING_WINDOW.*

query II
SELECT i, COUNT(*) FILTER(WHERE i = 2) OVER(ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM integers;
----
2 1
2 2
1 2
1 2

query TT
EXPLAIN
SELECT j, COUNT(j) OVER(ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM integers;
Expand Down

0 comments on commit c226f37

Please sign in to comment.