From 14b701d022e4b67d24be5b7751e8836fb6360fbd Mon Sep 17 00:00:00 2001 From: akarnokd Date: Mon, 23 Dec 2013 19:10:34 +0100 Subject: [PATCH] Operations Aggregate, Average and Sum with selector --- rxjava-core/src/main/java/rx/Observable.java | 140 +++++++++++ .../java/rx/operators/OperationAggregate.java | 158 ++++++++++++ .../java/rx/operators/OperationAverage.java | 227 +++++++++++++++++ .../main/java/rx/operators/OperationSum.java | 229 ++++++++++++++++++ .../rx/operators/OperationAggregateTest.java | 140 +++++++++++ .../rx/operators/OperationAverageTest.java | 206 +++++++++++++++- .../java/rx/operators/OperationSumTest.java | 205 ++++++++++++++++ 7 files changed, 1304 insertions(+), 1 deletion(-) create mode 100644 rxjava-core/src/main/java/rx/operators/OperationAggregate.java create mode 100644 rxjava-core/src/test/java/rx/operators/OperationAggregateTest.java diff --git a/rxjava-core/src/main/java/rx/Observable.java b/rxjava-core/src/main/java/rx/Observable.java index 04fa8e5132..96c18c4d1c 100644 --- a/rxjava-core/src/main/java/rx/Observable.java +++ b/rxjava-core/src/main/java/rx/Observable.java @@ -32,6 +32,7 @@ import rx.observables.BlockingObservable; import rx.observables.ConnectableObservable; import rx.observables.GroupedObservable; +import rx.operators.OperationAggregate; import rx.operators.OperationAll; import rx.operators.OperationAmb; import rx.operators.OperationAny; @@ -4118,6 +4119,54 @@ public static Observable sumDoubles(Observable source) { return OperationSum.sumDoubles(source); } + /** + * Create an Observable that extracts integer values from this Observable via + * the provided function and computes the integer sum of the value sequence. + * + * @param valueExtractor the function to extract an integer from this Observable + * @return an Observable that extracts integer values from this Observable via + * the provided function and computes the integer sum of the value sequence. + */ + public Observable sumInteger(Func1 valueExtractor) { + return create(new OperationSum.SumIntegerExtractor(this, valueExtractor)); + } + + /** + * Create an Observable that extracts long values from this Observable via + * the provided function and computes the long sum of the value sequence. + * + * @param valueExtractor the function to extract an long from this Observable + * @return an Observable that extracts long values from this Observable via + * the provided function and computes the long sum of the value sequence. + */ + public Observable sumLong(Func1 valueExtractor) { + return create(new OperationSum.SumLongExtractor(this, valueExtractor)); + } + + /** + * Create an Observable that extracts float values from this Observable via + * the provided function and computes the float sum of the value sequence. + * + * @param valueExtractor the function to extract an float from this Observable + * @return an Observable that extracts float values from this Observable via + * the provided function and computes the float sum of the value sequence. + */ + public Observable sumFloat(Func1 valueExtractor) { + return create(new OperationSum.SumFloatExtractor(this, valueExtractor)); + } + + /** + * Create an Observable that extracts double values from this Observable via + * the provided function and computes the double sum of the value sequence. + * + * @param valueExtractor the function to extract an double from this Observable + * @return an Observable that extracts double values from this Observable via + * the provided function and computes the double sum of the value sequence. + */ + public Observable sumDouble(Func1 valueExtractor) { + return create(new OperationSum.SumDoubleExtractor(this, valueExtractor)); + } + /** * Returns an Observable that computes the average of the Integers emitted * by the source Observable. @@ -4183,6 +4232,54 @@ public static Observable averageDoubles(Observable source) { return OperationAverage.averageDoubles(source); } + /** + * Create an Observable that extracts integer values from this Observable via + * the provided function and computes the integer average of the value sequence. + * + * @param valueExtractor the function to extract an integer from this Observable + * @return an Observable that extracts integer values from this Observable via + * the provided function and computes the integer average of the value sequence. + */ + public Observable averageInteger(Func1 valueExtractor) { + return create(new OperationAverage.AverageIntegerExtractor(this, valueExtractor)); + } + + /** + * Create an Observable that extracts long values from this Observable via + * the provided function and computes the long average of the value sequence. + * + * @param valueExtractor the function to extract an long from this Observable + * @return an Observable that extracts long values from this Observable via + * the provided function and computes the long average of the value sequence. + */ + public Observable averageLong(Func1 valueExtractor) { + return create(new OperationAverage.AverageLongExtractor(this, valueExtractor)); + } + + /** + * Create an Observable that extracts float values from this Observable via + * the provided function and computes the float average of the value sequence. + * + * @param valueExtractor the function to extract an float from this Observable + * @return an Observable that extracts float values from this Observable via + * the provided function and computes the float average of the value sequence. + */ + public Observable averageFloat(Func1 valueExtractor) { + return create(new OperationAverage.AverageFloatExtractor(this, valueExtractor)); + } + + /** + * Create an Observable that extracts double values from this Observable via + * the provided function and computes the double average of the value sequence. + * + * @param valueExtractor the function to extract an double from this Observable + * @return an Observable that extracts double values from this Observable via + * the provided function and computes the double average of the value sequence. + */ + public Observable averageDouble(Func1 valueExtractor) { + return create(new OperationAverage.AverageDoubleExtractor(this, valueExtractor)); + } + /** * Returns an Observable that emits the minimum item emitted by the source * Observable. If there is more than one such item, it returns the @@ -4954,6 +5051,49 @@ public Observable reduce(R initialValue, Func2 accumulat public Observable aggregate(R initialValue, Func2 accumulator) { return reduce(initialValue, accumulator); } + + /** + * Create an Observable that aggregates the source values with the given accumulator + * function and projects the final result via the resultselector. + *

+ * Works like the {@link #aggregate(java.lang.Object, rx.util.functions.Func2)} projected + * with {@link #map(rx.util.functions.Func1)} without the overhead of some helper + * operators. + * @param the intermediate (accumulator) type + * @param the result type + * @param seed the initial value of the accumulator + * @param accumulator the function that takes the current accumulator value, + * the current emitted value and returns a (new) accumulated value. + * @param resultSelector the selector to project the final value of the accumulator + * @return an Observable that aggregates the source values with the given accumulator + * function and projects the final result via the resultselector + */ + public Observable aggregate( + U seed, Func2 accumulator, + Func1 resultSelector) { + return create(new OperationAggregate.AggregateSelector(this, seed, accumulator, resultSelector)); + } + + /** + * Create an Observable that aggregates the source values with the given indexed accumulator + * function and projects the final result via the indexed resultselector. + * + * @param the intermediate (accumulator) type + * @param the result type + * @param seed the initial value of the accumulator + * @param accumulator the function that takes the current accumulator value, + * the current emitted value and returns a (new) accumulated value. + * @param resultSelector the selector to project the final value of the accumulator, where + * the second argument is the total number of elements accumulated + * @return an Observable that aggregates the source values with the given indexed accumulator + * function and projects the final result via the indexed resultselector. + */ + public Observable aggregateIndexed( + U seed, Func3 accumulator, + Func2 resultSelector + ) { + return create(new OperationAggregate.AggregateIndexedSelector(this, seed, accumulator, resultSelector)); + } /** * Returns an Observable that applies a function of your choosing to the diff --git a/rxjava-core/src/main/java/rx/operators/OperationAggregate.java b/rxjava-core/src/main/java/rx/operators/OperationAggregate.java new file mode 100644 index 0000000000..71778cccf1 --- /dev/null +++ b/rxjava-core/src/main/java/rx/operators/OperationAggregate.java @@ -0,0 +1,158 @@ +/** + * Copyright 2013 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rx.operators; + +import rx.Observable; +import rx.Observable.OnSubscribeFunc; +import rx.Observer; +import rx.Subscription; +import rx.util.functions.Func1; +import rx.util.functions.Func2; +import rx.util.functions.Func3; + +/** + * Aggregate overloads with index and selector functions. + */ +public final class OperationAggregate { + /** Utility class. */ + private OperationAggregate() { throw new IllegalStateException("No instances!"); } + + /** + * Aggregate and emit a value after running it through a selector. + * @param the input value type + * @param the intermediate value type + * @param the result value type + */ + public static final class AggregateSelector implements OnSubscribeFunc { + final Observable source; + final U seed; + final Func2 aggregator; + final Func1 resultSelector; + + public AggregateSelector( + Observable source, U seed, + Func2 aggregator, + Func1 resultSelector) { + this.source = source; + this.seed = seed; + this.aggregator = aggregator; + this.resultSelector = resultSelector; + } + + @Override + public Subscription onSubscribe(Observer t1) { + return source.subscribe(new AggregatorObserver(t1, seed)); + } + /** The aggregator observer of the source. */ + private final class AggregatorObserver implements Observer { + final Observer observer; + U accumulator; + public AggregatorObserver(Observer observer, U seed) { + this.observer = observer; + this.accumulator = seed; + } + + @Override + public void onNext(T args) { + accumulator = aggregator.call(accumulator, args); + } + + @Override + public void onError(Throwable e) { + accumulator = null; + observer.onError(e); + } + + @Override + public void onCompleted() { + U a = accumulator; + accumulator = null; + try { + observer.onNext(resultSelector.call(a)); + } catch (Throwable t) { + observer.onError(t); + return; + } + observer.onCompleted(); + } + } + } + /** + * Indexed aggregate and emit a value after running it through an indexed selector. + * @param the input value type + * @param the intermediate value type + * @param the result value type + */ + public static final class AggregateIndexedSelector implements OnSubscribeFunc { + final Observable source; + final U seed; + final Func3 aggregator; + final Func2 resultSelector; + + public AggregateIndexedSelector( + Observable source, + U seed, + Func3 aggregator, + Func2 resultSelector) { + this.source = source; + this.seed = seed; + this.aggregator = aggregator; + this.resultSelector = resultSelector; + } + + + + @Override + public Subscription onSubscribe(Observer t1) { + return source.subscribe(new AggregatorObserver(t1, seed)); + } + /** The aggregator observer of the source. */ + private final class AggregatorObserver implements Observer { + final Observer observer; + U accumulator; + int index; + public AggregatorObserver(Observer observer, U seed) { + this.observer = observer; + this.accumulator = seed; + } + + @Override + public void onNext(T args) { + accumulator = aggregator.call(accumulator, args, index++); + } + + @Override + public void onError(Throwable e) { + accumulator = null; + observer.onError(e); + } + + @Override + public void onCompleted() { + U a = accumulator; + accumulator = null; + try { + observer.onNext(resultSelector.call(a, index)); + } catch (Throwable t) { + observer.onError(t); + return; + } + observer.onCompleted(); + } + } + } +} diff --git a/rxjava-core/src/main/java/rx/operators/OperationAverage.java b/rxjava-core/src/main/java/rx/operators/OperationAverage.java index 35abc99eb5..29acf784d0 100644 --- a/rxjava-core/src/main/java/rx/operators/OperationAverage.java +++ b/rxjava-core/src/main/java/rx/operators/OperationAverage.java @@ -16,6 +16,9 @@ package rx.operators; import rx.Observable; +import rx.Observable.OnSubscribeFunc; +import rx.Observer; +import rx.Subscription; import rx.util.functions.Func1; import rx.util.functions.Func2; @@ -102,4 +105,228 @@ public Double call(Tuple2 result) { } }); } + + /** + * Compute the average by extracting integer values from the source via an + * extractor function. + * @param the source value type + */ + public static final class AverageIntegerExtractor implements OnSubscribeFunc { + final Observable source; + final Func1 valueExtractor; + + public AverageIntegerExtractor(Observable source, Func1 valueExtractor) { + this.source = source; + this.valueExtractor = valueExtractor; + } + + @Override + public Subscription onSubscribe(Observer t1) { + return source.subscribe(new AverageObserver(t1)); + } + /** Computes the average. */ + private final class AverageObserver implements Observer { + final Observer observer; + int sum; + int count; + public AverageObserver(Observer observer) { + this.observer = observer; + } + + @Override + public void onNext(T args) { + sum += valueExtractor.call(args); + count++; + } + + @Override + public void onError(Throwable e) { + observer.onError(e); + } + + @Override + public void onCompleted() { + if (count > 0) { + try { + observer.onNext(sum / count); + } catch (Throwable t) { + observer.onError(t); + return; + } + observer.onCompleted(); + } else { + observer.onError(new IllegalArgumentException("Sequence contains no elements")); + } + } + + } + } + + /** + * Compute the average by extracting long values from the source via an + * extractor function. + * @param the source value type + */ + public static final class AverageLongExtractor implements OnSubscribeFunc { + final Observable source; + final Func1 valueExtractor; + + public AverageLongExtractor(Observable source, Func1 valueExtractor) { + this.source = source; + this.valueExtractor = valueExtractor; + } + + @Override + public Subscription onSubscribe(Observer t1) { + return source.subscribe(new AverageObserver(t1)); + } + /** Computes the average. */ + private final class AverageObserver implements Observer { + final Observer observer; + long sum; + int count; + public AverageObserver(Observer observer) { + this.observer = observer; + } + + @Override + public void onNext(T args) { + sum += valueExtractor.call(args); + count++; + } + + @Override + public void onError(Throwable e) { + observer.onError(e); + } + + @Override + public void onCompleted() { + if (count > 0) { + try { + observer.onNext(sum / count); + } catch (Throwable t) { + observer.onError(t); + return; + } + observer.onCompleted(); + } else { + observer.onError(new IllegalArgumentException("Sequence contains no elements")); + } + } + + } + } + + /** + * Compute the average by extracting float values from the source via an + * extractor function. + * @param the source value type + */ + public static final class AverageFloatExtractor implements OnSubscribeFunc { + final Observable source; + final Func1 valueExtractor; + + public AverageFloatExtractor(Observable source, Func1 valueExtractor) { + this.source = source; + this.valueExtractor = valueExtractor; + } + + @Override + public Subscription onSubscribe(Observer t1) { + return source.subscribe(new AverageObserver(t1)); + } + /** Computes the average. */ + private final class AverageObserver implements Observer { + final Observer observer; + float sum; + int count; + public AverageObserver(Observer observer) { + this.observer = observer; + } + + @Override + public void onNext(T args) { + sum += valueExtractor.call(args); + count++; + } + + @Override + public void onError(Throwable e) { + observer.onError(e); + } + + @Override + public void onCompleted() { + if (count > 0) { + try { + observer.onNext(sum / count); + } catch (Throwable t) { + observer.onError(t); + return; + } + observer.onCompleted(); + } else { + observer.onError(new IllegalArgumentException("Sequence contains no elements")); + } + } + + } + } + + /** + * Compute the average by extracting double values from the source via an + * extractor function. + * @param the source value type + */ + public static final class AverageDoubleExtractor implements OnSubscribeFunc { + final Observable source; + final Func1 valueExtractor; + + public AverageDoubleExtractor(Observable source, Func1 valueExtractor) { + this.source = source; + this.valueExtractor = valueExtractor; + } + + @Override + public Subscription onSubscribe(Observer t1) { + return source.subscribe(new AverageObserver(t1)); + } + /** Computes the average. */ + private final class AverageObserver implements Observer { + final Observer observer; + double sum; + int count; + public AverageObserver(Observer observer) { + this.observer = observer; + } + + @Override + public void onNext(T args) { + sum += valueExtractor.call(args); + count++; + } + + @Override + public void onError(Throwable e) { + observer.onError(e); + } + + @Override + public void onCompleted() { + if (count > 0) { + try { + observer.onNext(sum / count); + } catch (Throwable t) { + observer.onError(t); + return; + } + observer.onCompleted(); + } else { + observer.onError(new IllegalArgumentException("Sequence contains no elements")); + } + } + + } + } } diff --git a/rxjava-core/src/main/java/rx/operators/OperationSum.java b/rxjava-core/src/main/java/rx/operators/OperationSum.java index fef81a2625..8f419bd222 100644 --- a/rxjava-core/src/main/java/rx/operators/OperationSum.java +++ b/rxjava-core/src/main/java/rx/operators/OperationSum.java @@ -16,6 +16,10 @@ package rx.operators; import rx.Observable; +import rx.Observable.OnSubscribeFunc; +import rx.Observer; +import rx.Subscription; +import rx.util.functions.Func1; import rx.util.functions.Func2; /** @@ -59,4 +63,229 @@ public Double call(Double accu, Double next) { } }); } + + /** + * Compute the sum by extracting integer values from the source via an + * extractor function. + * @param the source value type + */ + public static final class SumIntegerExtractor implements Observable.OnSubscribeFunc { + final Observable source; + final Func1 valueExtractor; + + public SumIntegerExtractor(Observable source, Func1 valueExtractor) { + this.source = source; + this.valueExtractor = valueExtractor; + } + + @Override + public Subscription onSubscribe(Observer t1) { + return source.subscribe(new SumObserver(t1)); + } + /** Computes the average. */ + private final class SumObserver implements Observer { + final Observer observer; + int sum; + boolean hasValue; + public SumObserver(Observer observer) { + this.observer = observer; + } + + @Override + public void onNext(T args) { + sum += valueExtractor.call(args); + hasValue = true; + } + + @Override + public void onError(Throwable e) { + observer.onError(e); + } + + @Override + public void onCompleted() { + if (hasValue) { + try { + observer.onNext(sum); + } catch (Throwable t) { + observer.onError(t); + return; + } + observer.onCompleted(); + } else { + observer.onError(new IllegalArgumentException("Sequence contains no elements")); + } + } + + } + } + + /** + * Compute the sum by extracting long values from the source via an + * extractor function. + * @param the source value type + */ + public static final class SumLongExtractor implements Observable.OnSubscribeFunc { + final Observable source; + final Func1 valueExtractor; + + public SumLongExtractor(Observable source, Func1 valueExtractor) { + this.source = source; + this.valueExtractor = valueExtractor; + } + + @Override + public Subscription onSubscribe(Observer t1) { + return source.subscribe(new SumObserver(t1)); + } + /** Computes the average. */ + private final class SumObserver implements Observer { + final Observer observer; + long sum; + boolean hasValue; + public SumObserver(Observer observer) { + this.observer = observer; + } + + @Override + public void onNext(T args) { + sum += valueExtractor.call(args); + hasValue = true; + } + + @Override + public void onError(Throwable e) { + observer.onError(e); + } + + @Override + public void onCompleted() { + if (hasValue) { + try { + observer.onNext(sum); + } catch (Throwable t) { + observer.onError(t); + return; + } + observer.onCompleted(); + } else { + observer.onError(new IllegalArgumentException("Sequence contains no elements")); + } + } + + } + } + + /** + * Compute the sum by extracting float values from the source via an + * extractor function. + * @param the source value type + */ + public static final class SumFloatExtractor implements Observable.OnSubscribeFunc { + final Observable source; + final Func1 valueExtractor; + + public SumFloatExtractor(Observable source, Func1 valueExtractor) { + this.source = source; + this.valueExtractor = valueExtractor; + } + + @Override + public Subscription onSubscribe(Observer t1) { + return source.subscribe(new SumObserver(t1)); + } + /** Computes the average. */ + private final class SumObserver implements Observer { + final Observer observer; + float sum; + boolean hasValue; + public SumObserver(Observer observer) { + this.observer = observer; + } + + @Override + public void onNext(T args) { + sum += valueExtractor.call(args); + hasValue = true; + } + + @Override + public void onError(Throwable e) { + observer.onError(e); + } + + @Override + public void onCompleted() { + if (hasValue) { + try { + observer.onNext(sum); + } catch (Throwable t) { + observer.onError(t); + return; + } + observer.onCompleted(); + } else { + observer.onError(new IllegalArgumentException("Sequence contains no elements")); + } + } + + } + } + + /** + * Compute the sum by extracting float values from the source via an + * extractor function. + * @param the source value type + */ + public static final class SumDoubleExtractor implements Observable.OnSubscribeFunc { + final Observable source; + final Func1 valueExtractor; + + public SumDoubleExtractor(Observable source, Func1 valueExtractor) { + this.source = source; + this.valueExtractor = valueExtractor; + } + + @Override + public Subscription onSubscribe(Observer t1) { + return source.subscribe(new SumObserver(t1)); + } + /** Computes the average. */ + private final class SumObserver implements Observer { + final Observer observer; + double sum; + boolean hasValue; + public SumObserver(Observer observer) { + this.observer = observer; + } + + @Override + public void onNext(T args) { + sum += valueExtractor.call(args); + hasValue = true; + } + + @Override + public void onError(Throwable e) { + observer.onError(e); + } + + @Override + public void onCompleted() { + if (hasValue) { + try { + observer.onNext(sum); + } catch (Throwable t) { + observer.onError(t); + return; + } + observer.onCompleted(); + } else { + observer.onError(new IllegalArgumentException("Sequence contains no elements")); + } + } + + } + } + } diff --git a/rxjava-core/src/test/java/rx/operators/OperationAggregateTest.java b/rxjava-core/src/test/java/rx/operators/OperationAggregateTest.java new file mode 100644 index 0000000000..749af5b235 --- /dev/null +++ b/rxjava-core/src/test/java/rx/operators/OperationAggregateTest.java @@ -0,0 +1,140 @@ +/** + * Copyright 2013 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rx.operators; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import static org.mockito.Mockito.*; +import rx.Observable; +import rx.Observer; +import rx.util.functions.Func1; +import rx.util.functions.Func2; +import rx.util.functions.Func3; +import rx.util.functions.Functions; + +public class OperationAggregateTest { + @Mock + Observer observer; + @Before + public void before() { + MockitoAnnotations.initMocks(this); + } + Func2 sum = new Func2() { + @Override + public Integer call(Integer t1, Integer t2) { + return t1 + t2; + } + }; + + @Test + public void testAggregateAsIntSum() { + + Observable result = Observable.from(1, 2, 3, 4, 5).aggregate(0, sum, Functions.identity()); + + result.subscribe(observer); + + verify(observer).onNext(1 + 2 + 3 + 4 + 5); + verify(observer).onCompleted(); + verify(observer, never()).onError(any(Throwable.class)); + + } + + @Test + public void testAggregateIndexedAsAverage() { + Func3 sumIndex = new Func3() { + @Override + public Integer call(Integer acc, Integer value, Integer index) { + return acc + (index + 1) + value; + } + }; + Func2 selectIndex = new Func2() { + @Override + public Integer call(Integer t1, Integer count) { + return t1 + count; + } + + }; + + Observable result = Observable.from(1, 2, 3, 4, 5) + .aggregateIndexed(0, sumIndex, selectIndex); + + result.subscribe(observer); + + verify(observer).onNext(2 + 4 + 6 + 8 + 10 + 5); + verify(observer).onCompleted(); + verify(observer, never()).onError(any(Throwable.class)); + + } + + static class CustomException extends RuntimeException { } + + @Test + public void testAggregateAsIntSumSourceThrows() { + Observable result = Observable.concat(Observable.from(1, 2, 3, 4, 5), + Observable.error(new CustomException())) + .aggregate(0, sum, Functions.identity()); + + result.subscribe(observer); + + verify(observer, never()).onNext(any()); + verify(observer, never()).onCompleted(); + verify(observer, times(1)).onError(any(CustomException.class)); + } + + @Test + public void testAggregateAsIntSumAccumulatorThrows() { + Func2 sumErr = new Func2() { + @Override + public Integer call(Integer t1, Integer t2) { + throw new CustomException(); + } + }; + + Observable result = Observable.from(1, 2, 3, 4, 5) + .aggregate(0, sumErr, Functions.identity()); + + result.subscribe(observer); + + verify(observer, never()).onNext(any()); + verify(observer, never()).onCompleted(); + verify(observer, times(1)).onError(any(CustomException.class)); + } + + @Test + public void testAggregateAsIntSumResultSelectorThrows() { + + Func1 error = new Func1() { + + @Override + public Integer call(Integer t1) { + throw new CustomException(); + } + }; + + Observable result = Observable.from(1, 2, 3, 4, 5) + .aggregate(0, sum, error); + + result.subscribe(observer); + + verify(observer, never()).onNext(any()); + verify(observer, never()).onCompleted(); + verify(observer, times(1)).onError(any(CustomException.class)); + } + +} diff --git a/rxjava-core/src/test/java/rx/operators/OperationAverageTest.java b/rxjava-core/src/test/java/rx/operators/OperationAverageTest.java index 357743da17..8655c37437 100644 --- a/rxjava-core/src/test/java/rx/operators/OperationAverageTest.java +++ b/rxjava-core/src/test/java/rx/operators/OperationAverageTest.java @@ -15,7 +15,6 @@ */ package rx.operators; -import static org.mockito.Matchers.*; import static org.mockito.Mockito.*; import static rx.operators.OperationAverage.*; @@ -23,6 +22,8 @@ import rx.Observable; import rx.Observer; +import rx.operators.OperationAggregateTest.CustomException; +import rx.util.functions.Func1; public class OperationAverageTest { @@ -118,4 +119,207 @@ public void testEmptyAverageDoubles() throws Throwable { verify(wd, times(1)).onError(isA(IllegalArgumentException.class)); verify(wd, never()).onCompleted(); } + + void testThrows(Observer o, Class errorClass) { + verify(o, never()).onNext(any()); + verify(o, never()).onCompleted(); + verify(o, times(1)).onError(any(errorClass)); + } + void testValue(Observer o, N value) { + verify(o, times(1)).onNext(value); + verify(o, times(1)).onCompleted(); + verify(o, never()).onError(any(Throwable.class)); + } + @Test + public void testIntegerAverageSelector() { + Observable source = Observable.from("a", "bb", "ccc", "dddd"); + Func1 length = new Func1() { + @Override + public Integer call(String t1) { + return t1.length(); + } + }; + + Observable result = source.averageInteger(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testValue(o, 2); + } + @Test + public void testLongAverageSelector() { + Observable source = Observable.from("a", "bb", "ccc", "dddd"); + Func1 length = new Func1() { + @Override + public Long call(String t1) { + return (long)t1.length(); + } + }; + + Observable result = source.averageLong(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testValue(o, 2L); + } + @Test + public void testFloatAverageSelector() { + Observable source = Observable.from("a", "bb", "ccc", "dddd"); + Func1 length = new Func1() { + @Override + public Float call(String t1) { + return (float)t1.length(); + } + }; + + Observable result = source.averageFloat(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testValue(o, 2.5f); + } + @Test + public void testDoubleAverageSelector() { + Observable source = Observable.from("a", "bb", "ccc", "dddd"); + Func1 length = new Func1() { + @Override + public Double call(String t1) { + return (double)t1.length(); + } + }; + + Observable result = source.averageDouble(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testValue(o, 2.5d); + } + @Test + public void testIntegerAverageSelectorEmpty() { + Observable source = Observable.empty(); + Func1 length = new Func1() { + @Override + public Integer call(String t1) { + return t1.length(); + } + }; + + Observable result = source.averageInteger(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testThrows(o, IllegalArgumentException.class); + } + @Test + public void testLongAverageSelectorEmpty() { + Observable source = Observable.empty(); + Func1 length = new Func1() { + @Override + public Long call(String t1) { + return (long)t1.length(); + } + }; + + Observable result = source.averageLong(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testThrows(o, IllegalArgumentException.class); + } + @Test + public void testFloatAverageSelectorEmpty() { + Observable source = Observable.empty(); + Func1 length = new Func1() { + @Override + public Float call(String t1) { + return (float)t1.length(); + } + }; + + Observable result = source.averageFloat(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testThrows(o, IllegalArgumentException.class); + } + @Test + public void testDoubleAverageSelectorEmpty() { + Observable source = Observable.empty(); + Func1 length = new Func1() { + @Override + public Double call(String t1) { + return (double)t1.length(); + } + }; + + Observable result = source.averageDouble(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testThrows(o, IllegalArgumentException.class); + } + @Test + public void testIntegerAverageSelectorThrows() { + Observable source = Observable.from("a"); + Func1 length = new Func1() { + @Override + public Integer call(String t1) { + throw new CustomException(); + } + }; + + Observable result = source.averageInteger(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testThrows(o, CustomException.class); + } + @Test + public void testLongAverageSelectorThrows() { + Observable source = Observable.from("a"); + Func1 length = new Func1() { + @Override + public Long call(String t1) { + throw new CustomException(); + } + }; + + Observable result = source.averageLong(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testThrows(o, CustomException.class); + } + @Test + public void testFloatAverageSelectorThrows() { + Observable source = Observable.from("a"); + Func1 length = new Func1() { + @Override + public Float call(String t1) { + throw new CustomException(); + } + }; + + Observable result = source.averageFloat(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testThrows(o, CustomException.class); + } + @Test + public void testDoubleAverageSelectorThrows() { + Observable source = Observable.from("a"); + Func1 length = new Func1() { + @Override + public Double call(String t1) { + throw new CustomException(); + } + }; + + Observable result = source.averageDouble(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testThrows(o, CustomException.class); + } } diff --git a/rxjava-core/src/test/java/rx/operators/OperationSumTest.java b/rxjava-core/src/test/java/rx/operators/OperationSumTest.java index e124ad13d5..2fcc4611e1 100644 --- a/rxjava-core/src/test/java/rx/operators/OperationSumTest.java +++ b/rxjava-core/src/test/java/rx/operators/OperationSumTest.java @@ -23,6 +23,7 @@ import rx.Observable; import rx.Observer; +import rx.util.functions.Func1; public class OperationSumTest { @@ -122,4 +123,208 @@ public void testEmptySumDoubles() throws Throwable { verify(wd, never()).onError(any(Throwable.class)); verify(wd, times(1)).onCompleted(); } + + void testThrows(Observer o, Class errorClass) { + verify(o, never()).onNext(any()); + verify(o, never()).onCompleted(); + verify(o, times(1)).onError(any(errorClass)); + } + void testValue(Observer o, N value) { + verify(o, times(1)).onNext(value); + verify(o, times(1)).onCompleted(); + verify(o, never()).onError(any(Throwable.class)); + } + + @Test + public void testIntegerSumSelector() { + Observable source = Observable.from("a", "bb", "ccc", "dddd"); + Func1 length = new Func1() { + @Override + public Integer call(String t1) { + return t1.length(); + } + }; + + Observable result = source.sumInteger(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testValue(o, 10); + } + @Test + public void testLongSumSelector() { + Observable source = Observable.from("a", "bb", "ccc", "dddd"); + Func1 length = new Func1() { + @Override + public Long call(String t1) { + return (long)t1.length(); + } + }; + + Observable result = source.sumLong(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testValue(o, 10L); + } + @Test + public void testFloatSumSelector() { + Observable source = Observable.from("a", "bb", "ccc", "dddd"); + Func1 length = new Func1() { + @Override + public Float call(String t1) { + return (float)t1.length(); + } + }; + + Observable result = source.sumFloat(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testValue(o, 10f); + } + @Test + public void testDoubleSumSelector() { + Observable source = Observable.from("a", "bb", "ccc", "dddd"); + Func1 length = new Func1() { + @Override + public Double call(String t1) { + return (double)t1.length(); + } + }; + + Observable result = source.sumDouble(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testValue(o, 10d); + } + @Test + public void testIntegerSumSelectorEmpty() { + Observable source = Observable.empty(); + Func1 length = new Func1() { + @Override + public Integer call(String t1) { + return t1.length(); + } + }; + + Observable result = source.sumInteger(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testThrows(o, IllegalArgumentException.class); + } + @Test + public void testLongSumSelectorEmpty() { + Observable source = Observable.empty(); + Func1 length = new Func1() { + @Override + public Long call(String t1) { + return (long)t1.length(); + } + }; + + Observable result = source.sumLong(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testThrows(o, IllegalArgumentException.class); + } + @Test + public void testFloatSumSelectorEmpty() { + Observable source = Observable.empty(); + Func1 length = new Func1() { + @Override + public Float call(String t1) { + return (float)t1.length(); + } + }; + + Observable result = source.sumFloat(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testThrows(o, IllegalArgumentException.class); + } + @Test + public void testDoubleSumSelectorEmpty() { + Observable source = Observable.empty(); + Func1 length = new Func1() { + @Override + public Double call(String t1) { + return (double)t1.length(); + } + }; + + Observable result = source.sumDouble(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testThrows(o, IllegalArgumentException.class); + } + @Test + public void testIntegerSumSelectorThrows() { + Observable source = Observable.from("a"); + Func1 length = new Func1() { + @Override + public Integer call(String t1) { + throw new OperationAggregateTest.CustomException(); + } + }; + + Observable result = source.sumInteger(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testThrows(o, OperationAggregateTest.CustomException.class); + } + @Test + public void testLongSumSelectorThrows() { + Observable source = Observable.from("a"); + Func1 length = new Func1() { + @Override + public Long call(String t1) { + throw new OperationAggregateTest.CustomException(); + } + }; + + Observable result = source.sumLong(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testThrows(o, OperationAggregateTest.CustomException.class); + } + @Test + public void testFloatSumSelectorThrows() { + Observable source = Observable.from("a"); + Func1 length = new Func1() { + @Override + public Float call(String t1) { + throw new OperationAggregateTest.CustomException(); + } + }; + + Observable result = source.sumFloat(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testThrows(o, OperationAggregateTest.CustomException.class); + } + @Test + public void testDoubleSumSelectorThrows() { + Observable source = Observable.from("a"); + Func1 length = new Func1() { + @Override + public Double call(String t1) { + throw new OperationAggregateTest.CustomException(); + } + }; + + Observable result = source.sumDouble(length); + Observer o = mock(Observer.class); + result.subscribe(o); + + testThrows(o, OperationAggregateTest.CustomException.class); + } }