diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 1eb1169b5a..f6334c5632 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -347,6 +347,12 @@ impl<'a> SQLPlanner<'a> { let has_aggs = projections.iter().any(has_agg) || !groupby_exprs.is_empty(); if has_aggs { + let having = selection + .having + .as_ref() + .map(|h| self.plan_expr(h)) + .transpose()?; + self.plan_aggregate_query( &projections, &schema, @@ -354,6 +360,7 @@ impl<'a> SQLPlanner<'a> { groupby_exprs, query, &projection_schema, + having, )?; } else { self.plan_non_agg_query(projections, schema, has_orderby, query, projection_schema)?; @@ -464,6 +471,7 @@ impl<'a> SQLPlanner<'a> { Ok(()) } + #[allow(clippy::too_many_arguments)] fn plan_aggregate_query( &mut self, projections: &Vec>, @@ -472,6 +480,7 @@ impl<'a> SQLPlanner<'a> { groupby_exprs: Vec>, query: &Query, projection_schema: &Schema, + having: Option>, ) -> Result<(), PlannerError> { let mut final_projection = Vec::with_capacity(projections.len()); let mut aggs = Vec::with_capacity(projections.len()); @@ -500,6 +509,15 @@ impl<'a> SQLPlanner<'a> { final_projection.push(p.clone()); } } + + if let Some(having) = &having { + if has_agg(having) { + let having = having.alias(having.semantic_id(schema).id); + + aggs.push(having); + } + } + let groupby_exprs = groupby_exprs .into_iter() .map(|e| { @@ -631,7 +649,7 @@ impl<'a> SQLPlanner<'a> { } let rel = self.relation_mut(); - rel.inner = rel.inner.aggregate(aggs, groupby_exprs)?; + rel.inner = rel.inner.aggregate(aggs.clone(), groupby_exprs)?; let has_orderby_before_projection = !orderbys_before_projection.is_empty(); let has_orderby_after_projection = !orderbys_after_projection.is_empty(); @@ -650,6 +668,16 @@ impl<'a> SQLPlanner<'a> { )?; } + if let Some(having) = having { + // if it's an agg, it's already resolved during .agg, so we just reference the column name + let having = if has_agg(&having) { + col(having.semantic_id(schema).id) + } else { + having + }; + rel.inner = rel.inner.filter(having)?; + } + // apply the final projection rel.inner = rel.inner.select(final_projection)?; @@ -661,6 +689,7 @@ impl<'a> SQLPlanner<'a> { orderbys_after_projection_nulls_first, )?; } + Ok(()) } @@ -1999,9 +2028,7 @@ fn check_select_features(selection: &sqlparser::ast::Select) -> SQLPlannerResult if !selection.sort_by.is_empty() { unsupported_sql_err!("SORT BY"); } - if selection.having.is_some() { - unsupported_sql_err!("HAVING"); - } + if !selection.named_window.is_empty() { unsupported_sql_err!("WINDOW"); } diff --git a/tests/sql/test_aggs.py b/tests/sql/test_aggs.py index 6d64878070..9e69742b88 100644 --- a/tests/sql/test_aggs.py +++ b/tests/sql/test_aggs.py @@ -1,5 +1,8 @@ +import pytest + import daft from daft import col +from daft.sql import SQLCatalog def test_aggs_sql(): @@ -41,3 +44,63 @@ def test_aggs_sql(): ) assert actual == expected + + +@pytest.mark.parametrize( + "agg,cond,expected", + [ + ("sum(values)", "sum(values) > 10", {"values": [20.5, 29.5]}), + ("sum(values)", "values > 10", {"values": [20.5, 29.5]}), + ("sum(values) as sum_v", "sum(values) > 10", {"sum_v": [20.5, 29.5]}), + ("sum(values) as sum_v", "sum_v > 10", {"sum_v": [20.5, 29.5]}), + ("count(*) as cnt", "cnt > 2", {"cnt": [3, 5]}), + ("count(*) as cnt", "count(*) > 2", {"cnt": [3, 5]}), + ("count(*)", "count(*) > 2", {"count": [3, 5]}), + ("count(*) as cnt", "sum(values) > 10", {"cnt": [3, 5]}), + ("sum(values), count(*)", "id > 1", {"values": [10.0, 29.5], "count": [2, 5]}), + ], +) +def test_having(agg, cond, expected): + df = daft.from_pydict( + { + "id": [1, 2, 3, 3, 3, 3, 2, 1, 3, 1], + "values": [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5], + } + ) + catalog = SQLCatalog({"df": df}) + + actual = daft.sql( + f""" + SELECT + {agg}, + from df + group by id + having {cond} + """, + catalog, + ).to_pydict() + + assert actual == expected + + +def test_having_non_grouped(): + df = daft.from_pydict( + { + "id": [1, 2, 3, 3, 3, 3, 2, 1, 3, 1], + "values": [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5], + "floats": [0.01, 0.011, 0.01047, 0.02, 0.019, 0.018, 0.017, 0.016, 0.015, 0.014], + } + ) + catalog = SQLCatalog({"df": df}) + + actual = daft.sql( + """ + SELECT + count(*) , + from df + having sum(values) > 40 + """, + catalog, + ).to_pydict() + + assert actual == {"count": [10]}