From ec3a451d5231252daed0f8c724b1e16d08c22047 Mon Sep 17 00:00:00 2001 From: Justin Jaffray Date: Wed, 16 Jan 2019 02:19:39 -0500 Subject: [PATCH] opt: add support for aggregate FILTER This commit introduces support for FILTERing individual aggregates in a group by. One unintuitive change is that if we have a COUNT(*) with a FILTER, we need to synthesize an input for the aggregation to be over, because otherwise we have no input to hang the AggFilter off of. Thus, we convert COUNT(*) FILTER (WHERE foo) to COUNT(true) FILTER (WHERE foo). This isn't a problem for DISTINCT because COUNT(DISTINCT *) is not valid. Release note (sql change): FILTER expressions are now supported by the cost-based optimizer. --- .../logictest/testdata/logic_test/optimizer | 13 +- .../exec/execbuilder/relational_builder.go | 11 ++ .../opt/exec/execbuilder/testdata/aggregate | 83 ++++++----- .../opt/exec/execbuilder/testdata/distsql_agg | 8 +- pkg/sql/opt/exec/factory.go | 4 + pkg/sql/opt/memo/check_expr.go | 5 + pkg/sql/opt/memo/expr_format.go | 11 ++ pkg/sql/opt/memo/extract.go | 25 +++- pkg/sql/opt/memo/typing.go | 1 + pkg/sql/opt/ops/scalar.opt | 10 ++ pkg/sql/opt/optbuilder/groupby.go | 51 +++++-- pkg/sql/opt/optbuilder/scope.go | 56 +++++-- pkg/sql/opt/optbuilder/testdata/aggregate | 137 +++++++++++++++++- pkg/sql/opt_exec_factory.go | 8 + 14 files changed, 346 insertions(+), 77 deletions(-) diff --git a/pkg/sql/logictest/testdata/logic_test/optimizer b/pkg/sql/logictest/testdata/logic_test/optimizer index a1924d4ca226..5568b4fa18f9 100644 --- a/pkg/sql/logictest/testdata/logic_test/optimizer +++ b/pkg/sql/logictest/testdata/logic_test/optimizer @@ -74,10 +74,13 @@ SELECT * FROM test.t WHERE 2*v > (SELECT max(v) FROM test.t) # Fall back to heuristic planner when feature is not support in cost-based # optimizer. -query I rowsort -SELECT count(*) FILTER (WHERE v>10) FROM t +query T +SELECT avg(k) OVER () FROM t ORDER BY 1 ---- -3 +2.5 +2.5 +2.5 +2.5 query II rowsort SELECT * FROM tview @@ -101,8 +104,8 @@ query error pq: unsupported statement: \*tree\.AlterTable ALTER TABLE test DROP COLUMN v; # Don't fall back to heuristic planner in ALWAYS mode. -query error pq: aggregates with FILTER are not supported yet -SELECT count(*) FILTER (WHERE v>10) FROM t +query error pq: window functions are not supported +SELECT avg(k) OVER () FROM t ORDER BY 1 query error pq: sequences are not supported SELECT * FROM seq diff --git a/pkg/sql/opt/exec/execbuilder/relational_builder.go b/pkg/sql/opt/exec/execbuilder/relational_builder.go index 765712b7b3db..1bf71c54f78d 100644 --- a/pkg/sql/opt/exec/execbuilder/relational_builder.go +++ b/pkg/sql/opt/exec/execbuilder/relational_builder.go @@ -634,10 +634,20 @@ func (b *Builder) buildGroupBy(groupBy memo.RelExpr) (execPlan, error) { distinct := false var argIdx []exec.ColumnOrdinal + var filterOrd exec.ColumnOrdinal = -1 if item.Agg.ChildCount() > 0 { child := item.Agg.Child(0) + if aggFilter, ok := child.(*memo.AggFilterExpr); ok { + filter, ok := aggFilter.Filter.(*memo.VariableExpr) + if !ok { + return execPlan{}, errors.Errorf("only VariableOp args supported") + } + filterOrd = input.getColumnOrdinal(filter.Col) + child = aggFilter.Input + } + if aggDistinct, ok := child.(*memo.AggDistinctExpr); ok { distinct = true child = aggDistinct.Input @@ -658,6 +668,7 @@ func (b *Builder) buildGroupBy(groupBy memo.RelExpr) (execPlan, error) { ResultType: item.Agg.DataType(), ArgCols: argIdx, ConstArgs: constArgs, + Filter: filterOrd, } ep.outputCols.Set(int(item.Col), len(groupingColIdx)+i) } diff --git a/pkg/sql/opt/exec/execbuilder/testdata/aggregate b/pkg/sql/opt/exec/execbuilder/testdata/aggregate index 5c93a83a992f..f390fbd4f93d 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/aggregate +++ b/pkg/sql/opt/exec/execbuilder/testdata/aggregate @@ -584,44 +584,51 @@ group · · (v, count) · · spans ALL · · · filter v > 10 · · -# TODO(radu): FILTER not yet supported. -## Verify that FILTER works. -# -#exec-raw -#CREATE TABLE filter_test ( -# k INT, -# v INT, -# mark BOOL -#) -#---- -# -## Check that filter expressions are only rendered once. -#exec nodist -#EXPLAIN (EXPRS) SELECT count(*) FILTER (WHERE k>5), max(k>5) FILTER(WHERE k>5) FROM filter_test GROUP BY v -#---- -#group · · -# │ aggregate 0 count_rows() FILTER (WHERE k > 5) -# │ aggregate 1 max(k > 5) FILTER (WHERE k > 5) -# │ group by @1 -# └── render · · -# │ render 0 v -# │ render 1 k > 5 -# └── scan · · -#· table filter_test@primary -#· spans ALL -# -#exec nodist -#EXPLAIN (TYPES) SELECT count(*) FILTER (WHERE k > 5) FROM filter_test GROUP BY v -#---- -#group 0 group · · (count int) · -# │ 0 · aggregate 0 count_rows() FILTER (WHERE k > 5) · · -# │ 0 · group by @1 · · -# └── render 1 render · · (v int, "k > 5" bool) · -# │ 1 · render 0 (v)[int] · · -# │ 1 · render 1 ((k)[int] > (5)[int])[bool] · · -# └── scan 2 scan · · (k int, v int, mark[omitted] bool, rowid[hidden,omitted] int) rowid!=NULL; key(rowid) -#· 2 · table filter_test@primary · · -#· 2 · spans ALL · · +# Verify that FILTER works. + +statement ok +CREATE TABLE filter_test ( + k INT, + v INT, + mark BOOL +) + +# Check that filter expressions are only rendered once. +query TTTTT +EXPLAIN (VERBOSE) SELECT count(*) FILTER (WHERE k>5), max(k>5) FILTER(WHERE k>5) FROM filter_test GROUP BY v +---- +render · · (count, max) · + │ render 0 agg0 · · + │ render 1 agg1 · · + └── group · · (v, agg0, agg1) · + │ aggregate 0 v · · + │ aggregate 1 count(column5) FILTER (WHERE column6) · · + │ aggregate 2 max(column6) FILTER (WHERE column6) · · + │ group by @3 · · + └── render · · (column5, column6, v) · + │ render 0 true · · + │ render 1 k > 5 · · + │ render 2 v · · + └── scan · · (k, v) · +· table filter_test@primary · · +· spans ALL · · + +query TTTTT +EXPLAIN (VERBOSE) SELECT count(*) FILTER (WHERE k > 5) FROM filter_test GROUP BY v +---- +render · · (count) · + │ render 0 agg0 · · + └── group · · (v, agg0) · + │ aggregate 0 v · · + │ aggregate 1 count(column5) FILTER (WHERE column6) · · + │ group by @3 · · + └── render · · (column5, column6, v) · + │ render 0 true · · + │ render 1 k > 5 · · + │ render 2 v · · + └── scan · · (k, v) · +· table filter_test@primary · · +· spans ALL · · # Tests with * inside GROUP BY. query TTTTT diff --git a/pkg/sql/opt/exec/execbuilder/testdata/distsql_agg b/pkg/sql/opt/exec/execbuilder/testdata/distsql_agg index e39949c4f91d..02fc715c0850 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/distsql_agg +++ b/pkg/sql/opt/exec/execbuilder/testdata/distsql_agg @@ -107,12 +107,12 @@ https://cockroachdb.github.io/distsqlplan/decode.html#eJzUVlFr4lwQff9-RZiniCPJvT query T SELECT url FROM [EXPLAIN (DISTSQL) SELECT sum(a), stddev(a), avg(a) FILTER (WHERE a > 5), count(b), avg(b), variance(b) FILTER (WHERE b < 8), sum(b) FILTER (WHERE b < 8), stddev(b) FILTER (WHERE b > 2) FROM data] ---- -https://cockroachdb.github.io/distsqlplan/decode.html#eJzslkGL2kAUx-_9FMM7bWEkmZlEY04jq0Jg67bR3UsrS2oeVnATmURoWfzuJYnVxCaTgAcvHpZ13rz_-897-Q3MB0RxiLPgHRNwvwMDChwoCKBgAQUblhR2Kl5hksQqSykEXvgbXJPCJtrt0yy8pLCKFYL7Aekm3SK4sAh-btHHIERlmEAhxDTYbHObndq8B-qPDIM0AArzXRAlLukZjARRSBiJ01-ogIKPUYjKJZJRItmPvWkKtF3X9WYLhxLJs788unJK0SKPHyOwPFCI9-n5oEkarBFcdqDdmxmt1wrXQRorw672Mn_58iDZ56yNb_7Ym06Pq8fnl9ni306RQ6be02LiE8nL-3VRcVKJSmVxzrbOGZXYqUI1s6aCXVPBrq1gNw6RNw7xPLt9FKsQFYaVwS0P-jEzs2nOU282enqbL8bjyeuD5FQyWhqYdfz15mVN2JVV_5Q1OBd6HfneaPY4eZAOlUMqmXnKGv5vxxiVLPPMHSuAZjgKQ1qUSJsS2TfyfwNKZIblsHGGojJD1v1WsS63ymA9g9_uXrW0U_rg_fu9arpXvDsTvBMTvGeI2zHR0k6JicGdiSYmRHcmRCcmRM-wbsdESzslJpw7E01MWN2ZsDoxYfUM-3ZMtLRTYmJ4Z6LLm6xmiD4muzhK8OJtVl_ZzN5sGK6xeOAl8V6t8KuKV7lNsXzOdXkgxCQtdlmx8KJiKztgWcy0Yl4Rs0sx1zu3WAut2tKLrWvObWvFfb1z_xrngVbs6J2da5yH-m9ltmCih-zSe3n49DcAAP__yWRJgQ== +https://cockroachdb.github.io/distsqlplan/decode.html#eJzslkGP2jwQhu_fr7DmtJ9klNhOIORktIAUacu2gd1Li1YpGVEkNkFOkFqt-O9VkpYkNHEiceDCCTyemdfz5rHkD4jiEBfBOybgfgUGFDhQEEDBAgo2rCkcVLzBJIlVllIUeOFPcE0Ku-hwTLPwmsImVgjuB6S7dI_gwir4vkcfgxCVYQKFENNgt89lDmr3HqhfMgzSACgsD0GUuGRgMBJEIWEkTn-gAgo-RiEql0j27WiaAm3Xdb3FyqFE8jyycS4jyMsIy6KwPlGIj2l50CQNtgguO9H-w0y2W4XbII2VYddnWb58epDW_9kYX_ypN5__WT0-vyxWf3eKHDL3nlYzn0hW3W-K2ucqu9bZLrN5mVGLnTvUMxs6iIYOorGDaDWRt5pYeneMYhWiwrBm3Pqkt5mZDT6zzI25t5g8vS1X0-ns9UFyKhmVomJz8e_NqxpZrIbnrFHZ6HXie5PF4-xBOlSOqWTmOWv8rxxjVLJMM1csAc1Qo0QKQ1qUSJsSOTTynxElMoNx3OqhqHnI-t8q1udWGWxg8Nvdq45xKh98eL9XbfeK92eC92KCDwxxOyY6xqkwMboz0caE6M-E6MWEGBjW7ZjoGKfChHNnoo0Jqz8TVi8mrIFh346JjnEqTIzvTPR5kzWY6GNyiKMEL95mzZ3N7M2G4RaLB14SH9UGP6t4k8sUy-e8Lg-EmKTFLisWXlRsZQesFjNtMa8Vs8tirlfukBbaaktfbF1zbltbPNQrD69RHmmLHb2yc43yWP-tzA5M9JBdaq9P__0OAAD__wXcSdY= query T SELECT url FROM [EXPLAIN (DISTSQL) SELECT sum(a), avg(DISTINCT a), variance(a) FILTER (WHERE a > 0) FROM data] ---- -https://cockroachdb.github.io/distsqlplan/decode.html#eJy0k0GL1DAUx-9-ivBOK6SkSTsiOXVYVynoKJ26F-0hTh61MJuUJAVl6XeXtoe1y4xEZvaYl_zyfu_B_xGM1bhTD-hBfgMOFARQyIBCDhQ20FDonT2g99ZNTxag1L9AphQ60w9hKjcUDtYhyEcIXTgiSKjVjyNWqDQ6lgIFjUF1x7lN77oH5X4XWgUFFPa9Ml6ShHGijCac2PATHVCo0Gh0khSckoJ_H9I0w1RKWe7qt9CMFOwQngR8UC2C5CM9I_nkNhjrNDrUK69mPDHGtm0dtipYxzbrKfZfP90U_DVQ2N5_uHlX7utyd1uTpXS_rcrt7vZuekHelx_ru4oU4qy0WEnz-M3ymM0ynjBx_d2KeE0RpSkSll1fM4vXzKI0s4Tl19fM4zXzKM08YZuXDdQJzQp9b43HZ8E6_XM6BQ51i0s6vR3cAb84e5jbLMfPMzcXNPqw3PLlUJr5ahb8G-b_hN-s4PQ5LC7pnF0C55fAm_-Cm_HVnwAAAP__MOHnFQ== +https://cockroachdb.github.io/distsqlplan/decode.html#eJy0k0GL1DAUx-9-ivBOK6SkSTsiOc2wrlLQUTp1L9pDbB61MJuUJAVl6XeXtoe1yzhEZvaYl_zyfu_B_xGM1bhXD-hBfgMOFARQyIBCDhQ2UFPonW3Qe-umJwtQ6F8gUwqd6YcwlWsKjXUI8hFCF44IEir144glKo2OpUBBY1DdcW7Tu-5Bud9brYICCodeGS9JwjhRRhNObPiJDiiUaDQ6Sbb8-5CmGaZSymJfvaVky6EeKdghPAn4oFoEyUf6D8knt8FYp9GhXnnV44kxdm3rsFXBOrZZT3H4-ulmK14Dhd39h5t3xaEq9rcVWUr3u7LY7W_vphfkffGxuivPSYuVNI_fLI_ZLOMJE9ffrYjXFFGaImHZ9TWzeM0sSjNLWH59zTxeM4_SzBO2edlAndAs0ffWeHwWrNM_p1PgULe4pNPbwTX4xdlmbrMcP8_cXNDow3LLl0Nh5qtZ8G-Yn4XfrOD0OSwu6ZxdAueXwJv_guvx1Z8AAAD__1Jn5xc= query T SELECT url FROM [EXPLAIN (DISTSQL) SELECT sum(a), avg(a), count(a), stddev(a), variance(a) FROM data] @@ -188,13 +188,13 @@ https://cockroachdb.github.io/distsqlplan/decode.html#eJzMlmFro04Qxt__P4XMq5b_ip query T SELECT url FROM [EXPLAIN (DISTSQL) SELECT sum(a+b), sum(a+b) FILTER (WHERE a < d), sum(a+b) FILTER (WHERE a = c) FROM data GROUP BY d] ---- -https://cockroachdb.github.io/distsqlplan/decode.html#eJzMlk9r20AQxe_9FGJOCV0h7R85iaCgHFowuE5xnENpTdhYi21wtGIlQ0Pwdy-yCpL_dEdhLfDJ1ko_zZt5j0HvkOlUjeWrKiD-BRQIMCDAgYAAAhHMCORGz1VRaFM9UgPD9A_EIYFVlm_K6nhGYK6NgvgdylW5VhDDVL6s1UTJVJkgBAKpKuVqvSuTm9WrNG9JKksJBB5zmRWx5wfUk1nqUU-XS2WAwERlqTKxlwjiJfRzwqqf35sw5PP66EvCYbYloDdlI6Qo5UJBTLeku9j7xcKohSy1CaJ9rUk1k8en71cJu27-ed-Go-nXiZfwE2fV5O7HP5_HD9Pn8dNodJXQ6wOZTeWXN28pi-VBUQGzbdMK-28rzXs2mTapMirde9PuLZZmaXhU-F8_9Pq4b37YRiORn1fiWPs6D2h08OTp2mKvNu0eS9ollgH1A9ZfMBG5La8Glx5Mes5g0l6CybqHg3UKB_MD3l84ELmted1cejjYOcPBegkH7x4O3ikc3A9Ef-FA5LbmdXvp4eDnDAfvJRyiezhEp3AIP4j6CwcitzWvu0sPhzhnOETv3zsnBExUkeusUJ2-ZsKqBZUuVN1voTdmrn4YPd-VqS8fdtzuIFVFWd-l9cUwq29VAtswPYRpG2Z7MP0YPHCB71xg6qSbRnaaWefN7TC3mzWwuyWsdGSHIxer7TBitR1GrLbDmNUIjVg9cLH6xgrf2s26dTHLDiNm2WHELDuMmYXQiFl3LmZRZItia9Rtj7otUrdN6rhK3XYpdVqmFNmmAjHtaJ1-yDQ7jZlmpzHT7DRqGoJjph0tVatps-2nvwEAAP__N5l2wQ== +https://cockroachdb.github.io/distsqlplan/decode.html#eJzMll1r2z4Uxu__n8Kcq_6ZjH0kOW0NA_Vig0KXjqy9GFsoaizSQGsZ2YaVku8-bEPjvExyUUJyJ7_8fJ5znoeD3yDXmRrLF1VC-gsQCFAgwIAABwIJTAkURs9UWWrTvNIB19kfSGMCi7yoq-b2lMBMGwXpG1SL6llBCnfy8VlNlMyUiWIgkKlKLp7bMoVZvEjzKjJZSSDwo5B5mQZhhIHMswADXT0pAwQmKs-USQOBnwQlgcDfdRyzmeDN-bNgJBAcpksCuq5WQspKzhWkuCTDxV7N50bNZaVNlKxrFc0crsY_H8a3dw_j-5ubM8H_b0TffzsT2DsFX69v7r5MAkF33GMbMleVH1-DJ1k-bRRFmC5XrdB_trL6Tp1rkymjsrUvtV-xNIvxVuHNblc90vcTez-1s7itqzQQlAhGLIaw_XYx1qEuIkw23txdm6_VxuHJxSHJjTCM6OGy65Dbs3N06tnFfWYXj5VdOjw_dFB-aBixw-XHIbc30vNTzw_dZ37osfLDhueHDcoPCyN-uPw45PZGenHq-WH7zA87Vn748PzwQfnhYZQcLj8Oub2RXp56fvg-88NP4d9rh8aJKgudl2rQn1XcdKmyuepGUurazNR3o2dtme7ytuXaG5kqq-4pdhfXefeoEdiHcRPGPkzXYPwYPPKBL31g9NKNiZ2m1nkzO8zsZo3sbnErndjhxMdqO-yw2g47rLbDLqsdtMPqkY_V51b4wm7WhY9Zdthhlh12mGWHXWY5aIdZlz5moWOLutao3x71W6R-m9RzlfrtUvRapujYptxh2tY6_ZBpdtplmp12mWannaY5cJdpW0vVatp0-d_fAAAA___qTKJK # Same query but restricted to a single range; no local aggregation stage. query T SELECT url FROM [EXPLAIN (DISTSQL) SELECT sum(a+b), sum(a+b) FILTER (WHERE a < d), sum(a+b) FILTER (WHERE a = c) FROM data WHERE a = 1 GROUP BY d] ---- -https://cockroachdb.github.io/distsqlplan/decode.html#eJyUkUFLwzAUx-9-ivJOihlN0p0CQjwoDHRKN0_aQ2wepbAlIUlBGf3u0uawVVzVU_v-7_1efvAOYKzGtdpjAPEKDAhwqAg4b2sMwfohTkMr_QGCEmiN6-IQVwRq6xHEAWIbdwgC1nZhXc6BgMao2t041hOwXTxCIaoGQRQ9OVnM5hdv1fsOS1QafU4n68H5dq_8p9QqKiCwccoEkeVsMXqUaDR6kcklySS7lnz4vHWUFnWKbmQB5xTZfxRvm8Zjo6L1OZsayqHevDxeSn51_MvuVw_buzKTxQ_Z8qwTnzj9co8Sg7Mm4J8OQvuKAOoG082D7XyNz97W4zOpfBq5MdAYYuoWqViZ1BoET2E2C_N5mM_C9Btc9RdfAQAA__-VFuTI +https://cockroachdb.github.io/distsqlplan/decode.html#eJyUksFq4zAQhu_7FGJOWXaCLdknwYJy2IVA6hQ3OZTWBNUajCGxjCRDS_C7F9uHJqV125vmn_nGH4zP0FhDmT6RB_kAHBAEFAitsyV5b90QT0Nr8wwyRqibtgtDXCCU1hHIM4Q6HAkkZHZp20gAgqGg6-M41iPYLrxBPuiKQCY9Xizm84t3-ulIOWlDLoqv1kPr6pN2L8rooAHhrtWNlyziy9Ejp8aQk0zxP0ogU_yxi-OkVOnw_qsSZCqFzxT5TxRXVeWo0sG6iF8bqhQQVtn9IdvuDtl-s1mo9Pegur9ZKH7xYv_Xm92_nCnxQZYAwrYLkimBKsEZbXGl_cXJcvKtbTx962ZxXyCQqWj6LbztXEm3zpbjZ6ZyO3JjYMiHqZtMxbqZWoPgJcxnYTEPi1k4fgcX_a_XAAAA___19-2A # Verify the XOR execution plan query T diff --git a/pkg/sql/opt/exec/factory.go b/pkg/sql/opt/exec/factory.go index 8366e6acc9e2..320dfb2452e4 100644 --- a/pkg/sql/opt/exec/factory.go +++ b/pkg/sql/opt/exec/factory.go @@ -375,4 +375,8 @@ type AggInfo struct { // ConstArgs is the list of any constant arguments to the aggregate, // for instance, the separator in string_agg. ConstArgs []tree.Datum + + // Filter is the index of the column, if any, which should be used as the + // FILTER condition for the aggregate. If there is no filter, Filter is -1. + Filter ColumnOrdinal } diff --git a/pkg/sql/opt/memo/check_expr.go b/pkg/sql/opt/memo/check_expr.go index 93e106c01b3f..af1ddcd7e324 100644 --- a/pkg/sql/opt/memo/check_expr.go +++ b/pkg/sql/opt/memo/check_expr.go @@ -192,6 +192,11 @@ func (m *Memo) checkExpr(e opt.Expr) { panic(fmt.Sprintf("zigzag join with mismatching eq columns")) } + case *AggDistinctExpr: + if t.Input.Op() == opt.AggFilterOp { + panic("AggFilter should always be on top of AggDistinct") + } + default: if !opt.IsListOp(e) { for i := 0; i < e.ChildCount(); i++ { diff --git a/pkg/sql/opt/memo/expr_format.go b/pkg/sql/opt/memo/expr_format.go index b45390c52125..481005f0b306 100644 --- a/pkg/sql/opt/memo/expr_format.go +++ b/pkg/sql/opt/memo/expr_format.go @@ -448,6 +448,17 @@ func (f *ExprFmtCtx) formatScalar(scalar opt.ScalarExpr, tp treeprinter.Node) { f.formatExpr(scalar.Child(2), tp.Child("err-code")) } + return + + case opt.AggFilterOp: + f.Buffer.Reset() + fmt.Fprintf(f.Buffer, "%v", scalar.Op()) + f.FormatScalarProps(scalar) + tp = tp.Child(f.Buffer.String()) + + f.formatExpr(scalar.Child(0), tp) + f.formatExpr(scalar.Child(1), tp.Child("filter")) + return } diff --git a/pkg/sql/opt/memo/extract.go b/pkg/sql/opt/memo/extract.go index 5215e1cdc459..4bcffa699823 100644 --- a/pkg/sql/opt/memo/extract.go +++ b/pkg/sql/opt/memo/extract.go @@ -111,23 +111,38 @@ func ExtractAggSingleInputColumn(e opt.ScalarExpr) opt.ColumnID { return ExtractVarFromAggInput(e.Child(0).(opt.ScalarExpr)).Col } -// ExtractAggInputColumns returns the input columns of an aggregate (which can -// be empty). +// ExtractAggInputColumns returns the set of columns the aggregate depends on. func ExtractAggInputColumns(e opt.ScalarExpr) opt.ColSet { if !opt.IsAggregateOp(e) { panic("not an Aggregate") } + if e.ChildCount() == 0 { + return opt.ColSet{} + } + + arg := e.Child(0) var res opt.ColSet - if e.ChildCount() > 0 { - res.Add(int(ExtractVarFromAggInput(e.Child(0).(opt.ScalarExpr)).Col)) + if filter, ok := arg.(*AggFilterExpr); ok { + res.Add(int(filter.Filter.(*VariableExpr).Col)) + arg = filter.Input + } + if distinct, ok := arg.(*AggDistinctExpr); ok { + arg = distinct.Input + } + if variable, ok := arg.(*VariableExpr); ok { + res.Add(int(variable.Col)) + return res } - return res + panic(fmt.Sprintf("unhandled aggregate input %T", arg)) } // ExtractVarFromAggInput is given an argument to an Aggregate and returns the // inner Variable expression, stripping out modifiers like AggDistinct. func ExtractVarFromAggInput(arg opt.ScalarExpr) *VariableExpr { + if filter, ok := arg.(*AggFilterExpr); ok { + arg = filter.Input + } if distinct, ok := arg.(*AggDistinctExpr); ok { arg = distinct.Input } diff --git a/pkg/sql/opt/memo/typing.go b/pkg/sql/opt/memo/typing.go index 3cc57d022e5d..015da81c8b98 100644 --- a/pkg/sql/opt/memo/typing.go +++ b/pkg/sql/opt/memo/typing.go @@ -152,6 +152,7 @@ func init() { // Modifiers for aggregations pass through their argument. typingFuncMap[opt.AggDistinctOp] = typeAsFirstArg + typingFuncMap[opt.AggFilterOp] = typeAsFirstArg for _, op := range opt.BinaryOperators { typingFuncMap[op] = typeAsBinary diff --git a/pkg/sql/opt/ops/scalar.opt b/pkg/sql/opt/ops/scalar.opt index b356eb7fa1a9..f8b8244bbbeb 100644 --- a/pkg/sql/opt/ops/scalar.opt +++ b/pkg/sql/opt/ops/scalar.opt @@ -873,6 +873,16 @@ define AggDistinct { Input ScalarExpr } +# AggFilter is used as a modifier that wraps the input of an aggregate +# function. It causes only rows for which the filter expression is true +# to be processed. AggFilter should always occur on top of AggDistinct +# if they are both present. +[Scalar] +define AggFilter { + Input ScalarExpr + Filter ScalarExpr +} + # ScalarList is a list expression that has scalar expression items of type # opt.ScalarExpr. opt.ScalarExpr is an external type that is defined outside of # Optgen. It is hard-coded in the code generator to be the item type for diff --git a/pkg/sql/opt/optbuilder/groupby.go b/pkg/sql/opt/optbuilder/groupby.go index ffaf78efac1b..3fb2f9d2b774 100644 --- a/pkg/sql/opt/optbuilder/groupby.go +++ b/pkg/sql/opt/optbuilder/groupby.go @@ -88,6 +88,7 @@ type aggregateInfo struct { def memo.FunctionPrivate distinct bool args memo.ScalarListExpr + filter opt.ScalarExpr // col is the output column of the aggregation. col *scopeColumn @@ -248,6 +249,7 @@ func (b *Builder) buildAggregation( args := make([]opt.ScalarExpr, 0, 2) if len(agg.args) > 0 { colID := argCols[0].id + argCols = argCols[1:] args = append(args, b.factory.ConstructVariable(colID)) if agg.distinct { // Wrap the argument with AggDistinct. @@ -255,11 +257,22 @@ func (b *Builder) buildAggregation( } // Append any constant arguments without further processing. - args = append(args, agg.args[1:]...) + constArgs := agg.args[1:] + args = append(args, constArgs...) + argCols = argCols[len(constArgs):] + + // If the aggregate had a filter, there's an extra column in argCols + // corresponding to the filter. + // TODO(justin): add a norm rule to push these filters below group bys where appropriate. + if agg.filter != nil { + colID := argCols[0].id + argCols = argCols[1:] + // Wrap the argument with AggFilter. + args[0] = b.factory.ConstructAggFilter(args[0], b.factory.ConstructVariable(colID)) + } } aggCols[i].scalar = b.constructAggregate(agg.def.Name, args).(opt.ScalarExpr) - argCols = argCols[len(agg.args):] if opt.AggregateIsOrderingSensitive(aggCols[i].scalar.Op()) { haveOrderingSensitiveAgg = true @@ -403,6 +416,22 @@ func (b *Builder) buildGrouping( } } +// buildAggArg builds a scalar expression which is used as an input in some form +// to an aggregate expression. The scopeColumn for the built expression will +// be added to tempScope. +func (b *Builder) buildAggArg( + e tree.TypedExpr, info *aggregateInfo, tempScope, inScope *scope, +) opt.ScalarExpr { + // This synthesizes a new tempScope column, unless the argument is a + // simple VariableOp. + col := b.addColumn(tempScope, "" /* alias */, e) + b.buildScalar(e, inScope, tempScope, col, &info.colRefs) + if col.scalar != nil { + return col.scalar + } + return b.factory.ConstructVariable(col.id) +} + // buildAggregateFunction is called when we are building a function which is an // aggregate. Any non-trivial parameters (i.e. not column reference) to the // aggregate function are extracted and added to aggInScope. The aggregate @@ -436,17 +465,13 @@ func (b *Builder) buildAggregateFunction( defer func() { b.subquery = subq }() for i, pexpr := range f.Exprs { - // This synthesizes a new tempScope column, unless the argument is a - // simple VariableOp. - texpr := pexpr.(tree.TypedExpr) - - col := b.addColumn(tempScope, "" /* alias */, texpr) - b.buildScalar(texpr, inScope, tempScope, col, &info.colRefs) - if col.scalar != nil { - info.args[i] = col.scalar - } else { - info.args[i] = b.factory.ConstructVariable(col.id) - } + info.args[i] = b.buildAggArg(pexpr.(tree.TypedExpr), &info, tempScope, inScope) + } + + // If we have a filter, add it to tempScope after all the arguments. We'll + // later extract the column that gets added here in buildAggregation. + if f.Filter != nil { + info.filter = b.buildAggArg(f.Filter.(tree.TypedExpr), &info, tempScope, inScope) } // Find the appropriate aggregation scopes for this aggregate now that we diff --git a/pkg/sql/opt/optbuilder/scope.go b/pkg/sql/opt/optbuilder/scope.go index 4e2b7901deb8..59fe77cf33b3 100644 --- a/pkg/sql/opt/optbuilder/scope.go +++ b/pkg/sql/opt/optbuilder/scope.go @@ -432,7 +432,8 @@ func (s *scope) getAggregateCols() []scopeColumn { // getAggregateArgCols returns the columns in this scope corresponding // to arguments to aggregate functions. This call is only valid on an -// aggInScope. +// aggInScope. If the aggregate has a filter, the column corresponding +// to its input will immediately follow its inputs. func (s *scope) getAggregateArgCols(groupingsLen int) []scopeColumn { // Aggregate args are always clustered at the beginning of the column list. return s.cols[:len(s.cols)-groupingsLen] @@ -460,7 +461,7 @@ func (s *scope) findAggregate(agg aggregateInfo) *scopeColumn { for i, a := range s.groupby.aggs { // Find an existing aggregate that uses the same function overload. - if a.def.Overload == agg.def.Overload && a.distinct == agg.distinct { + if a.def.Overload == agg.def.Overload && a.distinct == agg.distinct && a.filter == agg.filter { // Now check that the arguments are identical. if len(a.args) == len(agg.args) { match := true @@ -929,10 +930,6 @@ func (s *scope) replaceSRF(f *tree.FuncExpr, def *tree.FunctionDefinition) *srf // aggregate references no variables). The aggOutScope.groupby.aggs slice is // used later by the Builder to build aggregations in the aggregation scope. func (s *scope) replaceAggregate(f *tree.FuncExpr, def *tree.FunctionDefinition) tree.Expr { - if f.Filter != nil { - panic(unimplementedf("aggregates with FILTER are not supported yet")) - } - f, def = s.replaceCount(f, def) // We need to save and restore the previous value of the field in @@ -944,6 +941,23 @@ func (s *scope) replaceAggregate(f *tree.FuncExpr, def *tree.FunctionDefinition) tree.RejectNestedAggregates|tree.RejectWindowApplications) expr := f.Walk(s) + + // We need to do this check here to ensure that we check the usage of special + // functions with the right error message. + if f.Filter != nil { + func() { + oldProps := s.builder.semaCtx.Properties + defer func() { s.builder.semaCtx.Properties.Restore(oldProps) }() + + s.builder.semaCtx.Properties.Require("FILTER", tree.RejectSpecial) + _, err := tree.TypeCheck(expr.(*tree.FuncExpr).Filter, s.builder.semaCtx, types.Any) + if err != nil { + panic(builderError{err}) + } + + }() + } + typedFunc, err := tree.TypeCheck(expr, s.builder.semaCtx, types.Any) if err != nil { panic(builderError{err}) @@ -982,11 +996,31 @@ func (s *scope) replaceCount( if strings.EqualFold(def.Name, "count") && f.Type == 0 { if _, ok := vn.(tree.UnqualifiedStar); ok { - // Special case handling for COUNT(*). This is a special construct to - // count the number of rows; in this case * does NOT refer to a set of - // columns. A * is invalid elsewhere (and will be caught by TypeCheck()). - // Replace the function with COUNT_ROWS (which doesn't take any - // arguments). + if f.Filter != nil { + // If we have a COUNT(*) with a FILTER, we need to synthesize an input + // for the aggregation to be over, because otherwise we have no input + // to hang the AggFilter off of. + // Thus, we convert + // COUNT(*) FILTER (WHERE foo) + // to + // COUNT(true) FILTER (WHERE foo). + cpy := *f + e := &cpy + e.Exprs = tree.Exprs{tree.DBoolTrue} + + newDef, err := e.Func.Resolve(s.builder.semaCtx.SearchPath) + if err != nil { + panic(builderError{err}) + } + + return e, newDef + } + + // Special case handling for COUNT(*) with no FILTER. This is a special + // construct to count the number of rows; in this case * does NOT refer + // to a set of columns. A * is invalid elsewhere (and will be caught by + // TypeCheck()). Replace the function with COUNT_ROWS (which doesn't + // take any arguments). e := &tree.FuncExpr{ Func: tree.ResolvableFunctionReference{ FunctionReference: &tree.UnresolvedName{ diff --git a/pkg/sql/opt/optbuilder/testdata/aggregate b/pkg/sql/opt/optbuilder/testdata/aggregate index a210dd7de494..102477ff0960 100644 --- a/pkg/sql/opt/optbuilder/testdata/aggregate +++ b/pkg/sql/opt/optbuilder/testdata/aggregate @@ -2622,10 +2622,145 @@ scalar-group-by └── agg-distinct [type=decimal] └── variable: d [type=decimal] +# FILTER. + build SELECT sum(abc.d) FILTER (WHERE abc.d > 0) FROM abc ---- -error (0A000): aggregates with FILTER are not supported yet +scalar-group-by + ├── columns: sum:6(decimal) + ├── project + │ ├── columns: column5:5(bool) d:4(decimal) + │ ├── scan abc + │ │ └── columns: a:1(string!null) b:2(float) c:3(bool) d:4(decimal) + │ └── projections + │ └── gt [type=bool] + │ ├── variable: d [type=decimal] + │ └── const: 0 [type=decimal] + └── aggregations + └── sum [type=decimal] + └── agg-filter [type=decimal] + ├── variable: d [type=decimal] + └── filter + └── variable: column5 [type=bool] + +# Ensure aggregates with FILTER coexist properly with non-FILTER aggregates. +build +SELECT + sum(x) FILTER (WHERE y > 0), + avg(DISTINCT z), + avg(DISTINCT z) FILTER (WHERE y > 0) +FROM xyz +---- +scalar-group-by + ├── columns: sum:5(decimal) avg:6(float) avg:7(float) + ├── project + │ ├── columns: column4:4(bool) x:1(int!null) z:3(float) + │ ├── scan xyz + │ │ └── columns: x:1(int!null) y:2(int) z:3(float) + │ └── projections + │ └── gt [type=bool] + │ ├── variable: y [type=int] + │ └── const: 0 [type=int] + └── aggregations + ├── sum [type=decimal] + │ └── agg-filter [type=int] + │ ├── variable: x [type=int] + │ └── filter + │ └── variable: column4 [type=bool] + ├── avg [type=float] + │ └── agg-distinct [type=float] + │ └── variable: z [type=float] + └── avg [type=float] + └── agg-filter [type=float] + ├── agg-distinct [type=float] + │ └── variable: z [type=float] + └── filter + └── variable: column4 [type=bool] + +# Ensure aggregates involving FILTER are deduplicated. +build +SELECT + avg(DISTINCT x), + avg(DISTINCT x), + avg(DISTINCT x) FILTER (WHERE y > 0), + avg(DISTINCT x) FILTER (WHERE y > 0) +FROM xyz +---- +scalar-group-by + ├── columns: avg:4(decimal) avg:4(decimal) avg:6(decimal) avg:6(decimal) + ├── project + │ ├── columns: column5:5(bool) x:1(int!null) + │ ├── scan xyz + │ │ └── columns: x:1(int!null) y:2(int) z:3(float) + │ └── projections + │ └── gt [type=bool] + │ ├── variable: y [type=int] + │ └── const: 0 [type=int] + └── aggregations + ├── avg [type=decimal] + │ └── agg-distinct [type=int] + │ └── variable: x [type=int] + └── avg [type=decimal] + └── agg-filter [type=int] + ├── agg-distinct [type=int] + │ └── variable: x [type=int] + └── filter + └── variable: column5 [type=bool] + +build +SELECT + string_agg(x::string, 'foo') FILTER (WHERE y > 0) +FROM xyz +---- +scalar-group-by + ├── columns: string_agg:7(string) + ├── project + │ ├── columns: column4:4(string) column5:5(string!null) column6:6(bool) + │ ├── scan xyz + │ │ └── columns: x:1(int!null) y:2(int) z:3(float) + │ └── projections + │ ├── cast: STRING [type=string] + │ │ └── variable: x [type=int] + │ ├── const: 'foo' [type=string] + │ └── gt [type=bool] + │ ├── variable: y [type=int] + │ └── const: 0 [type=int] + └── aggregations + └── string-agg [type=string] + ├── agg-filter [type=string] + │ ├── variable: column4 [type=string] + │ └── filter + │ └── variable: column6 [type=bool] + └── const: 'foo' [type=string] + +build +SELECT y, count(*) FILTER (WHERE x > 5) FROM xyz GROUP BY y +---- +group-by + ├── columns: y:2(int) count:6(int) + ├── grouping columns: y:2(int) + ├── project + │ ├── columns: column4:4(bool!null) column5:5(bool) y:2(int) + │ ├── scan xyz + │ │ └── columns: x:1(int!null) y:2(int) z:3(float) + │ └── projections + │ ├── true [type=bool] + │ └── gt [type=bool] + │ ├── variable: x [type=int] + │ └── const: 5 [type=int] + └── aggregations + └── count [type=int] + └── agg-filter [type=bool] + ├── variable: column4 [type=bool] + └── filter + └── variable: column5 [type=bool] + +build +SELECT y, count(*) FILTER (WHERE count(*) > 5) FROM xyz GROUP BY y +---- +error: count_rows(): aggregate functions are not allowed in FILTER + # Check that ordering by an alias of an aggregate works. build diff --git a/pkg/sql/opt_exec_factory.go b/pkg/sql/opt_exec_factory.go index e5eeb0f1baf0..4ae6f6b6925b 100644 --- a/pkg/sql/opt_exec_factory.go +++ b/pkg/sql/opt_exec_factory.go @@ -438,6 +438,14 @@ func (ef *execFactory) addAggregations(n *groupNode, aggregations []exec.AggInfo if agg.Distinct { f.setDistinct() } + + if agg.Filter == -1 { + // A value of -1 means the aggregate had no filter. + f.filterRenderIdx = noRenderIdx + } else { + f.filterRenderIdx = int(agg.Filter) + } + n.funcs = append(n.funcs, f) n.columns = append(n.columns, sqlbase.ResultColumn{ Name: fmt.Sprintf("agg%d", i),