Skip to content

Commit

Permalink
Move Kotlin value class unboxing to InvocableHandlerMethod
Browse files Browse the repository at this point in the history
Before this commit, in Spring Framework 6.2, Kotlin value class
unboxing was done at CoroutinesUtils level, which is a good fit
for InvocableHandlerMethod use case, but not for other ones like
AopUtils.

This commit moves such unboxing to InvocableHandlerMethod in
order to keep the HTTP response body support while fixing other
regressions.

Closes gh-33943
  • Loading branch information
sdeleuze committed Nov 27, 2024
1 parent ea3bd7a commit 1aede29
Show file tree
Hide file tree
Showing 6 changed files with 320 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.SynchronousSink;

import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -109,7 +108,7 @@ public static Publisher<?> invokeSuspendingFunction(Method method, Object target
* @throws IllegalArgumentException if {@code method} is not a suspending function
* @since 6.0
*/
@SuppressWarnings({"deprecation", "DataFlowIssue", "NullAway"})
@SuppressWarnings({"DataFlowIssue", "NullAway"})
public static Publisher<?> invokeSuspendingFunction(
CoroutineContext context, Method method, @Nullable Object target, @Nullable Object... args) {

Expand Down Expand Up @@ -146,7 +145,7 @@ public static Publisher<?> invokeSuspendingFunction(
}
return KCallables.callSuspendBy(function, argMap, continuation);
})
.handle(CoroutinesUtils::handleResult)
.filter(result -> result != Unit.INSTANCE)
.onErrorMap(InvocationTargetException.class, InvocationTargetException::getTargetException);

KType returnType = function.getReturnType();
Expand All @@ -166,22 +165,4 @@ private static Flux<?> asFlux(Object flow) {
return ReactorFlowKt.asFlux(((Flow<?>) flow));
}

