Skip to content

Commit

Permalink
Merge pull request #36123 from manovotn/issue36118_2
Browse files Browse the repository at this point in the history
Make rest-client invocation context implement ArcInvocationContext
  • Loading branch information
geoand authored Sep 26, 2023
2 parents 1e3a64b + f35b854 commit aaa5689
Show file tree
Hide file tree
Showing 4 changed files with 300 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
package io.quarkus.restclient.runtime;

import java.lang.annotation.Annotation;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletionException;

import jakarta.enterprise.inject.spi.InterceptionType;
import jakarta.enterprise.inject.spi.Interceptor;
import jakarta.interceptor.InvocationContext;
import jakarta.ws.rs.client.ResponseProcessingException;

import org.jboss.resteasy.microprofile.client.ExceptionMapping;

import io.quarkus.arc.ArcInvocationContext;

/**
* A Quarkus copy of {@link org.jboss.resteasy.microprofile.client.InvocationContextImpl} which makes it implement
* {@link ArcInvocationContext} instead so that it's compatible with Quarkus interceptors.
*/
public class QuarkusInvocationContextImpl implements ArcInvocationContext {

private final Object target;

private final Method method;

private Object[] args;

private final int position;

private final Map<String, Object> contextData;

private final List<QuarkusInvocationContextImpl.InterceptorInvocation> chain;

private final Set<Annotation> interceptorBindings;

public QuarkusInvocationContextImpl(final Object target, final Method method, final Object[] args,
final List<QuarkusInvocationContextImpl.InterceptorInvocation> chain, Set<Annotation> interceptorBindings) {
this(target, method, args, chain, 0, interceptorBindings);
}

private QuarkusInvocationContextImpl(final Object target, final Method method, final Object[] args,
final List<QuarkusInvocationContextImpl.InterceptorInvocation> chain, final int position,
Set<Annotation> interceptorBindings) {
this.target = target;
this.method = method;
this.args = args;
this.interceptorBindings = interceptorBindings == null ? Collections.emptySet() : interceptorBindings;
this.contextData = new HashMap<>();
// put in bindings under Arc's specific key
this.contextData.put(ArcInvocationContext.KEY_INTERCEPTOR_BINDINGS, interceptorBindings);
this.position = position;
this.chain = chain;
}

boolean hasNextInterceptor() {
return position < chain.size();
}

protected Object invokeNext() throws Exception {
return chain.get(position).invoke(nextContext());
}

private InvocationContext nextContext() {
return new QuarkusInvocationContextImpl(target, method, args, chain, position + 1, interceptorBindings);
}

protected Object interceptorChainCompleted() throws Exception {
try {
return method.invoke(target, args);
} catch (InvocationTargetException e) {
Throwable cause = e.getCause();
if (cause instanceof CompletionException) {
cause = cause.getCause();
}
if (cause instanceof ExceptionMapping.HandlerException) {
((ExceptionMapping.HandlerException) cause).mapException(method);
}
if (cause instanceof ResponseProcessingException) {
ResponseProcessingException rpe = (ResponseProcessingException) cause;
// Note that the default client engine leverages a single connection
// MP FT: we need to close the response otherwise we would not be able to retry if the method returns jakarta.ws.rs.core.Response
rpe.getResponse().close();
cause = rpe.getCause();
if (cause instanceof RuntimeException) {
throw (RuntimeException) cause;
}
}
throw e;
}
}

@Override
public Object proceed() throws Exception {
try {
if (hasNextInterceptor()) {
return invokeNext();
} else {
return interceptorChainCompleted();
}
} catch (InvocationTargetException e) {
Throwable cause = e.getCause();
if (cause instanceof Error) {
throw (Error) cause;
}
if (cause instanceof Exception) {
throw (Exception) cause;
}
throw new RuntimeException(cause);
}
}

@Override
public Object getTarget() {
return target;
}

@Override
public Method getMethod() {
return method;
}

@Override
public Constructor<?> getConstructor() {
return null;
}

@Override
public Object[] getParameters() throws IllegalStateException {
return args;
}

@Override
public void setParameters(Object[] params) throws IllegalStateException, IllegalArgumentException {
this.args = params;
}

@Override
public Map<String, Object> getContextData() {
return contextData;
}

@Override
public Object getTimer() {
return null;
}

@Override
public Set<Annotation> getInterceptorBindings() {
return interceptorBindings;
}

@Override
public <T extends Annotation> T findIterceptorBinding(Class<T> annotationType) {
for (Annotation annotation : getInterceptorBindings()) {
if (annotation.annotationType().equals(annotationType)) {
return (T) annotation;
}
}
return null;
}

@Override
public <T extends Annotation> List<T> findIterceptorBindings(Class<T> annotationType) {
List<T> found = new ArrayList<>();
for (Annotation annotation : getInterceptorBindings()) {
if (annotation.annotationType().equals(annotationType)) {
found.add((T) annotation);
}
}
return found;
}

public static class InterceptorInvocation {

@SuppressWarnings("rawtypes")
private final Interceptor interceptor;

private final Object interceptorInstance;

public InterceptorInvocation(final Interceptor<?> interceptor, final Object interceptorInstance) {
this.interceptor = interceptor;
this.interceptorInstance = interceptorInstance;
}

@SuppressWarnings("unchecked")
Object invoke(InvocationContext ctx) throws Exception {
return interceptor.intercept(InterceptionType.AROUND_INVOKE, interceptorInstance, ctx);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import org.jboss.logging.Logger;
import org.jboss.resteasy.client.jaxrs.ResteasyClient;
import org.jboss.resteasy.microprofile.client.ExceptionMapping;
import org.jboss.resteasy.microprofile.client.InvocationContextImpl;
import org.jboss.resteasy.microprofile.client.RestClientProxy;
import org.jboss.resteasy.microprofile.client.header.ClientHeaderFillingException;

Expand All @@ -52,7 +51,9 @@ public class QuarkusProxyInvocationHandler implements InvocationHandler {

private final Set<Object> providerInstances;

private final Map<Method, List<InvocationContextImpl.InterceptorInvocation>> interceptorChains;
private final Map<Method, List<QuarkusInvocationContextImpl.InterceptorInvocation>> interceptorChains;

private final Map<Method, Set<Annotation>> interceptorBindingsMap;

private final ResteasyClient client;

Expand All @@ -70,10 +71,13 @@ public QuarkusProxyInvocationHandler(final Class<?> restClientInterface,
this.closed = new AtomicBoolean();
if (beanManager != null) {
this.creationalContext = beanManager.createCreationalContext(null);
this.interceptorChains = initInterceptorChains(beanManager, creationalContext, restClientInterface);
this.interceptorBindingsMap = new HashMap<>();
this.interceptorChains = initInterceptorChains(beanManager, creationalContext, restClientInterface,
interceptorBindingsMap);
} else {
this.creationalContext = null;
this.interceptorChains = Collections.emptyMap();
this.interceptorBindingsMap = Collections.emptyMap();
}
}

Expand Down Expand Up @@ -152,10 +156,10 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Throwabl
args = argsReplacement;
}

List<InvocationContextImpl.InterceptorInvocation> chain = interceptorChains.get(method);
List<QuarkusInvocationContextImpl.InterceptorInvocation> chain = interceptorChains.get(method);
if (chain != null) {
// Invoke business method interceptors
return new InvocationContextImpl(target, method, args, chain).proceed();
return new QuarkusInvocationContextImpl(target, method, args, chain, interceptorBindingsMap.get(method)).proceed();
} else {
try {
return method.invoke(target, args);
Expand Down Expand Up @@ -245,10 +249,11 @@ private static BeanManager getBeanManager(Class<?> restClientInterface) {
}
}

private static Map<Method, List<InvocationContextImpl.InterceptorInvocation>> initInterceptorChains(
BeanManager beanManager, CreationalContext<?> creationalContext, Class<?> restClientInterface) {
private static Map<Method, List<QuarkusInvocationContextImpl.InterceptorInvocation>> initInterceptorChains(
BeanManager beanManager, CreationalContext<?> creationalContext, Class<?> restClientInterface,
Map<Method, Set<Annotation>> interceptorBindingsMap) {

Map<Method, List<InvocationContextImpl.InterceptorInvocation>> chains = new HashMap<>();
Map<Method, List<QuarkusInvocationContextImpl.InterceptorInvocation>> chains = new HashMap<>();
// Interceptor as a key in a map is not entirely correct (custom interceptors) but should work in most cases
Map<Interceptor<?>, Object> interceptorInstances = new HashMap<>();

Expand All @@ -267,12 +272,13 @@ private static Map<Method, List<InvocationContextImpl.InterceptorInvocation>> in
List<Interceptor<?>> interceptors = beanManager.resolveInterceptors(InterceptionType.AROUND_INVOKE,
interceptorBindings);
if (!interceptors.isEmpty()) {
List<InvocationContextImpl.InterceptorInvocation> chain = new ArrayList<>();
List<QuarkusInvocationContextImpl.InterceptorInvocation> chain = new ArrayList<>();
for (Interceptor<?> interceptor : interceptors) {
chain.add(new InvocationContextImpl.InterceptorInvocation(interceptor,
chain.add(new QuarkusInvocationContextImpl.InterceptorInvocation(interceptor,
interceptorInstances.computeIfAbsent(interceptor,
i -> beanManager.getReference(i, i.getBeanClass(), creationalContext))));
}
interceptorBindingsMap.put(method, Set.of(interceptorBindings));
chains.put(method, chain);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import org.eclipse.microprofile.rest.client.inject.RegisterRestClient;
import org.eclipse.microprofile.rest.client.inject.RestClient;

import io.opentelemetry.instrumentation.annotations.SpanAttribute;
import io.opentelemetry.instrumentation.annotations.WithSpan;
import io.smallrye.common.annotation.Blocking;
import io.smallrye.mutiny.Uni;
import io.vertx.core.MultiMap;
Expand All @@ -34,6 +36,11 @@ public interface PingPongRestClient {
@GET
@Path("/client/pong/{message}")
Uni<String> asyncPingpong(@PathParam("message") String message);

@GET
@Path("/client/pong/{message}")
@WithSpan
String pingpongIntercept(@SpanAttribute(value = "message") @PathParam("message") String message);
}

@Inject
Expand Down Expand Up @@ -81,4 +88,9 @@ public Uni<String> asyncPingNamed(@PathParam("message") String message) {
.onItemOrFailure().call(httpClient::close);
}

@GET
@Path("pong-intercept/{message}")
public String pongIntercept(@PathParam("message") String message) {
return pingRestClient.pingpongIntercept(message);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

Expand Down Expand Up @@ -557,6 +558,79 @@ void testAsyncClientTracing() {
assertNotNull(clientServer.get("attr_user_agent.original"));
}

@Test
void testClientTracingWithInterceptor() {
given()
.when().get("/client/pong-intercept/one")
.then()
.statusCode(200)
.body(containsString("one"));

await().atMost(5, SECONDS).until(() -> getSpans().size() == 4);
List<Map<String, Object>> spans = getSpans();
assertEquals(4, spans.size());
assertEquals(1, spans.stream().map(map -> map.get("traceId")).collect(toSet()).size());

Map<String, Object> server = getSpanByKindAndParentId(spans, SERVER, "0000000000000000");
assertEquals(SERVER.toString(), server.get("kind"));
verifyResource(server);
assertEquals("GET /client/pong-intercept/{message}", server.get("name"));
assertEquals(SERVER.toString(), server.get("kind"));
assertTrue((Boolean) server.get("ended"));
assertEquals(SpanId.getInvalid(), server.get("parent_spanId"));
assertEquals(TraceId.getInvalid(), server.get("parent_traceId"));
assertFalse((Boolean) server.get("parent_valid"));
assertFalse((Boolean) server.get("parent_remote"));
assertEquals("GET", server.get("attr_http.method"));
assertEquals("/client/pong-intercept/one", server.get("attr_http.target"));
assertEquals(pathParamUrl.getHost(), server.get("attr_net.host.name"));
assertEquals(pathParamUrl.getPort(), Integer.valueOf((String) server.get("attr_net.host.port")));
assertEquals("http", server.get("attr_http.scheme"));
assertEquals("/client/pong-intercept/{message}", server.get("attr_http.route"));
assertEquals("200", server.get("attr_http.status_code"));
assertNotNull(server.get("attr_http.client_ip"));
assertNotNull(server.get("attr_user_agent.original"));

Map<String, Object> fromInterceptor = getSpanByKindAndParentId(spans, INTERNAL, server.get("spanId"));
assertEquals("PingPongRestClient.pingpongIntercept", fromInterceptor.get("name"));
assertEquals(INTERNAL.toString(), fromInterceptor.get("kind"));
assertTrue((Boolean) fromInterceptor.get("ended"));
assertTrue((Boolean) fromInterceptor.get("parent_valid"));
assertFalse((Boolean) fromInterceptor.get("parent_remote"));
assertNull(fromInterceptor.get("attr_http.method"));
assertNull(fromInterceptor.get("attr_http.status_code"));
assertEquals("one", fromInterceptor.get("attr_message"));

Map<String, Object> client = getSpanByKindAndParentId(spans, CLIENT, fromInterceptor.get("spanId"));
assertEquals("GET", client.get("name"));
assertEquals(SpanKind.CLIENT.toString(), client.get("kind"));
assertTrue((Boolean) client.get("ended"));
assertTrue((Boolean) client.get("parent_valid"));
assertFalse((Boolean) client.get("parent_remote"));
assertEquals("GET", client.get("attr_http.method"));
assertEquals("http://localhost:8081/client/pong/one", client.get("attr_http.url"));
assertEquals("200", client.get("attr_http.status_code"));

Map<String, Object> clientServer = getSpanByKindAndParentId(spans, SERVER, client.get("spanId"));
assertEquals(SERVER.toString(), clientServer.get("kind"));
verifyResource(clientServer);
assertEquals("GET /client/pong/{message}", clientServer.get("name"));
assertEquals(SERVER.toString(), clientServer.get("kind"));
assertTrue((Boolean) clientServer.get("ended"));
assertTrue((Boolean) clientServer.get("parent_valid"));
assertTrue((Boolean) clientServer.get("parent_remote"));
assertEquals("GET", clientServer.get("attr_http.method"));
assertEquals("/client/pong/one", clientServer.get("attr_http.target"));
assertEquals(pathParamUrl.getHost(), server.get("attr_net.host.name"));
assertEquals(pathParamUrl.getPort(), Integer.valueOf((String) server.get("attr_net.host.port")));
assertEquals("http", clientServer.get("attr_http.scheme"));
assertEquals("/client/pong/{message}", clientServer.get("attr_http.route"));
assertEquals("200", clientServer.get("attr_http.status_code"));
assertNotNull(clientServer.get("attr_http.client_ip"));
assertNotNull(clientServer.get("attr_user_agent.original"));
assertEquals(clientServer.get("parentSpanId"), client.get("spanId"));
}

@Test
void testTemplatedPathOnClass() {
given()
Expand Down

0 comments on commit aaa5689

Please sign in to comment.