From e431054b3081fbd4c477f44b61a60b962a1ba9c6 Mon Sep 17 00:00:00 2001
From: Dave Maughan <davidamaughan@gmail.com>
Date: Tue, 27 Jun 2023 20:51:04 +0100
Subject: [PATCH] Add support for @GlobalInterceptor on producer methods

Fixes #21358
---
 .../asciidoc/grpc-service-consumption.adoc    | 17 ++++++
 .../asciidoc/grpc-service-implementation.adoc | 17 ++++++
 .../io/quarkus/grpc/GlobalInterceptor.java    |  3 +-
 .../grpc/deployment/GrpcInterceptors.java     | 44 +++++++++++++-
 .../interceptors/ClientInterceptors.java      | 38 ++++++++++++
 .../interceptors/HeaderClientInterceptor.java | 58 -------------------
 .../interceptors/HeaderServerInterceptor.java | 43 --------------
 .../interceptors/HelloWorldEndpoint.java      | 11 +++-
 .../interceptors/ServerInterceptors.java      | 37 ++++++++++++
 .../HelloWorldEndpointTestBase.java           | 11 +++-
 10 files changed, 169 insertions(+), 110 deletions(-)
 create mode 100644 integration-tests/grpc-interceptors/src/main/java/io/quarkus/grpc/examples/interceptors/ClientInterceptors.java
 delete mode 100644 integration-tests/grpc-interceptors/src/main/java/io/quarkus/grpc/examples/interceptors/HeaderClientInterceptor.java
 delete mode 100644 integration-tests/grpc-interceptors/src/main/java/io/quarkus/grpc/examples/interceptors/HeaderServerInterceptor.java
 create mode 100644 integration-tests/grpc-interceptors/src/main/java/io/quarkus/grpc/examples/interceptors/ServerInterceptors.java

