diff --git a/go/vt/sqlparser/parse_test.go b/go/vt/sqlparser/parse_test.go index acd55be20fc..b5f8fa630e3 100644 --- a/go/vt/sqlparser/parse_test.go +++ b/go/vt/sqlparser/parse_test.go @@ -1132,7 +1132,7 @@ var ( output: "alter table a", }, { input: "analyze table a", - output: "alter table a", + output: "otherread", }, { input: "flush tables", output: "flush", diff --git a/go/vt/sqlparser/sql.go b/go/vt/sqlparser/sql.go index 40c3a374383..ec49904aecb 100644 --- a/go/vt/sqlparser/sql.go +++ b/go/vt/sqlparser/sql.go @@ -5095,7 +5095,7 @@ yydefault: yyDollar = yyS[yypt-3 : yypt+1] //line sql.y:1511 { - yyVAL.statement = &DDL{Action: AlterStr, Table: yyDollar[3].tableName} + yyVAL.statement = &OtherRead{} } case 265: yyDollar = yyS[yypt-4 : yypt+1] diff --git a/go/vt/sqlparser/sql.y b/go/vt/sqlparser/sql.y index 55140b62936..a3b33370a3c 100644 --- a/go/vt/sqlparser/sql.y +++ b/go/vt/sqlparser/sql.y @@ -1509,7 +1509,7 @@ truncate_statement: analyze_statement: ANALYZE TABLE table_name { - $$ = &DDL{Action: AlterStr, Table: $3} + $$ = &OtherRead{} } show_statement: diff --git a/go/vt/vtexplain/vtexplain_flaky_test.go b/go/vt/vtexplain/vtexplain_flaky_test.go index ef35da06890..a6f62fa29ff 100644 --- a/go/vt/vtexplain/vtexplain_flaky_test.go +++ b/go/vt/vtexplain/vtexplain_flaky_test.go @@ -138,7 +138,7 @@ func TestErrors(t *testing.T) { }{ { SQL: "INVALID SQL", - Err: "vtexplain execute error in 'INVALID SQL': unrecognized statement: INVALID SQL", + Err: "vtexplain execute error in 'INVALID SQL': syntax error at position 8 near 'INVALID'", }, { diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index ce2fc4196cf..ddca4b06eca 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -185,6 +185,16 @@ func (e *Executor) execute(ctx context.Context, safeSession *SafeSession, sql st stmtType := sqlparser.Preview(sql) logStats.StmtType = stmtType.String() + stmt, err := sqlparser.Parse(sql) + if err != nil { + // If the DDL statement failed to be properly parsed, fall through anyway + if stmtType == sqlparser.StmtDDL { + stmt = &sqlparser.DDL{} + } else { + return nil, err + } + } + // Mysql warnings are scoped to the current session, but are // cleared when a "non-diagnostic statement" is executed: // https://dev.mysql.com/doc/refman/8.0/en/show-warnings.html @@ -192,14 +202,14 @@ func (e *Executor) execute(ctx context.Context, safeSession *SafeSession, sql st // To emulate this behavior, clear warnings from the session // for all statements _except_ SHOW, so that SHOW WARNINGS // can actually return them. - if stmtType != sqlparser.StmtShow { + if _, isShow := stmt.(*sqlparser.Show); !isShow { safeSession.ClearWarnings() } - switch stmtType { - case sqlparser.StmtSelect: + switch specStmt := stmt.(type) { + case *sqlparser.Select, *sqlparser.Union: return e.handleExec(ctx, safeSession, sql, bindVars, destKeyspace, destTabletType, dest, logStats, stmtType) - case sqlparser.StmtInsert, sqlparser.StmtReplace, sqlparser.StmtUpdate, sqlparser.StmtDelete: + case *sqlparser.Insert, *sqlparser.Update, *sqlparser.Delete: safeSession := safeSession mustCommit := false @@ -236,24 +246,22 @@ func (e *Executor) execute(ctx context.Context, safeSession *SafeSession, sql st logStats.CommitTime = time.Since(commitStart) } return qr, nil - case sqlparser.StmtDDL: + case *sqlparser.DDL: return e.handleDDL(ctx, safeSession, sql, bindVars, dest, destKeyspace, destTabletType, logStats) - case sqlparser.StmtBegin: + case *sqlparser.Begin: return e.handleBegin(ctx, safeSession, sql, bindVars, destTabletType, logStats) - case sqlparser.StmtCommit: + case *sqlparser.Commit: return e.handleCommit(ctx, safeSession, sql, bindVars, logStats) - case sqlparser.StmtRollback: + case *sqlparser.Rollback: return e.handleRollback(ctx, safeSession, sql, bindVars, logStats) - case sqlparser.StmtSet: + case *sqlparser.Set: return e.handleSet(ctx, safeSession, sql, bindVars, logStats) - case sqlparser.StmtShow: + case *sqlparser.Show: return e.handleShow(ctx, safeSession, sql, bindVars, dest, destKeyspace, destTabletType, logStats) - case sqlparser.StmtUse: - return e.handleUse(ctx, safeSession, sql, bindVars) - case sqlparser.StmtOther: + case *sqlparser.Use: + return e.handleUse(ctx, safeSession, sql, bindVars, specStmt) + case *sqlparser.OtherAdmin, *sqlparser.OtherRead: return e.handleOther(ctx, safeSession, sql, bindVars, dest, destKeyspace, destTabletType, logStats) - case sqlparser.StmtComment: - return e.handleComment(sql) } return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unrecognized statement: %s", sql) } @@ -1143,16 +1151,7 @@ func (e *Executor) handleShow(ctx context.Context, safeSession *SafeSession, sql return e.handleOther(ctx, safeSession, sql, bindVars, dest, destKeyspace, destTabletType, logStats) } -func (e *Executor) handleUse(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - stmt, err := sqlparser.Parse(sql) - if err != nil { - return nil, err - } - use, ok := stmt.(*sqlparser.Use) - if !ok { - // This code is unreachable. - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unrecognized USE statement: %v", sql) - } +func (e *Executor) handleUse(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, use *sqlparser.Use) (*sqltypes.Result, error) { destKeyspace, destTabletType, _, err := e.ParseDestinationTarget(use.DBName.String()) if err != nil { return nil, err @@ -1199,12 +1198,6 @@ func (e *Executor) handleOther(ctx context.Context, safeSession *SafeSession, sq return result, err } -func (e *Executor) handleComment(sql string) (*sqltypes.Result, error) { - _, _ = sqlparser.ExtractMysqlComment(sql) - // Not sure if this is a good idea. - return &sqltypes.Result{}, nil -} - // StreamExecute executes a streaming query. func (e *Executor) StreamExecute(ctx context.Context, method string, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, target querypb.Target, callback func(*sqltypes.Result) error) (err error) { logStats := NewLogStats(ctx, method, sql, bindVars) diff --git a/go/vt/vtgate/executor_dml_test.go b/go/vt/vtgate/executor_dml_test.go index 80ead36aac8..c87e1646c69 100644 --- a/go/vt/vtgate/executor_dml_test.go +++ b/go/vt/vtgate/executor_dml_test.go @@ -1601,7 +1601,7 @@ func TestKeyDestRangeQuery(t *testing.T) { sbc2.Queries = nil // it does not work for inserts - _, err = executorExec(executor, "INSERT INTO sharded_user_msgs(message) VALUE('test')", nil) + _, err = executorExec(executor, "INSERT INTO sharded_user_msgs(message) VALUES('test')", nil) want := "range queries not supported for inserts: TestExecutor[-]" if err == nil || err.Error() != want { @@ -1686,7 +1686,7 @@ func TestKeyShardDestQuery(t *testing.T) { // it works for inserts - sql = "INSERT INTO sharded_user_msgs(message) VALUE('test')" + sql = "INSERT INTO sharded_user_msgs(message) VALUES('test')" _, err = executorExec(executor, sql, nil) wantQueries = []*querypb.BoundQuery{{ diff --git a/go/vt/vtgate/executor_test.go b/go/vt/vtgate/executor_test.go index 0f382bde580..0204e473763 100644 --- a/go/vt/vtgate/executor_test.go +++ b/go/vt/vtgate/executor_test.go @@ -30,9 +30,11 @@ import ( "time" "context" + "vitess.io/vitess/go/vt/topo" "github.com/golang/protobuf/proto" + "github.com/google/go-cmp/cmp" "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/callerid" @@ -953,8 +955,8 @@ func TestExecutorComment(t *testing.T) { executor, _, _, _ := createExecutorEnv() stmts := []string{ - "/*! SET max_execution_time=5000*/", - "/*!50708 SET max_execution_time=5000*/", + "/*! SET autocommit=1*/", + "/*!50708 SET @x=5000*/", } wantResult := &sqltypes.Result{} @@ -972,56 +974,78 @@ func TestExecutorComment(t *testing.T) { func TestExecutorOther(t *testing.T) { executor, sbc1, sbc2, sbclookup := createExecutorEnv() - stmts := []string{ - "show other", - "analyze", - "describe", - "explain", - "repair", - "optimize", + type cnts struct { + Sbc1Cnt int64 + Sbc2Cnt int64 + SbcLookupCnt int64 } - wantCount := []int64{0, 0, 0} - for _, stmt := range stmts { - _, err := executor.Execute(context.Background(), "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded}), stmt, nil) - if err != nil { - t.Error(err) - } - gotCount := []int64{ - sbc1.ExecCount.Get(), - sbc2.ExecCount.Get(), - sbclookup.ExecCount.Get(), - } - wantCount[2]++ - if !reflect.DeepEqual(gotCount, wantCount) { - t.Errorf("Exec %s: %v, want %v", stmt, gotCount, wantCount) - } - _, err = executor.Execute(context.Background(), "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}), stmt, nil) - if err != nil { - t.Error(err) - } - gotCount = []int64{ - sbc1.ExecCount.Get(), - sbc2.ExecCount.Get(), - sbclookup.ExecCount.Get(), - } - wantCount[0]++ - if !reflect.DeepEqual(gotCount, wantCount) { - t.Errorf("Exec %s: %v, want %v", stmt, gotCount, wantCount) - } + tcs := []struct { + targetStr string + + hasNoKeyspaceErr bool + hasDestinationShardErr bool + wantCnts cnts + }{ + { + targetStr: "", + hasNoKeyspaceErr: true, + }, + { + targetStr: "TestExecutor[-]", + hasDestinationShardErr: true, + }, + { + targetStr: KsTestUnsharded, + wantCnts: cnts{ + Sbc1Cnt: 0, + Sbc2Cnt: 0, + SbcLookupCnt: 1, + }, + }, + { + targetStr: "TestExecutor", + wantCnts: cnts{ + Sbc1Cnt: 1, + Sbc2Cnt: 0, + SbcLookupCnt: 0, + }, + }, } - _, err := executor.Execute(context.Background(), "TestExecute", NewSafeSession(&vtgatepb.Session{}), "analyze", nil) - want := errNoKeyspace.Error() - if err == nil || err.Error() != want { - t.Errorf("show vschema tables: %v, want %v", err, want) + stmts := []string{ + "show tables", + "analyze table t1", + "describe t1", + "explain t1", + "repair table t1", + "optimize table t1", } - // Can't target a range with handle other - _, err = executor.Execute(context.Background(), "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor[-]"}), "analyze", nil) - want = "Destination can only be a single shard for statement: analyze, got: DestinationExactKeyRange(-)" - if err == nil || err.Error() != want { - t.Errorf("analyze: got %v, want %v", err, want) + for _, stmt := range stmts { + for _, tc := range tcs { + sbc1.ExecCount.Set(0) + sbc2.ExecCount.Set(0) + sbclookup.ExecCount.Set(0) + + _, err := executor.Execute(context.Background(), "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: tc.targetStr}), stmt, nil) + if tc.hasNoKeyspaceErr { + assert.Error(t, err, errNoKeyspace) + } else if tc.hasDestinationShardErr { + assert.Errorf(t, err, "Destination can only be a single shard for statement: %s, got: DestinationExactKeyRange(-)", stmt) + } else { + assert.NoError(t, err) + } + + diff := cmp.Diff(tc.wantCnts, cnts{ + Sbc1Cnt: sbc1.ExecCount.Get(), + Sbc2Cnt: sbc2.ExecCount.Get(), + SbcLookupCnt: sbclookup.ExecCount.Get(), + }) + if diff != "" { + t.Errorf("stmt: %s\ntc: %+v\n-want,+got:\n%s", stmt, tc, diff) + } + } } } @@ -1031,68 +1055,85 @@ func TestExecutorDDL(t *testing.T) { executor, sbc1, sbc2, sbclookup := createExecutorEnv() + type cnts struct { + Sbc1Cnt int64 + Sbc2Cnt int64 + SbcLookupCnt int64 + } + + tcs := []struct { + targetStr string + + hasNoKeyspaceErr bool + shardQueryCnt int + wantCnts cnts + }{ + { + targetStr: "", + hasNoKeyspaceErr: true, + }, + { + targetStr: KsTestUnsharded, + shardQueryCnt: 1, + wantCnts: cnts{ + Sbc1Cnt: 0, + Sbc2Cnt: 0, + SbcLookupCnt: 1, + }, + }, + { + targetStr: "TestExecutor", + shardQueryCnt: 8, + wantCnts: cnts{ + Sbc1Cnt: 1, + Sbc2Cnt: 1, + SbcLookupCnt: 0, + }, + }, + { + targetStr: "TestExecutor/-20", + shardQueryCnt: 1, + wantCnts: cnts{ + Sbc1Cnt: 1, + Sbc2Cnt: 0, + SbcLookupCnt: 0, + }, + }, + } + stmts := []string{ "create", - "alter", - "rename", - "drop", - "truncate", + "alter table t1 add primary key id", + "rename table t1 to t2", + "truncate table t2", + "drop table t2", } - wantCount := []int64{0, 0, 0} + for _, stmt := range stmts { - _, err := executor.Execute(context.Background(), "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded}), stmt, nil) - if err != nil { - t.Error(err) - } - gotCount := []int64{ - sbc1.ExecCount.Get(), - sbc2.ExecCount.Get(), - sbclookup.ExecCount.Get(), - } - wantCount[2]++ - if !reflect.DeepEqual(gotCount, wantCount) { - t.Errorf("Exec %s: %v, want %v", stmt, gotCount, wantCount) - } - testQueryLog(t, logChan, "TestExecute", "DDL", stmt, 1) + for _, tc := range tcs { + sbc1.ExecCount.Set(0) + sbc2.ExecCount.Set(0) + sbclookup.ExecCount.Set(0) + + _, err := executor.Execute(context.Background(), "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: tc.targetStr}), stmt, nil) + if tc.hasNoKeyspaceErr { + assert.Error(t, err, errNoKeyspace) + } else { + assert.NoError(t, err) + } - _, err = executor.Execute(context.Background(), "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}), stmt, nil) - if err != nil { - t.Error(err) - } - gotCount = []int64{ - sbc1.ExecCount.Get(), - sbc2.ExecCount.Get(), - sbclookup.ExecCount.Get(), - } - wantCount[0]++ - wantCount[1]++ - if !reflect.DeepEqual(gotCount, wantCount) { - t.Errorf("Exec %s: %v, want %v", stmt, gotCount, wantCount) - } - testQueryLog(t, logChan, "TestExecute", "DDL", stmt, 8) + diff := cmp.Diff(tc.wantCnts, cnts{ + Sbc1Cnt: sbc1.ExecCount.Get(), + Sbc2Cnt: sbc2.ExecCount.Get(), + SbcLookupCnt: sbclookup.ExecCount.Get(), + }) + if diff != "" { + t.Errorf("stmt: %s\ntc: %+v\n-want,+got:\n%s", stmt, tc, diff) + } - _, err = executor.Execute(context.Background(), "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor/-20"}), stmt, nil) - if err != nil { - t.Error(err) - } - gotCount = []int64{ - sbc1.ExecCount.Get(), - sbc2.ExecCount.Get(), - sbclookup.ExecCount.Get(), - } - wantCount[0]++ - if !reflect.DeepEqual(gotCount, wantCount) { - t.Errorf("Exec %s: %v, want %v", stmt, gotCount, wantCount) + testQueryLog(t, logChan, "TestExecute", "DDL", stmt, tc.shardQueryCnt) } - testQueryLog(t, logChan, "TestExecute", "DDL", stmt, 1) - } - - _, err := executor.Execute(context.Background(), "TestExecute", NewSafeSession(&vtgatepb.Session{}), "create", nil) - want := errNoKeyspace.Error() - if err == nil || err.Error() != want { - t.Errorf("ddl with no keyspace: %v, want %v", err, want) } - testQueryLog(t, logChan, "TestExecute", "DDL", "create", 0) } func waitForVindex(t *testing.T, ks, name string, watch chan *vschemapb.SrvVSchema, executor *Executor) (*vschemapb.SrvVSchema, *vschemapb.Vindex) { @@ -1845,7 +1886,7 @@ func TestExecutorVindexDDLACL(t *testing.T) { func TestExecutorUnrecognized(t *testing.T) { executor, _, _, _ := createExecutorEnv() _, err := executor.Execute(context.Background(), "TestExecute", NewSafeSession(&vtgatepb.Session{}), "invalid statement", nil) - want := "unrecognized statement: invalid statement" + want := "syntax error at position 8 near 'invalid'" if err == nil || err.Error() != want { t.Errorf("show vschema tables: %v, want %v", err, want) } diff --git a/go/vt/vttablet/tabletserver/planbuilder/testdata/exec_cases.txt b/go/vt/vttablet/tabletserver/planbuilder/testdata/exec_cases.txt index d26044b6c73..dacde288127 100644 --- a/go/vt/vttablet/tabletserver/planbuilder/testdata/exec_cases.txt +++ b/go/vt/vttablet/tabletserver/planbuilder/testdata/exec_cases.txt @@ -2255,14 +2255,8 @@ options:PassthroughDMLs # analyze "analyze table a" { - "PlanID": "DDL", - "TableName": "", - "Permissions": [ - { - "TableName": "a", - "Role": 2 - } - ] + "PlanID": "OTHER_READ", + "TableName": "" } # reorganize partition with bind