diff --git a/extensions/grpc/deployment/src/main/java/io/quarkus/grpc/deployment/GrpcServerProcessor.java b/extensions/grpc/deployment/src/main/java/io/quarkus/grpc/deployment/GrpcServerProcessor.java index 158adc5dac633..c86443da55ff6 100644 --- a/extensions/grpc/deployment/src/main/java/io/quarkus/grpc/deployment/GrpcServerProcessor.java +++ b/extensions/grpc/deployment/src/main/java/io/quarkus/grpc/deployment/GrpcServerProcessor.java @@ -6,20 +6,25 @@ import java.lang.reflect.Modifier; import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import org.eclipse.microprofile.config.Config; import org.eclipse.microprofile.config.ConfigProvider; +import org.jboss.jandex.AnnotationTarget; import org.jboss.jandex.ClassInfo; +import org.jboss.jandex.DotName; import org.jboss.jandex.MethodInfo; import org.jboss.logging.Logger; import io.grpc.BindableService; import io.grpc.internal.ServerImpl; import io.quarkus.arc.deployment.AdditionalBeanBuildItem; +import io.quarkus.arc.deployment.AnnotationsTransformerBuildItem; import io.quarkus.arc.deployment.ValidationPhaseBuildItem; -import io.quarkus.arc.deployment.ValidationPhaseBuildItem.ValidationErrorBuildItem; +import io.quarkus.arc.processor.AnnotationsTransformer; import io.quarkus.arc.processor.BeanInfo; import io.quarkus.arc.processor.BuiltinScope; import io.quarkus.deployment.IsDevelopment; @@ -44,6 +49,8 @@ import io.quarkus.grpc.runtime.config.GrpcServerBuildTimeConfig; import io.quarkus.grpc.runtime.health.GrpcHealthEndpoint; import io.quarkus.grpc.runtime.health.GrpcHealthStorage; +import io.quarkus.grpc.runtime.supports.context.CleanUpRequestContext; +import io.quarkus.grpc.runtime.supports.context.GrpcRequestContextCdiInterceptor; import io.quarkus.kubernetes.spi.KubernetesPortBuildItem; import io.quarkus.netty.deployment.MinNettyAllocatorMaxOrderBuildItem; import io.quarkus.smallrye.health.deployment.spi.HealthBuildItem; @@ -66,9 +73,13 @@ MinNettyAllocatorMaxOrderBuildItem setMinimalNettyMaxOrderSize() { @BuildStep void discoverBindableServices(BuildProducer bindables, - CombinedIndexBuildItem combinedIndexBuildItem) { + CombinedIndexBuildItem combinedIndexBuildItem, + BuildProducer annotationTransformers) { Collection bindableServices = combinedIndexBuildItem.getIndex() .getAllKnownImplementors(GrpcDotNames.BINDABLE_SERVICE); + + final Set grpcServiceNames = new HashSet<>(); + for (ClassInfo service : bindableServices) { if (Modifier.isAbstract(service.flags())) { continue; @@ -80,15 +91,31 @@ void discoverBindableServices(BuildProducer bindables, } } bindables.produce(item); + grpcServiceNames.add(service.name()); } + + annotationTransformers.produce(new AnnotationsTransformerBuildItem(new AnnotationsTransformer() { + @Override + public boolean appliesTo(AnnotationTarget.Kind kind) { + return kind == AnnotationTarget.Kind.METHOD; + } + + @Override + public void transform(TransformationContext transformationContext) { + MethodInfo target = transformationContext.getTarget().asMethod(); + if (grpcServiceNames.contains(target.declaringClass().name()) && !"bindService".equals(target.name())) { + transformationContext.transform().add(CleanUpRequestContext.class).done(); + } + } + })); } @BuildStep void validateBindableServices(ValidationPhaseBuildItem validationPhase, - BuildProducer errors) { + BuildProducer errors) { for (BeanInfo bean : validationPhase.getContext().beans().classBeans().withBeanType(BindableService.class)) { if (!bean.getScope().getDotName().equals(BuiltinScope.SINGLETON.getName())) { - errors.produce(new ValidationErrorBuildItem( + errors.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem( new IllegalStateException("A gRPC service bean must have the javax.inject.Singleton scope: " + bean))); } } @@ -109,6 +136,8 @@ void registerBeans(BuildProducer beans, List bindables, BuildProducer features) { // @GrpcService is a CDI stereotype beans.produce(new AdditionalBeanBuildItem(GrpcService.class)); + beans.produce(new AdditionalBeanBuildItem(GrpcRequestContextCdiInterceptor.class)); + beans.produce(new AdditionalBeanBuildItem(CleanUpRequestContext.class)); if (!bindables.isEmpty()) { beans.produce(AdditionalBeanBuildItem.unremovableOf(GrpcContainer.class)); features.produce(new FeatureBuildItem(GRPC_SERVER)); diff --git a/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/devmode/DevModeTestService.java b/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/devmode/DevModeTestService.java index 939b774c07e7e..274bb1c239835 100644 --- a/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/devmode/DevModeTestService.java +++ b/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/devmode/DevModeTestService.java @@ -1,5 +1,6 @@ package io.quarkus.grpc.server.devmode; +import javax.enterprise.context.RequestScoped; import javax.inject.Singleton; import devmodetest.v1.Devmodetest; @@ -7,6 +8,7 @@ import io.grpc.examples.helloworld.HelloReply; import io.grpc.examples.helloworld.HelloRequest; import io.grpc.stub.StreamObserver; +import io.quarkus.arc.Arc; @Singleton public class DevModeTestService extends GreeterGrpc.GreeterImplBase { @@ -20,7 +22,11 @@ public void sayHello(HelloRequest request, StreamObserver responseOb } else { response = greeting + request.getName(); } - responseObserver.onNext(HelloReply.newBuilder().setMessage(response).build()); + if (Arc.container().getActiveContext(RequestScoped.class) != null) { + responseObserver.onNext(HelloReply.newBuilder().setMessage(response).build()); + } else { + throw new IllegalStateException("request context not active, failing"); + } responseObserver.onCompleted(); } } \ No newline at end of file diff --git a/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/services/TestService.java b/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/services/TestService.java index 79c5749c867fe..dcb9ef3f4033a 100644 --- a/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/services/TestService.java +++ b/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/services/TestService.java @@ -82,8 +82,6 @@ public void onCompleted() { }; } - ; - @Override public StreamObserver fullDuplexCall( StreamObserver responseObserver) { @@ -122,6 +120,7 @@ public StreamObserver halfDuplexCall( return new StreamObserver() { @Override public void onNext(Messages.StreamingOutputCallRequest streamingOutputCallRequest) { + assertThatTheRequestScopeIsActive(); String payload = streamingOutputCallRequest.getPayload().getBody().toStringUtf8(); ByteString value = ByteString.copyFromUtf8(payload.toUpperCase()); Messages.Payload response = Messages.Payload.newBuilder().setBody(value).build(); diff --git a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/GrpcServerRecorder.java b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/GrpcServerRecorder.java index 0cbbdca5d85c2..b5887dc299471 100644 --- a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/GrpcServerRecorder.java +++ b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/GrpcServerRecorder.java @@ -42,7 +42,7 @@ import io.quarkus.grpc.runtime.reflection.ReflectionService; import io.quarkus.grpc.runtime.supports.BlockingServerInterceptor; import io.quarkus.grpc.runtime.supports.CompressionInterceptor; -import io.quarkus.grpc.runtime.supports.RequestScopeHandlerInterceptor; +import io.quarkus.grpc.runtime.supports.context.GrpcRequestContextGrpcInterceptor; import io.quarkus.runtime.LaunchMode; import io.quarkus.runtime.RuntimeValue; import io.quarkus.runtime.ShutdownContext; @@ -66,6 +66,8 @@ public class GrpcServerRecorder { private static final AtomicInteger grpcVerticleCount = new AtomicInteger(0); private Map> blockingMethodsPerService = Collections.emptyMap(); + private static volatile DevModeWrapper devModeWrapper; + public void initializeGrpcServer(RuntimeValue vertxSupplier, GrpcConfiguration cfg, ShutdownContext shutdown, @@ -90,7 +92,7 @@ public void initializeGrpcServer(RuntimeValue vertxSupplier, if (GrpcServerReloader.getServer() == null) { devModeStart(grpcContainer, vertx, configuration, shutdown, launchMode); } else { - devModeReload(grpcContainer); + devModeReload(grpcContainer, vertx, configuration); } } else { prodStart(grpcContainer, vertx, configuration, launchMode); @@ -159,6 +161,8 @@ private void devModeStart(GrpcContainer grpcContainer, Vertx vertx, GrpcServerCo ShutdownContext shutdown, LaunchMode launchMode) { CompletableFuture future = new CompletableFuture<>(); + devModeWrapper = new DevModeWrapper(Thread.currentThread().getContextClassLoader()); + VertxServer vertxServer = buildServer(vertx, configuration, grpcContainer, launchMode) .start(new Handler>() { // NOSONAR @Override @@ -251,19 +255,18 @@ private static class GrpcServiceDefinition { } public String getImplementationClassName() { - return service.getClass().getName(); + // all grpc services have a io.quarkus.grpc.runtime.supports.context.GrpcRequestContextCdiInterceptor + // this means Arc passes a subclass to grpc internals. That's why we take superclass here + return service.getClass().getSuperclass().getName(); } } - private static void devModeReload(GrpcContainer grpcContainer) { - List svc = collectServiceDefinitions(grpcContainer.getServices()); + private void devModeReload(GrpcContainer grpcContainer, Vertx vertx, GrpcServerConfiguration configuration) { + List services = collectServiceDefinitions(grpcContainer.getServices()); List definitions = new ArrayList<>(); Map> methods = new HashMap<>(); - for (GrpcServiceDefinition service : svc) { - for (ServerMethodDefinition method : service.definition.getMethods()) { - methods.put(method.getMethodDescriptor().getFullMethodName(), method); - } + for (GrpcServiceDefinition service : services) { definitions.add(service.definition); } @@ -272,8 +275,20 @@ private static void devModeReload(GrpcContainer grpcContainer) { for (ServerMethodDefinition method : reflectionService.getMethods()) { methods.put(method.getMethodDescriptor().getFullMethodName(), method); } + List servicesWithInterceptors = new ArrayList<>(); + CompressionInterceptor compressionInterceptor = prepareCompressionInterceptor(configuration); + for (GrpcServiceDefinition service : services) { + servicesWithInterceptors.add(serviceWithInterceptors(vertx, compressionInterceptor, service, true)); + } + + for (ServerServiceDefinition serviceWithInterceptors : servicesWithInterceptors) { + for (ServerMethodDefinition method : serviceWithInterceptors.getMethods()) { + methods.put(method.getMethodDescriptor().getFullMethodName(), method); + } + } + devModeWrapper = new DevModeWrapper(Thread.currentThread().getContextClassLoader()); - GrpcServerReloader.reinitialize(definitions, methods, grpcContainer.getSortedInterceptors()); + GrpcServerReloader.reinitialize(servicesWithInterceptors, methods, grpcContainer.getSortedInterceptors()); } public static int getVerticleCount() { @@ -320,26 +335,11 @@ public void handle(HttpServerOptions options) { List toBeRegistered = collectServiceDefinitions(grpcContainer.getServices()); List definitions = new ArrayList<>(); - CompressionInterceptor compressionInterceptor = null; - if (configuration.compression.isPresent()) { - compressionInterceptor = new CompressionInterceptor(configuration.compression.get()); - } + CompressionInterceptor compressionInterceptor = prepareCompressionInterceptor(configuration); for (GrpcServiceDefinition service : toBeRegistered) { - List interceptors = new ArrayList<>(); - if (compressionInterceptor != null) { - interceptors.add(compressionInterceptor); - } - // We only register the blocking interceptor if needed by at least one method of the service. - if (!blockingMethodsPerService.isEmpty()) { - List list = blockingMethodsPerService.get(service.getImplementationClassName()); - if (list != null) { - interceptors.add(new BlockingServerInterceptor(vertx, list)); - } - } - // Order matters! Request scope must be called first (on the event loop) and so should be last in the list... - interceptors.add(new RequestScopeHandlerInterceptor()); - builder.addService(ServerInterceptors.intercept(service.definition, interceptors)); + builder.addService( + serviceWithInterceptors(vertx, compressionInterceptor, service, launchMode == LaunchMode.DEVELOPMENT)); LOGGER.debugf("Registered gRPC service '%s'", service.definition.getServiceDescriptor().getName()); definitions.add(service.definition); } @@ -367,7 +367,7 @@ public void handle(Promise event) { new Handler>() { @Override public void handle(AsyncResult result) { - command.run(); + devModeWrapper.run(command); } }); } @@ -381,6 +381,38 @@ public void handle(AsyncResult result) { return builder.build(); } + /** + * Compression interceptor if needed, null otherwise + * + * @param configuration gRPC server configuration + * @return interceptor or null + */ + private CompressionInterceptor prepareCompressionInterceptor(GrpcServerConfiguration configuration) { + CompressionInterceptor compressionInterceptor = null; + if (configuration.compression.isPresent()) { + compressionInterceptor = new CompressionInterceptor(configuration.compression.get()); + } + return compressionInterceptor; + } + + private ServerServiceDefinition serviceWithInterceptors(Vertx vertx, CompressionInterceptor compressionInterceptor, + GrpcServiceDefinition service, boolean devMode) { + List interceptors = new ArrayList<>(); + if (compressionInterceptor != null) { + interceptors.add(compressionInterceptor); + } + // We only register the blocking interceptor if needed by at least one method of the service. + if (!blockingMethodsPerService.isEmpty()) { + List list = blockingMethodsPerService.get(service.getImplementationClassName()); + if (list != null) { + interceptors.add(new BlockingServerInterceptor(vertx, list, devMode)); + } + } + // Order matters! Request scope must be called first (on the event loop) and so should be last in the list... + interceptors.add(new GrpcRequestContextGrpcInterceptor()); + return ServerInterceptors.intercept(service.definition, interceptors); + } + private class GrpcServerVerticle extends AbstractVerticle { private final GrpcServerConfiguration configuration; private final GrpcContainer grpcContainer; @@ -432,4 +464,17 @@ public void handle(AsyncResult ar) { }); } } + + private class DevModeWrapper { + private final ClassLoader classLoader; + + public DevModeWrapper(ClassLoader contextClassLoader) { + classLoader = contextClassLoader; + } + + public void run(Runnable command) { + Thread.currentThread().setContextClassLoader(classLoader); + command.run(); + } + } } diff --git a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/BlockingServerInterceptor.java b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/BlockingServerInterceptor.java index ef5720b49f018..ad3a8001dffff 100644 --- a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/BlockingServerInterceptor.java +++ b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/BlockingServerInterceptor.java @@ -28,10 +28,12 @@ public class BlockingServerInterceptor implements ServerInterceptor { private final Vertx vertx; private final List blockingMethods; private final Map cache = new HashMap<>(); + private final boolean devMode; - public BlockingServerInterceptor(Vertx vertx, List blockingMethods) { + public BlockingServerInterceptor(Vertx vertx, List blockingMethods, boolean devMode) { this.vertx = vertx; this.blockingMethods = new ArrayList<>(); + this.devMode = devMode; for (String method : blockingMethods) { this.blockingMethods.add(method.toLowerCase()); } @@ -98,19 +100,12 @@ synchronized void setDelegate(ServerCall.Listener delegate) { private synchronized void executeOnContextOrEnqueue(Consumer> consumer) { if (this.delegate != null) { final Context grpcContext = Context.current(); - vertx.executeBlocking(new Handler>() { - @Override - public void handle(Promise f) { - final Context previous = Context.current(); - grpcContext.attach(); - try { - consumer.accept(delegate); - f.complete(); - } finally { - grpcContext.detach(previous); - } - } - }, true, null); + Handler> blockingHandler = new BlockingExecutionHandler<>(consumer, grpcContext, delegate); + if (devMode) { + blockingHandler = new DevModeBlockingExecutionHandler(Thread.currentThread().getContextClassLoader(), + blockingHandler); + } + vertx.executeBlocking(blockingHandler, true, null); } else { incomingEvents.add(consumer); } @@ -147,4 +142,45 @@ public void onReady() { } } + private static class DevModeBlockingExecutionHandler implements Handler> { + + final ClassLoader tccl; + final Handler> delegate; + + public DevModeBlockingExecutionHandler(ClassLoader tccl, Handler> delegate) { + this.tccl = tccl; + this.delegate = delegate; + } + + @Override + public void handle(Promise event) { + Thread.currentThread().setContextClassLoader(tccl); + delegate.handle(event); + } + } + + private static class BlockingExecutionHandler implements Handler> { + private final ServerCall.Listener delegate; + private final Context grpcContext; + private final Consumer> consumer; + + public BlockingExecutionHandler(Consumer> consumer, Context grpcContext, + ServerCall.Listener delegate) { + this.consumer = consumer; + this.grpcContext = grpcContext; + this.delegate = delegate; + } + + @Override + public void handle(Promise event) { + final Context previous = Context.current(); + grpcContext.attach(); + try { + consumer.accept(delegate); + event.complete(); + } finally { + grpcContext.detach(previous); + } + } + } } diff --git a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/RequestScopeHandlerInterceptor.java b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/RequestScopeHandlerInterceptor.java deleted file mode 100644 index 40ed99d29f5e5..0000000000000 --- a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/RequestScopeHandlerInterceptor.java +++ /dev/null @@ -1,58 +0,0 @@ -package io.quarkus.grpc.runtime.supports; - -import org.jboss.logmanager.Logger; - -import io.grpc.ForwardingServerCall; -import io.grpc.Metadata; -import io.grpc.ServerCall; -import io.grpc.ServerCallHandler; -import io.grpc.ServerInterceptor; -import io.grpc.Status; -import io.quarkus.arc.Arc; -import io.quarkus.arc.ManagedContext; -import io.vertx.core.Context; -import io.vertx.core.Handler; -import io.vertx.core.Vertx; - -public class RequestScopeHandlerInterceptor implements ServerInterceptor { - - private final ManagedContext reqContext; - private static final Logger LOGGER = Logger.getLogger(RequestScopeHandlerInterceptor.class.getName()); - - public RequestScopeHandlerInterceptor() { - reqContext = Arc.container().requestContext(); - } - - @Override - public ServerCall.Listener interceptCall(ServerCall call, - Metadata headers, - ServerCallHandler next) { - - // This interceptor is called first, so, we should be on the event loop. - Context capturedVertxContext = Vertx.currentContext(); - if (capturedVertxContext != null) { - boolean activateAndDeactivateContext = !reqContext.isActive(); - if (activateAndDeactivateContext) { - reqContext.activate(); - } - return next.startCall(new ForwardingServerCall.SimpleForwardingServerCall(call) { - @Override - public void close(Status status, Metadata trailers) { - super.close(status, trailers); - if (activateAndDeactivateContext) { - capturedVertxContext.runOnContext(new Handler() { - @Override - public void handle(Void ignored) { - reqContext.deactivate(); - } - }); - } - } - }, headers); - } else { - LOGGER.warning("Unable to activate the request scope - interceptor not called on the Vert.x event loop"); - return next.startCall(call, headers); - } - } - -} diff --git a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/CleanUpRequestContext.java b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/CleanUpRequestContext.java new file mode 100644 index 0000000000000..434c49f7c2fc0 --- /dev/null +++ b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/CleanUpRequestContext.java @@ -0,0 +1,16 @@ +package io.quarkus.grpc.runtime.supports.context; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Inherited; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import javax.interceptor.InterceptorBinding; + +@Inherited +@InterceptorBinding +@Target({ ElementType.TYPE }) +@Retention(RetentionPolicy.RUNTIME) +public @interface CleanUpRequestContext { +} diff --git a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcRequestContextCdiInterceptor.java b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcRequestContextCdiInterceptor.java new file mode 100644 index 0000000000000..7b4793a5ea99e --- /dev/null +++ b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcRequestContextCdiInterceptor.java @@ -0,0 +1,40 @@ +package io.quarkus.grpc.runtime.supports.context; + +import javax.interceptor.AroundInvoke; +import javax.interceptor.Interceptor; +import javax.interceptor.InvocationContext; + +import io.quarkus.arc.Arc; +import io.quarkus.arc.ManagedContext; +import io.vertx.core.Context; +import io.vertx.core.Vertx; + +@Interceptor +@CleanUpRequestContext +public class GrpcRequestContextCdiInterceptor { + + @AroundInvoke + public Object cleanUpContext(InvocationContext invocationContext) throws Exception { + boolean cleanUp = false; + ManagedContext requestContext = Arc.container().requestContext(); + if (!requestContext.isActive()) { + Context context = Vertx.currentContext(); + + if (context != null) { + cleanUp = true; + requestContext.activate(); + GrpcRequestContextHolder contextHolder = GrpcRequestContextHolder.get(context); + if (contextHolder != null) { + contextHolder.state = requestContext.getState(); + } + } + } + try { + return invocationContext.proceed(); + } finally { + if (cleanUp) { + requestContext.deactivate(); + } + } + } +} diff --git a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcRequestContextGrpcInterceptor.java b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcRequestContextGrpcInterceptor.java new file mode 100644 index 0000000000000..2c3356c427687 --- /dev/null +++ b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcRequestContextGrpcInterceptor.java @@ -0,0 +1,78 @@ +package io.quarkus.grpc.runtime.supports.context; + +import org.jboss.logmanager.Logger; + +import io.grpc.ForwardingServerCall; +import io.grpc.ForwardingServerCallListener; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import io.quarkus.arc.Arc; +import io.quarkus.arc.ManagedContext; +import io.vertx.core.Context; +import io.vertx.core.Vertx; + +public class GrpcRequestContextGrpcInterceptor implements ServerInterceptor { + + private final ManagedContext reqContext; + private static final Logger LOGGER = Logger.getLogger(GrpcRequestContextGrpcInterceptor.class.getName()); + + public GrpcRequestContextGrpcInterceptor() { + reqContext = Arc.container().requestContext(); + } + + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata headers, + ServerCallHandler next) { + + // This interceptor is called first, so, we should be on the event loop. + Context capturedVertxContext = Vertx.currentContext(); + if (capturedVertxContext != null) { + GrpcRequestContextHolder contextHolder = GrpcRequestContextHolder.initialize(capturedVertxContext); + ServerCall.Listener delegate = next + .startCall(new ForwardingServerCall.SimpleForwardingServerCall(call) { + + @Override + public void close(Status status, Metadata trailers) { + super.close(status, trailers); + if (contextHolder.state != null) { + reqContext.destroy(contextHolder.state); + } + } + }, headers); + + return new ForwardingServerCallListener.SimpleForwardingServerCallListener(delegate) { + + @Override + public void onMessage(ReqT message) { + activateContext(); + super.onMessage(message); + } + + @Override + public void onReady() { + activateContext(); + super.onReady(); + } + + @Override + public void onComplete() { + activateContext(); + super.onComplete(); + } + + private void activateContext() { + if (contextHolder.state != null && !reqContext.isActive()) { + reqContext.activate(contextHolder.state); + } + } + }; + } else { + LOGGER.warning("Unable to activate the request scope - interceptor not called on the Vert.x event loop"); + return next.startCall(call, headers); + } + } +} diff --git a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcRequestContextHolder.java b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcRequestContextHolder.java new file mode 100644 index 0000000000000..bfd5e6753b817 --- /dev/null +++ b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcRequestContextHolder.java @@ -0,0 +1,21 @@ +package io.quarkus.grpc.runtime.supports.context; + +import io.quarkus.arc.InjectableContext; +import io.vertx.core.Context; + +public class GrpcRequestContextHolder { + + private static final String GRPC_REQUEST_CONTEXT_STATE = "GRPC_REQUEST_CONTEXT_STATE"; + + volatile InjectableContext.ContextState state; + + public static GrpcRequestContextHolder initialize(Context vertxContext) { + GrpcRequestContextHolder contextHolder = new GrpcRequestContextHolder(); + vertxContext.put(GRPC_REQUEST_CONTEXT_STATE, contextHolder); + return contextHolder; + } + + public static GrpcRequestContextHolder get(Context vertxContext) { + return vertxContext.get(GRPC_REQUEST_CONTEXT_STATE); + } +} diff --git a/extensions/grpc/runtime/src/test/java/io/quarkus/grpc/runtime/supports/BlockingServerInterceptorTest.java b/extensions/grpc/runtime/src/test/java/io/quarkus/grpc/runtime/supports/BlockingServerInterceptorTest.java index 09ecb78a07158..6c759e8b10287 100644 --- a/extensions/grpc/runtime/src/test/java/io/quarkus/grpc/runtime/supports/BlockingServerInterceptorTest.java +++ b/extensions/grpc/runtime/src/test/java/io/quarkus/grpc/runtime/supports/BlockingServerInterceptorTest.java @@ -27,7 +27,7 @@ class BlockingServerInterceptorTest { @BeforeEach void setup() { vertx = Vertx.vertx(); - blockingServerInterceptor = new BlockingServerInterceptor(vertx, Arrays.asList("blocking")); + blockingServerInterceptor = new BlockingServerInterceptor(vertx, Arrays.asList("blocking"), false); } @Test