From b9d29cee399ed33ce57dbc2aa1be2a7b06000e58 Mon Sep 17 00:00:00 2001 From: Ben Christensen Date: Fri, 15 Nov 2013 18:55:01 -0800 Subject: [PATCH] BugFix: AsyncSubject - it was not emitting values to observers that subscribed after onCompleted/onError --- .../main/java/rx/subjects/AsyncSubject.java | 121 ++++++++++---- .../java/rx/subjects/AsyncSubjectTest.java | 150 ++++++++++++++++++ 2 files changed, 239 insertions(+), 32 deletions(-) diff --git a/rxjava-core/src/main/java/rx/subjects/AsyncSubject.java b/rxjava-core/src/main/java/rx/subjects/AsyncSubject.java index 183efdfda9..3f7f5d0a0f 100644 --- a/rxjava-core/src/main/java/rx/subjects/AsyncSubject.java +++ b/rxjava-core/src/main/java/rx/subjects/AsyncSubject.java @@ -18,10 +18,13 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; +import rx.Notification; import rx.Observer; import rx.Subscription; import rx.operators.SafeObservableSubscription; +import rx.subscriptions.Subscriptions; /** * Subject that publishes only the last event to each {@link Observer} that has subscribed when the @@ -60,61 +63,115 @@ public class AsyncSubject extends Subject { * @return a new AsyncSubject */ public static AsyncSubject create() { - final ConcurrentHashMap> observers = new ConcurrentHashMap>(); + final AsyncSubjectState state = new AsyncSubjectState(); OnSubscribeFunc onSubscribe = new OnSubscribeFunc() { @Override public Subscription onSubscribe(Observer observer) { - final SafeObservableSubscription subscription = new SafeObservableSubscription(); - - subscription.wrap(new Subscription() { - @Override - public void unsubscribe() { - // on unsubscribe remove it from the map of outbound observers to notify - observers.remove(subscription); + /* + * Subscription needs to be synchronized with terminal states to ensure + * race conditions are handled. When subscribing we must make sure + * onComplete/onError is correctly emitted to all observers, even if it + * comes in while the onComplete/onError is being propagated. + */ + state.SUBSCRIPTION_LOCK.lock(); + try { + if (state.completed.get()) { + emitNotificationToObserver(state, observer); + return Subscriptions.empty(); + } else { + // the subject is not completed so we subscribe + final SafeObservableSubscription subscription = new SafeObservableSubscription(); + + subscription.wrap(new Subscription() { + @Override + public void unsubscribe() { + // on unsubscribe remove it from the map of outbound observers to notify + state.observers.remove(subscription); + } + }); + + // on subscribe add it to the map of outbound observers to notify + state.observers.put(subscription, observer); + + return subscription; } - }); + } finally { + state.SUBSCRIPTION_LOCK.unlock(); + } - // on subscribe add it to the map of outbound observers to notify - observers.put(subscription, observer); - return subscription; } + }; - return new AsyncSubject(onSubscribe, observers); + return new AsyncSubject(onSubscribe, state); } - private final ConcurrentHashMap> observers; - private final AtomicReference currentValue; - private final AtomicBoolean hasValue = new AtomicBoolean(); + private static void emitNotificationToObserver(final AsyncSubjectState state, Observer observer) { + Notification finalValue = state.currentValue.get(); + + // if null that means onNext was never invoked (no Notification set) + if (finalValue != null) { + if (finalValue.isOnNext()) { + observer.onNext(finalValue.getValue()); + } else if (finalValue.isOnError()) { + observer.onError(finalValue.getThrowable()); + } + } + observer.onCompleted(); + } - protected AsyncSubject(OnSubscribeFunc onSubscribe, ConcurrentHashMap> observers) { + /** + * State externally constructed and passed in so the onSubscribe function has access to it. + * + * @param + */ + private static class AsyncSubjectState { + private final ConcurrentHashMap> observers = new ConcurrentHashMap>(); + private final AtomicReference> currentValue = new AtomicReference>(); + private final AtomicBoolean completed = new AtomicBoolean(); + private final ReentrantLock SUBSCRIPTION_LOCK = new ReentrantLock(); + } + + private final AsyncSubjectState state; + + protected AsyncSubject(OnSubscribeFunc onSubscribe, AsyncSubjectState state) { super(onSubscribe); - this.observers = observers; - this.currentValue = new AtomicReference(); + this.state = state; } @Override public void onCompleted() { - T finalValue = currentValue.get(); - for (Observer observer : observers.values()) { - if (hasValue.get()) { - observer.onNext(finalValue); - } - observer.onCompleted(); - } + terminalState(); } @Override public void onError(Throwable e) { - for (Observer observer : observers.values()) { - observer.onError(e); - } + state.currentValue.set(new Notification(e)); + terminalState(); } @Override - public void onNext(T args) { - hasValue.set(true); - currentValue.set(args); + public void onNext(T v) { + state.currentValue.set(new Notification(v)); + } + + private void terminalState() { + /* + * We can not allow new subscribers to be added while we execute the terminal state. + */ + state.SUBSCRIPTION_LOCK.lock(); + try { + if (state.completed.compareAndSet(false, true)) { + for (Subscription s : state.observers.keySet()) { + // emit notifications to this observer + emitNotificationToObserver(state, state.observers.get(s)); + // remove the subscription as it is completed + state.observers.remove(s); + } + } + } finally { + state.SUBSCRIPTION_LOCK.unlock(); + } } } diff --git a/rxjava-core/src/test/java/rx/subjects/AsyncSubjectTest.java b/rxjava-core/src/test/java/rx/subjects/AsyncSubjectTest.java index 321b7b422a..b483b9c99f 100644 --- a/rxjava-core/src/test/java/rx/subjects/AsyncSubjectTest.java +++ b/rxjava-core/src/test/java/rx/subjects/AsyncSubjectTest.java @@ -15,9 +15,13 @@ */ package rx.subjects; +import static org.junit.Assert.*; import static org.mockito.Matchers.*; import static org.mockito.Mockito.*; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + import org.junit.Test; import org.mockito.InOrder; import org.mockito.Mockito; @@ -66,6 +70,62 @@ public void testCompleted() { verify(aObserver, times(1)).onCompleted(); } + @Test + public void testNull() { + AsyncSubject subject = AsyncSubject.create(); + + @SuppressWarnings("unchecked") + Observer aObserver = mock(Observer.class); + subject.subscribe(aObserver); + + subject.onNext(null); + subject.onCompleted(); + + verify(aObserver, times(1)).onNext(null); + verify(aObserver, Mockito.never()).onError(any(Throwable.class)); + verify(aObserver, times(1)).onCompleted(); + } + + @Test + public void testSubscribeAfterCompleted() { + AsyncSubject subject = AsyncSubject.create(); + + @SuppressWarnings("unchecked") + Observer aObserver = mock(Observer.class); + + subject.onNext("one"); + subject.onNext("two"); + subject.onNext("three"); + subject.onCompleted(); + + subject.subscribe(aObserver); + + verify(aObserver, times(1)).onNext("three"); + verify(aObserver, Mockito.never()).onError(any(Throwable.class)); + verify(aObserver, times(1)).onCompleted(); + } + + @Test + public void testSubscribeAfterError() { + AsyncSubject subject = AsyncSubject.create(); + + @SuppressWarnings("unchecked") + Observer aObserver = mock(Observer.class); + + subject.onNext("one"); + subject.onNext("two"); + subject.onNext("three"); + + RuntimeException re = new RuntimeException("failed"); + subject.onError(re); + + subject.subscribe(aObserver); + + verify(aObserver, times(1)).onError(re); + verify(aObserver, Mockito.never()).onNext(any(String.class)); + verify(aObserver, Mockito.never()).onCompleted(); + } + @Test public void testError() { AsyncSubject subject = AsyncSubject.create(); @@ -151,4 +211,94 @@ public void testEmptySubjectCompleted() { inOrder.verify(aObserver, times(1)).onCompleted(); inOrder.verifyNoMoreInteractions(); } + + /** + * Can receive timeout if subscribe never receives an onError/onCompleted ... which reveals a race condition. + */ + @Test + public void testSubscribeCompletionRaceCondition() { + /* + * With non-threadsafe code this fails most of the time on my dev laptop and is non-deterministic enough + * to act as a unit test to the race conditions. + * + * With the synchronization code in place I can not get this to fail on my laptop. + */ + for (int i = 0; i < 50; i++) { + final AsyncSubject subject = AsyncSubject.create(); + final AtomicReference value1 = new AtomicReference(); + + subject.subscribe(new Action1() { + + @Override + public void call(String t1) { + try { + // simulate a slow observer + Thread.sleep(50); + } catch (InterruptedException e) { + e.printStackTrace(); + } + value1.set(t1); + } + + }); + + Thread t1 = new Thread(new Runnable() { + + @Override + public void run() { + subject.onNext("value"); + subject.onCompleted(); + } + }); + + SubjectObserverThread t2 = new SubjectObserverThread(subject); + SubjectObserverThread t3 = new SubjectObserverThread(subject); + SubjectObserverThread t4 = new SubjectObserverThread(subject); + SubjectObserverThread t5 = new SubjectObserverThread(subject); + + t2.start(); + t3.start(); + t1.start(); + t4.start(); + t5.start(); + try { + t1.join(); + t2.join(); + t3.join(); + t4.join(); + t5.join(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + + assertEquals("value", value1.get()); + assertEquals("value", t2.value.get()); + assertEquals("value", t3.value.get()); + assertEquals("value", t4.value.get()); + assertEquals("value", t5.value.get()); + } + + } + + private static class SubjectObserverThread extends Thread { + + private final AsyncSubject subject; + private final AtomicReference value = new AtomicReference(); + + public SubjectObserverThread(AsyncSubject subject) { + this.subject = subject; + } + + @Override + public void run() { + try { + // a timeout exception will happen if we don't get a terminal state + String v = subject.timeout(2000, TimeUnit.MILLISECONDS).toBlockingObservable().single(); + value.set(v); + } catch (Exception e) { + e.printStackTrace(); + } + } + } + }