Skip to content

Commit

Permalink
gRPC fix request context clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
michalszynkiewicz committed Apr 30, 2021
1 parent 1af8504 commit 35a0e98
Show file tree
Hide file tree
Showing 9 changed files with 271 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,23 @@
import java.lang.reflect.Modifier;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.eclipse.microprofile.config.Config;
import org.eclipse.microprofile.config.ConfigProvider;
import org.jboss.jandex.AnnotationTarget;
import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.DotName;
import org.jboss.jandex.MethodInfo;
import org.jboss.logging.Logger;

import io.grpc.internal.ServerImpl;
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.deployment.AnnotationsTransformerBuildItem;
import io.quarkus.arc.processor.AnnotationsTransformer;
import io.quarkus.arc.processor.DotNames;
import io.quarkus.deployment.IsDevelopment;
import io.quarkus.deployment.IsNormal;
Expand All @@ -39,6 +45,8 @@
import io.quarkus.grpc.runtime.config.GrpcServerBuildTimeConfig;
import io.quarkus.grpc.runtime.health.GrpcHealthEndpoint;
import io.quarkus.grpc.runtime.health.GrpcHealthStorage;
import io.quarkus.grpc.runtime.supports.context.CleanUpRequestContext;
import io.quarkus.grpc.runtime.supports.context.GrpcRequestContextCdiInterceptor;
import io.quarkus.kubernetes.spi.KubernetesPortBuildItem;
import io.quarkus.netty.deployment.MinNettyAllocatorMaxOrderBuildItem;
import io.quarkus.smallrye.health.deployment.spi.HealthBuildItem;
Expand All @@ -61,9 +69,13 @@ MinNettyAllocatorMaxOrderBuildItem setMinimalNettyMaxOrderSize() {

@BuildStep
void discoverBindableServices(BuildProducer<BindableServiceBuildItem> bindables,
CombinedIndexBuildItem combinedIndexBuildItem) {
CombinedIndexBuildItem combinedIndexBuildItem,
BuildProducer<AnnotationsTransformerBuildItem> annotationTransformers) {
Collection<ClassInfo> bindableServices = combinedIndexBuildItem.getIndex()
.getAllKnownImplementors(GrpcDotNames.BINDABLE_SERVICE);

final Set<DotName> grpcServiceNames = new HashSet<>();

for (ClassInfo service : bindableServices) {
if (!Modifier.isAbstract(service.flags()) && service.classAnnotation(DotNames.SINGLETON) != null) {
BindableServiceBuildItem item = new BindableServiceBuildItem(service.name());
Expand All @@ -72,9 +84,25 @@ void discoverBindableServices(BuildProducer<BindableServiceBuildItem> bindables,
item.registerBlockingMethod(method.name());
}
}
grpcServiceNames.add(service.name());
bindables.produce(item);
}
}

annotationTransformers.produce(new AnnotationsTransformerBuildItem(new AnnotationsTransformer() {
@Override
public boolean appliesTo(AnnotationTarget.Kind kind) {
return kind == AnnotationTarget.Kind.METHOD;
}

@Override
public void transform(TransformationContext transformationContext) {
MethodInfo target = transformationContext.getTarget().asMethod();
if (grpcServiceNames.contains(target.declaringClass().name()) && !"bindService".equals(target.name())) {
transformationContext.transform().add(CleanUpRequestContext.class).done();
}
}
}));
}

@BuildStep(onlyIf = IsNormal.class)
Expand All @@ -87,6 +115,12 @@ KubernetesPortBuildItem registerGrpcServiceInKubernetes(List<BindableServiceBuil
return null;
}

@BuildStep
void registerAdditionalBeans(BuildProducer<AdditionalBeanBuildItem> beans) {
beans.produce(new AdditionalBeanBuildItem(GrpcRequestContextCdiInterceptor.class));
beans.produce(new AdditionalBeanBuildItem(CleanUpRequestContext.class));
}

@BuildStep
void buildContainerBean(BuildProducer<AdditionalBeanBuildItem> beans,
List<BindableServiceBuildItem> bindables, BuildProducer<FeatureBuildItem> features) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package io.quarkus.grpc.server.devmode;

import javax.enterprise.context.RequestScoped;
import javax.inject.Singleton;

import devmodetest.v1.Devmodetest;
import io.grpc.examples.helloworld.GreeterGrpc;
import io.grpc.examples.helloworld.HelloReply;
import io.grpc.examples.helloworld.HelloRequest;
import io.grpc.stub.StreamObserver;
import io.quarkus.arc.Arc;

@Singleton
public class DevModeTestService extends GreeterGrpc.GreeterImplBase {
Expand All @@ -20,7 +22,11 @@ public void sayHello(HelloRequest request, StreamObserver<HelloReply> responseOb
} else {
response = greeting + request.getName();
}
responseObserver.onNext(HelloReply.newBuilder().setMessage(response).build());
if (Arc.container().getActiveContext(RequestScoped.class) != null) {
responseObserver.onNext(HelloReply.newBuilder().setMessage(response).build());
} else {
throw new IllegalStateException("request context not active, failing");
}
responseObserver.onCompleted();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ public void onCompleted() {
};
}

