From e22acf63e83e483c875ddb1ab15f21eda2f00d85 Mon Sep 17 00:00:00 2001 From: Dave Moten Date: Mon, 13 Jul 2015 16:02:22 +1000 Subject: [PATCH] OperatorSwitch - fix lost requests race condition using ProducerArbiter --- .../rx/internal/operators/OperatorSwitch.java | 150 +++++++----------- .../operators/OperatorSwitchIfEmptyTest.java | 1 - .../operators/OperatorSwitchTest.java | 19 +-- 3 files changed, 63 insertions(+), 107 deletions(-) diff --git a/src/main/java/rx/internal/operators/OperatorSwitch.java b/src/main/java/rx/internal/operators/OperatorSwitch.java index afd35e477d..cbd02e1b58 100644 --- a/src/main/java/rx/internal/operators/OperatorSwitch.java +++ b/src/main/java/rx/internal/operators/OperatorSwitch.java @@ -22,6 +22,7 @@ import rx.Observable.Operator; import rx.Producer; import rx.Subscriber; +import rx.internal.producers.ProducerArbiter; import rx.observers.SerializedSubscriber; import rx.subscriptions.SerialSubscription; @@ -46,7 +47,9 @@ private static final class Holder { public static OperatorSwitch instance() { return (OperatorSwitch)Holder.INSTANCE; } + private OperatorSwitch() { } + @Override public Subscriber> call(final Subscriber child) { SwitchSubscriber sws = new SwitchSubscriber(child); @@ -55,10 +58,12 @@ public Subscriber> call(final Subscriber extends Subscriber> { - final SerializedSubscriber s; + final SerializedSubscriber serializedChild; final SerialSubscription ssub; final Object guard = new Object(); final NotificationLite nl = NotificationLite.instance(); + final ProducerArbiter arbiter; + /** Guarded by guard. */ int index; /** Guarded by guard. */ @@ -70,50 +75,19 @@ private static final class SwitchSubscriber extends Subscriber currentSubscriber; - public SwitchSubscriber(Subscriber child) { - s = new SerializedSubscriber(child); + SwitchSubscriber(Subscriber child) { + serializedChild = new SerializedSubscriber(child); + arbiter = new ProducerArbiter(); ssub = new SerialSubscription(); child.add(ssub); child.setProducer(new Producer(){ @Override public void request(long n) { - if (infinite) { - return; - } - if(n == Long.MAX_VALUE) { - infinite = true; - } - InnerSubscriber localSubscriber; - synchronized (guard) { - localSubscriber = currentSubscriber; - if (currentSubscriber == null) { - long r = initialRequested + n; - if (r < 0) { - infinite = true; - } else { - initialRequested = r; - } - } else { - long r = currentSubscriber.requested + n; - if (r < 0) { - infinite = true; - } else { - currentSubscriber.requested = r; - } - } - } - if (localSubscriber != null) { - if (infinite) - localSubscriber.requestMore(Long.MAX_VALUE); - else - localSubscriber.requestMore(n); + if (n > 0) { + arbiter.request(n); } } }); @@ -122,26 +96,18 @@ public void request(long n) { @Override public void onNext(Observable t) { final int id; - long remainingRequest; synchronized (guard) { id = ++index; active = true; - if (infinite) { - remainingRequest = Long.MAX_VALUE; - } else { - remainingRequest = currentSubscriber == null ? initialRequested : currentSubscriber.requested; - } - currentSubscriber = new InnerSubscriber(id, remainingRequest); - currentSubscriber.requested = remainingRequest; + currentSubscriber = new InnerSubscriber(id, arbiter, this); } ssub.set(currentSubscriber); - t.unsafeSubscribe(currentSubscriber); } @Override public void onError(Throwable e) { - s.onError(e); + serializedChild.onError(e); unsubscribe(); } @@ -165,10 +131,10 @@ public void onCompleted() { emitting = true; } drain(localQueue); - s.onCompleted(); + serializedChild.onCompleted(); unsubscribe(); } - void emit(T value, int id, InnerSubscriber innerSubscriber) { + void emit(T value, int id, InnerSubscriber innerSubscriber) { List localQueue; synchronized (guard) { if (id != index) { @@ -178,8 +144,6 @@ void emit(T value, int id, InnerSubscriber innerSubscriber) { if (queue == null) { queue = new ArrayList(); } - if (innerSubscriber.requested != Long.MAX_VALUE) - innerSubscriber.requested--; queue.add(value); return; } @@ -194,11 +158,8 @@ void emit(T value, int id, InnerSubscriber innerSubscriber) { drain(localQueue); if (once) { once = false; - synchronized (guard) { - if (innerSubscriber.requested != Long.MAX_VALUE) - innerSubscriber.requested--; - } - s.onNext(value); + serializedChild.onNext(value); + arbiter.produced(1); } synchronized (guard) { localQueue = queue; @@ -209,7 +170,7 @@ void emit(T value, int id, InnerSubscriber innerSubscriber) { break; } } - } while (!s.isUnsubscribed()); + } while (!serializedChild.isUnsubscribed()); } finally { if (!skipFinal) { synchronized (guard) { @@ -224,16 +185,17 @@ void drain(List localQueue) { } for (Object o : localQueue) { if (nl.isCompleted(o)) { - s.onCompleted(); + serializedChild.onCompleted(); break; } else if (nl.isError(o)) { - s.onError(nl.getError(o)); + serializedChild.onError(nl.getError(o)); break; } else { @SuppressWarnings("unchecked") T t = (T)o; - s.onNext(t); + serializedChild.onNext(t); + arbiter.produced(1); } } } @@ -258,7 +220,7 @@ void error(Throwable e, int id) { } drain(localQueue); - s.onError(e); + serializedChild.onError(e); unsubscribe(); } void complete(int id) { @@ -285,51 +247,45 @@ void complete(int id) { } drain(localQueue); - s.onCompleted(); + serializedChild.onCompleted(); unsubscribe(); } - final class InnerSubscriber extends Subscriber { - - /** - * The number of request that is not acknowledged. - * - * Guarded by guard. - */ - private long requested = 0; - - private final int id; + } + + private static final class InnerSubscriber extends Subscriber { - private final long initialRequested; + private final int id; - public InnerSubscriber(int id, long initialRequested) { - this.id = id; - this.initialRequested = initialRequested; - } + private final ProducerArbiter arbiter; - @Override - public void onStart() { - requestMore(initialRequested); - } + private final SwitchSubscriber parent; - public void requestMore(long n) { - request(n); - } + InnerSubscriber(int id, ProducerArbiter arbiter, SwitchSubscriber parent) { + this.id = id; + this.arbiter = arbiter; + this.parent = parent; + } + + @Override + public void setProducer(Producer p) { + arbiter.setProducer(p); + } - @Override - public void onNext(T t) { - emit(t, id, this); - } + @Override + public void onNext(T t) { + parent.emit(t, id, this); + } - @Override - public void onError(Throwable e) { - error(e, id); - } + @Override + public void onError(Throwable e) { + parent.error(e, id); + } - @Override - public void onCompleted() { - complete(id); - } + @Override + public void onCompleted() { + parent.complete(id); } } + } diff --git a/src/test/java/rx/internal/operators/OperatorSwitchIfEmptyTest.java b/src/test/java/rx/internal/operators/OperatorSwitchIfEmptyTest.java index 2534613ab4..332924ba68 100644 --- a/src/test/java/rx/internal/operators/OperatorSwitchIfEmptyTest.java +++ b/src/test/java/rx/internal/operators/OperatorSwitchIfEmptyTest.java @@ -27,7 +27,6 @@ import rx.Observable; import rx.Observable.OnSubscribe; import rx.functions.Action0; -import rx.functions.Action1; import rx.observers.TestSubscriber; import rx.schedulers.Schedulers; import rx.subscriptions.Subscriptions; diff --git a/src/test/java/rx/internal/operators/OperatorSwitchTest.java b/src/test/java/rx/internal/operators/OperatorSwitchTest.java index 6b5d3a1f79..63de5d0d81 100644 --- a/src/test/java/rx/internal/operators/OperatorSwitchTest.java +++ b/src/test/java/rx/internal/operators/OperatorSwitchTest.java @@ -25,7 +25,6 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; @@ -642,32 +641,34 @@ public Observable call(Long t) { } @Test(timeout = 10000) - public void testSecondaryRequestsAdditivelyAreMoreThanLongMaxValueInducesMaxValueRequestFromUpstream() throws InterruptedException { + public void testSecondaryRequestsAdditivelyAreMoreThanLongMaxValueInducesMaxValueRequestFromUpstream() + throws InterruptedException { final List requests = new CopyOnWriteArrayList(); final Action1 addRequest = new Action1() { @Override public void call(Long n) { requests.add(n); - }}; - TestSubscriber ts = new TestSubscriber(0); + } + }; + TestSubscriber ts = new TestSubscriber(1); Observable.switchOnNext( Observable.interval(100, TimeUnit.MILLISECONDS) .map(new Func1>() { @Override public Observable call(Long t) { - return Observable.from(Arrays.asList(1L, 2L, 3L)).doOnRequest(addRequest); + return Observable.from(Arrays.asList(1L, 2L, 3L)).doOnRequest( + addRequest); } }).take(3)).subscribe(ts); - ts.requestMore(1); - //we will miss two of the first observable + // we will miss two of the first observables Thread.sleep(250); ts.requestMore(Long.MAX_VALUE - 1); ts.requestMore(Long.MAX_VALUE - 1); ts.awaitTerminalEvent(); assertTrue(ts.getOnNextEvents().size() > 0); assertEquals(5, (int) requests.size()); - assertEquals(Long.MAX_VALUE, (long) requests.get(3)); - assertEquals(Long.MAX_VALUE, (long) requests.get(4)); + assertEquals(Long.MAX_VALUE, (long) requests.get(requests.size()-1)); } + }