Skip to content

Commit

Permalink
Properly cache ContextResolver usage for ObjectMapper in client code
Browse files Browse the repository at this point in the history
Fixes: #36067
  • Loading branch information
geoand committed Sep 25, 2023
1 parent 565f10f commit b957b27
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
import static io.restassured.RestAssured.given;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.net.URI;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;

import jakarta.inject.Singleton;
import jakarta.ws.rs.GET;
Expand Down Expand Up @@ -70,10 +68,17 @@ void serverShouldWrapRootElement() {
*/
@Test
void shouldClientUseCustomObjectMapperUnwrappingRootElement() {
assertFalse(ClientObjectMapperUnwrappingRootElement.USED.get());
AtomicLong count = ClientObjectMapperUnwrappingRootElement.COUNT;
assertEquals(0, count.get());
Request request = clientUnwrappingRootElement.get();
assertEquals("good", request.value);
assertTrue(ClientObjectMapperUnwrappingRootElement.USED.get());
assertEquals(1, count.get());

assertEquals("good", clientUnwrappingRootElement.get().value);
assertEquals("good", clientUnwrappingRootElement.get().value);
assertEquals("good", clientUnwrappingRootElement.get().value);
// count should not change as the resolution of the ObjectMapper should be cached
assertEquals(1, count.get());
}

/**
Expand All @@ -82,10 +87,17 @@ void shouldClientUseCustomObjectMapperUnwrappingRootElement() {
*/
@Test
void shouldClientUseCustomObjectMapperNotUnwrappingRootElement() {
assertFalse(MyClientNotUnwrappingRootElement.CUSTOM_OBJECT_MAPPER_USED.get());
AtomicLong count = MyClientNotUnwrappingRootElement.CUSTOM_OBJECT_MAPPER_COUNT;
assertEquals(0, count.get());
Request request = clientNotUnwrappingRootElement.get();
assertNull(request.value);
assertTrue(MyClientNotUnwrappingRootElement.CUSTOM_OBJECT_MAPPER_USED.get());
assertEquals(1, count.get());

assertNull(clientNotUnwrappingRootElement.get().value);
assertNull(clientNotUnwrappingRootElement.get().value);
assertNull(clientNotUnwrappingRootElement.get().value);
// count should not change as the resolution of the ObjectMapper should be cached
assertEquals(1, count.get());
}

@Path("/server")
Expand All @@ -108,14 +120,14 @@ public interface MyClientUnwrappingRootElement {
@Path("/server")
@Produces(MediaType.APPLICATION_JSON)
public interface MyClientNotUnwrappingRootElement {
AtomicBoolean CUSTOM_OBJECT_MAPPER_USED = new AtomicBoolean(false);
AtomicLong CUSTOM_OBJECT_MAPPER_COUNT = new AtomicLong();

@GET
Request get();

@ClientObjectMapper
static ObjectMapper objectMapper(ObjectMapper defaultObjectMapper) {
CUSTOM_OBJECT_MAPPER_USED.set(true);
CUSTOM_OBJECT_MAPPER_COUNT.incrementAndGet();
return defaultObjectMapper.copy()
.disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)
.disable(DeserializationFeature.UNWRAP_ROOT_VALUE);
Expand Down Expand Up @@ -158,11 +170,11 @@ public int hashCode() {
}

public static class ClientObjectMapperUnwrappingRootElement implements ContextResolver<ObjectMapper> {
static final AtomicBoolean USED = new AtomicBoolean(false);
static final AtomicLong COUNT = new AtomicLong();

@Override
public ObjectMapper getContext(Class<?> type) {
USED.set(true);
COUNT.incrementAndGet();
return new ObjectMapper().enable(DeserializationFeature.UNWRAP_ROOT_VALUE);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.quarkus.rest.client.reactive.jackson.runtime.serialisers;

import static io.quarkus.rest.client.reactive.jackson.runtime.serialisers.JacksonUtil.getObjectMapperFromContext;

import java.io.IOException;
import java.io.InputStream;
import java.lang.annotation.Annotation;
Expand All @@ -13,8 +15,6 @@
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.MultivaluedMap;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.ext.ContextResolver;
import jakarta.ws.rs.ext.Providers;

import org.jboss.logging.Logger;
import org.jboss.resteasy.reactive.ClientWebApplicationException;
Expand All @@ -33,7 +33,8 @@ public class ClientJacksonMessageBodyReader extends JacksonBasicMessageBodyReade

private static final Logger log = Logger.getLogger(ClientJacksonMessageBodyReader.class);

private final ConcurrentMap<ObjectMapper, ObjectReader> contextResolverMap = new ConcurrentHashMap<>();
private final ConcurrentMap<ResolverMapKey, ObjectMapper> contextResolverMap = new ConcurrentHashMap<>();
private final ConcurrentMap<ObjectMapper, ObjectReader> objectReaderMap = new ConcurrentHashMap<>();
private RestClientRequestContext context;

@Inject
Expand Down Expand Up @@ -66,43 +67,16 @@ public void handle(RestClientRequestContext requestContext) {
}

private ObjectReader getEffectiveReader(Class<Object> type, MediaType responseMediaType) {
ObjectMapper effectiveMapper = getObjectMapperFromContext(type, responseMediaType);
ObjectMapper effectiveMapper = getObjectMapperFromContext(type, responseMediaType, context, contextResolverMap);
if (effectiveMapper == null) {
return getEffectiveReader();
}

return contextResolverMap.computeIfAbsent(effectiveMapper, new Function<>() {
return objectReaderMap.computeIfAbsent(effectiveMapper, new Function<>() {
@Override
public ObjectReader apply(ObjectMapper objectMapper) {
return objectMapper.reader();
}
});
}

private ObjectMapper getObjectMapperFromContext(Class<Object> type, MediaType responseMediaType) {
Providers providers = getProviders();
if (providers == null) {
return null;
}

ContextResolver<ObjectMapper> contextResolver = providers.getContextResolver(ObjectMapper.class,
responseMediaType);
if (contextResolver == null) {
// TODO: not sure if this is correct, but Jackson does this as well...
contextResolver = providers.getContextResolver(ObjectMapper.class, null);
}
if (contextResolver != null) {
return contextResolver.getContext(type);
}

return null;
}

private Providers getProviders() {
if (context != null && context.getClientRequestContext() != null) {
return context.getClientRequestContext().getProviders();
}

return null;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.quarkus.rest.client.reactive.jackson.runtime.serialisers;

import static io.quarkus.rest.client.reactive.jackson.runtime.serialisers.JacksonUtil.getObjectMapperFromContext;
import static org.jboss.resteasy.reactive.server.jackson.JacksonMessageBodyWriterUtil.createDefaultWriter;
import static org.jboss.resteasy.reactive.server.jackson.JacksonMessageBodyWriterUtil.doLegacyWrite;

Expand Down Expand Up @@ -27,7 +28,8 @@ public class ClientJacksonMessageBodyWriter implements MessageBodyWriter<Object>

protected final ObjectMapper originalMapper;
protected final ObjectWriter defaultWriter;
private final ConcurrentMap<ObjectMapper, ObjectWriter> contextResolverMap = new ConcurrentHashMap<>();
private final ConcurrentMap<ResolverMapKey, ObjectMapper> contextResolverMap = new ConcurrentHashMap<>();
private final ConcurrentMap<ObjectMapper, ObjectWriter> objectWriterMap = new ConcurrentHashMap<>();
private RestClientRequestContext context;

@Inject
Expand All @@ -44,30 +46,26 @@ public boolean isWriteable(Class type, Type genericType, Annotation[] annotation
@Override
public void writeTo(Object o, Class<?> type, Type genericType, Annotation[] annotations, MediaType mediaType,
MultivaluedMap<String, Object> httpHeaders, OutputStream entityStream) throws IOException, WebApplicationException {
doLegacyWrite(o, annotations, httpHeaders, entityStream, getEffectiveWriter());
doLegacyWrite(o, annotations, httpHeaders, entityStream, getEffectiveWriter(type, mediaType));
}

@Override
public void handle(RestClientRequestContext requestContext) throws Exception {
this.context = requestContext;
}

protected ObjectWriter getEffectiveWriter() {
if (context == null) {
// no context injected when writer is not running within a rest client context
return defaultWriter;
}

ObjectMapper objectMapper = context.getConfiguration().getFromContext(ObjectMapper.class);
protected ObjectWriter getEffectiveWriter(Class<?> type, MediaType responseMediaType) {
ObjectMapper objectMapper = getObjectMapperFromContext(type, responseMediaType, context, contextResolverMap);
if (objectMapper == null) {
return defaultWriter;
}

return contextResolverMap.computeIfAbsent(objectMapper, new Function<>() {
return objectWriterMap.computeIfAbsent(objectMapper, new Function<>() {
@Override
public ObjectWriter apply(ObjectMapper objectMapper) {
return createDefaultWriter(objectMapper);
}
});
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package io.quarkus.rest.client.reactive.jackson.runtime.serialisers;

import java.util.concurrent.ConcurrentMap;
import java.util.function.Function;

import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.ext.ContextResolver;
import jakarta.ws.rs.ext.Providers;

import org.jboss.resteasy.reactive.client.impl.RestClientRequestContext;

import com.fasterxml.jackson.databind.ObjectMapper;

final class JacksonUtil {

private JacksonUtil() {
}

static ObjectMapper getObjectMapperFromContext(Class<?> type, MediaType responseMediaType, RestClientRequestContext context,
ConcurrentMap<ResolverMapKey, ObjectMapper> contextResolverMap) {
Providers providers = getProviders(context);
if (providers == null) {
return null;
}

ContextResolver<ObjectMapper> contextResolver = providers.getContextResolver(ObjectMapper.class,
responseMediaType);
if (contextResolver == null) {
// TODO: not sure if this is correct, but Jackson does this as well...
contextResolver = providers.getContextResolver(ObjectMapper.class, null);
}
if (contextResolver != null) {
var cr = contextResolver;
var key = new ResolverMapKey(type, context.getConfiguration(), context.getInvokedMethod().getDeclaringClass());
return contextResolverMap.computeIfAbsent(key, new Function<>() {
@Override
public ObjectMapper apply(ResolverMapKey resolverMapKey) {
return cr.getContext(resolverMapKey.getType());
}
});
}

return null;
}

private static Providers getProviders(RestClientRequestContext context) {
if (context != null && context.getClientRequestContext() != null) {
return context.getClientRequestContext().getProviders();
}

return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package io.quarkus.rest.client.reactive.jackson.runtime.serialisers;

import java.util.Objects;

import jakarta.ws.rs.core.Configuration;

/**
* Each REST Client can potentially have different providers, so we need to make sure that
* caching for one client does not affect caching of another
*/
public final class ResolverMapKey {

private final Class<?> type;
private final Configuration configuration;

private final Class<?> restClientClass;

public ResolverMapKey(Class<?> type, Configuration configuration, Class<?> restClientClass) {
this.type = type;
this.configuration = configuration;
this.restClientClass = restClientClass;
}

public Class<?> getType() {
return type;
}

public Configuration getConfiguration() {
return configuration;
}

public Class<?> getRestClientClass() {
return restClientClass;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof ResolverMapKey)) {
return false;
}
ResolverMapKey that = (ResolverMapKey) o;
return Objects.equals(type, that.type) && Objects.equals(configuration, that.configuration)
&& Objects.equals(restClientClass, that.restClientClass);
}

@Override
public int hashCode() {
return Objects.hash(type, configuration, restClientClass);
}
}

0 comments on commit b957b27

Please sign in to comment.