Skip to content

Commit

Permalink
Catch exception happening in the gRPC interceptors and close the call…
Browse files Browse the repository at this point in the history
… immediately.

Fix quarkusio#28053.

(cherry picked from commit 747f6ee)
  • Loading branch information
cescoffier authored and gsmet committed Oct 3, 2022
1 parent 6e1d849 commit 1f2ac33
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package io.quarkus.grpc.server.interceptors;

import static org.assertj.core.api.Assertions.assertThatThrownBy;

import java.time.Duration;

import javax.enterprise.context.ApplicationScoped;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.grpc.ForwardingServerCall;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.examples.helloworld.*;
import io.quarkus.grpc.GlobalInterceptor;
import io.quarkus.grpc.GrpcClient;
import io.quarkus.grpc.server.services.HelloService;
import io.quarkus.test.QuarkusUnitTest;
import io.smallrye.mutiny.Uni;

public class FailingInInterceptorTest {

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest().setArchiveProducer(
() -> ShrinkWrap.create(JavaArchive.class)
.addPackage(GreeterGrpc.class.getPackage())
.addClasses(MyFailingInterceptor.class, GreeterBean.class, HelloRequest.class, HelloService.class));

@GrpcClient
Greeter greeter;

@Test
void test() {
Uni<HelloReply> result = greeter.sayHello(HelloRequest.newBuilder().setName("ServiceA").build());
assertThatThrownBy(() -> result.await().atMost(Duration.ofSeconds(4)))
.isInstanceOf(StatusRuntimeException.class)
.hasMessageContaining("UNKNOWN");
}

@ApplicationScoped
@GlobalInterceptor
public static class MyFailingInterceptor implements ServerInterceptor {

@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
return next
.startCall(new ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT>(call) {

@Override
public void sendMessage(RespT message) {
throw new IllegalArgumentException("BOOM");
}

@Override
public void close(Status status, Metadata trailers) {
super.close(status, trailers);
}
}, headers);
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package io.quarkus.grpc.server.interceptors;

import static org.assertj.core.api.Assertions.assertThatThrownBy;

import java.time.Duration;

import javax.enterprise.context.ApplicationScoped;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.StatusRuntimeException;
import io.grpc.examples.helloworld.Greeter;
import io.grpc.examples.helloworld.GreeterBean;
import io.grpc.examples.helloworld.GreeterGrpc;
import io.grpc.examples.helloworld.HelloReply;
import io.grpc.examples.helloworld.HelloRequest;
import io.quarkus.grpc.GlobalInterceptor;
import io.quarkus.grpc.GrpcClient;
import io.quarkus.grpc.server.services.HelloService;
import io.quarkus.test.QuarkusUnitTest;
import io.smallrye.mutiny.Uni;

public class FailingInterceptorTest {

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest().setArchiveProducer(
() -> ShrinkWrap.create(JavaArchive.class)
.addPackage(GreeterGrpc.class.getPackage())
.addClasses(MyFailingInterceptor.class, GreeterBean.class, HelloRequest.class, HelloService.class));

@GrpcClient
Greeter greeter;

@Test
void test() {
Uni<HelloReply> result = greeter.sayHello(HelloRequest.newBuilder().setName("ServiceA").build());
assertThatThrownBy(() -> result.await().atMost(Duration.ofSeconds(4)))
.isInstanceOf(StatusRuntimeException.class)
.hasMessageContaining("UNKNOWN");
}

@ApplicationScoped
@GlobalInterceptor
public static class MyFailingInterceptor implements ServerInterceptor {

@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
throw new IllegalArgumentException("BOOM!");
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import static io.quarkus.vertx.core.runtime.context.VertxContextSafetyToggle.setContextSafe;

import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.function.Supplier;

import javax.enterprise.context.ApplicationScoped;
Expand All @@ -13,9 +15,11 @@
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.quarkus.grpc.GlobalInterceptor;
import io.smallrye.common.vertx.VertxContext;
import io.vertx.core.Context;
import io.vertx.core.Handler;
import io.vertx.core.Vertx;

@ApplicationScoped
Expand Down Expand Up @@ -44,7 +48,7 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re
setContextSafe(local, true);

// Must be sure to call next.startCall on the right context
return new ListenedOnDuplicatedContext<>(() -> next.startCall(call, headers), local);
return new ListenedOnDuplicatedContext<>(call, () -> next.startCall(call, headers), local);
} else {
log.warn("Unable to run on a duplicated context - interceptor not called on the Vert.x event loop");
return next.startCall(call, headers);
Expand All @@ -56,67 +60,99 @@ public int getPriority() {
return Integer.MAX_VALUE;
}

static class ListenedOnDuplicatedContext<ReqT> extends ServerCall.Listener<ReqT> {
static class ListenedOnDuplicatedContext<ReqT, RespT> extends ServerCall.Listener<ReqT> {

private final Context context;
private final Supplier<ServerCall.Listener<ReqT>> supplier;
private final ServerCall<ReqT, RespT> call;
private ServerCall.Listener<ReqT> delegate;

public ListenedOnDuplicatedContext(Supplier<ServerCall.Listener<ReqT>> supplier, Context context) {
private final AtomicBoolean closed = new AtomicBoolean();

public ListenedOnDuplicatedContext(ServerCall<ReqT, RespT> call, Supplier<ServerCall.Listener<ReqT>> supplier,
Context context) {
this.context = context;
this.supplier = supplier;
this.call = call;
}

private synchronized ServerCall.Listener<ReqT> getDelegate() {
if (delegate == null) {
delegate = supplier.get();
try {
delegate = supplier.get();
} catch (Throwable t) {
// If the interceptor supplier throws an exception, catch it, and close the call.
log.warnf("Unable to retrieve gRPC Server call listener", t);
close(t);
return null;
}
}
return delegate;
}

@Override
public void onMessage(ReqT message) {
private void close(Throwable t) {
if (closed.compareAndSet(false, true)) {
call.close(Status.fromThrowable(t), new Metadata());
}
}

private void invoke(Consumer<ServerCall.Listener<ReqT>> invocation) {
if (Vertx.currentContext() == context) {
getDelegate().onMessage(message);
ServerCall.Listener<ReqT> listener = getDelegate();
if (listener == null) {
return;
}
try {
invocation.accept(listener);
} catch (Throwable t) {
close(t);
}
} else {
context.runOnContext(x -> getDelegate().onMessage(message));
context.runOnContext(new Handler<Void>() {
@Override
public void handle(Void x) {
ServerCall.Listener<ReqT> listener = ListenedOnDuplicatedContext.this.getDelegate();
if (listener == null) {
return;
}
try {
invocation.accept(listener);
} catch (Throwable t) {
close(t);
}
}
});
}
}

@Override
public void onMessage(ReqT message) {
invoke(new Consumer<ServerCall.Listener<ReqT>>() {
@Override
public void accept(ServerCall.Listener<ReqT> listener) {
listener.onMessage(message);
}
});
}

@Override
public void onReady() {
if (Vertx.currentContext() == context) {
getDelegate().onReady();
} else {
context.runOnContext(x -> getDelegate().onReady());
}
invoke(ServerCall.Listener::onReady);
}

@Override
public void onHalfClose() {
if (Vertx.currentContext() == context) {
getDelegate().onHalfClose();
} else {
context.runOnContext(x -> getDelegate().onHalfClose());
}
invoke(ServerCall.Listener::onHalfClose);
}

@Override
public void onCancel() {
if (Vertx.currentContext() == context) {
getDelegate().onCancel();
} else {
context.runOnContext(x -> getDelegate().onCancel());
}
invoke(ServerCall.Listener::onCancel);
}

@Override
public void onComplete() {
if (Vertx.currentContext() == context) {
getDelegate().onComplete();
} else {
context.runOnContext(x -> getDelegate().onComplete());
}
invoke(ServerCall.Listener::onComplete);
}
}
}

0 comments on commit 1f2ac33

Please sign in to comment.