;

@Override
public StreamObserver<Messages.StreamingOutputCallRequest> fullDuplexCall(
StreamObserver<Messages.StreamingOutputCallResponse> responseObserver) {
Expand Down Expand Up @@ -122,6 +120,7 @@ public StreamObserver<Messages.StreamingOutputCallRequest> halfDuplexCall(
return new StreamObserver<Messages.StreamingOutputCallRequest>() {
@Override
public void onNext(Messages.StreamingOutputCallRequest streamingOutputCallRequest) {
assertThatTheRequestScopeIsActive();
String payload = streamingOutputCallRequest.getPayload().getBody().toStringUtf8();
ByteString value = ByteString.copyFromUtf8(payload.toUpperCase());
Messages.Payload response = Messages.Payload.newBuilder().setBody(value).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import io.quarkus.grpc.runtime.reflection.ReflectionService;
import io.quarkus.grpc.runtime.supports.BlockingServerInterceptor;
import io.quarkus.grpc.runtime.supports.CompressionInterceptor;
import io.quarkus.grpc.runtime.supports.RequestScopeHandlerInterceptor;
import io.quarkus.grpc.runtime.supports.context.GrpcRequestContextGrpcInterceptor;
import io.quarkus.runtime.LaunchMode;
import io.quarkus.runtime.RuntimeValue;
import io.quarkus.runtime.ShutdownContext;
Expand All @@ -66,6 +66,8 @@ public class GrpcServerRecorder {
private static final AtomicInteger grpcVerticleCount = new AtomicInteger(0);
private Map<String, List<String>> blockingMethodsPerService = Collections.emptyMap();

private static volatile DevModeWrapper devModeWrapper;

public void initializeGrpcServer(RuntimeValue<Vertx> vertxSupplier,
GrpcConfiguration cfg,
ShutdownContext shutdown,
Expand All @@ -90,7 +92,7 @@ public void initializeGrpcServer(RuntimeValue<Vertx> vertxSupplier,
if (GrpcServerReloader.getServer() == null) {
devModeStart(grpcContainer, vertx, configuration, shutdown, launchMode);
} else {
devModeReload(grpcContainer);
devModeReload(grpcContainer, vertx, configuration);
}
} else {
prodStart(grpcContainer, vertx, configuration, launchMode);
Expand Down Expand Up @@ -159,6 +161,8 @@ private void devModeStart(GrpcContainer grpcContainer, Vertx vertx, GrpcServerCo
ShutdownContext shutdown, LaunchMode launchMode) {
CompletableFuture<Boolean> future = new CompletableFuture<>();

devModeWrapper = new DevModeWrapper(Thread.currentThread().getContextClassLoader());

VertxServer vertxServer = buildServer(vertx, configuration, grpcContainer, launchMode)
.start(new Handler<AsyncResult<Void>>() { // NOSONAR
@Override
Expand Down Expand Up @@ -251,19 +255,18 @@ private static class GrpcServiceDefinition {
}

public String getImplementationClassName() {
return service.getClass().getName();
// all grpc services have a io.quarkus.grpc.runtime.supports.context.GrpcRequestContextCdiInterceptor
// this means Arc passes a subclass to grpc internals. That's why we take superclass here
return service.getClass().getSuperclass().getName();
}
}

private static void devModeReload(GrpcContainer grpcContainer) {
List<GrpcServiceDefinition> svc = collectServiceDefinitions(grpcContainer.getServices());
private void devModeReload(GrpcContainer grpcContainer, Vertx vertx, GrpcServerConfiguration configuration) {
List<GrpcServiceDefinition> services = collectServiceDefinitions(grpcContainer.getServices());

List<ServerServiceDefinition> definitions = new ArrayList<>();
Map<String, ServerMethodDefinition<?, ?>> methods = new HashMap<>();
for (GrpcServiceDefinition service : svc) {
for (ServerMethodDefinition<?, ?> method : service.definition.getMethods()) {
methods.put(method.getMethodDescriptor().getFullMethodName(), method);
}
for (GrpcServiceDefinition service : services) {
definitions.add(service.definition);
}

Expand All @@ -272,8 +275,20 @@ private static void devModeReload(GrpcContainer grpcContainer) {
for (ServerMethodDefinition<?, ?> method : reflectionService.getMethods()) {
methods.put(method.getMethodDescriptor().getFullMethodName(), method);
}
List<ServerServiceDefinition> servicesWithInterceptors = new ArrayList<>();
CompressionInterceptor compressionInterceptor = prepareCompressionInterceptor(configuration);
for (GrpcServiceDefinition service : services) {
servicesWithInterceptors.add(serviceWithInterceptors(vertx, compressionInterceptor, service));
}

for (ServerServiceDefinition serviceWithInterceptors : servicesWithInterceptors) {
for (ServerMethodDefinition<?, ?> method : serviceWithInterceptors.getMethods()) {
methods.put(method.getMethodDescriptor().getFullMethodName(), method);
}
}
devModeWrapper = new DevModeWrapper(Thread.currentThread().getContextClassLoader());

GrpcServerReloader.reinitialize(definitions, methods, grpcContainer.getSortedInterceptors());
GrpcServerReloader.reinitialize(servicesWithInterceptors, methods, grpcContainer.getSortedInterceptors());
}

public static int getVerticleCount() {
Expand Down Expand Up @@ -320,26 +335,10 @@ public void handle(HttpServerOptions options) {
List<GrpcServiceDefinition> toBeRegistered = collectServiceDefinitions(grpcContainer.getServices());
List<ServerServiceDefinition> definitions = new ArrayList<>();

CompressionInterceptor compressionInterceptor = null;
if (configuration.compression.isPresent()) {
compressionInterceptor = new CompressionInterceptor(configuration.compression.get());
}
CompressionInterceptor compressionInterceptor = prepareCompressionInterceptor(configuration);

for (GrpcServiceDefinition service : toBeRegistered) {
List<ServerInterceptor> interceptors = new ArrayList<>();
if (compressionInterceptor != null) {
interceptors.add(compressionInterceptor);
}
// We only register the blocking interceptor if needed by at least one method of the service.
if (!blockingMethodsPerService.isEmpty()) {
List<String> list = blockingMethodsPerService.get(service.getImplementationClassName());
if (list != null) {
interceptors.add(new BlockingServerInterceptor(vertx, list));
}
}
// Order matters! Request scope must be called first (on the event loop) and so should be last in the list...
interceptors.add(new RequestScopeHandlerInterceptor());
builder.addService(ServerInterceptors.intercept(service.definition, interceptors));
builder.addService(serviceWithInterceptors(vertx, compressionInterceptor, service));
LOGGER.debugf("Registered gRPC service '%s'", service.definition.getServiceDescriptor().getName());
definitions.add(service.definition);
}
Expand Down Expand Up @@ -367,7 +366,7 @@ public void handle(Promise<Boolean> event) {
new Handler<AsyncResult<Boolean>>() {
@Override
public void handle(AsyncResult<Boolean> result) {
command.run();
devModeWrapper.run(command);
}
});
}
Expand All @@ -381,6 +380,38 @@ public void handle(AsyncResult<Boolean> result) {
return builder.build();
}

/**
* Compression interceptor if needed, null otherwise
*
* @param configuration gRPC server configuration
* @return interceptor or null
*/
private CompressionInterceptor prepareCompressionInterceptor(GrpcServerConfiguration configuration) {
CompressionInterceptor compressionInterceptor = null;
if (configuration.compression.isPresent()) {
compressionInterceptor = new CompressionInterceptor(configuration.compression.get());
}
return compressionInterceptor;
}

private ServerServiceDefinition serviceWithInterceptors(Vertx vertx, CompressionInterceptor compressionInterceptor,
GrpcServiceDefinition service) {
List<ServerInterceptor> interceptors = new ArrayList<>();
if (compressionInterceptor != null) {
interceptors.add(compressionInterceptor);
}
// We only register the blocking interceptor if needed by at least one method of the service.
if (!blockingMethodsPerService.isEmpty()) {
List<String> list = blockingMethodsPerService.get(service.getImplementationClassName());
if (list != null) {
interceptors.add(new BlockingServerInterceptor(vertx, list));
}
}
// Order matters! Request scope must be called first (on the event loop) and so should be last in the list...
interceptors.add(new GrpcRequestContextGrpcInterceptor());
return ServerInterceptors.intercept(service.definition, interceptors);
}

private class GrpcServerVerticle extends AbstractVerticle {
private final GrpcServerConfiguration configuration;
private final GrpcContainer grpcContainer;
Expand Down Expand Up @@ -432,4 +463,17 @@ public void handle(AsyncResult<Void> ar) {
});
}
}

private class DevModeWrapper {
private final ClassLoader classLoader;

public DevModeWrapper(ClassLoader contextClassLoader) {
classLoader = contextClassLoader;
}

public void run(Runnable command) {
Thread.currentThread().setContextClassLoader(classLoader);
command.run();
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.quarkus.grpc.runtime.supports.context;

import java.lang.annotation.ElementType;
import java.lang.annotation.Inherited;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

import javax.interceptor.InterceptorBinding;

@Inherited
@InterceptorBinding
@Target({ ElementType.TYPE })
@Retention(RetentionPolicy.RUNTIME)
public @interface CleanUpRequestContext {
}
Loading

0 comments on commit 35a0e98

Please sign in to comment.