diff --git a/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/interceptors/FailingInInterceptorTest.java b/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/interceptors/FailingInInterceptorTest.java new file mode 100644 index 00000000000000..8af683c485c5e0 --- /dev/null +++ b/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/interceptors/FailingInInterceptorTest.java @@ -0,0 +1,64 @@ +package io.quarkus.grpc.server.interceptors; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.time.Duration; + +import javax.enterprise.context.ApplicationScoped; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.grpc.*; +import io.grpc.examples.helloworld.*; +import io.quarkus.grpc.GlobalInterceptor; +import io.quarkus.grpc.GrpcClient; +import io.quarkus.grpc.server.services.HelloService; +import io.quarkus.test.QuarkusUnitTest; +import io.smallrye.mutiny.Uni; + +public class FailingInInterceptorTest { + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest().setArchiveProducer( + () -> ShrinkWrap.create(JavaArchive.class) + .addPackage(GreeterGrpc.class.getPackage()) + .addClasses(MyFailingInterceptor.class, GreeterBean.class, HelloRequest.class, HelloService.class)); + + @GrpcClient + Greeter greeter; + + @Test + void test() { + Uni result = greeter.sayHello(HelloRequest.newBuilder().setName("ServiceA").build()); + assertThatThrownBy(() -> result.await().atMost(Duration.ofSeconds(4))) + .isInstanceOf(StatusRuntimeException.class) + .hasMessageContaining("UNKNOWN"); + } + + @ApplicationScoped + @GlobalInterceptor + public static class MyFailingInterceptor implements ServerInterceptor { + + @Override + public ServerCall.Listener interceptCall(ServerCall call, Metadata headers, + ServerCallHandler next) { + return next + .startCall(new ForwardingServerCall.SimpleForwardingServerCall(call) { + + @Override + public void sendMessage(RespT message) { + throw new IllegalArgumentException("BOOM"); + } + + @Override + public void close(Status status, Metadata trailers) { + super.close(status, trailers); + } + }, headers); + } + } + +} diff --git a/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/interceptors/FailingInterceptorTest.java b/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/interceptors/FailingInterceptorTest.java new file mode 100644 index 00000000000000..2560dc8fb62069 --- /dev/null +++ b/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/interceptors/FailingInterceptorTest.java @@ -0,0 +1,52 @@ +package io.quarkus.grpc.server.interceptors; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.time.Duration; + +import javax.enterprise.context.ApplicationScoped; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.grpc.*; +import io.grpc.examples.helloworld.*; +import io.quarkus.grpc.GlobalInterceptor; +import io.quarkus.grpc.GrpcClient; +import io.quarkus.grpc.server.services.HelloService; +import io.quarkus.test.QuarkusUnitTest; +import io.smallrye.mutiny.Uni; + +public class FailingInterceptorTest { + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest().setArchiveProducer( + () -> ShrinkWrap.create(JavaArchive.class) + .addPackage(GreeterGrpc.class.getPackage()) + .addClasses(MyFailingInterceptor.class, GreeterBean.class, HelloRequest.class, HelloService.class)); + + @GrpcClient + Greeter greeter; + + @Test + void test() { + Uni result = greeter.sayHello(HelloRequest.newBuilder().setName("ServiceA").build()); + assertThatThrownBy(() -> result.await().atMost(Duration.ofSeconds(4))) + .isInstanceOf(StatusRuntimeException.class) + .hasMessageContaining("UNKNOWN"); + } + + @ApplicationScoped + @GlobalInterceptor + public static class MyFailingInterceptor implements ServerInterceptor { + + @Override + public ServerCall.Listener interceptCall(ServerCall call, Metadata headers, + ServerCallHandler next) { + throw new IllegalArgumentException("BOOM!"); + } + } + +} diff --git a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcDuplicatedContextGrpcInterceptor.java b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcDuplicatedContextGrpcInterceptor.java index 21d7b0ee422495..cb009a460e778f 100644 --- a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcDuplicatedContextGrpcInterceptor.java +++ b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcDuplicatedContextGrpcInterceptor.java @@ -2,6 +2,8 @@ import static io.quarkus.vertx.core.runtime.context.VertxContextSafetyToggle.setContextSafe; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; import java.util.function.Supplier; import javax.enterprise.context.ApplicationScoped; @@ -9,13 +11,11 @@ import org.jboss.logging.Logger; -import io.grpc.Metadata; -import io.grpc.ServerCall; -import io.grpc.ServerCallHandler; -import io.grpc.ServerInterceptor; +import io.grpc.*; import io.quarkus.grpc.GlobalInterceptor; import io.smallrye.common.vertx.VertxContext; import io.vertx.core.Context; +import io.vertx.core.Handler; import io.vertx.core.Vertx; @ApplicationScoped @@ -44,7 +44,7 @@ public ServerCall.Listener interceptCall(ServerCall(() -> next.startCall(call, headers), local); + return new ListenedOnDuplicatedContext<>(call, () -> next.startCall(call, headers), local); } else { log.warn("Unable to run on a duplicated context - interceptor not called on the Vert.x event loop"); return next.startCall(call, headers); @@ -56,67 +56,99 @@ public int getPriority() { return Integer.MAX_VALUE; } - static class ListenedOnDuplicatedContext extends ServerCall.Listener { + static class ListenedOnDuplicatedContext extends ServerCall.Listener { private final Context context; private final Supplier> supplier; + private final ServerCall call; private ServerCall.Listener delegate; - public ListenedOnDuplicatedContext(Supplier> supplier, Context context) { + private final AtomicBoolean closed = new AtomicBoolean(); + + public ListenedOnDuplicatedContext(ServerCall call, Supplier> supplier, + Context context) { this.context = context; this.supplier = supplier; + this.call = call; } private synchronized ServerCall.Listener getDelegate() { if (delegate == null) { - delegate = supplier.get(); + try { + delegate = supplier.get(); + } catch (Throwable t) { + // If the interceptor supplier throws an exception, catch it, and close the call. + log.warnf("Unable to retrieve gRPC Server call listener", t); + close(t); + return null; + } } return delegate; } - @Override - public void onMessage(ReqT message) { + private void close(Throwable t) { + if (closed.compareAndSet(false, true)) { + call.close(Status.fromThrowable(t), new Metadata()); + } + } + + private void invoke(Consumer> invocation) { if (Vertx.currentContext() == context) { - getDelegate().onMessage(message); + ServerCall.Listener listener = getDelegate(); + if (listener == null) { + return; + } + try { + invocation.accept(listener); + } catch (Throwable t) { + close(t); + } } else { - context.runOnContext(x -> getDelegate().onMessage(message)); + context.runOnContext(new Handler() { + @Override + public void handle(Void x) { + ServerCall.Listener listener = ListenedOnDuplicatedContext.this.getDelegate(); + if (listener == null) { + return; + } + try { + invocation.accept(listener); + } catch (Throwable t) { + close(t); + } + } + }); } } + @Override + public void onMessage(ReqT message) { + invoke(new Consumer>() { + @Override + public void accept(ServerCall.Listener listener) { + listener.onMessage(message); + } + }); + } + @Override public void onReady() { - if (Vertx.currentContext() == context) { - getDelegate().onReady(); - } else { - context.runOnContext(x -> getDelegate().onReady()); - } + invoke(ServerCall.Listener::onReady); } @Override public void onHalfClose() { - if (Vertx.currentContext() == context) { - getDelegate().onHalfClose(); - } else { - context.runOnContext(x -> getDelegate().onHalfClose()); - } + invoke(ServerCall.Listener::onHalfClose); } @Override public void onCancel() { - if (Vertx.currentContext() == context) { - getDelegate().onCancel(); - } else { - context.runOnContext(x -> getDelegate().onCancel()); - } + invoke(ServerCall.Listener::onCancel); } @Override public void onComplete() { - if (Vertx.currentContext() == context) { - getDelegate().onComplete(); - } else { - context.runOnContext(x -> getDelegate().onComplete()); - } + invoke(ServerCall.Listener::onComplete); } } }