Skip to content

Commit

Permalink
Skip calling containsBean for io.micronaut.context.BeanProvider (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
dstepanov authored Aug 24, 2023
1 parent 4a503e6 commit 287b670
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package io.micronaut.inject.provider

import io.micronaut.annotation.processing.test.AbstractTypeElementSpec
import io.micronaut.context.ApplicationContext
import io.micronaut.context.DefaultApplicationContext
import io.micronaut.context.exceptions.NoSuchBeanException
import io.micronaut.context.exceptions.NonUniqueBeanException
import io.micronaut.inject.BeanDefinition
Expand Down Expand Up @@ -259,6 +260,9 @@ class Test {
then: 'BeanOneTwo is found'
foundBean.isPresent()
foundBean.get().class.name == 'test.BeanOneTwo'

cleanup:
context.close()
}

void "test BeanProvider's get by qualifier method" () {
Expand Down Expand Up @@ -319,5 +323,87 @@ class Test {

then: 'BeanOneTwo is returned'
foundBean.class.name == 'test.BeanOneTwo'

cleanup:
context.close()
}

void "test Jakarta Provider is triggering containsBean" () {
given:
DefaultApplicationContext context = buildContext('''\
package test;
import io.micronaut.inject.annotation.*;
import io.micronaut.context.annotation.*;
import jakarta.inject.Provider;
interface BeanNumber { }
@jakarta.inject.Singleton
class BeanNumberImpl implements BeanNumber {
}
@jakarta.inject.Singleton
class Test {
public Provider<BeanNumber> provider;
Test(Provider<BeanNumber> provider) {
this.provider = provider;
}
}
''')
def containsBeanCacheField = context.getClass().superclass.superclass.declaredFields.find {it.name == "containsBeanCache"}
containsBeanCacheField.accessible = true
Map containsBeanCache = containsBeanCacheField.get(context)

when: 'retrieve test bean'
int mapSize = containsBeanCache.size()
def bean = getBean(context, 'test.Test')

then: 'containsBean is triggered'
containsBeanCache.size() == mapSize + 1

then: 'bean exists'
bean.provider.get()

cleanup:
context.close()
}

void "test BeanProvider is not triggering containsBean" () {
given:
DefaultApplicationContext context = buildContext('''\
package test;
import io.micronaut.inject.annotation.*;
import io.micronaut.context.annotation.*;
import io.micronaut.context.BeanProvider;
interface BeanNumber { }
@jakarta.inject.Singleton
class Test {
public BeanProvider<BeanNumber> provider;
Test(BeanProvider<BeanNumber> provider) {
this.provider = provider;
}
}
''')
def containsBeanCacheField = context.getClass().superclass.superclass.declaredFields.find {it.name == "containsBeanCache"}
containsBeanCacheField.accessible = true
Map containsBeanCache = containsBeanCacheField.get(context)

when: 'retrieve test bean'
int mapSize = containsBeanCache.size()
def bean = getBean(context, 'test.Test')

then: 'containsBean is not triggered'
containsBeanCache.size() == mapSize

then: 'containsBean is triggered'
!bean.provider.isPresent()
containsBeanCache.size() == mapSize + 1

cleanup:
context.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,16 @@ public T instantiate(BeanResolutionContext resolutionContext, BeanContext contex
final BeanResolutionContext.Segment<?, ?> segment = resolutionContext.getPath().currentSegment().orElse(null);
if (segment != null) {
final InjectionPoint<?> injectionPoint = segment.getInjectionPoint();
if (injectionPoint instanceof ArgumentCoercible) {
Argument<?> injectionPointArgument = ((ArgumentCoercible<?>) injectionPoint)
.asArgument();

if (injectionPoint instanceof ArgumentCoercible<?> argumentCoercible) {
Argument<?> injectionPointArgument = argumentCoercible.asArgument();
Argument<?> resolveArgument = injectionPointArgument;
boolean isNullableProvider = injectionPointArgument.isNullable();
boolean isOptionalProvider;
if (resolveArgument.isOptional()) {
resolveArgument = resolveArgument.getFirstTypeVariable().orElse(Argument.OBJECT_ARGUMENT);
isOptionalProvider = true;
} else {
isOptionalProvider = false;
}
@SuppressWarnings("unchecked") Argument<Object> argument =
(Argument<Object>) resolveArgument
Expand All @@ -144,36 +147,26 @@ public T instantiate(BeanResolutionContext resolutionContext, BeanContext contex
qualifier = Qualifiers.byName(n.toString());
}
}

boolean hasBean = context.containsBean(argument, qualifier);
if (hasBean) {
return buildProvider(
resolutionContext,
context,
argument,
qualifier,
isSingleton()
);
} else {
if (injectionPointArgument.isOptional()) {
return (T) Optional.empty();
} else if (injectionPointArgument.isNullable()) {
throw new DisabledBeanException("Nullable bean doesn't exist");
} else {
if (qualifier instanceof AnyQualifier || isAllowEmptyProviders(context)) {
return buildProvider(
resolutionContext,
context,
argument,
qualifier,
isSingleton()
);
} else {
throw new NoSuchBeanException(argument, qualifier);
if (isNullableProvider || isOptionalProvider || !(isAllowEmptyProviders(context) || qualifier instanceof AnyQualifier)) {
// Skip the contains bean for the providers that support an empty value and aren't nullable or optional
boolean hasBean = context.containsBean(argument, qualifier);
if (!hasBean) {
if (isNullableProvider) {
throw new DisabledBeanException("Nullable bean doesn't exist");
}
if (isOptionalProvider) {
return (T) Optional.empty();
}
throw new NoSuchBeanException(argument, qualifier);
}
}

return buildProvider(
resolutionContext,
context,
argument,
qualifier,
isSingleton()
);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ protected BeanProvider<Object> buildProvider(
@NonNull Argument<Object> argument,
@Nullable Qualifier<Object> qualifier,
boolean singleton) {
return new BeanProvider<Object>() {
return new BeanProvider<>() {

private final DefaultBeanContext defaultBeanContext = (DefaultBeanContext) context;
private final Qualifier<Object> finalQualifier =
qualifier instanceof AnyQualifier ? null : qualifier;

Expand All @@ -78,12 +79,12 @@ private Qualifier<Object> qualify(Qualifier<Object> qualifier) {

@Override
public Object get() {
return ((DefaultBeanContext) context).getBean(resolutionContext.copy(), argument, finalQualifier);
return defaultBeanContext.getBean(resolutionContext.copy(), argument, finalQualifier);
}

@Override
public Optional<Object> find(Qualifier<Object> qualifier) {
return ((DefaultBeanContext) context).findBean(resolutionContext.copy(), argument, qualify(qualifier));
return defaultBeanContext.findBean(resolutionContext.copy(), argument, qualify(qualifier));
}

@Override
Expand All @@ -93,7 +94,7 @@ public BeanDefinition<Object> getDefinition() {

@Override
public Object get(Qualifier<Object> qualifier) {
return ((DefaultBeanContext) context).getBean(resolutionContext.copy(), argument, qualify(qualifier));
return defaultBeanContext.getBean(resolutionContext.copy(), argument, qualify(qualifier));
}

@Override
Expand All @@ -113,12 +114,12 @@ public boolean isPresent() {
@NonNull
@Override
public Iterator<Object> iterator() {
return ((DefaultBeanContext) context).getBeansOfType(resolutionContext.copy(), argument, finalQualifier).iterator();
return defaultBeanContext.getBeansOfType(resolutionContext.copy(), argument, finalQualifier).iterator();
}

@Override
public Stream<Object> stream() {
return ((DefaultBeanContext) context).streamOfType(resolutionContext.copy(), argument, finalQualifier);
return defaultBeanContext.streamOfType(resolutionContext.copy(), argument, finalQualifier);
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,20 @@ public boolean isPresent() {

@Override
protected Provider<Object> buildProvider(BeanResolutionContext resolutionContext, BeanContext context, Argument<Object> argument, Qualifier<Object> qualifier, boolean singleton) {
DefaultBeanContext defaultBeanContext = (DefaultBeanContext) context;
if (singleton) {
return new Provider<Object>() {
return new Provider<>() {
Object bean;
@Override
public Object get() {
if (bean == null) {
bean = ((DefaultBeanContext) context).getBean(resolutionContext.copy(), argument, qualifier);
bean = defaultBeanContext.getBean(resolutionContext.copy(), argument, qualifier);
}
return bean;
}
};
}
return () -> ((DefaultBeanContext) context).getBean(resolutionContext.copy(), argument, qualifier);
return () -> defaultBeanContext.getBean(resolutionContext.copy(), argument, qualifier);
}

static boolean isTypePresent() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,21 @@ public boolean isPresent() {

@Override
protected Provider<Object> buildProvider(BeanResolutionContext resolutionContext, BeanContext context, Argument<Object> argument, Qualifier<Object> qualifier, boolean singleton) {
DefaultBeanContext defaultBeanContext = (DefaultBeanContext) context;
if (singleton) {
return new Provider<Object>() {
return new Provider<>() {
Object bean;

@Override
public Object get() {
if (bean == null) {
bean = ((DefaultBeanContext) context).getBean(resolutionContext.copy(), argument, qualifier);
bean = defaultBeanContext.getBean(resolutionContext.copy(), argument, qualifier);
}
return bean;
}
};
}
return () -> ((DefaultBeanContext) context).getBean(resolutionContext.copy(), argument, qualifier);
return () -> defaultBeanContext.getBean(resolutionContext.copy(), argument, qualifier);
}

private static boolean isTypePresent() {
Expand Down

0 comments on commit 287b670

Please sign in to comment.