diff --git a/.changes/next-release/bugfix-AWSSDKforJavav2-974558b.json b/.changes/next-release/bugfix-AWSSDKforJavav2-974558b.json new file mode 100644 index 000000000000..77b4001abc2d --- /dev/null +++ b/.changes/next-release/bugfix-AWSSDKforJavav2-974558b.json @@ -0,0 +1,6 @@ +{ + "category": "AWS SDK for Java v2", + "contributor": "", + "type": "bugfix", + "description": "Fixed an issue where event streams might fail with ClassCastException or NoSuchElementExceptions" +} diff --git a/.changes/next-release/feature-AWSSDKforJavav2-d2a3922.json b/.changes/next-release/feature-AWSSDKforJavav2-d2a3922.json new file mode 100644 index 000000000000..5396cf9ffcc5 --- /dev/null +++ b/.changes/next-release/feature-AWSSDKforJavav2-d2a3922.json @@ -0,0 +1,6 @@ +{ + "category": "AWS SDK for Java v2", + "contributor": "", + "type": "feature", + "description": "Added new convenience methods to SdkPublisher: doAfterOnError, doAfterOnComplete, and doAfterCancel." +} diff --git a/core/aws-core/pom.xml b/core/aws-core/pom.xml index be41b0b84d10..907e79959fa9 100644 --- a/core/aws-core/pom.xml +++ b/core/aws-core/pom.xml @@ -73,10 +73,6 @@ utils ${awsjavasdk.version} - - org.slf4j - slf4j-api - software.amazon.eventstream eventstream diff --git a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/eventstream/EventStreamAsyncResponseTransformer.java b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/eventstream/EventStreamAsyncResponseTransformer.java index d8437707427e..95b664774423 100644 --- a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/eventstream/EventStreamAsyncResponseTransformer.java +++ b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/eventstream/EventStreamAsyncResponseTransformer.java @@ -15,41 +15,35 @@ package software.amazon.awssdk.awscore.eventstream; +import static java.util.Collections.emptyList; +import static java.util.Collections.singleton; import static java.util.Collections.singletonList; import static software.amazon.awssdk.core.http.HttpResponseHandler.X_AMZN_REQUEST_ID_HEADER; import static software.amazon.awssdk.core.http.HttpResponseHandler.X_AMZN_REQUEST_ID_HEADERS; import static software.amazon.awssdk.core.http.HttpResponseHandler.X_AMZ_ID_2_HEADER; -import static software.amazon.awssdk.utils.FunctionalUtils.runAndLogError; import java.io.ByteArrayInputStream; import java.nio.ByteBuffer; import java.util.HashMap; -import java.util.LinkedList; import java.util.List; import java.util.Map; -import java.util.Queue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import java.util.function.Supplier; import software.amazon.awssdk.annotations.SdkProtectedApi; import software.amazon.awssdk.core.SdkResponse; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.core.http.HttpResponseHandler; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; import software.amazon.awssdk.http.AbortableInputStream; -import software.amazon.awssdk.http.SdkCancellationException; import software.amazon.awssdk.http.SdkHttpFullResponse; -import software.amazon.awssdk.utils.BinaryUtils; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; import software.amazon.awssdk.utils.http.SdkHttpUtils; import software.amazon.eventstream.Message; import software.amazon.eventstream.MessageDecoder; @@ -64,12 +58,7 @@ @SdkProtectedApi public final class EventStreamAsyncResponseTransformer implements AsyncResponseTransformer { - - private static final Logger log = LoggerFactory.getLogger(EventStreamAsyncResponseTransformer.class); - - private static final Object ON_COMPLETE_EVENT = new Object(); - - private static final ExecutionAttributes EMPTY_EXECUTION_ATTRIBUTES = new ExecutionAttributes(); + private static final Logger log = Logger.loggerFor(EventStreamAsyncResponseTransformer.class); /** * {@link EventStreamResponseHandler} provided by customer. @@ -91,51 +80,7 @@ public final class EventStreamAsyncResponseTransformer */ private final HttpResponseHandler exceptionResponseHandler; - /** - * Remaining demand (i.e number of unmarshalled events) we need to provide to the customers subscriber. - */ - private final AtomicLong remainingDemand = new AtomicLong(0); - - /** - * Reference to customers subscriber to events. - */ - private final AtomicReference> subscriberRef = new AtomicReference<>(); - - private final AtomicReference dataSubscription = new AtomicReference<>(); - - /** - * Event stream message decoder that decodes the binary data into "frames". These frames are then passed to the - * unmarshaller to produce the event POJO. - */ - private final MessageDecoder decoder = new MessageDecoder(this::handleMessage); - - /** - * Tracks whether we have delivered a terminal notification to the subscriber and response handler - * (i.e. exception or completion). - */ - private volatile boolean isDone = false; - - /** - * Executor to deliver events to the subscriber - */ - private final Executor executor; - - /** - * Queue of events to deliver to downstream subscriber. Will contain mostly objects - * of type EventT, the special {@link #ON_COMPLETE_EVENT} will be added when all events - * have been added to the queue. - */ - private final Queue eventsToDeliver = new LinkedList<>(); - - /** - * Flag to indicate we are currently delivering events to the subscriber. - */ - private final AtomicBoolean isDelivering = new AtomicBoolean(false); - - /** - * Flag to indicate we are currently requesting demand from the data publisher. - */ - private final AtomicBoolean isRequesting = new AtomicBoolean(false); + private final Supplier attributesFactory; /** * Future to notify on completion. Note that we do not notify this future in the event of an error, that @@ -146,164 +91,184 @@ public final class EventStreamAsyncResponseTransformer private final CompletableFuture future; /** - * The name of the aws service + * Whether exceptions may be sent to the downstream event stream response handler. This prevents multiple exception + * deliveries from being performed. */ - private final String serviceName; + private final AtomicBoolean exceptionsMayBeSent = new AtomicBoolean(true); + + /** + * The future generated via {@link #prepare()}. + */ + private volatile CompletableFuture transformFuture; /** * Request Id for the streaming request. The value is populated when the initial response is received from the service. * As request id is not sent in event messages (including exceptions), this can be returned by the SDK along with * received exception details. */ - private String requestId = null; - - private volatile CompletableFuture transformFuture; + private volatile String requestId = null; /** * Extended Request Id for the streaming request. The value is populated when the initial response is received from the * service. As request id is not sent in event messages (including exceptions), this can be returned by the SDK along with * received exception details. */ - private String extendedRequestId = null; + private volatile String extendedRequestId = null; private EventStreamAsyncResponseTransformer( EventStreamResponseHandler eventStreamResponseHandler, HttpResponseHandler initialResponseHandler, HttpResponseHandler eventResponseHandler, HttpResponseHandler exceptionResponseHandler, - Executor executor, CompletableFuture future, String serviceName) { - this.eventStreamResponseHandler = eventStreamResponseHandler; this.initialResponseHandler = initialResponseHandler; this.eventResponseHandler = eventResponseHandler; this.exceptionResponseHandler = exceptionResponseHandler; - this.executor = executor; this.future = future; - this.serviceName = serviceName; + this.attributesFactory = () -> new ExecutionAttributes().putAttribute(SdkExecutionAttribute.SERVICE_NAME, serviceName); + } + + /** + * Creates a {@link Builder} used to create {@link EventStreamAsyncResponseTransformer}. + * + * @param Initial response type. + * @param Event type being delivered. + * @return New {@link Builder} instance. + */ + public static Builder builder() { + return new Builder<>(); } @Override public CompletableFuture prepare() { transformFuture = new CompletableFuture<>(); - subscriberRef.set(null); - isDone = false; return transformFuture; } @Override public void onResponse(SdkResponse response) { + // Capture the request IDs from the initial response, so that we can include them in each event. if (response != null && response.sdkHttpResponse() != null) { this.requestId = SdkHttpUtils.firstMatchingHeaderFromCollection(response.sdkHttpResponse().headers(), X_AMZN_REQUEST_ID_HEADERS) .orElse(null); - this.extendedRequestId = response.sdkHttpResponse() .firstMatchingHeader(X_AMZ_ID_2_HEADER) .orElse(null); + + log.debug(() -> getLogPrefix() + "Received HTTP response headers: " + response); } } @Override public void onStream(SdkPublisher publisher) { - CompletableFuture dataSubscriptionFuture = new CompletableFuture<>(); - publisher.subscribe(new ByteSubscriber(dataSubscriptionFuture)); - dataSubscriptionFuture.thenAccept(dataSubscription -> { - SdkPublisher eventPublisher = new EventPublisher(dataSubscription); - try { - eventStreamResponseHandler.onEventStream(eventPublisher); - } catch (Throwable t) { - exceptionOccurred(t); - dataSubscription.cancel(); - } - }); + Validate.isTrue(transformFuture != null, "onStream() invoked without prepare()."); + + exceptionsMayBeSent.set(true); + + SynchronousMessageDecoder decoder = new SynchronousMessageDecoder(); + eventStreamResponseHandler.onEventStream(publisher.flatMapIterable(decoder::decode) + .flatMapIterable(this::transformMessage) + .doAfterOnComplete(this::handleOnStreamComplete) + .doAfterOnError(this::handleOnStreamError) + .doAfterOnCancel(this::handleOnStreamCancel)); } @Override public void exceptionOccurred(Throwable throwable) { - synchronized (this) { - if (!isDone) { - isDone = true; - // If we have a Subscriber at this point notify it as well - if (subscriberRef.get() != null && shouldSurfaceErrorToEventSubscriber(throwable)) { - runAndLogError(log, "Error thrown from Subscriber#onError, ignoring.", - () -> subscriberRef.get().onError(throwable)); - } + if (exceptionsMayBeSent.compareAndSet(true, false)) { + try { eventStreamResponseHandler.exceptionOccurred(throwable); - transformFuture.completeExceptionally(throwable); + } catch (RuntimeException e) { + log.warn(() -> "Exception raised by exceptionOccurred. Ignoring.", e); } + transformFuture.completeExceptionally(throwable); } } - /** - * Called when all events have been delivered to the downstream subscriber. - */ - private void onEventComplete() { - synchronized (this) { - // No op if it's already done - if (isDone) { - return; - } + private void handleOnStreamComplete() { + log.trace(() -> getLogPrefix() + "Event stream completed successfully."); + exceptionsMayBeSent.set(false); + eventStreamResponseHandler.complete(); + transformFuture.complete(null); + future.complete(null); + } + + private void handleOnStreamError(Throwable throwable) { + log.trace(() -> getLogPrefix() + "Event stream failed.", throwable); + exceptionOccurred(throwable); + } - isDone = true; - runAndLogError(log, "Error thrown from Subscriber#onComplete, ignoring.", - () -> subscriberRef.get().onComplete()); - eventStreamResponseHandler.complete(); - future.complete(null); + private void handleOnStreamCancel() { + log.trace(() -> getLogPrefix() + "Event stream cancelled."); + exceptionsMayBeSent.set(false); + transformFuture.complete(null); + future.complete(null); + } + + private static final class SynchronousMessageDecoder { + private final MessageDecoder decoder = new MessageDecoder(); + + private Iterable decode(ByteBuffer bytes) { + decoder.feed(bytes); + return decoder.getDecodedMessages(); } } - /** - * Handle the event stream message according to it's type. - * - * @param m Decoded message. - */ - private void handleMessage(Message m) { + private Iterable transformMessage(Message message) { try { - if (isEvent(m)) { - if (m.getHeaders().get(":event-type").getString().equals("initial-response")) { - eventStreamResponseHandler.responseReceived( - initialResponseHandler.handle(adaptMessageToResponse(m, false), - EMPTY_EXECUTION_ATTRIBUTES)); - } else { - // Add to queue to be delivered later by the executor - eventsToDeliver.add(eventResponseHandler.handle(adaptMessageToResponse(m, false), - EMPTY_EXECUTION_ATTRIBUTES)); - } - } else if (isError(m) || isException(m)) { - SdkHttpFullResponse errorResponse = adaptMessageToResponse(m, true); - Throwable exception = exceptionResponseHandler.handle( - errorResponse, new ExecutionAttributes().putAttribute(SdkExecutionAttribute.SERVICE_NAME, serviceName)); - runAndLogError(log, "Error thrown from exceptionOccurred, ignoring.", () -> exceptionOccurred(exception)); + if (isEvent(message)) { + return transformEventMessage(message); + } else if (isError(message) || isException(message)) { + throw transformErrorMessage(message); + } else { + log.debug(() -> getLogPrefix() + "Decoded a message of an unknown type, it will be dropped: " + message); + return emptyList(); } - } catch (Exception e) { + } catch (Error | SdkException e) { + throw e; + } catch (Throwable e) { throw SdkClientException.builder().cause(e).build(); } } - /** - * @param m Message frame. - * @return True if frame is an event frame, false if not. - */ - private boolean isEvent(Message m) { - return "event".equals(m.getHeaders().get(":message-type").getString()); + private Iterable transformEventMessage(Message message) throws Exception { + SdkHttpFullResponse response = adaptMessageToResponse(message, false); + if (message.getHeaders().get(":event-type").getString().equals("initial-response")) { + ResponseT initialResponse = initialResponseHandler.handle(response, attributesFactory.get()); + eventStreamResponseHandler.responseReceived(initialResponse); + log.debug(() -> getLogPrefix() + "Decoded initial response: " + initialResponse); + return emptyList(); + } + + EventT event = eventResponseHandler.handle(response, attributesFactory.get()); + log.debug(() -> getLogPrefix() + "Decoded event: " + event); + return singleton(event); } - /** - * @param m Message frame. - * @return True if frame is an error frame, false if not. - */ - private boolean isError(Message m) { - return "error".equals(m.getHeaders().get(":message-type").getString()); + private Throwable transformErrorMessage(Message message) throws Exception { + SdkHttpFullResponse errorResponse = adaptMessageToResponse(message, true); + Throwable exception = exceptionResponseHandler.handle(errorResponse, attributesFactory.get()); + log.debug(() -> getLogPrefix() + "Decoded error or exception: " + exception, exception); + return exception; } - /** - * @param m Message frame. - * @return True if frame is an exception frame, false if not. - */ - private boolean isException(Message m) { - return "exception".equals(m.getHeaders().get(":message-type").getString()); + private String getLogPrefix() { + if (requestId == null) { + return ""; + } + + StringBuilder stringBuilder = new StringBuilder(); + stringBuilder.append("("); + stringBuilder.append("RequestId: ").append(requestId); + if (extendedRequestId != null) { + stringBuilder.append(", ExtendedRequestId: ").append(extendedRequestId); + } + stringBuilder.append(") "); + + return stringBuilder.toString(); } /** @@ -312,7 +277,6 @@ private boolean isException(Message m) { * @param message Message to transform. */ private SdkHttpFullResponse adaptMessageToResponse(Message message, boolean isException) { - Map> headers = message.getHeaders() .entrySet() @@ -322,7 +286,6 @@ private SdkHttpFullResponse adaptMessageToResponse(Message message, boolean isEx if (requestId != null) { headers.put(X_AMZN_REQUEST_ID_HEADER, singletonList(requestId)); } - if (extendedRequestId != null) { headers.put(X_AMZ_ID_2_HEADER, singletonList(extendedRequestId)); } @@ -339,202 +302,28 @@ private SdkHttpFullResponse adaptMessageToResponse(Message message, boolean isEx return builder.build(); } - private static boolean shouldSurfaceErrorToEventSubscriber(Throwable t) { - return !(t instanceof SdkCancellationException); - } - - /** - * Subscriber for the raw bytes from the stream. Feeds them to the {@link MessageDecoder} as they arrive - * and will request as much as needed to fulfill any outstanding demand. - */ - private class ByteSubscriber implements Subscriber { - - private final CompletableFuture dataSubscriptionFuture; - - /** - * @param dataSubscriptionFuture Future to notify when the {@link Subscription} object is available. - */ - private ByteSubscriber(CompletableFuture dataSubscriptionFuture) { - this.dataSubscriptionFuture = dataSubscriptionFuture; - } - - @Override - public void onSubscribe(Subscription subscription) { - dataSubscription.set(subscription); - dataSubscriptionFuture.complete(subscription); - } - - @Override - public void onNext(ByteBuffer buffer) { - // Bail out if we've already delivered an exception to the downstream subscriber - if (isDone) { - return; - } - synchronized (eventsToDeliver) { - decoder.feed(BinaryUtils.copyBytesFrom(buffer)); - // If we have things to deliver, do so. - if (!eventsToDeliver.isEmpty()) { - isRequesting.compareAndSet(true, false); - drainEventsIfNotAlready(); - } else { - // If we still haven't fulfilled the outstanding demand then keep requesting byte chunks until we do - if (remainingDemand.get() > 0) { - dataSubscription.get().request(1); - } - } - } - } - - @Override - public void onError(Throwable throwable) { - // Notified in response handler exceptionOccurred because we have more context on what we've delivered to - // the event stream subscriber there. - } - - @Override - public void onComplete() { - // Add the special on complete event to signal drainEvents to complete the subscriber - eventsToDeliver.add(ON_COMPLETE_EVENT); - drainEventsIfNotAlready(); - transformFuture.complete(null); - } - } - - /** - * Publisher of event stream events. Tracks outstanding demand and requests raw data from the stream until that demand is - * fulfilled. - */ - private class EventPublisher implements SdkPublisher { - - private final Subscription dataSubscription; - - private EventPublisher(Subscription dataSubscription) { - this.dataSubscription = dataSubscription; - } - - @Override - public void subscribe(Subscriber subscriber) { - if (subscriberRef.compareAndSet(null, subscriber)) { - subscriber.onSubscribe(new Subscription() { - @Override - public void request(long l) { - if (isDone) { - return; - } - synchronized (eventsToDeliver) { - remainingDemand.addAndGet(l); - if (!eventsToDeliver.isEmpty()) { - drainEventsIfNotAlready(); - } else { - requestDataIfNotAlready(); - } - } - } - - @Override - public void cancel() { - dataSubscription.cancel(); - - // Need to complete the futures, otherwise the downstream subscriber will never - // get notified - future.complete(null); - transformFuture.complete(null); - } - }); - } else { - log.error("Event stream publishers can only be subscribed to once."); - throw new IllegalStateException("This publisher may only be subscribed to once"); - } - } - } - - /** - * Requests data from the {@link ByteBuffer} {@link Publisher} until we have enough data to fulfill demand. If we are - * already requesting data this is a no-op. - */ - private void requestDataIfNotAlready() { - if (isRequesting.compareAndSet(false, true)) { - dataSubscription.get().request(1); - } - } - /** - * Drains events from the queue until the demand is met or all events are delivered. If we are already - * in the process of delivering events this is a no-op. - */ - private void drainEventsIfNotAlready() { - if (isDelivering.compareAndSet(false, true)) { - drainEvents(); - } - } - - /** - * Drains events from the queue until the demand is met or all events are delivered. This differs - * from {@link #drainEventsIfNotAlready()} in that it assumes it has the {@link #isDelivering} 'lease' already. - */ - private void drainEvents() { - // If we've already delivered an exception to the subscriber than bail out - if (isDone) { - return; - } - - if (isCompletedOrDeliverEvent()) { - onEventComplete(); - } - } - - /** - * Checks whether the eventsToDeliver is completed and if it is not completed, - * deliver more events - * - * @return true if the eventsToDeliver is completed, otherwise false. + * @param m Message frame. + * @return True if frame is an event frame, false if not. */ - private boolean isCompletedOrDeliverEvent() { - synchronized (eventsToDeliver) { - if (eventsToDeliver.peek() == ON_COMPLETE_EVENT) { - return true; - } - - if (eventsToDeliver.isEmpty() || remainingDemand.get() == 0) { - isDelivering.compareAndSet(true, false); - // If we still have demand to fulfill then request more if we aren't already requesting - if (remainingDemand.get() > 0) { - requestDataIfNotAlready(); - } - } else { - // Deliver the event and recursively call ourselves after it's delivered - Object event = eventsToDeliver.remove(); - remainingDemand.decrementAndGet(); - CompletableFuture.runAsync(() -> deliverEvent(event), executor) - .thenRunAsync(this::drainEvents, executor) - .whenComplete((v, t) -> { - if (t != null) { - log.error("Error occurred when delivering an event", t); - throw SdkClientException.create("fail to deliver events", t); - } - }); - } - } - return false; + private boolean isEvent(Message m) { + return "event".equals(m.getHeaders().get(":message-type").getString()); } /** - * Delivers the event to the downstream subscriber. We already know the type so the cast is safe. + * @param m Message frame. + * @return True if frame is an error frame, false if not. */ - @SuppressWarnings("unchecked") - private void deliverEvent(Object event) { - subscriberRef.get().onNext((EventT) event); + private boolean isError(Message m) { + return "error".equals(m.getHeaders().get(":message-type").getString()); } /** - * Creates a {@link Builder} used to create {@link EventStreamAsyncResponseTransformer}. - * - * @param Initial response type. - * @param Event type being delivered. - * @return New {@link Builder} instance. + * @param m Message frame. + * @return True if frame is an exception frame, false if not. */ - public static Builder builder() { - return new Builder<>(); + private boolean isException(Message m) { + return "exception".equals(m.getHeaders().get(":message-type").getString()); } /** @@ -549,7 +338,6 @@ public static final class Builder { private HttpResponseHandler initialResponseHandler; private HttpResponseHandler eventResponseHandler; private HttpResponseHandler exceptionResponseHandler; - private Executor executor; private CompletableFuture future; private String serviceName; @@ -596,11 +384,10 @@ public Builder exceptionResponseHandler( } /** - * @param executor Executor used to deliver events. - * @return This object for method chaining. + * This is no longer being used, but is left behind because this is a protected API. */ + @Deprecated public Builder executor(Executor executor) { - this.executor = executor; return this; } @@ -627,10 +414,8 @@ public EventStreamAsyncResponseTransformer build() { initialResponseHandler, eventResponseHandler, exceptionResponseHandler, - executor, future, serviceName); } } - } diff --git a/core/aws-core/src/test/java/software/amazon/awssdk/awscore/eventstream/EventStreamAsyncResponseTransformerTest.java b/core/aws-core/src/test/java/software/amazon/awssdk/awscore/eventstream/EventStreamAsyncResponseTransformerTest.java index d8ca1069a97b..ff81a7eec486 100644 --- a/core/aws-core/src/test/java/software/amazon/awssdk/awscore/eventstream/EventStreamAsyncResponseTransformerTest.java +++ b/core/aws-core/src/test/java/software/amazon/awssdk/awscore/eventstream/EventStreamAsyncResponseTransformerTest.java @@ -87,6 +87,7 @@ public void onComplete() { .executor(Executors.newSingleThreadExecutor()) .future(new CompletableFuture<>()) .build(); + transformer.prepare(); transformer.onStream(SdkPublisher.adapt(bytePublisher)); latch.await(); assertThat(numEvents) @@ -327,9 +328,10 @@ private void verifyExceptionThrown(Map headers) { Flowable bytePublisher = Flowable.just(exceptionMessage.toByteBuffer()); + SubscribingResponseHandler handler = new SubscribingResponseHandler(); AsyncResponseTransformer transformer = EventStreamAsyncResponseTransformer.builder() - .eventStreamResponseHandler(new SubscribingResponseHandler()) + .eventStreamResponseHandler(handler) .exceptionResponseHandler((response, executionAttributes) -> exception) .executor(Executors.newSingleThreadExecutor()) .future(new CompletableFuture<>()) @@ -343,13 +345,16 @@ private void verifyExceptionThrown(Map headers) { cf.join(); } catch (CompletionException e) { if (e.getCause() instanceof SdkServiceException) { - throw ((SdkServiceException) e.getCause()); + throw e.getCause(); } } }).isSameAs(exception); + + assertThat(handler.exceptionOccurredCalled).isTrue(); } private static class SubscribingResponseHandler implements EventStreamResponseHandler { + private volatile boolean exceptionOccurredCalled = false; @Override public void responseReceived(Object response) { @@ -363,6 +368,7 @@ public void onEventStream(SdkPublisher publisher) { @Override public void exceptionOccurred(Throwable throwable) { + exceptionOccurredCalled = true; } @Override diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SdkPublisher.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SdkPublisher.java index 978fe7aa8389..2066110ae5d1 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SdkPublisher.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SdkPublisher.java @@ -22,8 +22,10 @@ import java.util.function.Predicate; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import software.amazon.awssdk.annotations.SdkPublicApi; import software.amazon.awssdk.utils.async.BufferingSubscriber; +import software.amazon.awssdk.utils.async.EventListeningSubscriber; import software.amazon.awssdk.utils.async.FilteringSubscriber; import software.amazon.awssdk.utils.async.FlatteningSubscriber; import software.amazon.awssdk.utils.async.LimitingSubscriber; @@ -116,6 +118,36 @@ default SdkPublisher limit(int limit) { return subscriber -> subscribe(new LimitingSubscriber<>(subscriber, limit)); } + /** + * Add a callback that will be invoked after this publisher invokes {@link Subscriber#onComplete()}. + * + * @param afterOnComplete The logic that should be run immediately after onComplete. + * @return New publisher that invokes the requested callback. + */ + default SdkPublisher doAfterOnComplete(Runnable afterOnComplete) { + return subscriber -> subscribe(new EventListeningSubscriber<>(subscriber, afterOnComplete, null, null)); + } + + /** + * Add a callback that will be invoked after this publisher invokes {@link Subscriber#onError(Throwable)}. + * + * @param afterOnError The logic that should be run immediately after onError. + * @return New publisher that invokes the requested callback. + */ + default SdkPublisher doAfterOnError(Consumer afterOnError) { + return subscriber -> subscribe(new EventListeningSubscriber<>(subscriber, null, afterOnError, null)); + } + + /** + * Add a callback that will be invoked after this publisher invokes {@link Subscription#cancel()}. + * + * @param afterOnCancel The logic that should be run immediately after cancellation of the subscription. + * @return New publisher that invokes the requested callback. + */ + default SdkPublisher doAfterOnCancel(Runnable afterOnCancel) { + return subscriber -> subscribe(new EventListeningSubscriber<>(subscriber, null, null, afterOnCancel)); + } + /** * Subscribes to the publisher with the given {@link Consumer}. This consumer will be called for each event * published. There is no backpressure using this method if the Consumer dispatches processing asynchronously. If more diff --git a/services/kinesis/src/test/java/software/amazon/awssdk/services/kinesis/SubscribeToShardUnmarshallingTest.java b/services/kinesis/src/test/java/software/amazon/awssdk/services/kinesis/SubscribeToShardUnmarshallingTest.java index b287ef83a754..3468ad438c9d 100644 --- a/services/kinesis/src/test/java/software/amazon/awssdk/services/kinesis/SubscribeToShardUnmarshallingTest.java +++ b/services/kinesis/src/test/java/software/amazon/awssdk/services/kinesis/SubscribeToShardUnmarshallingTest.java @@ -29,6 +29,8 @@ import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Before; import org.junit.Test; @@ -201,9 +203,9 @@ private List subscribeToShard() throws Throwable { SubscribeToShardResponseHandler.builder() .subscriber(events::add) .build()) - .join(); + .get(10, TimeUnit.SECONDS); return events; - } catch (CompletionException e) { + } catch (ExecutionException e) { throw e.getCause(); } } @@ -234,9 +236,6 @@ public void request(long l) { @Override public void cancel() { - RuntimeException e = new RuntimeException(); - subscriber.onError(e); - value.onError(e); } })); return cf; diff --git a/utils/pom.xml b/utils/pom.xml index 6cea787b20b7..162b419133b5 100644 --- a/utils/pom.xml +++ b/utils/pom.xml @@ -84,6 +84,11 @@ commons-io test + + org.reactivestreams + reactive-streams-tck + test + diff --git a/utils/src/main/java/software/amazon/awssdk/utils/async/DelegatingSubscriber.java b/utils/src/main/java/software/amazon/awssdk/utils/async/DelegatingSubscriber.java index 72b2fbe9269c..04e4725fd670 100644 --- a/utils/src/main/java/software/amazon/awssdk/utils/async/DelegatingSubscriber.java +++ b/utils/src/main/java/software/amazon/awssdk/utils/async/DelegatingSubscriber.java @@ -15,14 +15,15 @@ package software.amazon.awssdk.utils.async; +import java.util.concurrent.atomic.AtomicBoolean; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import software.amazon.awssdk.annotations.SdkProtectedApi; @SdkProtectedApi public abstract class DelegatingSubscriber implements Subscriber { - protected final Subscriber subscriber; + private final AtomicBoolean complete = new AtomicBoolean(false); protected DelegatingSubscriber(Subscriber subscriber) { this.subscriber = subscriber; @@ -35,12 +36,15 @@ public void onSubscribe(Subscription subscription) { @Override public void onError(Throwable throwable) { - subscriber.onError(throwable); + if (complete.compareAndSet(false, true)) { + subscriber.onError(throwable); + } } @Override public void onComplete() { - subscriber.onComplete(); + if (complete.compareAndSet(false, true)) { + subscriber.onComplete(); + } } - } diff --git a/utils/src/main/java/software/amazon/awssdk/utils/async/EventListeningSubscriber.java b/utils/src/main/java/software/amazon/awssdk/utils/async/EventListeningSubscriber.java new file mode 100644 index 000000000000..7639c57086be --- /dev/null +++ b/utils/src/main/java/software/amazon/awssdk/utils/async/EventListeningSubscriber.java @@ -0,0 +1,91 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import java.util.function.Consumer; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.utils.Logger; + +/** + * A {@link Subscriber} that can invoke callbacks during various parts of the subscriber and subscription lifecycle. + */ +@SdkProtectedApi +public final class EventListeningSubscriber extends DelegatingSubscriber { + private static final Logger log = Logger.loggerFor(EventListeningSubscriber.class); + + private final Runnable afterCompleteListener; + private final Consumer afterErrorListener; + private final Runnable afterCancelListener; + + public EventListeningSubscriber(Subscriber subscriber, + Runnable afterCompleteListener, + Consumer afterErrorListener, + Runnable afterCancelListener) { + super(subscriber); + this.afterCompleteListener = afterCompleteListener; + this.afterErrorListener = afterErrorListener; + this.afterCancelListener = afterCancelListener; + } + + @Override + public void onNext(T t) { + super.subscriber.onNext(t); + } + + @Override + public void onSubscribe(Subscription subscription) { + super.onSubscribe(new CancelListeningSubscriber(subscription)); + } + + @Override + public void onError(Throwable throwable) { + super.onError(throwable); + if (afterErrorListener != null) { + callListener(() -> afterErrorListener.accept(throwable), + "Post-onError callback failed. This exception will be dropped."); + } + } + + @Override + public void onComplete() { + super.onComplete(); + callListener(afterCompleteListener, "Post-onComplete callback failed. This exception will be dropped."); + } + + private class CancelListeningSubscriber extends DelegatingSubscription { + protected CancelListeningSubscriber(Subscription s) { + super(s); + } + + @Override + public void cancel() { + super.cancel(); + callListener(afterCompleteListener, "Post-cancel callback failed. This exception will be dropped."); + } + } + + private void callListener(Runnable listener, String listenerFailureMessage) { + if (listener != null) { + try { + listener.run(); + } catch (RuntimeException e) { + log.error(() -> listenerFailureMessage, e); + } + } + } +} diff --git a/utils/src/main/java/software/amazon/awssdk/utils/async/FlatteningSubscriber.java b/utils/src/main/java/software/amazon/awssdk/utils/async/FlatteningSubscriber.java index 08c11556a836..3303485e1250 100644 --- a/utils/src/main/java/software/amazon/awssdk/utils/async/FlatteningSubscriber.java +++ b/utils/src/main/java/software/amazon/awssdk/utils/async/FlatteningSubscriber.java @@ -15,48 +15,82 @@ package software.amazon.awssdk.utils.async; -import java.util.LinkedList; -import java.util.Queue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; -import java.util.stream.Collectors; -import java.util.stream.StreamSupport; +import java.util.concurrent.atomic.AtomicReference; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; @SdkProtectedApi public class FlatteningSubscriber extends DelegatingSubscriber, U> { + private static final Logger log = Logger.loggerFor(FlatteningSubscriber.class); - private final AtomicLong demand = new AtomicLong(0); - private final Object lock = new Object(); + /** + * The amount of unfulfilled demand open against the upstream subscriber. + */ + private final AtomicLong upstreamDemand = new AtomicLong(0); - private boolean requestedNextBatch; - private Queue currentBatch; - private boolean onCompleteCalled = false; - private Subscription sourceSubscription; + /** + * The amount of unfulfilled demand the downstream subscriber has opened against us. + */ + private final AtomicLong downstreamDemand = new AtomicLong(0); + + /** + * A flag that is used to ensure that only one thread is handling updates to the state of this subscriber at a time. This + * allows us to ensure that the downstream onNext, onComplete and onError are only ever invoked serially. + */ + private final AtomicBoolean handlingStateUpdate = new AtomicBoolean(false); + + /** + * Items given to us by the upstream subscriber that we will use to fulfill demand of the downstream subscriber. + */ + private final LinkedBlockingQueue allItems = new LinkedBlockingQueue<>(); + + /** + * Whether the upstream subscriber has called onError on us. If this is null, we haven't gotten an onError. If it's non-null + * this will be the exception that the upstream passed to our onError. After we get an onError, we'll call onError on the + * downstream subscriber as soon as possible. + */ + private final AtomicReference onErrorFromUpstream = new AtomicReference<>(null); + + /** + * Whether we have called onComplete or onNext on the downstream subscriber. + */ + private volatile boolean terminalCallMadeDownstream = false; + + /** + * Whether the upstream subscriber has called onComplete on us. After this happens, we'll drain any outstanding items in the + * allItems queue and then call onComplete on the downstream subscriber. + */ + private volatile boolean onCompleteCalledByUpstream = false; + + /** + * The subscription to the upstream subscriber. + */ + private Subscription upstreamSubscription; public FlatteningSubscriber(Subscriber subscriber) { super(subscriber); - currentBatch = new LinkedList<>(); } @Override public void onSubscribe(Subscription subscription) { - sourceSubscription = subscription; + if (upstreamSubscription != null) { + log.warn(() -> "Received duplicate subscription, cancelling the duplicate.", new IllegalStateException()); + subscription.cancel(); + return; + } + + upstreamSubscription = subscription; subscriber.onSubscribe(new Subscription() { @Override public void request(long l) { - synchronized (lock) { - demand.addAndGet(l); - // Execution goes into `if` block only once for the initial request - // After that requestedNextBatch is always true and more requests are made in fulfillDemand() - if (!requestedNextBatch) { - requestedNextBatch = true; - sourceSubscription.request(1); - } else { - fulfillDemand(); - } - } + addDownstreamDemand(l); + handleStateUpdate(); } @Override @@ -68,34 +102,165 @@ public void cancel() { @Override public void onNext(Iterable nextItems) { - synchronized (lock) { - currentBatch = StreamSupport.stream(nextItems.spliterator(), false) - .collect(Collectors.toCollection(LinkedList::new)); - fulfillDemand(); + try { + nextItems.forEach(item -> { + Validate.notNull(nextItems, "Collections flattened by the flattening subscriber must not contain null."); + allItems.add(item); + }); + } catch (NullPointerException e) { + upstreamSubscription.cancel(); + onError(e); + throw e; + } + + upstreamDemand.decrementAndGet(); + handleStateUpdate(); + } + + @Override + public void onError(Throwable throwable) { + onErrorFromUpstream.compareAndSet(null, throwable); + handleStateUpdate(); + } + + @Override + public void onComplete() { + onCompleteCalledByUpstream = true; + handleStateUpdate(); + } + + /** + * Increment the downstream demand by the provided value, accounting for overflow. + */ + private void addDownstreamDemand(long l) { + Validate.isTrue(l > 0, "Demand must not be negative."); + downstreamDemand.getAndUpdate(current -> { + long newValue = current + l; + return newValue >= 0 ? newValue : Long.MAX_VALUE; + }); + } + + /** + * This is invoked after each downstream request or upstream onNext, onError or onComplete. + */ + private void handleStateUpdate() { + do { + // Anything that happens after this if statement and before we set handlingStateUpdate to false is guaranteed to only + // happen on one thread. For that reason, we should only invoke onNext, onComplete or onError within that block. + if (!handlingStateUpdate.compareAndSet(false, true)) { + return; + } + + try { + // If we've already called onComplete or onError, don't do anything. + if (terminalCallMadeDownstream) { + return; + } + + // Call onNext, onComplete and onError as needed based on the current subscriber state. + handleOnNextState(); + handleUpstreamDemandState(); + handleOnCompleteState(); + handleOnErrorState(); + } catch (Error e) { + throw e; + } catch (Throwable e) { + log.error(() -> "Unexpected exception encountered that violates the reactive streams specification. Attempting " + + "to terminate gracefully.", e); + upstreamSubscription.cancel(); + onError(e); + } finally { + handlingStateUpdate.set(false); + } + + // It's possible we had an important state change between when we decided to release the state update flag, and we + // actually released it. If that seems to have happened, try to handle that state change on this thread, because + // another thread is not guaranteed to come around and do so. + } while (onNextNeeded() || upstreamDemandNeeded() || onCompleteNeeded() || onErrorNeeded()); + } + + /** + * Fulfill downstream demand by pulling items out of the item queue and sending them downstream. + */ + private void handleOnNextState() { + while (onNextNeeded() && !onErrorNeeded()) { + downstreamDemand.decrementAndGet(); + subscriber.onNext(allItems.poll()); } } - private void fulfillDemand() { - while (demand.get() > 0 && !currentBatch.isEmpty()) { - demand.decrementAndGet(); - subscriber.onNext(currentBatch.poll()); + /** + * Returns true if we need to call onNext downstream. If this is executed outside the handling-state-update condition, the + * result is subject to change. + */ + private boolean onNextNeeded() { + return !allItems.isEmpty() && downstreamDemand.get() > 0; + } + + /** + * Request more upstream demand if it's needed. + */ + private void handleUpstreamDemandState() { + if (upstreamDemandNeeded()) { + ensureUpstreamDemandExists(); } + } - if (onCompleteCalled && currentBatch.isEmpty()) { + /** + * Returns true if we need to increase our upstream demand. + */ + private boolean upstreamDemandNeeded() { + return upstreamDemand.get() <= 0 && downstreamDemand.get() > 0 && allItems.isEmpty(); + } + + /** + * If there are zero pending items in the queue and the upstream has called onComplete, then tell the downstream + * we're done. + */ + private void handleOnCompleteState() { + if (onCompleteNeeded()) { + terminalCallMadeDownstream = true; subscriber.onComplete(); - } else if (currentBatch.isEmpty() && demand.get() > 0) { - requestedNextBatch = true; - sourceSubscription.request(1); } } - @Override - public void onComplete() { - synchronized (lock) { - onCompleteCalled = true; - if (currentBatch.isEmpty()) { - subscriber.onComplete(); - } + /** + * Returns true if we need to call onNext downstream. If this is executed outside the handling-state-update condition, the + * result is subject to change. + */ + private boolean onCompleteNeeded() { + return allItems.isEmpty() && onCompleteCalledByUpstream && !terminalCallMadeDownstream; + } + + /** + * If the upstream has called onError, then tell the downstream we're done, no matter what state the queue is in. + */ + private void handleOnErrorState() { + if (onErrorNeeded()) { + terminalCallMadeDownstream = true; + subscriber.onError(onErrorFromUpstream.get()); + } + } + + /** + * Returns true if we need to call onError downstream. If this is executed outside the handling-state-update condition, the + * result is subject to change. + */ + private boolean onErrorNeeded() { + return onErrorFromUpstream.get() != null && !terminalCallMadeDownstream; + } + + /** + * Ensure that we have at least 1 demand upstream, so that we can get more items. + */ + private void ensureUpstreamDemandExists() { + if (this.upstreamDemand.get() < 0) { + log.error(() -> "Upstream delivered more data than requested. Resetting state to prevent a frozen stream.", + new IllegalStateException()); + upstreamDemand.set(1); + upstreamSubscription.request(1); + } else if (this.upstreamDemand.compareAndSet(0, 1)) { + upstreamSubscription.request(1); } } } diff --git a/utils/src/main/java/software/amazon/awssdk/utils/async/SequentialSubscriber.java b/utils/src/main/java/software/amazon/awssdk/utils/async/SequentialSubscriber.java index e66afb50d2bd..77db8e2d15b5 100644 --- a/utils/src/main/java/software/amazon/awssdk/utils/async/SequentialSubscriber.java +++ b/utils/src/main/java/software/amazon/awssdk/utils/async/SequentialSubscriber.java @@ -28,7 +28,6 @@ */ @SdkProtectedApi public class SequentialSubscriber implements Subscriber { - private final Consumer consumer; private final CompletableFuture future; private Subscription subscription; diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/FlatteningSubscriberTckTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/FlatteningSubscriberTckTest.java new file mode 100644 index 000000000000..b4c27f991849 --- /dev/null +++ b/utils/src/test/java/software/amazon/awssdk/utils/async/FlatteningSubscriberTckTest.java @@ -0,0 +1,74 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import java.util.Arrays; +import java.util.concurrent.CompletableFuture; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.reactivestreams.tck.SubscriberWhiteboxVerification; +import org.reactivestreams.tck.TestEnvironment; + +public class FlatteningSubscriberTckTest extends SubscriberWhiteboxVerification> { + protected FlatteningSubscriberTckTest() { + super(new TestEnvironment()); + } + + @Override + public Subscriber> createSubscriber(WhiteboxSubscriberProbe> probe) { + Subscriber foo = new SequentialSubscriber<>(s -> {}, new CompletableFuture<>()); + return new FlatteningSubscriber(foo) { + @Override + public void onError(Throwable throwable) { + super.onError(throwable); + probe.registerOnError(throwable); + } + + @Override + public void onSubscribe(Subscription subscription) { + super.onSubscribe(subscription); + probe.registerOnSubscribe(new SubscriberPuppet() { + @Override + public void triggerRequest(long elements) { + subscription.request(elements); + } + + @Override + public void signalCancel() { + subscription.cancel(); + } + }); + } + + @Override + public void onNext(Iterable nextItems) { + super.onNext(nextItems); + probe.registerOnNext(nextItems); + } + + @Override + public void onComplete() { + super.onComplete(); + probe.registerOnComplete(); + } + }; + } + + @Override + public Iterable createElement(int element) { + return Arrays.asList(element, element); + } +} \ No newline at end of file diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/FlatteningSubscriberTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/FlatteningSubscriberTest.java new file mode 100644 index 000000000000..fc03fdead024 --- /dev/null +++ b/utils/src/test/java/software/amazon/awssdk/utils/async/FlatteningSubscriberTest.java @@ -0,0 +1,204 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import static org.mockito.Mockito.times; + +import java.util.Arrays; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +public class FlatteningSubscriberTest { + private Subscriber mockDelegate; + private Subscription mockUpstream; + private FlatteningSubscriber flatteningSubscriber; + + @Before + @SuppressWarnings("unchecked") + public void setup() { + mockDelegate = Mockito.mock(Subscriber.class); + mockUpstream = Mockito.mock(Subscription.class); + flatteningSubscriber = new FlatteningSubscriber<>(mockDelegate); + } + + @Test + public void requestOne() { + flatteningSubscriber.onSubscribe(mockUpstream); + + Subscription downstream = getDownstreamFromDelegate(); + downstream.request(1); + Mockito.verify(mockUpstream).request(1); + + flatteningSubscriber.onNext(Arrays.asList("foo", "bar")); + + Mockito.verify(mockDelegate).onNext("foo"); + + Mockito.verifyNoMoreInteractions(mockUpstream, mockDelegate); + } + + @Test + public void requestTwo() { + flatteningSubscriber.onSubscribe(mockUpstream); + + Subscription downstream = getDownstreamFromDelegate(); + downstream.request(2); + + Mockito.verify(mockUpstream).request(1); + + flatteningSubscriber.onNext(Arrays.asList("foo", "bar")); + + Mockito.verify(mockDelegate).onNext("foo"); + Mockito.verify(mockDelegate).onNext("bar"); + Mockito.verifyNoMoreInteractions(mockUpstream, mockDelegate); + } + + @Test + public void requestThree() { + flatteningSubscriber.onSubscribe(mockUpstream); + + Subscription downstream = getDownstreamFromDelegate(); + downstream.request(3); + + Mockito.verify(mockUpstream, times(1)).request(1); + Mockito.verifyNoMoreInteractions(mockUpstream, mockDelegate); + Mockito.reset(mockUpstream, mockDelegate); + + flatteningSubscriber.onNext(Arrays.asList("foo", "bar")); + + Mockito.verify(mockDelegate).onNext("foo"); + Mockito.verify(mockDelegate).onNext("bar"); + Mockito.verify(mockUpstream).request(1); + Mockito.verifyNoMoreInteractions(mockUpstream, mockDelegate); + Mockito.reset(mockUpstream, mockDelegate); + + flatteningSubscriber.onNext(Arrays.asList("baz")); + + Mockito.verify(mockDelegate).onNext("baz"); + Mockito.verifyNoMoreInteractions(mockUpstream, mockDelegate); + } + + @Test + public void requestInfinite() { + flatteningSubscriber.onSubscribe(mockUpstream); + + Subscription downstream = getDownstreamFromDelegate(); + downstream.request(1); + downstream.request(Long.MAX_VALUE); + downstream.request(Long.MAX_VALUE); + downstream.request(Long.MAX_VALUE); + downstream.request(Long.MAX_VALUE); + + Mockito.verify(mockUpstream, times(1)).request(1); + + flatteningSubscriber.onNext(Arrays.asList("foo", "bar")); + flatteningSubscriber.onComplete(); + + Mockito.verify(mockDelegate).onNext("foo"); + Mockito.verify(mockDelegate).onNext("bar"); + Mockito.verify(mockDelegate).onComplete(); + Mockito.verifyNoMoreInteractions(mockDelegate); + } + + @Test + public void onCompleteDelayedUntilAllDataDelivered() { + flatteningSubscriber.onSubscribe(mockUpstream); + + Subscription downstream = getDownstreamFromDelegate(); + downstream.request(1); + + Mockito.verify(mockUpstream).request(1); + + flatteningSubscriber.onNext(Arrays.asList("foo", "bar")); + flatteningSubscriber.onComplete(); + + Mockito.verify(mockDelegate).onNext("foo"); + Mockito.verifyNoMoreInteractions(mockUpstream, mockDelegate); + Mockito.reset(mockUpstream, mockDelegate); + + downstream.request(1); + Mockito.verify(mockDelegate).onNext("bar"); + Mockito.verify(mockDelegate).onComplete(); + Mockito.verifyNoMoreInteractions(mockUpstream, mockDelegate); + } + + @Test + public void onErrorDropsBufferedData() { + Throwable t = new Throwable(); + + flatteningSubscriber.onSubscribe(mockUpstream); + + Subscription downstream = getDownstreamFromDelegate(); + downstream.request(1); + + Mockito.verify(mockUpstream).request(1); + + flatteningSubscriber.onNext(Arrays.asList("foo", "bar")); + flatteningSubscriber.onError(t); + + Mockito.verify(mockDelegate).onNext("foo"); + Mockito.verify(mockDelegate).onError(t); + Mockito.verifyNoMoreInteractions(mockUpstream, mockDelegate); + } + + @Test + public void requestsFromDownstreamDoNothingAfterOnComplete() { + flatteningSubscriber.onSubscribe(mockUpstream); + + Subscription downstream = getDownstreamFromDelegate(); + downstream.request(1); + + Mockito.verify(mockUpstream).request(1); + + flatteningSubscriber.onComplete(); + + Mockito.verify(mockDelegate).onComplete(); + Mockito.verifyNoMoreInteractions(mockUpstream, mockDelegate); + + downstream.request(1); + Mockito.verifyNoMoreInteractions(mockUpstream, mockDelegate); + } + + @Test + public void requestsFromDownstreamDoNothingAfterOnError() { + Throwable t = new Throwable(); + + flatteningSubscriber.onSubscribe(mockUpstream); + + Subscription downstream = getDownstreamFromDelegate(); + downstream.request(1); + + Mockito.verify(mockUpstream).request(1); + + flatteningSubscriber.onError(t); + + Mockito.verify(mockDelegate).onError(t); + Mockito.verifyNoMoreInteractions(mockUpstream, mockDelegate); + + downstream.request(1); + Mockito.verifyNoMoreInteractions(mockUpstream, mockDelegate); + } + + private Subscription getDownstreamFromDelegate() { + ArgumentCaptor subscriptionCaptor = ArgumentCaptor.forClass(Subscription.class); + Mockito.verify(mockDelegate).onSubscribe(subscriptionCaptor.capture()); + return subscriptionCaptor.getValue(); + } + +} \ No newline at end of file