diff --git a/go/vt/vtgate/executor_test.go b/go/vt/vtgate/executor_test.go index 0f382bde580..06838eb02f6 100644 --- a/go/vt/vtgate/executor_test.go +++ b/go/vt/vtgate/executor_test.go @@ -33,6 +33,7 @@ import ( "vitess.io/vitess/go/vt/topo" "github.com/golang/protobuf/proto" + "github.com/google/go-cmp/cmp" "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/callerid" @@ -1031,6 +1032,52 @@ func TestExecutorDDL(t *testing.T) { executor, sbc1, sbc2, sbclookup := createExecutorEnv() + type cnts struct { + Sbc1Cnt int64 + Sbc2Cnt int64 + SbcLookupCnt int64 + } + + tcs := []struct { + targetStr string + + hasNoKeyspaceErr bool + shardQueryCnt int + wantCnts cnts + }{ + { + targetStr: "", + hasNoKeyspaceErr: true, + }, + { + targetStr: KsTestUnsharded, + shardQueryCnt: 1, + wantCnts: cnts{ + Sbc1Cnt: 0, + Sbc2Cnt: 0, + SbcLookupCnt: 1, + }, + }, + { + targetStr: "TestExecutor", + shardQueryCnt: 8, + wantCnts: cnts{ + Sbc1Cnt: 1, + Sbc2Cnt: 1, + SbcLookupCnt: 0, + }, + }, + { + targetStr: "TestExecutor/-20", + shardQueryCnt: 1, + wantCnts: cnts{ + Sbc1Cnt: 1, + Sbc2Cnt: 0, + SbcLookupCnt: 0, + }, + }, + } + stmts := []string{ "create", "alter", @@ -1038,61 +1085,32 @@ func TestExecutorDDL(t *testing.T) { "drop", "truncate", } - wantCount := []int64{0, 0, 0} + for _, stmt := range stmts { - _, err := executor.Execute(context.Background(), "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded}), stmt, nil) - if err != nil { - t.Error(err) - } - gotCount := []int64{ - sbc1.ExecCount.Get(), - sbc2.ExecCount.Get(), - sbclookup.ExecCount.Get(), - } - wantCount[2]++ - if !reflect.DeepEqual(gotCount, wantCount) { - t.Errorf("Exec %s: %v, want %v", stmt, gotCount, wantCount) - } - testQueryLog(t, logChan, "TestExecute", "DDL", stmt, 1) + for _, tc := range tcs { + sbc1.ExecCount.Set(0) + sbc2.ExecCount.Set(0) + sbclookup.ExecCount.Set(0) + + _, err := executor.Execute(context.Background(), "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: tc.targetStr}), stmt, nil) + if tc.hasNoKeyspaceErr { + assert.Error(t, err, errNoKeyspace) + } else { + assert.NoError(t, err) + } - _, err = executor.Execute(context.Background(), "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}), stmt, nil) - if err != nil { - t.Error(err) - } - gotCount = []int64{ - sbc1.ExecCount.Get(), - sbc2.ExecCount.Get(), - sbclookup.ExecCount.Get(), - } - wantCount[0]++ - wantCount[1]++ - if !reflect.DeepEqual(gotCount, wantCount) { - t.Errorf("Exec %s: %v, want %v", stmt, gotCount, wantCount) - } - testQueryLog(t, logChan, "TestExecute", "DDL", stmt, 8) + diff := cmp.Diff(tc.wantCnts, cnts{ + Sbc1Cnt: sbc1.ExecCount.Get(), + Sbc2Cnt: sbc2.ExecCount.Get(), + SbcLookupCnt: sbclookup.ExecCount.Get(), + }) + if diff != "" { + t.Errorf("stmt: %s\ntc: %+v\n-want,+got:\n%s", stmt, tc, diff) + } - _, err = executor.Execute(context.Background(), "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor/-20"}), stmt, nil) - if err != nil { - t.Error(err) - } - gotCount = []int64{ - sbc1.ExecCount.Get(), - sbc2.ExecCount.Get(), - sbclookup.ExecCount.Get(), - } - wantCount[0]++ - if !reflect.DeepEqual(gotCount, wantCount) { - t.Errorf("Exec %s: %v, want %v", stmt, gotCount, wantCount) + testQueryLog(t, logChan, "TestExecute", "DDL", stmt, tc.shardQueryCnt) } - testQueryLog(t, logChan, "TestExecute", "DDL", stmt, 1) - } - - _, err := executor.Execute(context.Background(), "TestExecute", NewSafeSession(&vtgatepb.Session{}), "create", nil) - want := errNoKeyspace.Error() - if err == nil || err.Error() != want { - t.Errorf("ddl with no keyspace: %v, want %v", err, want) } - testQueryLog(t, logChan, "TestExecute", "DDL", "create", 0) } func waitForVindex(t *testing.T, ks, name string, watch chan *vschemapb.SrvVSchema, executor *Executor) (*vschemapb.SrvVSchema, *vschemapb.Vindex) {