Skip to content

Commit

Permalink
Polish Bean Override support in the TestContext framework
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrannen committed Mar 10, 2024
1 parent 4c246b7 commit 6f5d3a4
Show file tree
Hide file tree
Showing 21 changed files with 332 additions and 291 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,27 @@

/**
* Mark an annotation as eligible for Bean Override parsing.
* This meta-annotation provides a {@link BeanOverrideProcessor} class which
* must be capable of handling the annotated annotation.
*
* <p>Target annotation must have a {@link RetentionPolicy} of {@code RUNTIME}
* and be applicable to {@link java.lang.reflect.Field Fields} only.
* @see BeanOverrideBeanPostProcessor
* <p>This meta-annotation specifies a {@link BeanOverrideProcessor} class which
* must be capable of handling the composed annotation that is meta-annotated
* with {@code @BeanOverride}.
*
* <p>The composed annotation that is meta-annotated with {@code @BeanOverride}
* must have a {@code RetentionPolicy} of {@link RetentionPolicy#RUNTIME RUNTIME}
* and a {@code Target} of {@link ElementType#FIELD FIELD}.
*
* @author Simon Baslé
* @since 6.2
* @see BeanOverrideBeanPostProcessor
*/
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.ANNOTATION_TYPE})
@Target(ElementType.ANNOTATION_TYPE)
public @interface BeanOverride {

/**
* A {@link BeanOverrideProcessor} implementation class by which the target
* annotation should be processed. Implementations must have a no-argument
* constructor.
* A {@link BeanOverrideProcessor} implementation class by which the composed
* annotation should be processed.
*/
Class<? extends BeanOverrideProcessor> value();

}
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,20 @@

/**
* A {@link BeanFactoryPostProcessor} used to register and inject overriding
* bean metadata with the {@link ApplicationContext}. A set of
* {@link OverrideMetadata} must be passed to the processor.
* A {@link BeanOverrideParser} can typically be used to parse these from test
* classes that use any annotation meta-annotated with {@link BeanOverride} to
* mark override sites.
* bean metadata with the {@link ApplicationContext}.
*
* <p>This processor supports two {@link BeanOverrideStrategy}:
* <p>A set of {@link OverrideMetadata} must be provided to this processor. A
* {@link BeanOverrideParser} can typically be used to parse this metadata from
* test classes that use any annotation meta-annotated with
* {@link BeanOverride @BeanOverride} to mark override sites.
*
* <p>This processor supports two types of {@link BeanOverrideStrategy}:
* <ul>
* <li>replacing a given bean's definition, immediately preparing a singleton
* <li>Replacing a given bean's definition, immediately preparing a singleton
* instance</li>
* <li>intercepting the actual bean instance upon creation and wrapping it,
* <li>Intercepting the actual bean instance upon creation and wrapping it,
* using the early bean definition mechanism of
* {@link SmartInstantiationAwareBeanPostProcessor}).</li>
* {@link SmartInstantiationAwareBeanPostProcessor}</li>
* </ul>
*
* <p>This processor also provides support for injecting the overridden bean
Expand All @@ -78,19 +79,25 @@ public class BeanOverrideBeanPostProcessor implements InstantiationAwareBeanPost
BeanFactoryAware, BeanFactoryPostProcessor, Ordered {

private static final String INFRASTRUCTURE_BEAN_NAME = BeanOverrideBeanPostProcessor.class.getName();
private static final String EARLY_INFRASTRUCTURE_BEAN_NAME = BeanOverrideBeanPostProcessor.WrapEarlyBeanPostProcessor.class.getName();

private final Set<OverrideMetadata> overrideMetadata;
private final Map<String, OverrideMetadata> earlyOverrideMetadata = new HashMap<>();
private static final String EARLY_INFRASTRUCTURE_BEAN_NAME =
BeanOverrideBeanPostProcessor.WrapEarlyBeanPostProcessor.class.getName();

private ConfigurableListableBeanFactory beanFactory;

private final Map<String, OverrideMetadata> earlyOverrideMetadata = new HashMap<>();

private final Map<OverrideMetadata, String> beanNameRegistry = new HashMap<>();

private final Map<Field, String> fieldRegistry = new HashMap<>();

private final Set<OverrideMetadata> overrideMetadata;

@Nullable
private ConfigurableListableBeanFactory beanFactory;


/**
* Create a new {@link BeanOverrideBeanPostProcessor} instance with the
* Create a new {@code BeanOverrideBeanPostProcessor} instance with the
* given {@link OverrideMetadata} set.
* @param overrideMetadata the initial override metadata
*/
Expand All @@ -107,7 +114,7 @@ public int getOrder() {
@Override
public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
Assert.isInstanceOf(ConfigurableListableBeanFactory.class, beanFactory,
"Beans overriding can only be used with a ConfigurableListableBeanFactory");
"Bean overriding can only be used with a ConfigurableListableBeanFactory");
this.beanFactory = (ConfigurableListableBeanFactory) beanFactory;
}

