Skip to content

Commit

Permalink
[gen4 planner] Make sure to not push down expressions when not possib…
Browse files Browse the repository at this point in the history
…le (#12607)

* Fix random aggregation to not select Null column

Signed-off-by: Florent Poinsard <[email protected]>

* stop pushing down projections that should be evaluated at the vtgate level

Signed-off-by: Andres Taylor <[email protected]>

* undo changes to AggregateRandom

Signed-off-by: Andres Taylor <[email protected]>

* clean up code

Signed-off-by: Andres Taylor <[email protected]>

* fix executor test mock

Signed-off-by: Florent Poinsard <[email protected]>

---------

Signed-off-by: Florent Poinsard <[email protected]>
Signed-off-by: Andres Taylor <[email protected]>
Co-authored-by: Andres Taylor <[email protected]>
  • Loading branch information
frouioui and systay authored Mar 17, 2023
1 parent fb27743 commit d185e0e
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 21 deletions.
10 changes: 10 additions & 0 deletions go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,13 @@ func TestScalarAggregate(t *testing.T) {
mcmp.Exec("insert into aggr_test(id, val1, val2) values(1,'a',1), (2,'A',1), (3,'b',1), (4,'c',3), (5,'c',4)")
mcmp.AssertMatches("select /*vt+ PLANNER=gen4 */ count(distinct val1) from aggr_test", `[[INT64(3)]]`)
}

func TestAggregationRandomOnAnAggregatedValue(t *testing.T) {
mcmp, closer := start(t)
defer closer()

mcmp.Exec("insert into t10(k, a, b) values (0, 100, 10), (10, 200, 20);")

mcmp.AssertMatchesNoOrder("select /*vt+ PLANNER=gen4 */ A.a, A.b, (A.a / A.b) as d from (select sum(a) as a, sum(b) as b from t10 where a = 100) A;",
`[[DECIMAL(100) DECIMAL(10) DECIMAL(10.0000)]]`)
}
5 changes: 5 additions & 0 deletions go/test/endtoend/vtgate/queries/aggregation/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,8 @@ CREATE TABLE t2 (
PRIMARY KEY (id)
) ENGINE InnoDB;

CREATE TABLE t10 (
k BIGINT PRIMARY KEY,
a INT,
b INT
);
8 changes: 8 additions & 0 deletions go/test/endtoend/vtgate/queries/aggregation/vschema.json
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@
"name": "hash"
}
]
},
"t10": {
"column_vindexes": [
{
"column": "k",
"name": "hash"
}
]
}
}
}
34 changes: 34 additions & 0 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3856,6 +3856,40 @@ func TestSelectAggregationData(t *testing.T) {
}
}

func TestSelectAggregationRandom(t *testing.T) {
cell := "aa"
hc := discovery.NewFakeHealthCheck(nil)
createSandbox(KsTestSharded).VSchema = executorVSchema
getSandbox(KsTestUnsharded).VSchema = unshardedVSchema
serv := newSandboxForCells([]string{cell})
resolver := newTestResolver(hc, serv, cell)
shards := []string{"-20", "20-40", "40-60", "60-80", "80-a0", "a0-c0", "c0-e0", "e0-"}
var conns []*sandboxconn.SandboxConn
for _, shard := range shards {
sbc := hc.AddTestTablet(cell, shard, 1, KsTestSharded, shard, topodatapb.TabletType_PRIMARY, true, 1, nil)
conns = append(conns, sbc)

sbc.SetResults([]*sqltypes.Result{sqltypes.MakeTestResult(
sqltypes.MakeTestFields("a|b", "int64|int64"),
"null|null",
)})
}

conns[0].SetResults([]*sqltypes.Result{sqltypes.MakeTestResult(
sqltypes.MakeTestFields("a|b", "int64|int64"),
"10|1",
)})

executor := createExecutor(serv, cell, resolver)
executor.pv = querypb.ExecuteOptions_Gen4
session := NewAutocommitSession(&vtgatepb.Session{})

rs, err := executor.Execute(context.Background(), "TestSelectCFC", session,
"select /*vt+ PLANNER=gen4 */ A.a, A.b, (A.a / A.b) as c from (select sum(a) as a, sum(b) as b from user) A", nil)
require.NoError(t, err)
assert.Equal(t, `[[INT64(10) INT64(1) DECIMAL(10.0000)]]`, fmt.Sprintf("%v", rs.Rows))
}

