Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make rest-client invocation context implement ArcInvocationContext #36123

Merged
merged 2 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
package io.quarkus.restclient.runtime;

import java.lang.annotation.Annotation;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletionException;

import jakarta.enterprise.inject.spi.InterceptionType;
import jakarta.enterprise.inject.spi.Interceptor;
import jakarta.interceptor.InvocationContext;
import jakarta.ws.rs.client.ResponseProcessingException;

import org.jboss.resteasy.microprofile.client.ExceptionMapping;

import io.quarkus.arc.ArcInvocationContext;

/**
* A Quarkus copy of {@link org.jboss.resteasy.microprofile.client.InvocationContextImpl} which makes it implement
* {@link ArcInvocationContext} instead so that it's compatible with Quarkus interceptors.
*/
public class QuarkusInvocationContextImpl implements ArcInvocationContext {

private final Object target;

private final Method method;

private Object[] args;

private final int position;

private final Map<String, Object> contextData;

private final List<QuarkusInvocationContextImpl.InterceptorInvocation> chain;

private final Set<Annotation> interceptorBindings;

public QuarkusInvocationContextImpl(final Object target, final Method method, final Object[] args,
final List<QuarkusInvocationContextImpl.InterceptorInvocation> chain, Set<Annotation> interceptorBindings) {
this(target, method, args, chain, 0, interceptorBindings);
}

private QuarkusInvocationContextImpl(final Object target, final Method method, final Object[] args,
final List<QuarkusInvocationContextImpl.InterceptorInvocation> chain, final int position,
Set<Annotation> interceptorBindings) {
this.target = target;
this.method = method;
this.args = args;
this.interceptorBindings = interceptorBindings == null ? Collections.emptySet() : interceptorBindings;
this.contextData = new HashMap<>();
// put in bindings under Arc's specific key
this.contextData.put(ArcInvocationContext.KEY_INTERCEPTOR_BINDINGS, interceptorBindings);
this.position = position;
this.chain = chain;
}

boolean hasNextInterceptor() {
return position < chain.size();
}

protected Object invokeNext() throws Exception {
return chain.get(position).invoke(nextContext());
}

private InvocationContext nextContext() {
return new QuarkusInvocationContextImpl(target, method, args, chain, position + 1, interceptorBindings);
}

protected Object interceptorChainCompleted() throws Exception {
try {
return method.invoke(target, args);
} catch (InvocationTargetException e) {
Throwable cause = e.getCause();
if (cause instanceof CompletionException) {
cause = cause.getCause();
}
if (cause instanceof ExceptionMapping.HandlerException) {
((ExceptionMapping.HandlerException) cause).mapException(method);
}
if (cause instanceof ResponseProcessingException) {
ResponseProcessingException rpe = (ResponseProcessingException) cause;
// Note that the default client engine leverages a single connection
// MP FT: we need to close the response otherwise we would not be able to retry if the method returns jakarta.ws.rs.core.Response
rpe.getResponse().close();
cause = rpe.getCause();
if (cause instanceof RuntimeException) {
throw (RuntimeException) cause;
}
}
throw e;
}
}

@Override
public Object proceed() throws Exception {
try {
if (hasNextInterceptor()) {
return invokeNext();
} else {
return interceptorChainCompleted();
}
} catch (InvocationTargetException e) {
Throwable cause = e.getCause();
if (cause instanceof Error) {
throw (Error) cause;
}
if (cause instanceof Exception) {
throw (Exception) cause;
}
throw new RuntimeException(cause);
}
}

@Override
public Object getTarget() {
return target;
}

@Override
public Method getMethod() {
return method;
}

@Override
public Constructor<?> getConstructor() {
return null;
}

@Override
public Object[] getParameters() throws IllegalStateException {
return args;
}

@Override
public void setParameters(Object[] params) throws IllegalStateException, IllegalArgumentException {
this.args = params;
}

@Override
public Map<String, Object> getContextData() {
return contextData;
}

@Override
public Object getTimer() {
return null;
}

@Override
public Set<Annotation> getInterceptorBindings() {
return interceptorBindings;
}

@Override
public <T extends Annotation> T findIterceptorBinding(Class<T> annotationType) {
for (Annotation annotation : getInterceptorBindings()) {
if (annotation.annotationType().equals(annotationType)) {
return (T) annotation;
}
}
return null;
}

@Override
public <T extends Annotation> List<T> findIterceptorBindings(Class<T> annotationType) {
List<T> found = new ArrayList<>();
for (Annotation annotation : getInterceptorBindings()) {
if (annotation.annotationType().equals(annotationType)) {
found.add((T) annotation);
}
}
return found;
}

public static class InterceptorInvocation {

@SuppressWarnings("rawtypes")
private final Interceptor interceptor;

private final Object interceptorInstance;

public InterceptorInvocation(final Interceptor<?> interceptor, final Object interceptorInstance) {
this.interceptor = interceptor;
this.interceptorInstance = interceptorInstance;
}

@SuppressWarnings("unchecked")
Object invoke(InvocationContext ctx) throws Exception {
return interceptor.intercept(InterceptionType.AROUND_INVOKE, interceptorInstance, ctx);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import org.jboss.logging.Logger;
import org.jboss.resteasy.client.jaxrs.ResteasyClient;
import org.jboss.resteasy.microprofile.client.ExceptionMapping;
import org.jboss.resteasy.microprofile.client.InvocationContextImpl;
import org.jboss.resteasy.microprofile.client.RestClientProxy;
import org.jboss.resteasy.microprofile.client.header.ClientHeaderFillingException;

Expand All @@ -52,7 +51,9 @@ public class QuarkusProxyInvocationHandler implements InvocationHandler {

private final Set<Object> providerInstances;

private final Map<Method, List<InvocationContextImpl.InterceptorInvocation>> interceptorChains;
private final Map<Method, List<QuarkusInvocationContextImpl.InterceptorInvocation>> interceptorChains;

private final Map<Method, Set<Annotation>> interceptorBindingsMap;

private final ResteasyClient client;

Expand All @@ -70,10 +71,13 @@ public QuarkusProxyInvocationHandler(final Class<?> restClientInterface,
this.closed = new AtomicBoolean();
if (beanManager != null) {
this.creationalContext = beanManager.createCreationalContext(null);
this.interceptorChains = initInterceptorChains(beanManager, creationalContext, restClientInterface);
this.interceptorBindingsMap = new HashMap<>();
this.interceptorChains = initInterceptorChains(beanManager, creationalContext, restClientInterface,
interceptorBindingsMap);
} else {
this.creationalContext = null;
this.interceptorChains = Collections.emptyMap();
this.interceptorBindingsMap = Collections.emptyMap();
}
}

Expand Down Expand Up @@ -152,10 +156,10 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Throwabl
args = argsReplacement;
}

List<InvocationContextImpl.InterceptorInvocation> chain = interceptorChains.get(method);
List<QuarkusInvocationContextImpl.InterceptorInvocation> chain = interceptorChains.get(method);
if (chain != null) {
// Invoke business method interceptors
return new InvocationContextImpl(target, method, args, chain).proceed();
return new QuarkusInvocationContextImpl(target, method, args, chain, interceptorBindingsMap.get(method)).proceed();
} else {
try {
return method.invoke(target, args);
Expand Down Expand Up @@ -245,10 +249,11 @@ private static BeanManager getBeanManager(Class<?> restClientInterface) {
}
}

private static Map<Method, List<InvocationContextImpl.InterceptorInvocation>> initInterceptorChains(
BeanManager beanManager, CreationalContext<?> creationalContext, Class<?> restClientInterface) {
private static Map<Method, List<QuarkusInvocationContextImpl.InterceptorInvocation>> initInterceptorChains(
BeanManager beanManager, CreationalContext<?> creationalContext, Class<?> restClientInterface,
Map<Method, Set<Annotation>> interceptorBindingsMap) {

Map<Method, List<InvocationContextImpl.InterceptorInvocation>> chains = new HashMap<>();
Map<Method, List<QuarkusInvocationContextImpl.InterceptorInvocation>> chains = new HashMap<>();
// Interceptor as a key in a map is not entirely correct (custom interceptors) but should work in most cases
Map<Interceptor<?>, Object> interceptorInstances = new HashMap<>();

Expand All @@ -267,12 +272,13 @@ private static Map<Method, List<InvocationContextImpl.InterceptorInvocation>> in
List<Interceptor<?>> interceptors = beanManager.resolveInterceptors(InterceptionType.AROUND_INVOKE,
interceptorBindings);
if (!interceptors.isEmpty()) {
List<InvocationContextImpl.InterceptorInvocation> chain = new ArrayList<>();
List<QuarkusInvocationContextImpl.InterceptorInvocation> chain = new ArrayList<>();
for (Interceptor<?> interceptor : interceptors) {
chain.add(new InvocationContextImpl.InterceptorInvocation(interceptor,
chain.add(new QuarkusInvocationContextImpl.InterceptorInvocation(interceptor,
interceptorInstances.computeIfAbsent(interceptor,
i -> beanManager.getReference(i, i.getBeanClass(), creationalContext))));
}
interceptorBindingsMap.put(method, Set.of(interceptorBindings));
chains.put(method, chain);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import org.eclipse.microprofile.rest.client.inject.RegisterRestClient;
import org.eclipse.microprofile.rest.client.inject.RestClient;

import io.opentelemetry.instrumentation.annotations.SpanAttribute;
import io.opentelemetry.instrumentation.annotations.WithSpan;
import io.smallrye.common.annotation.Blocking;
import io.smallrye.mutiny.Uni;
import io.vertx.core.MultiMap;
Expand All @@ -34,6 +36,11 @@ public interface PingPongRestClient {
@GET
@Path("/client/pong/{message}")
Uni<String> asyncPingpong(@PathParam("message") String message);

@GET
@Path("/client/pong/{message}")
@WithSpan
String pingpongIntercept(@SpanAttribute(value = "message") @PathParam("message") String message);
}

@Inject
Expand Down Expand Up @@ -81,4 +88,9 @@ public Uni<String> asyncPingNamed(@PathParam("message") String message) {
.onItemOrFailure().call(httpClient::close);
}

@GET
@Path("pong-intercept/{message}")
public String pongIntercept(@PathParam("message") String message) {
return pingRestClient.pingpongIntercept(message);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

Expand Down Expand Up @@ -557,6 +558,79 @@ void testAsyncClientTracing() {
assertNotNull(clientServer.get("attr_user_agent.original"));
}

@Test
void testClientTracingWithInterceptor() {
given()
.when().get("/client/pong-intercept/one")
.then()
.statusCode(200)
.body(containsString("one"));

await().atMost(5, SECONDS).until(() -> getSpans().size() == 4);
List<Map<String, Object>> spans = getSpans();
assertEquals(4, spans.size());
assertEquals(1, spans.stream().map(map -> map.get("traceId")).collect(toSet()).size());

Map<String, Object> server = getSpanByKindAndParentId(spans, SERVER, "0000000000000000");
assertEquals(SERVER.toString(), server.get("kind"));
verifyResource(server);
assertEquals("GET /client/pong-intercept/{message}", server.get("name"));
assertEquals(SERVER.toString(), server.get("kind"));
assertTrue((Boolean) server.get("ended"));
assertEquals(SpanId.getInvalid(), server.get("parent_spanId"));
assertEquals(TraceId.getInvalid(), server.get("parent_traceId"));
assertFalse((Boolean) server.get("parent_valid"));
assertFalse((Boolean) server.get("parent_remote"));
assertEquals("GET", server.get("attr_http.method"));
assertEquals("/client/pong-intercept/one", server.get("attr_http.target"));
assertEquals(pathParamUrl.getHost(), server.get("attr_net.host.name"));
assertEquals(pathParamUrl.getPort(), Integer.valueOf((String) server.get("attr_net.host.port")));
assertEquals("http", server.get("attr_http.scheme"));
assertEquals("/client/pong-intercept/{message}", server.get("attr_http.route"));
assertEquals("200", server.get("attr_http.status_code"));
assertNotNull(server.get("attr_http.client_ip"));
assertNotNull(server.get("attr_user_agent.original"));

Map<String, Object> fromInterceptor = getSpanByKindAndParentId(spans, INTERNAL, server.get("spanId"));
assertEquals("PingPongRestClient.pingpongIntercept", fromInterceptor.get("name"));
assertEquals(INTERNAL.toString(), fromInterceptor.get("kind"));
assertTrue((Boolean) fromInterceptor.get("ended"));
assertTrue((Boolean) fromInterceptor.get("parent_valid"));
assertFalse((Boolean) fromInterceptor.get("parent_remote"));
assertNull(fromInterceptor.get("attr_http.method"));
assertNull(fromInterceptor.get("attr_http.status_code"));
assertEquals("one", fromInterceptor.get("attr_message"));

Map<String, Object> client = getSpanByKindAndParentId(spans, CLIENT, fromInterceptor.get("spanId"));
assertEquals("GET", client.get("name"));
assertEquals(SpanKind.CLIENT.toString(), client.get("kind"));
assertTrue((Boolean) client.get("ended"));
assertTrue((Boolean) client.get("parent_valid"));
assertFalse((Boolean) client.get("parent_remote"));
assertEquals("GET", client.get("attr_http.method"));
assertEquals("http://localhost:8081/client/pong/one", client.get("attr_http.url"));
assertEquals("200", client.get("attr_http.status_code"));

Map<String, Object> clientServer = getSpanByKindAndParentId(spans, SERVER, client.get("spanId"));
assertEquals(SERVER.toString(), clientServer.get("kind"));
verifyResource(clientServer);
assertEquals("GET /client/pong/{message}", clientServer.get("name"));
assertEquals(SERVER.toString(), clientServer.get("kind"));
assertTrue((Boolean) clientServer.get("ended"));
assertTrue((Boolean) clientServer.get("parent_valid"));
assertTrue((Boolean) clientServer.get("parent_remote"));
assertEquals("GET", clientServer.get("attr_http.method"));
assertEquals("/client/pong/one", clientServer.get("attr_http.target"));
assertEquals(pathParamUrl.getHost(), server.get("attr_net.host.name"));
assertEquals(pathParamUrl.getPort(), Integer.valueOf((String) server.get("attr_net.host.port")));
assertEquals("http", clientServer.get("attr_http.scheme"));
assertEquals("/client/pong/{message}", clientServer.get("attr_http.route"));
assertEquals("200", clientServer.get("attr_http.status_code"));
assertNotNull(clientServer.get("attr_http.client_ip"));
assertNotNull(clientServer.get("attr_user_agent.original"));
assertEquals(clientServer.get("parentSpanId"), client.get("spanId"));
}

@Test
void testTemplatedPathOnClass() {
given()
Expand Down
Loading