diff --git a/extensions/resteasy-reactive/rest-client-reactive/deployment/src/test/java/io/quarkus/rest/client/reactive/provider/ContextProvidersPriorityTest.java b/extensions/resteasy-reactive/rest-client-reactive/deployment/src/test/java/io/quarkus/rest/client/reactive/provider/ContextProvidersPriorityTest.java new file mode 100644 index 0000000000000..c10b4923d2f1f --- /dev/null +++ b/extensions/resteasy-reactive/rest-client-reactive/deployment/src/test/java/io/quarkus/rest/client/reactive/provider/ContextProvidersPriorityTest.java @@ -0,0 +1,120 @@ +package io.quarkus.rest.client.reactive.provider; + +import static io.restassured.RestAssured.given; +import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThat; + +import java.net.URI; +import java.util.List; +import java.util.Map; + +import jakarta.annotation.Priority; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.HttpHeaders; +import jakarta.ws.rs.core.MultivaluedHashMap; +import jakarta.ws.rs.core.MultivaluedMap; +import jakarta.ws.rs.ext.ContextResolver; + +import org.eclipse.microprofile.rest.client.ext.ClientHeadersFactory; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; +import io.quarkus.rest.client.reactive.TestJacksonBasicMessageBodyReader; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; + +public class ContextProvidersPriorityTest { + private static final String HEADER_NAME = "my-header"; + private static final String HEADER_VALUE_FROM_LOW_PRIORITY = "low-priority"; + private static final String HEADER_VALUE_FROM_HIGH_PRIORITY = "high-priority"; + + @TestHTTPResource + URI baseUri; + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .withApplicationRoot((jar) -> jar.addClasses(Client.class, TestJacksonBasicMessageBodyReader.class)); + + @Test + void shouldUseTheHighestPriorityContextProvider() { + // @formatter:off + var response = + given() + .body(baseUri.toString()) + .when() + .post("/call-client") + .thenReturn(); + // @formatter:on + assertThat(response.statusCode()).isEqualTo(200); + assertThat(response.jsonPath().getString(HEADER_NAME)).isEqualTo(format("[%s]", HEADER_VALUE_FROM_HIGH_PRIORITY)); + } + + @Path("/") + @ApplicationScoped + public static class Resource { + + @GET + @Produces("application/json") + public Map> returnHeaderValues(@Context HttpHeaders headers) { + return headers.getRequestHeaders(); + } + + @Path("/call-client") + @POST + public Map> callClient(String uri) { + Client client = QuarkusRestClientBuilder.newBuilder() + .baseUri(URI.create(uri)) + .register(LowPriorityClientHeadersProvider.class) + .register(HighPriorityClientHeadersProvider.class) + .register(new TestJacksonBasicMessageBodyReader()) + .build(Client.class); + return client.get(); + } + } + + public interface Client { + @GET + Map> get(); + } + + @Priority(2) + public static class LowPriorityClientHeadersProvider implements ContextResolver { + + @Override + public ClientHeadersFactory getContext(Class aClass) { + return new CustomClientHeadersFactory(HEADER_VALUE_FROM_LOW_PRIORITY); + } + } + + @Priority(1) + public static class HighPriorityClientHeadersProvider implements ContextResolver { + + @Override + public ClientHeadersFactory getContext(Class aClass) { + return new CustomClientHeadersFactory(HEADER_VALUE_FROM_HIGH_PRIORITY); + } + } + + public static class CustomClientHeadersFactory implements ClientHeadersFactory { + + private final String value; + + public CustomClientHeadersFactory(String value) { + this.value = value; + } + + @Override + public MultivaluedMap update(MultivaluedMap multivaluedMap, + MultivaluedMap multivaluedMap1) { + MultivaluedHashMap newHeaders = new MultivaluedHashMap<>(); + newHeaders.add(HEADER_NAME, value); + return newHeaders; + } + } +} diff --git a/independent-projects/resteasy-reactive/common/runtime/src/main/java/org/jboss/resteasy/reactive/common/jaxrs/ConfigurationImpl.java b/independent-projects/resteasy-reactive/common/runtime/src/main/java/org/jboss/resteasy/reactive/common/jaxrs/ConfigurationImpl.java index 33f06b164fd69..588237450831f 100644 --- a/independent-projects/resteasy-reactive/common/runtime/src/main/java/org/jboss/resteasy/reactive/common/jaxrs/ConfigurationImpl.java +++ b/independent-projects/resteasy-reactive/common/runtime/src/main/java/org/jboss/resteasy/reactive/common/jaxrs/ConfigurationImpl.java @@ -54,7 +54,7 @@ public class ConfigurationImpl implements Configuration { private final MultivaluedMap, ResourceWriter> resourceWriters; private final MultivaluedMap, ResourceReader> resourceReaders; private final MultivaluedMap, RxInvokerProvider> rxInvokerProviders; - private final MultivaluedMap, ContextResolver> contextResolvers; + private final Map, MultivaluedMap>> contextResolvers; public ConfigurationImpl(RuntimeType runtimeType) { this.runtimeType = runtimeType; @@ -69,7 +69,7 @@ public ConfigurationImpl(RuntimeType runtimeType) { this.resourceReaders = new QuarkusMultivaluedHashMap<>(); this.resourceWriters = new QuarkusMultivaluedHashMap<>(); this.rxInvokerProviders = new QuarkusMultivaluedHashMap<>(); - this.contextResolvers = new QuarkusMultivaluedHashMap<>(); + this.contextResolvers = new HashMap<>(); } public ConfigurationImpl(Configuration configuration) { @@ -96,7 +96,7 @@ public ConfigurationImpl(Configuration configuration) { this.resourceWriters.putAll(configurationImpl.resourceWriters); this.rxInvokerProviders = new QuarkusMultivaluedHashMap<>(); this.rxInvokerProviders.putAll(configurationImpl.rxInvokerProviders); - this.contextResolvers = new QuarkusMultivaluedHashMap<>(); + this.contextResolvers = new HashMap<>(); this.contextResolvers.putAll(configurationImpl.contextResolvers); } else { this.allInstances = new HashMap<>(); @@ -111,7 +111,7 @@ public ConfigurationImpl(Configuration configuration) { this.resourceReaders = new QuarkusMultivaluedHashMap<>(); this.resourceWriters = new QuarkusMultivaluedHashMap<>(); this.rxInvokerProviders = new QuarkusMultivaluedHashMap<>(); - this.contextResolvers = new QuarkusMultivaluedHashMap<>(); + this.contextResolvers = new HashMap<>(); // this is the best we can do - we don't have any of the metadata associated with the registration for (Object i : configuration.getInstances()) { register(i); @@ -314,8 +314,10 @@ private void register(Object component, Integer priority) { added = true; Class componentClass = component.getClass(); Type[] args = Types.findParameterizedTypes(componentClass, ContextResolver.class); - contextResolvers.add(args != null && args.length == 1 ? Types.getRawType(args[0]) : Object.class, - (ContextResolver) component); + Class key = args != null && args.length == 1 ? Types.getRawType(args[0]) : Object.class; + int effectivePriority = priority != null ? priority : determinePriority(component); + contextResolvers.computeIfAbsent(key, k -> new MultivaluedTreeMap<>()) + .add(effectivePriority, (ContextResolver) component); } if (added) { allInstances.put(component.getClass(), component); @@ -419,8 +421,10 @@ public void register(Object component, Map, Integer> componentContracts if (component instanceof ContextResolver) { added = true; Type[] args = Types.findParameterizedTypes(componentClass, ContextResolver.class); - contextResolvers.add(args != null && args.length == 1 ? Types.getRawType(args[0]) : Object.class, - (ContextResolver) component); + Class key = args != null && args.length == 1 ? Types.getRawType(args[0]) : Object.class; + int effectivePriority = priority != null ? priority : determinePriority(component); + contextResolvers.computeIfAbsent(key, k -> new MultivaluedTreeMap<>()) + .add(effectivePriority, (ContextResolver) component); } if (added) { allInstances.put(componentClass, component); @@ -525,14 +529,16 @@ public RxInvokerProvider getRxInvokerProvider(Class wantedClass) { } public T getFromContext(Class wantedClass) { - List> candidates = contextResolvers.get(wantedClass); + MultivaluedMap> candidates = contextResolvers.get(wantedClass); if (candidates == null) { return null; } - for (ContextResolver contextResolver : candidates) { - Object instance = contextResolver.getContext(wantedClass); - if (instance != null) { - return (T) instance; + for (List> contextResolvers : candidates.values()) { + for (ContextResolver contextResolver : contextResolvers) { + Object instance = contextResolver.getContext(wantedClass); + if (instance != null) { + return (T) instance; + } } }