Skip to content

Commit

Permalink
[#343] DS light : makes the subscribe blocking like Redis
Browse files Browse the repository at this point in the history
  • Loading branch information
bamthomas committed Jan 9, 2020
1 parent bf4d620 commit 1469975
Show file tree
Hide file tree
Showing 9 changed files with 180 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import java.util.function.Consumer;

public interface DataBus {
void subscribe(Consumer<Message> subscriber, Channel... channels);
void subscribe(Consumer<Message> subscriber, Runnable subscriptionCallback, Channel... channels);
public interface DataBus extends Publisher {
int subscribe(Consumer<Message> subscriber, Channel... channels) throws InterruptedException;
int subscribe(Consumer<Message> subscriber, Runnable subscriptionCallback, Channel... channels) throws InterruptedException;
void unsubscribe(Consumer<Message> subscriber);
}
Original file line number Diff line number Diff line change
@@ -1,38 +1,82 @@
package org.icij.datashare.com.memory;

import org.icij.datashare.com.Channel;
import org.icij.datashare.com.DataBus;
import org.icij.datashare.com.Message;
import org.icij.datashare.com.Publisher;
import org.icij.datashare.com.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;

import static java.util.Optional.ofNullable;
import static org.icij.datashare.CollectionUtils.asSet;

public class MemoryDataBus implements Publisher, DataBus {
private final Map<Consumer<Message>, Set<Channel>> subscribers = new ConcurrentHashMap<>();
private final Logger logger = LoggerFactory.getLogger(getClass());
private final Map<Consumer<Message>, MessageListener> subscribers = new ConcurrentHashMap<>();

public void publish(Channel channel, Message message) {
subscribers.entrySet().stream().filter(e -> e.getValue().contains(channel)).forEach(e -> e.getKey().accept(message));
public void publish(final Channel channel, final Message message) {
subscribers.values().stream().filter(l -> l.hasSubscribedTo(channel)).forEach(l -> l.accept(message));
}

@Override
public void subscribe(Consumer<Message> subscriber, Channel... channels) {
subscribers.put(subscriber, asSet(channels));
public int subscribe(final Consumer<Message> subscriber, final Channel... channels) throws InterruptedException {
return subscribe(subscriber, () -> logger.debug("subscribed {} to {}", subscriber, Arrays.toString(channels)), channels);
}

@Override
public void subscribe(Consumer<Message> subscriber, Runnable subscriptionCallback, Channel... channels) {
subscribe(subscriber, channels);
public int subscribe(final Consumer<Message> subscriber, final Runnable subscriptionCallback, final Channel... channels) throws InterruptedException {
MessageListener listener = new MessageListener(subscriber, channels);
subscribers.put(subscriber, listener);
subscriptionCallback.run();

synchronized (listener.message) {
while (!listener.isShutdown()) {
listener.message.wait();
}
}
return listener.nbMessages.get();
}

@Override
public void unsubscribe(Consumer<Message> subscriber) {
subscribers.remove(subscriber);
ofNullable(subscribers.remove(subscriber)).ifPresent(l -> {
l.accept(new ShutdownMessage());
logger.debug("unsubscribed {}", subscriber);
});
}

private static class MessageListener implements Consumer<Message> {
private final Consumer<Message> subscriber;
private final LinkedHashSet<Channel> channels;
final AtomicReference<Message> message = new AtomicReference<>();
final AtomicInteger nbMessages = new AtomicInteger(0);

public MessageListener(Consumer<Message> subscriber, Channel... channels) {
this.subscriber = subscriber;
this.channels = asSet(channels);
}

boolean hasSubscribedTo(Channel channel) {
return channels.contains(channel);
}

@Override
public void accept(Message message) {
this.message.set(message);
subscriber.accept(message);
nbMessages.getAndIncrement();
synchronized (this.message) {
this.message.notify();
}
}

boolean isShutdown() {
return this.message.get() != null && this.message.get().type == Message.Type.SHUTDOWN;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;

import static java.util.Arrays.stream;
Expand All @@ -36,17 +37,18 @@ public RedisDataBus(PropertiesProvider propertiesProvider) {
}

@Override
public void subscribe(Consumer<Message> subscriber, Channel... channels) {
subscribe(subscriber, () -> logger.debug("subscribed to " + Arrays.toString(channels)), channels);
public int subscribe(Consumer<Message> subscriber, Channel... channels) {
return subscribe(subscriber, () -> logger.debug("subscribed to " + Arrays.toString(channels)), channels);
}

@Override
public void subscribe(Consumer<Message> subscriber, Runnable subscriptionCallback, Channel... channels) {
public int subscribe(Consumer<Message> subscriber, Runnable subscriptionCallback, Channel... channels) {
JedisListener jedisListener = new JedisListener(subscriber, subscriptionCallback);
subscribers.put(subscriber, jedisListener);
try (Jedis jedis = redis.getResource()) {
jedis.subscribe(jedisListener, stream(channels).map(Enum::name).toArray(String[]::new));
}
return jedisListener.nbMessages.get();
}

@Override
Expand All @@ -67,32 +69,34 @@ public void close() {
}

static class JedisListener extends JedisPubSub {
private final Consumer<Message> callback;
private final Runnable subscribedCallback;
private final Consumer<Message> callback;
private final Runnable subscribedCallback;
final AtomicInteger nbMessages = new AtomicInteger(0);

JedisListener(Consumer<Message> callback, Runnable subscribedCallback) {
this.callback = callback;
this.subscribedCallback = subscribedCallback;
}
JedisListener(Consumer<Message> callback, Runnable subscribedCallback) {
this.callback = callback;
this.subscribedCallback = subscribedCallback;
}

@Override
public void onSubscribe(String channel, int subscribedChannels) {
subscribedCallback.run();
}
@Override
public void onSubscribe(String channel, int subscribedChannels) {
subscribedCallback.run();
}

@Override
public void onMessage(String channel, String message) {
try {
HashMap result = new ObjectMapper().readValue(message, HashMap.class);
Message msg = new Message(result);
if (msg.type == SHUTDOWN) {
unsubscribe();
logger.info("Shutdown called. Unsubscribe done.");
}
callback.accept(msg);
} catch (IOException e) {
logger.error("cannot deserialize json message " + message, e);
}
}
}
@Override
public void onMessage(String channel, String message) {
try {
HashMap result = new ObjectMapper().readValue(message, HashMap.class);
Message msg = new Message(result);
if (msg.type == SHUTDOWN) {
unsubscribe();
logger.info("Shutdown called. Unsubscribe done.");
}
callback.accept(msg);
nbMessages.getAndIncrement();
} catch (IOException e) {
logger.error("cannot deserialize json message " + message, e);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
package org.icij.datashare.text.nlp;

interface DatashareListener extends Runnable {
import java.util.concurrent.Callable;

interface DatashareListener extends Callable<Integer> {
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ public void run() {
logger.info("running NlpApp for {} pipeline with {} thread(s)", pipeline.getType(), parallelism);
this.threadPool = Executors.newFixedThreadPool(parallelism,
new ThreadFactoryBuilder().setNameFormat(pipeline.getType().name() + "-%d").build());
generate(() -> new NlpConsumer(pipeline, indexer, queue)).limit(parallelism).forEach(l -> threadPool.execute(l));
forwarder.run();
generate(() -> new NlpConsumer(pipeline, indexer, queue)).limit(parallelism).forEach(l -> threadPool.submit(l));
forwarder.call();
logger.info("forwarder exited waiting for consumer(s) to finish");
shutdown();
} catch (Throwable throwable) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,17 @@ public NlpConsumer(AbstractPipeline pipeline, Indexer indexer, BlockingQueue<Mes
}

@Override
public void run() {
public Integer call() {
boolean exitAsked = false;
int nbMessages = 0;
while (! exitAsked) {
try {
Message message = messageQueue.poll(30, TimeUnit.SECONDS);
if (message != null) {
switch (message.type) {
case EXTRACT_NLP:
findNamedEntities(message.content.get(INDEX_NAME), message.content.get(DOC_ID), message.content.get(R_ID));
nbMessages++;
break;
case SHUTDOWN:
exitAsked = true;
Expand All @@ -48,6 +50,7 @@ public void run() {
}
synchronized (messageQueue) {
if (messageQueue.isEmpty()) {
logger.debug("queue is empty notifying messageQueue {}", messageQueue.hashCode());
messageQueue.notify();
}
}
Expand All @@ -57,6 +60,7 @@ public void run() {
}
}
logger.info("exiting main loop");
return nbMessages;
}

void findNamedEntities(final String projectName, final String id, final String routing) throws InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ public class NlpForwarder implements DatashareListener,Monitorable {
}

@Override
public void run() {
dataBus.subscribe(this::onMessage, subscribedCallback, Channel.NLP);
public Integer call() throws InterruptedException {
return dataBus.subscribe(this::onMessage, subscribedCallback, Channel.NLP);
}

void onMessage(final Message message) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package org.icij.datashare.com.memory;

import org.icij.datashare.com.Channel;
import org.icij.datashare.com.Message;
import org.icij.datashare.com.ShutdownMessage;
import org.junit.After;
import org.junit.Test;

import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;

Expand All @@ -12,58 +14,70 @@
import static org.icij.datashare.com.Channel.TEST;

public class MemoryDataBusTest {
private ExecutorService executor = Executors.newFixedThreadPool(2);
private MemoryDataBus dataBus = new MemoryDataBus();

@Test
public void test_subscribe_unsubscribe() {
@Test(timeout = 5000)
public void test_subscribe_unsubscribe() throws Exception {
AtomicInteger received = new AtomicInteger();
Consumer<Message> messageConsumer = message -> received.getAndIncrement();
dataBus.subscribe(messageConsumer, TEST);
dataBus.unsubscribe(messageConsumer);
CountDownLatch subscription = new CountDownLatch(1);
Future<Integer> subscribed = executor.submit(() -> dataBus.subscribe(messageConsumer, subscription::countDown, TEST));
subscription.await(1, TimeUnit.SECONDS);

dataBus.unsubscribe(messageConsumer);
dataBus.publish(TEST, new Message(Message.Type.EXTRACT_NLP));

assertThat(received.get()).isEqualTo(0);
}

@Test
public void test_subscribe_with_callback() {
AtomicInteger subscribed = new AtomicInteger();
AtomicInteger received = new AtomicInteger();

dataBus.subscribe(message -> received.getAndIncrement(), subscribed::getAndIncrement, Channel.TEST);

assertThat(received.get()).isEqualTo(1); // shutdown message
assertThat(subscribed.get()).isEqualTo(1);
}

@Test
public void test_pub_sub_one_subscriber() {
@Test(timeout = 5000)
public void test_pub_sub_one_subscriber() throws InterruptedException {
AtomicInteger received = new AtomicInteger();
dataBus.subscribe(message -> received.getAndIncrement(), TEST);
Consumer<Message> messageConsumer = message -> received.getAndIncrement();
CountDownLatch subscription = new CountDownLatch(1);
executor.submit(() -> dataBus.subscribe(messageConsumer, subscription::countDown, TEST));
subscription.await(1, TimeUnit.SECONDS);

dataBus.publish(TEST, new Message(Message.Type.EXTRACT_NLP));
dataBus.publish(TEST, new Message(Message.Type.EXTRACT_NLP));
dataBus.unsubscribe(messageConsumer);

assertThat(received.get()).isEqualTo(1);
assertThat(received.get()).isEqualTo(2); // +shutdown
}

@Test
public void test_pub_sub_one_subscriber_other_channel() {
@Test(timeout = 5000)
public void test_pub_sub_one_subscriber_other_channel() throws InterruptedException {
AtomicInteger received = new AtomicInteger();
dataBus.subscribe(message -> received.getAndIncrement(), TEST);
Consumer<Message> messageConsumer = message -> received.getAndIncrement();
CountDownLatch subscription = new CountDownLatch(1);
executor.submit(() -> dataBus.subscribe(messageConsumer, subscription::countDown, TEST));
subscription.await(1, TimeUnit.SECONDS);

dataBus.publish(NLP, new Message(Message.Type.EXTRACT_NLP));
dataBus.unsubscribe(messageConsumer);

assertThat(received.get()).isEqualTo(0);
assertThat(received.get()).isEqualTo(1); //shutdown
}

@Test
public void test_pub_sub_two_subscribers() {
@Test(timeout = 5000)
public void test_pub_sub_two_subscribers() throws InterruptedException {
AtomicInteger received = new AtomicInteger();
dataBus.subscribe(message -> received.getAndIncrement(), TEST);
dataBus.subscribe(message -> received.getAndIncrement(), TEST);
CountDownLatch subscriptions = new CountDownLatch(2);

executor.submit(() -> dataBus.subscribe(message -> received.getAndIncrement(), subscriptions::countDown, TEST));
executor.submit(() -> dataBus.subscribe(message -> received.getAndIncrement(), subscriptions::countDown, TEST));
subscriptions.await(1, TimeUnit.SECONDS);

dataBus.publish(TEST, new Message(Message.Type.EXTRACT_NLP));
dataBus.publish(TEST, new ShutdownMessage());

assertThat(received.get()).isEqualTo(4); // EXTRACT received by 2 subscribers + 2 shutdown messages
}

assertThat(received.get()).isEqualTo(2);
@After
public void tearDown() throws Exception {
executor.shutdown();
executor.awaitTermination(2, TimeUnit.SECONDS);
}
}
Loading

0 comments on commit 1469975

Please sign in to comment.