func TestSelectHexAndBit(t *testing.T) {
executor, _, _, _ := createExecutorEnv()
executor.normalize = true
Expand Down
36 changes: 18 additions & 18 deletions go/vt/vtgate/planbuilder/gen4_planner.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func newBuildSelectPlan(
return nil, nil, nil, err
}

plan = optimizePlan(plan)
optimizePlan(plan)

sel, isSel := selStmt.(*sqlparser.Select)
if isSel {
Expand All @@ -238,25 +238,25 @@ func newBuildSelectPlan(
}

// optimizePlan removes unnecessary simpleProjections that have been created while planning
func optimizePlan(plan logicalPlan) logicalPlan {
newPlan, _ := visit(plan, func(plan logicalPlan) (bool, logicalPlan, error) {
this, ok := plan.(*simpleProjection)
if !ok {
return true, plan, nil
}
func optimizePlan(plan logicalPlan) {
for _, lp := range plan.Inputs() {
optimizePlan(lp)
}

input, ok := this.input.(*simpleProjection)
if !ok {
return true, plan, nil
}
this, ok := plan.(*simpleProjection)
if !ok {
return
}

for i, col := range this.eSimpleProj.Cols {
this.eSimpleProj.Cols[i] = input.eSimpleProj.Cols[col]
}
this.input = input.input
return true, this, nil
})
return newPlan
input, ok := this.input.(*simpleProjection)
if !ok {
return
}

for i, col := range this.eSimpleProj.Cols {
this.eSimpleProj.Cols[i] = input.eSimpleProj.Cols[col]
}
this.input = input.input
}

func gen4UpdateStmtPlanner(
Expand Down
26 changes: 24 additions & 2 deletions go/vt/vtgate/planbuilder/horizon_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ func (hp *horizonPlanning) planHorizon(ctx *plancontext.PlanningContext, plan lo
// a simpleProjection. We create a new Route that contains the derived table in the
// FROM clause. Meaning that, when we push expressions to the select list of this
// new Route, we do not want them to rewrite them.
if _, isSimpleProj := plan.(*simpleProjection); isSimpleProj {
sp, derivedTable := plan.(*simpleProjection)
if derivedTable {
oldRewriteDerivedExpr := ctx.RewriteDerivedExpr
defer func() {
ctx.RewriteDerivedExpr = oldRewriteDerivedExpr
Expand All @@ -75,10 +76,11 @@ func (hp *horizonPlanning) planHorizon(ctx *plancontext.PlanningContext, plan lo
}

needsOrdering := len(hp.qp.OrderExprs) > 0
canShortcut := isRoute && hp.sel.Having == nil && !needsOrdering

// If we still have a HAVING clause, it's because it could not be pushed to the WHERE,
// so it probably has aggregations
canShortcut := isRoute && hp.sel.Having == nil && !needsOrdering

switch {
case hp.qp.NeedsAggregation() || hp.sel.Having != nil:
plan, err = hp.planAggregations(ctx, plan)
Expand All @@ -92,6 +94,26 @@ func (hp *horizonPlanning) planHorizon(ctx *plancontext.PlanningContext, plan lo
if err != nil {
return nil, err
}
case derivedTable:
pusher := func(ae *sqlparser.AliasedExpr) (int, error) {
offset, _, err := pushProjection(ctx, ae, sp.input, true, true, false)
return offset, err
}
needsVtGate, projections, colNames, err := hp.qp.NeedsProjecting(ctx, pusher)
if err != nil {
return nil, err
}
if !needsVtGate {
break
}

// there were some expressions we could not push down entirely,
// so replace the simpleProjection with a real projection
plan = &projection{
source: sp.input,
columns: projections,
columnNames: colNames,
}
default:
err = pushProjections(ctx, plan, hp.qp.SelectExprs)
if err != nil {
Expand Down
80 changes: 79 additions & 1 deletion go/vt/vtgate/planbuilder/operators/queryprojection.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,85 @@ func (qp *QueryProjection) NeedsAggregation() bool {
return qp.HasAggr || len(qp.groupByExprs) > 0
}

func (qp QueryProjection) onlyAggr() bool {
// NeedsProjecting returns true if we have projections that need to be evaluated at the vtgate level
// and can't be pushed down to MySQL
func (qp *QueryProjection) NeedsProjecting(
ctx *plancontext.PlanningContext,
pusher func(expr *sqlparser.AliasedExpr) (int, error),
) (needsVtGateEval bool, expressions []sqlparser.Expr, colNames []string, err error) {
for _, se := range qp.SelectExprs {
var ae *sqlparser.AliasedExpr
ae, err = se.GetAliasedExpr()
if err != nil {
return false, nil, nil, err
}

expr := ae.Expr
colNames = append(colNames, ae.ColumnName())

if _, isCol := expr.(*sqlparser.ColName); isCol {
offset, err := pusher(ae)
if err != nil {
return false, nil, nil, err
}
expressions = append(expressions, sqlparser.NewOffset(offset, expr))
continue
}

stopOnError := func(sqlparser.SQLNode, sqlparser.SQLNode) bool {
return err == nil
}
rewriter := func(cursor *sqlparser.CopyOnWriteCursor) {
col, isCol := cursor.Node().(*sqlparser.ColName)
if !isCol {
return
}
var tableInfo semantics.TableInfo
tableInfo, err = ctx.SemTable.TableInfoForExpr(col)
if err != nil {
return
}
dt, isDT := tableInfo.(*semantics.DerivedTable)
if !isDT {
return
}

rewritten := semantics.RewriteDerivedTableExpression(col, dt)
if sqlparser.ContainsAggregation(rewritten) {
offset, tErr := pusher(&sqlparser.AliasedExpr{Expr: col})
if tErr != nil {
err = tErr
return
}

cursor.Replace(sqlparser.NewOffset(offset, col))
}
}
newExpr := sqlparser.CopyOnRewrite(expr, stopOnError, rewriter, nil)

if err != nil {
return
}

if newExpr != expr {
// if we changed the expression, it means that we have to evaluate the rest at the vtgate level
expressions = append(expressions, newExpr.(sqlparser.Expr))
needsVtGateEval = true
continue
}

// we did not need to push any parts of this expression down. Let's check if we can push all of it
offset, err := pusher(ae)
if err != nil {
return false, nil, nil, err
}
expressions = append(expressions, sqlparser.NewOffset(offset, expr))
}

return
}

func (qp *QueryProjection) onlyAggr() bool {
if !qp.HasAggr {
return false
}
Expand Down
40 changes: 40 additions & 0 deletions go/vt/vtgate/planbuilder/testdata/aggr_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -4964,5 +4964,45 @@
"user.user_extra"
]
}
},
{
"comment": "Aggregations from derived table used in arithmetic outside derived table",
"query": "select A.a, A.b, (A.a / A.b) as d from (select sum(a) as a, sum(b) as b from user) A",
"v3-plan": "VT12001: unsupported: expression on results of a cross-shard subquery",
"gen4-plan": {
"QueryType": "SELECT",
"Original": "select A.a, A.b, (A.a / A.b) as d from (select sum(a) as a, sum(b) as b from user) A",
"Instructions": {
"OperatorType": "Projection",
"Expressions": [
"[COLUMN 0] as a",
"[COLUMN 1] as b",
"[COLUMN 0] / [COLUMN 1] as d"
],
"Inputs": [
{
"OperatorType": "Aggregate",
"Variant": "Scalar",
"Aggregates": "sum(0) AS a, sum(1) AS b",
"Inputs": [
{
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select sum(a) as a, sum(b) as b from `user` where 1 != 1",
"Query": "select sum(a) as a, sum(b) as b from `user`",
"Table": "`user`"
}
]
}
]
},
"TablesUsed": [
"user.user"
]
}
}
]

0 comments on commit d185e0e

Please sign in to comment.