diff --git a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go index b7ef4c4a78d..f07fb734df8 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go +++ b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go @@ -425,3 +425,13 @@ func TestScalarAggregate(t *testing.T) { mcmp.Exec("insert into aggr_test(id, val1, val2) values(1,'a',1), (2,'A',1), (3,'b',1), (4,'c',3), (5,'c',4)") mcmp.AssertMatches("select /*vt+ PLANNER=gen4 */ count(distinct val1) from aggr_test", `[[INT64(3)]]`) } + +func TestAggregationRandomOnAnAggregatedValue(t *testing.T) { + mcmp, closer := start(t) + defer closer() + + mcmp.Exec("insert into t10(k, a, b) values (0, 100, 10), (10, 200, 20);") + + mcmp.AssertMatchesNoOrder("select /*vt+ PLANNER=gen4 */ A.a, A.b, (A.a / A.b) as d from (select sum(a) as a, sum(b) as b from t10 where a = 100) A;", + `[[DECIMAL(100) DECIMAL(10) DECIMAL(10.0000)]]`) +} diff --git a/go/test/endtoend/vtgate/queries/aggregation/schema.sql b/go/test/endtoend/vtgate/queries/aggregation/schema.sql index a538a3dafed..0375bdb8499 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/schema.sql +++ b/go/test/endtoend/vtgate/queries/aggregation/schema.sql @@ -71,3 +71,8 @@ CREATE TABLE t2 ( PRIMARY KEY (id) ) ENGINE InnoDB; +CREATE TABLE t10 ( + k BIGINT PRIMARY KEY, + a INT, + b INT +); \ No newline at end of file diff --git a/go/test/endtoend/vtgate/queries/aggregation/vschema.json b/go/test/endtoend/vtgate/queries/aggregation/vschema.json index c2d3f133a35..4d1623d5633 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/vschema.json +++ b/go/test/endtoend/vtgate/queries/aggregation/vschema.json @@ -123,6 +123,14 @@ "name": "hash" } ] + }, + "t10": { + "column_vindexes": [ + { + "column": "k", + "name": "hash" + } + ] } } } \ No newline at end of file diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index f76bd742d03..8ba046146a7 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -3856,6 +3856,40 @@ func TestSelectAggregationData(t *testing.T) { } } +func TestSelectAggregationRandom(t *testing.T) { + cell := "aa" + hc := discovery.NewFakeHealthCheck(nil) + createSandbox(KsTestSharded).VSchema = executorVSchema + getSandbox(KsTestUnsharded).VSchema = unshardedVSchema + serv := newSandboxForCells([]string{cell}) + resolver := newTestResolver(hc, serv, cell) + shards := []string{"-20", "20-40", "40-60", "60-80", "80-a0", "a0-c0", "c0-e0", "e0-"} + var conns []*sandboxconn.SandboxConn + for _, shard := range shards { + sbc := hc.AddTestTablet(cell, shard, 1, KsTestSharded, shard, topodatapb.TabletType_PRIMARY, true, 1, nil) + conns = append(conns, sbc) + + sbc.SetResults([]*sqltypes.Result{sqltypes.MakeTestResult( + sqltypes.MakeTestFields("a|b", "int64|int64"), + "null|null", + )}) + } + + conns[0].SetResults([]*sqltypes.Result{sqltypes.MakeTestResult( + sqltypes.MakeTestFields("a|b", "int64|int64"), + "10|1", + )}) + + executor := createExecutor(serv, cell, resolver) + executor.pv = querypb.ExecuteOptions_Gen4 + session := NewAutocommitSession(&vtgatepb.Session{}) + + rs, err := executor.Execute(context.Background(), "TestSelectCFC", session, + "select /*vt+ PLANNER=gen4 */ A.a, A.b, (A.a / A.b) as c from (select sum(a) as a, sum(b) as b from user) A", nil) + require.NoError(t, err) + assert.Equal(t, `[[INT64(10) INT64(1) DECIMAL(10.0000)]]`, fmt.Sprintf("%v", rs.Rows)) +} + func TestSelectHexAndBit(t *testing.T) { executor, _, _, _ := createExecutorEnv() executor.normalize = true diff --git a/go/vt/vtgate/planbuilder/gen4_planner.go b/go/vt/vtgate/planbuilder/gen4_planner.go index dc49ae0a700..6822dcff642 100644 --- a/go/vt/vtgate/planbuilder/gen4_planner.go +++ b/go/vt/vtgate/planbuilder/gen4_planner.go @@ -216,7 +216,7 @@ func newBuildSelectPlan( return nil, nil, nil, err } - plan = optimizePlan(plan) + optimizePlan(plan) sel, isSel := selStmt.(*sqlparser.Select) if isSel { @@ -238,25 +238,25 @@ func newBuildSelectPlan( } // optimizePlan removes unnecessary simpleProjections that have been created while planning -func optimizePlan(plan logicalPlan) logicalPlan { - newPlan, _ := visit(plan, func(plan logicalPlan) (bool, logicalPlan, error) { - this, ok := plan.(*simpleProjection) - if !ok { - return true, plan, nil - } +func optimizePlan(plan logicalPlan) { + for _, lp := range plan.Inputs() { + optimizePlan(lp) + } - input, ok := this.input.(*simpleProjection) - if !ok { - return true, plan, nil - } + this, ok := plan.(*simpleProjection) + if !ok { + return + } - for i, col := range this.eSimpleProj.Cols { - this.eSimpleProj.Cols[i] = input.eSimpleProj.Cols[col] - } - this.input = input.input - return true, this, nil - }) - return newPlan + input, ok := this.input.(*simpleProjection) + if !ok { + return + } + + for i, col := range this.eSimpleProj.Cols { + this.eSimpleProj.Cols[i] = input.eSimpleProj.Cols[col] + } + this.input = input.input } func gen4UpdateStmtPlanner( diff --git a/go/vt/vtgate/planbuilder/horizon_planning.go b/go/vt/vtgate/planbuilder/horizon_planning.go index eea1400b916..4e33f62ebe5 100644 --- a/go/vt/vtgate/planbuilder/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/horizon_planning.go @@ -60,7 +60,8 @@ func (hp *horizonPlanning) planHorizon(ctx *plancontext.PlanningContext, plan lo // a simpleProjection. We create a new Route that contains the derived table in the // FROM clause. Meaning that, when we push expressions to the select list of this // new Route, we do not want them to rewrite them. - if _, isSimpleProj := plan.(*simpleProjection); isSimpleProj { + sp, derivedTable := plan.(*simpleProjection) + if derivedTable { oldRewriteDerivedExpr := ctx.RewriteDerivedExpr defer func() { ctx.RewriteDerivedExpr = oldRewriteDerivedExpr @@ -75,10 +76,11 @@ func (hp *horizonPlanning) planHorizon(ctx *plancontext.PlanningContext, plan lo } needsOrdering := len(hp.qp.OrderExprs) > 0 - canShortcut := isRoute && hp.sel.Having == nil && !needsOrdering // If we still have a HAVING clause, it's because it could not be pushed to the WHERE, // so it probably has aggregations + canShortcut := isRoute && hp.sel.Having == nil && !needsOrdering + switch { case hp.qp.NeedsAggregation() || hp.sel.Having != nil: plan, err = hp.planAggregations(ctx, plan) @@ -92,6 +94,26 @@ func (hp *horizonPlanning) planHorizon(ctx *plancontext.PlanningContext, plan lo if err != nil { return nil, err } + case derivedTable: + pusher := func(ae *sqlparser.AliasedExpr) (int, error) { + offset, _, err := pushProjection(ctx, ae, sp.input, true, true, false) + return offset, err + } + needsVtGate, projections, colNames, err := hp.qp.NeedsProjecting(ctx, pusher) + if err != nil { + return nil, err + } + if !needsVtGate { + break + } + + // there were some expressions we could not push down entirely, + // so replace the simpleProjection with a real projection + plan = &projection{ + source: sp.input, + columns: projections, + columnNames: colNames, + } default: err = pushProjections(ctx, plan, hp.qp.SelectExprs) if err != nil { diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index 29e356c6650..8de53a762be 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -418,7 +418,85 @@ func (qp *QueryProjection) NeedsAggregation() bool { return qp.HasAggr || len(qp.groupByExprs) > 0 } -func (qp QueryProjection) onlyAggr() bool { +// NeedsProjecting returns true if we have projections that need to be evaluated at the vtgate level +// and can't be pushed down to MySQL +func (qp *QueryProjection) NeedsProjecting( + ctx *plancontext.PlanningContext, + pusher func(expr *sqlparser.AliasedExpr) (int, error), +) (needsVtGateEval bool, expressions []sqlparser.Expr, colNames []string, err error) { + for _, se := range qp.SelectExprs { + var ae *sqlparser.AliasedExpr + ae, err = se.GetAliasedExpr() + if err != nil { + return false, nil, nil, err + } + + expr := ae.Expr + colNames = append(colNames, ae.ColumnName()) + + if _, isCol := expr.(*sqlparser.ColName); isCol { + offset, err := pusher(ae) + if err != nil { + return false, nil, nil, err + } + expressions = append(expressions, sqlparser.NewOffset(offset, expr)) + continue + } + + stopOnError := func(sqlparser.SQLNode, sqlparser.SQLNode) bool { + return err == nil + } + rewriter := func(cursor *sqlparser.CopyOnWriteCursor) { + col, isCol := cursor.Node().(*sqlparser.ColName) + if !isCol { + return + } + var tableInfo semantics.TableInfo + tableInfo, err = ctx.SemTable.TableInfoForExpr(col) + if err != nil { + return + } + dt, isDT := tableInfo.(*semantics.DerivedTable) + if !isDT { + return + } + + rewritten := semantics.RewriteDerivedTableExpression(col, dt) + if sqlparser.ContainsAggregation(rewritten) { + offset, tErr := pusher(&sqlparser.AliasedExpr{Expr: col}) + if tErr != nil { + err = tErr + return + } + + cursor.Replace(sqlparser.NewOffset(offset, col)) + } + } + newExpr := sqlparser.CopyOnRewrite(expr, stopOnError, rewriter, nil) + + if err != nil { + return + } + + if newExpr != expr { + // if we changed the expression, it means that we have to evaluate the rest at the vtgate level + expressions = append(expressions, newExpr.(sqlparser.Expr)) + needsVtGateEval = true + continue + } + + // we did not need to push any parts of this expression down. Let's check if we can push all of it + offset, err := pusher(ae) + if err != nil { + return false, nil, nil, err + } + expressions = append(expressions, sqlparser.NewOffset(offset, expr)) + } + + return +} + +func (qp *QueryProjection) onlyAggr() bool { if !qp.HasAggr { return false } diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index f8e6c7fcde1..e0cee664828 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -4964,5 +4964,45 @@ "user.user_extra" ] } + }, + { + "comment": "Aggregations from derived table used in arithmetic outside derived table", + "query": "select A.a, A.b, (A.a / A.b) as d from (select sum(a) as a, sum(b) as b from user) A", + "v3-plan": "VT12001: unsupported: expression on results of a cross-shard subquery", + "gen4-plan": { + "QueryType": "SELECT", + "Original": "select A.a, A.b, (A.a / A.b) as d from (select sum(a) as a, sum(b) as b from user) A", + "Instructions": { + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] as a", + "[COLUMN 1] as b", + "[COLUMN 0] / [COLUMN 1] as d" + ], + "Inputs": [ + { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum(0) AS a, sum(1) AS b", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select sum(a) as a, sum(b) as b from `user` where 1 != 1", + "Query": "select sum(a) as a, sum(b) as b from `user`", + "Table": "`user`" + } + ] + } + ] + }, + "TablesUsed": [ + "user.user" + ] + } } ]