Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support complex aggregation in Gen4's Operators #13326

Merged
merged 1 commit into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,3 +452,18 @@ func TestMinMaxAcrossJoins(t *testing.T) {
`SELECT /*vt+ PLANNER=gen4 */ t1.name, max(t1.shardKey), t2.shardKey, min(t2.id) FROM t1 JOIN t2 ON t1.t1_id != t2.shardKey GROUP BY t1.name, t2.shardKey`,
`[[VARCHAR("name 2") INT64(2) INT64(10) INT64(1)] [VARCHAR("name 1") INT64(1) INT64(10) INT64(1)] [VARCHAR("name 2") INT64(2) INT64(20) INT64(2)] [VARCHAR("name 1") INT64(1) INT64(20) INT64(2)]]`)
}

func TestComplexAggregation(t *testing.T) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@frouioui do we need to check the version of vtgate so these are not run on the upgrade/downgrade testing?

mcmp, closer := start(t)
defer closer()
mcmp.Exec("insert into t1(t1_id, `name`, `value`, shardkey) values(1,'a1','foo',100), (2,'b1','foo',200), (3,'c1','foo',300), (4,'a1','foo',100), (5,'d1','toto',200), (6,'c1','tata',893), (7,'a1','titi',2380), (8,'b1','tete',12833), (9,'e1','yoyo',783493)")

