diff --git a/implementation/src/main/java/io/smallrye/mutiny/operators/multi/MultiOnSubscribeCall.java b/implementation/src/main/java/io/smallrye/mutiny/operators/multi/MultiOnSubscribeCall.java index 0107c86f3..27b4c38ab 100644 --- a/implementation/src/main/java/io/smallrye/mutiny/operators/multi/MultiOnSubscribeCall.java +++ b/implementation/src/main/java/io/smallrye/mutiny/operators/multi/MultiOnSubscribeCall.java @@ -1,9 +1,11 @@ package io.smallrye.mutiny.operators.multi; import static io.smallrye.mutiny.helpers.Subscriptions.CANCELLED; +import static io.smallrye.mutiny.helpers.Subscriptions.empty; import java.util.Objects; import java.util.concurrent.Flow; +import java.util.concurrent.locks.ReentrantLock; import java.util.function.Function; import io.smallrye.mutiny.Multi; @@ -39,6 +41,11 @@ public void subscribe(MultiSubscriber actual) { private final class OnSubscribeSubscriber extends MultiOperatorProcessor { + private final ReentrantLock lock = new ReentrantLock(); + private Throwable failure; + private boolean terminatedEarly; + private boolean uniHasTerminated; + OnSubscribeSubscriber(MultiSubscriber downstream) { super(downstream); } @@ -48,13 +55,9 @@ public void onSubscribe(Flow.Subscription s) { if (compareAndSetUpstreamSubscription(null, s)) { try { Uni uni = Objects.requireNonNull(onSubscribe.apply(s), "The produced Uni must not be `null`"); - uni - .subscribe().with( - ignored -> downstream.onSubscribe(this), - failure -> { - Subscriptions.fail(downstream, failure); - getAndSetUpstreamSubscription(CANCELLED).cancel(); - }); + uni.subscribe().with( + ignored -> uniCompleted(), + err -> uniFailed(err)); } catch (Throwable e) { Subscriptions.fail(downstream, e); getAndSetUpstreamSubscription(CANCELLED).cancel(); @@ -63,6 +66,84 @@ public void onSubscribe(Flow.Subscription s) { s.cancel(); } } + + /* + * A note on locks. + * + * The methods below use a lock, but most don't use the idiomatic pattern: + * + * lock.lock(); + * try { + * // -- Critical section here -- + * } finally { + * lock.unlock(); + * } + * + * This is being done on purpose, and not just to make sure static analysis tools + * have something to complain about. If all you do is updating fields, and you don't + * call any method that might throw, then you can take more freedom. + * + * Most notably, we need to make sure that we don't dispatch signals (e.g., onFailure()) + * while we hold a lock. + */ + + @Override + public void onFailure(Throwable throwable) { + lock.lock(); + if (!uniHasTerminated) { + terminatedEarly = true; + this.failure = throwable; + lock.unlock(); + } else { + lock.unlock(); + super.onFailure(throwable); + } + } + + @Override + public void onCompletion() { + lock.lock(); + if (!uniHasTerminated) { + terminatedEarly = true; + lock.unlock(); + } else { + lock.unlock(); + super.onCompletion(); + } + } + + private void uniFailed(Throwable failure) { + getAndSetUpstreamSubscription(CANCELLED).cancel(); + lock.lock(); + try { + uniHasTerminated = true; + if (this.failure == null) { + this.failure = failure; + } else { + this.failure.addSuppressed(failure); + } + } finally { + lock.unlock(); + } + Subscriptions.fail(downstream, this.failure); + } + + private void uniCompleted() { + lock.lock(); + uniHasTerminated = true; + lock.unlock(); + if (terminatedEarly) { + getAndSetUpstreamSubscription(CANCELLED).cancel(); + downstream.onSubscribe(empty()); + if (this.failure != null) { + downstream.onFailure(failure); + } else { + downstream.onComplete(); + } + } else { + downstream.onSubscribe(this); + } + } } } diff --git a/implementation/src/test/java/io/smallrye/mutiny/operators/MultiOnSubscribeTest.java b/implementation/src/test/java/io/smallrye/mutiny/operators/MultiOnSubscribeTest.java index edd6f864f..373aabd02 100644 --- a/implementation/src/test/java/io/smallrye/mutiny/operators/MultiOnSubscribeTest.java +++ b/implementation/src/test/java/io/smallrye/mutiny/operators/MultiOnSubscribeTest.java @@ -5,12 +5,18 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import java.io.IOException; -import java.util.concurrent.*; +import java.time.Duration; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.Flow.Subscription; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; +import java.util.stream.Stream; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.parallel.ResourceAccessMode; @@ -317,4 +323,82 @@ public void testRunSubscriptionOnShutdownExecutorRequests() { subscriber.assertFailedWith(RejectedExecutionException.class, ""); } + @Test + public void testNoEarlyCompletionBeforeCallCompletes() { + AtomicReference callSubscriptionRef = new AtomicReference<>(); + + AssertSubscriber testSubscriber = Multi.createFrom().items(Stream::empty) + .onSubscription().call(sub -> Uni.createFrom().item(sub) + .onItem().delayIt().by(Duration.ofSeconds(1L)) + .onItem().invoke(callSubscriptionRef::set)) + .subscribe().withSubscriber(AssertSubscriber.create(Long.MAX_VALUE)); + + testSubscriber.awaitCompletion().assertSubscribed().assertHasNotReceivedAnyItem(); + assertThat(callSubscriptionRef.get()).isNotNull(); + } + + @Test + public void testNoEarlyFailureBeforeCallCompletes() { + AtomicReference callSubscriptionRef = new AtomicReference<>(); + + AssertSubscriber testSubscriber = Multi.createFrom().emitter(emitter -> { + emitter.fail(new IOException("boom")); + }).onSubscription().call(sub -> Uni.createFrom().item(sub) + .onItem().delayIt().by(Duration.ofSeconds(1L)) + .onItem().invoke(callSubscriptionRef::set)) + .subscribe().withSubscriber(AssertSubscriber.create(Long.MAX_VALUE)); + + testSubscriber + .awaitFailure() + .assertSubscribed() + .assertHasNotReceivedAnyItem() + .assertFailedWith(IOException.class, "boom"); + assertThat(callSubscriptionRef.get()).isNotNull(); + } + + @Test + public void testNoEarlyFailureBeforeCallFails() { + AtomicBoolean sync = new AtomicBoolean(); + AssertSubscriber testSubscriber = Multi.createFrom().emitter(emitter -> { + emitter.fail(new IOException("boom")); + sync.set(true); + }) + .onSubscription().call(sub -> Uni.createFrom() + .emitter(uniEmitter -> { + await().untilTrue(sync); + uniEmitter.fail(new RuntimeException("woops")); + }) + .runSubscriptionOn(Infrastructure.getDefaultExecutor())) + .subscribe().withSubscriber(AssertSubscriber.create(Long.MAX_VALUE)); + + testSubscriber + .awaitFailure() + .assertSubscribed() + .assertHasNotReceivedAnyItem() + .assertFailedWith(IOException.class, "boom"); + Throwable failure = testSubscriber.getFailure(); + assertThat(failure).hasSuppressedException(new RuntimeException("woops")); + } + + @Test + public void testNoEarlyCompletionAfterCallFails() { + AtomicBoolean sync = new AtomicBoolean(); + AssertSubscriber testSubscriber = Multi.createFrom().emitter(emitter -> { + await().untilTrue(sync); + emitter.complete(); + }) + .onSubscription().call(sub -> Uni.createFrom() + .emitter(uniEmitter -> { + uniEmitter.fail(new RuntimeException("woops")); + sync.set(true); + }) + .runSubscriptionOn(Infrastructure.getDefaultExecutor())) + .subscribe().withSubscriber(AssertSubscriber.create(Long.MAX_VALUE)); + + testSubscriber + .awaitFailure() + .assertSubscribed() + .assertHasNotReceivedAnyItem() + .assertFailedWith(RuntimeException.class, "woops"); + } }