From de4013a4c523262d42289ed47fc2c4c705bb05dd Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Tue, 23 Jul 2024 10:13:34 +0200 Subject: [PATCH] Consider declaring class when evaluating method return type for query method post-processing. We now consider the declaring class to properly resolve type variable references for the result post-processing of a query method result. Previously, we attempted to resolve the return type without considering the actual repository class resolving always Object instead of the type parameter. Closes #3125 --- .../support/QueryExecutionResultHandler.java | 53 ++++++++++++++----- .../QueryExecutorMethodInterceptor.java | 11 ++-- .../QueryExecutionResultHandlerUnitTests.java | 27 ++++++++-- 3 files changed, 71 insertions(+), 20 deletions(-) diff --git a/src/main/java/org/springframework/data/repository/core/support/QueryExecutionResultHandler.java b/src/main/java/org/springframework/data/repository/core/support/QueryExecutionResultHandler.java index f24f77ef35..9438a3fc92 100644 --- a/src/main/java/org/springframework/data/repository/core/support/QueryExecutionResultHandler.java +++ b/src/main/java/org/springframework/data/repository/core/support/QueryExecutionResultHandler.java @@ -15,7 +15,6 @@ */ package org.springframework.data.repository.core.support; -import java.lang.reflect.Method; import java.util.Collection; import java.util.Collections; import java.util.HashMap; @@ -23,13 +22,16 @@ import java.util.Optional; import org.springframework.core.CollectionFactory; +import org.springframework.core.KotlinDetector; import org.springframework.core.MethodParameter; import org.springframework.core.convert.ConversionService; import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.support.GenericConversionService; +import org.springframework.data.repository.util.ClassUtils; import org.springframework.data.repository.util.QueryExecutionConverters; import org.springframework.data.repository.util.ReactiveWrapperConverters; import org.springframework.data.util.NullableWrapper; +import org.springframework.data.util.ReactiveWrappers; import org.springframework.data.util.Streamable; import org.springframework.lang.Nullable; @@ -44,12 +46,14 @@ class QueryExecutionResultHandler { private static final TypeDescriptor WRAPPER_TYPE = TypeDescriptor.valueOf(NullableWrapper.class); + private static final Class FLOW_TYPE = loadIfPresent("kotlinx.coroutines.flow.Flow"); + private final GenericConversionService conversionService; private final Object mutex = new Object(); // concurrent access guarded by mutex. - private Map descriptorCache = Collections.emptyMap(); + private Map descriptorCache = Collections.emptyMap(); /** * Creates a new {@link QueryExecutionResultHandler}. @@ -58,6 +62,17 @@ class QueryExecutionResultHandler { this.conversionService = conversionService; } + @Nullable + @SuppressWarnings("unchecked") + public static Class loadIfPresent(String type) { + + try { + return (Class) org.springframework.util.ClassUtils.forName(type, ClassUtils.class.getClassLoader()); + } catch (ClassNotFoundException | LinkageError e) { + return null; + } + } + /** * Post-processes the given result of a query invocation to match the return type of the given method. * @@ -66,9 +81,9 @@ class QueryExecutionResultHandler { * @return */ @Nullable - Object postProcessInvocationResult(@Nullable Object result, Method method) { + Object postProcessInvocationResult(@Nullable Object result, MethodParameter method) { - if (!processingRequired(result, method.getReturnType())) { + if (!processingRequired(result, method)) { return result; } @@ -77,16 +92,16 @@ Object postProcessInvocationResult(@Nullable Object result, Method method) { return postProcessInvocationResult(result, 0, descriptor); } - private ReturnTypeDescriptor getOrCreateReturnTypeDescriptor(Method method) { + private ReturnTypeDescriptor getOrCreateReturnTypeDescriptor(MethodParameter method) { - Map descriptorCache = this.descriptorCache; + Map descriptorCache = this.descriptorCache; ReturnTypeDescriptor descriptor = descriptorCache.get(method); if (descriptor == null) { descriptor = ReturnTypeDescriptor.of(method); - Map updatedDescriptorCache; + Map updatedDescriptorCache; if (descriptorCache.isEmpty()) { updatedDescriptorCache = Collections.singletonMap(method, descriptor); @@ -94,7 +109,6 @@ private ReturnTypeDescriptor getOrCreateReturnTypeDescriptor(Method method) { updatedDescriptorCache = new HashMap<>(descriptorCache.size() + 1, 1); updatedDescriptorCache.putAll(descriptorCache); updatedDescriptorCache.put(method, descriptor); - } synchronized (mutex) { @@ -234,10 +248,21 @@ private static Object unwrapOptional(@Nullable Object source) { * Returns whether we have to process the given source object in the first place. * * @param source can be {@literal null}. - * @param targetType must not be {@literal null}. + * @param methodParameter must not be {@literal null}. * @return */ - private static boolean processingRequired(@Nullable Object source, Class targetType) { + private static boolean processingRequired(@Nullable Object source, MethodParameter methodParameter) { + + Class targetType = methodParameter.getParameterType(); + + if (source != null && ReactiveWrappers.KOTLIN_COROUTINES_PRESENT + && KotlinDetector.isSuspendingFunction(methodParameter.getMethod())) { + + // Spring's AOP invoker handles Publisher to Flow conversion, so we have to exempt these from post-processing. + if (FLOW_TYPE != null && FLOW_TYPE.isAssignableFrom(targetType)) { + return false; + } + } return !targetType.isInstance(source) // || source == null // @@ -253,19 +278,19 @@ static class ReturnTypeDescriptor { private final TypeDescriptor typeDescriptor; private final @Nullable TypeDescriptor nestedTypeDescriptor; - private ReturnTypeDescriptor(Method method) { - this.methodParameter = new MethodParameter(method, -1); + private ReturnTypeDescriptor(MethodParameter methodParameter) { + this.methodParameter = methodParameter; this.typeDescriptor = TypeDescriptor.nested(this.methodParameter, 0); this.nestedTypeDescriptor = TypeDescriptor.nested(this.methodParameter, 1); } /** - * Create a {@link ReturnTypeDescriptor} from a {@link Method}. + * Create a {@link ReturnTypeDescriptor} from a {@link MethodParameter}. * * @param method * @return */ - public static ReturnTypeDescriptor of(Method method) { + public static ReturnTypeDescriptor of(MethodParameter method) { return new ReturnTypeDescriptor(method); } diff --git a/src/main/java/org/springframework/data/repository/core/support/QueryExecutorMethodInterceptor.java b/src/main/java/org/springframework/data/repository/core/support/QueryExecutorMethodInterceptor.java index d61785a2b6..f75331b8fc 100644 --- a/src/main/java/org/springframework/data/repository/core/support/QueryExecutorMethodInterceptor.java +++ b/src/main/java/org/springframework/data/repository/core/support/QueryExecutorMethodInterceptor.java @@ -21,9 +21,12 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; import org.aopalliance.intercept.MethodInterceptor; import org.aopalliance.intercept.MethodInvocation; + +import org.springframework.core.MethodParameter; import org.springframework.core.ResolvableType; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.repository.core.NamedQueries; @@ -55,6 +58,7 @@ class QueryExecutorMethodInterceptor implements MethodInterceptor { private final RepositoryInformation repositoryInformation; private final Map queries; private final Map invocationMetadataCache = new ConcurrentReferenceHashMap<>(); + private final Map returnTypeMap = new ConcurrentHashMap<>(); private final QueryExecutionResultHandler resultHandler; private final NamedQueries namedQueries; private final List> queryPostProcessors; @@ -135,16 +139,17 @@ private void invokeListeners(RepositoryQuery query) { public Object invoke(@SuppressWarnings("null") MethodInvocation invocation) throws Throwable { Method method = invocation.getMethod(); + MethodParameter returnType = returnTypeMap.computeIfAbsent(method, it -> new MethodParameter(it, -1)); QueryExecutionConverters.ExecutionAdapter executionAdapter = QueryExecutionConverters // - .getExecutionAdapter(method.getReturnType()); + .getExecutionAdapter(returnType.getParameterType()); if (executionAdapter == null) { - return resultHandler.postProcessInvocationResult(doInvoke(invocation), method); + return resultHandler.postProcessInvocationResult(doInvoke(invocation), returnType); } return executionAdapter // - .apply(() -> resultHandler.postProcessInvocationResult(doInvoke(invocation), method)); + .apply(() -> resultHandler.postProcessInvocationResult(doInvoke(invocation), returnType)); } @Nullable diff --git a/src/test/java/org/springframework/data/repository/core/support/QueryExecutionResultHandlerUnitTests.java b/src/test/java/org/springframework/data/repository/core/support/QueryExecutionResultHandlerUnitTests.java index aeaa2ec163..1d80270c24 100755 --- a/src/test/java/org/springframework/data/repository/core/support/QueryExecutionResultHandlerUnitTests.java +++ b/src/test/java/org/springframework/data/repository/core/support/QueryExecutionResultHandlerUnitTests.java @@ -26,7 +26,6 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import java.lang.reflect.Method; import java.math.BigDecimal; import java.util.Arrays; import java.util.Collections; @@ -40,6 +39,8 @@ import org.assertj.core.api.SoftAssertions; import org.junit.jupiter.api.Test; import org.reactivestreams.Publisher; + +import org.springframework.core.MethodParameter; import org.springframework.dao.InvalidDataAccessApiUsageException; import org.springframework.data.repository.Repository; import org.springframework.data.util.Streamable; @@ -404,6 +405,17 @@ void nestedConversion() throws Exception { }); } + @Test // GH-3125 + void considersTypeBoundsFromBaseInterface() throws NoSuchMethodException { + + var method = CustomizedRepository.class.getMethod("findById", Object.class); + + var result = handler.postProcessInvocationResult(Optional.of(new Entity()), + new MethodParameter(method, -1).withContainingClass(CustomizedRepository.class)); + + assertThat(result).isInstanceOf(Entity.class); + } + @Test // DATACMNS-1552 void keepsVavrOptionType() throws Exception { @@ -412,8 +424,17 @@ void keepsVavrOptionType() throws Exception { assertThat(handler.postProcessInvocationResult(source, getMethod("option"))).isSameAs(source); } - private static Method getMethod(String methodName) throws Exception { - return Sample.class.getMethod(methodName); + private static MethodParameter getMethod(String methodName) throws Exception { + return new MethodParameter(Sample.class.getMethod(methodName), -1); + } + + interface BaseRepository extends Repository { + + T findById(ID id); + } + + interface CustomizedRepository extends BaseRepository { + } static interface Sample extends Repository {