Skip to content

Commit

Permalink
Merge pull request #32942 from Sgitario/context_provider_priority
Browse files Browse the repository at this point in the history
Taking into account the `@Priority` annotation when registering context providers
  • Loading branch information
Sgitario authored Apr 28, 2023
2 parents 3f4e1fe + a96d221 commit 7f65037
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package io.quarkus.rest.client.reactive.provider;

import static io.restassured.RestAssured.given;
import static java.lang.String.format;
import static org.assertj.core.api.Assertions.assertThat;

import java.net.URI;
import java.util.List;
import java.util.Map;

import jakarta.annotation.Priority;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MultivaluedHashMap;
import jakarta.ws.rs.core.MultivaluedMap;
import jakarta.ws.rs.ext.ContextResolver;

import org.eclipse.microprofile.rest.client.ext.ClientHeadersFactory;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
import io.quarkus.rest.client.reactive.TestJacksonBasicMessageBodyReader;
import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;

public class ContextProvidersPriorityTest {
private static final String HEADER_NAME = "my-header";
private static final String HEADER_VALUE_FROM_LOW_PRIORITY = "low-priority";
private static final String HEADER_VALUE_FROM_HIGH_PRIORITY = "high-priority";

@TestHTTPResource
URI baseUri;

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest()
.withApplicationRoot((jar) -> jar.addClasses(Client.class, TestJacksonBasicMessageBodyReader.class));

@Test
void shouldUseTheHighestPriorityContextProvider() {
// @formatter:off
var response =
given()
.body(baseUri.toString())
.when()
.post("/call-client")
.thenReturn();
// @formatter:on
assertThat(response.statusCode()).isEqualTo(200);
assertThat(response.jsonPath().getString(HEADER_NAME)).isEqualTo(format("[%s]", HEADER_VALUE_FROM_HIGH_PRIORITY));
}

@Path("/")
@ApplicationScoped
public static class Resource {

@GET
@Produces("application/json")
public Map<String, List<String>> returnHeaderValues(@Context HttpHeaders headers) {
return headers.getRequestHeaders();
}

@Path("/call-client")
@POST
public Map<String, List<String>> callClient(String uri) {
Client client = QuarkusRestClientBuilder.newBuilder()
.baseUri(URI.create(uri))
.register(LowPriorityClientHeadersProvider.class)
.register(HighPriorityClientHeadersProvider.class)
.register(new TestJacksonBasicMessageBodyReader())
.build(Client.class);
return client.get();
}
}

public interface Client {
@GET
Map<String, List<String>> get();
}

@Priority(2)
public static class LowPriorityClientHeadersProvider implements ContextResolver<ClientHeadersFactory> {

@Override
public ClientHeadersFactory getContext(Class<?> aClass) {
return new CustomClientHeadersFactory(HEADER_VALUE_FROM_LOW_PRIORITY);
}
}

@Priority(1)
public static class HighPriorityClientHeadersProvider implements ContextResolver<ClientHeadersFactory> {

@Override
public ClientHeadersFactory getContext(Class<?> aClass) {
return new CustomClientHeadersFactory(HEADER_VALUE_FROM_HIGH_PRIORITY);
}
}

public static class CustomClientHeadersFactory implements ClientHeadersFactory {

private final String value;

public CustomClientHeadersFactory(String value) {
this.value = value;
}

@Override
public MultivaluedMap<String, String> update(MultivaluedMap<String, String> multivaluedMap,
MultivaluedMap<String, String> multivaluedMap1) {
MultivaluedHashMap<String, String> newHeaders = new MultivaluedHashMap<>();
newHeaders.add(HEADER_NAME, value);
return newHeaders;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public class ConfigurationImpl implements Configuration {
private final MultivaluedMap<Class<?>, ResourceWriter> resourceWriters;
private final MultivaluedMap<Class<?>, ResourceReader> resourceReaders;
private final MultivaluedMap<Class<?>, RxInvokerProvider<?>> rxInvokerProviders;
private final MultivaluedMap<Class<?>, ContextResolver<?>> contextResolvers;
private final Map<Class<?>, MultivaluedMap<Integer, ContextResolver<?>>> contextResolvers;

public ConfigurationImpl(RuntimeType runtimeType) {
this.runtimeType = runtimeType;
Expand All @@ -69,7 +69,7 @@ public ConfigurationImpl(RuntimeType runtimeType) {
this.resourceReaders = new QuarkusMultivaluedHashMap<>();
this.resourceWriters = new QuarkusMultivaluedHashMap<>();
this.rxInvokerProviders = new QuarkusMultivaluedHashMap<>();
this.contextResolvers = new QuarkusMultivaluedHashMap<>();
this.contextResolvers = new HashMap<>();
}

public ConfigurationImpl(Configuration configuration) {
Expand All @@ -96,7 +96,7 @@ public ConfigurationImpl(Configuration configuration) {
this.resourceWriters.putAll(configurationImpl.resourceWriters);
this.rxInvokerProviders = new QuarkusMultivaluedHashMap<>();
this.rxInvokerProviders.putAll(configurationImpl.rxInvokerProviders);
this.contextResolvers = new QuarkusMultivaluedHashMap<>();
this.contextResolvers = new HashMap<>();
this.contextResolvers.putAll(configurationImpl.contextResolvers);
} else {
this.allInstances = new HashMap<>();
Expand All @@ -111,7 +111,7 @@ public ConfigurationImpl(Configuration configuration) {
this.resourceReaders = new QuarkusMultivaluedHashMap<>();
this.resourceWriters = new QuarkusMultivaluedHashMap<>();
this.rxInvokerProviders = new QuarkusMultivaluedHashMap<>();
this.contextResolvers = new QuarkusMultivaluedHashMap<>();
this.contextResolvers = new HashMap<>();
// this is the best we can do - we don't have any of the metadata associated with the registration
for (Object i : configuration.getInstances()) {
register(i);
Expand Down Expand Up @@ -314,8 +314,10 @@ private void register(Object component, Integer priority) {
added = true;
Class<?> componentClass = component.getClass();
Type[] args = Types.findParameterizedTypes(componentClass, ContextResolver.class);
contextResolvers.add(args != null && args.length == 1 ? Types.getRawType(args[0]) : Object.class,
(ContextResolver<?>) component);
Class<?> key = args != null && args.length == 1 ? Types.getRawType(args[0]) : Object.class;
int effectivePriority = priority != null ? priority : determinePriority(component);
contextResolvers.computeIfAbsent(key, k -> new MultivaluedTreeMap<>())
.add(effectivePriority, (ContextResolver<?>) component);
}
if (added) {
allInstances.put(component.getClass(), component);
Expand Down Expand Up @@ -419,8 +421,10 @@ public void register(Object component, Map<Class<?>, Integer> componentContracts
if (component instanceof ContextResolver) {
added = true;
Type[] args = Types.findParameterizedTypes(componentClass, ContextResolver.class);
contextResolvers.add(args != null && args.length == 1 ? Types.getRawType(args[0]) : Object.class,
(ContextResolver<?>) component);
Class<?> key = args != null && args.length == 1 ? Types.getRawType(args[0]) : Object.class;
int effectivePriority = priority != null ? priority : determinePriority(component);
contextResolvers.computeIfAbsent(key, k -> new MultivaluedTreeMap<>())
.add(effectivePriority, (ContextResolver<?>) component);
}
if (added) {
allInstances.put(componentClass, component);
Expand Down Expand Up @@ -525,14 +529,16 @@ public RxInvokerProvider<?> getRxInvokerProvider(Class<?> wantedClass) {
}

public <T> T getFromContext(Class<T> wantedClass) {
List<ContextResolver<?>> candidates = contextResolvers.get(wantedClass);
MultivaluedMap<Integer, ContextResolver<?>> candidates = contextResolvers.get(wantedClass);
if (candidates == null) {
return null;
}
for (ContextResolver<?> contextResolver : candidates) {
Object instance = contextResolver.getContext(wantedClass);
if (instance != null) {
return (T) instance;
for (List<ContextResolver<?>> contextResolvers : candidates.values()) {
for (ContextResolver<?> contextResolver : contextResolvers) {
Object instance = contextResolver.getContext(wantedClass);
if (instance != null) {
return (T) instance;
}
}
}

Expand Down

0 comments on commit 7f65037

Please sign in to comment.