From 695e44a9da4657325176f9bf9f5b42ce9713e101 Mon Sep 17 00:00:00 2001 From: Martin Kouba Date: Tue, 17 Dec 2024 09:41:50 +0100 Subject: [PATCH 1/4] ArC: introduce CurrentManagedContext - share the logic used in RequestContext and WebSocketSessionContext --- .../next/runtime/ContextSupport.java | 6 +- .../websockets/next/runtime/Endpoints.java | 4 +- .../next/runtime/WebSocketSessionContext.java | 230 ++------------- .../arc/impl/CurrentManagedContext.java | 270 ++++++++++++++++++ .../io/quarkus/arc/impl/RequestContext.java | 245 +--------------- 5 files changed, 304 insertions(+), 451 deletions(-) create mode 100644 independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/CurrentManagedContext.java diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java index 2b60536dfc45b..0698723d361dd 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java @@ -6,8 +6,8 @@ import io.quarkus.arc.InjectableContext.ContextState; import io.quarkus.arc.ManagedContext; +import io.quarkus.arc.impl.CurrentManagedContext.CurrentContextState; import io.quarkus.vertx.core.runtime.context.VertxContextSafetyToggle; -import io.quarkus.websockets.next.runtime.WebSocketSessionContext.SessionContextState; import io.smallrye.common.vertx.VertxContext; import io.vertx.core.Context; @@ -21,11 +21,11 @@ public class ContextSupport { static final String WEB_SOCKET_CONN_KEY = WebSocketConnectionBase.class.getName(); private final WebSocketConnectionBase connection; - private final SessionContextState sessionContextState; + private final CurrentContextState sessionContextState; private final WebSocketSessionContext sessionContext; private final ManagedContext requestContext; - ContextSupport(WebSocketConnectionBase connection, SessionContextState sessionContextState, + ContextSupport(WebSocketConnectionBase connection, CurrentContextState sessionContextState, WebSocketSessionContext sessionContext, ManagedContext requestContext) { this.connection = connection; diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java index 349ccd7a75aff..95d7e2158d7bc 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java @@ -11,13 +11,13 @@ import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; import io.quarkus.arc.ArcContainer; import io.quarkus.arc.InjectableContext; +import io.quarkus.arc.impl.CurrentManagedContext.CurrentContextState; import io.quarkus.runtime.LaunchMode; import io.quarkus.security.AuthenticationFailedException; import io.quarkus.security.ForbiddenException; import io.quarkus.security.UnauthorizedException; import io.quarkus.websockets.next.CloseReason; import io.quarkus.websockets.next.WebSocketException; -import io.quarkus.websockets.next.runtime.WebSocketSessionContext.SessionContextState; import io.quarkus.websockets.next.runtime.config.UnhandledFailureStrategy; import io.quarkus.websockets.next.runtime.telemetry.ErrorInterceptor; import io.quarkus.websockets.next.runtime.telemetry.TelemetrySupport; @@ -44,7 +44,7 @@ static void initialize(Vertx vertx, ArcContainer container, Codecs codecs, WebSo // Initialize and capture the session context state that will be activated // during message processing WebSocketSessionContext sessionContext = null; - SessionContextState sessionContextState = null; + CurrentContextState sessionContextState = null; if (activateSessionContext) { sessionContext = sessionContext(container); sessionContextState = sessionContext.initializeContextState(); diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketSessionContext.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketSessionContext.java index 3d6c488289c41..4d5a91b170a7d 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketSessionContext.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketSessionContext.java @@ -1,51 +1,31 @@ package io.quarkus.websockets.next.runtime; import java.lang.annotation.Annotation; -import java.lang.invoke.MethodHandles; -import java.lang.invoke.VarHandle; -import java.util.Map; -import java.util.Objects; +import java.util.function.Consumer; import java.util.function.Supplier; -import java.util.stream.Collectors; import jakarta.enterprise.context.BeforeDestroyed; import jakarta.enterprise.context.ContextNotActiveException; import jakarta.enterprise.context.Destroyed; import jakarta.enterprise.context.Initialized; import jakarta.enterprise.context.SessionScoped; -import jakarta.enterprise.context.spi.Contextual; -import jakarta.enterprise.context.spi.CreationalContext; import jakarta.enterprise.event.Event; import jakarta.enterprise.inject.Any; -import org.jboss.logging.Logger; - import io.quarkus.arc.Arc; import io.quarkus.arc.ArcContainer; -import io.quarkus.arc.ContextInstanceHandle; -import io.quarkus.arc.CurrentContext; import io.quarkus.arc.CurrentContextFactory; -import io.quarkus.arc.InjectableBean; -import io.quarkus.arc.ManagedContext; import io.quarkus.arc.impl.ComputingCacheContextInstances; -import io.quarkus.arc.impl.ContextInstanceHandleImpl; -import io.quarkus.arc.impl.ContextInstances; +import io.quarkus.arc.impl.CurrentManagedContext; import io.quarkus.arc.impl.LazyValue; -public class WebSocketSessionContext implements ManagedContext { - - private static final Logger LOG = Logger.getLogger(WebSocketSessionContext.class); - - private final CurrentContext currentContext; - private final LazyValue> initializedEvent; - private final LazyValue> beforeDestroyEvent; - private final LazyValue> destroyEvent; +public class WebSocketSessionContext extends CurrentManagedContext { public WebSocketSessionContext(CurrentContextFactory currentContextFactory) { - this.currentContext = currentContextFactory.create(SessionScoped.class); - this.initializedEvent = newEvent(Initialized.Literal.SESSION, Any.Literal.INSTANCE); - this.beforeDestroyEvent = newEvent(BeforeDestroyed.Literal.SESSION, Any.Literal.INSTANCE); - this.destroyEvent = newEvent(Destroyed.Literal.SESSION, Any.Literal.INSTANCE); + super(currentContextFactory.create(SessionScoped.class), ComputingCacheContextInstances::new, + newEvent(Initialized.Literal.SESSION, Any.Literal.INSTANCE), + newEvent(BeforeDestroyed.Literal.SESSION, Any.Literal.INSTANCE), + newEvent(Destroyed.Literal.SESSION, Any.Literal.INSTANCE)); } @Override @@ -53,147 +33,12 @@ public Class getScope() { return SessionScoped.class; } - @Override - public ContextState getState() { - SessionContextState state = currentState(); - if (state == null) { - throw notActive(); - } - return state; - } - - @Override - public ContextState activate(ContextState initialState) { - if (initialState == null) { - SessionContextState state = initializeContextState(); - currentContext.set(state); - return state; - } else { - if (initialState instanceof SessionContextState) { - currentContext.set((SessionContextState) initialState); - return initialState; - } else { - throw new IllegalArgumentException("Invalid initial state: " + initialState.getClass().getName()); - } - } - } - - @Override - public void deactivate() { - currentContext.remove(); - } - - @SuppressWarnings("unchecked") - @Override - public T get(Contextual contextual, CreationalContext creationalContext) { - Objects.requireNonNull(contextual, "Contextual must not be null"); - Objects.requireNonNull(creationalContext, "CreationalContext must not be null"); - InjectableBean bean = (InjectableBean) contextual; - if (!SessionScoped.class.getName().equals(bean.getScope().getName())) { - throw invalidScope(); - } - SessionContextState state = currentState(); - if (state == null || !state.isValid()) { - throw notActive(); - } - return (T) state.contextInstances.computeIfAbsent(bean.getIdentifier(), new Supplier>() { - - @Override - public ContextInstanceHandle get() { - return new ContextInstanceHandleImpl<>(bean, contextual.create(creationalContext), creationalContext); - } - }).get(); - } - - @Override - public T get(Contextual contextual) { - Objects.requireNonNull(contextual, "Contextual must not be null"); - InjectableBean bean = (InjectableBean) contextual; - if (!SessionScoped.class.getName().equals(bean.getScope().getName())) { - throw invalidScope(); - } - SessionContextState state = currentState(); - if (state == null || !state.isValid()) { - throw notActive(); - } - @SuppressWarnings("unchecked") - ContextInstanceHandle instance = (ContextInstanceHandle) state.contextInstances - .getIfPresent(bean.getIdentifier()); - return instance == null ? null : instance.get(); - } - - @Override - public boolean isActive() { - SessionContextState contextState = currentState(); - return contextState == null ? false : contextState.isValid(); - } - - @Override - public void destroy() { - destroy(currentState()); - } - - @Override - public void destroy(Contextual contextual) { - SessionContextState state = currentState(); - if (state == null || !state.isValid()) { - throw notActive(); - } - InjectableBean bean = (InjectableBean) contextual; - ContextInstanceHandle instance = state.contextInstances.remove(bean.getIdentifier()); - if (instance != null) { - instance.destroy(); - } - } - - @Override - public void destroy(ContextState state) { - if (state == null) { - // nothing to destroy - return; - } - if (state instanceof SessionContextState) { - SessionContextState sessionState = ((SessionContextState) state); - if (sessionState.invalidate()) { - fireIfNotNull(beforeDestroyEvent.get()); - sessionState.contextInstances.removeEach(ContextInstanceHandle::destroy); - fireIfNotNull(destroyEvent.get()); - } - } else { - throw new IllegalArgumentException("Invalid state implementation: " + state.getClass().getName()); - } - } - - SessionContextState initializeContextState() { - SessionContextState state = new SessionContextState(new ComputingCacheContextInstances()); - fireIfNotNull(initializedEvent.get()); - return state; - } - - private SessionContextState currentState() { - return currentContext.get(); - } - - private IllegalArgumentException invalidScope() { - throw new IllegalArgumentException("The bean does not declare @SessionScoped"); - } - - private ContextNotActiveException notActive() { + protected ContextNotActiveException notActive() { return new ContextNotActiveException("Session context is not active"); } - private void fireIfNotNull(Event event) { - if (event != null) { - try { - event.fire(toString()); - } catch (Exception e) { - LOG.warn("An error occurred during delivery of the context lifecycle event for " + toString(), e); - } - } - } - - private static LazyValue> newEvent(Annotation... qualifiers) { - return new LazyValue<>(new Supplier>() { + private static Consumer newEvent(Annotation... qualifiers) { + LazyValue> event = new LazyValue<>(new Supplier>() { @Override public Event get() { ArcContainer container = Arc.container(); @@ -203,54 +48,15 @@ public Event get() { return container.beanManager().getEvent().select(qualifiers); } }); - } - - static class SessionContextState implements ContextState { + return new Consumer() { - // Using 0 as default value enable removing an initialization - // in the constructor, piggybacking on the default value. - // As per https://docs.oracle.com/javase/specs/jls/se8/html/jls-12.html#jls-12.5 - // the default field values are set before 'this' is accessible, hence - // they should be the very first value observable even in presence of - // unsafe publication of this object. - private static final int VALID = 0; - private static final int INVALID = 1; - private static final VarHandle IS_VALID; - - static { - try { - IS_VALID = MethodHandles.lookup().findVarHandle(SessionContextState.class, "isValid", int.class); - } catch (ReflectiveOperationException e) { - throw new Error(e); + @Override + public void accept(Object t) { + Event e = event.get(); + if (e != null) { + e.fire(t); + } } - } - - private final ContextInstances contextInstances; - private volatile int isValid; - - SessionContextState(ContextInstances contextInstances) { - this.contextInstances = contextInstances; - } - - @Override - public Map, Object> getContextualInstances() { - return contextInstances.getAllPresent().stream() - .collect(Collectors.toUnmodifiableMap(ContextInstanceHandle::getBean, ContextInstanceHandle::get)); - } - - /** - * @return {@code true} if the state was successfully invalidated, {@code false} otherwise - */ - boolean invalidate() { - // Atomically sets the value just like AtomicBoolean.compareAndSet(boolean, boolean) - return IS_VALID.compareAndSet(this, VALID, INVALID); - } - - @Override - public boolean isValid() { - return isValid == VALID; - } - + }; } - } diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/CurrentManagedContext.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/CurrentManagedContext.java new file mode 100644 index 0000000000000..7d5da765a94fc --- /dev/null +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/CurrentManagedContext.java @@ -0,0 +1,270 @@ +package io.quarkus.arc.impl; + +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; +import java.util.Map; +import java.util.Objects; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import jakarta.enterprise.context.ContextNotActiveException; +import jakarta.enterprise.context.spi.Contextual; +import jakarta.enterprise.context.spi.CreationalContext; + +import org.jboss.logging.Logger; + +import io.quarkus.arc.ContextInstanceHandle; +import io.quarkus.arc.CurrentContext; +import io.quarkus.arc.InjectableBean; +import io.quarkus.arc.ManagedContext; + +/** + * A managed context backed by the {@link CurrentContext}. + */ +public abstract class CurrentManagedContext implements ManagedContext { + + private static final Logger LOG = Logger.getLogger(CurrentManagedContext.class); + + private final CurrentContext currentContext; + + private final Supplier contextInstances; + + private final Consumer initializedNotifier; + private final Consumer beforeDestroyedNotifier; + private final Consumer destroyedNotifier; + + protected CurrentManagedContext(CurrentContext currentContext, + Supplier contextInstances, Consumer initializedNotifier, + Consumer beforeDestroyedNotifier, Consumer destroyedNotifier) { + this.currentContext = currentContext; + this.contextInstances = contextInstances; + this.initializedNotifier = initializedNotifier; + this.beforeDestroyedNotifier = beforeDestroyedNotifier; + this.destroyedNotifier = destroyedNotifier; + } + + @Override + public ContextState getState() { + CurrentContextState state = currentState(); + if (state == null) { + throw notActive(); + } + return state; + } + + @Override + public ContextState activate(ContextState initialState) { + if (traceLog().isTraceEnabled()) { + traceActivate(initialState); + } + if (initialState == null) { + CurrentContextState state = initializeContextState(); + currentContext.set(state); + return state; + } else { + if (initialState instanceof CurrentContextState) { + currentContext.set((CurrentContextState) initialState); + return initialState; + } else { + throw new IllegalArgumentException("Invalid initial state: " + initialState.getClass().getName()); + } + } + } + + @Override + public void deactivate() { + if (traceLog().isTraceEnabled()) { + traceDeactivate(); + } + currentContext.remove(); + } + + @SuppressWarnings("unchecked") + @Override + public T getIfActive(Contextual contextual, Function, CreationalContext> creationalContextFun) { + Objects.requireNonNull(contextual, "Contextual must not be null"); + Objects.requireNonNull(creationalContextFun, "CreationalContext function must not be null"); + InjectableBean bean = (InjectableBean) contextual; + if (!Scopes.scopeMatches(this, bean)) { + throw Scopes.scopeDoesNotMatchException(this, bean); + } + CurrentContextState state = currentState(); + if (state == null || !state.isValid()) { + return null; + } + ContextInstances contextInstances = state.contextInstances; + ContextInstanceHandle instance = (ContextInstanceHandle) contextInstances.getIfPresent(bean.getIdentifier()); + if (instance == null) { + CreationalContext creationalContext = creationalContextFun.apply(contextual); + return (T) contextInstances.computeIfAbsent(bean.getIdentifier(), new Supplier>() { + + @Override + public ContextInstanceHandle get() { + return new ContextInstanceHandleImpl<>(bean, contextual.create(creationalContext), creationalContext); + } + }).get(); + } + return instance.get(); + } + + @Override + public T get(Contextual contextual, CreationalContext creationalContext) { + T result = getIfActive(contextual, + CreationalContextImpl.unwrap(Objects.requireNonNull(creationalContext, "CreationalContext must not be null"))); + if (result == null) { + throw notActive(); + } + return result; + } + + @Override + public T get(Contextual contextual) { + Objects.requireNonNull(contextual, "Contextual must not be null"); + InjectableBean bean = (InjectableBean) contextual; + if (!Scopes.scopeMatches(this, bean)) { + throw Scopes.scopeDoesNotMatchException(this, bean); + } + CurrentContextState state = currentState(); + if (state == null || !state.isValid()) { + throw notActive(); + } + @SuppressWarnings("unchecked") + ContextInstanceHandle instance = (ContextInstanceHandle) state.contextInstances + .getIfPresent(bean.getIdentifier()); + return instance == null ? null : instance.get(); + } + + @Override + public boolean isActive() { + CurrentContextState contextState = currentState(); + return contextState == null ? false : contextState.isValid(); + } + + @Override + public void destroy() { + destroy(currentState()); + } + + @Override + public void destroy(Contextual contextual) { + CurrentContextState state = currentState(); + if (state == null || !state.isValid()) { + throw notActive(); + } + InjectableBean bean = (InjectableBean) contextual; + ContextInstanceHandle instance = state.contextInstances.remove(bean.getIdentifier()); + if (instance != null) { + instance.destroy(); + } + } + + @Override + public void destroy(ContextState state) { + if (traceLog().isTraceEnabled()) { + traceDestroy(state); + } + if (state == null) { + // nothing to destroy + return; + } + if (state instanceof CurrentContextState) { + CurrentContextState currentState = ((CurrentContextState) state); + if (currentState.invalidate()) { + fireIfNotNull(beforeDestroyedNotifier); + currentState.contextInstances.removeEach(ContextInstanceHandle::destroy); + fireIfNotNull(destroyedNotifier); + } + } else { + throw new IllegalArgumentException("Invalid state implementation: " + state.getClass().getName()); + } + } + + public CurrentContextState initializeContextState() { + CurrentContextState state = new CurrentContextState(contextInstances.get()); + fireIfNotNull(initializedNotifier); + return state; + } + + protected Logger traceLog() { + return LOG; + } + + protected void traceActivate(ContextState initialState) { + // Noop + } + + protected void traceDeactivate() { + // Noop + } + + protected void traceDestroy(ContextState state) { + // Noop + } + + private CurrentContextState currentState() { + return currentContext.get(); + } + + protected abstract ContextNotActiveException notActive(); + + private void fireIfNotNull(Consumer notifier) { + if (notifier != null) { + try { + notifier.accept(toString()); + } catch (Exception e) { + LOG.warn("An error occurred during delivery of the context lifecycle event for " + toString(), e); + } + } + } + + public static class CurrentContextState implements ContextState { + + // Using 0 as default value enable removing an initialization + // in the constructor, piggybacking on the default value. + // As per https://docs.oracle.com/javase/specs/jls/se8/html/jls-12.html#jls-12.5 + // the default field values are set before 'this' is accessible, hence + // they should be the very first value observable even in presence of + // unsafe publication of this object. + private static final int VALID = 0; + private static final int INVALID = 1; + private static final VarHandle IS_VALID; + + static { + try { + IS_VALID = MethodHandles.lookup().findVarHandle(CurrentContextState.class, "isValid", int.class); + } catch (ReflectiveOperationException e) { + throw new Error(e); + } + } + + private final ContextInstances contextInstances; + private volatile int isValid; + + CurrentContextState(ContextInstances contextInstances) { + this.contextInstances = Objects.requireNonNull(contextInstances); + } + + @Override + public Map, Object> getContextualInstances() { + return contextInstances.getAllPresent().stream() + .collect(Collectors.toUnmodifiableMap(ContextInstanceHandle::getBean, ContextInstanceHandle::get)); + } + + /** + * @return {@code true} if the state was successfully invalidated, {@code false} otherwise + */ + boolean invalidate() { + // Atomically sets the value just like AtomicBoolean.compareAndSet(boolean, boolean) + return IS_VALID.compareAndSet(this, VALID, INVALID); + } + + @Override + public boolean isValid() { + return isValid == VALID; + } + + } + +} diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/RequestContext.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/RequestContext.java index 5164d84c374f3..62a1ac4abf829 100644 --- a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/RequestContext.java +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/RequestContext.java @@ -1,26 +1,16 @@ package io.quarkus.arc.impl; import java.lang.annotation.Annotation; -import java.lang.invoke.MethodHandles; -import java.lang.invoke.VarHandle; import java.util.Arrays; -import java.util.Map; -import java.util.Objects; -import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; import jakarta.enterprise.context.ContextNotActiveException; import jakarta.enterprise.context.RequestScoped; -import jakarta.enterprise.context.spi.Contextual; -import jakarta.enterprise.context.spi.CreationalContext; import org.jboss.logging.Logger; -import io.quarkus.arc.ContextInstanceHandle; import io.quarkus.arc.CurrentContext; -import io.quarkus.arc.InjectableBean; -import io.quarkus.arc.ManagedContext; import io.quarkus.arc.impl.EventImpl.Notifier; /** @@ -28,25 +18,16 @@ * * @author Martin Kouba */ -class RequestContext implements ManagedContext { +class RequestContext extends CurrentManagedContext { private static final Logger LOG = Logger.getLogger("io.quarkus.arc.requestContext"); - private final CurrentContext currentContext; - - private final Notifier initializedNotifier; - private final Notifier beforeDestroyedNotifier; - private final Notifier destroyedNotifier; - private final Supplier contextInstances; - - public RequestContext(CurrentContext currentContext, Notifier initializedNotifier, + public RequestContext(CurrentContext currentContext, Notifier initializedNotifier, Notifier beforeDestroyedNotifier, Notifier destroyedNotifier, Supplier contextInstances) { - this.currentContext = currentContext; - this.initializedNotifier = initializedNotifier; - this.beforeDestroyedNotifier = beforeDestroyedNotifier; - this.destroyedNotifier = destroyedNotifier; - this.contextInstances = contextInstances; + super(currentContext, contextInstances, initializedNotifier != null ? initializedNotifier::notify : null, + beforeDestroyedNotifier != null ? beforeDestroyedNotifier::notify : null, + destroyedNotifier != null ? destroyedNotifier::notify : null); } @Override @@ -54,107 +35,12 @@ public Class getScope() { return RequestScoped.class; } - @SuppressWarnings("unchecked") - @Override - public T getIfActive(Contextual contextual, Function, CreationalContext> creationalContextFun) { - Objects.requireNonNull(contextual, "Contextual must not be null"); - Objects.requireNonNull(creationalContextFun, "CreationalContext supplier must not be null"); - InjectableBean bean = (InjectableBean) contextual; - if (!Scopes.scopeMatches(this, bean)) { - throw Scopes.scopeDoesNotMatchException(this, bean); - } - RequestContextState ctxState = currentContext.get(); - if (!isActive(ctxState)) { - // Context is not active! - return null; - } - ContextInstances contextInstances = ctxState.contextInstances; - ContextInstanceHandle instance = (ContextInstanceHandle) contextInstances.getIfPresent(bean.getIdentifier()); - if (instance == null) { - CreationalContext creationalContext = creationalContextFun.apply(contextual); - return (T) contextInstances.computeIfAbsent(bean.getIdentifier(), new Supplier>() { - - @Override - public ContextInstanceHandle get() { - return new ContextInstanceHandleImpl<>(bean, contextual.create(creationalContext), creationalContext); - } - }).get(); - } - return instance.get(); - } - - @Override - public T get(Contextual contextual, CreationalContext creationalContext) { - T result = getIfActive(contextual, - CreationalContextImpl.unwrap(Objects.requireNonNull(creationalContext, "CreationalContext must not be null"))); - if (result == null) { - throw notActive(); - } - return result; - } - - @SuppressWarnings("unchecked") - @Override - public T get(Contextual contextual) { - Objects.requireNonNull(contextual, "Contextual must not be null"); - InjectableBean bean = (InjectableBean) contextual; - if (!Scopes.scopeMatches(this, bean)) { - throw Scopes.scopeDoesNotMatchException(this, bean); - } - RequestContextState state = currentContext.get(); - if (!isActive(state)) { - throw notActive(); - } - ContextInstanceHandle instance = (ContextInstanceHandle) state.contextInstances - .getIfPresent(bean.getIdentifier()); - return instance == null ? null : instance.get(); - } - @Override - public boolean isActive() { - return isActive(currentContext.get()); + protected Logger traceLog() { + return LOG; } - private boolean isActive(RequestContextState state) { - return state == null ? false : state.isValid(); - } - - @Override - public void destroy(Contextual contextual) { - RequestContextState state = currentContext.get(); - if (!isActive(state)) { - // Context is not active - throw notActive(); - } - InjectableBean bean = (InjectableBean) contextual; - ContextInstanceHandle instance = state.contextInstances.remove(bean.getIdentifier()); - if (instance != null) { - instance.destroy(); - } - } - - @Override - public ContextState activate(ContextState initialState) { - if (LOG.isTraceEnabled()) { - traceActivate(initialState); - } - if (initialState == null) { - RequestContextState state = new RequestContextState(contextInstances.get()); - currentContext.set(state); - // Fire an event with qualifier @Initialized(RequestScoped.class) if there are any observers for it - fireIfNotEmpty(initializedNotifier); - return state; - } else { - if (initialState instanceof RequestContextState) { - currentContext.set((RequestContextState) initialState); - return initialState; - } else { - throw new IllegalArgumentException("Invalid initial state: " + initialState.getClass().getName()); - } - } - } - - private void traceActivate(ContextState initialState) { + protected void traceActivate(ContextState initialState) { String stack = Arrays.stream(Thread.currentThread().getStackTrace()) .skip(2) .limit(7) @@ -164,29 +50,7 @@ private void traceActivate(ContextState initialState) { initialState != null ? Integer.toHexString(initialState.hashCode()) : "new", stack); } - @Override - public ContextState getState() { - RequestContextState state = currentContext.get(); - if (!isActive(state)) { - throw notActive(); - } - return state; - } - - public ContextState getStateIfActive() { - ContextState state = currentContext.get(); - return state != null && state.isValid() ? state : null; - } - - @Override - public void deactivate() { - if (LOG.isTraceEnabled()) { - traceDeactivate(); - } - currentContext.remove(); - } - - private static void traceDeactivate() { + protected void traceDeactivate() { String stack = Arrays.stream(Thread.currentThread().getStackTrace()) .skip(2) .limit(7) @@ -195,35 +59,7 @@ private static void traceDeactivate() { LOG.tracef("Deactivate%s\n\t...", stack); } - @Override - public void destroy() { - destroy(currentContext.get()); - } - - @Override - public void destroy(ContextState state) { - if (LOG.isTraceEnabled()) { - traceDestroy(state); - } - if (state == null) { - // nothing to destroy - return; - } - if (state instanceof RequestContextState) { - RequestContextState reqState = ((RequestContextState) state); - if (reqState.invalidate()) { - // Fire an event with qualifier @BeforeDestroyed(RequestScoped.class) if there are any observers for it - fireIfNotEmpty(beforeDestroyedNotifier); - reqState.contextInstances.removeEach(ContextInstanceHandle::destroy); - // Fire an event with qualifier @Destroyed(RequestScoped.class) if there are any observers for it - fireIfNotEmpty(destroyedNotifier); - } - } else { - throw new IllegalArgumentException("Invalid state implementation: " + state.getClass().getName()); - } - } - - private static void traceDestroy(ContextState state) { + protected void traceDestroy(ContextState state) { String stack = Arrays.stream(Thread.currentThread().getStackTrace()) .skip(2) .limit(7) @@ -232,68 +68,9 @@ private static void traceDestroy(ContextState state) { LOG.tracef("Destroy %s%s\n\t...", state != null ? Integer.toHexString(state.hashCode()) : "", stack); } - private void fireIfNotEmpty(Notifier notifier) { - if (notifier != null && !notifier.isEmpty()) { - try { - notifier.notify(toString()); - } catch (Exception e) { - LOG.warn("An error occurred during delivery of the container lifecycle event for qualifiers " - + notifier.eventMetadata.getQualifiers(), e); - } - } - } - - private ContextNotActiveException notActive() { + protected ContextNotActiveException notActive() { String msg = "Request context is not active - you can activate the request context for a specific method using the @ActivateRequestContext interceptor binding"; return new ContextNotActiveException(msg); } - static class RequestContextState implements ContextState { - - // Using 0 as default value enable removing an initialization - // in the constructor, piggybacking on the default value. - // As per https://docs.oracle.com/javase/specs/jls/se8/html/jls-12.html#jls-12.5 - // the default field values are set before 'this' is accessible, hence - // they should be the very first value observable even in presence of - // unsafe publication of this object. - private static final int VALID = 0; - private static final int INVALID = 1; - private static final VarHandle IS_VALID; - - static { - try { - IS_VALID = MethodHandles.lookup().findVarHandle(RequestContextState.class, "isValid", int.class); - } catch (ReflectiveOperationException e) { - throw new Error(e); - } - } - - private final ContextInstances contextInstances; - private volatile int isValid; - - RequestContextState(ContextInstances contextInstances) { - this.contextInstances = Objects.requireNonNull(contextInstances); - } - - @Override - public Map, Object> getContextualInstances() { - return contextInstances.getAllPresent().stream() - .collect(Collectors.toUnmodifiableMap(ContextInstanceHandle::getBean, ContextInstanceHandle::get)); - } - - /** - * @return {@code true} if the state was successfully invalidated, {@code false} otherwise - */ - boolean invalidate() { - // Atomically sets the value just like AtomicBoolean.compareAndSet(boolean, boolean) - return IS_VALID.compareAndSet(this, VALID, INVALID); - } - - @Override - public boolean isValid() { - return isValid == VALID; - } - - } - } From 1cd124ea0ce7571ec61d89b83b35b675c591625e Mon Sep 17 00:00:00 2001 From: Martin Kouba Date: Tue, 17 Dec 2024 11:05:33 +0100 Subject: [PATCH 2/4] ArC: introduce built-in session context --- .../quarkus/arc/processor/BuiltinScope.java | 4 +- .../java/io/quarkus/arc/ArcContainer.java | 7 ++ .../io/quarkus/arc/impl/ArcContainerImpl.java | 11 ++ .../java/io/quarkus/arc/impl/Contexts.java | 30 ++++-- .../io/quarkus/arc/impl/SessionContext.java | 35 ++++++ .../contexts/session/ContextObserver.java | 34 ++++++ .../arc/test/contexts/session/Controller.java | 30 ++++++ .../contexts/session/SessionContextTest.java | 100 ++++++++++++++++++ 8 files changed, 242 insertions(+), 9 deletions(-) create mode 100644 independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/SessionContext.java create mode 100644 independent-projects/arc/tests/src/test/java/io/quarkus/arc/test/contexts/session/ContextObserver.java create mode 100644 independent-projects/arc/tests/src/test/java/io/quarkus/arc/test/contexts/session/Controller.java create mode 100644 independent-projects/arc/tests/src/test/java/io/quarkus/arc/test/contexts/session/SessionContextTest.java diff --git a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/BuiltinScope.java b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/BuiltinScope.java index 29ade50172ec4..d543ac90be81f 100644 --- a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/BuiltinScope.java +++ b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/BuiltinScope.java @@ -5,6 +5,7 @@ import jakarta.enterprise.context.ApplicationScoped; import jakarta.enterprise.context.Dependent; import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.SessionScoped; import jakarta.inject.Singleton; import org.jboss.jandex.AnnotationInstance; @@ -16,7 +17,8 @@ public enum BuiltinScope { DEPENDENT(Dependent.class, false), SINGLETON(Singleton.class, false), APPLICATION(ApplicationScoped.class, true), - REQUEST(RequestScoped.class, true); + REQUEST(RequestScoped.class, true), + SESSION(SessionScoped.class, true); private ScopeInfo info; diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ArcContainer.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ArcContainer.java index 1ed13a4a95040..c4f3a4a3ea654 100644 --- a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ArcContainer.java +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ArcContainer.java @@ -204,6 +204,13 @@ public interface ArcContainer { */ ManagedContext requestContext(); + /** + * This method never throws {@link ContextNotActiveException}. + * + * @return the built-in context for {@link jakarta.enterprise.context.SessionScoped} + */ + ManagedContext sessionContext(); + /** * NOTE: Not all methods are supported! * diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ArcContainerImpl.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ArcContainerImpl.java index c63710b5cbd86..d9843805156a6 100644 --- a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ArcContainerImpl.java +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ArcContainerImpl.java @@ -36,6 +36,7 @@ import jakarta.enterprise.context.Initialized; import jakarta.enterprise.context.NormalScope; import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.SessionScoped; import jakarta.enterprise.event.Event; import jakarta.enterprise.inject.AmbiguousResolutionException; import jakarta.enterprise.inject.Any; @@ -207,9 +208,14 @@ public List get() { notifierOrNull(Set.of(BeforeDestroyed.Literal.REQUEST, Any.Literal.INSTANCE)), notifierOrNull(Set.of(Destroyed.Literal.REQUEST, Any.Literal.INSTANCE)), requestContextInstances != null ? requestContextInstances : ComputingCacheContextInstances::new); + SessionContext sessionContext = new SessionContext(this.currentContextFactory.create(SessionScoped.class), + notifierOrNull(Set.of(Initialized.Literal.SESSION, Any.Literal.INSTANCE)), + notifierOrNull(Set.of(BeforeDestroyed.Literal.SESSION, Any.Literal.INSTANCE)), + notifierOrNull(Set.of(Destroyed.Literal.SESSION, Any.Literal.INSTANCE)), ComputingCacheContextInstances::new); Contexts.Builder contextsBuilder = new Contexts.Builder( requestContext, + sessionContext, applicationContext, new SingletonContext(), new DependentContext()); @@ -399,6 +405,11 @@ public ManagedContext requestContext() { return contexts.requestContext; } + @Override + public ManagedContext sessionContext() { + return contexts.sessionContext; + } + @Override public BeanManager beanManager() { return BeanManagerImpl.INSTANCE.get(); diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/Contexts.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/Contexts.java index 5858b982dcc87..b0ec74902fa0b 100644 --- a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/Contexts.java +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/Contexts.java @@ -12,6 +12,7 @@ import jakarta.enterprise.context.ApplicationScoped; import jakarta.enterprise.context.Dependent; import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.SessionScoped; import jakarta.inject.Singleton; import io.quarkus.arc.InjectableContext; @@ -26,6 +27,7 @@ class Contexts { // Built-in contexts final ManagedContext requestContext; + final ManagedContext sessionContext; final InjectableContext applicationContext; final InjectableContext singletonContext; final InjectableContext dependentContext; @@ -35,6 +37,7 @@ class Contexts { private final List singletonContextSingleton; private final List dependentContextSingleton; private final List requestContextSingleton; + private final List sessionContextSingleton; // Lazily computed list of contexts for a scope private final ClassValue> unoptimizedContexts; @@ -42,18 +45,22 @@ class Contexts { // Precomputed values final Set> scopes; - Contexts(ManagedContext requestContext, InjectableContext applicationContext, InjectableContext singletonContext, + Contexts(ManagedContext requestContext, ManagedContext sessionContext, InjectableContext applicationContext, + InjectableContext singletonContext, InjectableContext dependentContext, Map, List> contexts) { this.requestContext = requestContext; + this.sessionContext = sessionContext; this.applicationContext = applicationContext; this.singletonContext = singletonContext; this.dependentContext = dependentContext; - this.applicationContextSingleton = Collections.singletonList(applicationContext); - this.singletonContextSingleton = Collections.singletonList(singletonContext); - this.dependentContextSingleton = Collections.singletonList(dependentContext); + this.applicationContextSingleton = List.of(applicationContext); + this.singletonContextSingleton = List.of(singletonContext); + this.dependentContextSingleton = List.of(dependentContext); List requestContexts = contexts.get(RequestScoped.class); - this.requestContextSingleton = requestContexts != null ? requestContexts : Collections.singletonList(requestContext); + this.requestContextSingleton = requestContexts != null ? requestContexts : List.of(requestContext); + List sessionContexts = contexts.get(SessionScoped.class); + this.sessionContextSingleton = sessionContexts != null ? sessionContexts : List.of(sessionContext); if (!contexts.isEmpty()) { // At least one custom context is registered @@ -84,11 +91,13 @@ protected List computeValue(Class type) { all.add(Singleton.class); all.add(Dependent.class); all.add(RequestScoped.class); + all.add(SessionScoped.class); this.scopes = Set.copyOf(all); } else { // No custom context is registered this.unoptimizedContexts = null; - this.scopes = Set.of(ApplicationScoped.class, Singleton.class, Dependent.class, RequestScoped.class); + this.scopes = Set.of(ApplicationScoped.class, Singleton.class, Dependent.class, RequestScoped.class, + SessionScoped.class); } } @@ -125,6 +134,8 @@ List getContexts(Class scopeType) { return singletonContextSingleton; } else if (Dependent.class.equals(scopeType)) { return dependentContextSingleton; + } else if (SessionScoped.class.equals(scopeType)) { + return sessionContextSingleton; } return unoptimizedContexts != null ? unoptimizedContexts.get(scopeType) : Collections.emptyList(); } @@ -132,14 +143,16 @@ List getContexts(Class scopeType) { static class Builder { private final ManagedContext requestContext; + private final ManagedContext sessionContext; private final InjectableContext applicationContext; private final InjectableContext singletonContext; private final InjectableContext dependentContext; private final Map, List> contexts = new HashMap<>(); - public Builder(ManagedContext requestContext, InjectableContext applicationContext, + public Builder(ManagedContext requestContext, ManagedContext sessionContext, InjectableContext applicationContext, InjectableContext singletonContext, InjectableContext dependentContext) { this.requestContext = requestContext; + this.sessionContext = sessionContext; this.applicationContext = applicationContext; this.singletonContext = singletonContext; this.dependentContext = dependentContext; @@ -163,7 +176,8 @@ Contexts build() { // If a custom request context is registered then add the built-in context as well putContext(requestContext); } - return new Contexts(requestContext, applicationContext, singletonContext, dependentContext, contexts); + return new Contexts(requestContext, sessionContext, applicationContext, singletonContext, dependentContext, + contexts); } } diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/SessionContext.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/SessionContext.java new file mode 100644 index 0000000000000..fb44039c68305 --- /dev/null +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/SessionContext.java @@ -0,0 +1,35 @@ +package io.quarkus.arc.impl; + +import java.lang.annotation.Annotation; +import java.util.function.Supplier; + +import jakarta.enterprise.context.ContextNotActiveException; +import jakarta.enterprise.context.SessionScoped; + +import io.quarkus.arc.CurrentContext; +import io.quarkus.arc.impl.EventImpl.Notifier; + +/** + * The built-in context for {@link SessionScoped}. + */ +public class SessionContext extends CurrentManagedContext { + + public SessionContext(CurrentContext currentContext, Notifier initializedNotifier, + Notifier beforeDestroyedNotifier, Notifier destroyedNotifier, + Supplier contextInstances) { + super(currentContext, contextInstances, initializedNotifier != null ? initializedNotifier::notify : null, + beforeDestroyedNotifier != null ? beforeDestroyedNotifier::notify : null, + destroyedNotifier != null ? destroyedNotifier::notify : null); + } + + @Override + public Class getScope() { + return SessionScoped.class; + } + + @Override + protected ContextNotActiveException notActive() { + return new ContextNotActiveException("Session context is not active"); + } + +} diff --git a/independent-projects/arc/tests/src/test/java/io/quarkus/arc/test/contexts/session/ContextObserver.java b/independent-projects/arc/tests/src/test/java/io/quarkus/arc/test/contexts/session/ContextObserver.java new file mode 100644 index 0000000000000..375d944058227 --- /dev/null +++ b/independent-projects/arc/tests/src/test/java/io/quarkus/arc/test/contexts/session/ContextObserver.java @@ -0,0 +1,34 @@ +package io.quarkus.arc.test.contexts.session; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.BeforeDestroyed; +import jakarta.enterprise.context.Destroyed; +import jakarta.enterprise.context.Initialized; +import jakarta.enterprise.context.SessionScoped; +import jakarta.enterprise.event.Observes; + +@ApplicationScoped +public class ContextObserver { + + static volatile int initializedObserved = 0; + static volatile int beforeDestroyedObserved = 0; + static volatile int destroyedObserved = 0; + + static void reset() { + initializedObserved = 0; + beforeDestroyedObserved = 0; + destroyedObserved = 0; + } + + void observeContextInit(@Observes @Initialized(SessionScoped.class) Object event) { + initializedObserved++; + } + + void observeContextBeforeDestroyed(@Observes @BeforeDestroyed(SessionScoped.class) Object event) { + beforeDestroyedObserved++; + } + + void observeContextDestroyed(@Observes @Destroyed(SessionScoped.class) Object event) { + destroyedObserved++; + } +} diff --git a/independent-projects/arc/tests/src/test/java/io/quarkus/arc/test/contexts/session/Controller.java b/independent-projects/arc/tests/src/test/java/io/quarkus/arc/test/contexts/session/Controller.java new file mode 100644 index 0000000000000..5fabd0f130e50 --- /dev/null +++ b/independent-projects/arc/tests/src/test/java/io/quarkus/arc/test/contexts/session/Controller.java @@ -0,0 +1,30 @@ +package io.quarkus.arc.test.contexts.session; + +import java.util.UUID; +import java.util.concurrent.atomic.AtomicBoolean; + +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.enterprise.context.SessionScoped; + +@SessionScoped +public class Controller { + + static final AtomicBoolean DESTROYED = new AtomicBoolean(); + + private String id; + + @PostConstruct + void init() { + id = UUID.randomUUID().toString(); + } + + @PreDestroy + void destroy() { + DESTROYED.set(true); + } + + String getId() { + return id; + } +} diff --git a/independent-projects/arc/tests/src/test/java/io/quarkus/arc/test/contexts/session/SessionContextTest.java b/independent-projects/arc/tests/src/test/java/io/quarkus/arc/test/contexts/session/SessionContextTest.java new file mode 100644 index 0000000000000..871deb5d45be3 --- /dev/null +++ b/independent-projects/arc/tests/src/test/java/io/quarkus/arc/test/contexts/session/SessionContextTest.java @@ -0,0 +1,100 @@ +package io.quarkus.arc.test.contexts.session; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import jakarta.enterprise.context.ContextNotActiveException; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.arc.Arc; +import io.quarkus.arc.ArcContainer; +import io.quarkus.arc.ManagedContext; +import io.quarkus.arc.test.ArcTestContainer; + +public class SessionContextTest { + + @RegisterExtension + public ArcTestContainer container = new ArcTestContainer(Controller.class, ContextObserver.class); + + @Test + public void testSessionContext() { + Controller.DESTROYED.set(false); + ArcContainer arc = Arc.container(); + ManagedContext sessionContext = arc.sessionContext(); + + try { + arc.instance(Controller.class).get().getId(); + fail(); + } catch (ContextNotActiveException expected) { + } + + sessionContext.activate(); + assertFalse(Controller.DESTROYED.get()); + Controller controller1 = arc.instance(Controller.class).get(); + Controller controller2 = arc.instance(Controller.class).get(); + String controller2Id = controller2.getId(); + assertEquals(controller1.getId(), controller2Id); + sessionContext.terminate(); + assertTrue(Controller.DESTROYED.get()); + + try { + arc.instance(Controller.class).get().getId(); + fail(); + } catch (ContextNotActiveException expected) { + } + + // Id must be different in a different context + Controller.DESTROYED.set(false); + sessionContext.activate(); + assertNotEquals(controller2Id, arc.instance(Controller.class).get().getId()); + sessionContext.terminate(); + assertTrue(Controller.DESTROYED.get()); + + Controller.DESTROYED.set(false); + sessionContext.activate(); + assertNotEquals(controller2Id, arc.instance(Controller.class).get().getId()); + sessionContext.terminate(); + assertTrue(Controller.DESTROYED.get()); + } + + @Test + public void testSessionContextEvents() { + // reset counters since other tests might have triggered it already + ContextObserver.reset(); + + // firstly test manual activation + ArcContainer arc = Arc.container(); + ManagedContext sessionContext = arc.sessionContext(); + + try { + arc.instance(Controller.class).get().getId(); + fail(); + } catch (ContextNotActiveException expected) { + } + + sessionContext.activate(); + assertEquals(1, ContextObserver.initializedObserved); + assertEquals(0, ContextObserver.beforeDestroyedObserved); + assertEquals(0, ContextObserver.destroyedObserved); + + // dummy check that bean is available + arc.instance(Controller.class).get().getId(); + + sessionContext.terminate(); + assertEquals(1, ContextObserver.initializedObserved); + assertEquals(1, ContextObserver.beforeDestroyedObserved); + assertEquals(1, ContextObserver.destroyedObserved); + + try { + arc.instance(Controller.class).get().getId(); + fail(); + } catch (ContextNotActiveException expected) { + } + } + +} From a3f0d32b7ef52cdbd8d2ac4ee075f82225a799be Mon Sep 17 00:00:00 2001 From: Martin Kouba Date: Tue, 17 Dec 2024 11:24:36 +0100 Subject: [PATCH 3/4] WebSockets Next: replace WebSocketSessionContext with built-in context --- .../next/deployment/WebSocketProcessor.java | 17 ----- .../next/runtime/ContextSupport.java | 9 ++- .../websockets/next/runtime/Endpoints.java | 20 ++---- .../next/runtime/WebSocketSessionContext.java | 62 ------------------- .../java/io/quarkus/arc/ManagedContext.java | 6 ++ .../arc/impl/CurrentManagedContext.java | 5 +- 6 files changed, 18 insertions(+), 101 deletions(-) delete mode 100644 extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketSessionContext.java diff --git a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java index 1c5ecbb6df7ce..9c8c02a2deadc 100644 --- a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java +++ b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java @@ -45,9 +45,6 @@ import io.quarkus.arc.deployment.BeanArchiveIndexBuildItem; import io.quarkus.arc.deployment.BeanDefiningAnnotationBuildItem; import io.quarkus.arc.deployment.BeanDiscoveryFinishedBuildItem; -import io.quarkus.arc.deployment.ContextRegistrationPhaseBuildItem; -import io.quarkus.arc.deployment.ContextRegistrationPhaseBuildItem.ContextConfiguratorBuildItem; -import io.quarkus.arc.deployment.CustomScopeBuildItem; import io.quarkus.arc.deployment.InvokerFactoryBuildItem; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; import io.quarkus.arc.deployment.SyntheticBeansRuntimeInitBuildItem; @@ -126,7 +123,6 @@ import io.quarkus.websockets.next.runtime.WebSocketEndpointBase; import io.quarkus.websockets.next.runtime.WebSocketHttpServerOptionsCustomizer; import io.quarkus.websockets.next.runtime.WebSocketServerRecorder; -import io.quarkus.websockets.next.runtime.WebSocketSessionContext; import io.quarkus.websockets.next.runtime.kotlin.ApplicationCoroutineScope; import io.quarkus.websockets.next.runtime.kotlin.CoroutineInvoker; import io.quarkus.websockets.next.runtime.telemetry.ErrorInterceptor; @@ -229,19 +225,6 @@ void produceCoroutineScope(BuildProducer additionalBean .build()); } - @BuildStep - ContextConfiguratorBuildItem registerSessionContext(ContextRegistrationPhaseBuildItem phase) { - return new ContextConfiguratorBuildItem(phase.getContext() - .configure(SessionScoped.class) - .normal() - .contextClass(WebSocketSessionContext.class)); - } - - @BuildStep - CustomScopeBuildItem registerSessionScope() { - return new CustomScopeBuildItem(DotName.createSimple(SessionScoped.class)); - } - @BuildStep void builtinCallbackArguments(BuildProducer providers) { providers.produce(new CallbackArgumentBuildItem(new MessageCallbackArgument())); diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java index 0698723d361dd..26823e8b9b5c7 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java @@ -6,7 +6,6 @@ import io.quarkus.arc.InjectableContext.ContextState; import io.quarkus.arc.ManagedContext; -import io.quarkus.arc.impl.CurrentManagedContext.CurrentContextState; import io.quarkus.vertx.core.runtime.context.VertxContextSafetyToggle; import io.smallrye.common.vertx.VertxContext; import io.vertx.core.Context; @@ -21,12 +20,12 @@ public class ContextSupport { static final String WEB_SOCKET_CONN_KEY = WebSocketConnectionBase.class.getName(); private final WebSocketConnectionBase connection; - private final CurrentContextState sessionContextState; - private final WebSocketSessionContext sessionContext; + private final ContextState sessionContextState; + private final ManagedContext sessionContext; private final ManagedContext requestContext; - ContextSupport(WebSocketConnectionBase connection, CurrentContextState sessionContextState, - WebSocketSessionContext sessionContext, + ContextSupport(WebSocketConnectionBase connection, ContextState sessionContextState, + ManagedContext sessionContext, ManagedContext requestContext) { this.connection = connection; this.sessionContext = sessionContext; diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java index 95d7e2158d7bc..64f43ee3a7377 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java @@ -4,14 +4,12 @@ import java.util.Optional; import java.util.function.Consumer; -import jakarta.enterprise.context.SessionScoped; - import org.jboss.logging.Logger; import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; import io.quarkus.arc.ArcContainer; import io.quarkus.arc.InjectableContext; -import io.quarkus.arc.impl.CurrentManagedContext.CurrentContextState; +import io.quarkus.arc.ManagedContext; import io.quarkus.runtime.LaunchMode; import io.quarkus.security.AuthenticationFailedException; import io.quarkus.security.ForbiddenException; @@ -43,11 +41,11 @@ static void initialize(Vertx vertx, ArcContainer container, Codecs codecs, WebSo // Initialize and capture the session context state that will be activated // during message processing - WebSocketSessionContext sessionContext = null; - CurrentContextState sessionContextState = null; + ManagedContext sessionContext = null; + InjectableContext.ContextState sessionContextState = null; if (activateSessionContext) { - sessionContext = sessionContext(container); - sessionContextState = sessionContext.initializeContextState(); + sessionContext = container.sessionContext(); + sessionContextState = sessionContext.initializeState(); } ContextSupport contextSupport = new ContextSupport(connection, sessionContextState, sessionContext, activateRequestContext ? container.requestContext() : null); @@ -406,12 +404,4 @@ private static WebSocketEndpoint createEndpoint(String endpointClassName, Contex } } - private static WebSocketSessionContext sessionContext(ArcContainer container) { - for (InjectableContext injectableContext : container.getContexts(SessionScoped.class)) { - if (WebSocketSessionContext.class.equals(injectableContext.getClass())) { - return (WebSocketSessionContext) injectableContext; - } - } - throw new WebSocketException("CDI session context not registered"); - } } diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketSessionContext.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketSessionContext.java deleted file mode 100644 index 4d5a91b170a7d..0000000000000 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketSessionContext.java +++ /dev/null @@ -1,62 +0,0 @@ -package io.quarkus.websockets.next.runtime; - -import java.lang.annotation.Annotation; -import java.util.function.Consumer; -import java.util.function.Supplier; - -import jakarta.enterprise.context.BeforeDestroyed; -import jakarta.enterprise.context.ContextNotActiveException; -import jakarta.enterprise.context.Destroyed; -import jakarta.enterprise.context.Initialized; -import jakarta.enterprise.context.SessionScoped; -import jakarta.enterprise.event.Event; -import jakarta.enterprise.inject.Any; - -import io.quarkus.arc.Arc; -import io.quarkus.arc.ArcContainer; -import io.quarkus.arc.CurrentContextFactory; -import io.quarkus.arc.impl.ComputingCacheContextInstances; -import io.quarkus.arc.impl.CurrentManagedContext; -import io.quarkus.arc.impl.LazyValue; - -public class WebSocketSessionContext extends CurrentManagedContext { - - public WebSocketSessionContext(CurrentContextFactory currentContextFactory) { - super(currentContextFactory.create(SessionScoped.class), ComputingCacheContextInstances::new, - newEvent(Initialized.Literal.SESSION, Any.Literal.INSTANCE), - newEvent(BeforeDestroyed.Literal.SESSION, Any.Literal.INSTANCE), - newEvent(Destroyed.Literal.SESSION, Any.Literal.INSTANCE)); - } - - @Override - public Class getScope() { - return SessionScoped.class; - } - - protected ContextNotActiveException notActive() { - return new ContextNotActiveException("Session context is not active"); - } - - private static Consumer newEvent(Annotation... qualifiers) { - LazyValue> event = new LazyValue<>(new Supplier>() { - @Override - public Event get() { - ArcContainer container = Arc.container(); - if (container.resolveObserverMethods(Object.class, qualifiers).isEmpty()) { - return null; - } - return container.beanManager().getEvent().select(qualifiers); - } - }); - return new Consumer() { - - @Override - public void accept(Object t) { - Event e = event.get(); - if (e != null) { - e.fire(t); - } - } - }; - } -} diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ManagedContext.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ManagedContext.java index 69e7d09cbc08c..2cfbf4913628f 100644 --- a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ManagedContext.java +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ManagedContext.java @@ -50,4 +50,10 @@ default void terminate() { destroy(); deactivate(); } + + /** + * + * @return a new initialized context state + */ + ContextState initializeState(); } diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/CurrentManagedContext.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/CurrentManagedContext.java index 7d5da765a94fc..266185912d5a1 100644 --- a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/CurrentManagedContext.java +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/CurrentManagedContext.java @@ -60,7 +60,7 @@ public ContextState activate(ContextState initialState) { traceActivate(initialState); } if (initialState == null) { - CurrentContextState state = initializeContextState(); + CurrentContextState state = initializeState(); currentContext.set(state); return state; } else { @@ -181,7 +181,8 @@ public void destroy(ContextState state) { } } - public CurrentContextState initializeContextState() { + @Override + public CurrentContextState initializeState() { CurrentContextState state = new CurrentContextState(contextInstances.get()); fireIfNotNull(initializedNotifier); return state; From 09706b9a85799b73d9c8cf16232702e4c5d01f6b Mon Sep 17 00:00:00 2001 From: Martin Kouba Date: Tue, 17 Dec 2024 12:53:52 +0100 Subject: [PATCH 4/4] ArC: introduce the ActivateSessionContext interceptor binding - it's only available in tests - fixes #45146 Co-authored-by: Matej Novotny --- .../asciidoc/getting-started-testing.adoc | 2 + .../quarkus/arc/deployment/ArcProcessor.java | 24 ------- .../quarkus/arc/deployment/ArcTestSteps.java | 71 +++++++++++++++++++ .../arc/test/context/session/Client.java | 32 +++++++++ .../context/session/SessionContextTest.java | 49 +++++++++++++ .../arc/test/context/session/SimpleBean.java | 30 ++++++++ .../ActivateSessionContextInterceptor.java | 30 ++++++++ .../quarkus/test/ActivateSessionContext.java | 30 ++++++++ 8 files changed, 244 insertions(+), 24 deletions(-) create mode 100644 extensions/arc/deployment/src/main/java/io/quarkus/arc/deployment/ArcTestSteps.java create mode 100644 extensions/arc/deployment/src/test/java/io/quarkus/arc/test/context/session/Client.java create mode 100644 extensions/arc/deployment/src/test/java/io/quarkus/arc/test/context/session/SessionContextTest.java create mode 100644 extensions/arc/deployment/src/test/java/io/quarkus/arc/test/context/session/SimpleBean.java create mode 100644 extensions/arc/runtime/src/main/java/io/quarkus/arc/runtime/test/ActivateSessionContextInterceptor.java create mode 100644 test-framework/common/src/main/java/io/quarkus/test/ActivateSessionContext.java diff --git a/docs/src/main/asciidoc/getting-started-testing.adoc b/docs/src/main/asciidoc/getting-started-testing.adoc index e2c0fac89ccbe..35c28cc97071a 100644 --- a/docs/src/main/asciidoc/getting-started-testing.adoc +++ b/docs/src/main/asciidoc/getting-started-testing.adoc @@ -369,6 +369,8 @@ public class GreetingServiceTest { ---- <1> The `GreetingService` bean will be injected into the test +TIP: If you want to inject/test a `@SessionScoped` bean then it's very likely that the session context is not active and you would receive the `ContextNotActiveException` when a method of the injected bean is invoked. However, it's possible to use the `@io.quarkus.test.ActivateSessionContext` interceptor binding to activate the session context for a specific business method. Please read the javadoc for futher limitations. + == Applying Interceptors to Tests As mentioned above Quarkus tests are actually full CDI beans, and as such you can apply CDI interceptors as you would diff --git a/extensions/arc/deployment/src/main/java/io/quarkus/arc/deployment/ArcProcessor.java b/extensions/arc/deployment/src/main/java/io/quarkus/arc/deployment/ArcProcessor.java index ff175046a6960..a89b988cb9184 100644 --- a/extensions/arc/deployment/src/main/java/io/quarkus/arc/deployment/ArcProcessor.java +++ b/extensions/arc/deployment/src/main/java/io/quarkus/arc/deployment/ArcProcessor.java @@ -68,16 +68,13 @@ import io.quarkus.arc.runtime.LoggerProducer; import io.quarkus.arc.runtime.appcds.AppCDSRecorder; import io.quarkus.arc.runtime.context.ArcContextProvider; -import io.quarkus.arc.runtime.test.PreloadedTestApplicationClassPredicate; import io.quarkus.bootstrap.BootstrapDebug; import io.quarkus.deployment.Capabilities; import io.quarkus.deployment.Capability; import io.quarkus.deployment.Feature; -import io.quarkus.deployment.IsTest; import io.quarkus.deployment.annotations.BuildProducer; import io.quarkus.deployment.annotations.BuildStep; import io.quarkus.deployment.annotations.Consume; -import io.quarkus.deployment.annotations.ExecutionTime; import io.quarkus.deployment.annotations.Produce; import io.quarkus.deployment.annotations.Record; import io.quarkus.deployment.builditem.AdditionalApplicationArchiveMarkerBuildItem; @@ -653,27 +650,6 @@ public void signalBeanContainerReady(AppCDSRecorder recorder, PreBeanContainerBu beanContainerProducer.produce(new BeanContainerBuildItem(bi.getValue())); } - @BuildStep(onlyIf = IsTest.class) - public AdditionalBeanBuildItem testApplicationClassPredicateBean() { - // We need to register the bean implementation for TestApplicationClassPredicate - // TestApplicationClassPredicate is used programmatically in the ArC recorder when StartupEvent is fired - return AdditionalBeanBuildItem.unremovableOf(PreloadedTestApplicationClassPredicate.class); - } - - @BuildStep(onlyIf = IsTest.class) - @Record(ExecutionTime.STATIC_INIT) - void initTestApplicationClassPredicateBean(ArcRecorder recorder, BeanContainerBuildItem beanContainer, - BeanDiscoveryFinishedBuildItem beanDiscoveryFinished, - CompletedApplicationClassPredicateBuildItem predicate) { - Set applicationBeanClasses = new HashSet<>(); - for (BeanInfo bean : beanDiscoveryFinished.beanStream().classBeans()) { - if (predicate.test(bean.getBeanClass())) { - applicationBeanClasses.add(bean.getBeanClass().toString()); - } - } - recorder.initTestApplicationClassPredicate(applicationBeanClasses); - } - @BuildStep List marker() { return Arrays.asList(new AdditionalApplicationArchiveMarkerBuildItem("META-INF/beans.xml"), diff --git a/extensions/arc/deployment/src/main/java/io/quarkus/arc/deployment/ArcTestSteps.java b/extensions/arc/deployment/src/main/java/io/quarkus/arc/deployment/ArcTestSteps.java new file mode 100644 index 0000000000000..66c86e0b055e5 --- /dev/null +++ b/extensions/arc/deployment/src/main/java/io/quarkus/arc/deployment/ArcTestSteps.java @@ -0,0 +1,71 @@ +package io.quarkus.arc.deployment; + +import java.util.HashSet; +import java.util.Set; +import java.util.function.Predicate; + +import org.jboss.jandex.AnnotationInstance; +import org.jboss.jandex.AnnotationTransformation; +import org.jboss.jandex.DotName; + +import io.quarkus.arc.processor.BeanInfo; +import io.quarkus.arc.runtime.ArcRecorder; +import io.quarkus.arc.runtime.test.ActivateSessionContextInterceptor; +import io.quarkus.arc.runtime.test.PreloadedTestApplicationClassPredicate; +import io.quarkus.deployment.IsTest; +import io.quarkus.deployment.annotations.BuildProducer; +import io.quarkus.deployment.annotations.BuildStep; +import io.quarkus.deployment.annotations.BuildSteps; +import io.quarkus.deployment.annotations.ExecutionTime; +import io.quarkus.deployment.annotations.Record; +import io.quarkus.deployment.builditem.ApplicationClassPredicateBuildItem; + +@BuildSteps(onlyIf = IsTest.class) +public class ArcTestSteps { + + @BuildStep + public void additionalBeans(BuildProducer additionalBeans) { + // We need to register the bean implementation for TestApplicationClassPredicate + // TestApplicationClassPredicate is used programmatically in the ArC recorder when StartupEvent is fired + additionalBeans.produce(AdditionalBeanBuildItem.unremovableOf(PreloadedTestApplicationClassPredicate.class)); + // In tests, register the ActivateSessionContextInterceptor and ActivateSessionContext interceptor binding + additionalBeans.produce(new AdditionalBeanBuildItem(ActivateSessionContextInterceptor.class)); + additionalBeans.produce(new AdditionalBeanBuildItem("io.quarkus.test.ActivateSessionContext")); + } + + @BuildStep + AnnotationsTransformerBuildItem addInterceptorBinding() { + return new AnnotationsTransformerBuildItem( + AnnotationTransformation.forClasses().whenClass(ActivateSessionContextInterceptor.class).transform(tc -> tc.add( + AnnotationInstance.builder(DotName.createSimple("io.quarkus.test.ActivateSessionContext")).build()))); + } + + // For some reason the annotation literal generated for io.quarkus.test.ActivateSessionContext lives in app class loader. + // This predicates ensures that the generated bean is considered an app class too. + // As a consequence, the type and all methods of ActivateSessionContextInterceptor must be public. + @BuildStep + ApplicationClassPredicateBuildItem appClassPredicate() { + return new ApplicationClassPredicateBuildItem(new Predicate() { + + @Override + public boolean test(String name) { + return name.startsWith(ActivateSessionContextInterceptor.class.getName()); + } + }); + } + + @BuildStep + @Record(ExecutionTime.STATIC_INIT) + void initTestApplicationClassPredicateBean(ArcRecorder recorder, BeanContainerBuildItem beanContainer, + BeanDiscoveryFinishedBuildItem beanDiscoveryFinished, + CompletedApplicationClassPredicateBuildItem predicate) { + Set applicationBeanClasses = new HashSet<>(); + for (BeanInfo bean : beanDiscoveryFinished.beanStream().classBeans()) { + if (predicate.test(bean.getBeanClass())) { + applicationBeanClasses.add(bean.getBeanClass().toString()); + } + } + recorder.initTestApplicationClassPredicate(applicationBeanClasses); + } + +} diff --git a/extensions/arc/deployment/src/test/java/io/quarkus/arc/test/context/session/Client.java b/extensions/arc/deployment/src/test/java/io/quarkus/arc/test/context/session/Client.java new file mode 100644 index 0000000000000..2be51f601a58f --- /dev/null +++ b/extensions/arc/deployment/src/test/java/io/quarkus/arc/test/context/session/Client.java @@ -0,0 +1,32 @@ +package io.quarkus.arc.test.context.session; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import jakarta.enterprise.context.Dependent; +import jakarta.enterprise.context.SessionScoped; +import jakarta.inject.Inject; + +import io.quarkus.arc.Arc; +import io.quarkus.arc.ClientProxy; +import io.quarkus.test.ActivateSessionContext; + +@Dependent +class Client { + + @Inject + SimpleBean bean; + + @ActivateSessionContext + public String ping() { + assertTrue(Arc.container().sessionContext().isActive()); + if (bean instanceof ClientProxy proxy) { + assertEquals(SessionScoped.class, proxy.arc_bean().getScope()); + } else { + fail("Not a client proxy"); + } + return bean.ping(); + } + +} diff --git a/extensions/arc/deployment/src/test/java/io/quarkus/arc/test/context/session/SessionContextTest.java b/extensions/arc/deployment/src/test/java/io/quarkus/arc/test/context/session/SessionContextTest.java new file mode 100644 index 0000000000000..b45fbfa2979af --- /dev/null +++ b/extensions/arc/deployment/src/test/java/io/quarkus/arc/test/context/session/SessionContextTest.java @@ -0,0 +1,49 @@ +package io.quarkus.arc.test.context.session; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.arc.Arc; +import io.quarkus.arc.ManagedContext; +import io.quarkus.test.QuarkusUnitTest; + +public class SessionContextTest { + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .withApplicationRoot(root -> root + .addClasses(SimpleBean.class, Client.class)); + + @Inject + Client client; + + @Inject + SimpleBean simpleBean; + + @Test + public void testContexts() { + assertFalse(Arc.container().sessionContext().isActive()); + assertNotNull(client.ping()); + assertTrue(SimpleBean.DESTROYED.get()); + assertFalse(Arc.container().sessionContext().isActive()); + SimpleBean.DESTROYED.set(false); + + ManagedContext sessionContext = Arc.container().sessionContext(); + try { + sessionContext.activate(); + String id = simpleBean.ping(); + assertEquals(id, client.ping()); + assertFalse(SimpleBean.DESTROYED.get()); + } finally { + sessionContext.terminate(); + } + assertTrue(SimpleBean.DESTROYED.get()); + } +} diff --git a/extensions/arc/deployment/src/test/java/io/quarkus/arc/test/context/session/SimpleBean.java b/extensions/arc/deployment/src/test/java/io/quarkus/arc/test/context/session/SimpleBean.java new file mode 100644 index 0000000000000..77bb8f5f81483 --- /dev/null +++ b/extensions/arc/deployment/src/test/java/io/quarkus/arc/test/context/session/SimpleBean.java @@ -0,0 +1,30 @@ +package io.quarkus.arc.test.context.session; + +import java.util.UUID; +import java.util.concurrent.atomic.AtomicBoolean; + +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.enterprise.context.SessionScoped; + +@SessionScoped +class SimpleBean { + + static final AtomicBoolean DESTROYED = new AtomicBoolean(); + + private String id; + + @PostConstruct + void init() { + id = UUID.randomUUID().toString(); + } + + public String ping() { + return id; + } + + @PreDestroy + void destroy() { + DESTROYED.set(true); + } +} \ No newline at end of file diff --git a/extensions/arc/runtime/src/main/java/io/quarkus/arc/runtime/test/ActivateSessionContextInterceptor.java b/extensions/arc/runtime/src/main/java/io/quarkus/arc/runtime/test/ActivateSessionContextInterceptor.java new file mode 100644 index 0000000000000..ea5452e2babc7 --- /dev/null +++ b/extensions/arc/runtime/src/main/java/io/quarkus/arc/runtime/test/ActivateSessionContextInterceptor.java @@ -0,0 +1,30 @@ +package io.quarkus.arc.runtime.test; + +import jakarta.annotation.Priority; +import jakarta.interceptor.AroundInvoke; +import jakarta.interceptor.Interceptor; +import jakarta.interceptor.InvocationContext; + +import io.quarkus.arc.Arc; +import io.quarkus.arc.ManagedContext; + +// The @ActivateSessionContext interceptor binding is added by the extension +@Interceptor +@Priority(Interceptor.Priority.PLATFORM_BEFORE + 100) +public class ActivateSessionContextInterceptor { + + @AroundInvoke + public Object aroundInvoke(InvocationContext ctx) throws Exception { + ManagedContext sessionContext = Arc.container().sessionContext(); + if (sessionContext.isActive()) { + return ctx.proceed(); + } + try { + sessionContext.activate(); + return ctx.proceed(); + } finally { + sessionContext.terminate(); + } + } + +} diff --git a/test-framework/common/src/main/java/io/quarkus/test/ActivateSessionContext.java b/test-framework/common/src/main/java/io/quarkus/test/ActivateSessionContext.java new file mode 100644 index 0000000000000..3acd8b00525af --- /dev/null +++ b/test-framework/common/src/main/java/io/quarkus/test/ActivateSessionContext.java @@ -0,0 +1,30 @@ +package io.quarkus.test; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.TYPE; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; +import java.util.concurrent.CompletionStage; + +import jakarta.interceptor.InterceptorBinding; + +/** + * Activates the session context before the intercepted method is called, and terminates the context when the method invocation + * completes (regardless of any exceptions being thrown). + *

+ * If the context is already active, it's a noop - the context is neither activated nor deactivated. + *

+ * Keep in mind that if the method returns an asynchronous type (such as {@link CompletionStage} then the session context is + * still terminated when the invocation completes and not at the time the asynchronous type is completed. Also note that session + * context is not propagated by MicroProfile Context Propagation. + *

+ * This interceptor binding is only available in tests. + */ +@InterceptorBinding +@Target({ METHOD, TYPE }) +@Retention(RUNTIME) +public @interface ActivateSessionContext { + +}