Skip to content

Commit

Permalink
GroupBy GroupedObservables should not re-subscribe to parent sequence
Browse files Browse the repository at this point in the history
ReactiveX/RxJava#282

Refactored to maintain a single subscription that propagates events to the correct child GroupedObservables.
  • Loading branch information
benjchristensen committed May 31, 2013
1 parent 448d778 commit bca0b15
Showing 1 changed file with 237 additions and 63 deletions.
300 changes: 237 additions & 63 deletions rxjava-core/src/main/java/rx/operators/OperationGroupBy.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,24 @@

import static org.junit.Assert.*;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import org.junit.Test;

import rx.Observable;
import rx.Observer;
import rx.Subscription;
import rx.observables.GroupedObservable;
import rx.subscriptions.Subscriptions;
import rx.util.functions.Action1;
import rx.util.functions.Func1;
import rx.util.functions.Functions;

Expand All @@ -55,69 +60,137 @@ public static <K, T> Func1<Observer<GroupedObservable<K, T>>, Subscription> grou
}

private static class GroupBy<K, V> implements Func1<Observer<GroupedObservable<K, V>>, Subscription> {

private final Observable<KeyValue<K, V>> source;
private final ConcurrentHashMap<K, GroupedSubject<K, V>> groupedObservables = new ConcurrentHashMap<K, GroupedSubject<K, V>>();

private GroupBy(Observable<KeyValue<K, V>> source) {
this.source = source;
}

@Override
public Subscription call(final Observer<GroupedObservable<K, V>> observer) {
return source.subscribe(new GroupByObserver(observer));
return source.subscribe(new Observer<KeyValue<K, V>>() {

@Override
public void onCompleted() {
// we need to propagate to all children I imagine ... we can't just leave all of those Observable/Observers hanging
for (GroupedSubject<K, V> o : groupedObservables.values()) {
o.onCompleted();
}
// now the parent
observer.onCompleted();
}

@Override
public void onError(Exception e) {
// we need to propagate to all children I imagine ... we can't just leave all of those Observable/Observers hanging
for (GroupedSubject<K, V> o : groupedObservables.values()) {
o.onError(e);
}
// now the parent
observer.onError(e);
}

@Override
public void onNext(KeyValue<K, V> value) {
GroupedSubject<K, V> gs = groupedObservables.get(value.key);
if (gs == null) {
/*
* Technically the source should be single-threaded so we shouldn't need to do this but I am
* programming defensively as most operators are so this can work with a concurrent sequence
* if it ends up receiving one.
*/
GroupedSubject<K, V> newGs = GroupedSubject.<K, V> create(value.key);
GroupedSubject<K, V> existing = groupedObservables.putIfAbsent(value.key, newGs);
if (existing == null) {
// we won so use the one we created
gs = newGs;
// since we won the creation we emit this new GroupedObservable
observer.onNext(gs);
} else {
// another thread beat us so use the existing one
gs = existing;
}
}
gs.onNext(value.value);
}
});
}
}