mcmp.Exec("set @@sql_mode = ' '")
mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ 1+COUNT(t1_id) FROM t1`)
mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ COUNT(t1_id)+1 FROM t1`)
mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ COUNT(t1_id)+MAX(shardkey) FROM t1`)
mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ shardkey, MIN(t1_id)+MAX(t1_id) FROM t1 GROUP BY shardkey`)
mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ shardkey + MIN(t1_id)+MAX(t1_id) FROM t1 GROUP BY shardkey`)
mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ name+COUNT(t1_id)+1 FROM t1 GROUP BY name`)
mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ COUNT(*)+shardkey+MIN(t1_id)+1+MAX(t1_id)*SUM(t1_id)+1+name FROM t1 GROUP BY shardkey, name`)
}
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/horizon_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ func (hp *horizonPlanning) planAggrUsingOA(
}
}

aggregationExprs, err := hp.qp.AggregationExpressions(ctx)
aggregationExprs, _, err := hp.qp.AggregationExpressions(ctx, false)
if err != nil {
return nil, err
}
Expand Down
36 changes: 34 additions & 2 deletions go/vt/vtgate/planbuilder/operators/horizon_expanding.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon horizo
return nil, err
}

aggregations, err := qp.AggregationExpressions(ctx)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can call this as needsEvaluationAfterAggregation

aggregations, complexAggr, err := qp.AggregationExpressions(ctx, true)
if err != nil {
return nil, err
}
Expand All @@ -135,6 +135,13 @@ func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon horizo
a.Alias = derived.Alias
}

if complexAggr {
return createProjectionForComplexAggregation(a, qp)
}
return createProjectionForSimpleAggregation(ctx, a, qp)
}

func createProjectionForSimpleAggregation(ctx *plancontext.PlanningContext, a *Aggregator, qp *QueryProjection) (ops.Operator, error) {
outer:
for colIdx, expr := range qp.SelectExprs {
ae, err := expr.GetAliasedExpr()
Expand Down Expand Up @@ -165,10 +172,35 @@ outer:
}
return nil, vterrors.VT13001(fmt.Sprintf("Could not find the %s in aggregation in the original query", sqlparser.String(ae)))
}

return a, nil
}

func createProjectionForComplexAggregation(a *Aggregator, qp *QueryProjection) (ops.Operator, error) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add comments as how this is built

p := &Projection{
Source: a,
Alias: a.Alias,
TableID: a.TableID,
}

for _, expr := range qp.SelectExprs {
ae, err := expr.GetAliasedExpr()
if err != nil {
return nil, err
}
p.Columns = append(p.Columns, ae)
p.Projections = append(p.Projections, UnexploredExpression{E: ae.Expr})
}
for i, by := range a.Grouping {
a.Grouping[i].ColOffset = len(a.Columns)
a.Columns = append(a.Columns, aeWrap(by.SimplifiedExpr))
}
for i, aggregation := range a.Aggregations {
a.Aggregations[i].ColOffset = len(a.Columns)
a.Columns = append(a.Columns, aggregation.Original)
}
return p, nil
}

func createProjectionWithoutAggr(qp *QueryProjection, src ops.Operator) (*Projection, error) {
proj := &Projection{
Source: src,
Expand Down
66 changes: 45 additions & 21 deletions go/vt/vtgate/planbuilder/operators/queryprojection.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,11 @@ type (
OriginalOpCode opcode.AggregateOpcode

Alias string

// The index at which the user expects to see this aggregated function. Set to nil, if the user does not ask for it
Index *int
// Only used in the old Horizon Planner
Index *int

Distinct bool

// the offsets point to columns on the same aggregator
Expand Down Expand Up @@ -444,13 +447,9 @@ func checkForInvalidAggregations(exp *sqlparser.AliasedExpr) error {
}, exp.Expr)
}

func (qp *QueryProjection) isExprInGroupByExprs(ctx *plancontext.PlanningContext, expr SelectExpr) bool {
func (qp *QueryProjection) isExprInGroupByExprs(ctx *plancontext.PlanningContext, expr sqlparser.Expr) bool {
for _, groupByExpr := range qp.groupByExprs {
exp, err := expr.GetExpr()
if err != nil {
return false
}
if ctx.SemTable.EqualsExprWithDeps(groupByExpr.SimplifiedExpr, exp) {
if ctx.SemTable.EqualsExprWithDeps(groupByExpr.SimplifiedExpr, expr) {
return true
}
}
Expand Down Expand Up @@ -623,7 +622,7 @@ func (qp *QueryProjection) NeedsDistinct() bool {
return true
}

func (qp *QueryProjection) AggregationExpressions(ctx *plancontext.PlanningContext) (out []Aggr, err error) {
func (qp *QueryProjection) AggregationExpressions(ctx *plancontext.PlanningContext, allowComplexExpression bool) (out []Aggr, complex bool, err error) {
orderBy:
for _, orderExpr := range qp.OrderExprs {
orderExpr := orderExpr.SimplifiedExpr
Expand All @@ -649,27 +648,54 @@ orderBy:
for idx, expr := range qp.SelectExprs {
aliasedExpr, err := expr.GetAliasedExpr()
if err != nil {
return nil, err
return nil, false, err
}

idxCopy := idx

if !sqlparser.ContainsAggregation(expr.Col) {
if !qp.isExprInGroupByExprs(ctx, expr) {
getExpr, err := expr.GetExpr()
if err != nil {
return nil, false, err
}
if !qp.isExprInGroupByExprs(ctx, getExpr) {
aggr := NewAggr(opcode.AggregateRandom, nil, aliasedExpr, aliasedExpr.ColumnName())
aggr.Index = &idxCopy
out = append(out, aggr)
}
continue
}
fnc, isAggregate := aliasedExpr.Expr.(sqlparser.AggrFunc)
if !isAggregate {
return nil, vterrors.VT12001("in scatter query: complex aggregate expression")
_, isAggregate := aliasedExpr.Expr.(sqlparser.AggrFunc)
if !isAggregate && !allowComplexExpression {
return nil, false, vterrors.VT12001("in scatter query: complex aggregate expression")
}

aggr := createAggrFromAggrFunc(fnc, aliasedExpr)
aggr.Index = &idxCopy
out = append(out, aggr)
sqlparser.CopyOnRewrite(aliasedExpr.Expr, func(node, parent sqlparser.SQLNode) bool {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add some comments here.

ex, isExpr := node.(sqlparser.Expr)
if !isExpr {
return true
}
if aggr, isAggr := node.(sqlparser.AggrFunc); isAggr {
ae := aeWrap(aggr)
if aggr == aliasedExpr.Expr {
ae = aliasedExpr
}
aggrFunc := createAggrFromAggrFunc(aggr, ae)
aggrFunc.Index = &idxCopy
out = append(out, aggrFunc)
return false
}
if sqlparser.ContainsAggregation(node) {
complex = true
return true
}
if !qp.isExprInGroupByExprs(ctx, ex) {
aggr := NewAggr(opcode.AggregateRandom, nil, aeWrap(ex), "")
aggr.Index = &idxCopy
out = append(out, aggr)
}
return false
}, nil, nil)
}
return
}
Expand All @@ -683,9 +709,7 @@ func createAggrFromAggrFunc(fnc sqlparser.AggrFunc, aliasedExpr *sqlparser.Alias
}
}

aggrF, _ := aliasedExpr.Expr.(sqlparser.AggrFunc)

if aggrF.IsDistinct() {
if fnc.IsDistinct() {
switch code {
case opcode.AggregateCount:
code = opcode.AggregateCountDistinct
Expand All @@ -694,8 +718,8 @@ func createAggrFromAggrFunc(fnc sqlparser.AggrFunc, aliasedExpr *sqlparser.Alias
}
}

aggr := NewAggr(code, aggrF, aliasedExpr, aliasedExpr.ColumnName())
aggr.Distinct = aggrF.IsDistinct()
aggr := NewAggr(code, fnc, aliasedExpr, aliasedExpr.ColumnName())
aggr.Distinct = fnc.IsDistinct()
return aggr
}

Expand Down
108 changes: 108 additions & 0 deletions go/vt/vtgate/planbuilder/testdata/aggr_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -6626,5 +6626,113 @@
"user.user_extra"
]
}
},
{
"comment": "Complex aggregate expression on scatter",
"query": "select 1+count(*) from user",
"v3-plan": "VT12001: unsupported: in scatter query: complex aggregate expression",
"gen4-plan": {
"QueryType": "SELECT",
"Original": "select 1+count(*) from user",
"Instructions": {
"OperatorType": "Projection",
"Expressions": [
"[COLUMN 0] + [COLUMN 1] as 1 + count(*)"
],
"Inputs": [
{
"OperatorType": "Aggregate",
"Variant": "Scalar",
"Aggregates": "random(0), sum_count_star(1) AS count(*)",
"Inputs": [
{
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select 1, count(*) from `user` where 1 != 1",
"Query": "select 1, count(*) from `user`",
"Table": "`user`"
}
]
}
]
},
"TablesUsed": [
"user.user"
]
}
},
{
"comment": "combine the output of two aggregations in the final result",
"query": "select greatest(sum(user.foo), sum(user_extra.bar)) from user join user_extra on user.col = user_extra.col",
"v3-plan": "VT12001: unsupported: cross-shard query with aggregates",
"gen4-plan": {
"QueryType": "SELECT",
"Original": "select greatest(sum(user.foo), sum(user_extra.bar)) from user join user_extra on user.col = user_extra.col",
"Instructions": {
"OperatorType": "Projection",
"Expressions": [
"GREATEST([COLUMN 0], [COLUMN 1]) as greatest(sum(`user`.foo), sum(user_extra.bar))"
],
"Inputs": [
{
"OperatorType": "Aggregate",
"Variant": "Scalar",
"Aggregates": "sum(0) AS sum(`user`.foo), sum(1) AS sum(user_extra.bar)",
"Inputs": [
{
"OperatorType": "Projection",
"Expressions": [
"[COLUMN 0] * [COLUMN 1] as sum(`user`.foo)",
"[COLUMN 3] * [COLUMN 2] as sum(user_extra.bar)"
],
"Inputs": [
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:0,R:0,R:1,L:1",
"JoinVars": {
"user_col": 2
},
"TableName": "`user`_user_extra",
"Inputs": [
{
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select sum(`user`.foo), count(*), `user`.col from `user` where 1 != 1 group by `user`.col",
"Query": "select sum(`user`.foo), count(*), `user`.col from `user` group by `user`.col",
"Table": "`user`"
},
{
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select count(*), sum(user_extra.bar) from user_extra where 1 != 1 group by .0",
"Query": "select count(*), sum(user_extra.bar) from user_extra where user_extra.col = :user_col group by .0",
"Table": "user_extra"
}
]
}
]
}
]
}
]
},
"TablesUsed": [
"user.user",
"user.user_extra"
]
}
}
]
63 changes: 62 additions & 1 deletion go/vt/vtgate/planbuilder/testdata/tpch_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,68 @@
"comment": "TPC-H query 14",
"query": "select 100.00 * sum(case when p_type like 'PROMO%' then l_extendedprice * (1 - l_discount) else 0 end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue from lineitem, part where l_partkey = p_partkey and l_shipdate >= date('1995-09-01') and l_shipdate < date('1995-09-01') + interval '1' month",
"v3-plan": "VT12001: unsupported: cross-shard query with aggregates",
"gen4-plan": "VT12001: unsupported: in scatter query: complex aggregate expression"
"gen4-plan": {
"QueryType": "SELECT",
"Original": "select 100.00 * sum(case when p_type like 'PROMO%' then l_extendedprice * (1 - l_discount) else 0 end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue from lineitem, part where l_partkey = p_partkey and l_shipdate >= date('1995-09-01') and l_shipdate < date('1995-09-01') + interval '1' month",
"Instructions": {
"OperatorType": "Projection",
"Expressions": [
"([COLUMN 0] * [COLUMN 1]) / [COLUMN 2] as promo_revenue"
],
"Inputs": [
{
"OperatorType": "Aggregate",
"Variant": "Scalar",
"Aggregates": "random(0), sum(1) AS sum(case when p_type like 'PROMO%' then l_extendedprice * (1 - l_discount) else 0 end), sum(2) AS sum(l_extendedprice * (1 - l_discount))",
"Inputs": [
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:0,R:0,L:3",
"JoinVars": {
"l_discount": 2,
"l_extendedprice": 1,
"l_partkey": 4
},
"TableName": "lineitem_part",
"Inputs": [
{
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "main",
"Sharded": true
},
"FieldQuery": "select 100.00, l_extendedprice, l_discount, l_extendedprice * (1 - l_discount), l_partkey from lineitem where 1 != 1",
"Query": "select 100.00, l_extendedprice, l_discount, l_extendedprice * (1 - l_discount), l_partkey from lineitem where l_shipdate >= date('1995-09-01') and l_shipdate < date('1995-09-01') + interval '1' month",
"Table": "lineitem"
},
{
"OperatorType": "Route",
"Variant": "EqualUnique",
"Keyspace": {
"Name": "main",
"Sharded": true
},
"FieldQuery": "select case when p_type like 'PROMO%' then :l_extendedprice * (1 - :l_discount) else 0 end from part where 1 != 1",
"Query": "select case when p_type like 'PROMO%' then :l_extendedprice * (1 - :l_discount) else 0 end from part where p_partkey = :l_partkey",
"Table": "part",
"Values": [
":l_partkey"
],
"Vindex": "hash"
}
]
}
]
}
]
},
"TablesUsed": [
"main.lineitem",
"main.part"
]
}
},
{
"comment": "TPC-H query 15 view\n#\"with revenue0(supplier_no, total_revenue) as (select l_suppkey, sum(l_extendedprice * (1 - l_discount)) from lineitem where l_shipdate >= date('1996-01-01') and l_shipdate < date('1996-01-01') + interval '3' month group by l_suppkey )\"\n#\"syntax error at position 236\"\n#Gen4 plan same as above\n# TPC-H query 15",
Expand Down
Loading