Skip to content

Commit

Permalink
Add support for found_rows()
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Mar 11, 2020
1 parent cbef5f2 commit 1241ef6
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 13 deletions.
4 changes: 3 additions & 1 deletion go/vt/proto/vtgate/vtgate.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

62 changes: 62 additions & 0 deletions go/vt/sqlparser/expression_rewriting.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func PrepareAST(in Statement, bindVars map[string]*querypb.BindVariable, prefix
type BindVarNeeds struct {
NeedLastInsertID bool
NeedDatabase bool
NeedFoundRows bool
}

// RewriteAST rewrites the whole AST, replacing function calls and adding column aliases to queries
Expand All @@ -50,6 +51,9 @@ func RewriteAST(in Statement) (*RewriteASTResult, error) {
if _, ok := er.bindVars[DBVarName]; ok {
r.NeedDatabase = true
}
if _, ok := er.bindVars[FoundRowsName]; ok {
r.NeedFoundRows = true
}

return r, nil
}
Expand Down Expand Up @@ -95,8 +99,58 @@ const (

//DBVarName is a reserved bind var name for database()
DBVarName = "__vtdbname"

//FoundRowsName is a reserved bind var name for found_rows()
FoundRowsName = "__vtfrows"
)

//<<<<<<< HEAD
//=======
//type funcRewrite struct {
// checkValid func(f *FuncExpr) error
// bindVarName string
//}
//
//var lastInsertID = funcRewrite{
// checkValid: func(f *FuncExpr) error {
// if len(f.Exprs) > 0 {
// return vterrors.New(vtrpc.Code_UNIMPLEMENTED, "Argument to LAST_INSERT_ID() not supported")
// }
// return nil
// },
// bindVarName: LastInsertIDName,
//}
//
//var dbName = funcRewrite{
// checkValid: func(f *FuncExpr) error {
// if len(f.Exprs) > 0 {
// return vterrors.New(vtrpc.Code_INVALID_ARGUMENT, "Argument to DATABASE() not supported")
// }
// return nil
// },
// bindVarName: DBVarName,
//}
//var foundRows = funcRewrite{
// checkValid: func(f *FuncExpr) error {
// if len(f.Exprs) > 0 {
// return vterrors.New(vtrpc.Code_INVALID_ARGUMENT, "Argument to FOUND_ROWS() not supported")
// }
// return nil
// },
// bindVarName: FoundRowsName,
//}
//
//var functions = map[string]*funcRewrite{
// "last_insert_id": &lastInsertID,
// "database": &dbName,
// "schema": &dbName,
// "found_rows": &foundRows,
//}
//
//// instead of creating new objects, we'll reuse this one
//var token = struct{}{}
//
//>>>>>>> Add support for found_rows()
func (er *expressionRewriter) goingDown(cursor *Cursor) bool {
switch node := cursor.Node().(type) {
case *AliasedExpr:
Expand Down Expand Up @@ -139,7 +193,15 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool {
cursor.Replace(bindVarExpression(DBVarName))
er.needBindVarFor(DBVarName)
}
case node.Name.EqualString("found_rows"):
if len(node.Exprs) > 0 {
er.err = vterrors.New(vtrpc.Code_INVALID_ARGUMENT, "Arguments to FOUND_ROWS() not supported")
} else {
cursor.Replace(bindVarExpression(FoundRowsName))
er.needBindVarFor(FoundRowsName)
}
}

}
return true
}
Expand Down
24 changes: 15 additions & 9 deletions go/vt/sqlparser/expression_rewriting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@ import (
)

type myTestCase struct {
in, expected string
liid, db bool
in, expected string
liid, db, foundRows bool
}

