Skip to content

Commit

Permalink
Consider declaring class when evaluating method return type for query…
Browse files Browse the repository at this point in the history
… method post-processing.

We now consider the declaring class to properly resolve type variable references for the result post-processing of a query method result.

Previously, we attempted to resolve the return type without considering the actual repository class resolving always Object instead of the type parameter.

Closes #3125
  • Loading branch information
mp911de committed Jul 23, 2024
1 parent 75d0992 commit de4013a
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,23 @@
*/
package org.springframework.data.repository.core.support;

import java.lang.reflect.Method;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

import org.springframework.core.CollectionFactory;
import org.springframework.core.KotlinDetector;
import org.springframework.core.MethodParameter;
import org.springframework.core.convert.ConversionService;
import org.springframework.core.convert.TypeDescriptor;
import org.springframework.core.convert.support.GenericConversionService;
import org.springframework.data.repository.util.ClassUtils;
import org.springframework.data.repository.util.QueryExecutionConverters;
import org.springframework.data.repository.util.ReactiveWrapperConverters;
import org.springframework.data.util.NullableWrapper;
import org.springframework.data.util.ReactiveWrappers;
import org.springframework.data.util.Streamable;
import org.springframework.lang.Nullable;

Expand All @@ -44,12 +46,14 @@ class QueryExecutionResultHandler {

private static final TypeDescriptor WRAPPER_TYPE = TypeDescriptor.valueOf(NullableWrapper.class);

private static final Class<?> FLOW_TYPE = loadIfPresent("kotlinx.coroutines.flow.Flow");

private final GenericConversionService conversionService;

private final Object mutex = new Object();

// concurrent access guarded by mutex.
private Map<Method, ReturnTypeDescriptor> descriptorCache = Collections.emptyMap();
private Map<MethodParameter, ReturnTypeDescriptor> descriptorCache = Collections.emptyMap();

/**
* Creates a new {@link QueryExecutionResultHandler}.
Expand All @@ -58,6 +62,17 @@ class QueryExecutionResultHandler {
this.conversionService = conversionService;
}

@Nullable
@SuppressWarnings("unchecked")
public static <T> Class<T> loadIfPresent(String type) {

try {
return (Class<T>) org.springframework.util.ClassUtils.forName(type, ClassUtils.class.getClassLoader());
} catch (ClassNotFoundException | LinkageError e) {
return null;
}
}

/**
* Post-processes the given result of a query invocation to match the return type of the given method.
*
Expand All @@ -66,9 +81,9 @@ class QueryExecutionResultHandler {
* @return
*/
@Nullable
Object postProcessInvocationResult(@Nullable Object result, Method method) {
Object postProcessInvocationResult(@Nullable Object result, MethodParameter method) {

if (!processingRequired(result, method.getReturnType())) {
if (!processingRequired(result, method)) {
return result;
}

Expand All @@ -77,24 +92,23 @@ Object postProcessInvocationResult(@Nullable Object result, Method method) {
return postProcessInvocationResult(result, 0, descriptor);
}

private ReturnTypeDescriptor getOrCreateReturnTypeDescriptor(Method method) {
private ReturnTypeDescriptor getOrCreateReturnTypeDescriptor(MethodParameter method) {

Map<Method, ReturnTypeDescriptor> descriptorCache = this.descriptorCache;
Map<MethodParameter, ReturnTypeDescriptor> descriptorCache = this.descriptorCache;
ReturnTypeDescriptor descriptor = descriptorCache.get(method);

if (descriptor == null) {

descriptor = ReturnTypeDescriptor.of(method);

Map<Method, ReturnTypeDescriptor> updatedDescriptorCache;
Map<MethodParameter, ReturnTypeDescriptor> updatedDescriptorCache;

if (descriptorCache.isEmpty()) {
updatedDescriptorCache = Collections.singletonMap(method, descriptor);
} else {
updatedDescriptorCache = new HashMap<>(descriptorCache.size() + 1, 1);
updatedDescriptorCache.putAll(descriptorCache);
updatedDescriptorCache.put(method, descriptor);

}

synchronized (mutex) {
Expand Down Expand Up @@ -234,10 +248,21 @@ private static Object unwrapOptional(@Nullable Object source) {
* Returns whether we have to process the given source object in the first place.
*
* @param source can be {@literal null}.
* @param targetType must not be {@literal null}.
* @param methodParameter must not be {@literal null}.
* @return
*/
private static boolean processingRequired(@Nullable Object source, Class<?> targetType) {
private static boolean processingRequired(@Nullable Object source, MethodParameter methodParameter) {

Class<?> targetType = methodParameter.getParameterType();

if (source != null && ReactiveWrappers.KOTLIN_COROUTINES_PRESENT
&& KotlinDetector.isSuspendingFunction(methodParameter.getMethod())) {

// Spring's AOP invoker handles Publisher to Flow conversion, so we have to exempt these from post-processing.
if (FLOW_TYPE != null && FLOW_TYPE.isAssignableFrom(targetType)) {
return false;
}
}

return !targetType.isInstance(source) //
|| source == null //
Expand All @@ -253,19 +278,19 @@ static class ReturnTypeDescriptor {
private final TypeDescriptor typeDescriptor;
private final @Nullable TypeDescriptor nestedTypeDescriptor;

private ReturnTypeDescriptor(Method method) {
this.methodParameter = new MethodParameter(method, -1);
private ReturnTypeDescriptor(MethodParameter methodParameter) {
this.methodParameter = methodParameter;
this.typeDescriptor = TypeDescriptor.nested(this.methodParameter, 0);
this.nestedTypeDescriptor = TypeDescriptor.nested(this.methodParameter, 1);
}

/**
* Create a {@link ReturnTypeDescriptor} from a {@link Method}.
* Create a {@link ReturnTypeDescriptor} from a {@link MethodParameter}.
*
* @param method
* @return
*/
public static ReturnTypeDescriptor of(Method method) {
public static ReturnTypeDescriptor of(MethodParameter method) {
return new ReturnTypeDescriptor(method);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation;

import org.springframework.core.MethodParameter;
import org.springframework.core.ResolvableType;
import org.springframework.data.projection.ProjectionFactory;
import org.springframework.data.repository.core.NamedQueries;
Expand Down Expand Up @@ -55,6 +58,7 @@ class QueryExecutorMethodInterceptor implements MethodInterceptor {
private final RepositoryInformation repositoryInformation;
private final Map<Method, RepositoryQuery> queries;
private final Map<Method, RepositoryMethodInvoker> invocationMetadataCache = new ConcurrentReferenceHashMap<>();
private final Map<Method, MethodParameter> returnTypeMap = new ConcurrentHashMap<>();
private final QueryExecutionResultHandler resultHandler;
private final NamedQueries namedQueries;
private final List<QueryCreationListener<?>> queryPostProcessors;
Expand Down Expand Up @@ -135,16 +139,17 @@ private void invokeListeners(RepositoryQuery query) {
public Object invoke(@SuppressWarnings("null") MethodInvocation invocation) throws Throwable {

Method method = invocation.getMethod();
MethodParameter returnType = returnTypeMap.computeIfAbsent(method, it -> new MethodParameter(it, -1));

QueryExecutionConverters.ExecutionAdapter executionAdapter = QueryExecutionConverters //
.getExecutionAdapter(method.getReturnType());
.getExecutionAdapter(returnType.getParameterType());

if (executionAdapter == null) {
return resultHandler.postProcessInvocationResult(doInvoke(invocation), method);
return resultHandler.postProcessInvocationResult(doInvoke(invocation), returnType);
}

return executionAdapter //
.apply(() -> resultHandler.postProcessInvocationResult(doInvoke(invocation), method));
.apply(() -> resultHandler.postProcessInvocationResult(doInvoke(invocation), returnType));
}

@Nullable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.lang.reflect.Method;
import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Collections;
Expand All @@ -40,6 +39,8 @@
import org.assertj.core.api.SoftAssertions;
import org.junit.jupiter.api.Test;
import org.reactivestreams.Publisher;

import org.springframework.core.MethodParameter;
import org.springframework.dao.InvalidDataAccessApiUsageException;
import org.springframework.data.repository.Repository;
import org.springframework.data.util.Streamable;
Expand Down Expand Up @@ -404,6 +405,17 @@ void nestedConversion() throws Exception {
});
}

@Test // GH-3125
void considersTypeBoundsFromBaseInterface() throws NoSuchMethodException {

var method = CustomizedRepository.class.getMethod("findById", Object.class);

var result = handler.postProcessInvocationResult(Optional.of(new Entity()),
new MethodParameter(method, -1).withContainingClass(CustomizedRepository.class));

assertThat(result).isInstanceOf(Entity.class);
}

@Test // DATACMNS-1552
void keepsVavrOptionType() throws Exception {

Expand All @@ -412,8 +424,17 @@ void keepsVavrOptionType() throws Exception {
assertThat(handler.postProcessInvocationResult(source, getMethod("option"))).isSameAs(source);
}

private static Method getMethod(String methodName) throws Exception {
return Sample.class.getMethod(methodName);
private static MethodParameter getMethod(String methodName) throws Exception {
return new MethodParameter(Sample.class.getMethod(methodName), -1);
}

interface BaseRepository<T, ID> extends Repository<T, ID> {

T findById(ID id);
}

interface CustomizedRepository extends BaseRepository<Entity, Long> {

}

static interface Sample extends Repository<Entity, Long> {
Expand Down

0 comments on commit de4013a

Please sign in to comment.