Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Override gRPC Context storage #28199

Merged
merged 1 commit into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void shouldSecureUniEndpoint() {
client.unaryCall(Security.Container.newBuilder().setText("woo-hoo").build())
.subscribe().with(e -> resultCount.incrementAndGet());

await().atMost(5, TimeUnit.SECONDS)
await().atMost(10, TimeUnit.SECONDS)
.until(() -> resultCount.get() == 1);
}

Expand All @@ -82,7 +82,7 @@ void shouldSecureMultiEndpoint() {
.supplier(() -> (Security.Container.newBuilder().setText("woo-hoo").build())).atMost(4))
.subscribe().with(e -> results.add(e.getIsOnEventLoop()));

await().atMost(5, TimeUnit.SECONDS)
await().atMost(10, TimeUnit.SECONDS)
.until(() -> results.size() == 5);

assertThat(results.stream().filter(e -> !e)).isEmpty();
Expand All @@ -101,7 +101,7 @@ void shouldFailWithInvalidCredentials() {
.onFailure().invoke(error::set)
.subscribe().with(e -> resultCount.incrementAndGet());

await().atMost(5, TimeUnit.SECONDS)
await().atMost(10, TimeUnit.SECONDS)
.until(() -> error.get() != null);
}

Expand All @@ -118,7 +118,7 @@ void shouldFailWithInvalidInsufficientRole() {
.onFailure().invoke(error::set)
.subscribe().with(e -> resultCount.incrementAndGet());

await().atMost(5, TimeUnit.SECONDS)
await().atMost(10, TimeUnit.SECONDS)
.until(() -> error.get() != null);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package io.quarkus.grpc.server.interceptors;

import static org.assertj.core.api.Assertions.assertThat;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.grpc.examples.helloworld.Greeter;
import io.grpc.examples.helloworld.GreeterGrpc;
import io.grpc.examples.helloworld.HelloReply;
import io.grpc.examples.helloworld.HelloRequest;
import io.quarkus.grpc.GrpcClient;
import io.quarkus.test.QuarkusUnitTest;

/**
* Test reproducing <a href="https://github.com/quarkusio/quarkus/issues/26830">#26830</a>.
*/
public class GrpcContextPropagationTest {

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest().setArchiveProducer(
() -> ShrinkWrap.create(JavaArchive.class)
.addPackage(GreeterGrpc.class.getPackage())
.addClasses(MyFirstInterceptor.class, MyInterceptedGreeting.class));

@GrpcClient
Greeter greeter;

@Test
void test() {
HelloReply foo = greeter.sayHello(HelloRequest.newBuilder().setName("foo").build()).await().indefinitely();
assertThat(foo.getMessage()).isEqualTo("hello k1 - 1");
foo = greeter.sayHello(HelloRequest.newBuilder().setName("foo").build()).await().indefinitely();
assertThat(foo.getMessage()).isEqualTo("hello k1 - 2");
}

}
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package io.quarkus.grpc.server.interceptors;

import java.util.concurrent.atomic.AtomicInteger;

import javax.enterprise.context.ApplicationScoped;
import javax.enterprise.inject.spi.Prioritized;

import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.ForwardingServerCall;
import io.grpc.Metadata;
import io.grpc.ServerCall;
Expand All @@ -15,19 +19,26 @@
@GlobalInterceptor
public class MyFirstInterceptor implements ServerInterceptor, Prioritized {

public static Context.Key<String> KEY_1 = Context.key("X-TEST_1");
public static Context.Key<Integer> KEY_2 = Context.keyWithDefault("X-TEST_2", -1);
private volatile long callTime;

private AtomicInteger counter = new AtomicInteger();

@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> serverCall,
Metadata metadata, ServerCallHandler<ReqT, RespT> serverCallHandler) {
return serverCallHandler
.startCall(new ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT>(serverCall) {
@Override
public void close(Status status, Metadata trailers) {
callTime = System.nanoTime();
super.close(status, trailers);
}
}, metadata);

Context ctx = Context.current().withValue(KEY_1, "k1").withValue(KEY_2, counter.incrementAndGet());
return Contexts.interceptCall(ctx, new ForwardingServerCall.SimpleForwardingServerCall<>(serverCall) {

@Override
public void close(Status status, Metadata trailers) {
callTime = System.nanoTime();
super.close(status, trailers);
}
}, metadata, serverCallHandler);

}

public long getLastCall() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package io.quarkus.grpc.server.interceptors;

import io.grpc.examples.helloworld.Greeter;
import io.grpc.examples.helloworld.HelloReply;
import io.grpc.examples.helloworld.HelloRequest;
import io.quarkus.grpc.GrpcService;
import io.smallrye.common.annotation.Blocking;
import io.smallrye.mutiny.Uni;

@GrpcService
public class MyInterceptedGreeting implements Greeter {
@Override
@Blocking
public Uni<HelloReply> sayHello(HelloRequest request) {
return Uni.createFrom().item(() -> HelloReply.newBuilder()
.setMessage("hello " + MyFirstInterceptor.KEY_1.get() + " - " + MyFirstInterceptor.KEY_2.get()).build());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package io.grpc.override;

import io.grpc.Context;
import io.smallrye.common.vertx.VertxContext;
import io.vertx.core.Vertx;

/**
* Override gRPC context storage to rely on duplicated context when available.
*/
public class ContextStorageOverride extends Context.Storage {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow that's nasty (in gRPC) , the io.grpc.override.ContextStorageOverride hardcoded class.


private static final ThreadLocal<Context> fallback = new ThreadLocal<>();

private static final String GRPC_CONTEXT = "GRPC_CONTEXT";

@Override
public Context doAttach(Context toAttach) {
Context current = current();
io.vertx.core.Context dc = Vertx.currentContext();
if (dc != null && VertxContext.isDuplicatedContext(dc)) {
dc.putLocal(GRPC_CONTEXT, toAttach);
} else {
fallback.set(toAttach);
}
return current;
}

@Override
public void detach(Context context, Context toRestore) {
io.vertx.core.Context dc = Vertx.currentContext();
if (toRestore != Context.ROOT) {
if (dc != null && VertxContext.isDuplicatedContext(dc)) {
dc.putLocal(GRPC_CONTEXT, toRestore);
} else {
fallback.set(toRestore);
}
} else {
if (dc != null && VertxContext.isDuplicatedContext(dc)) {
// Do nothing duplicated context are not shared.
} else {
fallback.set(null);
}
}
}

@Override
public Context current() {
if (VertxContext.isOnDuplicatedContext()) {
Context current = Vertx.currentContext().getLocal(GRPC_CONTEXT);
if (current == null) {
return Context.ROOT;
}
return current;
} else {
Context current = fallback.get();
if (current == null) {
return Context.ROOT;
}
return current;
}
}

@Override
public void attach(Context toAttach) {
// do nothing, should not be called.
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ private synchronized ServerCall.Listener<ReqT> getDelegate() {
delegate = supplier.get();
} catch (Throwable t) {
// If the interceptor supplier throws an exception, catch it, and close the call.
log.warnf("Unable to retrieve gRPC Server call listener", t);
log.warn("Unable to retrieve gRPC Server call listener", t);
close(t);
return null;
}
Expand Down