Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT]: sql HAVING #3364

Merged
merged 5 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,13 +347,20 @@ impl<'a> SQLPlanner<'a> {
let has_aggs = projections.iter().any(has_agg) || !groupby_exprs.is_empty();

if has_aggs {
let having = selection
.having
.as_ref()
.map(|h| self.plan_expr(h))
.transpose()?;

self.plan_aggregate_query(
&projections,
&schema,
has_orderby,
groupby_exprs,
query,
&projection_schema,
having,
)?;
} else {
self.plan_non_agg_query(projections, schema, has_orderby, query, projection_schema)?;
Expand Down Expand Up @@ -464,6 +471,7 @@ impl<'a> SQLPlanner<'a> {
Ok(())
}

#[allow(clippy::too_many_arguments)]
fn plan_aggregate_query(
&mut self,
projections: &Vec<Arc<Expr>>,
Expand All @@ -472,6 +480,7 @@ impl<'a> SQLPlanner<'a> {
groupby_exprs: Vec<Arc<Expr>>,
query: &Query,
projection_schema: &Schema,
having: Option<Arc<Expr>>,
) -> Result<(), PlannerError> {
let mut final_projection = Vec::with_capacity(projections.len());
let mut aggs = Vec::with_capacity(projections.len());
Expand Down Expand Up @@ -500,6 +509,15 @@ impl<'a> SQLPlanner<'a> {
final_projection.push(p.clone());
}
}

if let Some(having) = &having {
if has_agg(having) {
let having = having.alias(having.semantic_id(schema).id);

aggs.push(having);
}
}

let groupby_exprs = groupby_exprs
.into_iter()
.map(|e| {
Expand Down Expand Up @@ -631,7 +649,7 @@ impl<'a> SQLPlanner<'a> {
}

let rel = self.relation_mut();
rel.inner = rel.inner.aggregate(aggs, groupby_exprs)?;
rel.inner = rel.inner.aggregate(aggs.clone(), groupby_exprs)?;

let has_orderby_before_projection = !orderbys_before_projection.is_empty();
let has_orderby_after_projection = !orderbys_after_projection.is_empty();
Expand All @@ -650,6 +668,16 @@ impl<'a> SQLPlanner<'a> {
)?;
}

if let Some(having) = having {
// if it's an agg, it's already resolved during .agg, so we just reference the column name
let having = if has_agg(&having) {
col(having.semantic_id(schema).id)
} else {
having
};
rel.inner = rel.inner.filter(having)?;
}

// apply the final projection
rel.inner = rel.inner.select(final_projection)?;

Expand All @@ -661,6 +689,7 @@ impl<'a> SQLPlanner<'a> {
orderbys_after_projection_nulls_first,
)?;
}

Ok(())
}

Expand Down Expand Up @@ -1999,9 +2028,7 @@ fn check_select_features(selection: &sqlparser::ast::Select) -> SQLPlannerResult
if !selection.sort_by.is_empty() {
unsupported_sql_err!("SORT BY");
}
if selection.having.is_some() {
unsupported_sql_err!("HAVING");
}

if !selection.named_window.is_empty() {
unsupported_sql_err!("WINDOW");
}
Expand Down
63 changes: 63 additions & 0 deletions tests/sql/test_aggs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pytest

import daft
from daft import col
from daft.sql import SQLCatalog


def test_aggs_sql():
Expand Down Expand Up @@ -41,3 +44,63 @@ def test_aggs_sql():
)

assert actual == expected


@pytest.mark.parametrize(
"agg,cond,expected",
[
("sum(values)", "sum(values) > 10", {"values": [20.5, 29.5]}),
("sum(values)", "values > 10", {"values": [20.5, 29.5]}),
("sum(values) as sum_v", "sum(values) > 10", {"sum_v": [20.5, 29.5]}),
("sum(values) as sum_v", "sum_v > 10", {"sum_v": [20.5, 29.5]}),
("count(*) as cnt", "cnt > 2", {"cnt": [3, 5]}),
("count(*) as cnt", "count(*) > 2", {"cnt": [3, 5]}),
("count(*)", "count(*) > 2", {"count": [3, 5]}),
universalmind303 marked this conversation as resolved.
Show resolved Hide resolved
("count(*) as cnt", "sum(values) > 10", {"cnt": [3, 5]}),
("sum(values), count(*)", "id > 1", {"values": [10.0, 29.5], "count": [2, 5]}),
],
)
def test_having(agg, cond, expected):
df = daft.from_pydict(
{
"id": [1, 2, 3, 3, 3, 3, 2, 1, 3, 1],
"values": [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5],
}
)
catalog = SQLCatalog({"df": df})

actual = daft.sql(
f"""
SELECT
{agg},
from df
group by id
having {cond}
""",
catalog,
).to_pydict()

assert actual == expected


def test_having_non_grouped():
df = daft.from_pydict(
{
"id": [1, 2, 3, 3, 3, 3, 2, 1, 3, 1],
"values": [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5],
"floats": [0.01, 0.011, 0.01047, 0.02, 0.019, 0.018, 0.017, 0.016, 0.015, 0.014],
}
)
catalog = SQLCatalog({"df": df})

actual = daft.sql(
"""
SELECT
count(*) ,
from df
having sum(values) > 40
""",
catalog,
).to_pydict()

assert actual == {"count": [10]}