Skip to content

Commit

Permalink
Merge pull request quarkusio#17946 from michalszynkiewicz/fix-grpc-re…
Browse files Browse the repository at this point in the history
…quest-scope

gRPC: fix request context propagation
  • Loading branch information
gsmet authored Jun 22, 2021
2 parents 065e34b + 8b9c8b2 commit 71f9dc5
Show file tree
Hide file tree
Showing 34 changed files with 1,185 additions and 204 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import io.quarkus.grpc.runtime.MutinyStub;
import io.quarkus.grpc.runtime.supports.Channels;
import io.quarkus.grpc.runtime.supports.GrpcClientConfigProvider;
import io.quarkus.grpc.runtime.supports.context.GrpcEnableRequestContext;
import io.smallrye.common.annotation.Blocking;
import io.smallrye.common.annotation.NonBlocking;

Expand All @@ -28,7 +27,6 @@ public class GrpcDotNames {
public static final DotName CHANNEL = DotName.createSimple(Channel.class.getName());
public static final DotName GRPC_CLIENT = DotName.createSimple(GrpcClient.class.getName());
public static final DotName GRPC_SERVICE = DotName.createSimple(GrpcService.class.getName());
public static final DotName GRPC_ENABLE_REQUEST_CONTEXT = DotName.createSimple(GrpcEnableRequestContext.class.getName());

public static final DotName BLOCKING = DotName.createSimple(Blocking.class.getName());
public static final DotName NON_BLOCKING = DotName.createSimple(NonBlocking.class.getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import io.quarkus.arc.processor.AnnotationsTransformer;
import io.quarkus.arc.processor.BeanInfo;
import io.quarkus.arc.processor.BuiltinScope;
import io.quarkus.arc.processor.Transformation;
import io.quarkus.deployment.IsDevelopment;
import io.quarkus.deployment.IsNormal;
import io.quarkus.deployment.annotations.BuildProducer;
Expand All @@ -60,8 +59,6 @@
import io.quarkus.grpc.runtime.config.GrpcServerBuildTimeConfig;
import io.quarkus.grpc.runtime.health.GrpcHealthEndpoint;
import io.quarkus.grpc.runtime.health.GrpcHealthStorage;
import io.quarkus.grpc.runtime.supports.context.GrpcEnableRequestContext;
import io.quarkus.grpc.runtime.supports.context.GrpcRequestContextCdiInterceptor;
import io.quarkus.kubernetes.spi.KubernetesPortBuildItem;
import io.quarkus.netty.deployment.MinNettyAllocatorMaxOrderBuildItem;
import io.quarkus.runtime.LaunchMode;
Expand Down Expand Up @@ -240,14 +237,11 @@ public boolean appliesTo(Kind kind) {
@Override
public void transform(TransformationContext context) {
ClassInfo clazz = context.getTarget().asClass();
if (userDefinedServices.contains(clazz.name())) {
// Add @GrpcEnableRequestContext to activate the request context during each call
Transformation transform = context.transform().add(GrpcDotNames.GRPC_ENABLE_REQUEST_CONTEXT);
if (!customScopes.isScopeDeclaredOn(clazz)) {
// Add @Singleton to make it a bean
transform.add(BuiltinScope.SINGLETON.getName());
}
transform.done();
if (userDefinedServices.contains(clazz.name()) && !customScopes.isScopeDeclaredOn(clazz)) {
// Add @Singleton to make it a bean
context.transform()
.add(BuiltinScope.SINGLETON.getName())
.done();
}
}
});
Expand Down Expand Up @@ -303,8 +297,6 @@ void registerBeans(BuildProducer<AdditionalBeanBuildItem> beans,
List<BindableServiceBuildItem> bindables, BuildProducer<FeatureBuildItem> features) {
// @GrpcService is a CDI qualifier
beans.produce(new AdditionalBeanBuildItem(GrpcService.class));
beans.produce(new AdditionalBeanBuildItem(GrpcRequestContextCdiInterceptor.class));
beans.produce(new AdditionalBeanBuildItem(GrpcEnableRequestContext.class));

if (!bindables.isEmpty() || LaunchMode.current() == LaunchMode.DEVELOPMENT) {
beans.produce(AdditionalBeanBuildItem.unremovableOf(GrpcContainer.class));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
import io.quarkus.grpc.runtime.devmode.GrpcServerReloader;
import io.quarkus.grpc.runtime.health.GrpcHealthStorage;
import io.quarkus.grpc.runtime.reflection.ReflectionService;
import io.quarkus.grpc.runtime.supports.BlockingServerInterceptor;
import io.quarkus.grpc.runtime.supports.CompressionInterceptor;
import io.quarkus.grpc.runtime.supports.blocking.BlockingServerInterceptor;
import io.quarkus.grpc.runtime.supports.context.GrpcRequestContextGrpcInterceptor;
import io.quarkus.runtime.LaunchMode;
import io.quarkus.runtime.RuntimeValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@
import grpc.health.v1.HealthOuterClass.HealthCheckResponse.ServingStatus;
import grpc.health.v1.MutinyHealthGrpc;
import io.quarkus.grpc.GrpcService;
import io.quarkus.grpc.runtime.supports.context.GrpcEnableRequestContext;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.Uni;
import io.smallrye.mutiny.operators.multi.processors.BroadcastProcessor;

// Note that we need to add the scope and interceptor binding explicitly because this class is not part of the index
// Note that we need to add the scope explicitly because this class is not part of the index
@Singleton
@GrpcEnableRequestContext
@GrpcService
public class GrpcHealthEndpoint extends MutinyHealthGrpc.HealthImplBase {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ public static Channel createChannel(String name) throws SSLException {
GrpcClientConfiguration config = configProvider.getConfiguration(name);

if (config == null && LaunchMode.current() == LaunchMode.TEST) {
LOGGER.infof(
"gRPC client %s created without configuration. We are assuming that it's created to test your gRPC services.",
name);
config = testConfig(configProvider.getServerConfiguration());
}

Expand Down Expand Up @@ -164,7 +167,6 @@ public static Channel createChannel(String name) throws SSLException {
}

private static GrpcClientConfiguration testConfig(GrpcServerConfiguration serverConfiguration) {
LOGGER.info("gRPC client created without configuration. We are assuming that it's created to test your gRPC services.");
GrpcClientConfiguration config = new GrpcClientConfiguration();
config.port = serverConfiguration.testPort;
config.host = serverConfiguration.host;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package io.quarkus.grpc.runtime.supports.blocking;

import java.util.function.Consumer;

import io.grpc.Context;
import io.grpc.ServerCall;
import io.quarkus.arc.InjectableContext;
import io.quarkus.arc.ManagedContext;
import io.vertx.core.Handler;
import io.vertx.core.Promise;

class BlockingExecutionHandler<ReqT> implements Handler<Promise<Object>> {
private final ServerCall.Listener<ReqT> delegate;
private final Context grpcContext;
private final Consumer<ServerCall.Listener<ReqT>> consumer;
private final InjectableContext.ContextState state;
private final ManagedContext requestContext;

public BlockingExecutionHandler(Consumer<ServerCall.Listener<ReqT>> consumer, Context grpcContext,
ServerCall.Listener<ReqT> delegate, InjectableContext.ContextState state,
ManagedContext requestContext) {
this.consumer = consumer;
this.grpcContext = grpcContext;
this.delegate = delegate;
this.state = state;
this.requestContext = requestContext;
}

@Override
public void handle(Promise<Object> event) {
final Context previous = Context.current();
grpcContext.attach();
try {
requestContext.activate(state);
try {
consumer.accept(delegate);
} catch (Throwable any) {
event.fail(any);
return;
} finally {
requestContext.deactivate();
}
event.complete();
} finally {
grpcContext.detach(previous);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package io.quarkus.grpc.runtime.supports;
package io.quarkus.grpc.runtime.supports.blocking;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
Expand All @@ -13,12 +12,15 @@
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.quarkus.arc.Arc;
import io.quarkus.arc.InjectableContext.ContextState;
import io.quarkus.arc.ManagedContext;
import io.vertx.core.Handler;
import io.vertx.core.Promise;
import io.vertx.core.Vertx;

/**
* gRPC Server interceptor offloading the execution of the gRPC method on a wroker thread if the method is annotated
* gRPC Server interceptor offloading the execution of the gRPC method on a worker thread if the method is annotated
* with {@link io.smallrye.common.annotation.Blocking}.
*
* For non-annotated methods, the interceptor acts as a pass-through.
Expand Down Expand Up @@ -62,13 +64,23 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re
boolean isBlocking = cache.computeIfAbsent(fullMethodName, this);

if (isBlocking) {
ReplayListener<ReqT> replay = new ReplayListener<>();

final ManagedContext requestContext = getRequestContext();
// context should always be active here
// it is initialized by io.quarkus.grpc.runtime.supports.context.GrpcRequestContextGrpcInterceptor
// that should always be called before this interceptor
ContextState state = requestContext.getState();
ReplayListener<ReqT> replay = new ReplayListener<>(state);
vertx.executeBlocking(new Handler<Promise<Object>>() {
@Override
public void handle(Promise<Object> f) {
ServerCall.Listener<ReqT> listener = next.startCall(call, headers);
replay.setDelegate(listener);
ServerCall.Listener<ReqT> listener;
try {
requestContext.activate(state);
listener = next.startCall(call, headers);
} finally {
requestContext.deactivate();
}
replay.setDelegate(listener, requestContext);
f.complete(null);
}
}, null);
Expand All @@ -87,30 +99,46 @@ public void handle(Promise<Object> f) {
*/
private class ReplayListener<ReqT> extends ServerCall.Listener<ReqT> {
private ServerCall.Listener<ReqT> delegate;
private final List<Consumer<ServerCall.Listener<ReqT>>> incomingEvents = new LinkedList<>();
private final List<Consumer<ServerCall.Listener<ReqT>>> incomingEvents = new ArrayList<>();
private final ContextState requestContextState;

private ReplayListener(ContextState requestContextState) {
this.requestContextState = requestContextState;
}

synchronized void setDelegate(ServerCall.Listener<ReqT> delegate) {
synchronized void setDelegate(ServerCall.Listener<ReqT> delegate,
ManagedContext requestContext) {
this.delegate = delegate;
for (Consumer<ServerCall.Listener<ReqT>> event : incomingEvents) {
event.accept(delegate);
requestContext.activate(requestContextState);
try {
for (Consumer<ServerCall.Listener<ReqT>> event : incomingEvents) {
event.accept(delegate);
}
} finally {
requestContext.deactivate();
}
incomingEvents.clear();
}

private synchronized void executeOnContextOrEnqueue(Consumer<ServerCall.Listener<ReqT>> consumer) {
if (this.delegate != null) {
final Context grpcContext = Context.current();
Handler<Promise<Object>> blockingHandler = new BlockingExecutionHandler<>(consumer, grpcContext, delegate);
if (devMode) {
blockingHandler = new DevModeBlockingExecutionHandler<ReqT>(Thread.currentThread().getContextClassLoader(),
blockingHandler);
}
vertx.executeBlocking(blockingHandler, true, null);
executeBlockingWithRequestContext(consumer);
} else {
incomingEvents.add(consumer);
}
}

private void executeBlockingWithRequestContext(Consumer<ServerCall.Listener<ReqT>> consumer) {
final Context grpcContext = Context.current();
Handler<Promise<Object>> blockingHandler = new BlockingExecutionHandler<>(consumer, grpcContext, delegate,
requestContextState, getRequestContext());
if (devMode) {
blockingHandler = new DevModeBlockingExecutionHandler(Thread.currentThread().getContextClassLoader(),
blockingHandler);
}
vertx.executeBlocking(blockingHandler, true, null);
}

@Override
public void onMessage(ReqT message) {
executeOnContextOrEnqueue(new Consumer<ServerCall.Listener<ReqT>>() {
Expand Down Expand Up @@ -142,50 +170,8 @@ public void onReady() {
}
}

private static class DevModeBlockingExecutionHandler<ReqT> implements Handler<Promise<Object>> {

final ClassLoader tccl;
final Handler<Promise<Object>> delegate;

public DevModeBlockingExecutionHandler(ClassLoader tccl, Handler<Promise<Object>> delegate) {
this.tccl = tccl;
this.delegate = delegate;
}

@Override
public void handle(Promise<Object> event) {
ClassLoader originalTccl = Thread.currentThread().getContextClassLoader();
Thread.currentThread().setContextClassLoader(tccl);
try {
delegate.handle(event);
} finally {
Thread.currentThread().setContextClassLoader(originalTccl);
}
}
}

private static class BlockingExecutionHandler<ReqT> implements Handler<Promise<Object>> {
private final ServerCall.Listener<ReqT> delegate;
private final Context grpcContext;
private final Consumer<ServerCall.Listener<ReqT>> consumer;

public BlockingExecutionHandler(Consumer<ServerCall.Listener<ReqT>> consumer, Context grpcContext,
ServerCall.Listener<ReqT> delegate) {
this.consumer = consumer;
this.grpcContext = grpcContext;
this.delegate = delegate;
}

@Override
public void handle(Promise<Object> event) {
final Context previous = Context.current();
grpcContext.attach();
try {
consumer.accept(delegate);
event.complete();
} finally {
grpcContext.detach(previous);
}
}
// protected for tests
protected ManagedContext getRequestContext() {
return Arc.container().requestContext();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.quarkus.grpc.runtime.supports.blocking;

import io.vertx.core.Handler;
import io.vertx.core.Promise;

class DevModeBlockingExecutionHandler implements Handler<Promise<Object>> {

final ClassLoader tccl;
final Handler<Promise<Object>> delegate;

public DevModeBlockingExecutionHandler(ClassLoader tccl, Handler<Promise<Object>> delegate) {
this.tccl = tccl;
this.delegate = delegate;
}

@Override
public void handle(Promise<Object> event) {
ClassLoader originalTccl = Thread.currentThread().getContextClassLoader();
Thread.currentThread().setContextClassLoader(tccl);
try {
delegate.handle(event);
} finally {
Thread.currentThread().setContextClassLoader(originalTccl);
}
}
}

This file was deleted.

Loading

0 comments on commit 71f9dc5

Please sign in to comment.