Skip to content

Commit

Permalink
gRPC fix request context clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
michalszynkiewicz committed Apr 29, 2021
1 parent 45b205c commit 50c4a15
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,23 @@
import java.lang.reflect.Modifier;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.eclipse.microprofile.config.Config;
import org.eclipse.microprofile.config.ConfigProvider;
import org.jboss.jandex.AnnotationTarget;
import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.DotName;
import org.jboss.jandex.MethodInfo;
import org.jboss.logging.Logger;

import io.grpc.internal.ServerImpl;
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.deployment.AnnotationsTransformerBuildItem;
import io.quarkus.arc.processor.AnnotationsTransformer;
import io.quarkus.arc.processor.DotNames;
import io.quarkus.deployment.IsDevelopment;
import io.quarkus.deployment.IsNormal;
Expand All @@ -39,6 +45,8 @@
import io.quarkus.grpc.runtime.config.GrpcServerBuildTimeConfig;
import io.quarkus.grpc.runtime.health.GrpcHealthEndpoint;
import io.quarkus.grpc.runtime.health.GrpcHealthStorage;
import io.quarkus.grpc.runtime.supports.context.CleanUpRequestContext;
import io.quarkus.grpc.runtime.supports.context.GrpcRequestContextCdiInterceptor;
import io.quarkus.kubernetes.spi.KubernetesPortBuildItem;
import io.quarkus.netty.deployment.MinNettyAllocatorMaxOrderBuildItem;
import io.quarkus.smallrye.health.deployment.spi.HealthBuildItem;
Expand All @@ -61,9 +69,13 @@ MinNettyAllocatorMaxOrderBuildItem setMinimalNettyMaxOrderSize() {

@BuildStep
void discoverBindableServices(BuildProducer<BindableServiceBuildItem> bindables,
CombinedIndexBuildItem combinedIndexBuildItem) {
CombinedIndexBuildItem combinedIndexBuildItem,
BuildProducer<AnnotationsTransformerBuildItem> annotationTransformers) {
Collection<ClassInfo> bindableServices = combinedIndexBuildItem.getIndex()
.getAllKnownImplementors(GrpcDotNames.BINDABLE_SERVICE);

final Set<DotName> grpcServiceNames = new HashSet<>();

for (ClassInfo service : bindableServices) {
if (!Modifier.isAbstract(service.flags()) && service.classAnnotation(DotNames.SINGLETON) != null) {
BindableServiceBuildItem item = new BindableServiceBuildItem(service.name());
Expand All @@ -72,9 +84,25 @@ void discoverBindableServices(BuildProducer<BindableServiceBuildItem> bindables,
item.registerBlockingMethod(method.name());
}
}
grpcServiceNames.add(service.name());
bindables.produce(item);
}
}

annotationTransformers.produce(new AnnotationsTransformerBuildItem(new AnnotationsTransformer() {
@Override
public boolean appliesTo(AnnotationTarget.Kind kind) {
return kind == AnnotationTarget.Kind.CLASS;
}

@Override
public void transform(TransformationContext transformationContext) {
AnnotationTarget target = transformationContext.getTarget();
if (grpcServiceNames.contains(target.asClass().name())) {
transformationContext.transform().add(CleanUpRequestContext.class).done();
}
}
}));
}

@BuildStep(onlyIf = IsNormal.class)
Expand All @@ -87,6 +115,12 @@ KubernetesPortBuildItem registerGrpcServiceInKubernetes(List<BindableServiceBuil
return null;
}

@BuildStep
void registerAdditionalBeans(BuildProducer<AdditionalBeanBuildItem> beans) {
beans.produce(new AdditionalBeanBuildItem(GrpcRequestContextCdiInterceptor.class));
beans.produce(new AdditionalBeanBuildItem(CleanUpRequestContext.class));
}

@BuildStep
void buildContainerBean(BuildProducer<AdditionalBeanBuildItem> beans,
List<BindableServiceBuildItem> bindables, BuildProducer<FeatureBuildItem> features) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package io.quarkus.grpc.server.devmode;

import javax.enterprise.context.RequestScoped;
import javax.inject.Singleton;