private static void handleResult(Object result, SynchronousSink<Object> sink) {
if (result == Unit.INSTANCE) {
sink.complete();
}
else if (KotlinDetector.isInlineClass(result.getClass())) {
try {
sink.next(result.getClass().getDeclaredMethod("unbox-impl").invoke(result));
sink.complete();
}
catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException ex) {
sink.error(ex);
}
}
else {
sink.next(result);
sink.complete();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ class CoroutinesUtilsTests {

@Test
fun invokeSuspendingFunctionWithValueClassParameter() {
val method = CoroutinesUtilsTests::class.java.declaredMethods.first { it.name.startsWith("suspendingFunctionWithValueClass") }
val method = CoroutinesUtilsTests::class.java.declaredMethods.first { it.name.startsWith("suspendingFunctionWithValueClassParameter") }
val mono = CoroutinesUtils.invokeSuspendingFunction(method, this, "foo", null) as Mono
runBlocking {
Assertions.assertThat(mono.awaitSingle()).isEqualTo("foo")
Expand All @@ -204,7 +204,16 @@ class CoroutinesUtilsTests {
val method = CoroutinesUtilsTests::class.java.declaredMethods.first { it.name.startsWith("suspendingFunctionWithValueClassReturnValue") }
val mono = CoroutinesUtils.invokeSuspendingFunction(method, this, null) as Mono
runBlocking {
Assertions.assertThat(mono.awaitSingle()).isEqualTo("foo")
Assertions.assertThat(mono.awaitSingle()).isEqualTo(ValueClass("foo"))
}
}

@Test
fun invokeSuspendingFunctionWithResultOfUnitReturnValue() {
val method = CoroutinesUtilsTests::class.java.declaredMethods.first { it.name.startsWith("suspendingFunctionWithResultOfUnitReturnValue") }
val mono = CoroutinesUtils.invokeSuspendingFunction(method, this, null) as Mono
runBlocking {
Assertions.assertThat(mono.awaitSingle()).isEqualTo(Result.success(Unit))
}
}

Expand Down Expand Up @@ -314,7 +323,7 @@ class CoroutinesUtilsTests {
return null
}

suspend fun suspendingFunctionWithValueClass(value: ValueClass): String {
suspend fun suspendingFunctionWithValueClassParameter(value: ValueClass): String {
delay(1)
return value.value
}
Expand All @@ -324,6 +333,11 @@ class CoroutinesUtilsTests {
return ValueClass("foo")
}

suspend fun suspendingFunctionWithResultOfUnitReturnValue(): Result<Unit> {
delay(1)
return Result.success(Unit)
}

suspend fun suspendingFunctionWithValueClassWithInit(value: ValueClassWithInit): String {
delay(1)
return value.value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import kotlin.reflect.full.KClasses;
import kotlin.reflect.jvm.KCallablesJvm;
import kotlin.reflect.jvm.ReflectJvmMapping;
import reactor.core.publisher.Mono;
import reactor.core.publisher.SynchronousSink;

import org.springframework.context.MessageSource;
import org.springframework.core.CoroutinesUtils;
Expand Down Expand Up @@ -288,7 +290,8 @@ else if (targetException instanceof Exception exception) {
* @since 6.0
*/
protected Object invokeSuspendingFunction(Method method, Object target, Object[] args) {
return CoroutinesUtils.invokeSuspendingFunction(method, target, args);
Object result = CoroutinesUtils.invokeSuspendingFunction(method, target, args);
return (result instanceof Mono<?> mono ? mono.handle(KotlinDelegate::handleResult) : result);
}


Expand All @@ -298,7 +301,7 @@ protected Object invokeSuspendingFunction(Method method, Object target, Object[]
private static class KotlinDelegate {

@Nullable
@SuppressWarnings({"deprecation", "DataFlowIssue"})
@SuppressWarnings("DataFlowIssue")
public static Object invokeFunction(Method method, Object target, Object[] args) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException {
KFunction<?> function = ReflectJvmMapping.getKotlinFunction(method);
// For property accessors
Expand Down Expand Up @@ -333,10 +336,33 @@ public static Object invokeFunction(Method method, Object target, Object[] args)
}
Object result = function.callBy(argMap);
if (result != null && KotlinDetector.isInlineClass(result.getClass())) {
return result.getClass().getDeclaredMethod("unbox-impl").invoke(result);
result = unbox(result);
}
return (result == Unit.INSTANCE ? null : result);
}

private static void handleResult(Object result, SynchronousSink<Object> sink) {
if (KotlinDetector.isInlineClass(result.getClass())) {
try {
Object unboxed = unbox(result);
if (unboxed != Unit.INSTANCE) {
sink.next(unboxed);
}
sink.complete();
}
catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException ex) {
sink.error(ex);
}
}
else {
sink.next(result);
sink.complete();
}
}

private static Object unbox(Object result) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException {
return result.getClass().getDeclaredMethod("unbox-impl").invoke(result);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@

package org.springframework.web.method.support

import kotlinx.coroutines.delay
import org.assertj.core.api.Assertions
import org.junit.jupiter.api.Test
import org.springframework.core.MethodParameter
import org.springframework.util.ReflectionUtils
import org.springframework.web.bind.support.WebDataBinderFactory
import org.springframework.web.context.request.NativeWebRequest
import org.springframework.web.context.request.ServletWebRequest
import org.springframework.web.testfixture.method.ResolvableMethod
import org.springframework.web.testfixture.servlet.MockHttpServletRequest
import org.springframework.web.testfixture.servlet.MockHttpServletResponse
import reactor.core.publisher.Mono
import reactor.test.StepVerifier
import java.lang.reflect.Method
import kotlin.reflect.jvm.javaGetter
import kotlin.reflect.jvm.javaMethod
Expand All @@ -33,6 +37,7 @@ import kotlin.reflect.jvm.javaMethod
*
* @author Sebastien Deleuze
*/
@Suppress("UNCHECKED_CAST")
class InvocableHandlerMethodKotlinTests {

private val request: NativeWebRequest = ServletWebRequest(MockHttpServletRequest(), MockHttpServletResponse())
Expand Down Expand Up @@ -110,6 +115,12 @@ class InvocableHandlerMethodKotlinTests {
Assertions.assertThat(value).isEqualTo("foo")
}

@Test
fun resultOfUnitReturnValue() {
val value = getInvocable(ValueClassHandler::resultOfUnitReturnValue.javaMethod!!).invokeForRequest(request, null)
Assertions.assertThat(value).isNull()
}

@Test
fun valueClassDefaultValue() {
composite.addResolver(StubArgumentResolver(Double::class.java))
Expand Down Expand Up @@ -138,6 +149,60 @@ class InvocableHandlerMethodKotlinTests {
Assertions.assertThat(value).isEqualTo('a')
}

@Test
fun suspendingValueClass() {
composite.addResolver(ContinuationHandlerMethodArgumentResolver())
composite.addResolver(StubArgumentResolver(Long::class.java, 1L))
val value = getInvocable(SuspendingValueClassHandler::longValueClass.javaMethod!!).invokeForRequest(request, null)
StepVerifier.create(value as Mono<Long>).expectNext(1L).verifyComplete()
}

@Test
fun suspendingValueClassReturnValue() {
composite.addResolver(ContinuationHandlerMethodArgumentResolver())
val value = getInvocable(SuspendingValueClassHandler::valueClassReturnValue.javaMethod!!).invokeForRequest(request, null)
StepVerifier.create(value as Mono<String>).expectNext("foo").verifyComplete()
}

@Test
fun suspendingResultOfUnitReturnValue() {
composite.addResolver(ContinuationHandlerMethodArgumentResolver())
val value = getInvocable(SuspendingValueClassHandler::resultOfUnitReturnValue.javaMethod!!).invokeForRequest(request, null)
StepVerifier.create(value as Mono<Unit>).verifyComplete()
}

@Test
fun suspendingValueClassDefaultValue() {
composite.addResolver(ContinuationHandlerMethodArgumentResolver())
composite.addResolver(StubArgumentResolver(Double::class.java))
val value = getInvocable(SuspendingValueClassHandler::doubleValueClass.javaMethod!!).invokeForRequest(request, null)
StepVerifier.create(value as Mono<Double>).expectNext(3.1).verifyComplete()
}

@Test
fun suspendingValueClassWithInit() {
composite.addResolver(ContinuationHandlerMethodArgumentResolver())
composite.addResolver(StubArgumentResolver(String::class.java, ""))
val value = getInvocable(SuspendingValueClassHandler::valueClassWithInit.javaMethod!!).invokeForRequest(request, null)
StepVerifier.create(value as Mono<String>).verifyError(IllegalArgumentException::class.java)
}

@Test
fun suspendingValueClassWithNullable() {
composite.addResolver(ContinuationHandlerMethodArgumentResolver())
composite.addResolver(StubArgumentResolver(LongValueClass::class.java, null))
val value = getInvocable(SuspendingValueClassHandler::valueClassWithNullable.javaMethod!!).invokeForRequest(request, null)
StepVerifier.create(value as Mono<Long>).verifyComplete()
}

@Test
fun suspendingValueClassWithPrivateConstructor() {
composite.addResolver(ContinuationHandlerMethodArgumentResolver())
composite.addResolver(StubArgumentResolver(Char::class.java, 'a'))
val value = getInvocable(SuspendingValueClassHandler::valueClassWithPrivateConstructor.javaMethod!!).invokeForRequest(request, null)
StepVerifier.create(value as Mono<Char>).expectNext('a').verifyComplete()
}

@Test
fun propertyAccessor() {
val value = getInvocable(PropertyAccessorHandler::prop.javaGetter!!).invokeForRequest(request, null)
Expand Down Expand Up @@ -206,23 +271,58 @@ class InvocableHandlerMethodKotlinTests {

private class ValueClassHandler {

fun valueClassReturnValue() =
StringValueClass("foo")
fun valueClassReturnValue() = StringValueClass("foo")

fun resultOfUnitReturnValue() = Result.success(Unit)

fun longValueClass(limit: LongValueClass) = limit.value

fun doubleValueClass(limit: DoubleValueClass = DoubleValueClass(3.1)) = limit.value

fun valueClassWithInit(valueClass: ValueClassWithInit) = valueClass

fun valueClassWithNullable(limit: LongValueClass?) = limit?.value

fun valueClassWithPrivateConstructor(limit: ValueClassWithPrivateConstructor) = limit.value
}

private class SuspendingValueClassHandler {

suspend fun valueClassReturnValue(): StringValueClass {
delay(1)
return StringValueClass("foo")
}

suspend fun resultOfUnitReturnValue(): Result<Unit> {
delay(1)
return Result.success(Unit)
}

fun longValueClass(limit: LongValueClass) =
limit.value
suspend fun longValueClass(limit: LongValueClass): Long {
delay(1)
return limit.value
}

fun doubleValueClass(limit: DoubleValueClass = DoubleValueClass(3.1)) =
limit.value

fun valueClassWithInit(valueClass: ValueClassWithInit) =
valueClass
suspend fun doubleValueClass(limit: DoubleValueClass = DoubleValueClass(3.1)): Double {
delay(1)
return limit.value
}

fun valueClassWithNullable(limit: LongValueClass?) =
limit?.value
suspend fun valueClassWithInit(valueClass: ValueClassWithInit): ValueClassWithInit {
delay(1)
return valueClass
}

suspend fun valueClassWithNullable(limit: LongValueClass?): Long? {
delay(1)
return limit?.value
}

fun valueClassWithPrivateConstructor(limit: ValueClassWithPrivateConstructor) =
limit.value
suspend fun valueClassWithPrivateConstructor(limit: ValueClassWithPrivateConstructor): Char {
delay(1)
return limit.value
}
}

private class PropertyAccessorHandler {
Expand Down Expand Up @@ -282,4 +382,19 @@ class InvocableHandlerMethodKotlinTests {

class CustomException(message: String) : Throwable(message)

// Avoid adding a spring-webmvc dependency
class ContinuationHandlerMethodArgumentResolver : HandlerMethodArgumentResolver {

override fun supportsParameter(parameter: MethodParameter) =
"kotlin.coroutines.Continuation" == parameter.getParameterType().getName()

override fun resolveArgument(
parameter: MethodParameter,
mavContainer: ModelAndViewContainer?,
webRequest: NativeWebRequest,
binderFactory: WebDataBinderFactory?
) = null

}

}
Loading

0 comments on commit 1aede29

Please sign in to comment.