private class GroupByObserver implements Observer<KeyValue<K, V>> {
private final Observer<GroupedObservable<K, V>> underlying;
private static class GroupedSubject<K, T> extends GroupedObservable<K, T> implements Observer<T> {

private final ConcurrentHashMap<K, Boolean> keys = new ConcurrentHashMap<K, Boolean>();
static <K, T> GroupedSubject<K, T> create(K key) {
@SuppressWarnings("unchecked")
final AtomicReference<Observer<T>> subscribedObserver = new AtomicReference<Observer<T>>(EMPTY_OBSERVER);

private GroupByObserver(Observer<GroupedObservable<K, V>> underlying) {
this.underlying = underlying;
}
return new GroupedSubject<K, T>(key, new Func1<Observer<T>, Subscription>() {

@Override
public void onCompleted() {
underlying.onCompleted();
}
@Override
public Subscription call(Observer<T> observer) {
// register Observer
subscribedObserver.set(observer);

@Override
public void onError(Exception e) {
underlying.onError(e);
}
return new Subscription() {

@Override
public void onNext(final KeyValue<K, V> args) {
K key = args.key;
boolean newGroup = keys.putIfAbsent(key, true) == null;
if (newGroup) {
underlying.onNext(buildObservableFor(source, key));
@SuppressWarnings("unchecked")
@Override
public void unsubscribe() {
// we remove the Observer so we stop emitting further events (they will be ignored if parent continues to send)
subscribedObserver.set(EMPTY_OBSERVER);
// I don't believe we need to worry about the parent here as it's a separate sequence that would
// be unsubscribed to directly if that needs to happen.
}
};
}
}
}, subscribedObserver);
}
}

private static <K, R> GroupedObservable<K, R> buildObservableFor(Observable<KeyValue<K, R>> source, final K key) {
final Observable<R> observable = source.filter(new Func1<KeyValue<K, R>, Boolean>() {
@Override
public Boolean call(KeyValue<K, R> pair) {
return key.equals(pair.key);
}
}).map(new Func1<KeyValue<K, R>, R>() {
@Override
public R call(KeyValue<K, R> pair) {
return pair.value;
}
});
return new GroupedObservable<K, R>(key, new Func1<Observer<R>, Subscription>() {
private final AtomicReference<Observer<T>> subscribedObserver;

@Override
public Subscription call(Observer<R> observer) {
return observable.subscribe(observer);
}
public GroupedSubject(K key, Func1<Observer<T>, Subscription> onSubscribe, AtomicReference<Observer<T>> subscribedObserver) {
super(key, onSubscribe);
this.subscribedObserver = subscribedObserver;
}

@Override
public void onCompleted() {
subscribedObserver.get().onCompleted();
}

@Override
public void onError(Exception e) {
subscribedObserver.get().onError(e);
}

@Override
public void onNext(T v) {
subscribedObserver.get().onNext(v);
}

});
}

@SuppressWarnings("rawtypes")
private static Observer EMPTY_OBSERVER = new Observer() {

@Override
public void onCompleted() {
// do nothing
}

@Override
public void onError(Exception e) {
// do nothing
}

@Override
public void onNext(Object args) {
// do nothing
}

};

private static class KeyValue<K, V> {
private final K key;
private final V value;
Expand All @@ -141,45 +214,146 @@ public void testGroupBy() {
Observable<String> source = Observable.from("one", "two", "three", "four", "five", "six");
Observable<GroupedObservable<Integer, String>> grouped = Observable.create(groupBy(source, length));

Map<Integer, List<String>> map = toMap(grouped);
Map<Integer, Collection<String>> map = toMap(grouped);

assertEquals(3, map.size());
assertEquals(Arrays.asList("one", "two", "six"), map.get(3));
assertEquals(Arrays.asList("four", "five"), map.get(4));
assertEquals(Arrays.asList("three"), map.get(5));

assertArrayEquals(Arrays.asList("one", "two", "six").toArray(), map.get(3).toArray());
assertArrayEquals(Arrays.asList("four", "five").toArray(), map.get(4).toArray());
assertArrayEquals(Arrays.asList("three").toArray(), map.get(5).toArray());
}

@Test
public void testEmpty() {
Observable<String> source = Observable.from();
Observable<GroupedObservable<Integer, String>> grouped = Observable.create(groupBy(source, length));

Map<Integer, List<String>> map = toMap(grouped);
Map<Integer, Collection<String>> map = toMap(grouped);

assertTrue(map.isEmpty());
}

private static <K, V> Map<K, List<V>> toMap(Observable<GroupedObservable<K, V>> observable) {
Map<K, List<V>> result = new HashMap<K, List<V>>();
for (GroupedObservable<K, V> g : observable.toBlockingObservable().toIterable()) {
K key = g.getKey();
private static <K, V> Map<K, Collection<V>> toMap(Observable<GroupedObservable<K, V>> observable) {

for (V value : g.toBlockingObservable().toIterable()) {
List<V> values = result.get(key);
if (values == null) {
values = new ArrayList<V>();
result.put(key, values);
}
final ConcurrentHashMap<K, Collection<V>> result = new ConcurrentHashMap<K, Collection<V>>();

values.add(value);
}
observable.forEach(new Action1<GroupedObservable<K, V>>() {

}
@Override
public void call(final GroupedObservable<K, V> o) {
result.put(o.getKey(), new ConcurrentLinkedQueue<V>());
o.subscribe(new Action1<V>() {

@Override
public void call(V v) {
result.get(o.getKey()).add(v);
}

});
}
});

return result;
}

/**
* Assert that only a single subscription to a stream occurs and that all events are received.
*
* @throws Exception
*/
@Test
public void testGroupedEventStream() throws Exception {

final AtomicInteger eventCounter = new AtomicInteger();
final AtomicInteger subscribeCounter = new AtomicInteger();
final AtomicInteger groupCounter = new AtomicInteger();
final CountDownLatch latch = new CountDownLatch(1);
final int count = 100;
final int groupCount = 2;

Observable<Event> es = Observable.create(new Func1<Observer<Event>, Subscription>() {

@Override
public Subscription call(final Observer<Event> observer) {
System.out.println("*** Subscribing to EventStream ***");
subscribeCounter.incrementAndGet();
new Thread(new Runnable() {

@Override
public void run() {
for (int i = 0; i < count; i++) {
Event e = new Event();
e.source = i % groupCount;
e.message = "Event-" + i;
observer.onNext(e);
}
observer.onCompleted();
}

}).start();
return Subscriptions.empty();
}

});

es.groupBy(new Func1<Event, Integer>() {

@Override
public Integer call(Event e) {
return e.source;
}
}).mapMany(new Func1<GroupedObservable<Integer, Event>, Observable<String>>() {

@Override
public Observable<String> call(GroupedObservable<Integer, Event> eventGroupedObservable) {
System.out.println("GroupedObservable Key: " + eventGroupedObservable.getKey());
groupCounter.incrementAndGet();

return eventGroupedObservable.map(new Func1<Event, String>() {

@Override
public String call(Event event) {
return "Source: " + event.source + " Message: " + event.message;
}
});

};
}).subscribe(new Observer<String>() {

@Override
public void onCompleted() {
latch.countDown();
}

@Override
public void onError(Exception e) {
e.printStackTrace();
latch.countDown();
}

@Override
public void onNext(String outputMessage) {
System.out.println(outputMessage);
eventCounter.incrementAndGet();
}
});

latch.await(5000, TimeUnit.MILLISECONDS);
assertEquals(1, subscribeCounter.get());
assertEquals(groupCount, groupCounter.get());
assertEquals(count, eventCounter.get());

}

private static class Event {
int source;
String message;

@Override
public String toString() {
return "Event => source: " + source + " message: " + message;
}
}

}

}

0 comments on commit bca0b15

Please sign in to comment.