func TestRewrites(in *testing.T) {
tests := []myTestCase{
{
in: "SELECT 42",
expected: "SELECT 42",
db: false, liid: false,
db: false, liid: false, foundRows: false,
},
{
in: "SELECT last_insert_id()",
Expand All @@ -42,7 +42,7 @@ func TestRewrites(in *testing.T) {
{
in: "SELECT database()",
expected: "SELECT :__vtdbname as `database()`",
db: true, liid: false,
db: true, liid: false, foundRows: false,
},
{
in: "SELECT database() from test",
Expand All @@ -52,12 +52,12 @@ func TestRewrites(in *testing.T) {
{
in: "SELECT last_insert_id() as test",
expected: "SELECT :__lastInsertId as test",
db: false, liid: true,
db: false, liid: true, foundRows: false,
},
{
in: "SELECT last_insert_id() + database()",
expected: "SELECT :__lastInsertId + :__vtdbname as `last_insert_id() + database()`",
db: true, liid: true,
db: true, liid: true, foundRows: false,
},
{
in: "select (select database()) from test",
Expand All @@ -72,7 +72,7 @@ func TestRewrites(in *testing.T) {
{
in: "select (select database() from dual) from dual",
expected: "select (select :__vtdbname as `database()` from dual) as `(select database() from dual)` from dual",
db: true, liid: false,
db: true, liid: false, foundRows: false,
},
{
in: "select id from user where database()",
Expand All @@ -82,12 +82,17 @@ func TestRewrites(in *testing.T) {
{
in: "select table_name from information_schema.tables where table_schema = database()",
expected: "select table_name from information_schema.tables where table_schema = database()",
db: false, liid: false,
db: false, liid: false, foundRows: false,
},
{
in: "select schema()",
expected: "select :__vtdbname as 'schema()'",
db: true, liid: false,
db: true, liid: false, foundRows: false,
},
{
in: "select found_rows()",
expected: "select :__vtfrows as 'found_rows()'",
db: false, liid: false, foundRows: true,
},
}

Expand All @@ -106,6 +111,7 @@ func TestRewrites(in *testing.T) {
require.Equal(t, s, String(result.AST))
require.Equal(t, tc.liid, result.NeedLastInsertID, "should need last insert id")
require.Equal(t, tc.db, result.NeedDatabase, "should need database name")
require.Equal(t, tc.foundRows, result.NeedFoundRows, "should need found rows")
})
}
}
6 changes: 6 additions & 0 deletions go/vt/vtgate/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ func (e *Executor) Execute(ctx context.Context, method string, safeSession *Safe

logStats := NewLogStats(ctx, method, sql, bindVars)
result, err = e.execute(ctx, safeSession, sql, bindVars, logStats)
if err == nil {
safeSession.FoundRows = result.RowsAffected
}
logStats.Error = err
if result != nil && len(result.Rows) > *warnMemoryRows {
warnings.Add("ResultsExceeded", 1)
Expand Down Expand Up @@ -340,6 +343,9 @@ func (e *Executor) handleExec(ctx context.Context, safeSession *SafeSession, sql
bindVars[sqlparser.DBVarName] = sqltypes.StringBindVariable(keyspace)
}
}
if bindVarNeeds.NeedFoundRows {
bindVars[sqlparser.FoundRowsName] = sqltypes.Uint64BindVariable(safeSession.FoundRows)
}

qr, err := plan.Instructions.Execute(vcursor, bindVars, true)
logStats.ExecuteTime = time.Since(execStart)
Expand Down
21 changes: 21 additions & 0 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,27 @@ func TestSelectLastInsertId(t *testing.T) {
assert.Equal(t, wantQueries, sbc1.Queries)
}

func TestFoundRows(t *testing.T) {
executor, sbc1, _, _ := createExecutorEnv()
executor.normalize = true
logChan := QueryLogger.Subscribe("Test")
defer QueryLogger.Unsubscribe(logChan)

// run this extra query so we can assert on the number of rows found
_, err := executorExec(executor, "select 42", map[string]*querypb.BindVariable{})
require.NoError(t, err)

sql := "select found_rows()"
_, err = executorExec(executor, sql, map[string]*querypb.BindVariable{})
require.NoError(t, err)
expected := &querypb.BoundQuery{
Sql: "select :__vtfrows as `found_rows()` from dual",
BindVariables: map[string]*querypb.BindVariable{"__vtfrows": sqltypes.Uint64BindVariable(1)},
}

assert.Equal(t, expected, sbc1.Queries[1])
}

func TestSelectLastInsertIdInUnion(t *testing.T) {
executor, sbc1, _, _ := createExecutorEnv()
executor.normalize = true
Expand Down
6 changes: 3 additions & 3 deletions go/vt/vtgate/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ func TestExecutorAutocommit(t *testing.T) {
if err != nil {
t.Fatal(err)
}
wantSession := &vtgatepb.Session{TargetString: "@master", InTransaction: true}
wantSession := &vtgatepb.Session{TargetString: "@master", InTransaction: true, FoundRows: 1}
testSession := *session.Session
testSession.ShardSessions = nil
if !proto.Equal(&testSession, wantSession) {
Expand Down Expand Up @@ -643,7 +643,7 @@ func TestExecutorAutocommit(t *testing.T) {
if err != nil {
t.Fatal(err)
}
wantSession = &vtgatepb.Session{Autocommit: true, TargetString: "@master"}
wantSession = &vtgatepb.Session{Autocommit: true, TargetString: "@master", FoundRows: 1}
if !proto.Equal(session.Session, wantSession) {
t.Errorf("autocommit=1: %v, want %v", session.Session, wantSession)
}
Expand Down Expand Up @@ -672,7 +672,7 @@ func TestExecutorAutocommit(t *testing.T) {
if err != nil {
t.Fatal(err)
}
wantSession = &vtgatepb.Session{InTransaction: true, Autocommit: true, TargetString: "@master"}
wantSession = &vtgatepb.Session{InTransaction: true, Autocommit: true, TargetString: "@master", FoundRows: 1}
testSession = *session.Session
testSession.ShardSessions = nil
if !proto.Equal(&testSession, wantSession) {
Expand Down
1 change: 1 addition & 0 deletions go/vt/vtgate/vtgate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ func TestVTGateExecute(t *testing.T) {
},
TransactionId: 1,
}},
FoundRows: 1,
}
if !proto.Equal(wantSession, session) {
t.Errorf("want \n%+v, got \n%+v", wantSession, session)
Expand Down

0 comments on commit 1241ef6

Please sign in to comment.