From 1241ef68e7a5fbcd3918962f70f80f597c2ca0a8 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Fri, 17 Jan 2020 11:51:46 -0800 Subject: [PATCH] Add support for found_rows() Signed-off-by: Andres Taylor --- go/vt/proto/vtgate/vtgate.pb.go | 4 +- go/vt/sqlparser/expression_rewriting.go | 62 ++++++++++++++++++++ go/vt/sqlparser/expression_rewriting_test.go | 24 +++++--- go/vt/vtgate/executor.go | 6 ++ go/vt/vtgate/executor_select_test.go | 21 +++++++ go/vt/vtgate/executor_test.go | 6 +- go/vt/vtgate/vtgate_test.go | 1 + 7 files changed, 111 insertions(+), 13 deletions(-) diff --git a/go/vt/proto/vtgate/vtgate.pb.go b/go/vt/proto/vtgate/vtgate.pb.go index 74af5414ddd..61ba91e777b 100644 --- a/go/vt/proto/vtgate/vtgate.pb.go +++ b/go/vt/proto/vtgate/vtgate.pb.go @@ -137,7 +137,9 @@ type Session struct { // post_sessions contains sessions that have to be committed last. PostSessions []*Session_ShardSession `protobuf:"bytes,10,rep,name=post_sessions,json=postSessions,proto3" json:"post_sessions,omitempty"` // last_insert_id keeps track of the last seen insert_id for this session - LastInsertId uint64 `protobuf:"varint,11,opt,name=last_insert_id,json=lastInsertId,proto3" json:"last_insert_id,omitempty"` + LastInsertId uint64 `protobuf:"varint,11,opt,name=last_insert_id,json=lastInsertId,proto3" json:"last_insert_id,omitempty"` + // found_rows keeps track of how many rows the last query returned + FoundRows uint64 `protobuf:"varint,11,opt,name=found_rows,json=foundRows,proto3" json:"found_rows,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` diff --git a/go/vt/sqlparser/expression_rewriting.go b/go/vt/sqlparser/expression_rewriting.go index d33491422a5..d2936973433 100644 --- a/go/vt/sqlparser/expression_rewriting.go +++ b/go/vt/sqlparser/expression_rewriting.go @@ -33,6 +33,7 @@ func PrepareAST(in Statement, bindVars map[string]*querypb.BindVariable, prefix type BindVarNeeds struct { NeedLastInsertID bool NeedDatabase bool + NeedFoundRows bool } // RewriteAST rewrites the whole AST, replacing function calls and adding column aliases to queries @@ -50,6 +51,9 @@ func RewriteAST(in Statement) (*RewriteASTResult, error) { if _, ok := er.bindVars[DBVarName]; ok { r.NeedDatabase = true } + if _, ok := er.bindVars[FoundRowsName]; ok { + r.NeedFoundRows = true + } return r, nil } @@ -95,8 +99,58 @@ const ( //DBVarName is a reserved bind var name for database() DBVarName = "__vtdbname" + + //FoundRowsName is a reserved bind var name for found_rows() + FoundRowsName = "__vtfrows" ) +//<<<<<<< HEAD +//======= +//type funcRewrite struct { +// checkValid func(f *FuncExpr) error +// bindVarName string +//} +// +//var lastInsertID = funcRewrite{ +// checkValid: func(f *FuncExpr) error { +// if len(f.Exprs) > 0 { +// return vterrors.New(vtrpc.Code_UNIMPLEMENTED, "Argument to LAST_INSERT_ID() not supported") +// } +// return nil +// }, +// bindVarName: LastInsertIDName, +//} +// +//var dbName = funcRewrite{ +// checkValid: func(f *FuncExpr) error { +// if len(f.Exprs) > 0 { +// return vterrors.New(vtrpc.Code_INVALID_ARGUMENT, "Argument to DATABASE() not supported") +// } +// return nil +// }, +// bindVarName: DBVarName, +//} +//var foundRows = funcRewrite{ +// checkValid: func(f *FuncExpr) error { +// if len(f.Exprs) > 0 { +// return vterrors.New(vtrpc.Code_INVALID_ARGUMENT, "Argument to FOUND_ROWS() not supported") +// } +// return nil +// }, +// bindVarName: FoundRowsName, +//} +// +//var functions = map[string]*funcRewrite{ +// "last_insert_id": &lastInsertID, +// "database": &dbName, +// "schema": &dbName, +// "found_rows": &foundRows, +//} +// +//// instead of creating new objects, we'll reuse this one +//var token = struct{}{} +// +//>>>>>>> Add support for found_rows() func (er *expressionRewriter) goingDown(cursor *Cursor) bool { switch node := cursor.Node().(type) { case *AliasedExpr: @@ -139,7 +193,15 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool { cursor.Replace(bindVarExpression(DBVarName)) er.needBindVarFor(DBVarName) } + case node.Name.EqualString("found_rows"): + if len(node.Exprs) > 0 { + er.err = vterrors.New(vtrpc.Code_INVALID_ARGUMENT, "Arguments to FOUND_ROWS() not supported") + } else { + cursor.Replace(bindVarExpression(FoundRowsName)) + er.needBindVarFor(FoundRowsName) + } } + } return true } diff --git a/go/vt/sqlparser/expression_rewriting_test.go b/go/vt/sqlparser/expression_rewriting_test.go index 2951db2dc5c..b7dd524cbbb 100644 --- a/go/vt/sqlparser/expression_rewriting_test.go +++ b/go/vt/sqlparser/expression_rewriting_test.go @@ -23,8 +23,8 @@ import ( ) type myTestCase struct { - in, expected string - liid, db bool + in, expected string + liid, db, foundRows bool } func TestRewrites(in *testing.T) { @@ -32,7 +32,7 @@ func TestRewrites(in *testing.T) { { in: "SELECT 42", expected: "SELECT 42", - db: false, liid: false, + db: false, liid: false, foundRows: false, }, { in: "SELECT last_insert_id()", @@ -42,7 +42,7 @@ func TestRewrites(in *testing.T) { { in: "SELECT database()", expected: "SELECT :__vtdbname as `database()`", - db: true, liid: false, + db: true, liid: false, foundRows: false, }, { in: "SELECT database() from test", @@ -52,12 +52,12 @@ func TestRewrites(in *testing.T) { { in: "SELECT last_insert_id() as test", expected: "SELECT :__lastInsertId as test", - db: false, liid: true, + db: false, liid: true, foundRows: false, }, { in: "SELECT last_insert_id() + database()", expected: "SELECT :__lastInsertId + :__vtdbname as `last_insert_id() + database()`", - db: true, liid: true, + db: true, liid: true, foundRows: false, }, { in: "select (select database()) from test", @@ -72,7 +72,7 @@ func TestRewrites(in *testing.T) { { in: "select (select database() from dual) from dual", expected: "select (select :__vtdbname as `database()` from dual) as `(select database() from dual)` from dual", - db: true, liid: false, + db: true, liid: false, foundRows: false, }, { in: "select id from user where database()", @@ -82,12 +82,17 @@ func TestRewrites(in *testing.T) { { in: "select table_name from information_schema.tables where table_schema = database()", expected: "select table_name from information_schema.tables where table_schema = database()", - db: false, liid: false, + db: false, liid: false, foundRows: false, }, { in: "select schema()", expected: "select :__vtdbname as 'schema()'", - db: true, liid: false, + db: true, liid: false, foundRows: false, + }, + { + in: "select found_rows()", + expected: "select :__vtfrows as 'found_rows()'", + db: false, liid: false, foundRows: true, }, } @@ -106,6 +111,7 @@ func TestRewrites(in *testing.T) { require.Equal(t, s, String(result.AST)) require.Equal(t, tc.liid, result.NeedLastInsertID, "should need last insert id") require.Equal(t, tc.db, result.NeedDatabase, "should need database name") + require.Equal(t, tc.foundRows, result.NeedFoundRows, "should need found rows") }) } } diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index e68ed782f4b..3084c036416 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -145,6 +145,9 @@ func (e *Executor) Execute(ctx context.Context, method string, safeSession *Safe logStats := NewLogStats(ctx, method, sql, bindVars) result, err = e.execute(ctx, safeSession, sql, bindVars, logStats) + if err == nil { + safeSession.FoundRows = result.RowsAffected + } logStats.Error = err if result != nil && len(result.Rows) > *warnMemoryRows { warnings.Add("ResultsExceeded", 1) @@ -340,6 +343,9 @@ func (e *Executor) handleExec(ctx context.Context, safeSession *SafeSession, sql bindVars[sqlparser.DBVarName] = sqltypes.StringBindVariable(keyspace) } } + if bindVarNeeds.NeedFoundRows { + bindVars[sqlparser.FoundRowsName] = sqltypes.Uint64BindVariable(safeSession.FoundRows) + } qr, err := plan.Instructions.Execute(vcursor, bindVars, true) logStats.ExecuteTime = time.Since(execStart) diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 14363bb6449..3cc40311a57 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -235,6 +235,27 @@ func TestSelectLastInsertId(t *testing.T) { assert.Equal(t, wantQueries, sbc1.Queries) } +func TestFoundRows(t *testing.T) { + executor, sbc1, _, _ := createExecutorEnv() + executor.normalize = true + logChan := QueryLogger.Subscribe("Test") + defer QueryLogger.Unsubscribe(logChan) + + // run this extra query so we can assert on the number of rows found + _, err := executorExec(executor, "select 42", map[string]*querypb.BindVariable{}) + require.NoError(t, err) + + sql := "select found_rows()" + _, err = executorExec(executor, sql, map[string]*querypb.BindVariable{}) + require.NoError(t, err) + expected := &querypb.BoundQuery{ + Sql: "select :__vtfrows as `found_rows()` from dual", + BindVariables: map[string]*querypb.BindVariable{"__vtfrows": sqltypes.Uint64BindVariable(1)}, + } + + assert.Equal(t, expected, sbc1.Queries[1]) +} + func TestSelectLastInsertIdInUnion(t *testing.T) { executor, sbc1, _, _ := createExecutorEnv() executor.normalize = true diff --git a/go/vt/vtgate/executor_test.go b/go/vt/vtgate/executor_test.go index ec148f45594..d23b00131b2 100644 --- a/go/vt/vtgate/executor_test.go +++ b/go/vt/vtgate/executor_test.go @@ -609,7 +609,7 @@ func TestExecutorAutocommit(t *testing.T) { if err != nil { t.Fatal(err) } - wantSession := &vtgatepb.Session{TargetString: "@master", InTransaction: true} + wantSession := &vtgatepb.Session{TargetString: "@master", InTransaction: true, FoundRows: 1} testSession := *session.Session testSession.ShardSessions = nil if !proto.Equal(&testSession, wantSession) { @@ -643,7 +643,7 @@ func TestExecutorAutocommit(t *testing.T) { if err != nil { t.Fatal(err) } - wantSession = &vtgatepb.Session{Autocommit: true, TargetString: "@master"} + wantSession = &vtgatepb.Session{Autocommit: true, TargetString: "@master", FoundRows: 1} if !proto.Equal(session.Session, wantSession) { t.Errorf("autocommit=1: %v, want %v", session.Session, wantSession) } @@ -672,7 +672,7 @@ func TestExecutorAutocommit(t *testing.T) { if err != nil { t.Fatal(err) } - wantSession = &vtgatepb.Session{InTransaction: true, Autocommit: true, TargetString: "@master"} + wantSession = &vtgatepb.Session{InTransaction: true, Autocommit: true, TargetString: "@master", FoundRows: 1} testSession = *session.Session testSession.ShardSessions = nil if !proto.Equal(&testSession, wantSession) { diff --git a/go/vt/vtgate/vtgate_test.go b/go/vt/vtgate/vtgate_test.go index 560aca613b0..2705dc42008 100644 --- a/go/vt/vtgate/vtgate_test.go +++ b/go/vt/vtgate/vtgate_test.go @@ -254,6 +254,7 @@ func TestVTGateExecute(t *testing.T) { }, TransactionId: 1, }}, + FoundRows: 1, } if !proto.Equal(wantSession, session) { t.Errorf("want \n%+v, got \n%+v", wantSession, session)