Skip to content

Commit

Permalink
gRPC fix request context clean-up and classloaders after dev mode reload
Browse files Browse the repository at this point in the history
  • Loading branch information
michalszynkiewicz committed Apr 30, 2021
1 parent cfd5d98 commit e7dc28b
Show file tree
Hide file tree
Showing 11 changed files with 321 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,25 @@
import java.lang.reflect.Modifier;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.eclipse.microprofile.config.Config;
import org.eclipse.microprofile.config.ConfigProvider;
import org.jboss.jandex.AnnotationTarget;
import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.DotName;
import org.jboss.jandex.MethodInfo;
import org.jboss.logging.Logger;

import io.grpc.BindableService;
import io.grpc.internal.ServerImpl;
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.deployment.AnnotationsTransformerBuildItem;
import io.quarkus.arc.deployment.ValidationPhaseBuildItem;
import io.quarkus.arc.deployment.ValidationPhaseBuildItem.ValidationErrorBuildItem;
import io.quarkus.arc.processor.AnnotationsTransformer;
import io.quarkus.arc.processor.BeanInfo;
import io.quarkus.arc.processor.BuiltinScope;
import io.quarkus.deployment.IsDevelopment;
Expand All @@ -44,6 +49,8 @@
import io.quarkus.grpc.runtime.config.GrpcServerBuildTimeConfig;
import io.quarkus.grpc.runtime.health.GrpcHealthEndpoint;
import io.quarkus.grpc.runtime.health.GrpcHealthStorage;
import io.quarkus.grpc.runtime.supports.context.CleanUpRequestContext;
import io.quarkus.grpc.runtime.supports.context.GrpcRequestContextCdiInterceptor;
import io.quarkus.kubernetes.spi.KubernetesPortBuildItem;
import io.quarkus.netty.deployment.MinNettyAllocatorMaxOrderBuildItem;
import io.quarkus.smallrye.health.deployment.spi.HealthBuildItem;
Expand All @@ -66,9 +73,13 @@ MinNettyAllocatorMaxOrderBuildItem setMinimalNettyMaxOrderSize() {

@BuildStep
void discoverBindableServices(BuildProducer<BindableServiceBuildItem> bindables,
CombinedIndexBuildItem combinedIndexBuildItem) {
CombinedIndexBuildItem combinedIndexBuildItem,
BuildProducer<AnnotationsTransformerBuildItem> annotationTransformers) {
Collection<ClassInfo> bindableServices = combinedIndexBuildItem.getIndex()
.getAllKnownImplementors(GrpcDotNames.BINDABLE_SERVICE);

final Set<DotName> grpcServiceNames = new HashSet<>();

for (ClassInfo service : bindableServices) {
if (Modifier.isAbstract(service.flags())) {
continue;
Expand All @@ -80,15 +91,31 @@ void discoverBindableServices(BuildProducer<BindableServiceBuildItem> bindables,
}
}
bindables.produce(item);
grpcServiceNames.add(service.name());
}

annotationTransformers.produce(new AnnotationsTransformerBuildItem(new AnnotationsTransformer() {
@Override
public boolean appliesTo(AnnotationTarget.Kind kind) {
return kind == AnnotationTarget.Kind.METHOD;
}

@Override
public void transform(TransformationContext transformationContext) {
MethodInfo target = transformationContext.getTarget().asMethod();
if (grpcServiceNames.contains(target.declaringClass().name()) && !"bindService".equals(target.name())) {
transformationContext.transform().add(CleanUpRequestContext.class).done();
}
}
}));
}

@BuildStep
void validateBindableServices(ValidationPhaseBuildItem validationPhase,
BuildProducer<ValidationErrorBuildItem> errors) {
BuildProducer<ValidationPhaseBuildItem.ValidationErrorBuildItem> errors) {
for (BeanInfo bean : validationPhase.getContext().beans().classBeans().withBeanType(BindableService.class)) {
if (!bean.getScope().getDotName().equals(BuiltinScope.SINGLETON.getName())) {
errors.produce(new ValidationErrorBuildItem(
errors.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem(
new IllegalStateException("A gRPC service bean must have the javax.inject.Singleton scope: " + bean)));
}
}
Expand All @@ -109,6 +136,8 @@ void registerBeans(BuildProducer<AdditionalBeanBuildItem> beans,
List<BindableServiceBuildItem> bindables, BuildProducer<FeatureBuildItem> features) {
// @GrpcService is a CDI stereotype
beans.produce(new AdditionalBeanBuildItem(GrpcService.class));
beans.produce(new AdditionalBeanBuildItem(GrpcRequestContextCdiInterceptor.class));
beans.produce(new AdditionalBeanBuildItem(CleanUpRequestContext.class));
if (!bindables.isEmpty()) {
beans.produce(AdditionalBeanBuildItem.unremovableOf(GrpcContainer.class));
features.produce(new FeatureBuildItem(GRPC_SERVER));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package io.quarkus.grpc.server.devmode;

import javax.enterprise.context.RequestScoped;
import javax.inject.Singleton;

import devmodetest.v1.Devmodetest;
import io.grpc.examples.helloworld.GreeterGrpc;
import io.grpc.examples.helloworld.HelloReply;
import io.grpc.examples.helloworld.HelloRequest;
import io.grpc.stub.StreamObserver;
import io.quarkus.arc.Arc;

@Singleton
public class DevModeTestService extends GreeterGrpc.GreeterImplBase {
Expand All @@ -20,7 +22,11 @@ public void sayHello(HelloRequest request, StreamObserver<HelloReply> responseOb
} else {
response = greeting + request.getName();
}
responseObserver.onNext(HelloReply.newBuilder().setMessage(response).build());
if (Arc.container().getActiveContext(RequestScoped.class) != null) {
responseObserver.onNext(HelloReply.newBuilder().setMessage(response).build());
} else {
throw new IllegalStateException("request context not active, failing");
}
responseObserver.onCompleted();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ public void onCompleted() {
};
}

;

@Override
public StreamObserver<Messages.StreamingOutputCallRequest> fullDuplexCall(
StreamObserver<Messages.StreamingOutputCallResponse> responseObserver) {
Expand Down Expand Up @@ -122,6 +120,7 @@ public StreamObserver<Messages.StreamingOutputCallRequest> halfDuplexCall(
return new StreamObserver<Messages.StreamingOutputCallRequest>() {
@Override
public void onNext(Messages.StreamingOutputCallRequest streamingOutputCallRequest) {
assertThatTheRequestScopeIsActive();
String payload = streamingOutputCallRequest.getPayload().getBody().toStringUtf8();
ByteString value = ByteString.copyFromUtf8(payload.toUpperCase());
Messages.Payload response = Messages.Payload.newBuilder().setBody(value).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import io.quarkus.grpc.runtime.reflection.ReflectionService;
import io.quarkus.grpc.runtime.supports.BlockingServerInterceptor;
import io.quarkus.grpc.runtime.supports.CompressionInterceptor;
import io.quarkus.grpc.runtime.supports.RequestScopeHandlerInterceptor;
import io.quarkus.grpc.runtime.supports.context.GrpcRequestContextGrpcInterceptor;
import io.quarkus.runtime.LaunchMode;
import io.quarkus.runtime.RuntimeValue;
import io.quarkus.runtime.ShutdownContext;
Expand All @@ -66,6 +66,8 @@ public class GrpcServerRecorder {
private static final AtomicInteger grpcVerticleCount = new AtomicInteger(0);
private Map<String, List<String>> blockingMethodsPerService = Collections.emptyMap();

private static volatile DevModeWrapper devModeWrapper;

public void initializeGrpcServer(RuntimeValue<Vertx> vertxSupplier,
GrpcConfiguration cfg,
ShutdownContext shutdown,
Expand All @@ -90,7 +92,7 @@ public void initializeGrpcServer(RuntimeValue<Vertx> vertxSupplier,
if (GrpcServerReloader.getServer() == null) {
devModeStart(grpcContainer, vertx, configuration, shutdown, launchMode);
} else {
devModeReload(grpcContainer);
devModeReload(grpcContainer, vertx, configuration);
}
} else {
prodStart(grpcContainer, vertx, configuration, launchMode);
Expand Down Expand Up @@ -159,6 +161,8 @@ private void devModeStart(GrpcContainer grpcContainer, Vertx vertx, GrpcServerCo
ShutdownContext shutdown, LaunchMode launchMode) {
CompletableFuture<Boolean> future = new CompletableFuture<>();

devModeWrapper = new DevModeWrapper(Thread.currentThread().getContextClassLoader());

VertxServer vertxServer = buildServer(vertx, configuration, grpcContainer, launchMode)
.start(new Handler<AsyncResult<Void>>() { // NOSONAR
@Override
Expand Down Expand Up @@ -251,19 +255,18 @@ private static class GrpcServiceDefinition {
}

public String getImplementationClassName() {
return service.getClass().getName();
// all grpc services have a io.quarkus.grpc.runtime.supports.context.GrpcRequestContextCdiInterceptor
// this means Arc passes a subclass to grpc internals. That's why we take superclass here
return service.getClass().getSuperclass().getName();
}
}

private static void devModeReload(GrpcContainer grpcContainer) {
List<GrpcServiceDefinition> svc = collectServiceDefinitions(grpcContainer.getServices());
private void devModeReload(GrpcContainer grpcContainer, Vertx vertx, GrpcServerConfiguration configuration) {
List<GrpcServiceDefinition> services = collectServiceDefinitions(grpcContainer.getServices());

List<ServerServiceDefinition> definitions = new ArrayList<>();
Map<String, ServerMethodDefinition<?, ?>> methods = new HashMap<>();
for (GrpcServiceDefinition service : svc) {
for (ServerMethodDefinition<?, ?> method : service.definition.getMethods()) {
methods.put(method.getMethodDescriptor().getFullMethodName(), method);
}
for (GrpcServiceDefinition service : services) {
definitions.add(service.definition);
}

Expand All @@ -272,8 +275,20 @@ private static void devModeReload(GrpcContainer grpcContainer) {
for (ServerMethodDefinition<?, ?> method : reflectionService.getMethods()) {
methods.put(method.getMethodDescriptor().getFullMethodName(), method);
}
List<ServerServiceDefinition> servicesWithInterceptors = new ArrayList<>();
CompressionInterceptor compressionInterceptor = prepareCompressionInterceptor(configuration);
for (GrpcServiceDefinition service : services) {
servicesWithInterceptors.add(serviceWithInterceptors(vertx, compressionInterceptor, service, true));
}

for (ServerServiceDefinition serviceWithInterceptors : servicesWithInterceptors) {
for (ServerMethodDefinition<?, ?> method : serviceWithInterceptors.getMethods()) {
methods.put(method.getMethodDescriptor().getFullMethodName(), method);
}
}
devModeWrapper = new DevModeWrapper(Thread.currentThread().getContextClassLoader());

GrpcServerReloader.reinitialize(definitions, methods, grpcContainer.getSortedInterceptors());
GrpcServerReloader.reinitialize(servicesWithInterceptors, methods, grpcContainer.getSortedInterceptors());
}

public static int getVerticleCount() {
Expand Down Expand Up @@ -320,26 +335,11 @@ public void handle(HttpServerOptions options) {
List<GrpcServiceDefinition> toBeRegistered = collectServiceDefinitions(grpcContainer.getServices());
List<ServerServiceDefinition> definitions = new ArrayList<>();

CompressionInterceptor compressionInterceptor = null;
if (configuration.compression.isPresent()) {
compressionInterceptor = new CompressionInterceptor(configuration.compression.get());
}
CompressionInterceptor compressionInterceptor = prepareCompressionInterceptor(configuration);

for (GrpcServiceDefinition service : toBeRegistered) {
List<ServerInterceptor> interceptors = new ArrayList<>();
if (compressionInterceptor != null) {
interceptors.add(compressionInterceptor);
}
// We only register the blocking interceptor if needed by at least one method of the service.
if (!blockingMethodsPerService.isEmpty()) {
List<String> list = blockingMethodsPerService.get(service.getImplementationClassName());
if (list != null) {
interceptors.add(new BlockingServerInterceptor(vertx, list));
}
}
// Order matters! Request scope must be called first (on the event loop) and so should be last in the list...
interceptors.add(new RequestScopeHandlerInterceptor());
builder.addService(ServerInterceptors.intercept(service.definition, interceptors));
builder.addService(
serviceWithInterceptors(vertx, compressionInterceptor, service, launchMode == LaunchMode.DEVELOPMENT));
LOGGER.debugf("Registered gRPC service '%s'", service.definition.getServiceDescriptor().getName());
definitions.add(service.definition);
}
Expand Down Expand Up @@ -367,7 +367,7 @@ public void handle(Promise<Boolean> event) {
new Handler<AsyncResult<Boolean>>() {
@Override
public void handle(AsyncResult<Boolean> result) {
command.run();
devModeWrapper.run(command);
}
});
}
Expand All @@ -381,6 +381,38 @@ public void handle(AsyncResult<Boolean> result) {
return builder.build();
}

/**
* Compression interceptor if needed, null otherwise
*
* @param configuration gRPC server configuration
* @return interceptor or null
*/
private CompressionInterceptor prepareCompressionInterceptor(GrpcServerConfiguration configuration) {
CompressionInterceptor compressionInterceptor = null;
if (configuration.compression.isPresent()) {
compressionInterceptor = new CompressionInterceptor(configuration.compression.get());
}
return compressionInterceptor;
}

private ServerServiceDefinition serviceWithInterceptors(Vertx vertx, CompressionInterceptor compressionInterceptor,
GrpcServiceDefinition service, boolean devMode) {
List<ServerInterceptor> interceptors = new ArrayList<>();
if (compressionInterceptor != null) {
interceptors.add(compressionInterceptor);
}
// We only register the blocking interceptor if needed by at least one method of the service.
if (!blockingMethodsPerService.isEmpty()) {
List<String> list = blockingMethodsPerService.get(service.getImplementationClassName());
if (list != null) {
interceptors.add(new BlockingServerInterceptor(vertx, list, devMode));
}
}
// Order matters! Request scope must be called first (on the event loop) and so should be last in the list...
interceptors.add(new GrpcRequestContextGrpcInterceptor());
return ServerInterceptors.intercept(service.definition, interceptors);
}

private class GrpcServerVerticle extends AbstractVerticle {
private final GrpcServerConfiguration configuration;
private final GrpcContainer grpcContainer;
Expand Down Expand Up @@ -432,4 +464,17 @@ public void handle(AsyncResult<Void> ar) {
});
}
}

private class DevModeWrapper {
private final ClassLoader classLoader;

public DevModeWrapper(ClassLoader contextClassLoader) {
classLoader = contextClassLoader;
}

public void run(Runnable command) {
Thread.currentThread().setContextClassLoader(classLoader);
command.run();
}
}
}
Loading

0 comments on commit e7dc28b

Please sign in to comment.