Skip to content

Commit

Permalink
Refactored system function rewriting
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Mar 11, 2020
1 parent bdbcd38 commit cbef5f2
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 31 deletions.
58 changes: 41 additions & 17 deletions go/vt/sqlparser/expression_rewriting.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,29 @@ func PrepareAST(in Statement, bindVars map[string]*querypb.BindVariable, prefix
return RewriteAST(in)
}

// BindVarNeeds represents the bind vars that need to be provided as the result of expression rewriting.
type BindVarNeeds struct {
NeedLastInsertID bool
NeedDatabase bool
}

// RewriteAST rewrites the whole AST, replacing function calls and adding column aliases to queries
func RewriteAST(in Statement) (*RewriteASTResult, error) {
er := new(expressionRewriter)
er := newExpressionRewriter()
er.shouldRewriteDatabaseFunc = shouldRewriteDatabaseFunc(in)
Rewrite(in, er.goingDown, nil)

return &RewriteASTResult{
AST: in,
NeedLastInsertID: er.lastInsertID,
NeedDatabase: er.database,
}, nil
r := &RewriteASTResult{
AST: in,
}
if _, ok := er.bindVars[LastInsertIDName]; ok {
r.NeedLastInsertID = true
}
if _, ok := er.bindVars[DBVarName]; ok {
r.NeedDatabase = true
}

return r, nil
}

func shouldRewriteDatabaseFunc(in Statement) bool {
Expand All @@ -63,20 +75,24 @@ func shouldRewriteDatabaseFunc(in Statement) bool {

// RewriteASTResult contains the rewritten ast and meta information about it
type RewriteASTResult struct {
AST Statement
NeedLastInsertID bool
NeedDatabase bool
BindVarNeeds
AST Statement // The rewritten AST
}

type expressionRewriter struct {
lastInsertID, database bool
bindVars map[string]struct{}
shouldRewriteDatabaseFunc bool
err error
}

func newExpressionRewriter() *expressionRewriter {
return &expressionRewriter{bindVars: make(map[string]struct{})}
}

const (
//LastInsertIDName is a reserved bind var name for last_insert_id()
LastInsertIDName = "__lastInsertId"

//DBVarName is a reserved bind var name for database()
DBVarName = "__vtdbname"
)
Expand All @@ -87,7 +103,7 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool {
if node.As.IsEmpty() {
buf := NewTrackedBuffer(nil)
node.Expr.Format(buf)
inner := new(expressionRewriter)
inner := newExpressionRewriter()
inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc
tmp := Rewrite(node.Expr, inner.goingDown, nil)
newExpr, ok := tmp.(Expr)
Expand All @@ -96,11 +112,12 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool {
return false
}
node.Expr = newExpr
er.database = er.database || inner.database
er.lastInsertID = er.lastInsertID || inner.lastInsertID
if inner.didAnythingChange() {
node.As = NewColIdent(buf.String())
}
for k := range inner.bindVars {
er.needBindVarFor(k)
}
return false
}

Expand All @@ -111,24 +128,31 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool {
er.err = vterrors.New(vtrpc.Code_UNIMPLEMENTED, "Argument to LAST_INSERT_ID() not supported")
} else {
cursor.Replace(bindVarExpression(LastInsertIDName))
er.lastInsertID = true
er.needBindVarFor(LastInsertIDName)
}
case er.shouldRewriteDatabaseFunc &&
(node.Name.EqualString("database") ||
node.Name.EqualString("schema")):
node.Name.EqualString("schema")):
if len(node.Exprs) > 0 {
er.err = vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "Syntax error. %s() takes no arguments", node.Name.String())
} else {
cursor.Replace(bindVarExpression(DBVarName))
er.database = true
er.needBindVarFor(DBVarName)
}
}
}
return true
}

// instead of creating new objects, we'll reuse this one
var token = struct{}{}

func (er *expressionRewriter) needBindVarFor(name string) {
er.bindVars[name] = token
}

func (er *expressionRewriter) didAnythingChange() bool {
return er.database || er.lastInsertID
return len(er.bindVars) > 0
}

