diff --git a/go/vt/vtgate/engine/send.go b/go/vt/vtgate/engine/send.go index 8eab3a8d067..c812fa53ac5 100644 --- a/go/vt/vtgate/engine/send.go +++ b/go/vt/vtgate/engine/send.go @@ -73,7 +73,7 @@ func (s *Send) GetTableName() string { } // Execute implements Primitive interface -func (s *Send) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, _ bool) (*sqltypes.Result, error) { +func (s *Send) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { rss, _, err := vcursor.ResolveDestinations(s.Keyspace.Name, nil, []key.Destination{s.TargetDestination}) if err != nil { return nil, vterrors.Wrap(err, "sendExecute") @@ -110,12 +110,43 @@ func (s *Send) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariabl } // StreamExecute implements Primitive interface -func (s *Send) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantields bool, callback func(*sqltypes.Result) error) error { - return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "not reachable") // TODO: systay - this should work +func (s *Send) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { + rss, _, err := vcursor.ResolveDestinations(s.Keyspace.Name, nil, []key.Destination{s.TargetDestination}) + if err != nil { + return vterrors.Wrap(err, "sendStreamExecute") + } + + if !s.Keyspace.Sharded && len(rss) != 1 { + return vterrors.Errorf(vtrpcpb.Code_FAILED_PRECONDITION, "Keyspace does not have exactly one shard: %v", rss) + } + + if s.SingleShardOnly && len(rss) != 1 { + return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Unexpected error, DestinationKeyspaceID mapping to multiple shards: %s, got: %v", s.Query, s.TargetDestination) + } + + queries := make([]*querypb.BoundQuery, len(rss)) + for i := range rss { + queries[i] = &querypb.BoundQuery{ + Sql: s.Query, + BindVariables: bindVars, + } + } + + multiBindVars := make([]map[string]*querypb.BindVariable, len(rss)) + for i := range multiBindVars { + multiBindVars[i] = bindVars + } + return vcursor.StreamExecuteMulti(s.Query, rss, multiBindVars, callback) } // GetFields implements Primitive interface func (s *Send) GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { + // We don't need to worry about GetFields being needed for joins since the Send primitive currently doesn't + // get nested with other types of primitives. + // However, GetFields is used for prepared statements. Because of that, prepared statements are not yet + // compatible with the Send primitive. + // + // TODO: implement this return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "not reachable") } diff --git a/go/vt/vtgate/engine/send_test.go b/go/vt/vtgate/engine/send_test.go index 4604ce0e488..162018aac23 100644 --- a/go/vt/vtgate/engine/send_test.go +++ b/go/vt/vtgate/engine/send_test.go @@ -34,7 +34,9 @@ func TestSendTable(t *testing.T) { shards []string destination key.Destination expectedQueryLog []string + expectedError string isDML bool + singleShardOnly bool } singleShard := []string{"0"} @@ -95,6 +97,18 @@ func TestSendTable(t *testing.T) { }, isDML: true, }, + { + testName: "sharded with multi shard destination", + sharded: true, + shards: twoShards, + destination: key.DestinationAllShards{}, + expectedQueryLog: []string{ + `ResolveDestinations ks [] Destinations:DestinationAllShards()`, + }, + expectedError: "Unexpected error, DestinationKeyspaceID mapping to multiple shards: dummy_query, got: DestinationAllShards()", + isDML: true, + singleShardOnly: true, + }, } for _, tc := range tests { @@ -107,10 +121,15 @@ func TestSendTable(t *testing.T) { Query: "dummy_query", TargetDestination: tc.destination, IsDML: tc.isDML, + SingleShardOnly: tc.singleShardOnly, } vc := &loggingVCursor{shards: tc.shards} _, err := send.Execute(vc, map[string]*querypb.BindVariable{}, false) - require.NoError(t, err) + if tc.expectedError != "" { + require.EqualError(t, err, tc.expectedError) + } else { + require.NoError(t, err) + } vc.ExpectLog(t, tc.expectedQueryLog) // Failure cases @@ -126,3 +145,122 @@ func TestSendTable(t *testing.T) { }) } } + +func TestSendTable_StreamExecute(t *testing.T) { + type testCase struct { + testName string + sharded bool + shards []string + destination key.Destination + expectedQueryLog []string + expectedError string + isDML bool + singleShardOnly bool + } + + singleShard := []string{"0"} + twoShards := []string{"-20", "20-"} + tests := []testCase{ + { + testName: "unsharded with no autocommit", + sharded: false, + shards: singleShard, + destination: key.DestinationAllShards{}, + expectedQueryLog: []string{ + `ResolveDestinations ks [] Destinations:DestinationAllShards()`, + `StreamExecuteMulti dummy_query ks.0: {} `, + }, + isDML: false, + }, + { + testName: "sharded with no autocommit", + sharded: true, + shards: twoShards, + destination: key.DestinationShard("20-"), + expectedQueryLog: []string{ + `ResolveDestinations ks [] Destinations:DestinationShard(20-)`, + `StreamExecuteMulti dummy_query ks.DestinationShard(20-): {} `, + }, + isDML: false, + }, + { + testName: "unsharded", + sharded: false, + shards: singleShard, + destination: key.DestinationAllShards{}, + expectedQueryLog: []string{ + `ResolveDestinations ks [] Destinations:DestinationAllShards()`, + `StreamExecuteMulti dummy_query ks.0: {} `, + }, + isDML: true, + }, + { + testName: "sharded with single shard destination", + sharded: true, + shards: twoShards, + destination: key.DestinationShard("20-"), + expectedQueryLog: []string{ + `ResolveDestinations ks [] Destinations:DestinationShard(20-)`, + `StreamExecuteMulti dummy_query ks.DestinationShard(20-): {} `, + }, + isDML: true, + }, + { + testName: "sharded with multi shard destination", + sharded: true, + shards: twoShards, + destination: key.DestinationAllShards{}, + expectedQueryLog: []string{ + `ResolveDestinations ks [] Destinations:DestinationAllShards()`, + `StreamExecuteMulti dummy_query ks.-20: {} ks.20-: {} `, + }, + isDML: true, + }, + { + testName: "sharded with multi shard destination single shard setting", + sharded: true, + shards: twoShards, + destination: key.DestinationAllShards{}, + expectedQueryLog: []string{ + `ResolveDestinations ks [] Destinations:DestinationAllShards()`, + }, + expectedError: "Unexpected error, DestinationKeyspaceID mapping to multiple shards: dummy_query, got: DestinationAllShards()", + isDML: true, + singleShardOnly: true, + }, + } + + for _, tc := range tests { + t.Run(tc.testName, func(t *testing.T) { + send := &Send{ + Keyspace: &vindexes.Keyspace{ + Name: "ks", + Sharded: tc.sharded, + }, + Query: "dummy_query", + TargetDestination: tc.destination, + IsDML: tc.isDML, + SingleShardOnly: tc.singleShardOnly, + } + vc := &loggingVCursor{shards: tc.shards} + _, err := wrapStreamExecute(send, vc, map[string]*querypb.BindVariable{}, false) + if tc.expectedError != "" { + require.EqualError(t, err, tc.expectedError) + } else { + require.NoError(t, err) + } + vc.ExpectLog(t, tc.expectedQueryLog) + + // Failure cases + vc = &loggingVCursor{shardErr: errors.New("shard_error")} + _, err = wrapStreamExecute(send, vc, map[string]*querypb.BindVariable{}, false) + require.EqualError(t, err, "sendStreamExecute: shard_error") + + if !tc.sharded { + vc = &loggingVCursor{} + _, err = wrapStreamExecute(send, vc, map[string]*querypb.BindVariable{}, false) + require.EqualError(t, err, "Keyspace does not have exactly one shard: []") + } + }) + } +} diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index fc83b8507bc..e9a78cae942 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -989,7 +989,8 @@ func (e *Executor) handleComment(sql string) (*sqltypes.Result, error) { // StreamExecute executes a streaming query. func (e *Executor) StreamExecute(ctx context.Context, method string, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, target querypb.Target, callback func(*sqltypes.Result) error) (err error) { logStats := NewLogStats(ctx, method, sql, bindVars) - logStats.StmtType = sqlparser.Preview(sql).String() + stmtType := sqlparser.Preview(sql) + logStats.StmtType = stmtType.String() defer logStats.Send() if bindVars == nil { @@ -998,10 +999,22 @@ func (e *Executor) StreamExecute(ctx context.Context, method string, safeSession query, comments := sqlparser.SplitMarginComments(sql) vcursor, _ := newVCursorImpl(ctx, safeSession, comments, e, logStats, e.vm, e.resolver.resolver) - // check if this is a stream statement for messaging - // TODO: support keyRange syntax - if logStats.StmtType == sqlparser.StmtStream.String() { + switch stmtType { + case sqlparser.StmtStream: + // this is a stream statement for messaging + // TODO: support keyRange syntax return e.handleMessageStream(ctx, sql, target, callback, vcursor, logStats) + case sqlparser.StmtSelect, sqlparser.StmtDDL, sqlparser.StmtSet, sqlparser.StmtInsert, sqlparser.StmtReplace, sqlparser.StmtUpdate, sqlparser.StmtDelete, + sqlparser.StmtUse, sqlparser.StmtOther, sqlparser.StmtComment: + // These may or may not all work, but getPlan() should either return a plan with instructions + // or an error, so it's safe to try. + break + case sqlparser.StmtBegin, sqlparser.StmtCommit, sqlparser.StmtRollback: + // These statements don't populate plan.Instructions. We want to make sure we don't try to + // dereference nil Instructions which would panic. + fallthrough + default: + return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unsupported statement type for OLAP: %s", stmtType) } plan, err := e.getPlan( diff --git a/go/vt/vtgate/planbuilder/testdata/bypass_cases.txt b/go/vt/vtgate/planbuilder/testdata/bypass_cases.txt index a3b282848fb..4fc5d2d023f 100644 --- a/go/vt/vtgate/planbuilder/testdata/bypass_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/bypass_cases.txt @@ -73,3 +73,22 @@ "SingleShardOnly": false } } + +# insert bypass with sequence: sequences ignored +"insert into user(nonid) values (2)" +{ + "QueryType": "INSERT", + "Original": "insert into user(nonid) values (2)", + "Instructions": { + "OperatorType": "Send", + "Variant": "", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "TargetDestination": "Shard(-80)", + "IsDML": true, + "Query": "insert into user(nonid) values (2)", + "SingleShardOnly": false + } +}