Skip to content

Commit

Permalink
Avoid infinite recursion in BeanValidationBeanRegistrationAotProcessor
Browse files Browse the repository at this point in the history
Prior to this commit, AOT processing for bean validation failed with a
StackOverflowError for constraints with fields having recursive generic
types.

With this commit, the algorithm tracks visited classes and aborts
preemptively when a cycle is detected.

Closes spring-projectsgh-33950

Co-authored-by: Sam Brannen <[email protected]>
  • Loading branch information
scordio and sbrannen committed Nov 24, 2024
1 parent 1910d32 commit 5e7b3a3
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -104,20 +103,24 @@ public static BeanRegistrationAotContribution processAheadOfTime(RegisteredBean
}

Class<?> beanClass = registeredBean.getBeanClass();
Set<Class<?>> visitedClasses = new HashSet<>();
Set<Class<?>> validatedClasses = new HashSet<>();
Set<Class<? extends ConstraintValidator<?, ?>>> constraintValidatorClasses = new HashSet<>();

processAheadOfTime(beanClass, validatedClasses, constraintValidatorClasses);
processAheadOfTime(beanClass, visitedClasses, validatedClasses, constraintValidatorClasses);

if (!validatedClasses.isEmpty() || !constraintValidatorClasses.isEmpty()) {
return new AotContribution(validatedClasses, constraintValidatorClasses);
}
return null;
}

private static void processAheadOfTime(Class<?> clazz, Collection<Class<?>> validatedClasses,
Collection<Class<? extends ConstraintValidator<?, ?>>> constraintValidatorClasses) {
private static void processAheadOfTime(Class<?> clazz, Set<Class<?>> visitedClasses, Set<Class<?>> validatedClasses,
Set<Class<? extends ConstraintValidator<?, ?>>> constraintValidatorClasses) {

if (!visitedClasses.add(clazz)) {
return;
}
Assert.notNull(validator, "Validator can't be null");

BeanDescriptor descriptor;
Expand Down Expand Up @@ -149,12 +152,12 @@ else if (ex instanceof TypeNotPresentException) {

ReflectionUtils.doWithFields(clazz, field -> {
Class<?> type = field.getType();
if (Iterable.class.isAssignableFrom(type) || List.class.isAssignableFrom(type) || Optional.class.isAssignableFrom(type)) {
if (Iterable.class.isAssignableFrom(type) || Optional.class.isAssignableFrom(type)) {
ResolvableType resolvableType = ResolvableType.forField(field);
Class<?> genericType = resolvableType.getGeneric(0).toClass();
if (shouldProcess(genericType)) {
validatedClasses.add(clazz);
processAheadOfTime(genericType, validatedClasses, constraintValidatorClasses);
processAheadOfTime(genericType, visitedClasses, validatedClasses, constraintValidatorClasses);
}
}
if (Map.class.isAssignableFrom(type)) {
Expand All @@ -163,11 +166,11 @@ else if (ex instanceof TypeNotPresentException) {
Class<?> valueGenericType = resolvableType.getGeneric(1).toClass();
if (shouldProcess(keyGenericType)) {
validatedClasses.add(clazz);
processAheadOfTime(keyGenericType, validatedClasses, constraintValidatorClasses);
processAheadOfTime(keyGenericType, visitedClasses, validatedClasses, constraintValidatorClasses);
}
if (shouldProcess(valueGenericType)) {
validatedClasses.add(clazz);
processAheadOfTime(valueGenericType, validatedClasses, constraintValidatorClasses);
processAheadOfTime(valueGenericType, visitedClasses, validatedClasses, constraintValidatorClasses);
}
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import java.lang.annotation.Target;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import jakarta.validation.Constraint;
import jakarta.validation.ConstraintValidator;
Expand All @@ -31,6 +34,8 @@
import jakarta.validation.constraints.Pattern;
import org.hibernate.validator.internal.constraintvalidators.bv.PatternValidator;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.hint.MemberCategory;
Expand Down Expand Up @@ -121,6 +126,15 @@ void shouldProcessTransitiveGenericTypeLevelConstraint() {
.withMemberCategory(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS)).accepts(this.generationContext.getRuntimeHints());
}

@ParameterizedTest // gh-33936
@ValueSource(classes = {BeanWithIterable.class, BeanWithMap.class, BeanWithOptional.class})
void shouldProcessRecursiveGenericsWithoutInfiniteRecursion(Class<?> beanClass) {
process(beanClass);
assertThat(this.generationContext.getRuntimeHints().reflection().typeHints()).hasSize(1);
assertThat(RuntimeHintsPredicates.reflection().onType(beanClass)
.withMemberCategory(MemberCategory.DECLARED_FIELDS)).accepts(this.generationContext.getRuntimeHints());
}

private void process(Class<?> beanClass) {
BeanRegistrationAotContribution contribution = createContribution(beanClass);
if (contribution != null) {
Expand Down Expand Up @@ -244,4 +258,16 @@ public void setExclude(List<Exclude> exclude) {
}
}

static class BeanWithIterable {
private final Iterable<BeanWithIterable> beans = Set.of();
}

static class BeanWithMap {
private final Map<String, BeanWithMap> beans = Map.of();
}

static class BeanWithOptional {
private final Optional<BeanWithOptional> beans = Optional.empty();
}

}

0 comments on commit 5e7b3a3

Please sign in to comment.