import devmodetest.v1.Devmodetest;
import io.grpc.examples.helloworld.GreeterGrpc;
import io.grpc.examples.helloworld.HelloReply;
import io.grpc.examples.helloworld.HelloRequest;
import io.grpc.stub.StreamObserver;
import io.quarkus.arc.Arc;

@Singleton
public class DevModeTestService extends GreeterGrpc.GreeterImplBase {
Expand All @@ -20,7 +22,11 @@ public void sayHello(HelloRequest request, StreamObserver<HelloReply> responseOb
} else {
response = greeting + request.getName();
}
responseObserver.onNext(HelloReply.newBuilder().setMessage(response).build());
if (Arc.container().getActiveContext(RequestScoped.class) != null) {
responseObserver.onNext(HelloReply.newBuilder().setMessage(response).build());
} else {
throw new IllegalStateException("request context not active, failing");
}
responseObserver.onCompleted();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import io.quarkus.grpc.runtime.reflection.ReflectionService;
import io.quarkus.grpc.runtime.supports.BlockingServerInterceptor;
import io.quarkus.grpc.runtime.supports.CompressionInterceptor;
import io.quarkus.grpc.runtime.supports.RequestScopeHandlerInterceptor;
import io.quarkus.grpc.runtime.supports.context.GrpcRequestScopeGrpcInterceptor;
import io.quarkus.runtime.LaunchMode;
import io.quarkus.runtime.RuntimeValue;
import io.quarkus.runtime.ShutdownContext;
Expand All @@ -66,6 +66,8 @@ public class GrpcServerRecorder {
private static final AtomicInteger grpcVerticleCount = new AtomicInteger(0);
private Map<String, List<String>> blockingMethodsPerService = Collections.emptyMap();

private static volatile DevModeWrapper devModeWrapper;

public void initializeGrpcServer(RuntimeValue<Vertx> vertxSupplier,
GrpcConfiguration cfg,
ShutdownContext shutdown,
Expand All @@ -90,7 +92,7 @@ public void initializeGrpcServer(RuntimeValue<Vertx> vertxSupplier,
if (GrpcServerReloader.getServer() == null) {
devModeStart(grpcContainer, vertx, configuration, shutdown, launchMode);
} else {
devModeReload(grpcContainer);
devModeReload(grpcContainer, vertx, configuration);
}
} else {
prodStart(grpcContainer, vertx, configuration, launchMode);
Expand Down Expand Up @@ -159,6 +161,8 @@ private void devModeStart(GrpcContainer grpcContainer, Vertx vertx, GrpcServerCo
ShutdownContext shutdown, LaunchMode launchMode) {
CompletableFuture<Boolean> future = new CompletableFuture<>();

devModeWrapper = new DevModeWrapper(Thread.currentThread().getContextClassLoader());

VertxServer vertxServer = buildServer(vertx, configuration, grpcContainer, launchMode)
.start(new Handler<AsyncResult<Void>>() { // NOSONAR
@Override
Expand Down Expand Up @@ -255,15 +259,12 @@ public String getImplementationClassName() {
}
}

private static void devModeReload(GrpcContainer grpcContainer) {
List<GrpcServiceDefinition> svc = collectServiceDefinitions(grpcContainer.getServices());
private void devModeReload(GrpcContainer grpcContainer, Vertx vertx, GrpcServerConfiguration configuration) {
List<GrpcServiceDefinition> services = collectServiceDefinitions(grpcContainer.getServices());

List<ServerServiceDefinition> definitions = new ArrayList<>();
Map<String, ServerMethodDefinition<?, ?>> methods = new HashMap<>();
for (GrpcServiceDefinition service : svc) {
for (ServerMethodDefinition<?, ?> method : service.definition.getMethods()) {
methods.put(method.getMethodDescriptor().getFullMethodName(), method);
}
for (GrpcServiceDefinition service : services) {
definitions.add(service.definition);
}

Expand All @@ -272,8 +273,20 @@ private static void devModeReload(GrpcContainer grpcContainer) {
for (ServerMethodDefinition<?, ?> method : reflectionService.getMethods()) {
methods.put(method.getMethodDescriptor().getFullMethodName(), method);
}
List<ServerServiceDefinition> servicesWithInterceptors = new ArrayList<>();
CompressionInterceptor compressionInterceptor = prepareCompressionInterceptor(configuration);
for (GrpcServiceDefinition service : services) {
servicesWithInterceptors.add(serviceWithInterceptors(vertx, compressionInterceptor, service));
}

for (ServerServiceDefinition serviceWithInterceptors : servicesWithInterceptors) {
for (ServerMethodDefinition<?, ?> method : serviceWithInterceptors.getMethods()) {
methods.put(method.getMethodDescriptor().getFullMethodName(), method);
}
}
devModeWrapper = new DevModeWrapper(Thread.currentThread().getContextClassLoader());

GrpcServerReloader.reinitialize(definitions, methods, grpcContainer.getSortedInterceptors());
GrpcServerReloader.reinitialize(servicesWithInterceptors, methods, grpcContainer.getSortedInterceptors());
}

public static int getVerticleCount() {
Expand Down Expand Up @@ -320,26 +333,10 @@ public void handle(HttpServerOptions options) {
List<GrpcServiceDefinition> toBeRegistered = collectServiceDefinitions(grpcContainer.getServices());
List<ServerServiceDefinition> definitions = new ArrayList<>();

CompressionInterceptor compressionInterceptor = null;
if (configuration.compression.isPresent()) {
compressionInterceptor = new CompressionInterceptor(configuration.compression.get());
}
CompressionInterceptor compressionInterceptor = prepareCompressionInterceptor(configuration);

for (GrpcServiceDefinition service : toBeRegistered) {
List<ServerInterceptor> interceptors = new ArrayList<>();
if (compressionInterceptor != null) {
interceptors.add(compressionInterceptor);
}
// We only register the blocking interceptor if needed by at least one method of the service.
if (!blockingMethodsPerService.isEmpty()) {
List<String> list = blockingMethodsPerService.get(service.getImplementationClassName());
if (list != null) {
interceptors.add(new BlockingServerInterceptor(vertx, list));
}
}
// Order matters! Request scope must be called first (on the event loop) and so should be last in the list...
interceptors.add(new RequestScopeHandlerInterceptor());
builder.addService(ServerInterceptors.intercept(service.definition, interceptors));
builder.addService(serviceWithInterceptors(vertx, compressionInterceptor, service));
LOGGER.debugf("Registered gRPC service '%s'", service.definition.getServiceDescriptor().getName());
definitions.add(service.definition);
}
Expand Down Expand Up @@ -367,7 +364,7 @@ public void handle(Promise<Boolean> event) {
new Handler<AsyncResult<Boolean>>() {
@Override
public void handle(AsyncResult<Boolean> result) {
command.run();
devModeWrapper.run(command);
}
});
}
Expand All @@ -381,6 +378,38 @@ public void handle(AsyncResult<Boolean> result) {
return builder.build();
}

/**
* Compression interceptor if needed, null otherwise
*
* @param configuration gRPC server configuration
* @return interceptor or null
*/
private CompressionInterceptor prepareCompressionInterceptor(GrpcServerConfiguration configuration) {
CompressionInterceptor compressionInterceptor = null;
if (configuration.compression.isPresent()) {
compressionInterceptor = new CompressionInterceptor(configuration.compression.get());
}
return compressionInterceptor;
}

private ServerServiceDefinition serviceWithInterceptors(Vertx vertx, CompressionInterceptor compressionInterceptor,
GrpcServiceDefinition service) {
List<ServerInterceptor> interceptors = new ArrayList<>();
if (compressionInterceptor != null) {
interceptors.add(compressionInterceptor);
}
// We only register the blocking interceptor if needed by at least one method of the service.
if (!blockingMethodsPerService.isEmpty()) {
List<String> list = blockingMethodsPerService.get(service.getImplementationClassName());
if (list != null) {
interceptors.add(new BlockingServerInterceptor(vertx, list));
}
}
// Order matters! Request scope must be called first (on the event loop) and so should be last in the list...
interceptors.add(new GrpcRequestScopeGrpcInterceptor());
return ServerInterceptors.intercept(service.definition, interceptors);
}

private class GrpcServerVerticle extends AbstractVerticle {
private final GrpcServerConfiguration configuration;
private final GrpcContainer grpcContainer;
Expand Down Expand Up @@ -432,4 +461,17 @@ public void handle(AsyncResult<Void> ar) {
});
}
}

private class DevModeWrapper {
private final ClassLoader classLoader;

public DevModeWrapper(ClassLoader contextClassLoader) {
classLoader = contextClassLoader;
}

public void run(Runnable command) {
Thread.currentThread().setContextClassLoader(classLoader);
command.run();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.quarkus.grpc.runtime.supports.context;

import java.lang.annotation.ElementType;
import java.lang.annotation.Inherited;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

import javax.interceptor.InterceptorBinding;

@Inherited
@InterceptorBinding
@Target({ ElementType.TYPE })
@Retention(RetentionPolicy.RUNTIME)
public @interface CleanUpRequestContext {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package io.quarkus.grpc.runtime.supports.context;

import javax.interceptor.AroundInvoke;
import javax.interceptor.Interceptor;
import javax.interceptor.InvocationContext;

import io.quarkus.arc.Arc;
import io.quarkus.arc.ManagedContext;
import io.vertx.core.Context;
import io.vertx.core.Vertx;

@Interceptor
@CleanUpRequestContext
public class GrpcRequestContextCdiInterceptor {
public static final String GRPC_REQUEST_CONTEXT_STATE = "GRPC_REQUEST_CONTEXT_STATE";

@AroundInvoke
public Object cleanUpContext(InvocationContext invocationContext) throws Exception {
boolean cleanUp = false;
ManagedContext requestContext = Arc.container().requestContext();
if (!requestContext.isActive()) {
cleanUp = true;
requestContext.activate();
Context context = Vertx.currentContext();
if (context != null) {
context.put(GRPC_REQUEST_CONTEXT_STATE, requestContext.getState());
}
}
try {
return invocationContext.proceed();
} finally {
if (cleanUp) {
requestContext.deactivate();
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.grpc.runtime.supports;
package io.quarkus.grpc.runtime.supports.context;

import org.jboss.logmanager.Logger;

Expand All @@ -9,17 +9,18 @@
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.quarkus.arc.Arc;
import io.quarkus.arc.InjectableContext;
import io.quarkus.arc.ManagedContext;
import io.vertx.core.Context;
import io.vertx.core.Handler;
import io.vertx.core.Vertx;

public class RequestScopeHandlerInterceptor implements ServerInterceptor {
public class GrpcRequestScopeGrpcInterceptor implements ServerInterceptor {

private final ManagedContext reqContext;
private static final Logger LOGGER = Logger.getLogger(RequestScopeHandlerInterceptor.class.getName());
private static final Logger LOGGER = Logger.getLogger(GrpcRequestScopeGrpcInterceptor.class.getName());

public RequestScopeHandlerInterceptor() {
public GrpcRequestScopeGrpcInterceptor() {
reqContext = Arc.container().requestContext();
}

Expand All @@ -31,19 +32,17 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re
// This interceptor is called first, so, we should be on the event loop.
Context capturedVertxContext = Vertx.currentContext();
if (capturedVertxContext != null) {
boolean activateAndDeactivateContext = !reqContext.isActive();
if (activateAndDeactivateContext) {
reqContext.activate();
}
return next.startCall(new ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT>(call) {
@Override
public void close(Status status, Metadata trailers) {
super.close(status, trailers);
if (activateAndDeactivateContext) {
InjectableContext.ContextState state = capturedVertxContext
.get(GrpcRequestContextCdiInterceptor.GRPC_REQUEST_CONTEXT_STATE);
if (state != null) {
capturedVertxContext.runOnContext(new Handler<Void>() {
@Override
public void handle(Void ignored) {
reqContext.deactivate();
reqContext.destroy(state);
}
});
}
Expand Down

0 comments on commit 50c4a15

Please sign in to comment.