From dd7c4ce98ca40b0f154a618a6d70a9699e675fcc Mon Sep 17 00:00:00 2001 From: akarnokd Date: Fri, 10 Jan 2014 15:42:40 +0100 Subject: [PATCH 1/2] MergeMap with Iterable and resultSelector overloads --- rxjava-core/src/main/java/rx/Observable.java | 45 ++++ .../java/rx/operators/OperationFlatMap.java | 208 ++++++++++++++++++ .../rx/operators/OperationFlatMapTest.java | 145 ++++++++++++ 3 files changed, 398 insertions(+) create mode 100644 rxjava-core/src/main/java/rx/operators/OperationFlatMap.java create mode 100644 rxjava-core/src/test/java/rx/operators/OperationFlatMapTest.java diff --git a/rxjava-core/src/main/java/rx/Observable.java b/rxjava-core/src/main/java/rx/Observable.java index 2ac2fb7b5d..a1227b9ef6 100644 --- a/rxjava-core/src/main/java/rx/Observable.java +++ b/rxjava-core/src/main/java/rx/Observable.java @@ -53,6 +53,7 @@ import rx.operators.OperationElementAt; import rx.operators.OperationFilter; import rx.operators.OperationFinally; +import rx.operators.OperationFlatMap; import rx.operators.OperationGroupBy; import rx.operators.OperationGroupByUntil; import rx.operators.OperationGroupJoin; @@ -3856,6 +3857,50 @@ public Observable mergeMap(Func1 the element type of the collection Observable + * @param the result type + * @param collectionSelector function that returns an Observable sequence for each value in the source Observable + * @param resultSelector function that combines the values of the source and collection Observable + * @return an Observable that applies a function to the pair of values from the source + * Observable and the collection Observable. + */ + public Observable mergeMap(Func1> collectionSelector, + Func2 resultSelector) { + return create(OperationFlatMap.flatMap(this, collectionSelector, resultSelector)); + } + + /** + * Create an Observable that merges the values of the iterables returned by the + * collectionSelector for each source value. + * @param the result value type + * @param collectionSelector function that returns an Iterable sequence of values for + * each source value. + * @return an Observable that merges the values of the iterables returned by the + * collectionSelector for each source value. + */ + public Observable mergeMapIterable(Func1> collectionSelector) { + return merge(map(OperationFlatMap.flatMapIterableFunc(collectionSelector))); + } + + /** + * Create an Observable that applies a function to the pair of values from the source + * Observable and the collection Iterable sequence. + * @param the collection element type + * @param the result type + * @param collectionSelector function that returns an Iterable sequence of values for + * each source value. + * @param resultSelector function that combines the values of the source and collection Iterable + * @return n Observable that applies a function to the pair of values from the source + * Observable and the collection Iterable sequence. + */ + public Observable mergeMapIterable(Func1> collectionSelector, + Func2 resultSelector) { + return mergeMap(OperationFlatMap.flatMapIterableFunc(collectionSelector), resultSelector); + } + /** * Creates a new Observable by applying a function that you supply to each * item emitted by the source Observable, where that function returns an diff --git a/rxjava-core/src/main/java/rx/operators/OperationFlatMap.java b/rxjava-core/src/main/java/rx/operators/OperationFlatMap.java new file mode 100644 index 0000000000..20c12993cb --- /dev/null +++ b/rxjava-core/src/main/java/rx/operators/OperationFlatMap.java @@ -0,0 +1,208 @@ +/** + * Copyright 2013 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package rx.operators; + +import java.util.concurrent.atomic.AtomicInteger; +import rx.Observable; +import rx.Observable.OnSubscribeFunc; +import rx.Observer; +import rx.Subscription; +import rx.subscriptions.CompositeSubscription; +import rx.subscriptions.SerialSubscription; +import rx.util.functions.Func1; +import rx.util.functions.Func2; + +/** + * Additional flatMap operators. + */ +public final class OperationFlatMap { + /** Utility class. */ + private OperationFlatMap() { throw new IllegalStateException("No instances!"); } + + /** + * Observable that pairs up the source values and all the derived collection + * values and projects them via the selector. + */ + public static OnSubscribeFunc flatMap(Observable source, + Func1> collectionSelector, + Func2 resultSelector + ) { + return new FlatMapPairSelector(source, collectionSelector, resultSelector); + } + /** + * Converts the result Iterable of a function into an Observable. + */ + public static Func1> flatMapIterableFunc( + Func1> collectionSelector) { + return new IterableToObservableFunc(collectionSelector); + } + /** + * Converts the result Iterable of a function into an Observable. + * @param the parameter type + * @param the result type + */ + private static final class IterableToObservableFunc implements Func1> { + final Func1> func; + + public IterableToObservableFunc(Func1> func) { + this.func = func; + } + + @Override + public Observable call(T t1) { + return Observable.from(func.call(t1)); + } + } + /** + * Pairs up the source value with each of the associated observable values + * and uses a selector function to calculate the result sequence. + * @param the source value type + * @param the collection value type + * @param the result type + */ + private static final class FlatMapPairSelector implements OnSubscribeFunc { + final Observable source; + final Func1> collectionSelector; + final Func2 resultSelector; + + public FlatMapPairSelector(Observable source, Func1> collectionSelector, Func2 resultSelector) { + this.source = source; + this.collectionSelector = collectionSelector; + this.resultSelector = resultSelector; + } + + @Override + public Subscription onSubscribe(Observer t1) { + CompositeSubscription csub = new CompositeSubscription(); + + csub.add(source.subscribe(new SourceObserver(t1, collectionSelector, resultSelector, csub))); + + return csub; + } + + /** Observes the source, starts the collections and projects the result. */ + private static final class SourceObserver implements Observer { + final Observer observer; + final Func1> collectionSelector; + final Func2 resultSelector; + final CompositeSubscription csub; + final AtomicInteger wip; + /** Don't let various events run at the same time. */ + final Object guard; + boolean done; + + public SourceObserver(Observer observer, Func1> collectionSelector, Func2 resultSelector, CompositeSubscription csub) { + this.observer = observer; + this.collectionSelector = collectionSelector; + this.resultSelector = resultSelector; + this.csub = csub; + this.wip = new AtomicInteger(1); + this.guard = new Object(); + } + + @Override + public void onNext(T args) { + Observable coll; + try { + coll = collectionSelector.call(args); + } catch (Throwable e) { + onError(e); + return; + } + + SerialSubscription ssub = new SerialSubscription(); + csub.add(ssub); + wip.incrementAndGet(); + + ssub.set(coll.subscribe(new CollectionObserver(this, args, ssub))); + } + + @Override + public void onError(Throwable e) { + synchronized (guard) { + if (done) { + return; + } + done = true; + observer.onError(e); + } + csub.unsubscribe(); + } + + @Override + public void onCompleted() { + if (wip.decrementAndGet() == 0) { + synchronized (guard) { + if (done) { + return; + } + done = true; + observer.onCompleted(); + } + csub.unsubscribe(); + } + } + + void complete(Subscription s) { + csub.remove(s); + onCompleted(); + } + + void emit(T t, U u) { + R r; + try { + r = resultSelector.call(t, u); + } catch (Throwable e) { + onError(e); + return; + } + synchronized (guard) { + if (done) { + return; + } + observer.onNext(r); + } + } + } + /** Observe a collection and call emit with the pair of the key and the value. */ + private static final class CollectionObserver implements Observer { + final SourceObserver so; + final Subscription cancel; + final T value; + + public CollectionObserver(SourceObserver so, T value, Subscription cancel) { + this.so = so; + this.value = value; + this.cancel = cancel; + } + + @Override + public void onNext(U args) { + so.emit(value, args); + } + + @Override + public void onError(Throwable e) { + so.onError(e); + } + + @Override + public void onCompleted() { + so.complete(cancel); + } + }; + } +} diff --git a/rxjava-core/src/test/java/rx/operators/OperationFlatMapTest.java b/rxjava-core/src/test/java/rx/operators/OperationFlatMapTest.java new file mode 100644 index 0000000000..9c3afd73f7 --- /dev/null +++ b/rxjava-core/src/test/java/rx/operators/OperationFlatMapTest.java @@ -0,0 +1,145 @@ +/** + * Copyright 2013 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package rx.operators; + +import java.util.Arrays; +import java.util.List; +import org.junit.Test; +import static org.mockito.Mockito.*; +import rx.Observable; +import rx.Observer; +import rx.util.functions.Func1; +import rx.util.functions.Func2; + +public class OperationFlatMapTest { + @Test + public void testNormal() { + @SuppressWarnings("unchecked") + Observer o = mock(Observer.class); + + final List list = Arrays.asList(1, 2, 3); + + Func1> func = new Func1>() { + @Override + public List call(Integer t1) { + return list; + } + }; + Func2 resFunc = new Func2() { + + @Override + public Integer call(Integer t1, Integer t2) { + return t1 | t2; + } + }; + + List source = Arrays.asList(16, 32, 64); + + Observable.from(source).mergeMapIterable(func, resFunc).subscribe(o); + + for (Integer s : source) { + for (Integer v : list) { + verify(o).onNext(s | v); + } + } + verify(o).onCompleted(); + verify(o, never()).onError(any(Throwable.class)); + } + @Test + public void testCollectionFunctionThrows() { + @SuppressWarnings("unchecked") + Observer o = mock(Observer.class); + + Func1> func = new Func1>() { + @Override + public List call(Integer t1) { + throw new OperationReduceTest.CustomException(); + } + }; + Func2 resFunc = new Func2() { + + @Override + public Integer call(Integer t1, Integer t2) { + return t1 | t2; + } + }; + + List source = Arrays.asList(16, 32, 64); + + Observable.from(source).mergeMapIterable(func, resFunc).subscribe(o); + + verify(o, never()).onCompleted(); + verify(o, never()).onNext(any()); + verify(o).onError(any(OperationReduceTest.CustomException.class)); + } + + @Test + public void testResultFunctionThrows() { + @SuppressWarnings("unchecked") + Observer o = mock(Observer.class); + + final List list = Arrays.asList(1, 2, 3); + + Func1> func = new Func1>() { + @Override + public List call(Integer t1) { + return list; + } + }; + Func2 resFunc = new Func2() { + + @Override + public Integer call(Integer t1, Integer t2) { + throw new OperationReduceTest.CustomException(); + } + }; + + List source = Arrays.asList(16, 32, 64); + + Observable.from(source).mergeMapIterable(func, resFunc).subscribe(o); + + verify(o, never()).onCompleted(); + verify(o, never()).onNext(any()); + verify(o).onError(any(OperationReduceTest.CustomException.class)); + } + @Test + public void testMergeError() { + @SuppressWarnings("unchecked") + Observer o = mock(Observer.class); + + Func1> func = new Func1>() { + @Override + public Observable call(Integer t1) { + return Observable.error(new OperationReduceTest.CustomException()); + } + }; + Func2 resFunc = new Func2() { + + @Override + public Integer call(Integer t1, Integer t2) { + return t1 | t2; + } + }; + + List source = Arrays.asList(16, 32, 64); + + Observable.from(source).mergeMap(func, resFunc).subscribe(o); + + verify(o, never()).onCompleted(); + verify(o, never()).onNext(any()); + verify(o).onError(any(OperationReduceTest.CustomException.class)); + } +} From 892e27a04c3c48990e8d8023420c8d6fada6ac58 Mon Sep 17 00:00:00 2001 From: akarnokd Date: Fri, 10 Jan 2014 16:42:44 +0100 Subject: [PATCH 2/2] Added event-merger overload --- rxjava-core/src/main/java/rx/Observable.java | 17 ++ .../java/rx/operators/OperationFlatMap.java | 162 ++++++++++++++++++ .../rx/operators/OperationFlatMapTest.java | 150 ++++++++++++++++ 3 files changed, 329 insertions(+) diff --git a/rxjava-core/src/main/java/rx/Observable.java b/rxjava-core/src/main/java/rx/Observable.java index a1227b9ef6..0171917258 100644 --- a/rxjava-core/src/main/java/rx/Observable.java +++ b/rxjava-core/src/main/java/rx/Observable.java @@ -3901,6 +3901,23 @@ public Observable mergeMapIterable(Func1 the result type + * @param onNext function returning a collection to merge for each onNext event of the source + * @param onError function returning a collection to merge for an onError event + * @param onCompleted function returning a collection to merge for an onCompleted event + * @return an Observable that projects the notification of an observable sequence to an observable + * sequence and merges the results into one. + */ + public Observable mergeMap( + Func1> onNext, + Func1> onError, + Func0> onCompleted) { + return create(OperationFlatMap.flatMap(this, onNext, onError, onCompleted)); + } + /** * Creates a new Observable by applying a function that you supply to each * item emitted by the source Observable, where that function returns an diff --git a/rxjava-core/src/main/java/rx/operators/OperationFlatMap.java b/rxjava-core/src/main/java/rx/operators/OperationFlatMap.java index 20c12993cb..a1a39f5b2b 100644 --- a/rxjava-core/src/main/java/rx/operators/OperationFlatMap.java +++ b/rxjava-core/src/main/java/rx/operators/OperationFlatMap.java @@ -22,6 +22,7 @@ import rx.Subscription; import rx.subscriptions.CompositeSubscription; import rx.subscriptions.SerialSubscription; +import rx.util.functions.Func0; import rx.util.functions.Func1; import rx.util.functions.Func2; @@ -205,4 +206,165 @@ public void onCompleted() { } }; } + + /** + * Projects the notification of an observable sequence to an observable + * sequence and merges the results into one. + */ + public static OnSubscribeFunc flatMap(Observable source, + Func1> onNext, + Func1> onError, + Func0> onCompleted) { + return new FlatMapTransform(source, onNext, onError, onCompleted); + } + + /** + * Projects the notification of an observable sequence to an observable + * sequence and merges the results into one. + * @param the source value type + * @param the result value type + */ + private static final class FlatMapTransform implements OnSubscribeFunc { + final Observable source; + final Func1> onNext; + final Func1> onError; + final Func0> onCompleted; + + public FlatMapTransform(Observable source, Func1> onNext, Func1> onError, Func0> onCompleted) { + this.source = source; + this.onNext = onNext; + this.onError = onError; + this.onCompleted = onCompleted; + } + + @Override + public Subscription onSubscribe(Observer t1) { + CompositeSubscription csub = new CompositeSubscription(); + + csub.add(source.subscribe(new SourceObserver(t1, onNext, onError, onCompleted, csub))); + + return csub; + } + /** + * Observe the source and merge the values. + * @param the source value type + * @param the result value type + */ + private static final class SourceObserver implements Observer { + final Observer observer; + final Func1> onNext; + final Func1> onError; + final Func0> onCompleted; + final CompositeSubscription csub; + final AtomicInteger wip; + volatile boolean done; + final Object guard; + + public SourceObserver(Observer observer, Func1> onNext, Func1> onError, Func0> onCompleted, CompositeSubscription csub) { + this.observer = observer; + this.onNext = onNext; + this.onError = onError; + this.onCompleted = onCompleted; + this.csub = csub; + this.guard = new Object(); + this.wip = new AtomicInteger(1); + } + + @Override + public void onNext(T args) { + Observable o; + try { + o = onNext.call(args); + } catch (Throwable t) { + synchronized (guard) { + observer.onError(t); + } + csub.unsubscribe(); + return; + } + subscribeInner(o); + } + + @Override + public void onError(Throwable e) { + Observable o; + try { + o = onError.call(e); + } catch (Throwable t) { + synchronized (guard) { + observer.onError(t); + } + csub.unsubscribe(); + return; + } + subscribeInner(o); + done = true; + finish(); + } + + @Override + public void onCompleted() { + Observable o; + try { + o = onCompleted.call(); + } catch (Throwable t) { + synchronized (guard) { + observer.onError(t); + } + csub.unsubscribe(); + return; + } + subscribeInner(o); + done = true; + finish(); + } + + void subscribeInner(Observable o) { + SerialSubscription ssub = new SerialSubscription(); + wip.incrementAndGet(); + csub.add(ssub); + + ssub.set(o.subscribe(new CollectionObserver(this, ssub))); + } + void finish() { + if (wip.decrementAndGet() == 0) { + synchronized (guard) { + observer.onCompleted(); + } + csub.unsubscribe(); + } + } + } + /** Observes the collections. */ + private static final class CollectionObserver implements Observer { + final SourceObserver parent; + final Subscription cancel; + + public CollectionObserver(SourceObserver parent, Subscription cancel) { + this.parent = parent; + this.cancel = cancel; + } + + @Override + public void onNext(R args) { + synchronized (parent.guard) { + parent.observer.onNext(args); + } + } + + @Override + public void onError(Throwable e) { + synchronized (parent.guard) { + parent.observer.onError(e); + } + parent.csub.unsubscribe(); + } + + @Override + public void onCompleted() { + parent.csub.remove(cancel); + parent.finish(); + } + } + } } diff --git a/rxjava-core/src/test/java/rx/operators/OperationFlatMapTest.java b/rxjava-core/src/test/java/rx/operators/OperationFlatMapTest.java index 9c3afd73f7..f081f0f8e8 100644 --- a/rxjava-core/src/test/java/rx/operators/OperationFlatMapTest.java +++ b/rxjava-core/src/test/java/rx/operators/OperationFlatMapTest.java @@ -21,6 +21,7 @@ import static org.mockito.Mockito.*; import rx.Observable; import rx.Observer; +import rx.util.functions.Func0; import rx.util.functions.Func1; import rx.util.functions.Func2; @@ -142,4 +143,153 @@ public Integer call(Integer t1, Integer t2) { verify(o, never()).onNext(any()); verify(o).onError(any(OperationReduceTest.CustomException.class)); } + Func1 just(final R value) { + return new Func1() { + + @Override + public R call(T t1) { + return value; + } + }; + } + Func0 just0(final R value) { + return new Func0() { + + @Override + public R call() { + return value; + } + }; + } + @Test + public void testFlatMapTransformsNormal() { + Observable onNext = Observable.from(Arrays.asList(1, 2, 3)); + Observable onCompleted = Observable.from(Arrays.asList(4)); + Observable onError = Observable.from(Arrays.asList(5)); + + Observable source = Observable.from(Arrays.asList(10, 20, 30)); + + @SuppressWarnings("unchecked") + Observer o = mock(Observer.class); + + source.mergeMap(just(onNext), just(onError), just0(onCompleted)).subscribe(o); + + verify(o, times(3)).onNext(1); + verify(o, times(3)).onNext(2); + verify(o, times(3)).onNext(3); + verify(o).onNext(4); + verify(o).onCompleted(); + + verify(o, never()).onNext(5); + verify(o, never()).onError(any(Throwable.class)); + } + @Test + public void testFlatMapTransformsException() { + Observable onNext = Observable.from(Arrays.asList(1, 2, 3)); + Observable onCompleted = Observable.from(Arrays.asList(4)); + Observable onError = Observable.from(Arrays.asList(5)); + + Observable source = Observable.concat( + Observable.from(Arrays.asList(10, 20, 30)) + , Observable.error(new RuntimeException("Forced failure!")) + ); + + @SuppressWarnings("unchecked") + Observer o = mock(Observer.class); + + source.mergeMap(just(onNext), just(onError), just0(onCompleted)).subscribe(o); + + verify(o, times(3)).onNext(1); + verify(o, times(3)).onNext(2); + verify(o, times(3)).onNext(3); + verify(o).onNext(5); + verify(o).onCompleted(); + verify(o, never()).onNext(4); + + verify(o, never()).onError(any(Throwable.class)); + } + Func0 funcThrow0(R r) { + return new Func0() { + @Override + public R call() { + throw new OperationReduceTest.CustomException(); + } + }; + } + Func1 funcThrow(T t, R r) { + return new Func1() { + @Override + public R call(T t) { + throw new OperationReduceTest.CustomException(); + } + }; + } + @Test + public void testFlatMapTransformsOnNextFuncThrows() { + Observable onCompleted = Observable.from(Arrays.asList(4)); + Observable onError = Observable.from(Arrays.asList(5)); + + Observable source = Observable.from(Arrays.asList(10, 20, 30)); + + @SuppressWarnings("unchecked") + Observer o = mock(Observer.class); + + source.mergeMap(funcThrow(1, onError), just(onError), just0(onCompleted)).subscribe(o); + + verify(o).onError(any(OperationReduceTest.CustomException.class)); + verify(o, never()).onNext(any()); + verify(o, never()).onCompleted(); + } + @Test + public void testFlatMapTransformsOnErrorFuncThrows() { + Observable onNext = Observable.from(Arrays.asList(1, 2, 3)); + Observable onCompleted = Observable.from(Arrays.asList(4)); + Observable onError = Observable.from(Arrays.asList(5)); + + Observable source = Observable.error(new OperationReduceTest.CustomException()); + + @SuppressWarnings("unchecked") + Observer o = mock(Observer.class); + + source.mergeMap(just(onNext), funcThrow((Throwable)null, onError), just0(onCompleted)).subscribe(o); + + verify(o).onError(any(OperationReduceTest.CustomException.class)); + verify(o, never()).onNext(any()); + verify(o, never()).onCompleted(); + } + + @Test + public void testFlatMapTransformsOnCompletedFuncThrows() { + Observable onNext = Observable.from(Arrays.asList(1, 2, 3)); + Observable onCompleted = Observable.from(Arrays.asList(4)); + Observable onError = Observable.from(Arrays.asList(5)); + + Observable source = Observable.from(Arrays.asList()); + + @SuppressWarnings("unchecked") + Observer o = mock(Observer.class); + + source.mergeMap(just(onNext), just(onError), funcThrow0(onCompleted)).subscribe(o); + + verify(o).onError(any(OperationReduceTest.CustomException.class)); + verify(o, never()).onNext(any()); + verify(o, never()).onCompleted(); + } + @Test + public void testFlatMapTransformsMergeException() { + Observable onNext = Observable.error(new OperationReduceTest.CustomException()); + Observable onCompleted = Observable.from(Arrays.asList(4)); + Observable onError = Observable.from(Arrays.asList(5)); + + Observable source = Observable.from(Arrays.asList(10, 20, 30)); + + @SuppressWarnings("unchecked") + Observer o = mock(Observer.class); + + source.mergeMap(just(onNext), just(onError), funcThrow0(onCompleted)).subscribe(o); + + verify(o).onError(any(OperationReduceTest.CustomException.class)); + verify(o, never()).onNext(any()); + verify(o, never()).onCompleted(); + } }