Skip to content

Commit

Permalink
Fix gRPC context propagation.
Browse files Browse the repository at this point in the history
  • Loading branch information
alesj authored and cescoffier committed Oct 23, 2023
1 parent fdd4c15 commit c2e2b7e
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package io.quarkus.grpc.client.bd;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.grpc.Deadline;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.examples.helloworld.GreeterGrpc;
import io.grpc.examples.helloworld.HelloRequest;
import io.quarkus.grpc.GrpcClient;
import io.quarkus.test.QuarkusUnitTest;

public class ClientBlockingDeadlineTest {

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest().setArchiveProducer(
() -> ShrinkWrap.create(JavaArchive.class)
.addPackage(GreeterGrpc.class.getPackage()).addClasses(HelloService.class))
.withConfigurationResource("hello-config-deadline.properties");

@GrpcClient("hello-service")
GreeterGrpc.GreeterBlockingStub stub;

@Test
public void testCallOptions() {
Deadline deadline = stub.getCallOptions().getDeadline();
assertNotNull(deadline);
try {
//noinspection ResultOfMethodCallIgnored
stub.sayHello(HelloRequest.newBuilder().setName("Scaladar").build());
} catch (Exception e) {
Assertions.assertTrue(e instanceof StatusRuntimeException);
StatusRuntimeException sre = (StatusRuntimeException) e;
Status status = sre.getStatus();
Assertions.assertNotNull(status);
Assertions.assertEquals(Status.DEADLINE_EXCEEDED.getCode(), status.getCode());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package io.quarkus.grpc.client.bd;

import java.time.Duration;

import io.grpc.Context;
import io.grpc.Deadline;
import io.grpc.examples.helloworld.GreeterGrpc;
import io.grpc.examples.helloworld.HelloReply;
import io.grpc.examples.helloworld.HelloRequest;
import io.grpc.stub.StreamObserver;
import io.quarkus.grpc.GrpcService;
import io.smallrye.common.annotation.Blocking;
import io.smallrye.mutiny.Uni;

@GrpcService
public class HelloService extends GreeterGrpc.GreeterImplBase {

@Override
@Blocking
public void sayHello(HelloRequest request, StreamObserver<HelloReply> observer) {
Deadline deadline = Context.current().getDeadline();
if (deadline == null) {
throw new IllegalStateException("Null deadline");
}
Uni.createFrom()
.item(HelloReply.newBuilder().setMessage("OK").build())
.onItem()
.delayIt()
.by(Duration.ofMillis(400)).invoke(observer::onNext)
.invoke(observer::onCompleted)
.await()
.indefinitely();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import java.time.Duration;

import io.grpc.Context;
import io.grpc.Deadline;
import io.grpc.examples.helloworld.Greeter;
import io.grpc.examples.helloworld.HelloReply;
import io.grpc.examples.helloworld.HelloRequest;
Expand All @@ -13,6 +15,10 @@ public class HelloService implements Greeter {

@Override
public Uni<HelloReply> sayHello(HelloRequest request) {
Deadline deadline = Context.current().getDeadline();
if (deadline == null) {
throw new IllegalStateException("Null deadline");
}
return Uni.createFrom().item(HelloReply.newBuilder().setMessage("OK").build()).onItem().delayIt()
.by(Duration.ofMillis(400));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re
Metadata metadata, ServerCallHandler<ReqT, RespT> serverCallHandler) {

Context ctx = Context.current().withValue(KEY_1, "k1").withValue(KEY_2, counter.incrementAndGet());
ctx.attach(); // Make sure the context is attached to the current duplicated context.
return Contexts.interceptCall(ctx, new ForwardingServerCall.SimpleForwardingServerCall<>(serverCall) {

@Override
Expand All @@ -38,7 +39,6 @@ public void close(Status status, Metadata trailers) {
super.close(status, trailers);
}
}, metadata, serverCallHandler);

}

public long getLastCall() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package io.quarkus.grpc.runtime.supports.blocking;

import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
Expand All @@ -13,6 +12,8 @@
import java.util.function.Consumer;
import java.util.function.Function;

import org.jboss.logging.Logger;

import io.grpc.Context;
import io.grpc.Metadata;
import io.grpc.ServerCall;
Expand All @@ -31,6 +32,7 @@
* For non-annotated methods, the interceptor acts as a pass-through.
*/
public class BlockingServerInterceptor implements ServerInterceptor, Function<String, Boolean> {
private static final Logger log = Logger.getLogger(BlockingServerInterceptor.class);

private final Vertx vertx;
private final Set<String> blockingMethods;
Expand Down Expand Up @@ -140,14 +142,16 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re
*/
private class ReplayListener<ReqT> extends ServerCall.Listener<ReqT> {
private final InjectableContext.ContextState requestContextState;
private final Context grpcContext;

// exclusive to event loop context
private ServerCall.Listener<ReqT> delegate;
private final Queue<Consumer<ServerCall.Listener<ReqT>>> incomingEvents = new LinkedList<>();
private boolean isConsumingFromIncomingEvents = false;
private volatile ServerCall.Listener<ReqT> delegate;
private final Queue<Consumer<ServerCall.Listener<ReqT>>> incomingEvents = new ConcurrentLinkedQueue<>();
private volatile boolean isConsumingFromIncomingEvents;

private ReplayListener(InjectableContext.ContextState requestContextState) {
this.requestContextState = requestContextState;
this.grpcContext = Context.current();
}

/**
Expand Down Expand Up @@ -185,6 +189,12 @@ private void executeBlockingWithRequestContext(Consumer<ServerCall.Listener<ReqT
final Context grpcContext = Context.current();
Callable<Void> blockingHandler = new BlockingExecutionHandler<>(consumer, grpcContext, delegate,
requestContextState, getRequestContext(), this);

if (!isExecutable()) {
log.warn("Not executable, already shutdown? Ignoring execution ...");
return;
}

if (devMode) {
blockingHandler = new DevModeBlockingExecutionHandler(Thread.currentThread().getContextClassLoader(),
blockingHandler);
Expand Down Expand Up @@ -323,6 +333,11 @@ public void onReady() {
}

// protected for tests

protected boolean isExecutable() {
return Arc.container() != null;
}

protected ManagedContext getRequestContext() {
return Arc.container().requestContext();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,28 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re
setContextSafe(local, true);

// Must be sure to call next.startCall on the right context
return new ListenedOnDuplicatedContext<>(ehp, call, () -> next.startCall(call, headers), local);
return new ListenedOnDuplicatedContext<>(ehp, call, nextCall(call, headers, next), local);
} else {
log.warn("Unable to run on a duplicated context - interceptor not called on the Vert.x event loop");
return next.startCall(call, headers);
}
}

private <ReqT, RespT> Supplier<ServerCall.Listener<ReqT>> nextCall(ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
// Must be sure to call next.startCall on the right context
io.grpc.Context current = io.grpc.Context.current();
return () -> {
io.grpc.Context previous = current.attach();
try {
return next.startCall(call, headers);
} finally {
current.detach(previous);
}
};
}

@Override
public int getPriority() {
return Interceptors.DUPLICATE_CONTEXT;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ void setup() {
when(requestContext.getState()).thenReturn(contextState);
blockingServerInterceptor = new BlockingServerInterceptor(vertx, Collections.singletonList("blocking"),
Collections.emptyList(), null, false) {
@Override
protected boolean isExecutable() {
return true;
}

@Override
protected ManagedContext getRequestContext() {
return requestContext;
Expand All @@ -54,21 +59,25 @@ void testContextPropagation() throws Exception {

// setting grpc context
final Context context = Context.current().withValue(USERNAME, "my-user");
Context previous = context.attach();
try {
final ServerCall.Listener listener = blockingServerInterceptor.interceptCall(serverCall, null, serverCallHandler);
serverCallHandler.awaitSetup();

final ServerCall.Listener listener = blockingServerInterceptor.interceptCall(serverCall, null, serverCallHandler);
serverCallHandler.awaitSetup();
// simulate GRPC call
context.wrap(() -> listener.onMessage("hello")).run();

// simulate GRPC call
context.wrap(() -> listener.onMessage("hello")).run();
// await for the message to be received
serverCallHandler.await();

// await for the message to be received
serverCallHandler.await();
// check that the thread is a worker thread
assertThat(serverCallHandler.threadName).contains("vert.x").contains("worker");

// check that the thread is a worker thread
assertThat(serverCallHandler.threadName).contains("vert.x").contains("worker");

// check that the context was propagated correctly
assertThat(serverCallHandler.contextUserName).isEqualTo("my-user");
// check that the context was propagated correctly
assertThat(serverCallHandler.contextUserName).isEqualTo("my-user");
} finally {
context.detach(previous);
}
}

static class BlockingServerCallHandler implements ServerCallHandler {
Expand Down

0 comments on commit c2e2b7e

Please sign in to comment.