Skip to content

Commit

Permalink
Merge pull request #3078 from davidmoten/switch-fix
Browse files Browse the repository at this point in the history
switchOnNext() - fix lost requests race condition
  • Loading branch information
akarnokd committed Aug 12, 2015
2 parents cb1712d + e22acf6 commit 9a84006
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 107 deletions.
150 changes: 53 additions & 97 deletions src/main/java/rx/internal/operators/OperatorSwitch.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import rx.Observable.Operator;
import rx.Producer;
import rx.Subscriber;
import rx.internal.producers.ProducerArbiter;
import rx.observers.SerializedSubscriber;
import rx.subscriptions.SerialSubscription;

Expand All @@ -46,7 +47,9 @@ private static final class Holder {
public static <T> OperatorSwitch<T> instance() {
return (OperatorSwitch<T>)Holder.INSTANCE;
}

private OperatorSwitch() { }

@Override
public Subscriber<? super Observable<? extends T>> call(final Subscriber<? super T> child) {
SwitchSubscriber<T> sws = new SwitchSubscriber<T>(child);
Expand All @@ -55,10 +58,12 @@ public Subscriber<? super Observable<? extends T>> call(final Subscriber<? super
}

private static final class SwitchSubscriber<T> extends Subscriber<Observable<? extends T>> {
final SerializedSubscriber<T> s;
final SerializedSubscriber<T> serializedChild;
final SerialSubscription ssub;
final Object guard = new Object();
final NotificationLite<?> nl = NotificationLite.instance();
final ProducerArbiter arbiter;

/** Guarded by guard. */
int index;
/** Guarded by guard. */
Expand All @@ -70,50 +75,19 @@ private static final class SwitchSubscriber<T> extends Subscriber<Observable<? e
/** Guarded by guard. */
boolean emitting;
/** Guarded by guard. */
InnerSubscriber currentSubscriber;
/** Guarded by guard. */
long initialRequested;

volatile boolean infinite = false;
InnerSubscriber<T> currentSubscriber;

public SwitchSubscriber(Subscriber<? super T> child) {
s = new SerializedSubscriber<T>(child);
SwitchSubscriber(Subscriber<? super T> child) {
serializedChild = new SerializedSubscriber<T>(child);
arbiter = new ProducerArbiter();
ssub = new SerialSubscription();
child.add(ssub);
child.setProducer(new Producer(){

@Override
public void request(long n) {
if (infinite) {
return;
}
if(n == Long.MAX_VALUE) {
infinite = true;
}
InnerSubscriber localSubscriber;
synchronized (guard) {
localSubscriber = currentSubscriber;
if (currentSubscriber == null) {
long r = initialRequested + n;
if (r < 0) {
infinite = true;
} else {
initialRequested = r;
}
} else {
long r = currentSubscriber.requested + n;
if (r < 0) {
infinite = true;
} else {
currentSubscriber.requested = r;
}
}
}
if (localSubscriber != null) {
if (infinite)
localSubscriber.requestMore(Long.MAX_VALUE);
else
localSubscriber.requestMore(n);
if (n > 0) {
arbiter.request(n);
}
}
});
Expand All @@ -122,26 +96,18 @@ public void request(long n) {
@Override
public void onNext(Observable<? extends T> t) {
final int id;
long remainingRequest;
synchronized (guard) {
id = ++index;
active = true;
if (infinite) {
remainingRequest = Long.MAX_VALUE;
} else {
remainingRequest = currentSubscriber == null ? initialRequested : currentSubscriber.requested;
}
currentSubscriber = new InnerSubscriber(id, remainingRequest);
currentSubscriber.requested = remainingRequest;
currentSubscriber = new InnerSubscriber<T>(id, arbiter, this);
}
ssub.set(currentSubscriber);

t.unsafeSubscribe(currentSubscriber);
}

@Override
public void onError(Throwable e) {
s.onError(e);
serializedChild.onError(e);
unsubscribe();
}

Expand All @@ -165,10 +131,10 @@ public void onCompleted() {
emitting = true;
}
drain(localQueue);
s.onCompleted();
serializedChild.onCompleted();
unsubscribe();
}
void emit(T value, int id, InnerSubscriber innerSubscriber) {
void emit(T value, int id, InnerSubscriber<T> innerSubscriber) {
List<Object> localQueue;
synchronized (guard) {
if (id != index) {
Expand All @@ -178,8 +144,6 @@ void emit(T value, int id, InnerSubscriber innerSubscriber) {
if (queue == null) {
queue = new ArrayList<Object>();
}
if (innerSubscriber.requested != Long.MAX_VALUE)
innerSubscriber.requested--;
queue.add(value);
return;
}
Expand All @@ -194,11 +158,8 @@ void emit(T value, int id, InnerSubscriber innerSubscriber) {
drain(localQueue);
if (once) {
once = false;
synchronized (guard) {
if (innerSubscriber.requested != Long.MAX_VALUE)
innerSubscriber.requested--;
}
s.onNext(value);
serializedChild.onNext(value);
arbiter.produced(1);
}
synchronized (guard) {
localQueue = queue;
Expand All @@ -209,7 +170,7 @@ void emit(T value, int id, InnerSubscriber innerSubscriber) {
break;
}
}
} while (!s.isUnsubscribed());
} while (!serializedChild.isUnsubscribed());
} finally {
if (!skipFinal) {
synchronized (guard) {
Expand All @@ -224,16 +185,17 @@ void drain(List<Object> localQueue) {
}
for (Object o : localQueue) {
if (nl.isCompleted(o)) {
s.onCompleted();
serializedChild.onCompleted();
break;
} else
if (nl.isError(o)) {
s.onError(nl.getError(o));
serializedChild.onError(nl.getError(o));
break;
} else {
@SuppressWarnings("unchecked")
T t = (T)o;
s.onNext(t);
serializedChild.onNext(t);
arbiter.produced(1);
}
}
}
Expand All @@ -258,7 +220,7 @@ void error(Throwable e, int id) {
}

drain(localQueue);
s.onError(e);
serializedChild.onError(e);
unsubscribe();
}
void complete(int id) {
Expand All @@ -285,51 +247,45 @@ void complete(int id) {
}

drain(localQueue);
s.onCompleted();
serializedChild.onCompleted();
unsubscribe();
}

final class InnerSubscriber extends Subscriber<T> {

/**
* The number of request that is not acknowledged.
*
* Guarded by guard.
*/
private long requested = 0;

private final int id;
}

private static final class InnerSubscriber<T> extends Subscriber<T> {

private final long initialRequested;
private final int id;

public InnerSubscriber(int id, long initialRequested) {
this.id = id;
this.initialRequested = initialRequested;
}
private final ProducerArbiter arbiter;

@Override
public void onStart() {
requestMore(initialRequested);
}
private final SwitchSubscriber<T> parent;

public void requestMore(long n) {
request(n);
}
InnerSubscriber(int id, ProducerArbiter arbiter, SwitchSubscriber<T> parent) {
this.id = id;
this.arbiter = arbiter;
this.parent = parent;
}

@Override
public void setProducer(Producer p) {
arbiter.setProducer(p);
}

@Override
public void onNext(T t) {
emit(t, id, this);
}
@Override
public void onNext(T t) {
parent.emit(t, id, this);
}

@Override
public void onError(Throwable e) {
error(e, id);
}
@Override
public void onError(Throwable e) {
parent.error(e, id);
}

@Override
public void onCompleted() {
complete(id);
}
@Override
public void onCompleted() {
parent.complete(id);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import rx.Observable;
import rx.Observable.OnSubscribe;
import rx.functions.Action0;
import rx.functions.Action1;
import rx.observers.TestSubscriber;
import rx.schedulers.Schedulers;
import rx.subscriptions.Subscriptions;
Expand Down
19 changes: 10 additions & 9 deletions src/test/java/rx/internal/operators/OperatorSwitchTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
Expand Down Expand Up @@ -642,32 +641,34 @@ public Observable<Long> call(Long t) {
}

@Test(timeout = 10000)
public void testSecondaryRequestsAdditivelyAreMoreThanLongMaxValueInducesMaxValueRequestFromUpstream() throws InterruptedException {
public void testSecondaryRequestsAdditivelyAreMoreThanLongMaxValueInducesMaxValueRequestFromUpstream()
throws InterruptedException {
final List<Long> requests = new CopyOnWriteArrayList<Long>();
final Action1<Long> addRequest = new Action1<Long>() {

@Override
public void call(Long n) {
requests.add(n);
}};
TestSubscriber<Long> ts = new TestSubscriber<Long>(0);
}
};
TestSubscriber<Long> ts = new TestSubscriber<Long>(1);
Observable.switchOnNext(
Observable.interval(100, TimeUnit.MILLISECONDS)
.map(new Func1<Long, Observable<Long>>() {
@Override
public Observable<Long> call(Long t) {
return Observable.from(Arrays.asList(1L, 2L, 3L)).doOnRequest(addRequest);
return Observable.from(Arrays.asList(1L, 2L, 3L)).doOnRequest(
addRequest);
}
}).take(3)).subscribe(ts);
ts.requestMore(1);
//we will miss two of the first observable
// we will miss two of the first observables
Thread.sleep(250);
ts.requestMore(Long.MAX_VALUE - 1);
ts.requestMore(Long.MAX_VALUE - 1);
ts.awaitTerminalEvent();
assertTrue(ts.getOnNextEvents().size() > 0);
assertEquals(5, (int) requests.size());
assertEquals(Long.MAX_VALUE, (long) requests.get(3));
assertEquals(Long.MAX_VALUE, (long) requests.get(4));
assertEquals(Long.MAX_VALUE, (long) requests.get(requests.size()-1));
}

}

0 comments on commit 9a84006

Please sign in to comment.