Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rest Client Reactive - fixed @RegisterProvider support #16815

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import static io.quarkus.rest.client.reactive.deployment.DotNames.CLIENT_HEADER_PARAM;
import static io.quarkus.rest.client.reactive.deployment.DotNames.CLIENT_HEADER_PARAMS;
import static io.quarkus.rest.client.reactive.deployment.DotNames.REGISTER_CLIENT_HEADERS;
import static io.quarkus.rest.client.reactive.deployment.DotNames.REGISTER_PROVIDER;
import static io.quarkus.rest.client.reactive.deployment.DotNames.REGISTER_PROVIDERS;
import static org.jboss.resteasy.reactive.common.processor.HashUtil.sha1;
import static org.objectweb.asm.Opcodes.ACC_STATIC;

Expand Down Expand Up @@ -43,7 +41,6 @@
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.builditem.GeneratedClassBuildItem;
import io.quarkus.gizmo.AssignableResultHandle;
import io.quarkus.gizmo.BranchResult;
import io.quarkus.gizmo.BytecodeCreator;
import io.quarkus.gizmo.CatchBlockCreator;
import io.quarkus.gizmo.ClassCreator;
Expand All @@ -56,7 +53,6 @@
import io.quarkus.gizmo.TryBlock;
import io.quarkus.jaxrs.client.reactive.deployment.JaxrsClientReactiveEnricher;
import io.quarkus.rest.client.reactive.HeaderFiller;
import io.quarkus.rest.client.reactive.runtime.BeanGrabber;
import io.quarkus.rest.client.reactive.runtime.MicroProfileRestClientRequestFilter;
import io.quarkus.rest.client.reactive.runtime.NoOpHeaderFiller;
import io.quarkus.runtime.util.HashUtil;
Expand Down Expand Up @@ -91,16 +87,6 @@ class MicroProfileRestClientEnricher implements JaxrsClientReactiveEnricher {
public void forClass(MethodCreator constructor, AssignableResultHandle webTargetBase,
ClassInfo interfaceClass, IndexView index) {

AnnotationInstance annotation = interfaceClass.classAnnotation(REGISTER_PROVIDER);
AnnotationInstance groupAnnotation = interfaceClass.classAnnotation(REGISTER_PROVIDERS);

if (annotation != null) {
addProvider(constructor, webTargetBase, index, annotation);
}
for (AnnotationInstance annotationInstance : extractAnnotations(groupAnnotation)) {
addProvider(constructor, webTargetBase, index, annotationInstance);
}

ResultHandle clientHeadersFactory = null;

AnnotationInstance registerClientHeaders = interfaceClass.classAnnotation(REGISTER_CLIENT_HEADERS);
Expand Down Expand Up @@ -493,54 +479,14 @@ private AnnotationInstance[] extractAnnotations(AnnotationInstance groupAnnotati
return EMPTY_ANNOTATION_INSTANCES;
}

private void addProvider(MethodCreator ctor, AssignableResultHandle target, IndexView index,
AnnotationInstance registerProvider) {
// if a registered provider is a CDI bean, it has to be reused
// take the name of the provider class from the annotation:
String providerClass = registerProvider.value().asString();

// get bean, or null, with BeanGrabber.getBeanIfDefined(providerClass)
ResultHandle providerBean = ctor.invokeStaticMethod(
MethodDescriptor.ofMethod(BeanGrabber.class, "getBeanIfDefined", Object.class, Class.class),
ctor.loadClass(providerClass));

// if bean != null, register the bean
BranchResult branchResult = ctor.ifNotNull(providerBean);
BytecodeCreator beanProviderAvailable = branchResult.trueBranch();

ResultHandle alteredTarget = beanProviderAvailable.invokeInterfaceMethod(
MethodDescriptor.ofMethod(Configurable.class, "register", Configurable.class, Object.class,
int.class),
target, providerBean,
beanProviderAvailable.load(registerProvider.valueWithDefault(index, "priority").asInt()));
beanProviderAvailable.assign(target, alteredTarget);

// else, create a new instance of the provider class
ClassInfo providerClassInfo = index.getClassByName(DotName.createSimple(providerClass));
BytecodeCreator beanProviderNotAvailable = branchResult.falseBranch();
if ((providerClassInfo != null) && providerClassInfo.hasNoArgsConstructor()) { // if the filter has a no-args constructor, use it
ResultHandle provider = beanProviderNotAvailable.newInstance(MethodDescriptor.ofConstructor(providerClass));
alteredTarget = beanProviderNotAvailable.invokeInterfaceMethod(
MethodDescriptor.ofMethod(Configurable.class, "register", Configurable.class, Object.class,
int.class),
target, provider,
beanProviderNotAvailable.load(registerProvider.valueWithDefault(index, "priority").asInt()));
beanProviderNotAvailable.assign(target, alteredTarget);
} else { // the filter does not have a no-args constructor, so we can do nothing but fail
beanProviderNotAvailable.throwException(IllegalStateException.class,
"Provider " + providerClass + " must either be a CDI bean or have a no-args constructor");
}

}

/**
* ClientHeaderParam annotations can be defined on a JAX-RS interface or a sub-client (sub-resource).
* If we're filling headers for a sub-client, we need to know the defining class of the ClientHeaderParam
* to properly resolve default methods of the "root" client
*/
private static class HeaderData {
private AnnotationInstance annotation;
private ClassInfo definingClass;
private final AnnotationInstance annotation;
private final ClassInfo definingClass;

public HeaderData(AnnotationInstance annotation, ClassInfo definingClass) {
this.annotation = annotation;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import javax.enterprise.context.SessionScoped;
import javax.enterprise.inject.Typed;
import javax.inject.Singleton;
import javax.ws.rs.core.MediaType;

import org.eclipse.microprofile.config.Config;
Expand All @@ -40,6 +43,7 @@
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.deployment.GeneratedBeanBuildItem;
import io.quarkus.arc.deployment.GeneratedBeanGizmoAdaptor;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.arc.processor.BuiltinScope;
import io.quarkus.arc.processor.ScopeInfo;
import io.quarkus.deployment.Capabilities;
Expand All @@ -53,6 +57,7 @@
import io.quarkus.deployment.builditem.ConfigurationTypeBuildItem;
import io.quarkus.deployment.builditem.ExtensionSslNativeSupportBuildItem;
import io.quarkus.deployment.builditem.FeatureBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
import io.quarkus.gizmo.ClassCreator;
import io.quarkus.gizmo.FieldDescriptor;
import io.quarkus.gizmo.MethodCreator;
Expand All @@ -61,16 +66,18 @@
import io.quarkus.jaxrs.client.reactive.deployment.JaxrsClientReactiveEnricherBuildItem;
import io.quarkus.jaxrs.client.reactive.deployment.RestClientDefaultConsumesBuildItem;
import io.quarkus.jaxrs.client.reactive.deployment.RestClientDefaultProducesBuildItem;
import io.quarkus.rest.client.reactive.runtime.AnnotationRegisteredProviders;
import io.quarkus.rest.client.reactive.runtime.HeaderCapturingServerFilter;
import io.quarkus.rest.client.reactive.runtime.HeaderContainer;
import io.quarkus.rest.client.reactive.runtime.RestClientCDIDelegateBuilder;
import io.quarkus.rest.client.reactive.runtime.RestClientReactiveConfig;
import io.quarkus.rest.client.reactive.runtime.RestClientRecorder;
import io.quarkus.rest.client.reactive.runtime.RestClientRecorder.RegisteredProvider;
import io.quarkus.resteasy.reactive.spi.ContainerRequestFilterBuildItem;

class ReactiveResteasyMpClientProcessor {
class RestClientReactiveProcessor {

private static final Logger log = Logger.getLogger(ReactiveResteasyMpClientProcessor.class);
private static final Logger log = Logger.getLogger(RestClientReactiveProcessor.class);

private static final DotName REGISTER_REST_CLIENT = DotName.createSimple(RegisterRestClient.class.getName());
private static final DotName SESSION_SCOPED = DotName.createSimple(SessionScoped.class.getName());
Expand Down Expand Up @@ -151,6 +158,59 @@ void registerHeaderFactoryBeans(CombinedIndexBuildItem index,
}
}

/**
* Creates an implementation of `AnnotationRegisteredProviders` class with a constructor that
* puts all the providers registered by the @RegisterProvider annotation in a
* map using the {@link AnnotationRegisteredProviders#addProviders(String, Map)} method
*
* @param indexBuildItem index
* @param generatedBeans build producer for generated beans
*/
@BuildStep
@Record(ExecutionTime.RUNTIME_INIT)
void registerProvidersFromAnnotations(CombinedIndexBuildItem indexBuildItem,
BuildProducer<GeneratedBeanBuildItem> generatedBeans,
BuildProducer<SyntheticBeanBuildItem> syntheticBeanBuildItemBuildProducer,
RestClientRecorder recorder, BuildProducer<ReflectiveClassBuildItem> reflectiveClasses) {

IndexView index = indexBuildItem.getIndex();
Map<String, List<RegisteredProvider>> annotationsByClassName = new HashMap<>();

for (AnnotationInstance annotation : index.getAnnotations(REGISTER_PROVIDER)) {
String targetClass = annotation.target().asClass().name().toString();
addProviderFromAnnotation(annotationsByClassName, annotation, targetClass, reflectiveClasses);
}

for (AnnotationInstance groupAnnotation : index.getAnnotations(REGISTER_PROVIDERS)) {
String targetClass = groupAnnotation.target().asClass().name().toString();

for (AnnotationInstance annotation : groupAnnotation.value().asNestedArray()) {
addProviderFromAnnotation(annotationsByClassName, annotation, targetClass, reflectiveClasses);
}
}

SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem
.configure(AnnotationRegisteredProviders.class)
.scope(Singleton.class)
.setRuntimeInit()
.unremovable()
.runtimeValue(recorder.registeredProviders(annotationsByClassName));
syntheticBeanBuildItemBuildProducer.produce(configurator.done());
}

private void addProviderFromAnnotation(Map<String, List<RegisteredProvider>> annotationsByClassName,
AnnotationInstance annotation, String targetClass, BuildProducer<ReflectiveClassBuildItem> reflectiveClasses) {
AnnotationValue priority = annotation.value("priority");

RegisteredProvider provider = new RegisteredProvider();
provider.className = annotation.value().asString();
provider.priority = priority == null ? -1 : priority.asInt();

annotationsByClassName.computeIfAbsent(targetClass, key -> new ArrayList<>()).add(provider);

reflectiveClasses.produce(new ReflectiveClassBuildItem(false, false, provider.className));
}

@BuildStep
AdditionalBeanBuildItem registerProviderBeans(CombinedIndexBuildItem combinedIndex) {
IndexView index = combinedIndex.getIndex();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package io.quarkus.rest.client.reactive.runtime;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

public class AnnotationRegisteredProviders {
private final Map<String, Map<Class<?>, Integer>> providers = new HashMap<>();

public Map<Class<?>, Integer> getProviders(Class<?> clientClass) {
Map<Class<?>, Integer> providersForClass = providers.get(clientClass.getName());
return providersForClass == null ? Collections.emptyMap() : providersForClass;
}

public void addProviders(String className, Map<Class<?>, Integer> providersForClass) {
this.providers.put(className, providersForClass);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ public <T> T build(Class<T> aClass) throws IllegalStateException, RestClientDefi

RestClientListeners.get().forEach(listener -> listener.onNewClient(aClass, this));

AnnotationRegisteredProviders annotationRegisteredProviders = Arc.container()
.instance(AnnotationRegisteredProviders.class).get();
for (Map.Entry<Class<?>, Integer> mapper : annotationRegisteredProviders.getProviders(aClass).entrySet()) {
register(mapper.getKey(), mapper.getValue());
}

Object defaultMapperDisabled = getConfiguration().getProperty(DEFAULT_MAPPER_DISABLED);
Boolean globallyDisabledMapper = ConfigProvider.getConfig()
.getOptionalValue(DEFAULT_MAPPER_DISABLED, Boolean.class).orElse(false);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,46 @@
package io.quarkus.rest.client.reactive.runtime;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.eclipse.microprofile.rest.client.spi.RestClientBuilderResolver;

import io.quarkus.runtime.RuntimeValue;
import io.quarkus.runtime.annotations.Recorder;

@Recorder
public class RestClientRecorder {
public void setRestClientBuilderResolver() {
RestClientBuilderResolver.setInstance(new BuilderResolver());
}

public RuntimeValue<AnnotationRegisteredProviders> registeredProviders(
Map<String, List<RegisteredProvider>> annotationsByClassName) {
AnnotationRegisteredProviders result = new AnnotationRegisteredProviders();

for (Map.Entry<String, List<RegisteredProvider>> providersForClass : annotationsByClassName.entrySet()) {
result.addProviders(providersForClass.getKey(), toMapOfProviders(providersForClass.getValue()));
}

return new RuntimeValue<>(result);
}

private Map<Class<?>, Integer> toMapOfProviders(List<RegisteredProvider> value) {
Map<Class<?>, Integer> result = new HashMap<>();
for (RegisteredProvider registeredProvider : value) {
try {
result.put(Class.forName(registeredProvider.className, false, Thread.currentThread().getContextClassLoader()),
registeredProvider.priority);
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
return result;
}

public static class RegisteredProvider {
public String className;
public int priority;
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.it.rest.client;
package io.quarkus.it.rest.client.main;

import io.quarkus.runtime.annotations.RegisterForReflection;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.it.rest.client;
package io.quarkus.it.rest.client.main;

import java.util.concurrent.CompletionStage;

Expand All @@ -15,7 +15,7 @@
@Path("")
@RegisterProvider(DefaultCtorTestFilter.class)
@RegisterProvider(NonDefaultCtorTestFilter.class)
public interface SimpleClient {
public interface AppleClient {
@POST
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.it.rest.client;
package io.quarkus.it.rest.client.main;

import java.net.URI;
import java.util.concurrent.atomic.AtomicInteger;
Expand All @@ -7,11 +7,13 @@
import javax.enterprise.event.Observes;

import org.eclipse.microprofile.rest.client.RestClientBuilder;
import org.eclipse.microprofile.rest.client.inject.RestClient;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper;

import io.quarkus.it.rest.client.main.MyResponseExceptionMapper.MyException;
import io.smallrye.mutiny.Uni;
import io.vertx.ext.web.Router;
import io.vertx.ext.web.RoutingContext;
Expand All @@ -21,12 +23,27 @@
public class ClientCallingResource {
private static final ObjectMapper mapper = new JsonMapper();

private static final String[] RESPONSES = { "cortland", "cortland2", "cortland3" };
private static final String[] RESPONSES = { "cortland", "lobo", "golden delicious" };
private final AtomicInteger count = new AtomicInteger(0);

@RestClient
ClientWithExceptionMapper clientWithExceptionMapper;

void init(@Observes Router router) {
router.post().handler(BodyHandler.create());

router.get("/unprocessable").handler(rc -> rc.response().setStatusCode(422).end("the entity was unprocessable"));

router.post("/call-client-with-exception-mapper").blockingHandler(rc -> {
String url = rc.getBody().toString();
ClientWithExceptionMapper client = RestClientBuilder.newBuilder().baseUri(URI.create(url))
.register(MyResponseExceptionMapper.class)
.build(ClientWithExceptionMapper.class);
callGet(rc, client);
});

router.post("/call-cdi-client-with-exception-mapper").blockingHandler(rc -> callGet(rc, clientWithExceptionMapper));

router.post("/apples").handler(rc -> {
int count = this.count.getAndIncrement();
rc.response().putHeader("content-type", "application/json")
Expand All @@ -35,8 +52,8 @@ void init(@Observes Router router) {

router.route("/call-client").blockingHandler(rc -> {
String url = rc.getBody().toString();
SimpleClient client = RestClientBuilder.newBuilder().baseUri(URI.create(url))
.build(SimpleClient.class);
AppleClient client = RestClientBuilder.newBuilder().baseUri(URI.create(url))
.build(AppleClient.class);
Uni<Apple> apple1 = Uni.createFrom().item(client.swapApple(new Apple("lobo")));
Uni<Apple> apple2 = Uni.createFrom().completionStage(client.completionSwapApple(new Apple("lobo2")));
Uni<Apple> apple3 = client.uniSwapApple(new Apple("lobo3"));
Expand All @@ -56,12 +73,23 @@ void init(@Observes Router router) {
} catch (JsonProcessingException e) {
fail(rc, e.getMessage());
}
}, t -> {
fail(rc, t.getMessage());
});
}, t -> fail(rc, t.getMessage()));
});
}

private void callGet(RoutingContext rc, ClientWithExceptionMapper client) {
try {
client.get();
} catch (MyException expected) {
rc.response().setStatusCode(200).end();
return;
} catch (Exception unexpected) {
rc.response().setStatusCode(500).end("Expected MyException to be thrown, got " + unexpected.getClass());
return;
}
rc.response().setStatusCode(500).end("Expected MyException to be thrown but no exception has been thrown");
}

private void fail(RoutingContext rc, String message) {
rc.response().putHeader("content-type", "text/plain").setStatusCode(500).end(message);
}
Expand Down
Loading