func bindVarExpression(name string) *SQLVal {
Expand Down
8 changes: 4 additions & 4 deletions go/vt/vtgate/engine/primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"sync"
"time"

"vitess.io/vitess/go/vt/sqlparser"

"golang.org/x/net/context"

"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -94,10 +96,8 @@ type Plan struct {
Rows uint64 `json:",omitempty"`
// Total number of errors
Errors uint64 `json:",omitempty"`
// NeedsLastInsertID signals whether this plan will need to be provided with last_insert_id
NeedsLastInsertID bool `json:"-"` // don't include in the json representation
// NeedsDatabaseName signals whether this plan will need to be provided with the database name
NeedsDatabaseName bool `json:"-"` // don't include in the json representation

sqlparser.BindVarNeeds
}

// AddStats updates the plan execution statistics
Expand Down
9 changes: 5 additions & 4 deletions go/vt/vtgate/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,11 @@ func (e *Executor) handleExec(ctx context.Context, safeSession *SafeSession, sql
return nil, err
}

if plan.NeedsLastInsertID {
bindVarNeeds := plan.BindVarNeeds
if bindVarNeeds.NeedLastInsertID {
bindVars[sqlparser.LastInsertIDName] = sqltypes.Uint64BindVariable(safeSession.GetLastInsertId())
}
if plan.NeedsDatabaseName {
if bindVarNeeds.NeedDatabase {
keyspace, _, _, _ := e.ParseDestinationTarget(safeSession.TargetString)
if keyspace == "" {
bindVars[sqlparser.DBVarName] = sqltypes.NullBindVariable
Expand Down Expand Up @@ -1415,7 +1416,7 @@ func (e *Executor) getPlan(vcursor *vcursorImpl, sql string, comments sqlparser.
return nil, err
}
if !e.normalize {
plan, err := planbuilder.BuildFromStmt(sql, stmt, vcursor, false, false)
plan, err := planbuilder.BuildFromStmt(sql, stmt, vcursor, sqlparser.BindVarNeeds{})
if err != nil {
return nil, err
}
Expand All @@ -1442,7 +1443,7 @@ func (e *Executor) getPlan(vcursor *vcursorImpl, sql string, comments sqlparser.
if result, ok := e.plans.Get(planKey); ok {
return result.(*engine.Plan), nil
}
plan, err := planbuilder.BuildFromStmt(normalized, rewrittenStatement, vcursor, result.NeedLastInsertID, result.NeedDatabase)
plan, err := planbuilder.BuildFromStmt(normalized, rewrittenStatement, vcursor, result.BindVarNeeds)
if err != nil {
return nil, err
}
Expand Down
11 changes: 5 additions & 6 deletions go/vt/vtgate/planbuilder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,14 @@ func Build(query string, vschema ContextVSchema) (*engine.Plan, error) {
return nil, err
}

return BuildFromStmt(query, result.AST, vschema, result.NeedLastInsertID, result.NeedDatabase)
return BuildFromStmt(query, result.AST, vschema, result.BindVarNeeds)
}

// BuildFromStmt builds a plan based on the AST provided.
// TODO(sougou): The query input is trusted as the source
// of the AST. Maybe this function just returns instructions
// and engine.Plan can be built by the caller.
func BuildFromStmt(query string, stmt sqlparser.Statement, vschema ContextVSchema, needsLastInsertID, needsDBName bool) (*engine.Plan, error) {
func BuildFromStmt(query string, stmt sqlparser.Statement, vschema ContextVSchema, bindVarNeeds sqlparser.BindVarNeeds) (*engine.Plan, error) {
var err error
var instruction engine.Primitive
switch stmt := stmt.(type) {
Expand Down Expand Up @@ -309,10 +309,9 @@ func BuildFromStmt(query string, stmt sqlparser.Statement, vschema ContextVSchem
return nil, err
}
plan := &engine.Plan{
Original: query,
Instructions: instruction,
NeedsLastInsertID: needsLastInsertID,
NeedsDatabaseName: needsDBName,
Original: query,
Instructions: instruction,
}
plan.BindVarNeeds = bindVarNeeds
return plan, nil
}

0 comments on commit cbef5f2

Please sign in to comment.