diff --git a/docs/src/main/asciidoc/grpc-service-consumption.adoc b/docs/src/main/asciidoc/grpc-service-consumption.adoc
index 21a5b22ff689a..a05c32c1bde93 100644
--- a/docs/src/main/asciidoc/grpc-service-consumption.adoc
+++ b/docs/src/main/asciidoc/grpc-service-consumption.adoc
@@ -325,6 +325,23 @@ public class MyInterceptor implements ClientInterceptor {
 ----
 <1> This interceptor is applied to all injected gRPC clients.
 
+It's also possible to annotate a producer method as a global interceptor:
+
+[source, java]
+----
+import io.quarkus.grpc.GlobalInterceptor;
+
+import jakarta.enterprise.inject.Produces;
+
+public class MyProducer {
+    @GlobalInterceptor
+    @Produces
+    public MyInterceptor myInterceptor() {
+        return new MyInterceptor();
+    }
+}
+----
+
 TIP: Check the https://grpc.github.io/grpc-java/javadoc/io/grpc/ClientInterceptor.html[ClientInterceptor JavaDoc] to properly implement your interceptor.
 
 .`@RegisterClientInterceptor` Example
diff --git a/docs/src/main/asciidoc/grpc-service-implementation.adoc b/docs/src/main/asciidoc/grpc-service-implementation.adoc
index ce1f987c3e2e8..0b909e6b38f9b 100644
--- a/docs/src/main/asciidoc/grpc-service-implementation.adoc
+++ b/docs/src/main/asciidoc/grpc-service-implementation.adoc
@@ -271,6 +271,23 @@ public class MyInterceptor implements ServerInterceptor {
 }
 ----
 
+It's also possible to annotate a producer method as a global interceptor:
+
+[source, java]
+----
+import io.quarkus.grpc.GlobalInterceptor;
+
+import jakarta.enterprise.inject.Produces;
+
+public class MyProducer {
+    @GlobalInterceptor
+    @Produces
+    public MyInterceptor myInterceptor() {
+        return new MyInterceptor();
+    }
+}
+----
+
 TIP: Check the https://grpc.github.io/grpc-java/javadoc/io/grpc/ServerInterceptor.html[ServerInterceptor JavaDoc] to properly implement your interceptor.
 
 To apply an interceptor to all exposed services, annotate it with `@io.quarkus.grpc.GlobalInterceptor`.
diff --git a/extensions/grpc/api/src/main/java/io/quarkus/grpc/GlobalInterceptor.java b/extensions/grpc/api/src/main/java/io/quarkus/grpc/GlobalInterceptor.java
index b78ceab79bdb8..4935c32280ac5 100644
--- a/extensions/grpc/api/src/main/java/io/quarkus/grpc/GlobalInterceptor.java
+++ b/extensions/grpc/api/src/main/java/io/quarkus/grpc/GlobalInterceptor.java
@@ -1,6 +1,7 @@
 package io.quarkus.grpc;
 
 import static java.lang.annotation.ElementType.FIELD;
+import static java.lang.annotation.ElementType.METHOD;
 import static java.lang.annotation.ElementType.PARAMETER;
 import static java.lang.annotation.ElementType.TYPE;
 import static java.lang.annotation.RetentionPolicy.RUNTIME;
@@ -15,7 +16,7 @@
  * @see RegisterInterceptor
  * @see RegisterClientInterceptor
  */
-@Target({ FIELD, PARAMETER, TYPE })
+@Target({ FIELD, PARAMETER, TYPE, METHOD })
 @Retention(RUNTIME)
 public @interface GlobalInterceptor {
 }
diff --git a/extensions/grpc/deployment/src/main/java/io/quarkus/grpc/deployment/GrpcInterceptors.java b/extensions/grpc/deployment/src/main/java/io/quarkus/grpc/deployment/GrpcInterceptors.java
index 8b436cae4e1ca..07400477fc693 100644
--- a/extensions/grpc/deployment/src/main/java/io/quarkus/grpc/deployment/GrpcInterceptors.java
+++ b/extensions/grpc/deployment/src/main/java/io/quarkus/grpc/deployment/GrpcInterceptors.java
@@ -1,6 +1,8 @@
 package io.quarkus.grpc.deployment;
 
 import static io.quarkus.grpc.deployment.GrpcDotNames.GLOBAL_INTERCEPTOR;
+import static org.jboss.jandex.AnnotationTarget.Kind.CLASS;
+import static org.jboss.jandex.AnnotationTarget.Kind.METHOD;
 
 import java.lang.reflect.Modifier;
 import java.util.Collection;
@@ -8,6 +10,8 @@
 import java.util.List;
 import java.util.Set;
 
+import org.jboss.jandex.AnnotationInstance;
+import org.jboss.jandex.AnnotationTarget;
 import org.jboss.jandex.ClassInfo;
 import org.jboss.jandex.DotName;
 import org.jboss.jandex.IndexView;
@@ -27,6 +31,7 @@ final class GrpcInterceptors {
     }
 
     static GrpcInterceptors gatherInterceptors(IndexView index, DotName interceptorInterface) {
+        Set<DotName> allGlobalInterceptors = allGlobalInterceptors(index, interceptorInterface);
         Set<String> globalInterceptors = new HashSet<>();
         Set<String> nonGlobalInterceptors = new HashSet<>();
 
@@ -36,13 +41,46 @@ static GrpcInterceptors gatherInterceptors(IndexView index, DotName interceptorI
                     || Modifier.isInterface(interceptorImplClass.flags())) {
                 continue;
             }
-            if (interceptorImplClass.declaredAnnotation(GLOBAL_INTERCEPTOR) == null) {
-                nonGlobalInterceptors.add(interceptorImplClass.name().toString());
-            } else {
+            if (allGlobalInterceptors.contains(interceptorImplClass.name())) {
                 globalInterceptors.add(interceptorImplClass.name().toString());
+            } else {
+                nonGlobalInterceptors.add(interceptorImplClass.name().toString());
             }
         }
         return new GrpcInterceptors(globalInterceptors, nonGlobalInterceptors);
     }
 
+    private static Set<DotName> allGlobalInterceptors(IndexView index, DotName interceptorInterface) {
+        Set<DotName> result = new HashSet<>();
+        for (AnnotationInstance instance : index.getAnnotations(GLOBAL_INTERCEPTOR)) {
+            ClassInfo classInfo = classInfo(index, instance.target());
+            if (isAssignableFrom(index, classInfo, interceptorInterface)) {
+                result.add(classInfo.name());
+            }
+        }
+        return result;
+    }
+
+    private static ClassInfo classInfo(IndexView index, AnnotationTarget target) {
+        if (target.kind() == CLASS) {
+            return target.asClass();
+        } else if (target.kind() == METHOD) {
+            return index.getClassByName(target.asMethod().returnType().name());
+        }
+        return null;
+    }
+
+    private static boolean isAssignableFrom(IndexView index, ClassInfo classInfo, DotName interceptorInterface) {
+        if (classInfo == null) {
+            return false;
+        }
+        if (classInfo.interfaceNames().contains(interceptorInterface)) {
+            return true;
+        }
+        if (classInfo.superName() == null) {
+            return false;
+        }
+        return isAssignableFrom(index, index.getClassByName(classInfo.superName()), interceptorInterface);
+    }
+
 }
