diff --git a/rxjava-core/src/main/java/rx/Observable.java b/rxjava-core/src/main/java/rx/Observable.java index 5a703c5019..06a08ef576 100644 --- a/rxjava-core/src/main/java/rx/Observable.java +++ b/rxjava-core/src/main/java/rx/Observable.java @@ -53,6 +53,7 @@ import rx.operators.OperationFirstOrDefault; import rx.operators.OperationGroupBy; import rx.operators.OperationInterval; +import rx.operators.OperationJoin; import rx.operators.OperationJoinPatterns; import rx.operators.OperationLast; import rx.operators.OperationMap; @@ -5942,5 +5943,26 @@ public static Observable when(Plan0 p1, Plan0 p2, Plan0 p3, Plan public static Observable when(Plan0 p1, Plan0 p2, Plan0 p3, Plan0 p4, Plan0 p5, Plan0 p6, Plan0 p7, Plan0 p8, Plan0 p9) { return create(OperationJoinPatterns.when(p1, p2, p3, p4, p5, p6, p7, p8, p9)); } + /** + * Correlates the elements of two sequences based on overlapping durations. + * @param right The right observable sequence to join elements for. + * @param leftDurationSelector A function to select the duration of each + * element of this observable sequence, used to + * determine overlap. + * @param rightDurationSelector A function to select the duration of each + * element of the right observable sequence, + * used to determine overlap. + * @param resultSelector A function invoked to compute a result element + * for any two overlapping elements of the left and + * right observable sequences. + * @return An observable sequence that contains result elements computed + * from source elements that have an overlapping duration. + * @see MSDN: Observable.Join + */ + public Observable join(Observable right, Func1> leftDurationSelector, + Func1> rightDurationSelector, + Func2 resultSelector) { + return create(new OperationJoin(this, right, leftDurationSelector, rightDurationSelector, resultSelector)); + } } diff --git a/rxjava-core/src/main/java/rx/operators/OperationJoin.java b/rxjava-core/src/main/java/rx/operators/OperationJoin.java new file mode 100644 index 0000000000..b75b8498b0 --- /dev/null +++ b/rxjava-core/src/main/java/rx/operators/OperationJoin.java @@ -0,0 +1,277 @@ +/** + * Copyright 2013 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package rx.operators; + +import java.util.HashMap; +import java.util.Map; +import rx.Observable; +import rx.Observable.OnSubscribeFunc; +import rx.Observer; +import rx.Subscription; +import rx.subscriptions.CompositeSubscription; +import rx.subscriptions.SerialSubscription; +import rx.util.functions.Func1; +import rx.util.functions.Func2; + +/** + * Correlates the elements of two sequences based on overlapping durations. + */ +public class OperationJoin implements OnSubscribeFunc { + final Observable left; + final Observable right; + final Func1> leftDurationSelector; + final Func1> rightDurationSelector; + final Func2 resultSelector; + public OperationJoin( + Observable left, + Observable right, + Func1> leftDurationSelector, + Func1> rightDurationSelector, + Func2 resultSelector) { + this.left = left; + this.right = right; + this.leftDurationSelector = leftDurationSelector; + this.rightDurationSelector = rightDurationSelector; + this.resultSelector = resultSelector; + } + + @Override + public Subscription onSubscribe(Observer t1) { + SerialSubscription cancel = new SerialSubscription(); + ResultSink result = new ResultSink(t1, cancel); + cancel.setSubscription(result.run()); + return cancel; + } + /** Manage the left and right sources. */ + class ResultSink { + final Object gate = new Object(); + final CompositeSubscription group = new CompositeSubscription(); + boolean leftDone; + int leftId; + final Map leftMap = new HashMap(); + boolean rightDone; + int rightId; + final Map rightMap = new HashMap(); + final Observer observer; + final Subscription cancel; + public ResultSink(Observer observer, Subscription cancel) { + this.observer = observer; + this.cancel = cancel; + } + public Subscription run() { + SerialSubscription leftCancel = new SerialSubscription(); + SerialSubscription rightCancel = new SerialSubscription(); + + group.add(leftCancel); + group.add(rightCancel); + + leftCancel.setSubscription(left.subscribe(new LeftObserver(leftCancel))); + rightCancel.setSubscription(right.subscribe(new RightObserver(rightCancel))); + + return group; + } + /** Observes the left values. */ + class LeftObserver implements Observer { + final Subscription self; + public LeftObserver(Subscription self) { + this.self = self; + } + protected void expire(int id, Subscription resource) { + synchronized (gate) { + if (leftMap.remove(id) != null && leftMap.isEmpty() && leftDone) { + observer.onCompleted(); + cancel.unsubscribe(); + } + } + group.remove(resource); + } + @Override + public void onNext(TLeft args) { + int id; + synchronized (gate) { + id = leftId++; + leftMap.put(id, args); + } + SerialSubscription md = new SerialSubscription(); + group.add(md); + + Observable duration; + try { + duration = leftDurationSelector.call(args); + } catch (Throwable t) { + observer.onError(t); + cancel.unsubscribe(); + return; + } + + md.setSubscription(duration.subscribe(new LeftDurationObserver(id, md))); + + synchronized (gate) { + for (TRight r : rightMap.values()) { + R result; + try { + result = resultSelector.call(args, r); + } catch (Throwable t) { + observer.onError(t); + cancel.unsubscribe(); + return; + } + observer.onNext(result); + } + } + } + @Override + public void onError(Throwable e) { + synchronized (gate) { + observer.onError(e); + cancel.unsubscribe(); + } + } + @Override + public void onCompleted() { + synchronized (gate) { + leftDone = true; + if (rightDone || leftMap.isEmpty()) { + observer.onCompleted(); + cancel.unsubscribe(); + } else { + self.unsubscribe(); + } + } + } + /** Observes the left duration. */ + class LeftDurationObserver implements Observer { + final int id; + final Subscription handle; + public LeftDurationObserver(int id, Subscription handle) { + this.id = id; + this.handle = handle; + } + + @Override + public void onNext(TLeftDuration args) { + expire(id, handle); + } + + @Override + public void onError(Throwable e) { + LeftObserver.this.onError(e); + } + + @Override + public void onCompleted() { + expire(id, handle); + } + + } + } + /** Observes the right values. */ + class RightObserver implements Observer { + final Subscription self; + public RightObserver(Subscription self) { + this.self = self; + } + void expire(int id, Subscription resource) { + synchronized (gate) { + if (rightMap.remove(id) != null && rightMap.isEmpty() && rightDone) { + observer.onCompleted(); + cancel.unsubscribe(); + } + } + group.remove(resource); + } + @Override + public void onNext(TRight args) { + int id = 0; + synchronized (gate) { + id = rightId++; + rightMap.put(id, args); + } + SerialSubscription md = new SerialSubscription(); + group.add(md); + + Observable duration; + try { + duration = rightDurationSelector.call(args); + } catch (Throwable t) { + observer.onError(t); + cancel.unsubscribe(); + return; + } + + md.setSubscription(duration.subscribe(new RightDurationObserver(id, md))); + + synchronized (gate) { + for (TLeft lv : leftMap.values()) { + R result; + try { + result = resultSelector.call(lv, args); + } catch (Throwable t) { + observer.onError(t); + cancel.unsubscribe(); + return; + } + observer.onNext(result); + } + } + } + @Override + public void onError(Throwable e) { + synchronized (gate) { + observer.onError(e); + cancel.unsubscribe(); + } + } + @Override + public void onCompleted() { + synchronized (gate) { + rightDone = true; + if (leftDone || rightMap.isEmpty()) { + observer.onCompleted(); + cancel.unsubscribe(); + } else { + self.unsubscribe(); + } + } + } + /** Observe the right duration. */ + class RightDurationObserver implements Observer { + final int id; + final Subscription handle; + public RightDurationObserver(int id, Subscription handle) { + this.id = id; + this.handle = handle; + } + + @Override + public void onNext(TRightDuration args) { + expire(id, handle); + } + + @Override + public void onError(Throwable e) { + RightObserver.this.onError(e); + } + + @Override + public void onCompleted() { + expire(id, handle); + } + + } + } + } +} diff --git a/rxjava-core/src/test/java/rx/operators/OperationJoinTest.java b/rxjava-core/src/test/java/rx/operators/OperationJoinTest.java new file mode 100644 index 0000000000..8f841ebb50 --- /dev/null +++ b/rxjava-core/src/test/java/rx/operators/OperationJoinTest.java @@ -0,0 +1,302 @@ +/** + * Copyright 2013 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package rx.operators; + +import java.util.Collection; +import org.junit.Before; +import org.junit.Test; +import static org.mockito.Matchers.any; +import org.mockito.Mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import org.mockito.MockitoAnnotations; +import rx.Observable; +import rx.Observer; +import rx.subjects.PublishSubject; +import rx.util.functions.Action1; +import rx.util.functions.Func1; +import rx.util.functions.Func2; + +public class OperationJoinTest { + @Mock + Observer observer; + + Func2 add = new Func2() { + @Override + public Integer call(Integer t1, Integer t2) { + return t1 + t2; + } + }; + Func1> just(final Observable observable) { + return new Func1>() { + @Override + public Observable call(Integer t1) { + return observable; + } + }; + } + @Before + public void before() { + MockitoAnnotations.initMocks(this); + } + @Test + public void normal1() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Observable m = source1.join(source2, + just(Observable.never()), + just(Observable.never()), add); + + m.subscribe(observer); + + source1.onNext(1); + source1.onNext(2); + source1.onNext(4); + + source2.onNext(16); + source2.onNext(32); + source2.onNext(64); + + source1.onCompleted(); + source2.onCompleted(); + + verify(observer, times(1)).onNext(17); + verify(observer, times(1)).onNext(18); + verify(observer, times(1)).onNext(20); + verify(observer, times(1)).onNext(33); + verify(observer, times(1)).onNext(34); + verify(observer, times(1)).onNext(36); + verify(observer, times(1)).onNext(65); + verify(observer, times(1)).onNext(66); + verify(observer, times(1)).onNext(68); + + verify(observer, times(1)).onCompleted(); + verify(observer, never()).onError(any(Throwable.class)); + } + @Test + public void normal1WithDuration() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + PublishSubject duration1 = PublishSubject.create(); + + Observable m = source1.join(source2, + just(duration1), + just(Observable.never()), add); + m.subscribe(observer); + + source1.onNext(1); + source1.onNext(2); + source2.onNext(16); + + duration1.onNext(1); + + source1.onNext(4); + source1.onNext(8); + + source1.onCompleted(); + source2.onCompleted(); + + verify(observer, times(1)).onNext(17); + verify(observer, times(1)).onNext(18); + verify(observer, times(1)).onNext(20); + verify(observer, times(1)).onNext(24); + + verify(observer, times(1)).onCompleted(); + verify(observer, never()).onError(any(Throwable.class)); + + } + @Test + public void normal2() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Observable m = source1.join(source2, + just(Observable.never()), + just(Observable.never()), add); + + m.subscribe(observer); + + source1.onNext(1); + source1.onNext(2); + source1.onCompleted(); + + source2.onNext(16); + source2.onNext(32); + source2.onNext(64); + + source2.onCompleted(); + + verify(observer, times(1)).onNext(17); + verify(observer, times(1)).onNext(18); + verify(observer, times(1)).onNext(33); + verify(observer, times(1)).onNext(34); + verify(observer, times(1)).onNext(65); + verify(observer, times(1)).onNext(66); + + verify(observer, times(1)).onCompleted(); + verify(observer, never()).onError(any(Throwable.class)); + } + @Test + public void leftThrows() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Observable m = source1.join(source2, + just(Observable.never()), + just(Observable.never()), add); + + m.subscribe(observer); + + source2.onNext(1); + source1.onError(new RuntimeException("Forced failure")); + + verify(observer, times(1)).onError(any(Throwable.class)); + verify(observer, never()).onCompleted(); + verify(observer, never()).onNext(any()); + } + @Test + public void rightThrows() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Observable m = source1.join(source2, + just(Observable.never()), + just(Observable.never()), add); + + m.subscribe(observer); + + source1.onNext(1); + source2.onError(new RuntimeException("Forced failure")); + + verify(observer, times(1)).onError(any(Throwable.class)); + verify(observer, never()).onCompleted(); + verify(observer, never()).onNext(any()); + } + @Test + public void leftDurationThrows() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Observable duration1 = Observable.error(new RuntimeException("Forced failure")); + + Observable m = source1.join(source2, + just(duration1), + just(Observable.never()), add); + m.subscribe(observer); + + source1.onNext(1); + + + verify(observer, times(1)).onError(any(Throwable.class)); + verify(observer, never()).onCompleted(); + verify(observer, never()).onNext(any()); + } + @Test + public void rightDurationThrows() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Observable duration1 = Observable.error(new RuntimeException("Forced failure")); + + Observable m = source1.join(source2, + just(Observable.never()), + just(duration1), add); + m.subscribe(observer); + + source2.onNext(1); + + + verify(observer, times(1)).onError(any(Throwable.class)); + verify(observer, never()).onCompleted(); + verify(observer, never()).onNext(any()); + } + @Test + public void leftDurationSelectorThrows() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Func1> fail = new Func1>() { + @Override + public Observable call(Integer t1) { + throw new RuntimeException("Forced failure"); + } + }; + + Observable m = source1.join(source2, + fail, + just(Observable.never()), add); + m.subscribe(observer); + + source1.onNext(1); + + + verify(observer, times(1)).onError(any(Throwable.class)); + verify(observer, never()).onCompleted(); + verify(observer, never()).onNext(any()); + } + @Test + public void rightDurationSelectorThrows() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Func1> fail = new Func1>() { + @Override + public Observable call(Integer t1) { + throw new RuntimeException("Forced failure"); + } + }; + + Observable m = source1.join(source2, + just(Observable.never()), + fail, add); + m.subscribe(observer); + + source2.onNext(1); + + + verify(observer, times(1)).onError(any(Throwable.class)); + verify(observer, never()).onCompleted(); + verify(observer, never()).onNext(any()); + } + @Test + public void resultSelectorThrows() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Func2 fail = new Func2() { + @Override + public Integer call(Integer t1, Integer t2) { + throw new RuntimeException("Forced failure"); + } + }; + + Observable m = source1.join(source2, + just(Observable.never()), + just(Observable.never()), fail); + m.subscribe(observer); + + source1.onNext(1); + source2.onNext(2); + + + verify(observer, times(1)).onError(any(Throwable.class)); + verify(observer, never()).onCompleted(); + verify(observer, never()).onNext(any()); + } +}