diff --git a/extensions/arc/deployment/src/main/java/io/quarkus/arc/deployment/ArcProcessor.java b/extensions/arc/deployment/src/main/java/io/quarkus/arc/deployment/ArcProcessor.java index 5cea61f68c1f2..02c7f65d49242 100644 --- a/extensions/arc/deployment/src/main/java/io/quarkus/arc/deployment/ArcProcessor.java +++ b/extensions/arc/deployment/src/main/java/io/quarkus/arc/deployment/ArcProcessor.java @@ -32,6 +32,7 @@ import org.jboss.logging.Logger; import io.quarkus.arc.ArcContainer; +import io.quarkus.arc.ContextReferenceFactory; import io.quarkus.arc.deployment.BeanRegistrationPhaseBuildItem.BeanConfiguratorBuildItem; import io.quarkus.arc.deployment.ContextRegistrationPhaseBuildItem.ContextConfiguratorBuildItem; import io.quarkus.arc.deployment.ObserverRegistrationPhaseBuildItem.ObserverConfiguratorBuildItem; @@ -102,6 +103,7 @@ import io.quarkus.gizmo.ResultHandle; import io.quarkus.runtime.LaunchMode; import io.quarkus.runtime.QuarkusApplication; +import io.quarkus.runtime.RuntimeValue; import io.quarkus.runtime.annotations.QuarkusMain; import io.quarkus.runtime.test.TestApplicationClassPredicate; import io.quarkus.runtime.util.HashUtil; @@ -480,7 +482,8 @@ public BeanContainerBuildItem generateResources(ArcConfig config, ArcRecorder re LiveReloadBuildItem liveReloadBuildItem, BuildProducer generatedResource, BuildProducer bytecodeTransformer, - List reflectiveBeanClasses) throws Exception { + List reflectiveBeanClasses, + Optional customContextReferenceFactory) throws Exception { for (ValidationErrorBuildItem validationError : validationErrors) { for (Throwable error : validationError.getValues()) { @@ -561,7 +564,11 @@ public void registerSubclass(DotName beanClassName, String subclassName) { reflectiveClasses.produce(new ReflectiveClassBuildItem(true, false, binding.name().toString())); } - ArcContainer container = recorder.getContainer(shutdown); + RuntimeValue contextReferenceFactory = null; + if (customContextReferenceFactory.isPresent()) { + contextReferenceFactory = customContextReferenceFactory.get().getFactory(); + } + ArcContainer container = recorder.initContainer(shutdown, contextReferenceFactory); BeanContainer beanContainer = recorder.initBeanContainer(container, beanContainerListenerBuildItems.stream().map(BeanContainerListenerBuildItem::getBeanContainerListener) .collect(Collectors.toList())); diff --git a/extensions/arc/deployment/src/main/java/io/quarkus/arc/deployment/ContextReferenceFactoryBuildItem.java b/extensions/arc/deployment/src/main/java/io/quarkus/arc/deployment/ContextReferenceFactoryBuildItem.java new file mode 100644 index 0000000000000..767864c7f5439 --- /dev/null +++ b/extensions/arc/deployment/src/main/java/io/quarkus/arc/deployment/ContextReferenceFactoryBuildItem.java @@ -0,0 +1,22 @@ +package io.quarkus.arc.deployment; + +import io.quarkus.arc.ContextReferenceFactory; +import io.quarkus.builder.item.SimpleBuildItem; +import io.quarkus.runtime.RuntimeValue; + +/** + * An extension can provide a custom {@link ContextReferenceFactory}. + */ +public final class ContextReferenceFactoryBuildItem extends SimpleBuildItem { + + private final RuntimeValue factory; + + public ContextReferenceFactoryBuildItem(RuntimeValue factory) { + this.factory = factory; + } + + public RuntimeValue getFactory() { + return factory; + } + +} diff --git a/extensions/arc/runtime/src/main/java/io/quarkus/arc/runtime/ArcRecorder.java b/extensions/arc/runtime/src/main/java/io/quarkus/arc/runtime/ArcRecorder.java index 82de993b4c258..8f74876d86bf6 100644 --- a/extensions/arc/runtime/src/main/java/io/quarkus/arc/runtime/ArcRecorder.java +++ b/extensions/arc/runtime/src/main/java/io/quarkus/arc/runtime/ArcRecorder.java @@ -14,6 +14,7 @@ import io.quarkus.arc.Arc; import io.quarkus.arc.ArcContainer; +import io.quarkus.arc.ContextReferenceFactory; import io.quarkus.arc.InjectableBean; import io.quarkus.arc.InjectableBean.Kind; import io.quarkus.arc.impl.ArcContainerImpl; @@ -36,8 +37,9 @@ public class ArcRecorder { */ public static volatile Map> supplierMap; - public ArcContainer getContainer(ShutdownContext shutdown) throws Exception { - ArcContainer container = Arc.initialize(); + public ArcContainer initContainer(ShutdownContext shutdown, RuntimeValue contextReferenceFactory) + throws Exception { + ArcContainer container = Arc.initialize(contextReferenceFactory != null ? contextReferenceFactory.getValue() : null); shutdown.addShutdownTask(new Runnable() { @Override public void run() { diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/Arc.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/Arc.java index dfa67bb526ec4..d8b64a1b42373 100644 --- a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/Arc.java +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/Arc.java @@ -11,17 +11,21 @@ public final class Arc { private static final AtomicReference INSTANCE = new AtomicReference<>(); + public static ArcContainer initialize() { + return initialize(null); + } + /** * * @return the initialized container */ - public static ArcContainer initialize() { + public static ArcContainer initialize(ContextReferenceFactory contextReferenceFactory) { ArcContainerImpl container = INSTANCE.get(); if (container == null) { synchronized (INSTANCE) { container = INSTANCE.get(); if (container == null) { - container = new ArcContainerImpl(); + container = new ArcContainerImpl(contextReferenceFactory); // Set the container instance first because Arc.container() can be used within ArcContainerImpl.init() INSTANCE.set(container); container.init(); diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ArcContainer.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ArcContainer.java index 55c7dd5060d36..29e59160e5a13 100644 --- a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ArcContainer.java +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ArcContainer.java @@ -168,4 +168,10 @@ public interface ArcContainer { * @return the default executor service */ ExecutorService getExecutorService(); + + /** + * + * @return the factory + */ + ContextReferenceFactory getContextReferenceFactory(); } diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ContextReference.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ContextReference.java new file mode 100644 index 0000000000000..39231cad9b296 --- /dev/null +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ContextReference.java @@ -0,0 +1,31 @@ +package io.quarkus.arc; + +import io.quarkus.arc.InjectableContext.ContextState; + +/** + * Represents the current context of a normal scope. + * + * @param + * @see ContextReferenceFactory + */ +public interface ContextReference { + + /** + * + * @return the current state + */ + T get(); + + /** + * Sets the current state. + * + * @param state + */ + void set(T state); + + /** + * Removes the current state. + */ + void remove(); + +} diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ContextReferenceFactory.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ContextReferenceFactory.java new file mode 100644 index 0000000000000..940f44fae6e86 --- /dev/null +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ContextReferenceFactory.java @@ -0,0 +1,15 @@ +package io.quarkus.arc; + +import io.quarkus.arc.InjectableContext.ContextState; + +/** + * This factory is used to create a new {@link ContextReference} for a non-shared context of a normal scope, e.g. the request + * context. + * + * @param + */ +public interface ContextReferenceFactory { + + ContextReference create(); + +} diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ArcContainerImpl.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ArcContainerImpl.java index fa85acdd04388..51354f0dca4e0 100644 --- a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ArcContainerImpl.java +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ArcContainerImpl.java @@ -6,6 +6,7 @@ import io.quarkus.arc.ArcContainer; import io.quarkus.arc.Components; import io.quarkus.arc.ComponentsProvider; +import io.quarkus.arc.ContextReferenceFactory; import io.quarkus.arc.InjectableBean; import io.quarkus.arc.InjectableContext; import io.quarkus.arc.InjectableDecorator; @@ -95,7 +96,9 @@ public class ArcContainerImpl implements ArcContainer { private volatile ExecutorService executorService; - public ArcContainerImpl() { + private final ContextReferenceFactory contextReferenceFactory; + + public ArcContainerImpl(ContextReferenceFactory contextReferenceFactory) { id = String.valueOf(ID_GENERATOR.incrementAndGet()); running = new AtomicBoolean(true); List> beans = new ArrayList<>(); @@ -105,10 +108,12 @@ public ArcContainerImpl() { List> observers = new ArrayList<>(); Map, Set> transitiveInterceptorBindings = new HashMap<>(); Map> qualifierNonbindingMembers = new HashMap<>(); + this.contextReferenceFactory = contextReferenceFactory == null ? new ThreadLocalContextReferenceFactory() + : contextReferenceFactory; applicationContext = new ApplicationContext(); singletonContext = new SingletonContext(); - requestContext = new RequestContext(); + requestContext = new RequestContext(this.contextReferenceFactory.create()); contexts = new HashMap<>(); putContext(requestContext); putContext(applicationContext); @@ -335,6 +340,11 @@ public void setExecutor(ExecutorService executor) { this.executorService = executor; } + @Override + public ContextReferenceFactory getContextReferenceFactory() { + return contextReferenceFactory; + } + @Override public String toString() { return "ArcContainerImpl [id=" + id + ", running=" + running + ", beans=" + beans.size() + ", observers=" diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/RequestContext.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/RequestContext.java index bc2d4cb7ee5aa..d63601f3e8e2e 100644 --- a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/RequestContext.java +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/RequestContext.java @@ -1,6 +1,7 @@ package io.quarkus.arc.impl; import io.quarkus.arc.ContextInstanceHandle; +import io.quarkus.arc.ContextReference; import io.quarkus.arc.InjectableBean; import io.quarkus.arc.ManagedContext; import io.quarkus.arc.impl.EventImpl.Notifier; @@ -33,14 +34,14 @@ class RequestContext implements ManagedContext { private static final Logger LOGGER = Logger.getLogger(RequestContext.class.getPackage().getName()); - // It's a normal scope so there may be no more than one mapped instance per contextual type per thread - private final ThreadLocal currentContext = new ThreadLocal<>(); + private final ContextReference currentContext; private final LazyValue> initializedNotifier; private final LazyValue> beforeDestroyedNotifier; private final LazyValue> destroyedNotifier; - public RequestContext() { + public RequestContext(ContextReference currentContext) { + this.currentContext = currentContext; this.initializedNotifier = new LazyValue<>(RequestContext::createInitializedNotifier); this.beforeDestroyedNotifier = new LazyValue<>(RequestContext::createBeforeDestroyedNotifier); this.destroyedNotifier = new LazyValue<>(RequestContext::createDestroyedNotifier); @@ -56,19 +57,18 @@ public Class getScope() { public T getIfActive(Contextual contextual, Function, CreationalContext> creationalContextFun) { Objects.requireNonNull(contextual, "Contextual must not be null"); Objects.requireNonNull(creationalContextFun, "CreationalContext supplier must not be null"); - RequestContextState ctxState = currentContext.get(); - if (ctxState == null) { - // Thread local not set - context is not active! + RequestContextState state = currentContext.get(); + if (state == null) { + // Context is not active return null; } - Map, ContextInstanceHandle> ctxMap = currentContext.get().value; - ContextInstanceHandle instance = (ContextInstanceHandle) ctxMap.get(contextual); + ContextInstanceHandle instance = (ContextInstanceHandle) state.map.get(contextual); if (instance == null) { CreationalContext creationalContext = creationalContextFun.apply(contextual); // Bean instance does not exist - create one if we have CreationalContext instance = new ContextInstanceHandleImpl((InjectableBean) contextual, contextual.create(creationalContext), creationalContext); - ctxMap.put(contextual, instance); + state.map.put(contextual, instance); } return instance.get(); } @@ -88,12 +88,12 @@ public T get(Contextual contextual, CreationalContext creationalContex @Override public T get(Contextual contextual) { Objects.requireNonNull(contextual, "Contextual must not be null"); - Map, ContextInstanceHandle> ctx = currentContext.get().value; - if (ctx == null) { - // Thread local not set - context is not active! + RequestContextState state = currentContext.get(); + if (state == null) { + // Context is not active throw new ContextNotActiveException(); } - ContextInstanceHandle instance = (ContextInstanceHandle) ctx.get(contextual); + ContextInstanceHandle instance = (ContextInstanceHandle) state.map.get(contextual); return instance == null ? null : instance.get(); } @@ -104,12 +104,12 @@ public boolean isActive() { @Override public void destroy(Contextual contextual) { - Map, ContextInstanceHandle> ctx = currentContext.get().value; - if (ctx == null) { - // Thread local not set - context is not active! + RequestContextState state = currentContext.get(); + if (state == null) { + // Context is not active throw new ContextNotActiveException(); } - ContextInstanceHandle instance = ctx.remove(contextual); + ContextInstanceHandle instance = state.map.remove(contextual); if (instance != null) { instance.destroy(); } @@ -123,7 +123,7 @@ public void activate(ContextState initialState) { fireIfNotEmpty(initializedNotifier); } else { if (initialState instanceof RequestContextState) { - currentContext.set(((RequestContextState) initialState)); + currentContext.set((RequestContextState) initialState); } else { throw new IllegalArgumentException("Invalid initial state: " + initialState.getClass().getName()); } @@ -132,20 +132,16 @@ public void activate(ContextState initialState) { @Override public ContextState getState() { - RequestContextState ctx = currentContext.get(); - if (ctx == null) { + RequestContextState state = currentContext.get(); + if (state == null) { // Thread local not set - context is not active! throw new ContextNotActiveException(); } - return ctx; + return state; } public ContextState getStateIfActive() { - RequestContextState ctx = currentContext.get(); - if (ctx == null) { - return null; - } - return ctx; + return currentContext.get(); } @Override @@ -167,9 +163,9 @@ public void destroy(ContextState state) { if (state instanceof RequestContextState) { RequestContextState reqState = ((RequestContextState) state); reqState.isValid.set(false); - destroy(reqState.value); + destroy(reqState.map); } else { - throw new IllegalArgumentException("Invalid state: " + state.getClass().getName()); + throw new IllegalArgumentException("Invalid state implementation: " + state.getClass().getName()); } } @@ -230,17 +226,17 @@ private static Notifier createDestroyedNotifier() { static class RequestContextState implements ContextState { - private final ConcurrentMap, ContextInstanceHandle> value; + private final Map, ContextInstanceHandle> map; private final AtomicBoolean isValid; RequestContextState(ConcurrentMap, ContextInstanceHandle> value) { - this.value = value; + this.map = Objects.requireNonNull(value); this.isValid = new AtomicBoolean(true); } @Override public Map, Object> getContextualInstances() { - return value.values().stream() + return map.values().stream() .collect(Collectors.toUnmodifiableMap(ContextInstanceHandle::getBean, ContextInstanceHandle::get)); } diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ThreadLocalContextReference.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ThreadLocalContextReference.java new file mode 100644 index 0000000000000..8e765de12c892 --- /dev/null +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ThreadLocalContextReference.java @@ -0,0 +1,30 @@ +package io.quarkus.arc.impl; + +import io.quarkus.arc.ContextReference; +import io.quarkus.arc.InjectableContext.ContextState; + +/** + * {@link ThreadLocal} implementation of {@link ContextReference}. + * + * @param + */ +final class ThreadLocalContextReference implements ContextReference { + + private final ThreadLocal currentContext = new ThreadLocal<>(); + + @Override + public T get() { + return currentContext.get(); + } + + @Override + public void set(T state) { + currentContext.set(state); + } + + @Override + public void remove() { + currentContext.remove(); + } + +} diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ThreadLocalContextReferenceFactory.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ThreadLocalContextReferenceFactory.java new file mode 100644 index 0000000000000..5a29e4fb7192b --- /dev/null +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ThreadLocalContextReferenceFactory.java @@ -0,0 +1,19 @@ +package io.quarkus.arc.impl; + +import io.quarkus.arc.ContextReference; +import io.quarkus.arc.ContextReferenceFactory; +import io.quarkus.arc.InjectableContext.ContextState; + +/** + * The default implementation makes use of {@link ThreadLocal} variables. + * + * @see ThreadLocalContextReference + */ +final class ThreadLocalContextReferenceFactory implements ContextReferenceFactory { + + @Override + public ContextReference create() { + return new ThreadLocalContextReference<>(); + } + +}