diff --git a/integration-tests/grpc-interceptors/src/main/java/io/quarkus/grpc/examples/interceptors/ClientInterceptors.java b/integration-tests/grpc-interceptors/src/main/java/io/quarkus/grpc/examples/interceptors/ClientInterceptors.java
new file mode 100644
index 0000000000000..5f69a36be2668
--- /dev/null
+++ b/integration-tests/grpc-interceptors/src/main/java/io/quarkus/grpc/examples/interceptors/ClientInterceptors.java
@@ -0,0 +1,38 @@
+package io.quarkus.grpc.examples.interceptors;
+
+import jakarta.enterprise.context.ApplicationScoped;
+import jakarta.enterprise.inject.Produces;
+
+import io.grpc.CallOptions;
+import io.grpc.Channel;
+import io.grpc.ClientCall;
+import io.grpc.ClientInterceptor;
+import io.grpc.MethodDescriptor;
+import io.quarkus.grpc.GlobalInterceptor;
+
+class ClientInterceptors {
+    @GlobalInterceptor
+    @ApplicationScoped
+    static class TypeTarget extends Base {
+    }
+
+    static class MethodTarget extends Base {
+    }
+
+    static class Producer {
+        @GlobalInterceptor
+        @Produces
+        MethodTarget methodTarget() {
+            return new MethodTarget();
+        }
+    }
+
+    abstract static class Base implements ClientInterceptor {
+        @Override
+        public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions options,
+                Channel next) {
+            HelloWorldEndpoint.invoked.add(getClass().getName());
+            return next.newCall(method, options);
+        }
+    }
+}
diff --git a/integration-tests/grpc-interceptors/src/main/java/io/quarkus/grpc/examples/interceptors/HeaderClientInterceptor.java b/integration-tests/grpc-interceptors/src/main/java/io/quarkus/grpc/examples/interceptors/HeaderClientInterceptor.java
deleted file mode 100644
index c6dcea58ef74f..0000000000000
--- a/integration-tests/grpc-interceptors/src/main/java/io/quarkus/grpc/examples/interceptors/HeaderClientInterceptor.java
+++ /dev/null
@@ -1,58 +0,0 @@
-package io.quarkus.grpc.examples.interceptors;
-
-import java.util.logging.Logger;
-
-import jakarta.enterprise.context.ApplicationScoped;
-
-import io.grpc.CallOptions;
-import io.grpc.Channel;
-import io.grpc.ClientCall;
-import io.grpc.ClientInterceptor;
-import io.grpc.ForwardingClientCall;
-import io.grpc.ForwardingClientCallListener;
-import io.grpc.Metadata;
-import io.grpc.MethodDescriptor;
-import io.quarkus.grpc.GlobalInterceptor;
-
-/**
- * A interceptor to handle client header.
- */
-@GlobalInterceptor
-@ApplicationScoped
-public class HeaderClientInterceptor implements ClientInterceptor {
-
-    private static final Logger logger = Logger.getLogger(HeaderClientInterceptor.class.getName());
-
-    static volatile boolean invoked = false;
-
-    static final Metadata.Key<String> CUSTOM_HEADER_KEY = Metadata.Key.of("custom_client_header_key",
-            Metadata.ASCII_STRING_MARSHALLER);
-
-    @Override
-    public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method,
-            CallOptions callOptions, Channel next) {
-        return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) {
-
-            @Override
-            public void start(Listener<RespT> responseListener, Metadata headers) {
-                /* put custom header */
-                headers.put(CUSTOM_HEADER_KEY, "customRequestValue");
-                super.start(
-                        new ForwardingClientCallListener.SimpleForwardingClientCallListener<RespT>(responseListener) {
-                            @Override
-                            public void onHeaders(Metadata headers) {
-
-                                //
-                                // if you don't need receive header from server,
-                                // you can use {@link io.grpc.stub.MetadataUtils#attachHeaders}
-                                // directly to send header
-                                //
-                                invoked = true;
-                                logger.info("header received from server:" + headers);
-                                super.onHeaders(headers);
-                            }
-                        }, headers);
-            }
-        };
-    }
-}
diff --git a/integration-tests/grpc-interceptors/src/main/java/io/quarkus/grpc/examples/interceptors/HeaderServerInterceptor.java b/integration-tests/grpc-interceptors/src/main/java/io/quarkus/grpc/examples/interceptors/HeaderServerInterceptor.java
deleted file mode 100644
index 82bb2ed772852..0000000000000
--- a/integration-tests/grpc-interceptors/src/main/java/io/quarkus/grpc/examples/interceptors/HeaderServerInterceptor.java
+++ /dev/null
@@ -1,43 +0,0 @@
-package io.quarkus.grpc.examples.interceptors;
-
-import java.util.logging.Logger;
-
-import jakarta.enterprise.context.ApplicationScoped;
-
-import com.google.common.annotations.VisibleForTesting;
-
-import io.grpc.ForwardingServerCall;
-import io.grpc.Metadata;
-import io.grpc.ServerCall;
-import io.grpc.ServerCallHandler;
-import io.grpc.ServerInterceptor;
-import io.quarkus.grpc.GlobalInterceptor;
-
-/**
- * A interceptor to handle server header.
- */
-@ApplicationScoped
-@GlobalInterceptor
-public class HeaderServerInterceptor implements ServerInterceptor {
-
-    private static final Logger logger = Logger.getLogger(HeaderServerInterceptor.class.getName());
-
-    @VisibleForTesting
-    static final Metadata.Key<String> CUSTOM_HEADER_KEY = Metadata.Key.of("custom_server_header_key",
-            Metadata.ASCII_STRING_MARSHALLER);
-
-    @Override
-    public <I, O> ServerCall.Listener<I> interceptCall(
-            ServerCall<I, O> call,
-            final Metadata requestHeaders,
-            ServerCallHandler<I, O> next) {
-        logger.info("header received from client:" + requestHeaders);
-        return next.startCall(new ForwardingServerCall.SimpleForwardingServerCall<>(call) {
-            @Override
-            public void sendHeaders(Metadata responseHeaders) {
-                responseHeaders.put(CUSTOM_HEADER_KEY, "customRespondValue");
-                super.sendHeaders(responseHeaders);
-            }
-        }, requestHeaders);
-    }
-}
diff --git a/integration-tests/grpc-interceptors/src/main/java/io/quarkus/grpc/examples/interceptors/HelloWorldEndpoint.java b/integration-tests/grpc-interceptors/src/main/java/io/quarkus/grpc/examples/interceptors/HelloWorldEndpoint.java
index 45a7e12e70ca3..0d9dc839a2d1a 100644
--- a/integration-tests/grpc-interceptors/src/main/java/io/quarkus/grpc/examples/interceptors/HelloWorldEndpoint.java
+++ b/integration-tests/grpc-interceptors/src/main/java/io/quarkus/grpc/examples/interceptors/HelloWorldEndpoint.java
@@ -1,5 +1,8 @@
 package io.quarkus.grpc.examples.interceptors;
 
