diff --git a/extensions/grpc/deployment/src/main/java/io/quarkus/grpc/deployment/GrpcServerProcessor.java b/extensions/grpc/deployment/src/main/java/io/quarkus/grpc/deployment/GrpcServerProcessor.java index a66be90b8d7e4b..bacc279198d822 100644 --- a/extensions/grpc/deployment/src/main/java/io/quarkus/grpc/deployment/GrpcServerProcessor.java +++ b/extensions/grpc/deployment/src/main/java/io/quarkus/grpc/deployment/GrpcServerProcessor.java @@ -6,17 +6,23 @@ import java.lang.reflect.Modifier; import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import org.eclipse.microprofile.config.Config; import org.eclipse.microprofile.config.ConfigProvider; +import org.jboss.jandex.AnnotationTarget; import org.jboss.jandex.ClassInfo; +import org.jboss.jandex.DotName; import org.jboss.jandex.MethodInfo; import org.jboss.logging.Logger; import io.grpc.internal.ServerImpl; import io.quarkus.arc.deployment.AdditionalBeanBuildItem; +import io.quarkus.arc.deployment.AnnotationsTransformerBuildItem; +import io.quarkus.arc.processor.AnnotationsTransformer; import io.quarkus.arc.processor.DotNames; import io.quarkus.deployment.IsDevelopment; import io.quarkus.deployment.IsNormal; @@ -39,6 +45,8 @@ import io.quarkus.grpc.runtime.config.GrpcServerBuildTimeConfig; import io.quarkus.grpc.runtime.health.GrpcHealthEndpoint; import io.quarkus.grpc.runtime.health.GrpcHealthStorage; +import io.quarkus.grpc.runtime.supports.context.CleanUpRequestContext; +import io.quarkus.grpc.runtime.supports.context.GrpcRequestContextCdiInterceptor; import io.quarkus.kubernetes.spi.KubernetesPortBuildItem; import io.quarkus.netty.deployment.MinNettyAllocatorMaxOrderBuildItem; import io.quarkus.smallrye.health.deployment.spi.HealthBuildItem; @@ -61,9 +69,13 @@ MinNettyAllocatorMaxOrderBuildItem setMinimalNettyMaxOrderSize() { @BuildStep void discoverBindableServices(BuildProducer bindables, - CombinedIndexBuildItem combinedIndexBuildItem) { + CombinedIndexBuildItem combinedIndexBuildItem, + BuildProducer annotationTransformers) { Collection bindableServices = combinedIndexBuildItem.getIndex() .getAllKnownImplementors(GrpcDotNames.BINDABLE_SERVICE); + + final Set grpcServiceNames = new HashSet<>(); + for (ClassInfo service : bindableServices) { if (!Modifier.isAbstract(service.flags()) && service.classAnnotation(DotNames.SINGLETON) != null) { BindableServiceBuildItem item = new BindableServiceBuildItem(service.name()); @@ -72,9 +84,25 @@ void discoverBindableServices(BuildProducer bindables, item.registerBlockingMethod(method.name()); } } + grpcServiceNames.add(service.name()); bindables.produce(item); } } + + annotationTransformers.produce(new AnnotationsTransformerBuildItem(new AnnotationsTransformer() { + @Override + public boolean appliesTo(AnnotationTarget.Kind kind) { + return kind == AnnotationTarget.Kind.METHOD; + } + + @Override + public void transform(TransformationContext transformationContext) { + MethodInfo target = transformationContext.getTarget().asMethod(); + if (grpcServiceNames.contains(target.declaringClass().name()) && !"bindService".equals(target.name())) { + transformationContext.transform().add(CleanUpRequestContext.class).done(); + } + } + })); } @BuildStep(onlyIf = IsNormal.class) @@ -87,6 +115,12 @@ KubernetesPortBuildItem registerGrpcServiceInKubernetes(List beans) { + beans.produce(new AdditionalBeanBuildItem(GrpcRequestContextCdiInterceptor.class)); + beans.produce(new AdditionalBeanBuildItem(CleanUpRequestContext.class)); + } + @BuildStep void buildContainerBean(BuildProducer beans, List bindables, BuildProducer features) { diff --git a/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/devmode/DevModeTestService.java b/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/devmode/DevModeTestService.java index 939b774c07e7ed..274bb1c2398350 100644 --- a/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/devmode/DevModeTestService.java +++ b/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/devmode/DevModeTestService.java @@ -1,5 +1,6 @@ package io.quarkus.grpc.server.devmode; +import javax.enterprise.context.RequestScoped; import javax.inject.Singleton; import devmodetest.v1.Devmodetest; @@ -7,6 +8,7 @@ import io.grpc.examples.helloworld.HelloReply; import io.grpc.examples.helloworld.HelloRequest; import io.grpc.stub.StreamObserver; +import io.quarkus.arc.Arc; @Singleton public class DevModeTestService extends GreeterGrpc.GreeterImplBase { @@ -20,7 +22,11 @@ public void sayHello(HelloRequest request, StreamObserver responseOb } else { response = greeting + request.getName(); } - responseObserver.onNext(HelloReply.newBuilder().setMessage(response).build()); + if (Arc.container().getActiveContext(RequestScoped.class) != null) { + responseObserver.onNext(HelloReply.newBuilder().setMessage(response).build()); + } else { + throw new IllegalStateException("request context not active, failing"); + } responseObserver.onCompleted(); } } \ No newline at end of file diff --git a/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/services/TestService.java b/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/services/TestService.java index 79c5749c867fe2..dcb9ef3f4033a4 100644 --- a/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/services/TestService.java +++ b/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/server/services/TestService.java @@ -82,8 +82,6 @@ public void onCompleted() { }; } - ; - @Override public StreamObserver fullDuplexCall( StreamObserver responseObserver) { @@ -122,6 +120,7 @@ public StreamObserver halfDuplexCall( return new StreamObserver() { @Override public void onNext(Messages.StreamingOutputCallRequest streamingOutputCallRequest) { + assertThatTheRequestScopeIsActive(); String payload = streamingOutputCallRequest.getPayload().getBody().toStringUtf8(); ByteString value = ByteString.copyFromUtf8(payload.toUpperCase()); Messages.Payload response = Messages.Payload.newBuilder().setBody(value).build(); diff --git a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/GrpcServerRecorder.java b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/GrpcServerRecorder.java index 0cbbdca5d85c24..19b455b03e5364 100644 --- a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/GrpcServerRecorder.java +++ b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/GrpcServerRecorder.java @@ -42,7 +42,7 @@ import io.quarkus.grpc.runtime.reflection.ReflectionService; import io.quarkus.grpc.runtime.supports.BlockingServerInterceptor; import io.quarkus.grpc.runtime.supports.CompressionInterceptor; -import io.quarkus.grpc.runtime.supports.RequestScopeHandlerInterceptor; +import io.quarkus.grpc.runtime.supports.context.GrpcRequestContextGrpcInterceptor; import io.quarkus.runtime.LaunchMode; import io.quarkus.runtime.RuntimeValue; import io.quarkus.runtime.ShutdownContext; @@ -66,6 +66,8 @@ public class GrpcServerRecorder { private static final AtomicInteger grpcVerticleCount = new AtomicInteger(0); private Map> blockingMethodsPerService = Collections.emptyMap(); + private static volatile DevModeWrapper devModeWrapper; + public void initializeGrpcServer(RuntimeValue vertxSupplier, GrpcConfiguration cfg, ShutdownContext shutdown, @@ -90,7 +92,7 @@ public void initializeGrpcServer(RuntimeValue vertxSupplier, if (GrpcServerReloader.getServer() == null) { devModeStart(grpcContainer, vertx, configuration, shutdown, launchMode); } else { - devModeReload(grpcContainer); + devModeReload(grpcContainer, vertx, configuration); } } else { prodStart(grpcContainer, vertx, configuration, launchMode); @@ -159,6 +161,8 @@ private void devModeStart(GrpcContainer grpcContainer, Vertx vertx, GrpcServerCo ShutdownContext shutdown, LaunchMode launchMode) { CompletableFuture future = new CompletableFuture<>(); + devModeWrapper = new DevModeWrapper(Thread.currentThread().getContextClassLoader()); + VertxServer vertxServer = buildServer(vertx, configuration, grpcContainer, launchMode) .start(new Handler>() { // NOSONAR @Override @@ -251,19 +255,18 @@ private static class GrpcServiceDefinition { } public String getImplementationClassName() { - return service.getClass().getName(); + // all grpc services have a io.quarkus.grpc.runtime.supports.context.GrpcRequestContextCdiInterceptor + // this means Arc passes a subclass to grpc internals. That's why we take superclass here + return service.getClass().getSuperclass().getName(); } } - private static void devModeReload(GrpcContainer grpcContainer) { - List svc = collectServiceDefinitions(grpcContainer.getServices()); + private void devModeReload(GrpcContainer grpcContainer, Vertx vertx, GrpcServerConfiguration configuration) { + List services = collectServiceDefinitions(grpcContainer.getServices()); List definitions = new ArrayList<>(); Map> methods = new HashMap<>(); - for (GrpcServiceDefinition service : svc) { - for (ServerMethodDefinition method : service.definition.getMethods()) { - methods.put(method.getMethodDescriptor().getFullMethodName(), method); - } + for (GrpcServiceDefinition service : services) { definitions.add(service.definition); } @@ -272,8 +275,20 @@ private static void devModeReload(GrpcContainer grpcContainer) { for (ServerMethodDefinition method : reflectionService.getMethods()) { methods.put(method.getMethodDescriptor().getFullMethodName(), method); } + List servicesWithInterceptors = new ArrayList<>(); + CompressionInterceptor compressionInterceptor = prepareCompressionInterceptor(configuration); + for (GrpcServiceDefinition service : services) { + servicesWithInterceptors.add(serviceWithInterceptors(vertx, compressionInterceptor, service)); + } + + for (ServerServiceDefinition serviceWithInterceptors : servicesWithInterceptors) { + for (ServerMethodDefinition method : serviceWithInterceptors.getMethods()) { + methods.put(method.getMethodDescriptor().getFullMethodName(), method); + } + } + devModeWrapper = new DevModeWrapper(Thread.currentThread().getContextClassLoader()); - GrpcServerReloader.reinitialize(definitions, methods, grpcContainer.getSortedInterceptors()); + GrpcServerReloader.reinitialize(servicesWithInterceptors, methods, grpcContainer.getSortedInterceptors()); } public static int getVerticleCount() { @@ -320,26 +335,10 @@ public void handle(HttpServerOptions options) { List toBeRegistered = collectServiceDefinitions(grpcContainer.getServices()); List definitions = new ArrayList<>(); - CompressionInterceptor compressionInterceptor = null; - if (configuration.compression.isPresent()) { - compressionInterceptor = new CompressionInterceptor(configuration.compression.get()); - } + CompressionInterceptor compressionInterceptor = prepareCompressionInterceptor(configuration); for (GrpcServiceDefinition service : toBeRegistered) { - List interceptors = new ArrayList<>(); - if (compressionInterceptor != null) { - interceptors.add(compressionInterceptor); - } - // We only register the blocking interceptor if needed by at least one method of the service. - if (!blockingMethodsPerService.isEmpty()) { - List list = blockingMethodsPerService.get(service.getImplementationClassName()); - if (list != null) { - interceptors.add(new BlockingServerInterceptor(vertx, list)); - } - } - // Order matters! Request scope must be called first (on the event loop) and so should be last in the list... - interceptors.add(new RequestScopeHandlerInterceptor()); - builder.addService(ServerInterceptors.intercept(service.definition, interceptors)); + builder.addService(serviceWithInterceptors(vertx, compressionInterceptor, service)); LOGGER.debugf("Registered gRPC service '%s'", service.definition.getServiceDescriptor().getName()); definitions.add(service.definition); } @@ -367,7 +366,7 @@ public void handle(Promise event) { new Handler>() { @Override public void handle(AsyncResult result) { - command.run(); + devModeWrapper.run(command); } }); } @@ -381,6 +380,38 @@ public void handle(AsyncResult result) { return builder.build(); } + /** + * Compression interceptor if needed, null otherwise + * + * @param configuration gRPC server configuration + * @return interceptor or null + */ + private CompressionInterceptor prepareCompressionInterceptor(GrpcServerConfiguration configuration) { + CompressionInterceptor compressionInterceptor = null; + if (configuration.compression.isPresent()) { + compressionInterceptor = new CompressionInterceptor(configuration.compression.get()); + } + return compressionInterceptor; + } + + private ServerServiceDefinition serviceWithInterceptors(Vertx vertx, CompressionInterceptor compressionInterceptor, + GrpcServiceDefinition service) { + List interceptors = new ArrayList<>(); + if (compressionInterceptor != null) { + interceptors.add(compressionInterceptor); + } + // We only register the blocking interceptor if needed by at least one method of the service. + if (!blockingMethodsPerService.isEmpty()) { + List list = blockingMethodsPerService.get(service.getImplementationClassName()); + if (list != null) { + interceptors.add(new BlockingServerInterceptor(vertx, list)); + } + } + // Order matters! Request scope must be called first (on the event loop) and so should be last in the list... + interceptors.add(new GrpcRequestContextGrpcInterceptor()); + return ServerInterceptors.intercept(service.definition, interceptors); + } + private class GrpcServerVerticle extends AbstractVerticle { private final GrpcServerConfiguration configuration; private final GrpcContainer grpcContainer; @@ -432,4 +463,17 @@ public void handle(AsyncResult ar) { }); } } + + private class DevModeWrapper { + private final ClassLoader classLoader; + + public DevModeWrapper(ClassLoader contextClassLoader) { + classLoader = contextClassLoader; + } + + public void run(Runnable command) { + Thread.currentThread().setContextClassLoader(classLoader); + command.run(); + } + } } diff --git a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/RequestScopeHandlerInterceptor.java b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/RequestScopeHandlerInterceptor.java deleted file mode 100644 index 40ed99d29f5e53..00000000000000 --- a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/RequestScopeHandlerInterceptor.java +++ /dev/null @@ -1,58 +0,0 @@ -package io.quarkus.grpc.runtime.supports; - -import org.jboss.logmanager.Logger; - -import io.grpc.ForwardingServerCall; -import io.grpc.Metadata; -import io.grpc.ServerCall; -import io.grpc.ServerCallHandler; -import io.grpc.ServerInterceptor; -import io.grpc.Status; -import io.quarkus.arc.Arc; -import io.quarkus.arc.ManagedContext; -import io.vertx.core.Context; -import io.vertx.core.Handler; -import io.vertx.core.Vertx; - -public class RequestScopeHandlerInterceptor implements ServerInterceptor { - - private final ManagedContext reqContext; - private static final Logger LOGGER = Logger.getLogger(RequestScopeHandlerInterceptor.class.getName()); - - public RequestScopeHandlerInterceptor() { - reqContext = Arc.container().requestContext(); - } - - @Override - public ServerCall.Listener interceptCall(ServerCall call, - Metadata headers, - ServerCallHandler next) { - - // This interceptor is called first, so, we should be on the event loop. - Context capturedVertxContext = Vertx.currentContext(); - if (capturedVertxContext != null) { - boolean activateAndDeactivateContext = !reqContext.isActive(); - if (activateAndDeactivateContext) { - reqContext.activate(); - } - return next.startCall(new ForwardingServerCall.SimpleForwardingServerCall(call) { - @Override - public void close(Status status, Metadata trailers) { - super.close(status, trailers); - if (activateAndDeactivateContext) { - capturedVertxContext.runOnContext(new Handler() { - @Override - public void handle(Void ignored) { - reqContext.deactivate(); - } - }); - } - } - }, headers); - } else { - LOGGER.warning("Unable to activate the request scope - interceptor not called on the Vert.x event loop"); - return next.startCall(call, headers); - } - } - -} diff --git a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/CleanUpRequestContext.java b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/CleanUpRequestContext.java new file mode 100644 index 00000000000000..434c49f7c2fc08 --- /dev/null +++ b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/CleanUpRequestContext.java @@ -0,0 +1,16 @@ +package io.quarkus.grpc.runtime.supports.context; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Inherited; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import javax.interceptor.InterceptorBinding; + +@Inherited +@InterceptorBinding +@Target({ ElementType.TYPE }) +@Retention(RetentionPolicy.RUNTIME) +public @interface CleanUpRequestContext { +} diff --git a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcRequestContextCdiInterceptor.java b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcRequestContextCdiInterceptor.java new file mode 100644 index 00000000000000..7b4793a5ea99e0 --- /dev/null +++ b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcRequestContextCdiInterceptor.java @@ -0,0 +1,40 @@ +package io.quarkus.grpc.runtime.supports.context; + +import javax.interceptor.AroundInvoke; +import javax.interceptor.Interceptor; +import javax.interceptor.InvocationContext; + +import io.quarkus.arc.Arc; +import io.quarkus.arc.ManagedContext; +import io.vertx.core.Context; +import io.vertx.core.Vertx; + +@Interceptor +@CleanUpRequestContext +public class GrpcRequestContextCdiInterceptor { + + @AroundInvoke + public Object cleanUpContext(InvocationContext invocationContext) throws Exception { + boolean cleanUp = false; + ManagedContext requestContext = Arc.container().requestContext(); + if (!requestContext.isActive()) { + Context context = Vertx.currentContext(); + + if (context != null) { + cleanUp = true; + requestContext.activate(); + GrpcRequestContextHolder contextHolder = GrpcRequestContextHolder.get(context); + if (contextHolder != null) { + contextHolder.state = requestContext.getState(); + } + } + } + try { + return invocationContext.proceed(); + } finally { + if (cleanUp) { + requestContext.deactivate(); + } + } + } +} diff --git a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcRequestContextGrpcInterceptor.java b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcRequestContextGrpcInterceptor.java new file mode 100644 index 00000000000000..2c3356c427687f --- /dev/null +++ b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcRequestContextGrpcInterceptor.java @@ -0,0 +1,78 @@ +package io.quarkus.grpc.runtime.supports.context; + +import org.jboss.logmanager.Logger; + +import io.grpc.ForwardingServerCall; +import io.grpc.ForwardingServerCallListener; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import io.quarkus.arc.Arc; +import io.quarkus.arc.ManagedContext; +import io.vertx.core.Context; +import io.vertx.core.Vertx; + +public class GrpcRequestContextGrpcInterceptor implements ServerInterceptor { + + private final ManagedContext reqContext; + private static final Logger LOGGER = Logger.getLogger(GrpcRequestContextGrpcInterceptor.class.getName()); + + public GrpcRequestContextGrpcInterceptor() { + reqContext = Arc.container().requestContext(); + } + + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata headers, + ServerCallHandler next) { + + // This interceptor is called first, so, we should be on the event loop. + Context capturedVertxContext = Vertx.currentContext(); + if (capturedVertxContext != null) { + GrpcRequestContextHolder contextHolder = GrpcRequestContextHolder.initialize(capturedVertxContext); + ServerCall.Listener delegate = next + .startCall(new ForwardingServerCall.SimpleForwardingServerCall(call) { + + @Override + public void close(Status status, Metadata trailers) { + super.close(status, trailers); + if (contextHolder.state != null) { + reqContext.destroy(contextHolder.state); + } + } + }, headers); + + return new ForwardingServerCallListener.SimpleForwardingServerCallListener(delegate) { + + @Override + public void onMessage(ReqT message) { + activateContext(); + super.onMessage(message); + } + + @Override + public void onReady() { + activateContext(); + super.onReady(); + } + + @Override + public void onComplete() { + activateContext(); + super.onComplete(); + } + + private void activateContext() { + if (contextHolder.state != null && !reqContext.isActive()) { + reqContext.activate(contextHolder.state); + } + } + }; + } else { + LOGGER.warning("Unable to activate the request scope - interceptor not called on the Vert.x event loop"); + return next.startCall(call, headers); + } + } +} diff --git a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcRequestContextHolder.java b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcRequestContextHolder.java new file mode 100644 index 00000000000000..bfd5e6753b8179 --- /dev/null +++ b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/runtime/supports/context/GrpcRequestContextHolder.java @@ -0,0 +1,21 @@ +package io.quarkus.grpc.runtime.supports.context; + +import io.quarkus.arc.InjectableContext; +import io.vertx.core.Context; + +public class GrpcRequestContextHolder { + + private static final String GRPC_REQUEST_CONTEXT_STATE = "GRPC_REQUEST_CONTEXT_STATE"; + + volatile InjectableContext.ContextState state; + + public static GrpcRequestContextHolder initialize(Context vertxContext) { + GrpcRequestContextHolder contextHolder = new GrpcRequestContextHolder(); + vertxContext.put(GRPC_REQUEST_CONTEXT_STATE, contextHolder); + return contextHolder; + } + + public static GrpcRequestContextHolder get(Context vertxContext) { + return vertxContext.get(GRPC_REQUEST_CONTEXT_STATE); + } +}