diff --git a/spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java b/spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java index dbf2df4fa192..63f40110c4bb 100644 --- a/spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java +++ b/spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java @@ -16,7 +16,6 @@ package org.springframework.aop.scope; -import java.lang.reflect.Executable; import java.util.function.Predicate; import javax.lang.model.element.Modifier; @@ -109,7 +108,7 @@ private static class ScopedProxyBeanRegistrationCodeFragments extends BeanRegist } @Override - public ClassName getTarget(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) { + public ClassName getTarget(RegisteredBean registeredBean) { return ClassName.get(this.targetBeanDefinition.getResolvableType().toClass()); } @@ -139,9 +138,7 @@ public CodeBlock generateSetBeanDefinitionPropertiesCode( @Override public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, - BeanRegistrationCode beanRegistrationCode, - Executable constructorOrFactoryMethod, - boolean allowDirectSupplierShortcut) { + BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) { GeneratedMethod generatedMethod = beanRegistrationCode.getMethods() .add("getScopedProxyInstance", method -> { diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java index 68a06fd7dbfd..6986183d14f2 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java @@ -16,10 +16,6 @@ package org.springframework.beans.factory.aot; -import java.lang.reflect.Constructor; -import java.lang.reflect.Executable; -import java.lang.reflect.Method; -import java.lang.reflect.Proxy; import java.util.List; import javax.lang.model.element.Modifier; @@ -29,14 +25,9 @@ import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; -import org.springframework.aot.hint.RuntimeHints; import org.springframework.beans.factory.config.BeanDefinition; -import org.springframework.beans.factory.config.DependencyDescriptor; -import org.springframework.beans.factory.support.AutowireCandidateResolver; -import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RootBeanDefinition; -import org.springframework.core.MethodParameter; import org.springframework.javapoet.ClassName; import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; @@ -56,8 +47,6 @@ class BeanDefinitionMethodGenerator { private final RegisteredBean registeredBean; - private final Executable constructorOrFactoryMethod; - @Nullable private final String currentPropertyName; @@ -83,7 +72,6 @@ class BeanDefinitionMethodGenerator { } this.methodGeneratorFactory = methodGeneratorFactory; this.registeredBean = registeredBean; - this.constructorOrFactoryMethod = registeredBean.resolveConstructorOrFactoryMethod(); this.currentPropertyName = currentPropertyName; this.aotContributions = aotContributions; } @@ -98,9 +86,8 @@ class BeanDefinitionMethodGenerator { MethodReference generateBeanDefinitionMethod(GenerationContext generationContext, BeanRegistrationsCode beanRegistrationsCode) { - registerRuntimeHintsIfNecessary(generationContext.getRuntimeHints()); BeanRegistrationCodeFragments codeFragments = getCodeFragments(generationContext, beanRegistrationsCode); - ClassName target = codeFragments.getTarget(this.registeredBean, this.constructorOrFactoryMethod); + ClassName target = codeFragments.getTarget(this.registeredBean); if (isWritablePackageName(target)) { GeneratedClass generatedClass = lookupGeneratedClass(generationContext, target); GeneratedMethods generatedMethods = generatedClass.getMethods().withPrefix(getName()); @@ -178,8 +165,7 @@ private GeneratedMethod generateBeanDefinitionMethod(GenerationContext generatio BeanRegistrationCodeFragments codeFragments, Modifier modifier) { BeanRegistrationCodeGenerator codeGenerator = new BeanRegistrationCodeGenerator( - className, generatedMethods, this.registeredBean, - this.constructorOrFactoryMethod, codeFragments); + className, generatedMethods, this.registeredBean, codeFragments); this.aotContributions.forEach(aotContribution -> aotContribution.applyTo(generationContext, codeGenerator)); @@ -218,52 +204,4 @@ private String getSimpleBeanName(String beanName) { return StringUtils.uncapitalize(beanName); } - private void registerRuntimeHintsIfNecessary(RuntimeHints runtimeHints) { - if (this.registeredBean.getBeanFactory() instanceof DefaultListableBeanFactory dlbf) { - ProxyRuntimeHintsRegistrar registrar = new ProxyRuntimeHintsRegistrar(dlbf.getAutowireCandidateResolver()); - if (this.constructorOrFactoryMethod instanceof Method method) { - registrar.registerRuntimeHints(runtimeHints, method); - } - else if (this.constructorOrFactoryMethod instanceof Constructor constructor) { - registrar.registerRuntimeHints(runtimeHints, constructor); - } - } - } - - - private static class ProxyRuntimeHintsRegistrar { - - private final AutowireCandidateResolver candidateResolver; - - public ProxyRuntimeHintsRegistrar(AutowireCandidateResolver candidateResolver) { - this.candidateResolver = candidateResolver; - } - - public void registerRuntimeHints(RuntimeHints runtimeHints, Method method) { - Class[] parameterTypes = method.getParameterTypes(); - for (int i = 0; i < parameterTypes.length; i++) { - MethodParameter methodParam = new MethodParameter(method, i); - DependencyDescriptor dependencyDescriptor = new DependencyDescriptor(methodParam, true); - registerProxyIfNecessary(runtimeHints, dependencyDescriptor); - } - } - - public void registerRuntimeHints(RuntimeHints runtimeHints, Constructor constructor) { - Class[] parameterTypes = constructor.getParameterTypes(); - for (int i = 0; i < parameterTypes.length; i++) { - MethodParameter methodParam = new MethodParameter(constructor, i); - DependencyDescriptor dependencyDescriptor = new DependencyDescriptor( - methodParam, true); - registerProxyIfNecessary(runtimeHints, dependencyDescriptor); - } - } - - private void registerProxyIfNecessary(RuntimeHints runtimeHints, DependencyDescriptor dependencyDescriptor) { - Class proxyType = this.candidateResolver.getLazyResolutionProxyClass(dependencyDescriptor, null); - if (proxyType != null && Proxy.isProxyClass(proxyType)) { - runtimeHints.proxies().registerJdkProxy(proxyType.getInterfaces()); - } - } - } - } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragments.java index d7f02a43e5ec..db1bd2e81556 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragments.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragments.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,9 @@ package org.springframework.beans.factory.aot; -import java.lang.reflect.Executable; import java.util.List; import java.util.function.Predicate; +import java.util.function.UnaryOperator; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; @@ -31,9 +31,19 @@ /** * Generate the various fragments of code needed to register a bean. - * + *

+ * A default implementation is provided that suits most needs and custom code + * fragments are only expected to be used by library authors having built custom + * arrangement on top of the core container. + *

+ * Users are not expected to implement this interface directly, but rather extends + * from {@link BeanRegistrationCodeFragmentsDecorator} and only override the + * necessary method(s). * @author Phillip Webb + * @author Stephane Nicoll * @since 6.0 + * @see BeanRegistrationCodeFragmentsDecorator + * @see BeanRegistrationAotContribution#withCustomCodeFragments(UnaryOperator) */ public interface BeanRegistrationCodeFragments { @@ -50,16 +60,19 @@ public interface BeanRegistrationCodeFragments { /** * Return the target for the registration. Used to determine where to write - * the code. + * the code. This should take into account visibility issue, such as + * package access of an element of the bean to register. * @param registeredBean the registered bean - * @param constructorOrFactoryMethod the constructor or factory method * @return the target {@link ClassName} */ - ClassName getTarget(RegisteredBean registeredBean, - Executable constructorOrFactoryMethod); + ClassName getTarget(RegisteredBean registeredBean); /** * Generate the code that defines the new bean definition instance. + *

+ * This should declare a variable named {@value BEAN_DEFINITION_VARIABLE} + * so that further fragments can refer to the variable to further tune + * the bean definition. * @param generationContext the generation context * @param beanType the bean type * @param beanRegistrationCode the bean registration code @@ -81,6 +94,11 @@ CodeBlock generateSetBeanDefinitionPropertiesCode( /** * Generate the code that sets the instance supplier on the bean definition. + *

+ * The {@code postProcessors} represent methods to be exposed once the + * instance has been created to further configure it. Each method should + * accept two parameters, the {@link RegisteredBean} and the bean + * instance, and should return the modified bean instance. * @param generationContext the generation context * @param beanRegistrationCode the bean registration code * @param instanceSupplierCode the instance supplier code supplier code @@ -96,15 +114,13 @@ CodeBlock generateSetBeanInstanceSupplierCode( * Generate the instance supplier code. * @param generationContext the generation context * @param beanRegistrationCode the bean registration code - * @param constructorOrFactoryMethod the constructor or factory method for - * the bean * @param allowDirectSupplierShortcut if direct suppliers may be used rather * than always needing an {@link InstanceSupplier} * @return the generated code */ CodeBlock generateInstanceSupplierCode( GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode, - Executable constructorOrFactoryMethod, boolean allowDirectSupplierShortcut); + boolean allowDirectSupplierShortcut); /** * Generate the return statement. diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragmentsDecorator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragmentsDecorator.java index e4ff961262e1..4a493d0d9395 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragmentsDecorator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragmentsDecorator.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package org.springframework.beans.factory.aot; -import java.lang.reflect.Executable; import java.util.List; import java.util.function.Predicate; import java.util.function.UnaryOperator; @@ -51,8 +50,8 @@ protected BeanRegistrationCodeFragmentsDecorator(BeanRegistrationCodeFragments d } @Override - public ClassName getTarget(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) { - return this.delegate.getTarget(registeredBean, constructorOrFactoryMethod); + public ClassName getTarget(RegisteredBean registeredBean) { + return this.delegate.getTarget(registeredBean); } @Override @@ -83,11 +82,10 @@ public CodeBlock generateSetBeanInstanceSupplierCode(GenerationContext generatio @Override public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, - BeanRegistrationCode beanRegistrationCode, Executable constructorOrFactoryMethod, - boolean allowDirectSupplierShortcut) { + BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) { return this.delegate.generateInstanceSupplierCode(generationContext, - beanRegistrationCode, constructorOrFactoryMethod, allowDirectSupplierShortcut); + beanRegistrationCode, allowDirectSupplierShortcut); } @Override diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeGenerator.java index 3547378b0673..98564d4852e7 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeGenerator.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package org.springframework.beans.factory.aot; -import java.lang.reflect.Executable; import java.util.ArrayList; import java.util.List; import java.util.function.Predicate; @@ -47,19 +46,15 @@ class BeanRegistrationCodeGenerator implements BeanRegistrationCode { private final RegisteredBean registeredBean; - private final Executable constructorOrFactoryMethod; - private final BeanRegistrationCodeFragments codeFragments; BeanRegistrationCodeGenerator(ClassName className, GeneratedMethods generatedMethods, - RegisteredBean registeredBean, Executable constructorOrFactoryMethod, - BeanRegistrationCodeFragments codeFragments) { + RegisteredBean registeredBean, BeanRegistrationCodeFragments codeFragments) { this.className = className; this.generatedMethods = generatedMethods; this.registeredBean = registeredBean; - this.constructorOrFactoryMethod = constructorOrFactoryMethod; this.codeFragments = codeFragments; } @@ -87,8 +82,7 @@ CodeBlock generateCode(GenerationContext generationContext) { generationContext, this, this.registeredBean.getMergedBeanDefinition(), REJECT_ALL_ATTRIBUTES_FILTER)); CodeBlock instanceSupplierCode = this.codeFragments.generateInstanceSupplierCode( - generationContext, this, this.constructorOrFactoryMethod, - this.instancePostProcessors.isEmpty()); + generationContext, this, this.instancePostProcessors.isEmpty()); code.add(this.codeFragments.generateSetBeanInstanceSupplierCode(generationContext, this, instanceSupplierCode, this.instancePostProcessors)); code.add(this.codeFragments.generateReturnCode(generationContext, this)); diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java index 26df73d45ee5..e51aebd7f024 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java @@ -21,6 +21,7 @@ import java.lang.reflect.Modifier; import java.util.List; import java.util.function.Predicate; +import java.util.function.Supplier; import org.springframework.aot.generate.AccessControl; import org.springframework.aot.generate.GenerationContext; @@ -39,12 +40,14 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; +import org.springframework.util.function.SingletonSupplier; /** * Internal {@link BeanRegistrationCodeFragments} implementation used by * default. * * @author Phillip Webb + * @author Stephane Nicoll */ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragments { @@ -54,6 +57,8 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme private final BeanDefinitionMethodGeneratorFactory beanDefinitionMethodGeneratorFactory; + private final Supplier constructorOrFactoryMethod; + DefaultBeanRegistrationCodeFragments(BeanRegistrationsCode beanRegistrationsCode, RegisteredBean registeredBean, @@ -62,14 +67,13 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme this.beanRegistrationsCode = beanRegistrationsCode; this.registeredBean = registeredBean; this.beanDefinitionMethodGeneratorFactory = beanDefinitionMethodGeneratorFactory; + this.constructorOrFactoryMethod = SingletonSupplier.of(registeredBean::resolveConstructorOrFactoryMethod); } @Override - public ClassName getTarget(RegisteredBean registeredBean, - Executable constructorOrFactoryMethod) { - - Class target = extractDeclaringClass(registeredBean.getBeanType(), constructorOrFactoryMethod); + public ClassName getTarget(RegisteredBean registeredBean) { + Class target = extractDeclaringClass(registeredBean.getBeanType(), this.constructorOrFactoryMethod.get()); while (target.getName().startsWith("java.") && registeredBean.isInnerBean()) { RegisteredBean parent = registeredBean.getParent(); Assert.state(parent != null, "No parent available for inner bean"); @@ -219,12 +223,11 @@ public CodeBlock generateSetBeanInstanceSupplierCode( @Override public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, - BeanRegistrationCode beanRegistrationCode, - Executable constructorOrFactoryMethod, boolean allowDirectSupplierShortcut) { + BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) { return new InstanceSupplierCodeGenerator(generationContext, beanRegistrationCode.getClassName(), beanRegistrationCode.getMethods(), allowDirectSupplierShortcut) - .generateCode(this.registeredBean, constructorOrFactoryMethod); + .generateCode(this.registeredBean,this.constructorOrFactoryMethod.get()); } @Override diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java index e21ead29bf5b..87de98852ec2 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java @@ -21,6 +21,7 @@ import java.lang.reflect.Member; import java.lang.reflect.Method; import java.lang.reflect.Modifier; +import java.lang.reflect.Proxy; import java.util.Arrays; import java.util.function.Consumer; @@ -38,9 +39,14 @@ import org.springframework.aot.hint.ExecutableMode; import org.springframework.aot.hint.MemberCategory; import org.springframework.aot.hint.ReflectionHints; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.beans.factory.config.DependencyDescriptor; +import org.springframework.beans.factory.support.AutowireCandidateResolver; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.InstanceSupplier; import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.core.KotlinDetector; +import org.springframework.core.MethodParameter; import org.springframework.core.ResolvableType; import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; @@ -51,9 +57,10 @@ import org.springframework.util.function.ThrowingSupplier; /** - * Internal code generator to create an {@link InstanceSupplier}, usually in + * Default code generator to create an {@link InstanceSupplier}, usually in * the form of a {@link BeanInstanceSupplier} that retains the executable - * that is used to instantiate the bean. + * that is used to instantiate the bean. Takes care of registering the + * necessary hints if reflection or a JDK proxy is required. * *

Generated code is usually a method reference that generates the * {@link BeanInstanceSupplier}, but some shortcut can be used as well such as: @@ -66,8 +73,9 @@ * @author Juergen Hoeller * @author Sebastien Deleuze * @since 6.0 + * @see BeanRegistrationCodeFragments */ -class InstanceSupplierCodeGenerator { +public class InstanceSupplierCodeGenerator { private static final String REGISTERED_BEAN_PARAMETER_NAME = "registeredBean"; @@ -89,7 +97,15 @@ class InstanceSupplierCodeGenerator { private final boolean allowDirectSupplierShortcut; - InstanceSupplierCodeGenerator(GenerationContext generationContext, + /** + * Create a new instance. + * @param generationContext the generation context + * @param className the class name of the bean to instantiate + * @param generatedMethods the generated methods + * @param allowDirectSupplierShortcut whether a direct supplier may be used rather + * than always needing an {@link InstanceSupplier} + */ + public InstanceSupplierCodeGenerator(GenerationContext generationContext, ClassName className, GeneratedMethods generatedMethods, boolean allowDirectSupplierShortcut) { this.generationContext = generationContext; @@ -98,8 +114,14 @@ class InstanceSupplierCodeGenerator { this.allowDirectSupplierShortcut = allowDirectSupplierShortcut; } - - CodeBlock generateCode(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) { + /** + * Generate the instance supplier code. + * @param registeredBean the bean to handle + * @param constructorOrFactoryMethod the executable to use to create the bean + * @return the generated code + */ + public CodeBlock generateCode(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) { + registerRuntimeHintsIfNecessary(registeredBean, constructorOrFactoryMethod); if (constructorOrFactoryMethod instanceof Constructor constructor) { return generateCodeForConstructor(registeredBean, constructor); } @@ -110,6 +132,19 @@ CodeBlock generateCode(RegisteredBean registeredBean, Executable constructorOrFa "No suitable executor found for " + registeredBean.getBeanName()); } + private void registerRuntimeHintsIfNecessary(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) { + if (registeredBean.getBeanFactory() instanceof DefaultListableBeanFactory dlbf) { + RuntimeHints runtimeHints = this.generationContext.getRuntimeHints(); + ProxyRuntimeHintsRegistrar registrar = new ProxyRuntimeHintsRegistrar(dlbf.getAutowireCandidateResolver()); + if (constructorOrFactoryMethod instanceof Method method) { + registrar.registerRuntimeHints(runtimeHints, method); + } + else if (constructorOrFactoryMethod instanceof Constructor constructor) { + registrar.registerRuntimeHints(runtimeHints, constructor); + } + } + } + private CodeBlock generateCodeForConstructor(RegisteredBean registeredBean, Constructor constructor) { String beanName = registeredBean.getBeanName(); Class beanClass = registeredBean.getBeanClass(); @@ -372,4 +407,40 @@ public static boolean hasConstructorWithOptionalParameter(Class beanClass) { } + + private static class ProxyRuntimeHintsRegistrar { + + private final AutowireCandidateResolver candidateResolver; + + public ProxyRuntimeHintsRegistrar(AutowireCandidateResolver candidateResolver) { + this.candidateResolver = candidateResolver; + } + + public void registerRuntimeHints(RuntimeHints runtimeHints, Method method) { + Class[] parameterTypes = method.getParameterTypes(); + for (int i = 0; i < parameterTypes.length; i++) { + MethodParameter methodParam = new MethodParameter(method, i); + DependencyDescriptor dependencyDescriptor = new DependencyDescriptor(methodParam, true); + registerProxyIfNecessary(runtimeHints, dependencyDescriptor); + } + } + + public void registerRuntimeHints(RuntimeHints runtimeHints, Constructor constructor) { + Class[] parameterTypes = constructor.getParameterTypes(); + for (int i = 0; i < parameterTypes.length; i++) { + MethodParameter methodParam = new MethodParameter(constructor, i); + DependencyDescriptor dependencyDescriptor = new DependencyDescriptor( + methodParam, true); + registerProxyIfNecessary(runtimeHints, dependencyDescriptor); + } + } + + private void registerProxyIfNecessary(RuntimeHints runtimeHints, DependencyDescriptor dependencyDescriptor) { + Class proxyType = this.candidateResolver.getLazyResolutionProxyClass(dependencyDescriptor, null); + if (proxyType != null && Proxy.isProxyClass(proxyType)) { + runtimeHints.proxies().registerJdkProxy(proxyType.getInterfaces()); + } + } + } + } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java index 1ab505e3d35e..859cbe62b667 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java @@ -16,10 +16,14 @@ package org.springframework.beans.factory.aot; +import java.lang.reflect.Constructor; +import java.lang.reflect.Executable; import java.lang.reflect.Method; +import java.util.function.UnaryOperator; import org.junit.jupiter.api.Test; +import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.annotation.InjectAnnotationBeanPostProcessorTests.StringFactoryBean; @@ -28,6 +32,7 @@ import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.testfixture.beans.factory.DummyFactory; import org.springframework.beans.testfixture.beans.factory.aot.GenericFactoryBean; +import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationCode; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationsCode; import org.springframework.beans.testfixture.beans.factory.aot.NumberFactoryBean; import org.springframework.beans.testfixture.beans.factory.aot.SimpleBean; @@ -35,9 +40,14 @@ import org.springframework.beans.testfixture.beans.factory.aot.SimpleBeanFactoryBean; import org.springframework.core.ResolvableType; import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; +import org.springframework.lang.Nullable; import org.springframework.util.ReflectionUtils; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; /** * Tests for {@link DefaultBeanRegistrationCodeFragments}. @@ -48,136 +58,202 @@ class DefaultBeanRegistrationCodeFragmentsTests { private final BeanRegistrationsCode beanRegistrationsCode = new MockBeanRegistrationsCode(new TestGenerationContext()); + private final GenerationContext generationContext = new TestGenerationContext(); + private final DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); @Test void getTargetOnConstructor() { - RegisteredBean registeredBean = registerTestBean(SimpleBean.class); - assertTarget(createInstance(registeredBean).getTarget(registeredBean, - SimpleBean.class.getDeclaredConstructors()[0]), SimpleBean.class); + RegisteredBean registeredBean = registerTestBean(SimpleBean.class, + SimpleBean.class.getDeclaredConstructors()[0]); + assertTarget(createInstance(registeredBean).getTarget(registeredBean), SimpleBean.class); } @Test void getTargetOnConstructorToPublicFactoryBean() { - RegisteredBean registeredBean = registerTestBean(SimpleBean.class); - assertTarget(createInstance(registeredBean).getTarget(registeredBean, - SimpleBeanFactoryBean.class.getDeclaredConstructors()[0]), SimpleBean.class); + RegisteredBean registeredBean = registerTestBean(SimpleBean.class, + SimpleBeanFactoryBean.class.getDeclaredConstructors()[0]); + assertTarget(createInstance(registeredBean).getTarget(registeredBean), SimpleBean.class); } @Test void getTargetOnConstructorToPublicGenericFactoryBeanExtractTargetFromFactoryBeanType() { - RegisteredBean registeredBean = registerTestBean(ResolvableType - .forClassWithGenerics(GenericFactoryBean.class, SimpleBean.class)); - assertTarget(createInstance(registeredBean).getTarget(registeredBean, - GenericFactoryBean.class.getDeclaredConstructors()[0]), SimpleBean.class); + ResolvableType beanType = ResolvableType.forClassWithGenerics( + GenericFactoryBean.class, SimpleBean.class); + RegisteredBean registeredBean = registerTestBean(beanType, + GenericFactoryBean.class.getDeclaredConstructors()[0]); + assertTarget(createInstance(registeredBean).getTarget(registeredBean), SimpleBean.class); } @Test void getTargetOnConstructorToPublicGenericFactoryBeanWithBoundExtractTargetFromFactoryBeanType() { - RegisteredBean registeredBean = registerTestBean(ResolvableType - .forClassWithGenerics(NumberFactoryBean.class, Integer.class)); - assertTarget(createInstance(registeredBean).getTarget(registeredBean, - NumberFactoryBean.class.getDeclaredConstructors()[0]), Integer.class); + ResolvableType beanType = ResolvableType.forClassWithGenerics( + NumberFactoryBean.class, Integer.class); + RegisteredBean registeredBean = registerTestBean(beanType, + NumberFactoryBean.class.getDeclaredConstructors()[0]); + assertTarget(createInstance(registeredBean).getTarget(registeredBean), Integer.class); } @Test void getTargetOnConstructorToPublicGenericFactoryBeanUseBeanTypeAsFallback() { - RegisteredBean registeredBean = registerTestBean(SimpleBean.class); - assertTarget(createInstance(registeredBean).getTarget(registeredBean, - GenericFactoryBean.class.getDeclaredConstructors()[0]), SimpleBean.class); + RegisteredBean registeredBean = registerTestBean(SimpleBean.class, + GenericFactoryBean.class.getDeclaredConstructors()[0]); + assertTarget(createInstance(registeredBean).getTarget(registeredBean), SimpleBean.class); } @Test void getTargetOnConstructorToProtectedFactoryBean() { - RegisteredBean registeredBean = registerTestBean(SimpleBean.class); - assertTarget(createInstance(registeredBean).getTarget(registeredBean, - PrivilegedTestBeanFactoryBean.class.getDeclaredConstructors()[0]), + RegisteredBean registeredBean = registerTestBean(SimpleBean.class, + PrivilegedTestBeanFactoryBean.class.getDeclaredConstructors()[0]); + assertTarget(createInstance(registeredBean).getTarget(registeredBean), PrivilegedTestBeanFactoryBean.class); } @Test void getTargetOnMethod() { - RegisteredBean registeredBean = registerTestBean(SimpleBean.class); Method method = ReflectionUtils.findMethod(SimpleBeanConfiguration.class, "simpleBean"); assertThat(method).isNotNull(); - assertTarget(createInstance(registeredBean).getTarget(registeredBean, method), + RegisteredBean registeredBean = registerTestBean(SimpleBean.class, method); + assertTarget(createInstance(registeredBean).getTarget(registeredBean), SimpleBeanConfiguration.class); } @Test void getTargetOnMethodWithInnerBeanInJavaPackage() { RegisteredBean registeredBean = registerTestBean(SimpleBean.class); - RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", - new RootBeanDefinition(String.class)); Method method = ReflectionUtils.findMethod(getClass(), "createString"); assertThat(method).isNotNull(); - assertTarget(createInstance(innerBean).getTarget(innerBean, method), getClass()); + RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", + applyConstructorOrFactoryMethod(new RootBeanDefinition(String.class), method)); + assertTarget(createInstance(innerBean).getTarget(innerBean), getClass()); } @Test void getTargetOnConstructorWithInnerBeanInJavaPackage() { RegisteredBean registeredBean = registerTestBean(SimpleBean.class); - RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", new RootBeanDefinition(String.class)); - assertTarget(createInstance(innerBean).getTarget(innerBean, - String.class.getDeclaredConstructors()[0]), SimpleBean.class); + RootBeanDefinition innerBeanDefinition = applyConstructorOrFactoryMethod( + new RootBeanDefinition(String.class), String.class.getDeclaredConstructors()[0]); + RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", + innerBeanDefinition); + assertTarget(createInstance(innerBean).getTarget(innerBean), SimpleBean.class); } @Test void getTargetOnConstructorWithInnerBeanOnTypeInJavaPackage() { RegisteredBean registeredBean = registerTestBean(SimpleBean.class); + RootBeanDefinition innerBeanDefinition = applyConstructorOrFactoryMethod( + new RootBeanDefinition(StringFactoryBean.class), + StringFactoryBean.class.getDeclaredConstructors()[0]); RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", - new RootBeanDefinition(StringFactoryBean.class)); - assertTarget(createInstance(innerBean).getTarget(innerBean, - StringFactoryBean.class.getDeclaredConstructors()[0]), SimpleBean.class); + innerBeanDefinition); + assertTarget(createInstance(innerBean).getTarget(innerBean), SimpleBean.class); } @Test void getTargetOnMethodWithInnerBeanInRegularPackage() { RegisteredBean registeredBean = registerTestBean(DummyFactory.class); - RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", - new RootBeanDefinition(SimpleBean.class)); Method method = ReflectionUtils.findMethod(SimpleBeanConfiguration.class, "simpleBean"); assertThat(method).isNotNull(); - assertTarget(createInstance(innerBean).getTarget(innerBean, method), + RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", + applyConstructorOrFactoryMethod(new RootBeanDefinition(SimpleBean.class), method)); + assertTarget(createInstance(innerBean).getTarget(innerBean), SimpleBeanConfiguration.class); } @Test void getTargetOnConstructorWithInnerBeanInRegularPackage() { RegisteredBean registeredBean = registerTestBean(DummyFactory.class); + RootBeanDefinition innerBeanDefinition = applyConstructorOrFactoryMethod( + new RootBeanDefinition(SimpleBean.class), SimpleBean.class.getDeclaredConstructors()[0]); RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", - new RootBeanDefinition(SimpleBean.class)); - assertTarget(createInstance(innerBean).getTarget(innerBean, - SimpleBean.class.getDeclaredConstructors()[0]), SimpleBean.class); + innerBeanDefinition); + assertTarget(createInstance(innerBean).getTarget(innerBean), SimpleBean.class); } @Test void getTargetOnConstructorWithInnerBeanOnFactoryBeanOnTypeInRegularPackage() { RegisteredBean registeredBean = registerTestBean(DummyFactory.class); + RootBeanDefinition innerBeanDefinition = applyConstructorOrFactoryMethod( + new RootBeanDefinition(SimpleBean.class), + SimpleBeanFactoryBean.class.getDeclaredConstructors()[0]); RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", - new RootBeanDefinition(SimpleBean.class)); - assertTarget(createInstance(innerBean).getTarget(innerBean, - SimpleBeanFactoryBean.class.getDeclaredConstructors()[0]), SimpleBean.class); + innerBeanDefinition); + assertTarget(createInstance(innerBean).getTarget(innerBean), SimpleBean.class); + } + + @Test + void customizedGetTargetDoesNotResolveConstructorOrFactoryMethod() { + RegisteredBean registeredBean = spy(registerTestBean(SimpleBean.class)); + BeanRegistrationCodeFragments customCodeFragments = createCustomCodeFragments(registeredBean, codeFragments -> new BeanRegistrationCodeFragmentsDecorator(codeFragments) { + @Override + public ClassName getTarget(RegisteredBean registeredBean) { + return ClassName.get(String.class); + } + }); + assertTarget(customCodeFragments.getTarget(registeredBean), String.class); + verify(registeredBean, never()).resolveConstructorOrFactoryMethod(); + } + + @Test + void customizedGenerateInstanceSupplierCodeDoesNotResolveConstructorOrFactoryMethod() { + RegisteredBean registeredBean = spy(registerTestBean(SimpleBean.class)); + BeanRegistrationCodeFragments customCodeFragments = createCustomCodeFragments(registeredBean, codeFragments -> new BeanRegistrationCodeFragmentsDecorator(codeFragments) { + @Override + public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, + BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) { + return CodeBlock.of("// Hello"); + } + }); + assertThat(customCodeFragments.generateInstanceSupplierCode(this.generationContext, + new MockBeanRegistrationCode(this.generationContext), false)).hasToString("// Hello"); + verify(registeredBean, never()).resolveConstructorOrFactoryMethod(); + } + + private BeanRegistrationCodeFragments createCustomCodeFragments(RegisteredBean registeredBean, UnaryOperator customFragments) { + BeanRegistrationAotContribution aotContribution = BeanRegistrationAotContribution. + withCustomCodeFragments(customFragments); + BeanRegistrationCodeFragments defaultCodeFragments = createInstance(registeredBean); + return aotContribution.customizeBeanRegistrationCodeFragments( + this.generationContext, defaultCodeFragments); } private void assertTarget(ClassName target, Class expected) { assertThat(target).isEqualTo(ClassName.get(expected)); } - private RegisteredBean registerTestBean(Class beanType) { - this.beanFactory.registerBeanDefinition("testBean", - new RootBeanDefinition(beanType)); + return registerTestBean(beanType, null); + } + + private RegisteredBean registerTestBean(Class beanType, + @Nullable Executable constructorOrFactoryMethod) { + this.beanFactory.registerBeanDefinition("testBean", applyConstructorOrFactoryMethod( + new RootBeanDefinition(beanType), constructorOrFactoryMethod)); return RegisteredBean.of(this.beanFactory, "testBean"); } - private RegisteredBean registerTestBean(ResolvableType beanType) { + + private RegisteredBean registerTestBean(ResolvableType beanType, + @Nullable Executable constructorOrFactoryMethod) { RootBeanDefinition beanDefinition = new RootBeanDefinition(); beanDefinition.setTargetType(beanType); - this.beanFactory.registerBeanDefinition("testBean", beanDefinition); + this.beanFactory.registerBeanDefinition("testBean", + applyConstructorOrFactoryMethod(beanDefinition, constructorOrFactoryMethod)); return RegisteredBean.of(this.beanFactory, "testBean"); } + private RootBeanDefinition applyConstructorOrFactoryMethod(RootBeanDefinition beanDefinition, + @Nullable Executable constructorOrFactoryMethod) { + + if (constructorOrFactoryMethod instanceof Method method) { + beanDefinition.setResolvedFactoryMethod(method); + } + else if (constructorOrFactoryMethod instanceof Constructor constructor) { + beanDefinition.setAttribute(RootBeanDefinition.PREFERRED_CONSTRUCTORS_ATTRIBUTE, constructor); + } + return beanDefinition; + } + private BeanRegistrationCodeFragments createInstance(RegisteredBean registeredBean) { return new DefaultBeanRegistrationCodeFragments(this.beanRegistrationsCode, registeredBean, new BeanDefinitionMethodGeneratorFactory(this.beanFactory)); diff --git a/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java b/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java index aefd1f7ae101..87ad3b39aedb 100644 --- a/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java +++ b/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java @@ -58,6 +58,7 @@ import org.springframework.beans.factory.aot.BeanRegistrationCode; import org.springframework.beans.factory.aot.BeanRegistrationCodeFragments; import org.springframework.beans.factory.aot.BeanRegistrationCodeFragmentsDecorator; +import org.springframework.beans.factory.aot.InstanceSupplierCodeGenerator; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinitionHolder; import org.springframework.beans.factory.config.BeanFactoryPostProcessor; @@ -315,9 +316,8 @@ public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registe Object configClassAttr = registeredBean.getMergedBeanDefinition() .getAttribute(ConfigurationClassUtils.CONFIGURATION_CLASS_ATTRIBUTE); if (ConfigurationClassUtils.CONFIGURATION_CLASS_FULL.equals(configClassAttr)) { - Class proxyClass = registeredBean.getBeanType().toClass(); return BeanRegistrationAotContribution.withCustomCodeFragments(codeFragments -> - new ConfigurationClassProxyBeanRegistrationCodeFragments(codeFragments, proxyClass)); + new ConfigurationClassProxyBeanRegistrationCodeFragments(codeFragments, registeredBean)); } return null; } @@ -749,12 +749,15 @@ private CodeBlock handleNull(@Nullable Object value, Supplier nonNull private static class ConfigurationClassProxyBeanRegistrationCodeFragments extends BeanRegistrationCodeFragmentsDecorator { + private final RegisteredBean registeredBean; + private final Class proxyClass; public ConfigurationClassProxyBeanRegistrationCodeFragments(BeanRegistrationCodeFragments codeFragments, - Class proxyClass) { + RegisteredBean registeredBean) { super(codeFragments); - this.proxyClass = proxyClass; + this.registeredBean = registeredBean; + this.proxyClass = registeredBean.getBeanType().toClass(); } @Override @@ -770,11 +773,14 @@ public CodeBlock generateSetBeanDefinitionPropertiesCode(GenerationContext gener @Override public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, - BeanRegistrationCode beanRegistrationCode, Executable constructorOrFactoryMethod, + BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) { - Executable executableToUse = proxyExecutable(generationContext.getRuntimeHints(), constructorOrFactoryMethod); - return super.generateInstanceSupplierCode(generationContext, beanRegistrationCode, - executableToUse, allowDirectSupplierShortcut); + + Executable executableToUse = proxyExecutable(generationContext.getRuntimeHints(), + this.registeredBean.resolveConstructorOrFactoryMethod()); + return new InstanceSupplierCodeGenerator(generationContext, + beanRegistrationCode.getClassName(), beanRegistrationCode.getMethods(), allowDirectSupplierShortcut) + .generateCode(this.registeredBean, executableToUse); } private Executable proxyExecutable(RuntimeHints runtimeHints, Executable userExecutable) { diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java b/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java index c1a5cf8eae26..2a6d68a2b615 100644 --- a/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java @@ -17,7 +17,6 @@ package org.springframework.orm.jpa.persistenceunit; import java.lang.annotation.Annotation; -import java.lang.reflect.Executable; import java.util.List; import javax.lang.model.element.Modifier; @@ -97,7 +96,6 @@ public JpaManagedTypesBeanRegistrationCodeFragments(BeanRegistrationCodeFragment @Override public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode, - Executable constructorOrFactoryMethod, boolean allowDirectSupplierShortcut) { PersistenceManagedTypes persistenceManagedTypes = this.registeredBean.getBeanFactory() .getBean(this.registeredBean.getBeanName(), PersistenceManagedTypes.class);