+import java.util.HashSet;
+import java.util.Set;
+
 import jakarta.ws.rs.GET;
 import jakarta.ws.rs.Path;
 import jakarta.ws.rs.PathParam;
@@ -14,7 +17,7 @@
 
 @Path("/hello")
 public class HelloWorldEndpoint {
-
+    static Set<String> invoked = new HashSet<>();
     @GrpcClient("hello")
     GreeterGrpc.GreeterBlockingStub blockingHelloService;
 
@@ -24,10 +27,12 @@ public class HelloWorldEndpoint {
     @GET
     @Path("/blocking/{name}")
     public Response helloBlocking(@PathParam("name") String name) {
-        HeaderClientInterceptor.invoked = false;
+        invoked.clear();
         HelloReply helloReply = blockingHelloService.sayHello(HelloRequest.newBuilder().setName(name).build());
 
-        return Response.ok(helloReply.getMessage()).header("intercepted", HeaderClientInterceptor.invoked).build();
+        return Response.ok(helloReply.getMessage())
+                .header("interceptors", String.join(",", invoked))
+                .build();
     }
 
     @GET
diff --git a/integration-tests/grpc-interceptors/src/main/java/io/quarkus/grpc/examples/interceptors/ServerInterceptors.java b/integration-tests/grpc-interceptors/src/main/java/io/quarkus/grpc/examples/interceptors/ServerInterceptors.java
new file mode 100644
index 0000000000000..d88b34bece02e
--- /dev/null
+++ b/integration-tests/grpc-interceptors/src/main/java/io/quarkus/grpc/examples/interceptors/ServerInterceptors.java
@@ -0,0 +1,37 @@
+package io.quarkus.grpc.examples.interceptors;
+
+import jakarta.enterprise.context.ApplicationScoped;
+import jakarta.enterprise.inject.Produces;
+
+import io.grpc.Metadata;
+import io.grpc.ServerCall;
+import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptor;
+import io.quarkus.grpc.GlobalInterceptor;
+
+class ServerInterceptors {
+    @GlobalInterceptor
+    @ApplicationScoped
+    static class TypeTarget extends Base {
+    }
+
+    static class MethodTarget extends Base {
+    }
+
+    static class Producer {
+        @GlobalInterceptor
+        @Produces
+        MethodTarget methodTarget() {
+            return new MethodTarget();
+        }
+    }
+
+    abstract static class Base implements ServerInterceptor {
+        @Override
+        public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata metadata,
+                ServerCallHandler<ReqT, RespT> next) {
+            HelloWorldEndpoint.invoked.add(getClass().getName());
+            return next.startCall(call, metadata);
+        }
+    }
+}
diff --git a/integration-tests/grpc-interceptors/src/test/java/io/quarkus/grpc/example/interceptors/HelloWorldEndpointTestBase.java b/integration-tests/grpc-interceptors/src/test/java/io/quarkus/grpc/example/interceptors/HelloWorldEndpointTestBase.java
index e30cf19c4dc9e..78c3ff85cf9b4 100644
--- a/integration-tests/grpc-interceptors/src/test/java/io/quarkus/grpc/example/interceptors/HelloWorldEndpointTestBase.java
+++ b/integration-tests/grpc-interceptors/src/test/java/io/quarkus/grpc/example/interceptors/HelloWorldEndpointTestBase.java
@@ -3,6 +3,8 @@
 import static io.restassured.RestAssured.get;
 import static org.assertj.core.api.Assertions.assertThat;
 
+import java.util.Set;
+
 import org.junit.jupiter.api.Test;
 
 import io.restassured.response.Response;
@@ -12,10 +14,15 @@ class HelloWorldEndpointTestBase {
     @Test
     public void testHelloWorldServiceUsingBlockingStub() {
         Response response = get("/hello/blocking/neo");
-        String intercepted = response.getHeader("intercepted");
         String responseMsg = response.asString();
         assertThat(responseMsg).isEqualTo("Hello neo");
-        assertThat(intercepted).isEqualTo("true");
+
+        Set<String> invoked = Set.of(response.getHeader("interceptors").split(","));
+        assertThat(invoked).containsExactlyInAnyOrder(
+                "io.quarkus.grpc.examples.interceptors.ClientInterceptors$TypeTarget",
+                "io.quarkus.grpc.examples.interceptors.ClientInterceptors$MethodTarget",
+                "io.quarkus.grpc.examples.interceptors.ServerInterceptors$TypeTarget",
+                "io.quarkus.grpc.examples.interceptors.ServerInterceptors$MethodTarget");
 
         ensureThatMetricsAreProduced();
     }