Skip to content

Commit

Permalink
Remove usage of gRPC Context cancellation in the remote execution cli…
Browse files Browse the repository at this point in the history
…ent. (bazelbuild#17438)

The gRPC remote execution client frequently "converts" gRPC calls into `ListenableFuture`s by setting a `SettableFuture` in the `onCompleted` or `onError` gRPC stub callbacks. If the future has direct executor callbacks, those callbacks will execute with the gRPC Context of the freshly completed call. That is problematic if the `Context` was canceled (canceling the call `Context` is good hygiene after completing a gRPC call), and the future callback goes to make further gRPC calls.

Therefore, this change removes all usage of gRPC `Context` cancellation. It would be nice if there was instead some way to avoid leaking `Context`s between calls instead of having totally forswear `Context` cancellation. However, I can't see a good way to enforce proper isolation.

Fixes bazelbuild#17298.

Closes bazelbuild#17426.

PiperOrigin-RevId: 507730469
Change-Id: Iea74acad4592952700e41d34672f6478de509d5e

Co-authored-by: Benjamin Peterson <[email protected]>
  • Loading branch information
ShreeM01 and benjaminp authored Feb 9, 2023
1 parent 05eefa1 commit 026a8d0
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@
import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
import com.google.devtools.build.lib.remote.util.Utils;
import io.grpc.Channel;
import io.grpc.Context;
import io.grpc.Context.CancellableContext;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
Expand Down Expand Up @@ -231,7 +229,6 @@ private ListenableFuture<Void> startAsyncUpload(
ListenableFuture<Void> currUpload = newUpload.start();
currUpload.addListener(
() -> {
newUpload.cancel();
if (openedFilePermits != null) {
openedFilePermits.release();
}
Expand All @@ -249,7 +246,6 @@ private static final class AsyncUpload implements AsyncCallable<Long> {
private final String resourceName;
private final Chunker chunker;
private final ProgressiveBackoff progressiveBackoff;
private final CancellableContext grpcContext;

private long lastCommittedOffset = -1;

Expand All @@ -269,7 +265,6 @@ private static final class AsyncUpload implements AsyncCallable<Long> {
this.progressiveBackoff = new ProgressiveBackoff(retrier::newBackoff);
this.resourceName = resourceName;
this.chunker = chunker;
this.grpcContext = Context.current().withCancellation();
}

ListenableFuture<Void> start() {
Expand Down Expand Up @@ -367,13 +362,11 @@ private ListenableFuture<Long> query() {
Futures.transform(
channel.withChannelFuture(
channel ->
grpcContext.call(
() ->
bsFutureStub(channel)
.queryWriteStatus(
QueryWriteStatusRequest.newBuilder()
.setResourceName(resourceName)
.build()))),
bsFutureStub(channel)
.queryWriteStatus(
QueryWriteStatusRequest.newBuilder()
.setResourceName(resourceName)
.build())),
QueryWriteStatusResponse::getCommittedSize,
MoreExecutors.directExecutor());
return Futures.catchingAsync(
Expand All @@ -395,18 +388,10 @@ private ListenableFuture<Long> upload(long pos) {
return channel.withChannelFuture(
channel -> {
SettableFuture<Long> uploadResult = SettableFuture.create();
grpcContext.run(
() ->
bsAsyncStub(channel)
.write(new Writer(resourceName, chunker, pos, uploadResult)));
bsAsyncStub(channel).write(new Writer(resourceName, chunker, pos, uploadResult));
return uploadResult;
});
}

void cancel() {
grpcContext.cancel(
Status.CANCELLED.withDescription("Cancelled by user").asRuntimeException());
}
}

private static final class Writer
Expand All @@ -430,6 +415,13 @@ private Writer(
@Override
public void beforeStart(ClientCallStreamObserver<WriteRequest> requestObserver) {
this.requestObserver = requestObserver;
uploadResult.addListener(
() -> {
if (uploadResult.isCancelled()) {
requestObserver.cancel("cancelled by user", null);
}
},
MoreExecutors.directExecutor());
requestObserver.setOnReadyHandler(this);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Ascii;
import com.google.common.base.Preconditions;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.flogger.GoogleLogger;
Expand All @@ -58,11 +59,11 @@
import com.google.devtools.build.lib.vfs.Path;
import com.google.protobuf.ByteString;
import io.grpc.Channel;
import io.grpc.Context;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;
import io.grpc.stub.ClientCallStreamObserver;
import io.grpc.stub.ClientResponseObserver;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
Expand Down Expand Up @@ -371,81 +372,87 @@ private ListenableFuture<Long> requestRead(
} catch (IOException e) {
return Futures.immediateFailedFuture(e);
}
Context.CancellableContext grpcContext = Context.current().withCancellation();
future.addListener(() -> grpcContext.cancel(null), MoreExecutors.directExecutor());
grpcContext.run(
() ->
bsAsyncStub(context, channel)
.read(
ReadRequest.newBuilder()
.setResourceName(resourceName)
.setReadOffset(rawOut.getCount())
.build(),
new StreamObserver<ReadResponse>() {
@Override
public void onNext(ReadResponse readResponse) {
ByteString data = readResponse.getData();
try {
data.writeTo(out);
} catch (IOException e) {
// Cancel the call.
throw new RuntimeException(e);
}
// reset the stall backoff because we've made progress or been kept alive
progressiveBackoff.reset();
}

@Override
public void onError(Throwable t) {
if (rawOut.getCount() == digest.getSizeBytes()) {
// If the file was fully downloaded, it doesn't matter if there was an
// error at
// the end of the stream.
logger.atInfo().withCause(t).log(
"ignoring error because file was fully received");
onCompleted();
return;
}
releaseOut();
Status status = Status.fromThrowable(t);
if (status.getCode() == Status.Code.NOT_FOUND) {
future.setException(new CacheNotFoundException(digest));
} else {
future.setException(t);
}
}

@Override
public void onCompleted() {
try {
try {
out.flush();
} finally {
releaseOut();
}
if (digestSupplier != null) {
Utils.verifyBlobContents(digest, digestSupplier.get());
}
} catch (IOException e) {
future.setException(e);
} catch (RuntimeException e) {
logger.atWarning().withCause(e).log("Unexpected exception");
future.setException(e);
}
future.set(rawOut.getCount());
bsAsyncStub(context, channel)
.read(
ReadRequest.newBuilder()
.setResourceName(resourceName)
.setReadOffset(rawOut.getCount())
.build(),
new ClientResponseObserver<ReadRequest, ReadResponse>() {
@Override
public void beforeStart(ClientCallStreamObserver<ReadRequest> requestStream) {
future.addListener(
() -> {
if (future.isCancelled()) {
requestStream.cancel("canceled by user", null);
}

private void releaseOut() {
if (out instanceof ZstdDecompressingOutputStream) {
try {
((ZstdDecompressingOutputStream) out).closeShallow();
} catch (IOException e) {
logger.atWarning().withCause(e).log(
"failed to cleanly close output stream");
}
}
}
}));
},
MoreExecutors.directExecutor());
}

@Override
public void onNext(ReadResponse readResponse) {
ByteString data = readResponse.getData();
try {
data.writeTo(out);
} catch (IOException e) {
// Cancel the call.
throw new VerifyException(e);
}
// reset the stall backoff because we've made progress or been kept alive
progressiveBackoff.reset();
}

@Override
public void onError(Throwable t) {
if (rawOut.getCount() == digest.getSizeBytes()) {
// If the file was fully downloaded, it doesn't matter if there was an
// error at
// the end of the stream.
logger.atInfo().withCause(t).log(
"ignoring error because file was fully received");
onCompleted();
return;
}
releaseOut();
Status status = Status.fromThrowable(t);
if (status.getCode() == Status.Code.NOT_FOUND) {
future.setException(new CacheNotFoundException(digest));
} else {
future.setException(t);
}
}

@Override
public void onCompleted() {
try {
try {
out.flush();
} finally {
releaseOut();
}
if (digestSupplier != null) {
Utils.verifyBlobContents(digest, digestSupplier.get());
}
} catch (IOException e) {
future.setException(e);
} catch (RuntimeException e) {
logger.atWarning().withCause(e).log("Unexpected exception");
future.setException(e);
}
future.set(rawOut.getCount());
}

private void releaseOut() {
if (out instanceof ZstdDecompressingOutputStream) {
try {
((ZstdDecompressingOutputStream) out).closeShallow();
} catch (IOException e) {
logger.atWarning().withCause(e).log("failed to cleanly close output stream");
}
}
}
});
return future;
}

Expand Down

0 comments on commit 026a8d0

Please sign in to comment.