From 1469975d385f222724388ec0baf5dcbc18965258 Mon Sep 17 00:00:00 2001 From: Bruno Thomas Date: Thu, 9 Jan 2020 15:06:28 +0000 Subject: [PATCH] [#343] DS light : makes the subscribe blocking like Redis --- .../java/org/icij/datashare/com/DataBus.java | 6 +- .../datashare/com/memory/MemoryDataBus.java | 70 ++++++++++++++---- .../datashare/com/redis/RedisDataBus.java | 60 ++++++++-------- .../datashare/text/nlp/DatashareListener.java | 4 +- .../org/icij/datashare/text/nlp/NlpApp.java | 4 +- .../icij/datashare/text/nlp/NlpConsumer.java | 6 +- .../icij/datashare/text/nlp/NlpForwarder.java | 4 +- .../com/memory/MemoryDataBusTest.java | 72 +++++++++++-------- .../icij/datashare/text/nlp/NlpAppTest.java | 53 ++++++++------ 9 files changed, 180 insertions(+), 99 deletions(-) diff --git a/datashare-api/src/main/java/org/icij/datashare/com/DataBus.java b/datashare-api/src/main/java/org/icij/datashare/com/DataBus.java index abb4d68ce..2240701d7 100644 --- a/datashare-api/src/main/java/org/icij/datashare/com/DataBus.java +++ b/datashare-api/src/main/java/org/icij/datashare/com/DataBus.java @@ -2,8 +2,8 @@ import java.util.function.Consumer; -public interface DataBus { - void subscribe(Consumer subscriber, Channel... channels); - void subscribe(Consumer subscriber, Runnable subscriptionCallback, Channel... channels); +public interface DataBus extends Publisher { + int subscribe(Consumer subscriber, Channel... channels) throws InterruptedException; + int subscribe(Consumer subscriber, Runnable subscriptionCallback, Channel... channels) throws InterruptedException; void unsubscribe(Consumer subscriber); } diff --git a/datashare-api/src/main/java/org/icij/datashare/com/memory/MemoryDataBus.java b/datashare-api/src/main/java/org/icij/datashare/com/memory/MemoryDataBus.java index f8bc06636..7f21230f5 100644 --- a/datashare-api/src/main/java/org/icij/datashare/com/memory/MemoryDataBus.java +++ b/datashare-api/src/main/java/org/icij/datashare/com/memory/MemoryDataBus.java @@ -1,38 +1,82 @@ package org.icij.datashare.com.memory; -import org.icij.datashare.com.Channel; -import org.icij.datashare.com.DataBus; -import org.icij.datashare.com.Message; -import org.icij.datashare.com.Publisher; +import org.icij.datashare.com.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import java.util.Arrays; +import java.util.LinkedHashSet; import java.util.Map; -import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; +import static java.util.Optional.ofNullable; import static org.icij.datashare.CollectionUtils.asSet; public class MemoryDataBus implements Publisher, DataBus { - private final Map, Set> subscribers = new ConcurrentHashMap<>(); + private final Logger logger = LoggerFactory.getLogger(getClass()); + private final Map, MessageListener> subscribers = new ConcurrentHashMap<>(); - public void publish(Channel channel, Message message) { - subscribers.entrySet().stream().filter(e -> e.getValue().contains(channel)).forEach(e -> e.getKey().accept(message)); + public void publish(final Channel channel, final Message message) { + subscribers.values().stream().filter(l -> l.hasSubscribedTo(channel)).forEach(l -> l.accept(message)); } @Override - public void subscribe(Consumer subscriber, Channel... channels) { - subscribers.put(subscriber, asSet(channels)); + public int subscribe(final Consumer subscriber, final Channel... channels) throws InterruptedException { + return subscribe(subscriber, () -> logger.debug("subscribed {} to {}", subscriber, Arrays.toString(channels)), channels); } @Override - public void subscribe(Consumer subscriber, Runnable subscriptionCallback, Channel... channels) { - subscribe(subscriber, channels); + public int subscribe(final Consumer subscriber, final Runnable subscriptionCallback, final Channel... channels) throws InterruptedException { + MessageListener listener = new MessageListener(subscriber, channels); + subscribers.put(subscriber, listener); subscriptionCallback.run(); + + synchronized (listener.message) { + while (!listener.isShutdown()) { + listener.message.wait(); + } + } + return listener.nbMessages.get(); } @Override public void unsubscribe(Consumer subscriber) { - subscribers.remove(subscriber); + ofNullable(subscribers.remove(subscriber)).ifPresent(l -> { + l.accept(new ShutdownMessage()); + logger.debug("unsubscribed {}", subscriber); + }); } + private static class MessageListener implements Consumer { + private final Consumer subscriber; + private final LinkedHashSet channels; + final AtomicReference message = new AtomicReference<>(); + final AtomicInteger nbMessages = new AtomicInteger(0); + + public MessageListener(Consumer subscriber, Channel... channels) { + this.subscriber = subscriber; + this.channels = asSet(channels); + } + + boolean hasSubscribedTo(Channel channel) { + return channels.contains(channel); + } + + @Override + public void accept(Message message) { + this.message.set(message); + subscriber.accept(message); + nbMessages.getAndIncrement(); + synchronized (this.message) { + this.message.notify(); + } + } + + boolean isShutdown() { + return this.message.get() != null && this.message.get().type == Message.Type.SHUTDOWN; + } + } } diff --git a/datashare-api/src/main/java/org/icij/datashare/com/redis/RedisDataBus.java b/datashare-api/src/main/java/org/icij/datashare/com/redis/RedisDataBus.java index 696fe8899..994ca5c48 100644 --- a/datashare-api/src/main/java/org/icij/datashare/com/redis/RedisDataBus.java +++ b/datashare-api/src/main/java/org/icij/datashare/com/redis/RedisDataBus.java @@ -20,6 +20,7 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import static java.util.Arrays.stream; @@ -36,17 +37,18 @@ public RedisDataBus(PropertiesProvider propertiesProvider) { } @Override - public void subscribe(Consumer subscriber, Channel... channels) { - subscribe(subscriber, () -> logger.debug("subscribed to " + Arrays.toString(channels)), channels); + public int subscribe(Consumer subscriber, Channel... channels) { + return subscribe(subscriber, () -> logger.debug("subscribed to " + Arrays.toString(channels)), channels); } @Override - public void subscribe(Consumer subscriber, Runnable subscriptionCallback, Channel... channels) { + public int subscribe(Consumer subscriber, Runnable subscriptionCallback, Channel... channels) { JedisListener jedisListener = new JedisListener(subscriber, subscriptionCallback); subscribers.put(subscriber, jedisListener); try (Jedis jedis = redis.getResource()) { jedis.subscribe(jedisListener, stream(channels).map(Enum::name).toArray(String[]::new)); } + return jedisListener.nbMessages.get(); } @Override @@ -67,32 +69,34 @@ public void close() { } static class JedisListener extends JedisPubSub { - private final Consumer callback; - private final Runnable subscribedCallback; + private final Consumer callback; + private final Runnable subscribedCallback; + final AtomicInteger nbMessages = new AtomicInteger(0); - JedisListener(Consumer callback, Runnable subscribedCallback) { - this.callback = callback; - this.subscribedCallback = subscribedCallback; - } + JedisListener(Consumer callback, Runnable subscribedCallback) { + this.callback = callback; + this.subscribedCallback = subscribedCallback; + } - @Override - public void onSubscribe(String channel, int subscribedChannels) { - subscribedCallback.run(); - } + @Override + public void onSubscribe(String channel, int subscribedChannels) { + subscribedCallback.run(); + } - @Override - public void onMessage(String channel, String message) { - try { - HashMap result = new ObjectMapper().readValue(message, HashMap.class); - Message msg = new Message(result); - if (msg.type == SHUTDOWN) { - unsubscribe(); - logger.info("Shutdown called. Unsubscribe done."); - } - callback.accept(msg); - } catch (IOException e) { - logger.error("cannot deserialize json message " + message, e); - } - } - } + @Override + public void onMessage(String channel, String message) { + try { + HashMap result = new ObjectMapper().readValue(message, HashMap.class); + Message msg = new Message(result); + if (msg.type == SHUTDOWN) { + unsubscribe(); + logger.info("Shutdown called. Unsubscribe done."); + } + callback.accept(msg); + nbMessages.getAndIncrement(); + } catch (IOException e) { + logger.error("cannot deserialize json message " + message, e); + } + } + } } diff --git a/datashare-api/src/main/java/org/icij/datashare/text/nlp/DatashareListener.java b/datashare-api/src/main/java/org/icij/datashare/text/nlp/DatashareListener.java index cb68db7e9..5d75fb061 100644 --- a/datashare-api/src/main/java/org/icij/datashare/text/nlp/DatashareListener.java +++ b/datashare-api/src/main/java/org/icij/datashare/text/nlp/DatashareListener.java @@ -1,4 +1,6 @@ package org.icij.datashare.text.nlp; -interface DatashareListener extends Runnable { +import java.util.concurrent.Callable; + +interface DatashareListener extends Callable { } diff --git a/datashare-api/src/main/java/org/icij/datashare/text/nlp/NlpApp.java b/datashare-api/src/main/java/org/icij/datashare/text/nlp/NlpApp.java index f93373855..10f3e17d4 100644 --- a/datashare-api/src/main/java/org/icij/datashare/text/nlp/NlpApp.java +++ b/datashare-api/src/main/java/org/icij/datashare/text/nlp/NlpApp.java @@ -70,8 +70,8 @@ public void run() { logger.info("running NlpApp for {} pipeline with {} thread(s)", pipeline.getType(), parallelism); this.threadPool = Executors.newFixedThreadPool(parallelism, new ThreadFactoryBuilder().setNameFormat(pipeline.getType().name() + "-%d").build()); - generate(() -> new NlpConsumer(pipeline, indexer, queue)).limit(parallelism).forEach(l -> threadPool.execute(l)); - forwarder.run(); + generate(() -> new NlpConsumer(pipeline, indexer, queue)).limit(parallelism).forEach(l -> threadPool.submit(l)); + forwarder.call(); logger.info("forwarder exited waiting for consumer(s) to finish"); shutdown(); } catch (Throwable throwable) { diff --git a/datashare-api/src/main/java/org/icij/datashare/text/nlp/NlpConsumer.java b/datashare-api/src/main/java/org/icij/datashare/text/nlp/NlpConsumer.java index 319142555..e05774ece 100644 --- a/datashare-api/src/main/java/org/icij/datashare/text/nlp/NlpConsumer.java +++ b/datashare-api/src/main/java/org/icij/datashare/text/nlp/NlpConsumer.java @@ -30,8 +30,9 @@ public NlpConsumer(AbstractPipeline pipeline, Indexer indexer, BlockingQueue messageConsumer = message -> received.getAndIncrement(); - dataBus.subscribe(messageConsumer, TEST); - dataBus.unsubscribe(messageConsumer); + CountDownLatch subscription = new CountDownLatch(1); + Future subscribed = executor.submit(() -> dataBus.subscribe(messageConsumer, subscription::countDown, TEST)); + subscription.await(1, TimeUnit.SECONDS); + dataBus.unsubscribe(messageConsumer); dataBus.publish(TEST, new Message(Message.Type.EXTRACT_NLP)); - assertThat(received.get()).isEqualTo(0); - } - - @Test - public void test_subscribe_with_callback() { - AtomicInteger subscribed = new AtomicInteger(); - AtomicInteger received = new AtomicInteger(); - - dataBus.subscribe(message -> received.getAndIncrement(), subscribed::getAndIncrement, Channel.TEST); - + assertThat(received.get()).isEqualTo(1); // shutdown message assertThat(subscribed.get()).isEqualTo(1); } - @Test - public void test_pub_sub_one_subscriber() { + @Test(timeout = 5000) + public void test_pub_sub_one_subscriber() throws InterruptedException { AtomicInteger received = new AtomicInteger(); - dataBus.subscribe(message -> received.getAndIncrement(), TEST); + Consumer messageConsumer = message -> received.getAndIncrement(); + CountDownLatch subscription = new CountDownLatch(1); + executor.submit(() -> dataBus.subscribe(messageConsumer, subscription::countDown, TEST)); + subscription.await(1, TimeUnit.SECONDS); - dataBus.publish(TEST, new Message(Message.Type.EXTRACT_NLP)); + dataBus.publish(TEST, new Message(Message.Type.EXTRACT_NLP)); + dataBus.unsubscribe(messageConsumer); - assertThat(received.get()).isEqualTo(1); + assertThat(received.get()).isEqualTo(2); // +shutdown } - @Test - public void test_pub_sub_one_subscriber_other_channel() { + @Test(timeout = 5000) + public void test_pub_sub_one_subscriber_other_channel() throws InterruptedException { AtomicInteger received = new AtomicInteger(); - dataBus.subscribe(message -> received.getAndIncrement(), TEST); + Consumer messageConsumer = message -> received.getAndIncrement(); + CountDownLatch subscription = new CountDownLatch(1); + executor.submit(() -> dataBus.subscribe(messageConsumer, subscription::countDown, TEST)); + subscription.await(1, TimeUnit.SECONDS); dataBus.publish(NLP, new Message(Message.Type.EXTRACT_NLP)); + dataBus.unsubscribe(messageConsumer); - assertThat(received.get()).isEqualTo(0); + assertThat(received.get()).isEqualTo(1); //shutdown } - @Test - public void test_pub_sub_two_subscribers() { + @Test(timeout = 5000) + public void test_pub_sub_two_subscribers() throws InterruptedException { AtomicInteger received = new AtomicInteger(); - dataBus.subscribe(message -> received.getAndIncrement(), TEST); - dataBus.subscribe(message -> received.getAndIncrement(), TEST); + CountDownLatch subscriptions = new CountDownLatch(2); + + executor.submit(() -> dataBus.subscribe(message -> received.getAndIncrement(), subscriptions::countDown, TEST)); + executor.submit(() -> dataBus.subscribe(message -> received.getAndIncrement(), subscriptions::countDown, TEST)); + subscriptions.await(1, TimeUnit.SECONDS); dataBus.publish(TEST, new Message(Message.Type.EXTRACT_NLP)); + dataBus.publish(TEST, new ShutdownMessage()); + + assertThat(received.get()).isEqualTo(4); // EXTRACT received by 2 subscribers + 2 shutdown messages + } - assertThat(received.get()).isEqualTo(2); + @After + public void tearDown() throws Exception { + executor.shutdown(); + executor.awaitTermination(2, TimeUnit.SECONDS); } } diff --git a/datashare-api/src/test/java/org/icij/datashare/text/nlp/NlpAppTest.java b/datashare-api/src/test/java/org/icij/datashare/text/nlp/NlpAppTest.java index a75ab33cc..d10f17efe 100644 --- a/datashare-api/src/test/java/org/icij/datashare/text/nlp/NlpAppTest.java +++ b/datashare-api/src/test/java/org/icij/datashare/text/nlp/NlpAppTest.java @@ -1,24 +1,27 @@ package org.icij.datashare.text.nlp; import org.icij.datashare.PropertiesProvider; -import org.icij.datashare.com.Channel; -import org.icij.datashare.com.Message; -import org.icij.datashare.com.ShutdownMessage; +import org.icij.datashare.com.*; +import org.icij.datashare.com.memory.MemoryDataBus; import org.icij.datashare.com.redis.RedisDataBus; import org.icij.datashare.text.Language; import org.icij.datashare.text.indexing.Indexer; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.mockito.Mock; import org.mockito.stubbing.Answer; +import java.util.Collection; import java.util.Properties; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.stream.IntStream; +import static java.util.Arrays.asList; import static java.util.concurrent.TimeUnit.SECONDS; import static org.fest.assertions.Assertions.assertThat; import static org.icij.datashare.com.Message.Field.*; @@ -35,18 +38,28 @@ import static org.mockito.Mockito.*; import static org.mockito.MockitoAnnotations.initMocks; +@RunWith(Parameterized.class) public class NlpAppTest { + @Parameterized.Parameters + public static Collection dataBuses() { + return asList(new Object[][]{ + {new MemoryDataBus()}, + {new RedisDataBus(new PropertiesProvider())} + }); + } @Mock private AbstractPipeline pipeline; @Mock private Indexer indexer; - private RedisDataBus publisher = new RedisDataBus(new PropertiesProvider()); + private DataBus dataBus; private final ExecutorService executor = Executors.newFixedThreadPool(3); + public NlpAppTest(DataBus dataBus) { this.dataBus = dataBus;} + @Test(timeout = 5000) public void test_subscriber_mode_for_standalone_extraction() throws Exception { runNlpApp("1", 0); - publisher.publish(Channel.NLP, new Message(EXTRACT_NLP).add(DOC_ID, "doc_id").add(R_ID, "routing").add(INDEX_NAME, local().id)); - publisher.publish(Channel.NLP, new ShutdownMessage()); + dataBus.publish(Channel.NLP, new Message(EXTRACT_NLP).add(DOC_ID, "doc_id").add(R_ID, "routing").add(INDEX_NAME, local().id)); + dataBus.publish(Channel.NLP, new ShutdownMessage()); shutdownNlpApp(); verify(pipeline, times(1)).process(anyString(), anyString(), any(Language.class)); @@ -56,9 +69,9 @@ public void test_subscriber_mode_for_standalone_extraction() throws Exception { public void test_consumer_mode_for_multithreaded_server_extraction() throws Exception { runNlpApp("2", 0); - publisher.publish(Channel.NLP, new Message(EXTRACT_NLP).add(DOC_ID, "doc_id1").add(R_ID, "routing1").add(INDEX_NAME, local().id)); - publisher.publish(Channel.NLP, new Message(EXTRACT_NLP).add(DOC_ID, "doc_id2").add(R_ID, "routing2").add(INDEX_NAME, local().id)); - publisher.publish(Channel.NLP, new ShutdownMessage()); + dataBus.publish(Channel.NLP, new Message(EXTRACT_NLP).add(DOC_ID, "doc_id1").add(R_ID, "routing1").add(INDEX_NAME, local().id)); + dataBus.publish(Channel.NLP, new Message(EXTRACT_NLP).add(DOC_ID, "doc_id2").add(R_ID, "routing2").add(INDEX_NAME, local().id)); + dataBus.publish(Channel.NLP, new ShutdownMessage()); shutdownNlpApp(); verify(pipeline, times(2)).process(anyString(), anyString(), any(Language.class)); @@ -68,8 +81,8 @@ public void test_consumer_mode_for_multithreaded_server_extraction() throws Exce public void test_nlp_app_should_wait_queue_to_be_empty_to_shutdown() throws Exception { runNlpApp("1", 200); - IntStream.range(1,4).forEach(i -> publisher.publish(Channel.NLP, new Message(EXTRACT_NLP).add(DOC_ID, "doc_id" + i).add(R_ID, "routing" + i).add(INDEX_NAME, local().id))); - publisher.publish(Channel.NLP, new ShutdownMessage()); + IntStream.range(1,4).forEach(i -> dataBus.publish(Channel.NLP, new Message(EXTRACT_NLP).add(DOC_ID, "doc_id" + i).add(R_ID, "routing" + i).add(INDEX_NAME, local().id))); + dataBus.publish(Channel.NLP, new ShutdownMessage()); shutdownNlpApp(); verify(pipeline, times(3)).process(anyString(), anyString(), any(Language.class)); @@ -80,10 +93,10 @@ public void test_nlp_app_progress_rate() throws Exception { NlpApp nlpApp = runNlpApp("1", 0); assertThat(nlpApp.getProgressRate()).isEqualTo(-1); - publisher.publish(Channel.NLP, new Message(INIT_MONITORING).add(VALUE, "4")); - publisher.publish(Channel.NLP, new Message(EXTRACT_NLP).add(DOC_ID, "doc_id1").add(R_ID, "routing1").add(INDEX_NAME, local().id)); - publisher.publish(Channel.NLP, new Message(EXTRACT_NLP).add(DOC_ID, "doc_id2").add(R_ID, "routing2").add(INDEX_NAME, local().id)); - publisher.publish(Channel.NLP, new ShutdownMessage()); + dataBus.publish(Channel.NLP, new Message(INIT_MONITORING).add(VALUE, "4")); + dataBus.publish(Channel.NLP, new Message(EXTRACT_NLP).add(DOC_ID, "doc_id1").add(R_ID, "routing1").add(INDEX_NAME, local().id)); + dataBus.publish(Channel.NLP, new Message(EXTRACT_NLP).add(DOC_ID, "doc_id2").add(R_ID, "routing2").add(INDEX_NAME, local().id)); + dataBus.publish(Channel.NLP, new ShutdownMessage()); shutdownNlpApp(); assertThat(nlpApp.getProgressRate()).isEqualTo(0.5); @@ -94,10 +107,10 @@ public void test_nlp_app_progress_rate__two_init_add_values() throws Exception { NlpApp nlpApp = runNlpApp("1", 0); assertThat(nlpApp.getProgressRate()).isEqualTo(-1); - publisher.publish(Channel.NLP, new Message(INIT_MONITORING).add(VALUE, "4")); - publisher.publish(Channel.NLP, new Message(INIT_MONITORING).add(VALUE, "6")); - publisher.publish(Channel.NLP, new Message(EXTRACT_NLP).add(DOC_ID, "doc_id").add(R_ID, "routing").add(INDEX_NAME, local().id)); - publisher.publish(Channel.NLP, new ShutdownMessage()); + dataBus.publish(Channel.NLP, new Message(INIT_MONITORING).add(VALUE, "4")); + dataBus.publish(Channel.NLP, new Message(INIT_MONITORING).add(VALUE, "6")); + dataBus.publish(Channel.NLP, new Message(EXTRACT_NLP).add(DOC_ID, "doc_id").add(R_ID, "routing").add(INDEX_NAME, local().id)); + dataBus.publish(Channel.NLP, new ShutdownMessage()); shutdownNlpApp(); assertThat(nlpApp.getProgressRate()).isEqualTo(0.1); @@ -113,7 +126,7 @@ private NlpApp runNlpApp(String parallelism, int nlpProcessDelayMillis) throws I if (nlpProcessDelayMillis > 0) Thread.sleep(nlpProcessDelayMillis); return new Annotations("docid_mock", Pipeline.Type.CORENLP, Language.FRENCH); }); - NlpApp nlpApp = new NlpApp(publisher, indexer, pipeline, properties, latch::countDown,1, local()); + NlpApp nlpApp = new NlpApp(dataBus, indexer, pipeline, properties, latch::countDown,1, local()); executor.execute(nlpApp); latch.await(2, SECONDS); return nlpApp;