From 90ea2197634e8083eacc1706ccee8ce132d2c639 Mon Sep 17 00:00:00 2001 From: Stefano Cordio Date: Sat, 23 Nov 2024 16:47:20 +0100 Subject: [PATCH] Avoid infinite recursion in AOT processing with recursive generics --- ...alidationBeanRegistrationAotProcessor.java | 19 ++++++++------- ...tionBeanRegistrationAotProcessorTests.java | 24 +++++++++++++++++++ 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/spring-context/src/main/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessor.java b/spring-context/src/main/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessor.java index 0647386c21cb..1783f01afe9e 100644 --- a/spring-context/src/main/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessor.java +++ b/spring-context/src/main/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessor.java @@ -18,7 +18,6 @@ import java.util.Collection; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -104,10 +103,11 @@ public static BeanRegistrationAotContribution processAheadOfTime(RegisteredBean } Class beanClass = registeredBean.getBeanClass(); + Set> visitedClasses = new HashSet<>(); Set> validatedClasses = new HashSet<>(); Set>> constraintValidatorClasses = new HashSet<>(); - processAheadOfTime(beanClass, validatedClasses, constraintValidatorClasses); + processAheadOfTime(beanClass, visitedClasses, validatedClasses, constraintValidatorClasses); if (!validatedClasses.isEmpty() || !constraintValidatorClasses.isEmpty()) { return new AotContribution(validatedClasses, constraintValidatorClasses); @@ -115,9 +115,12 @@ public static BeanRegistrationAotContribution processAheadOfTime(RegisteredBean return null; } - private static void processAheadOfTime(Class clazz, Collection> validatedClasses, - Collection>> constraintValidatorClasses) { + private static void processAheadOfTime(Class clazz, Set> visitedClasses, Set> validatedClasses, + Set>> constraintValidatorClasses) { + if (visitedClasses.add(clazz)) { + return; + } Assert.notNull(validator, "Validator can't be null"); BeanDescriptor descriptor; @@ -149,12 +152,12 @@ else if (ex instanceof TypeNotPresentException) { ReflectionUtils.doWithFields(clazz, field -> { Class type = field.getType(); - if (Iterable.class.isAssignableFrom(type) || List.class.isAssignableFrom(type) || Optional.class.isAssignableFrom(type)) { + if (Iterable.class.isAssignableFrom(type) || Optional.class.isAssignableFrom(type)) { ResolvableType resolvableType = ResolvableType.forField(field); Class genericType = resolvableType.getGeneric(0).toClass(); if (shouldProcess(genericType)) { validatedClasses.add(clazz); - processAheadOfTime(genericType, validatedClasses, constraintValidatorClasses); + processAheadOfTime(genericType, visitedClasses, validatedClasses, constraintValidatorClasses); } } if (Map.class.isAssignableFrom(type)) { @@ -163,11 +166,11 @@ else if (ex instanceof TypeNotPresentException) { Class valueGenericType = resolvableType.getGeneric(1).toClass(); if (shouldProcess(keyGenericType)) { validatedClasses.add(clazz); - processAheadOfTime(keyGenericType, validatedClasses, constraintValidatorClasses); + processAheadOfTime(keyGenericType, visitedClasses, validatedClasses, constraintValidatorClasses); } if (shouldProcess(valueGenericType)) { validatedClasses.add(clazz); - processAheadOfTime(valueGenericType, validatedClasses, constraintValidatorClasses); + processAheadOfTime(valueGenericType, visitedClasses, validatedClasses, constraintValidatorClasses); } } }); diff --git a/spring-context/src/test/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessorTests.java b/spring-context/src/test/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessorTests.java index d43d8033317d..6291fb132972 100644 --- a/spring-context/src/test/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessorTests.java +++ b/spring-context/src/test/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessorTests.java @@ -22,6 +22,9 @@ import java.lang.annotation.Target; import java.util.ArrayList; import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; import jakarta.validation.Constraint; import jakarta.validation.ConstraintValidator; @@ -31,6 +34,8 @@ import jakarta.validation.constraints.Pattern; import org.hibernate.validator.internal.constraintvalidators.bv.PatternValidator; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.hint.MemberCategory; @@ -121,6 +126,13 @@ void shouldProcessTransitiveGenericTypeLevelConstraint() { .withMemberCategory(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS)).accepts(this.generationContext.getRuntimeHints()); } + @ParameterizedTest // gh-33936 + @ValueSource(classes = {BeanWithIterable.class, BeanWithMap.class, BeanWithOptional.class}) + void shouldProcessRecursiveGenericsWithoutInfiniteRecursion(Class beanClass) { + process(beanClass); + assertThat(this.generationContext.getRuntimeHints().reflection().typeHints()).isEmpty(); + } + private void process(Class beanClass) { BeanRegistrationAotContribution contribution = createContribution(beanClass); if (contribution != null) { @@ -244,4 +256,16 @@ public void setExclude(List exclude) { } } + static class BeanWithIterable { + private final Iterable beans = Set.of(); + } + + static class BeanWithMap { + private final Map beans = Map.of(); + } + + static class BeanWithOptional { + private final Optional beans = Optional.empty(); + } + }