diff --git a/storage/replica.go b/storage/replica.go index 89773ae6024f..037466079910 100644 --- a/storage/replica.go +++ b/storage/replica.go @@ -135,6 +135,7 @@ func usesTimestampCache(r roachpb.Request) bool { // via the done channel. type pendingCmd struct { ctx context.Context + idKey cmdIDKey raftCmd roachpb.RaftCommand done chan roachpb.ResponseWithError // Used to signal waiting RPC handler } @@ -635,7 +636,6 @@ func (r *Replica) RaftStatus() *raft.Status { // range's leadership is confirmed. The command is then dispatched // either along the read-only execution path or the read-write Raft // command queue. -// TODO(tschottdorf): use BatchRequest w/o pointer receiver. func (r *Replica) Send(ctx context.Context, ba roachpb.BatchRequest) (*roachpb.BatchResponse, *roachpb.Error) { var br *roachpb.BatchResponse @@ -963,11 +963,32 @@ func (r *Replica) addWriteCmd(ctx context.Context, ba roachpb.BatchRequest, wg * if err == nil { // If the command was accepted by raft, wait for the range to apply it. - select { - case respWithErr := <-pendingCmd.done: - br, pErr = respWithErr.Reply, respWithErr.Err - case <-ctx.Done(): - pErr = roachpb.NewError(ctx.Err()) + ctxDone := ctx.Done() + for br == nil && pErr == nil { + select { + case respWithErr := <-pendingCmd.done: + br, pErr = respWithErr.Reply, respWithErr.Err + case <-ctxDone: + // Cancellation is somewhat tricky since we can't prevent the + // Raft command from executing at some point in the future. + // We try to remove the pending command, but if the processRaft + // goroutine has already grabbed it (as would typically be the + // case right as it executes), it's too late and we're still + // going to have to wait until the command returns (which would + // typically be right away). + // A typical outcome of a bug here would be use-after-free of + // the trace of this client request; we finish it when + // returning from here, but the Raft execution also uses it. + ctxDone = nil + if r.tryAbandon(pendingCmd.idKey) { + // TODO(tschottdorf): the command will still execute at + // some process, so maybe this should be a structured error + // which can be interpreted appropriately upstream. + pErr = roachpb.NewError(ctx.Err()) + } else { + log.Warningf("unable to cancel expired Raft command %s", ba) + } + } } } else { pErr = roachpb.NewError(err) @@ -975,6 +996,18 @@ func (r *Replica) addWriteCmd(ctx context.Context, ba roachpb.BatchRequest, wg * return br, pErr } +// tryAbandon attempts to remove a pending command from the internal commands +// map. This is possible until execution of the command at the local replica +// has already begun, in which case false is returned and the client needs to +// continue waiting for successful execution. +func (r *Replica) tryAbandon(idKey cmdIDKey) bool { + r.mu.Lock() + _, ok := r.mu.pendingCmds[idKey] + delete(r.mu.pendingCmds, idKey) + r.mu.Unlock() + return ok +} + // proposeRaftCommand prepares necessary pending command struct and // initializes a client command ID if one hasn't been. It then // proposes the command to Raft and returns the error channel and @@ -990,8 +1023,9 @@ func (r *Replica) proposeRaftCommand(ctx context.Context, ba roachpb.BatchReques idKeyBuf = encoding.EncodeUint64Ascending(idKeyBuf, uint64(rand.Int63())) idKey := cmdIDKey(idKeyBuf) pendingCmd := &pendingCmd{ - ctx: ctx, - done: make(chan roachpb.ResponseWithError, 1), + ctx: ctx, + idKey: idKey, + done: make(chan roachpb.ResponseWithError, 1), raftCmd: roachpb.RaftCommand{ RangeID: r.RangeID, OriginReplica: *replica, diff --git a/storage/replica_test.go b/storage/replica_test.go index e4b204dc73a9..afbe2e05c4c5 100644 --- a/storage/replica_test.go +++ b/storage/replica_test.go @@ -4005,3 +4005,52 @@ func TestGCIncorrectRange(t *testing.T) { t.Errorf("expected value at key %s to no longer exist after GC to correct range, found value %v", key, resVal) } } + +// TestReplicaCancelRaft checks that it is possible to safely abandon Raft +// commands via a canceable context.Context. +func TestReplicaCancelRaft(t *testing.T) { + defer leaktest.AfterTest(t)() + for _, cancelEarly := range []bool{true, false} { + func() { + // Pick a key unlikely to be used by background processes. + key := []byte("acdfg") + ctx, cancel := context.WithCancel(context.Background()) + TestingCommandFilter = func(_ roachpb.StoreID, args roachpb.Request, _ roachpb.Header) error { + if !args.Header().Key.Equal(key) { + return nil + } + if cancelEarly { + return errors.New("expected client to abandon this request") + } + cancel() + return nil + } + defer func() { TestingCommandFilter = nil }() + tc := testContext{} + tc.Start(t) + defer tc.Stop() + if cancelEarly { + cancel() + } + var ba roachpb.BatchRequest + ba.Add(&roachpb.GetRequest{ + Span: roachpb.Span{Key: key}, + }) + br, pErr := tc.rng.addWriteCmd(ctx, ba, nil /* wg */) + if pErr == nil { + if !cancelEarly { + // We cancelled the context while the command was already + // being processed, so the client had to wait for successful + // execution. + return + } + t.Fatalf("expected an error, but got successful response %+v", br) + } + // If we cancelled the context early enough, we expect to receive a + // corresponding error and not wait for the command. + if !testutils.IsPError(pErr, context.Canceled.Error()) { + t.Fatalf("unexpected error: %s", pErr) + } + }() + } +}