Skip to content

Commit

Permalink
Add Status parameter to GrpcExceptionHandlerFunction.apply() meth…
Browse files Browse the repository at this point in the history
…od. (#5786)

Motivation:
When an exception is raised, `GrpcService` creates a new `Status` from `Status.getCause()` via `GrpcExceptionHandlerFunction`. This can result in sending an incorrect status in certain situations:
- If the original exception is a `StatusRuntimeException` containing the `Status` that the user wants to send.
- The original exception has already been converted into a `Status` via `Status.fromThrowable()`: https://github.com/grpc/grpc-java/blob/5770114d08dcd352f2288ef52d17e1833530323c/stub/src/main/java/io/grpc/stub/ServerCalls.java#L389
- This converted `Status` is ignored and a new, potentially incorrect `Status` is created by `GrpcExceptionHandlerFunction` and sent.

Modifications:
- Added a `Status` parameter to the `GrpcExceptionHandlerFunction.apply()` method.
- Updated the default implementation of `GrpcExceptionHandlerFunction.of()` to return the provided `Status` if it is not an unknown status.

Result:
- The `GrpcExceptionHandlerFunction` now properly handles and returns the correct `Status`.
- (Breaking) The `apply` method of `GrpcExceptionHandlerFunction` now takes a `Status`.
  • Loading branch information
minwoox committed Jun 28, 2024
1 parent a48670e commit 142aff8
Show file tree
Hide file tree
Showing 26 changed files with 114 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class GrpcExceptionHandler implements GrpcExceptionHandlerFunction {

@Nullable
@Override
public Status apply(RequestContext ctx, Throwable cause, Metadata metadata) {
public Status apply(RequestContext ctx, @Nullable Status status, Throwable cause, Metadata metadata) {
if (cause instanceof IllegalArgumentException) {
return Status.INVALID_ARGUMENT.withCause(cause);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,12 @@ internal class CoroutineServerInterceptorTest {
object : ServerExtension() {
override fun configure(sb: ServerBuilder) {
val exceptionHandler =
GrpcExceptionHandlerFunction { _: RequestContext, throwable: Throwable, _: Metadata ->
GrpcExceptionHandlerFunction {
_: RequestContext,
_: Status?,
throwable: Throwable,
_: Metadata,
->
if (throwable is AnticipatedException && throwable.message == "Invalid access") {
return@GrpcExceptionHandlerFunction Status.UNAUTHENTICATED
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import com.linecorp.armeria.common.ContentTooLargeException;
import com.linecorp.armeria.common.RequestContext;
import com.linecorp.armeria.common.TimeoutException;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.stream.ClosedStreamException;
import com.linecorp.armeria.server.RequestTimeoutException;
import com.linecorp.armeria.server.ServiceRequestContext;
Expand All @@ -45,7 +46,10 @@ enum DefaultGrpcExceptionHandlerFunction implements GrpcExceptionHandlerFunction
* well and the protocol package.
*/
@Override
public Status apply(RequestContext ctx, Throwable cause, Metadata metadata) {
public Status apply(RequestContext ctx, @Nullable Status status, Throwable cause, Metadata metadata) {
if (status != null && status.getCode() != Code.UNKNOWN) {
return status;
}
final Status s = Status.fromThrowable(cause);
if (s.getCode() != Code.UNKNOWN) {
return s;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ public interface GoogleGrpcExceptionHandlerFunction extends GrpcExceptionHandler

@Nullable
@Override
default Status apply(RequestContext ctx, Throwable throwable, Metadata metadata) {
default Status apply(RequestContext ctx, @Nullable Status status, Throwable throwable, Metadata metadata) {
return handleException(ctx, throwable, metadata, this::applyStatusProto);
}

/**
* Maps the specified {@link Throwable} to a {@link com.google.rpc.Status},
* and mutates the specified {@link Metadata}.
* The `grpc-status-details-bin` key is ignored since it will be overwritten
* by {@link GoogleGrpcExceptionHandlerFunction#apply(RequestContext, Throwable, Metadata)}.
* by {@link GrpcExceptionHandlerFunction#apply(RequestContext, Status, Throwable, Metadata)}.
* If {@code null} is returned, the built-in mapping rule is used by default.
*/
com.google.rpc.@Nullable Status applyStatusProto(RequestContext ctx, Throwable throwable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,16 @@ static GrpcExceptionHandlerFunction of() {
}

/**
* Maps the specified {@link Throwable} to a gRPC {@link Status},
* and mutates the specified {@link Metadata}.
* If {@code null} is returned, the built-in mapping rule is used by default.
* Maps the specified {@link Throwable} to a gRPC {@link Status} and mutates the specified {@link Metadata}.
* If {@code null} is returned, {@link #of()} will be used to return {@link Status} as the default.
*
* <p>The {@link Status} may also be specified as a parameter if it is created by
* the upstream gRPC framework.
* You can return the {@link Status} or any other {@link Status} as needed. If the exception is raised
* internally in Armeria, no {@link Status} created, so {@code null} will be specified.
*/
@Nullable
Status apply(RequestContext ctx, Throwable cause, Metadata metadata);
Status apply(RequestContext ctx, @Nullable Status status, Throwable cause, Metadata metadata);

/**
* Returns a {@link GrpcExceptionHandlerFunction} that returns the result of this function
Expand All @@ -63,12 +67,12 @@ static GrpcExceptionHandlerFunction of() {
*/
default GrpcExceptionHandlerFunction orElse(GrpcExceptionHandlerFunction next) {
requireNonNull(next, "next");
return (ctx, cause, metadata) -> {
final Status status = apply(ctx, cause, metadata);
if (status != null) {
return status;
return (ctx, status, cause, metadata) -> {
final Status newStatus = apply(ctx, status, cause, metadata);
if (newStatus != null) {
return newStatus;
}
return next.apply(ctx, cause, metadata);
return next.apply(ctx, status, cause, metadata);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public final class GrpcExceptionHandlerFunctionBuilder {
*/
public GrpcExceptionHandlerFunctionBuilder on(Class<? extends Throwable> exceptionType, Status status) {
requireNonNull(status, "status");
return on(exceptionType, (ctx, cause, metadata) -> status);
return on(exceptionType, (ctx, unused, cause, metadata) -> status);
}

/**
Expand All @@ -66,7 +66,7 @@ public <T extends Throwable> GrpcExceptionHandlerFunctionBuilder on(
requireNonNull(exceptionType, "exceptionType");
requireNonNull(exceptionHandler, "exceptionHandler");
//noinspection unchecked
return on(exceptionType, (ctx, cause, metadata) -> exceptionHandler.apply((T) cause, metadata));
return on(exceptionType, (ctx, status, cause, metadata) -> exceptionHandler.apply((T) cause, metadata));
}

/**
Expand Down Expand Up @@ -107,11 +107,11 @@ public GrpcExceptionHandlerFunction build() {

final List<Entry<Class<? extends Throwable>, GrpcExceptionHandlerFunction>> mappings =
ImmutableList.copyOf(exceptionMappings);
return (ctx, cause, metadata) -> {
return (ctx, status, cause, metadata) -> {
for (Map.Entry<Class<? extends Throwable>, GrpcExceptionHandlerFunction> mapping : mappings) {
if (mapping.getKey().isInstance(cause)) {
final Status status = mapping.getValue().apply(ctx, cause, metadata);
return status == null ? null : status.withCause(cause);
final Status newStatus = mapping.getValue().apply(ctx, status, cause, metadata);
return newStatus == null ? null : newStatus.withCause(cause);
}
}
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ public void start(Listener<O> responseListener, Metadata metadata) {
prepareHeaders(compressor, metadata, remainingNanos);

final BiFunction<ClientRequestContext, Throwable, HttpResponse> errorResponseFactory =
(unused, cause) -> HttpResponse.ofFailure(exceptionHandler.apply(ctx, cause, metadata)
(unused, cause) -> HttpResponse.ofFailure(exceptionHandler.apply(ctx, null, cause, metadata)
.withDescription(cause.getMessage())
.asRuntimeException());
final HttpResponse res = initContextAndExecuteWithFallback(
Expand Down Expand Up @@ -454,7 +454,7 @@ public void onNext(DeframedMessage message) {
});
} catch (Throwable t) {
final Metadata metadata = new Metadata();
close(exceptionHandler.apply(ctx, t, metadata), metadata);
close(exceptionHandler.apply(ctx, null, t, metadata), metadata);
}
}

Expand Down Expand Up @@ -511,7 +511,7 @@ private void prepareHeaders(Compressor compressor, Metadata metadata, long remai

private void closeWhenListenerThrows(Throwable t) {
final Metadata metadata = new Metadata();
closeWhenEos(exceptionHandler.apply(ctx, t, metadata), metadata);
closeWhenEos(exceptionHandler.apply(ctx, null, t, metadata), metadata);
}

private void closeWhenEos(Status status, Metadata metadata) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ public void processHeaders(HttpHeaders headers, StreamDecoderOutput<DeframedMess
decompressor(ForwardingDecompressor.forGrpc(decompressor));
} catch (Throwable t) {
final Metadata metadata = new Metadata();
transportStatusListener.transportReportStatus(exceptionHandler.apply(ctx, t, metadata),
transportStatusListener.transportReportStatus(exceptionHandler.apply(ctx, null, t, metadata),
metadata);
return;
}
Expand Down Expand Up @@ -148,7 +148,8 @@ public void processTrailers(HttpHeaders headers, StreamDecoderOutput<DeframedMes
@Override
public void processOnError(Throwable cause) {
final Metadata metadata = new Metadata();
transportStatusListener.transportReportStatus(exceptionHandler.apply(ctx, cause, metadata), metadata);
transportStatusListener.transportReportStatus(
exceptionHandler.apply(ctx, null, cause, metadata), metadata);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ public UnwrappingGrpcExceptionHandleFunction(GrpcExceptionHandlerFunction handle
delegate = handlerFunction;
}

@Nullable
@Override
public @Nullable Status apply(RequestContext ctx, Throwable cause, Metadata metadata) {
public Status apply(RequestContext ctx, @Nullable Status status, Throwable cause, Metadata metadata) {
final Throwable t = peelAndUnwrap(cause);
return delegate.apply(ctx, t, metadata);
return delegate.apply(ctx, status, t, metadata);
}

private static Throwable peelAndUnwrap(Throwable t) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ public final void close(Throwable exception) {
public final void close(Throwable exception, boolean cancelled) {
exception = Exceptions.peel(exception);
final Metadata metadata = generateMetadataFromThrowable(exception);
final Status status = exceptionHandler.apply(ctx, exception, metadata);
final Status status = exceptionHandler.apply(ctx, null, exception, metadata);
close(new ServerStatusAndMetadata(status, metadata, false, cancelled), exception);
}

Expand All @@ -223,7 +223,7 @@ public final void close(Status status, Metadata metadata) {
close(new ServerStatusAndMetadata(status, metadata, false));
return;
}
Status newStatus = exceptionHandler.apply(ctx, status.getCause(), metadata);
Status newStatus = exceptionHandler.apply(ctx, status, status.getCause(), metadata);
assert newStatus != null;
if (status.getDescription() != null) {
newStatus = newStatus.withDescription(status.getDescription());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ protected HttpResponse doPost(ServiceRequestContext ctx, HttpRequest req) throws
return HttpResponse.of(
(ResponseHeaders) AbstractServerCall.statusToTrailers(
ctx, defaultHeaders.get(serializationFormat).toBuilder(),
exceptionHandler.apply(ctx, e, metadata), metadata));
exceptionHandler.apply(ctx, null, e, metadata), metadata));
}
} else {
if (Boolean.TRUE.equals(ctx.attr(AbstractUnframedGrpcService.IS_UNFRAMED_GRPC))) {
Expand Down Expand Up @@ -320,7 +320,7 @@ private <I, O> void startCall(ServerMethodDefinition<I, O> methodDef, ServiceReq
call.setListener(listener);
call.startDeframing();
ctx.whenRequestCancelling().handle((cancellationCause, unused) -> {
final Status status = call.exceptionHandler().apply(ctx, cancellationCause, headers);
final Status status = call.exceptionHandler().apply(ctx, null, cancellationCause, headers);
assert status != null;
call.close(new ServerStatusAndMetadata(status, new Metadata(), true, true));
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,8 @@ public GrpcServiceBuilder exceptionHandler(GrpcExceptionHandlerFunction exceptio
@Deprecated
public GrpcServiceBuilder exceptionMapping(GrpcStatusFunction statusFunction) {
requireNonNull(statusFunction, "statusFunction");
return exceptionHandler(statusFunction::apply);
return exceptionHandler(
(ctx, status, throwable, metadata) -> statusFunction.apply(ctx, throwable, metadata));
}

/**
Expand Down Expand Up @@ -943,7 +944,9 @@ public GrpcServiceBuilder addExceptionMapping(Class<? extends Throwable> excepti
checkState(exceptionHandler == null,
"addExceptionMapping() and exceptionMapping() are mutually exclusive.");

exceptionMappingsBuilder().on(exceptionType, statusFunction::apply);
exceptionMappingsBuilder().on(exceptionType,
(ctx, status, throwable, metadata) ->
statusFunction.apply(ctx, throwable, metadata));
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ public O parse(InputStream inputStream) {

@Test
void useDefaultGrpcExceptionHandlerFunctionAsFallback() {
final GrpcExceptionHandlerFunction noopExceptionHandler = (ctx, cause, metadata) -> null;
final GrpcExceptionHandlerFunction noopExceptionHandler = (ctx, status, cause, metadata) -> null;
final GrpcExceptionHandlerFunction exceptionHandler =
GrpcExceptionHandlerFunction.builder()
.on(ContentTooLargeException.class, noopExceptionHandler)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ void chaining() {
final RuntimeException exception = new RuntimeException();
final TestServiceBlockingStub stub =
GrpcClients.builder(server.httpUri())
.exceptionHandler(((ctx, cause, metadata) -> {
.exceptionHandler(((ctx, status, cause, metadata) -> {
stringDeque.add("1");
return null;
}))
.exceptionHandler(((ctx, cause, metadata) -> {
.exceptionHandler(((ctx, status, cause, metadata) -> {
stringDeque.add("2");
return null;
}))
.exceptionHandler(((ctx, cause, metadata) -> {
.exceptionHandler(((ctx, status, cause, metadata) -> {
if (cause == exception) {
stringDeque.add("3");
return Status.DATA_LOSS;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ void cancelAfterBegin() throws Exception {
responseObserver.awaitCompletion();
assertThat(responseObserver.getValues()).isEmpty();
assertThat(GrpcExceptionHandlerFunction.of()
.apply(null, responseObserver.getError(), null)
.apply(null, null, responseObserver.getError(), null)
.getCode()).isEqualTo(Code.CANCELLED);

final RequestLog log = requestLogQueue.take();
Expand Down Expand Up @@ -783,7 +783,7 @@ void cancelAfterFirstResponse() throws Exception {
responseObserver.awaitCompletion(operationTimeoutMillis(), TimeUnit.MILLISECONDS);
assertThat(responseObserver.getValues()).hasSize(1);
assertThat(GrpcExceptionHandlerFunction.of()
.apply(null, responseObserver.getError(), null)
.apply(null, null, responseObserver.getError(), null)
.getCode()).isEqualTo(Code.CANCELLED);

checkRequestLog((rpcReq, rpcRes, grpcStatus) -> {
Expand Down Expand Up @@ -1418,7 +1418,7 @@ void deadlineExceededServerStreaming() throws Exception {

assertThat(recorder.getError()).isNotNull();
assertThat(GrpcExceptionHandlerFunction.of()
.apply(null, recorder.getError(), null)
.apply(null, null, recorder.getError(), null)
.getCode())
.isEqualTo(Status.DEADLINE_EXCEEDED.getCode());

Expand Down Expand Up @@ -1618,10 +1618,10 @@ void statusCodeAndMessage() throws Exception {
verify(responseObserver,
timeout(operationTimeoutMillis())).onError(captor.capture());
assertThat(GrpcExceptionHandlerFunction.of()
.apply(null, captor.getValue(), null)
.apply(null, null, captor.getValue(), null)
.getCode()).isEqualTo(Status.UNKNOWN.getCode());
assertThat(GrpcExceptionHandlerFunction.of()
.apply(null, captor.getValue(), null)
.apply(null, null, captor.getValue(), null)
.getDescription()).isEqualTo(errorMessage);
verifyNoMoreInteractions(responseObserver);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,16 @@ class DefaultGrpcExceptionHandlerFunctionTest {
void failFastExceptionToUnavailableCode() {
assertThat(GrpcExceptionHandlerFunction
.of()
.apply(null, new FailFastException(CircuitBreaker.ofDefaultName()), null)
.apply(null, null, new FailFastException(CircuitBreaker.ofDefaultName()), null)
.getCode()).isEqualTo(Status.Code.UNAVAILABLE);
}

@Test
void invalidProtocolBufferExceptionToInvalidArgumentCode() {
assertThat(GrpcExceptionHandlerFunction
.of()
.apply(null, new InvalidProtocolBufferException("Failed to parse message"), null)
.apply(null, null,
new InvalidProtocolBufferException("Failed to parse message"), null)
.getCode()).isEqualTo(Status.Code.INVALID_ARGUMENT);
}
}
Loading

0 comments on commit 142aff8

Please sign in to comment.