diff --git a/extensions/vertx/deployment/src/main/java/io/quarkus/vertx/deployment/EventBusConsumer.java b/extensions/vertx/deployment/src/main/java/io/quarkus/vertx/deployment/EventBusConsumer.java index 78b952c475d4f..c4476f08be99a 100644 --- a/extensions/vertx/deployment/src/main/java/io/quarkus/vertx/deployment/EventBusConsumer.java +++ b/extensions/vertx/deployment/src/main/java/io/quarkus/vertx/deployment/EventBusConsumer.java @@ -4,6 +4,8 @@ import static io.quarkus.vertx.deployment.VertxConstants.MESSAGE; import static io.quarkus.vertx.deployment.VertxConstants.MUTINY_MESSAGE; import static io.quarkus.vertx.deployment.VertxConstants.UNI; +import static org.objectweb.asm.Opcodes.ACC_FINAL; +import static org.objectweb.asm.Opcodes.ACC_PRIVATE; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; @@ -22,17 +24,14 @@ import io.quarkus.arc.processor.BeanInfo; import io.quarkus.arc.processor.BuiltinScope; import io.quarkus.arc.processor.DotNames; -import io.quarkus.gizmo.AssignableResultHandle; -import io.quarkus.gizmo.BranchResult; -import io.quarkus.gizmo.BytecodeCreator; import io.quarkus.gizmo.ClassCreator; import io.quarkus.gizmo.ClassOutput; -import io.quarkus.gizmo.FunctionCreator; +import io.quarkus.gizmo.FieldCreator; +import io.quarkus.gizmo.FieldDescriptor; import io.quarkus.gizmo.MethodCreator; import io.quarkus.gizmo.MethodDescriptor; import io.quarkus.gizmo.ResultHandle; import io.quarkus.runtime.util.HashUtil; -import io.quarkus.vertx.ConsumeEvent; import io.quarkus.vertx.runtime.EventConsumerInvoker; import io.smallrye.common.annotation.Blocking; import io.smallrye.mutiny.Uni; @@ -42,6 +41,8 @@ class EventBusConsumer { private static final String INVOKER_SUFFIX = "_VertxInvoker"; + private static final MethodDescriptor INVOKER_CONSTRUCTOR = MethodDescriptor + .ofConstructor(EventConsumerInvoker.class); private static final MethodDescriptor ARC_CONTAINER = MethodDescriptor .ofMethod(Arc.class, "container", ArcContainer.class); private static final MethodDescriptor INSTANCE_HANDLE_GET = MethodDescriptor.ofMethod(InstanceHandle.class, "get", @@ -57,8 +58,6 @@ class EventBusConsumer { "newInstance", io.vertx.mutiny.core.eventbus.Message.class, Message.class); private static final MethodDescriptor MESSAGE_REPLY = MethodDescriptor.ofMethod(Message.class, "reply", void.class, Object.class); - private static final MethodDescriptor MESSAGE_FAIL = MethodDescriptor.ofMethod(Message.class, "fail", void.class, - Integer.TYPE, String.class); private static final MethodDescriptor MESSAGE_BODY = MethodDescriptor.ofMethod(Message.class, "body", Object.class); private static final MethodDescriptor INSTANCE_HANDLE_DESTROY = MethodDescriptor .ofMethod(InstanceHandle.class, "destroy", @@ -99,32 +98,58 @@ static String generateInvoker(BeanInfo bean, MethodInfo method, blocking = method.hasAnnotation(BLOCKING) || (blockingValue != null && blockingValue.asBoolean()); ClassCreator invokerCreator = ClassCreator.builder().classOutput(classOutput).className(generatedName) - .interfaces(EventConsumerInvoker.class).build(); + .superClass(EventConsumerInvoker.class).build(); - // The method descriptor is: void invokeBean(Object message) - MethodCreator invoke = invokerCreator.getMethodCreator("invokeBean", void.class, Object.class) - .addException(Exception.class); + // Initialized state + FieldCreator beanField = invokerCreator.getFieldCreator("bean", InjectableBean.class) + .setModifiers(ACC_PRIVATE | ACC_FINAL); + FieldCreator containerField = invokerCreator.getFieldCreator("container", ArcContainer.class) + .setModifiers(ACC_PRIVATE | ACC_FINAL); if (blocking) { MethodCreator isBlocking = invokerCreator.getMethodCreator("isBlocking", boolean.class); isBlocking.returnValue(isBlocking.load(true)); } - invoke(bean, method, invoke.getMethodParam(0), invoke); + implementConstructor(bean, invokerCreator, beanField, containerField); + implementInvoke(bean, method, invokerCreator, beanField.getFieldDescriptor(), containerField.getFieldDescriptor()); - invoke.returnValue(null); invokerCreator.close(); return generatedName.replace('/', '.'); } - private static void invoke(BeanInfo bean, MethodInfo method, ResultHandle messageHandle, BytecodeCreator invoke) { - ResultHandle containerHandle = invoke.invokeStaticMethod(ARC_CONTAINER); - ResultHandle beanHandle = invoke.invokeInterfaceMethod(ARC_CONTAINER_BEAN, containerHandle, - invoke.load(bean.getIdentifier())); + static void implementConstructor(BeanInfo bean, ClassCreator invokerCreator, FieldCreator beanField, + FieldCreator containerField) { + MethodCreator constructor = invokerCreator.getMethodCreator("", void.class); + // Invoke super() + constructor.invokeSpecialMethod(INVOKER_CONSTRUCTOR, constructor.getThis()); + + ResultHandle containerHandle = constructor + .invokeStaticMethod(ARC_CONTAINER); + ResultHandle beanHandle = constructor.invokeInterfaceMethod( + ARC_CONTAINER_BEAN, + containerHandle, constructor.load(bean.getIdentifier())); + constructor.writeInstanceField(beanField.getFieldDescriptor(), constructor.getThis(), beanHandle); + constructor.writeInstanceField(containerField.getFieldDescriptor(), constructor.getThis(), containerHandle); + constructor.returnValue(null); + } + + private static void implementInvoke(BeanInfo bean, MethodInfo method, ClassCreator invokerCreator, + FieldDescriptor beanField, + FieldDescriptor containerField) { + + // The method descriptor is: CompletionStage invokeBean(Message message) + MethodCreator invoke = invokerCreator.getMethodCreator("invokeBean", CompletionStage.class, Message.class) + .addException(Exception.class); + + ResultHandle containerHandle = invoke.readInstanceField(containerField, invoke.getThis()); + ResultHandle beanHandle = invoke.readInstanceField(beanField, invoke.getThis()); ResultHandle instanceHandle = invoke.invokeInterfaceMethod(ARC_CONTAINER_INSTANCE_FOR_BEAN, containerHandle, beanHandle); ResultHandle beanInstanceHandle = invoke .invokeInterfaceMethod(INSTANCE_HANDLE_GET, instanceHandle); + ResultHandle messageHandle = invoke.getMethodParam(0); + ResultHandle completionStage; Type paramType = method.parameters().get(0); if (paramType.name().equals(MESSAGE)) { @@ -133,6 +158,7 @@ private static void invoke(BeanInfo bean, MethodInfo method, ResultHandle messag MethodDescriptor .ofMethod(bean.getImplClazz().name().toString(), method.name(), void.class, Message.class), beanInstanceHandle, messageHandle); + completionStage = invoke.loadNull(); } else if (paramType.name().equals(MUTINY_MESSAGE)) { // io.vertx.mutiny.core.eventbus.Message ResultHandle mutinyMessageHandle = invoke.invokeStaticMethod(MUTINY_MESSAGE_NEW_INSTANCE, messageHandle); @@ -140,30 +166,27 @@ private static void invoke(BeanInfo bean, MethodInfo method, ResultHandle messag MethodDescriptor.ofMethod(bean.getImplClazz().name().toString(), method.name(), void.class, io.vertx.mutiny.core.eventbus.Message.class), beanInstanceHandle, mutinyMessageHandle); + completionStage = invoke.loadNull(); } else { // Parameter is payload ResultHandle bodyHandle = invoke.invokeInterfaceMethod(MESSAGE_BODY, messageHandle); - ResultHandle replyHandle = invoke.invokeVirtualMethod( + ResultHandle returnHandle = invoke.invokeVirtualMethod( MethodDescriptor.ofMethod(bean.getImplClazz().name().toString(), method.name(), method.returnType().name().toString(), paramType.name().toString()), beanInstanceHandle, bodyHandle); - if (replyHandle != null) { + if (returnHandle != null) { if (method.returnType().name().equals(COMPLETION_STAGE)) { - FunctionCreator handler = generateWhenCompleteHandler(messageHandle, invoke); - invoke.invokeInterfaceMethod( - WHEN_COMPLETE, - replyHandle, handler.getInstance()); + completionStage = returnHandle; } else if (method.returnType().name().equals(UNI)) { - // If the return type is Uni use uni.subscribeAsCompletionStage().whenComplete(...) - FunctionCreator handler = generateWhenCompleteHandler(messageHandle, invoke); - ResultHandle subscribedCompletionStage = invoke.invokeInterfaceMethod(SUBSCRIBE_AS_COMPLETION_STAGE, - replyHandle); - invoke.invokeInterfaceMethod(WHEN_COMPLETE, - subscribedCompletionStage, handler.getInstance()); + completionStage = invoke.invokeInterfaceMethod(SUBSCRIBE_AS_COMPLETION_STAGE, + returnHandle); } else { // Message.reply(returnValue) - invoke.invokeInterfaceMethod(MESSAGE_REPLY, messageHandle, replyHandle); + invoke.invokeInterfaceMethod(MESSAGE_REPLY, messageHandle, returnHandle); + completionStage = invoke.loadNull(); } + } else { + completionStage = invoke.loadNull(); } } @@ -171,58 +194,8 @@ private static void invoke(BeanInfo bean, MethodInfo method, ResultHandle messag if (BuiltinScope.DEPENDENT.is(bean.getScope())) { invoke.invokeInterfaceMethod(INSTANCE_HANDLE_DESTROY, instanceHandle); } - } - /** - * If the return type is CompletionStage use: - *
-     * cs.whenComplete((whenResult, whenFailure) -> {
-     *  if (failure != null) {
-     *         message.fail(status, whenFailure.getMessage());
-     *  } else {
-     *         message.reply(whenResult);
-     *  }
-     * })
-     * 
- * - * @param messageHandle the message variable - * @param invoke the bytecode creator - * @return the function - */ - private static FunctionCreator generateWhenCompleteHandler(ResultHandle messageHandle, BytecodeCreator invoke) { - FunctionCreator handler = invoke.createFunction(BiConsumer.class); - BytecodeCreator bytecode = handler.getBytecode(); - - // This avoid having to check cast in the branches - AssignableResultHandle whenResult = bytecode.createVariable(Object.class); - bytecode.assign(whenResult, bytecode.getMethodParam(0)); - AssignableResultHandle whenFailure = bytecode.createVariable(Exception.class); - bytecode.assign(whenFailure, bytecode.getMethodParam(1)); - AssignableResultHandle message = bytecode.createVariable(Message.class); - bytecode.assign(message, messageHandle); - - BranchResult ifFailureIfNull = bytecode.ifNull(whenFailure); - // failure is not null branch - message.fail(failureStatus, failure.getMessage()) - // In this branch we use the EXPLICIT FAILURE CODE - BytecodeCreator failureIsNotNull = ifFailureIfNull.falseBranch(); - ResultHandle failureStatus = failureIsNotNull.load(ConsumeEvent.EXPLICIT_FAILURE_CODE); - ResultHandle failureMessage = failureIsNotNull - .invokeVirtualMethod(THROWABLE_GET_MESSAGE, whenFailure); - failureIsNotNull.invokeInterfaceMethod( - MESSAGE_FAIL, - message, - failureStatus, - failureMessage); - - // failure is null branch - message.reply(reply)) - BytecodeCreator failureIsNull = ifFailureIfNull.trueBranch(); - failureIsNull.invokeInterfaceMethod( - MESSAGE_REPLY, - messageHandle, - whenResult); - - bytecode.returnValue(null); - return handler; + invoke.returnValue(completionStage); } private EventBusConsumer() { diff --git a/extensions/vertx/deployment/src/test/java/io/quarkus/vertx/deployment/RequestContextTerminationTest.java b/extensions/vertx/deployment/src/test/java/io/quarkus/vertx/deployment/RequestContextTerminationTest.java new file mode 100644 index 0000000000000..c09f420eb1a2b --- /dev/null +++ b/extensions/vertx/deployment/src/test/java/io/quarkus/vertx/deployment/RequestContextTerminationTest.java @@ -0,0 +1,104 @@ +package io.quarkus.vertx.deployment; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import javax.annotation.PreDestroy; +import javax.enterprise.context.RequestScoped; +import javax.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.vertx.ConsumeEvent; +import io.smallrye.mutiny.Uni; +import io.vertx.core.eventbus.EventBus; + +public class RequestContextTerminationTest { + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses(SimpleBean.class)); + + @Inject + EventBus eventBus; + + @Test + public void testTermination() throws InterruptedException { + assertTerminated("foo"); + assertTerminated("foo-cs"); + assertTerminated("foo-uni"); + } + + void assertTerminated(String address) throws InterruptedException { + BlockingQueue synchronizer = new LinkedBlockingQueue<>(); + Converter.DESTROYED.set(false); + eventBus.request(address, "bongo", ar -> { + if (ar.succeeded()) { + try { + synchronizer.put(ar.result().body()); + } catch (InterruptedException e) { + fail(e); + } + } else { + fail(ar.cause()); + } + }); + assertEquals("BONGO", synchronizer.poll(2, TimeUnit.SECONDS)); + assertTrue(Converter.DESTROYED.get()); + } + + @Test + public void testFailureNoReplyHandler() throws InterruptedException { + } + + static class SimpleBean { + + @Inject + Converter converter; + + @ConsumeEvent("foo") + String foo(String message) { + return converter.convert(message); + } + + @ConsumeEvent("foo-cs") + CompletionStage asyncFoo(String message) { + return CompletableFuture.completedFuture(converter.convert(message)); + } + + @ConsumeEvent("foo-uni") + Uni asyncFooUni(String message) { + return Uni.createFrom().item(converter.convert(message)); + } + + } + + @RequestScoped + static class Converter { + + static final AtomicBoolean DESTROYED = new AtomicBoolean(); + + String convert(String val) { + return val.toUpperCase(); + } + + @PreDestroy + void destroy() { + DESTROYED.set(true); + } + + } + +} diff --git a/extensions/vertx/runtime/src/main/java/io/quarkus/vertx/runtime/EventConsumerInvoker.java b/extensions/vertx/runtime/src/main/java/io/quarkus/vertx/runtime/EventConsumerInvoker.java index beeede5b396b8..4c7e999881b98 100644 --- a/extensions/vertx/runtime/src/main/java/io/quarkus/vertx/runtime/EventConsumerInvoker.java +++ b/extensions/vertx/runtime/src/main/java/io/quarkus/vertx/runtime/EventConsumerInvoker.java @@ -1,16 +1,110 @@ package io.quarkus.vertx.runtime; -import io.quarkus.arc.runtime.BeanInvoker; +import java.util.concurrent.CompletionStage; +import java.util.function.BiConsumer; + +import io.quarkus.arc.Arc; +import io.quarkus.arc.InjectableContext.ContextState; +import io.quarkus.arc.ManagedContext; import io.quarkus.vertx.ConsumeEvent; import io.vertx.core.eventbus.Message; /** * Invokes a business method annotated with {@link ConsumeEvent}. */ -public interface EventConsumerInvoker extends BeanInvoker> { +public abstract class EventConsumerInvoker { - default boolean isBlocking() { + public boolean isBlocking() { return false; } + public void invoke(Message message) throws Exception { + ManagedContext requestContext = Arc.container().requestContext(); + if (requestContext.isActive()) { + CompletionStage ret = invokeBean(message); + if (ret != null) { + ret.whenComplete(new RequestActiveConsumer(message)); + } + } else { + // Activate the request context + requestContext.activate(); + CompletionStage ret; + try { + ret = invokeBean(message); + } catch (Exception e) { + // Terminate the request context and re-throw the exception + requestContext.terminate(); + throw e; + } + if (ret == null) { + // No async computation - just terminate + requestContext.terminate(); + } else { + // Capture the state, deactivate and destroy the context when the computation completes + ContextState endState = requestContext.getState(); + requestContext.deactivate(); + ret.whenComplete(new RequestActivatedConsumer(message, requestContext, endState)); + } + } + } + + protected abstract CompletionStage invokeBean(Message message) throws Exception; + + private static class RequestActiveConsumer implements BiConsumer { + + private final Message message; + + RequestActiveConsumer(Message message) { + this.message = message; + } + + @Override + public void accept(Object result, Throwable failure) { + if (failure != null) { + if (message.replyAddress() == null) { + // No reply handler + throw VertxRecorder.wrapIfNecessary(failure); + } else { + message.fail(ConsumeEvent.EXPLICIT_FAILURE_CODE, failure.getMessage()); + } + } else { + message.reply(result); + } + } + + } + + private static class RequestActivatedConsumer implements BiConsumer { + + private final Message message; + private final ManagedContext requestContext; + private final ContextState endState; + + public RequestActivatedConsumer(Message message, ManagedContext requestContext, ContextState endState) { + this.message = message; + this.requestContext = requestContext; + this.endState = endState; + } + + @Override + public void accept(Object result, Throwable failure) { + try { + requestContext.destroy(endState); + } catch (Exception e) { + throw VertxRecorder.wrapIfNecessary(e); + } + if (failure != null) { + if (message.replyAddress() == null) { + // No reply handler + throw VertxRecorder.wrapIfNecessary(failure); + } else { + message.fail(ConsumeEvent.EXPLICIT_FAILURE_CODE, failure.getMessage()); + } + } else { + message.reply(result); + } + } + + } + } diff --git a/extensions/vertx/runtime/src/main/java/io/quarkus/vertx/runtime/VertxRecorder.java b/extensions/vertx/runtime/src/main/java/io/quarkus/vertx/runtime/VertxRecorder.java index b84270f741ec8..b340eabd863a7 100644 --- a/extensions/vertx/runtime/src/main/java/io/quarkus/vertx/runtime/VertxRecorder.java +++ b/extensions/vertx/runtime/src/main/java/io/quarkus/vertx/runtime/VertxRecorder.java @@ -133,8 +133,10 @@ public void handle(AsyncResult ar) { } } - private RuntimeException wrapIfNecessary(Exception e) { - if (e instanceof RuntimeException) { + static RuntimeException wrapIfNecessary(Throwable e) { + if (e instanceof Error) { + throw (Error) e; + } else if (e instanceof RuntimeException) { return (RuntimeException) e; } else { return new RuntimeException(e);