Expand All @@ -120,25 +127,25 @@ protected Set<OverrideMetadata> getOverrideMetadata() {

@Override
public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
Assert.state(this.beanFactory == beanFactory, "Unexpected beanFactory to postProcess");
Assert.state(this.beanFactory == beanFactory, "Unexpected BeanFactory to post-process");
Assert.isInstanceOf(BeanDefinitionRegistry.class, beanFactory,
"Bean overriding annotations can only be used on bean factories that implement "
+ "BeanDefinitionRegistry");
postProcessWithRegistry((BeanDefinitionRegistry) beanFactory);
}

private void postProcessWithRegistry(BeanDefinitionRegistry registry) {
//Note that a tracker bean is registered down the line only if there is some overrideMetadata parsed
Set<OverrideMetadata> overrideMetadata = getOverrideMetadata();
for (OverrideMetadata metadata : overrideMetadata) {
// Note that a tracker bean is registered down the line only if there is some overrideMetadata parsed.
for (OverrideMetadata metadata : getOverrideMetadata()) {
registerBeanOverride(registry, metadata);
}
}

/**
* Copy the details of a {@link BeanDefinition} to the definition created by
* this processor for a given {@link OverrideMetadata}. Defaults to copying
* the {@link BeanDefinition#isPrimary()} attribute and scope.
* Copy certain details of a {@link BeanDefinition} to the definition created by
* this processor for a given {@link OverrideMetadata}.
* <p>The default implementation copies the {@linkplain BeanDefinition#isPrimary()
* primary flag} and the {@linkplain BeanDefinition#getScope() scope}.
*/
protected void copyBeanDefinitionDetails(BeanDefinition from, RootBeanDefinition to) {
to.setPrimary(from.isPrimary());
Expand All @@ -155,6 +162,7 @@ private void registerBeanOverride(BeanDefinitionRegistry registry, OverrideMetad

private void registerReplaceDefinition(BeanDefinitionRegistry registry, OverrideMetadata overrideMetadata,
boolean enforceExistingDefinition) {

RootBeanDefinition beanDefinition = createBeanDefinition(overrideMetadata);
String beanName = overrideMetadata.getExpectedBeanName();

Expand All @@ -166,7 +174,7 @@ private void registerReplaceDefinition(BeanDefinitionRegistry registry, Override
}
else if (enforceExistingDefinition) {
throw new IllegalStateException("Unable to override " + overrideMetadata.getBeanOverrideDescription() +
" bean, expected a bean definition to replace with name '" + beanName + "'");
" bean; expected a bean definition to replace with name '" + beanName + "'");
}
registry.registerBeanDefinition(beanName, beanDefinition);

Expand All @@ -185,10 +193,10 @@ else if (enforceExistingDefinition) {

/**
* Check that the expected bean name is registered and matches the type to override.
* If so, put the override metadata in the early tracking map.
* The map will later be checked to see if a given bean should be wrapped
* <p>If so, put the override metadata in the early tracking map.
* <p>The map will later be checked to see if a given bean should be wrapped
* upon creation, during the {@link WrapEarlyBeanPostProcessor#getEarlyBeanReference(Object, String)}
* phase
* phase.
*/
private void registerWrapEarly(OverrideMetadata metadata) {
Set<String> existingBeanNames = getExistingBeanNames(metadata.typeToOverride());
Expand All @@ -203,11 +211,12 @@ private void registerWrapEarly(OverrideMetadata metadata) {
}

/**
* Check early overrides records and use the {@link OverrideMetadata} to
* Check early override records and use the {@link OverrideMetadata} to
* create an override instance from the provided bean, if relevant.
* <p>Called during the {@link SmartInstantiationAwareBeanPostProcessor}
* phases (see {@link WrapEarlyBeanPostProcessor#getEarlyBeanReference(Object, String)}
* and {@link WrapEarlyBeanPostProcessor#postProcessAfterInitialization(Object, String)}).
* phases.
* @see WrapEarlyBeanPostProcessor#getEarlyBeanReference(Object, String)
* @see WrapEarlyBeanPostProcessor#postProcessAfterInitialization(Object, String)
*/
protected final Object wrapIfNecessary(Object bean, String beanName) throws BeansException {
final OverrideMetadata metadata = this.earlyOverrideMetadata.get(beanName);
Expand Down Expand Up @@ -236,17 +245,15 @@ private Set<String> getExistingBeanNames(ResolvableType resolvableType) {
beans.add(beanName);
}
}
beans.removeIf(this::isScopedTarget);
beans.removeIf(ScopedProxyUtils::isScopedTarget);
return beans;
}

private boolean isScopedTarget(String beanName) {
try {
return ScopedProxyUtils.isScopedTarget(beanName);
}
catch (Throwable ex) {
return false;
}
@Override
public PropertyValues postProcessProperties(PropertyValues pvs, Object bean, String beanName)
throws BeansException {
ReflectionUtils.doWithFields(bean.getClass(), field -> postProcessField(bean, field));
return pvs;
}

private void postProcessField(Object bean, Field field) {
Expand All @@ -256,16 +263,10 @@ private void postProcessField(Object bean, Field field) {
}
}

@Override
public PropertyValues postProcessProperties(PropertyValues pvs, Object bean, String beanName)
throws BeansException {
ReflectionUtils.doWithFields(bean.getClass(), field -> postProcessField(bean, field));
return pvs;
}

void inject(Field field, Object target, OverrideMetadata overrideMetadata) {
String beanName = this.beanNameRegistry.get(overrideMetadata);
Assert.state(StringUtils.hasLength(beanName), () -> "No bean found for overrideMetadata " + overrideMetadata);
Assert.state(StringUtils.hasLength(beanName),
() -> "No bean found for OverrideMetadata: " + overrideMetadata);
inject(field, target, beanName);
}

Expand All @@ -287,25 +288,26 @@ private void inject(Field field, Object target, String beanName) {
}

/**
* Register the processor with a {@link BeanDefinitionRegistry}.
* Not required when using the Spring TestContext Framework, as registration
* is automatic via the {@link org.springframework.core.io.support.SpringFactoriesLoader SpringFactoriesLoader}
* Register a {@link BeanOverrideBeanPostProcessor} with a {@link BeanDefinitionRegistry}.
* <p>Not required when using the Spring TestContext Framework, as registration
* is automatic via the
* {@link org.springframework.core.io.support.SpringFactoriesLoader SpringFactoriesLoader}
* mechanism.
* @param registry the bean definition registry
* @param overrideMetadata the initial override metadata set
*/
public static void register(BeanDefinitionRegistry registry, @Nullable Set<OverrideMetadata> overrideMetadata) {
//early processor
getOrAddInfrastructureBeanDefinition(registry, WrapEarlyBeanPostProcessor.class, EARLY_INFRASTRUCTURE_BEAN_NAME,
constructorArguments -> constructorArguments.addIndexedArgumentValue(0,
new RuntimeBeanReference(INFRASTRUCTURE_BEAN_NAME)));

//main processor
BeanDefinition definition = getOrAddInfrastructureBeanDefinition(registry, BeanOverrideBeanPostProcessor.class,
INFRASTRUCTURE_BEAN_NAME, constructorArguments -> constructorArguments
.addIndexedArgumentValue(0, new LinkedHashSet<OverrideMetadata>()));
ConstructorArgumentValues.ValueHolder constructorArg = definition.getConstructorArgumentValues()
.getIndexedArgumentValue(0, Set.class);
// Early processor
getOrAddInfrastructureBeanDefinition(
registry, WrapEarlyBeanPostProcessor.class, EARLY_INFRASTRUCTURE_BEAN_NAME, constructorArgs ->
constructorArgs.addIndexedArgumentValue(0, new RuntimeBeanReference(INFRASTRUCTURE_BEAN_NAME)));

// Main processor
BeanDefinition definition = getOrAddInfrastructureBeanDefinition(
registry, BeanOverrideBeanPostProcessor.class, INFRASTRUCTURE_BEAN_NAME, constructorArgs ->
constructorArgs.addIndexedArgumentValue(0, new LinkedHashSet<OverrideMetadata>()));
ConstructorArgumentValues.ValueHolder constructorArg =
definition.getConstructorArgumentValues().getIndexedArgumentValue(0, Set.class);
@SuppressWarnings("unchecked")
Set<OverrideMetadata> existing = (Set<OverrideMetadata>) constructorArg.getValue();
if (overrideMetadata != null && existing != null) {
Expand All @@ -315,6 +317,7 @@ public static void register(BeanDefinitionRegistry registry, @Nullable Set<Overr

private static BeanDefinition getOrAddInfrastructureBeanDefinition(BeanDefinitionRegistry registry,
Class<?> clazz, String beanName, Consumer<ConstructorArgumentValues> constructorArgumentsConsumer) {

if (!registry.containsBeanDefinition(beanName)) {
RootBeanDefinition definition = new RootBeanDefinition(clazz);
definition.setRole(BeanDefinition.ROLE_INFRASTRUCTURE);
Expand All @@ -326,17 +329,21 @@ private static BeanDefinition getOrAddInfrastructureBeanDefinition(BeanDefinitio
return registry.getBeanDefinition(beanName);
}


private static final class WrapEarlyBeanPostProcessor implements SmartInstantiationAwareBeanPostProcessor,
PriorityOrdered {

private final BeanOverrideBeanPostProcessor mainProcessor;

private final Map<String, Object> earlyReferences;


private WrapEarlyBeanPostProcessor(BeanOverrideBeanPostProcessor mainProcessor) {
this.mainProcessor = mainProcessor;
this.earlyReferences = new ConcurrentHashMap<>(16);
}


@Override
public int getOrder() {
return Ordered.HIGHEST_PRECEDENCE;
Expand All @@ -363,8 +370,9 @@ public Object postProcessAfterInitialization(Object bean, String beanName) throw
}

private String getCacheKey(Object bean, String beanName) {
return StringUtils.hasLength(beanName) ? beanName : bean.getClass().getName();
return (StringUtils.hasLength(beanName) ? beanName : bean.getClass().getName());
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import java.util.List;
import java.util.Set;

import org.springframework.aot.hint.annotation.Reflective;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.test.context.ContextConfigurationAttributes;
Expand All @@ -29,7 +28,8 @@
import org.springframework.test.context.TestContextAnnotationUtils;

/**
* A {@link ContextCustomizerFactory} to add support for Bean Overriding.
* {@link ContextCustomizerFactory} which provides support for Bean Overriding
* in tests.
*
* @author Simon Baslé
* @since 6.2
Expand All @@ -39,6 +39,7 @@ public class BeanOverrideContextCustomizerFactory implements ContextCustomizerFa
@Override
public ContextCustomizer createContextCustomizer(Class<?> testClass,
List<ContextConfigurationAttributes> configAttributes) {

BeanOverrideParser parser = new BeanOverrideParser();
parseMetadata(testClass, parser);
if (parser.getOverrideMetadata().isEmpty()) {
Expand All @@ -56,10 +57,9 @@ private void parseMetadata(Class<?> testClass, BeanOverrideParser parser) {
}

/**
* A {@link ContextCustomizer} for Bean Overriding in tests.
* {@link ContextCustomizer} for Bean Overriding in tests.
*/
@Reflective
static final class BeanOverrideContextCustomizer implements ContextCustomizer {
private static final class BeanOverrideContextCustomizer implements ContextCustomizer {

private final Set<OverrideMetadata> metadata;

Expand Down Expand Up @@ -97,4 +97,5 @@ public int hashCode() {
return this.metadata.hashCode();
}
}

}
Loading

0 comments on commit 6f5d3a4

Please sign in to comment.