diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9e8bba7..c0e7801 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -60,7 +60,7 @@ jobs: uses: coursier/cache-action@v6 - name: Check Document Generation run: sbt docs/compileDocs - + publish: runs-on: ubuntu-20.04 needs: [lint, test, website] diff --git a/build.sbt b/build.sbt index 4d55ac0..a135eff 100644 --- a/build.sbt +++ b/build.sbt @@ -55,12 +55,13 @@ lazy val interopReactiveStreams = project .settings(testFrameworks += new TestFramework("zio.test.sbt.ZTestFramework")) .settings( libraryDependencies ++= Seq( - "dev.zio" %% "zio" % zioVersion, - "dev.zio" %% "zio-streams" % zioVersion, - "dev.zio" %% "zio-test" % zioVersion % Test, - "dev.zio" %% "zio-test-sbt" % zioVersion % Test, - "org.reactivestreams" % "reactive-streams" % rsVersion, - "org.reactivestreams" % "reactive-streams-tck" % rsVersion % Test + "dev.zio" %% "zio" % zioVersion, + "dev.zio" %% "zio-streams" % zioVersion, + "dev.zio" %% "zio-test" % zioVersion % Test, + "dev.zio" %% "zio-test-sbt" % zioVersion % Test, + "org.reactivestreams" % "reactive-streams" % rsVersion, + "org.reactivestreams" % "reactive-streams-tck" % rsVersion % Test, + "net.sourceforge.streamsupport" % "streamsupport-flow" % "1.7.4" % Test ), libraryDependencies ++= { if (scalaVersion.value == ScalaDotty) diff --git a/zio-interop-reactivestreams/src/main/scala/zio/interop/reactivestreams/Adapters.scala b/zio-interop-reactivestreams/src/main/scala/zio/interop/reactivestreams/Adapters.scala index c0859cf..3cbfbd3 100644 --- a/zio-interop-reactivestreams/src/main/scala/zio/interop/reactivestreams/Adapters.scala +++ b/zio-interop-reactivestreams/src/main/scala/zio/interop/reactivestreams/Adapters.scala @@ -3,304 +3,457 @@ package zio.interop.reactivestreams import org.reactivestreams.Publisher import org.reactivestreams.Subscriber import org.reactivestreams.Subscription +import org.reactivestreams.Processor import zio._ import zio.Unsafe._ import zio.internal.RingBuffer import zio.stream._ -import zio.stream.ZStream.Pull -import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicReference +import zio.stream.internal.AsyncInputConsumer +import zio.stream.internal.AsyncInputProducer +import java.util.concurrent.atomic.AtomicBoolean +import zio.UIO +import scala.util.control.NoStackTrace object Adapters { def streamToPublisher[R, E <: Throwable, O]( stream: => ZStream[R, E, O] - )(implicit trace: Trace): ZIO[R, Nothing, Publisher[O]] = - ZIO.runtime.map { runtime => subscriber => - if (subscriber == null) { - throw new NullPointerException("Subscriber must not be null.") - } else - unsafe { implicit unsafe => - val subscription = new DemandTrackingSubscription(subscriber) - runtime.unsafe.fork( - for { - _ <- ZIO.succeed(subscriber.onSubscribe(subscription)) - _ <- stream - .run(demandUnfoldSink(subscriber, subscription)) - .catchAll(e => ZIO.succeed(subscriber.onError(e))) - .forkDaemon - } yield () - ) - () - } - } + )(implicit trace: Trace): ZIO[R, Nothing, Publisher[O]] = ZIO.runtime[R].map { runtime => subscriber => + if (subscriber == null) { + throw new NullPointerException("Subscriber must not be null.") + } else + unsafe { implicit unsafe => + val subscription = new SubscriptionProducer[O](subscriber) + subscriber.onSubscribe(subscription) + runtime.unsafe.fork( + (stream.toChannel >>> ZChannel.fromZIO(subscription.awaitCompletion).embedInput(subscription)).runDrain + ) + () + } + } def subscriberToSink[E <: Throwable, I]( subscriber: => Subscriber[I] )(implicit trace: Trace): ZIO[Scope, Nothing, (E => UIO[Unit], ZSink[Any, Nothing, I, I, Unit])] = - unsafe { implicit unsafe => - val sub = subscriber - for { - error <- Promise.make[E, Nothing] - subscription = new DemandTrackingSubscription(sub) - _ <- ZIO.succeed(sub.onSubscribe(subscription)) - fiber <- error.await.interruptible.catchAll(t => ZIO.succeed(sub.onError(t))).forkScoped - } yield (error.fail(_) *> fiber.join, demandUnfoldSink(sub, subscription)) + ZIO.suspendSucceed { + unsafe { implicit unsafe => + val subscription = new SubscriptionProducer[I](subscriber) + + def reader: ZChannel[Any, ZNothing, Chunk[I], Any, Nothing, Chunk[I], Unit] = ZChannel.readWithCause( + i => ZChannel.fromZIO(subscription.emit(i)) *> reader, + e => ZChannel.failCause(e), + d => ZChannel.fromZIO(subscription.done(d)) *> ZChannel.succeed(()) + ) + + for { + _ <- ZIO.acquireRelease(ZIO.succeed(subscriber.onSubscribe(subscription)))(_ => + ZIO.succeed(subscription.cancel()) + ) + } yield ((e: E) => subscription.error(Cause.fail(e)).sandbox.ignore, ZSink.fromChannel(reader)) + } } def publisherToStream[O]( publisher: => Publisher[O], bufferSize: => Int - )(implicit trace: Trace): ZStream[Any, Throwable, O] = { - - val pullOrFail = - for { - subscriberP <- makeSubscriber[O](bufferSize) - (subscriber, p) = subscriberP - _ <- ZIO.acquireRelease(ZIO.succeed(publisher.subscribe(subscriber)))(_ => ZIO.succeed(subscriber.interrupt())) - subQ <- p.await - (sub, q) = subQ - process <- process(sub, q, () => subscriber.await(), () => subscriber.isDone) - } yield process - val pull = pullOrFail.catchAll(e => ZIO.succeed(Pull.fail(e))) - fromPull[Any, Throwable, O](pull) - } + )(implicit trace: Trace): ZStream[Any, Throwable, O] = publisherToChannel(publisher, bufferSize).toStream def sinkToSubscriber[R, I, L, Z]( sink: => ZSink[R, Throwable, I, L, Z], bufferSize: => Int - )(implicit trace: Trace): ZIO[R with Scope, Throwable, (Subscriber[I], IO[Throwable, Z])] = + )(implicit trace: Trace): ZIO[R with Scope, Throwable, (Subscriber[I], IO[Throwable, Z])] = { + def promChannel(prom: Promise[Throwable, Z]): ZChannel[R, Throwable, Chunk[L], Z, Any, Any, Any] = + ZChannel.readWithCause( + ZChannel.write(_) *> promChannel(prom), + e => ZChannel.fromZIO(prom.failCause(e)) *> ZChannel.failCause(e), + d => ZChannel.fromZIO(prom.succeed(d)) *> ZChannel.succeed(d) + ) + for { - subscriberP <- makeSubscriber[I](bufferSize) - (subscriber, p) = subscriberP - pull = p.await.flatMap { case (subscription, q) => - process(subscription, q, () => subscriber.await(), () => subscriber.isDone, bufferSize) - } - .catchAll(e => ZIO.succeedNow(Pull.fail(e))) - fiber <- fromPull(pull).run(sink).forkScoped - } yield (subscriber, fiber.join) - - private def process[A]( - sub: Subscription, - q: RingBuffer[A], - await: () => IO[Option[Throwable], Unit], - isDone: () => Boolean, - maxChunkSize: Int = Int.MaxValue - ): ZIO[Scope, Nothing, ZIO[Any, Option[Throwable], Chunk[A]]] = + prom <- Promise.make[Throwable, Z] + subscriber <- channelToSubscriber( + (ZChannel.identity[Throwable, Chunk[I], Any] pipeToOrFail sink.channel) >>> promChannel(prom), + bufferSize + ) + } yield (subscriber, prom.await) + } + + /** Upstream errors will not be passed to the processor. If you want errors to be passed, convert the processor to a + * channel instead. + */ + def processorToPipeline[I, O]( + processor: Processor[I, O], + bufferSize: Int = 16 + )(implicit trace: Trace): ZPipeline[Any, Throwable, I, O] = ZPipeline.unwrapScoped { + val subscription = new SubscriptionProducer[I](processor)(unsafe) + val subscriber = new SubscriberConsumer[O](bufferSize)(unsafe) + val passthrough = new PassthroughAsyncInput(subscription, subscriber) + for { - _ <- ZIO.succeed(sub.request(q.capacity.toLong)) - requestedRef <- Ref.make(q.capacity.toLong) // TODO: maybe turn into unfold? - } yield { - def pull: Pull[Any, Throwable, A] = - for { - requested <- requestedRef.get - pollSize = Math.min(requested, maxChunkSize.toLong).toInt - chunk <- ZIO.succeed(q.pollUpTo(pollSize)) - r <- - if (chunk.isEmpty) - await() *> pull - else - (if (chunk.size == pollSize && !isDone()) - ZIO.succeed(sub.request(q.capacity.toLong)) *> requestedRef.set(q.capacity.toLong) - else requestedRef.set(requested - chunk.size)) *> - Pull.emit(chunk) - } yield r - - pull + _ <- + ZIO.acquireRelease(ZIO.succeed(processor.subscribe(subscriber)))(_ => subscriber.cancelSubscription) + _ <- ZIO.acquireRelease(ZIO.succeed(processor.onSubscribe(subscription)))(_ => ZIO.succeed(subscription.cancel())) + } yield ZPipeline.fromChannel(ZChannel.fromInput(passthrough).embedInput(passthrough)) + } + + def pipelineToProcessor[R <: Scope, I, O]( + pipeline: ZPipeline[R, Throwable, I, O], + bufferSize: Int = 16 + )(implicit trace: Trace): ZIO[R, Nothing, Processor[I, O]] = + channelToProcessor(ZChannel.identity pipeToOrFail pipeline.channel, bufferSize) + + def channelToPublisher[R <: Scope, O]( + channel: ZChannel[R, Any, Any, Any, Throwable, Chunk[O], Any] + ): ZIO[R, Nothing, Publisher[O]] = ZIO.runtime[R].map { runtime => + new Publisher[O] { + def subscribe(subscriber: Subscriber[_ >: O]): Unit = + if (subscriber == null) { + throw new NullPointerException("Subscriber must not be null.") + } else { + val subscription = new SubscriptionProducer[O](subscriber)(unsafe) + unsafe { implicit u => + runtime.unsafe.run { + for { + _ <- ZIO.acquireRelease(ZIO.succeed(subscriber.onSubscribe(subscription)))(_ => + ZIO.succeed(subscription.cancel()) + ) + _ <- (channel >>> ZChannel + .fromZIO(subscription.awaitCompletion) + .embedInput(subscription)).runDrain.forkScoped + } yield () + }.getOrThrow() + } + } } + } - private trait InterruptibleSubscriber[A] extends Subscriber[A] { - def interrupt(): Unit - def await(): IO[Option[Throwable], Unit] - def isDone: Boolean + def channelToSubscriber[R <: Scope, I]( + channel: ZChannel[R, Throwable, Chunk[I], Any, Any, Any, Any], + bufferSize: Int = 16 + ): ZIO[R, Nothing, Subscriber[I]] = ZIO.suspendSucceed { + val subscriber = new SubscriberConsumer[I](bufferSize)(unsafe) + for { + _ <- ZIO.addFinalizer(subscriber.cancelSubscription) + _ <- (ZChannel.fromInput(subscriber) >>> channel).runDrain.forkScoped + } yield subscriber } - private def makeSubscriber[A]( - capacity: Int - ): ZIO[ - Scope, - Nothing, - ( - InterruptibleSubscriber[A], - Promise[Throwable, (Subscription, RingBuffer[A])] - ) - ] = + def channelToProcessor[R <: Scope, I, O]( + channel: ZChannel[R, Throwable, Chunk[I], Any, Throwable, Chunk[O], Any], + bufferSize: Int = 16 + ): ZIO[R, Nothing, Processor[I, O]] = for { - q <- ZIO.succeed(RingBuffer[A](capacity)) - p <- ZIO.acquireRelease( - Promise - .make[Throwable, (Subscription, RingBuffer[A])] - )( - _.poll.flatMap(_.fold(ZIO.unit)(_.foldZIO(_ => ZIO.unit, { case (sub, _) => ZIO.succeed(sub.cancel()) }))) - ) - } yield unsafe { implicit unsafe => - val subscriber = - new InterruptibleSubscriber[A] { - - val isSubscribedOrInterrupted = new AtomicBoolean - @volatile - var done: Option[Option[Throwable]] = None - @volatile - var toNotify: Option[Promise[Option[Throwable], Unit]] = None - - override def interrupt(): Unit = - isSubscribedOrInterrupted.set(true) - - override def await(): IO[Option[Throwable], Unit] = - done match { - case Some(value) => - if (q.isEmpty()) ZIO.fail(value) else ZIO.unit - case None => - val p = Promise.unsafe.make[Option[Throwable], Unit](FiberId.None) - toNotify = Some(p) - // An element has arrived in the meantime, we do not need to start waiting. - if (!q.isEmpty()) { - toNotify = None - ZIO.unit - } else - done.fold(p.await) { e => - // The producer has canceled or errored in the meantime. - toNotify = None - if (q.isEmpty()) ZIO.fail(e) else ZIO.unit - } - } + runtime <- ZIO.runtime[R] + subscriber = new SubscriberConsumer[I](bufferSize)(unsafe) + _ <- ZIO.addFinalizer(subscriber.cancelSubscription) + } yield new Processor[I, O] { + def onSubscribe(s: Subscription): Unit = subscriber.onSubscribe(s) - override def isDone: Boolean = done.isDefined + def onNext(t: I): Unit = subscriber.onNext(t) - override def onSubscribe(s: Subscription): Unit = - if (s == null) { - val e = new NullPointerException("s was null in onSubscribe") - p.unsafe.done(ZIO.fail(e)) - throw e - } else { - val shouldCancel = isSubscribedOrInterrupted.getAndSet(true) - if (shouldCancel) - s.cancel() - else - p.unsafe.done(ZIO.succeedNow((s, q))) - } + def onError(t: Throwable): Unit = subscriber.onError(t) - override def onNext(t: A): Unit = - if (t == null) { - failNPE("t was null in onNext") - } else { - q.offer(t) - toNotify.foreach(_.unsafe.done(ZIO.unit)) - } + def onComplete(): Unit = subscriber.onComplete() - override def onError(e: Throwable): Unit = - if (e == null) - failNPE("t was null in onError") - else - fail(e) + def subscribe(s: Subscriber[_ >: O]): Unit = { + val subscription = new SubscriptionProducer[O](s)(unsafe) + unsafe { implicit u => + runtime.unsafe.run { + for { + finalizerRef <- Ref.make(ZIO.unit) + _ <- ZIO.addFinalizer(finalizerRef.get.flatten) + _ <- (ZIO.succeed(s.onSubscribe(subscription)) *> finalizerRef.set( + ZIO.succeed(subscription.cancel()) + )).uninterruptible + _ <- + (ZChannel.fromInput(subscriber) >>> channel >>> ZChannel + .fromZIO(subscription.awaitCompletion *> finalizerRef.set(ZIO.unit) *> subscriber.cancelSubscription) + .embedInput(subscription)).runDrain.forkScoped + } yield () + }.getOrThrow() + } + } + } - override def onComplete(): Unit = { - done = Some(None) - toNotify.foreach(_.unsafe.done(ZIO.fail(None))) - } + def publisherToChannel[O]( + publisher: Publisher[O], + bufferSize: Int = 16 + )(implicit trace: Trace): ZChannel[Any, Any, Any, Any, Throwable, Chunk[O], Any] = ZChannel.unwrapScoped[Any] { + val subscriber = new SubscriberConsumer[O](bufferSize)(unsafe) - private def failNPE(msg: String) = { - val e = new NullPointerException(msg) - fail(e) - throw e - } + for { + _ <- + ZIO.acquireRelease(ZIO.succeed(publisher.subscribe(subscriber)))(_ => subscriber.cancelSubscription) + } yield ZChannel.fromInput(subscriber) + } - private def fail(e: Throwable) = { - done = Some(Some(e)) - toNotify.foreach(_.unsafe.done(ZIO.fail(Some(e)))) - } + def subscriberToChannel[I]( + consumer: Subscriber[I] + )(implicit trace: Trace): ZChannel[Any, Throwable, Chunk[I], Any, Any, Any, Any] = ZChannel.unwrapScoped[Any] { + val subscription = new SubscriptionProducer[I](consumer)(unsafe) + + for { + _ <- ZIO.acquireRelease(ZIO.succeed(consumer.onSubscribe(subscription)))(_ => ZIO.succeed(subscription.cancel())) + } yield ZChannel.fromZIO(subscription.awaitCompletion).embedInput(subscription) + } + + def processorToChannel[I, O]( + processor: Processor[I, O], + bufferSize: Int = 16 + )(implicit trace: Trace): ZChannel[Any, Throwable, Chunk[I], Any, Throwable, Chunk[O], Any] = ZChannel.unwrapScoped { + val subscription = new SubscriptionProducer[I](processor)(unsafe) + val subscriber = new SubscriberConsumer[O](bufferSize)(unsafe) + + for { + _ <- + ZIO.acquireRelease(ZIO.succeed(processor.subscribe(subscriber)))(_ => subscriber.cancelSubscription) + _ <- ZIO.acquireRelease(ZIO.succeed(processor.onSubscribe(subscription)))(_ => ZIO.succeed(subscription.cancel())) + } yield ZChannel.fromInput(subscriber).embedInput(subscription) + } + private class SubscriptionProducer[A](sub: Subscriber[_ >: A])(implicit unsafe: Unsafe) + extends Subscription + with AsyncInputProducer[Throwable, Chunk[A], Any] { + import SubscriptionProducer.State + + private val state: AtomicReference[State[A]] = new AtomicReference(State.initial[A]) + private val completed: Promise[Nothing, Unit] = Promise.unsafe.make(FiberId.None) + + val awaitCompletion: UIO[Unit] = completed.await + + def request(n: Long): Unit = + if (n <= 0) sub.onError(new IllegalArgumentException("non-positive subscription request")) + else { + state.getAndUpdate { + case State.Running(demand) => State.Running(demand + n) + case State.Waiting(_) => State.Running(n) + case other => other + } match { + case State.Waiting(resume) => resume.unsafe.done(ZIO.unit) + case _ => () } + } - (subscriber, p) + def cancel(): Unit = { + state.getAndSet(State.Cancelled) match { + case State.Waiting(resume) => resume.unsafe.done(ZIO.interrupt) + case _ => () + } + completed.unsafe.done(ZIO.unit) } - private def demandUnfoldSink[I]( - subscriber: Subscriber[I], - subscription: DemandTrackingSubscription - ): ZSink[Any, Nothing, I, I, Unit] = - ZSink - .foldChunksZIO[Any, Nothing, I, Boolean](true)(identity) { (_, chunk) => - ZIO - .iterate(chunk)(!_.isEmpty) { chunk => - subscription - .offer(chunk.size) - .flatMap { acceptedCount => - ZIO - .foreach(chunk.take(acceptedCount))(a => ZIO.succeed(subscriber.onNext(a))) - .as(chunk.drop(acceptedCount)) - } + def emit(el: Chunk[A])(implicit trace: zio.Trace): UIO[Any] = ZIO.suspendSucceed { + if (el.isEmpty) ZIO.unit + else + ZIO.suspendSucceed { + state.getAndUpdate { + case State.Running(demand) => + if (demand > el.size) + State.Running(demand - el.size) + else + State.Waiting(Promise.unsafe.make[Nothing, Unit](FiberId.None)) + case other => other + } match { + case State.Waiting(resume) => + resume.await *> emit(el) + case State.Running(demand) => + if (demand > el.size) + ZIO.succeed(el.foreach(sub.onNext(_))) + else + ZIO.succeed(el.take(demand.toInt).foreach(sub.onNext(_))) *> emit(el.drop(demand.toInt)) + case State.Cancelled => + ZIO.interrupt } - .fold( - _ => false, // canceled - _ => true - ) + } + } + + def done(a: Any)(implicit trace: zio.Trace): UIO[Any] = ZIO.suspendSucceed { + state.getAndSet(State.Cancelled) match { + case State.Running(_) => ZIO.succeed(sub.onComplete()) *> completed.succeed(()) + case State.Cancelled => ZIO.interrupt + case State.Waiting(resume) => ZIO.succeed(sub.onComplete()) *> resume.interrupt *> completed.succeed(()) } - .map(_ => if (!subscription.isCanceled) subscriber.onComplete()) - - private class DemandTrackingSubscription(subscriber: Subscriber[_])(implicit val unsafe: Unsafe) - extends Subscription { - - private case class State( - requestedCount: Long, // -1 when cancelled - toNotify: Option[(Int, Promise[Unit, Int])] - ) - - private val initial = State(0L, None) - private val canceled = State(-1, None) - private def requested(n: Long) = State(n, None) - private def awaiting(n: Int, p: Promise[Unit, Int]) = State(0L, Some((n, p))) - - private val state = new AtomicReference(initial) - - def offer(n: Int): IO[Unit, Int] = { - var result: IO[Unit, Int] = null - state.updateAndGet { - case `canceled` => - result = ZIO.fail(()) - canceled - case State(0L, _) => - val p = Promise.unsafe.make[Unit, Int](FiberId.None) - result = p.await - awaiting(n, p) - case State(requestedCount, _) => - val newRequestedCount = Math.max(requestedCount - n, 0L) - val accepted = Math.min(requestedCount, n.toLong).toInt - result = ZIO.succeedNow(accepted) - requested(newRequestedCount) + } + + def error(cause: Cause[Throwable])(implicit trace: zio.Trace): UIO[Any] = ZIO.suspendSucceed { + state.getAndSet(State.Cancelled) match { + case State.Running(_) => + ZIO.succeed { + cause.failureOrCause.fold( + sub.onError, + c => sub.onError(UpstreamDefect(c)) + ) + } *> completed.succeed(()) + case State.Cancelled => ZIO.interrupt + case State.Waiting(resume) => + ZIO.succeed { + cause.failureOrCause.fold( + sub.onError, + c => sub.onError(UpstreamDefect(c)) + ) + } *> resume.interrupt *> completed.succeed(()) } - result } - def isCanceled: Boolean = state.get().requestedCount < 0 - - override def request(n: Long): Unit = { - if (n <= 0) subscriber.onError(new IllegalArgumentException("non-positive subscription request")) - var notification: () => Unit = () => () - state.getAndUpdate { - case `canceled` => - canceled - case State(requestedCount, Some((offered, toNotify))) => - val newRequestedCount = requestedCount + n - val accepted = Math.min(offered.toLong, newRequestedCount) - val remaining = newRequestedCount - accepted - notification = () => toNotify.unsafe.done(ZIO.succeedNow(accepted.toInt)) - requested(remaining) - case State(requestedCount, _) if ((Long.MaxValue - n) > requestedCount) => - requested(requestedCount + n) - case _ => - requested(Long.MaxValue) + def awaitRead(implicit trace: zio.Trace): UIO[Any] = ZIO.unit + } + + private object SubscriptionProducer { + sealed trait State[+A] + object State { + def initial[A](implicit unsafe: Unsafe): State[A] = Waiting(Promise.unsafe.make[Nothing, Unit](FiberId.None)) + + final case class Waiting(resume: Promise[Nothing, Unit]) extends State[Nothing] + final case class Running(demand: Long) extends State[Nothing] + case object Cancelled extends State[Nothing] + } + } + + private class SubscriberConsumer[A](capacity: Int)(implicit unsafe: Unsafe) + extends Subscriber[A] + with AsyncInputConsumer[Throwable, Chunk[A], Unit] { + import SubscriberConsumer.State + + private val subscription: Promise[Nothing, Subscription] = Promise.unsafe.make(FiberId.None) + private val buffer: RingBuffer[A] = RingBuffer(capacity) + private val state: AtomicReference[State] = new AtomicReference(State.Drained) + private val isSubscribedOrCanceled: AtomicBoolean = new AtomicBoolean(false) + + def onSubscribe(s: Subscription): Unit = + if (!isSubscribedOrCanceled.compareAndSet(false, true)) { + s.cancel() + } else { + subscription.unsafe.done(ZIO.succeedNow(s)) + s.request(buffer.capacity.toLong) + } + + def onNext(t: A): Unit = + if (t == null) { + throw new NullPointerException("t was null in onNext") + } else if (!buffer.offer(t)) { + throw new IllegalStateException("buffer is full") + } else { + state.getAndUpdate { + case State.Drained => State.Full + case State.Waiting(_) => State.Full + case other => other + } match { + case State.Waiting(promise) => promise.unsafe.done(ZIO.unit) + case _ => () + } + } + + def onError(t: Throwable): Unit = + if (t == null) { + throw new NullPointerException("t was null in onError") + } else { + state.getAndSet(State.Failed(t)) match { + case State.Waiting(promise) => promise.unsafe.done(ZIO.unit) + case _ => () + } + } + + def onComplete(): Unit = + state.getAndSet(State.Ended) match { + case State.Waiting(promise) => promise.unsafe.done(ZIO.unit) + case _ => () + } + + def cancelSubscription: UIO[Unit] = + ZIO.succeed(isSubscribedOrCanceled.set(true)) *> + subscription.poll.flatMap(ZIO.foreachDiscard(_)(_.map(_.cancel()).exit)) *> + subscription.interrupt.unit *> + ZIO.succeed(state.getAndSet(State.Canceled) match { + case State.Waiting(promise) => promise.unsafe.done(ZIO.unit) + case _ => () + }) + + def takeWith[B](onError: Cause[Throwable] => B, onElement: Chunk[A] => B, onDone: Unit => B)(implicit + trace: zio.Trace + ): UIO[B] = subscription.await.flatMap { sub => + ZIO.suspendSucceed { + state.updateAndGet { + case State.Drained => State.Waiting(Promise.unsafe.make[Nothing, Unit](FiberId.None)) + case State.Full => State.Drained + case other => other + } match { + case State.Drained => + val data = buffer.pollUpTo(buffer.capacity) + val dataSize = data.size.toLong + if (dataSize > 0) { + sub.request(data.size.toLong) + ZIO.succeedNow(onElement(data)) + } else { + ZIO.succeedNow(onElement(Chunk.empty)) + } + + case State.Full => ??? // impossible + + case State.Waiting(promise) => + promise.await *> takeWith(onError, onElement, onDone) + + case State.Failed(t) => + // drain remaining data before failing + val data = buffer.pollUpTo(buffer.capacity) + if (data.nonEmpty) ZIO.succeedNow(onElement(data)) + else { + t match { + case UpstreamDefect(cause) => ZIO.succeedNow(onError(cause)) + case err => ZIO.succeedNow(onError(Cause.fail(err))) + } + } + case State.Ended => + // drain remaining data before failing + val data = buffer.pollUpTo(buffer.capacity) + if (data.nonEmpty) ZIO.succeedNow(onElement(data)) else ZIO.succeedNow(onDone(())) + + case State.Canceled => + ZIO.interrupt + } } - notification() } + } + + private object SubscriberConsumer { + + sealed trait State + + object State { + final case class Waiting(promise: Promise[Nothing, Unit]) extends State + case object Drained extends State + case object Full extends State + final case class Failed(cause: Throwable) extends State + case object Ended extends State + case object Canceled extends State + + } + + } - override def cancel(): Unit = - state.getAndSet(canceled).toNotify.foreach { case (_, p) => p.unsafe.done(ZIO.fail(())) } + private final case class UpstreamDefect(cause: Cause[Nothing]) extends NoStackTrace { + override def getMessage(): String = s"Upsteam defect: ${cause.prettyPrint}" } - private def fromPull[R, E, A](zio: ZIO[R with Scope, Nothing, ZIO[R, Option[E], Chunk[A]]])(implicit - trace: Trace - ): ZStream[R, E, A] = - ZStream.unwrapScoped[R](zio.map(pull => ZStream.repeatZIOChunkOption(pull))) + class PassthroughAsyncInput[I, O]( + producer: AsyncInputProducer[Nothing, I, Any], + consumer: AsyncInputConsumer[Throwable, O, Any] + ) extends AsyncInputProducer[Throwable, I, Any] + with AsyncInputConsumer[Throwable, O, Any] { + private val error: Promise[Nothing, Cause[Throwable]] = unsafe(implicit u => Promise.unsafe.make(FiberId.None)) + + def takeWith[A](onError: Cause[Throwable] => A, onElement: O => A, onDone: Any => A)(implicit + trace: zio.Trace + ): UIO[A] = + consumer.takeWith(onError, onElement, onDone) race error.await.map(onError) + + def emit(el: I)(implicit trace: zio.Trace): UIO[Any] = producer.emit(el) + + def done(a: Any)(implicit trace: zio.Trace): UIO[Any] = producer.done(a) + + def error(cause: Cause[Throwable])(implicit trace: zio.Trace): UIO[Any] = error.succeed(cause) + def awaitRead(implicit trace: zio.Trace): UIO[Any] = producer.awaitRead + + } } diff --git a/zio-interop-reactivestreams/src/main/scala/zio/interop/reactivestreams/package.scala b/zio-interop-reactivestreams/src/main/scala/zio/interop/reactivestreams/package.scala index 6ac7b4a..879930e 100644 --- a/zio-interop-reactivestreams/src/main/scala/zio/interop/reactivestreams/package.scala +++ b/zio-interop-reactivestreams/src/main/scala/zio/interop/reactivestreams/package.scala @@ -2,9 +2,13 @@ package zio.interop import org.reactivestreams.Publisher import org.reactivestreams.Subscriber -import zio.{ Scope, UIO, Task, ZIO, Trace } +import zio.{ Scope, UIO, Task, ZIO, Trace, URIO } import zio.stream.ZSink import zio.stream.ZStream +import org.reactivestreams.Processor +import zio.stream.ZPipeline +import zio.stream.ZChannel +import zio.Chunk package object reactivestreams { @@ -39,6 +43,11 @@ package object reactivestreams { */ def toZIOStream(qSize: Int = 16)(implicit trace: Trace): ZStream[Any, Throwable, O] = Adapters.publisherToStream(publisher, qSize) + + def toPublisherZIOChannel(bufferSize: Int = 16)(implicit + trace: Trace + ): ZChannel[Any, Any, Any, Any, Throwable, Chunk[O], Any] = + Adapters.publisherToChannel(publisher, bufferSize) } final implicit class subscriberToSink[I](private val subscriber: Subscriber[I]) extends AnyVal { @@ -57,6 +66,26 @@ package object reactivestreams { trace: Trace ): ZIO[Scope, Nothing, (E => UIO[Unit], ZSink[Any, Nothing, I, I, Unit])] = Adapters.subscriberToSink(subscriber) + + def toSubscriberZIOChannel(implicit trace: Trace): ZChannel[Any, Throwable, Chunk[I], Any, Any, Any, Any] = + Adapters.subscriberToChannel(subscriber) + } + + final implicit class processorToPipeline[I, O](private val processor: Processor[I, O]) extends AnyVal { + + def toZIOPipeline(implicit trace: Trace): ZPipeline[Any, Throwable, I, O] = + Adapters.processorToPipeline(processor) + + def toProcessorZIOChannel(implicit + trace: Trace + ): ZChannel[Any, Throwable, Chunk[I], Any, Throwable, Chunk[O], Any] = + Adapters.processorToChannel(processor) } + final implicit class pipelineToProcessor[R <: Scope, I, O](private val pipeline: ZPipeline[R, Throwable, I, O]) + extends AnyVal { + + def toProcessor(implicit trace: Trace): URIO[R, Processor[I, O]] = + Adapters.pipelineToProcessor(pipeline) + } } diff --git a/zio-interop-reactivestreams/src/test/scala/zio/interop/reactivestreams/PipelineToProcessorSpec.scala b/zio-interop-reactivestreams/src/test/scala/zio/interop/reactivestreams/PipelineToProcessorSpec.scala new file mode 100644 index 0000000..bfb19a8 --- /dev/null +++ b/zio-interop-reactivestreams/src/test/scala/zio/interop/reactivestreams/PipelineToProcessorSpec.scala @@ -0,0 +1,78 @@ +package zio.interop.reactivestreams + +import zio.test.Assertion._ +import zio.test._ +import org.reactivestreams.tck.IdentityProcessorVerification +import org.testng.annotations.Test +import zio._ +import zio.stream.ZPipeline +import org.reactivestreams.tck +import zio.Unsafe.unsafe +import org.reactivestreams.Processor +import java.util.concurrent.Executors +import java.lang.reflect.InvocationTargetException +import org.testng.SkipException + +object PipelineToProcessorSpec extends ZIOSpecDefault { + + override def spec = + suite("Converting a `Pipeline` to a `Processor`")( + suite("passes all required and optional TCK tests")( + tckTests: _* + ) + ) + + val managedVerification = + for { + runtime <- ZIO.runtime[Scope] + executor <- ZIO.succeed(Executors.newFixedThreadPool(4)) + _ <- ZIO.addFinalizer(ZIO.succeed(executor.shutdown())) + env = new tck.TestEnvironment(1000, 500) + ver = new IdentityProcessorVerification[Int](env) { + override def createIdentityProcessor( + bufferSize: Int + ): Processor[Int, Int] = + unsafe { implicit u => + runtime.unsafe.run(Adapters.pipelineToProcessor(ZPipeline.identity[Int], bufferSize)).getOrThrow() + } + + override def createElement(n: Int): Int = n + + override def createFailedPublisher() = null + + override def publisherExecutorService() = executor + + override def maxSupportedSubscribers() = 1 + + override def boundedDepthOfOnNextAndRequestRecursion() = 1 + } + _ <- ZIO.succeed(ver.setUp()) + } yield ver + + val tckTests = + classOf[IdentityProcessorVerification[Int]] + .getMethods() + .toList + .filter { method => + method + .getAnnotations() + .exists(annotation => classOf[Test].isAssignableFrom(annotation.annotationType())) + } + .collect { + case method if method.getName().startsWith("untested") => + test(method.getName())(assert(())(anything)) @@ TestAspect.ignore + case method => + test(method.getName())( + ZIO.scoped[Any] { + for { + ver <- managedVerification + r <- ZIO + .attemptBlockingInterrupt(method.invoke(ver)) + .unit + .refineOrDie { case e: InvocationTargetException => e.getTargetException() } + .exit + } yield assert(r)(fails(isSubtype[SkipException](anything)) || succeeds(isUnit)) + } + ) + } +} diff --git a/zio-interop-reactivestreams/src/test/scala/zio/interop/reactivestreams/ProcessorToPipelineSpec.scala b/zio-interop-reactivestreams/src/test/scala/zio/interop/reactivestreams/ProcessorToPipelineSpec.scala new file mode 100644 index 0000000..e002066 --- /dev/null +++ b/zio-interop-reactivestreams/src/test/scala/zio/interop/reactivestreams/ProcessorToPipelineSpec.scala @@ -0,0 +1,145 @@ +package zio.interop.reactivestreams + +import zio.Chunk +import zio.UIO +import zio.ZIO +import zio.stream.ZStream +import zio.test.Assertion._ +import zio.test._ +import scala.collection.mutable.ListBuffer +import java8.util.concurrent.SubmissionPublisher +import java8.util.concurrent.{ Flow => Flow8 } +import org.reactivestreams.{ Processor, Subscriber, Subscription } + +object ProcessorToPipelineSpec extends ZIOSpecDefault { + + override def spec = + suite("Converting a `Processor` to a `Pipeline`")( + test("works with a well behaved `Publisher`") { + val processor = new TestProcessor((i: Int) => i.toString()) + + val effect = ZStream(1, 2, 3, 4, 5).via(processor.toZIOPipeline).runCollect + + for { + result <- effect + events <- processor.getEvents + } yield assert(result)(equalTo(Chunk("1", "2", "3", "4", "5"))) && + assert(events)( + equalTo( + List( + ProcessorEvent.OnSubscribe, + ProcessorEvent.OnNext(1), + ProcessorEvent.OnNext(2), + ProcessorEvent.OnNext(3), + ProcessorEvent.OnNext(4), + ProcessorEvent.OnNext(5), + ProcessorEvent.OnComplete + ) + ) + ) + }, + test("passes through errors without offering them to the processor") { + val processor = new TestProcessor((i: Int) => i.toString()) + val err = new RuntimeException() + + val effect = (ZStream(1, 2) ++ ZStream.fail(err)).via(processor.toZIOPipeline).runCollect + + for { + result <- effect.exit + events <- processor.getEvents + } yield assert(result)(fails(equalTo(err))) && + assert(events)( + equalTo( + List( + ProcessorEvent.OnSubscribe, + ProcessorEvent.OnNext(1), + ProcessorEvent.OnNext(2) + ) + ) + ) + }, + test("passes through errors when converting to a raw channel") { + val processor = new TestProcessor((i: Int) => i.toString()) + val err = new RuntimeException() + + val effect = ((ZStream(1, 2) ++ ZStream.fail(err)).channel >>> processor.toProcessorZIOChannel).runCollect + + for { + result <- effect.exit + events <- processor.getEvents + } yield assert(result)(fails(equalTo(err))) && + assert(events)( + equalTo( + List( + ProcessorEvent.OnSubscribe, + ProcessorEvent.OnNext(1), + ProcessorEvent.OnNext(2), + ProcessorEvent.OnError(err) + ) + ) + ) + } + ) @@ TestAspect.withLiveClock + + sealed trait ProcessorEvent[+A] + object ProcessorEvent { + case object OnSubscribe extends ProcessorEvent[Nothing] + final case class OnNext[A](item: A) extends ProcessorEvent[A] + final case class OnError(error: Throwable) extends ProcessorEvent[Nothing] + case object OnComplete extends ProcessorEvent[Nothing] + } + + final class TestProcessor[A, B](f: A => B) extends Processor[A, B] { + + private var subscription: Subscription = null + private val submissionPublisher = new SubmissionPublisher[B]() + private val events = ListBuffer[ProcessorEvent[A]]() + + def onSubscribe(subscription: Subscription): Unit = { + this.events += ProcessorEvent.OnSubscribe + this.subscription = subscription; + subscription.request(1); + } + + def onNext(item: A): Unit = { + this.events += ProcessorEvent.OnNext(item) + submissionPublisher.submit(f(item)); + subscription.request(1); + } + + def onError(error: Throwable): Unit = { + this.events += ProcessorEvent.OnError(error) + submissionPublisher.closeExceptionally(error); + } + + def onComplete(): Unit = { + this.events += ProcessorEvent.OnComplete + submissionPublisher.close(); + } + + def getEvents: UIO[List[ProcessorEvent[A]]] = + ZIO.succeed(this.events.toList) + + def subscribe(subscriber: Subscriber[_ >: B]): Unit = + submissionPublisher.subscribe(new CompatSubscriber[B](subscriber)) + } + + final class CompatSubscriber[B](underlying: Subscriber[_ >: B]) extends Flow8.Subscriber[B] { + def onSubscribe(subscription: Flow8.Subscription): Unit = + underlying.onSubscribe(new CompatSubscription(subscription)) + + def onNext(item: B): Unit = underlying.onNext(item) + + def onError(throwable: Throwable): Unit = underlying.onError(throwable) + + def onComplete(): Unit = underlying.onComplete() + + } + + final class CompatSubscription(underlying: Flow8.Subscription) extends Subscription { + def request(n: Long): Unit = + underlying.request(n) + def cancel(): Unit = + underlying.cancel() + } +} diff --git a/zio-interop-reactivestreams/src/test/scala/zio/interop/reactivestreams/PublisherToStreamSpec.scala b/zio-interop-reactivestreams/src/test/scala/zio/interop/reactivestreams/PublisherToStreamSpec.scala index bef3763..18404a5 100644 --- a/zio-interop-reactivestreams/src/test/scala/zio/interop/reactivestreams/PublisherToStreamSpec.scala +++ b/zio-interop-reactivestreams/src/test/scala/zio/interop/reactivestreams/PublisherToStreamSpec.scala @@ -193,6 +193,7 @@ object PublisherToStreamSpec extends ZIOSpecDefault { Adapters.publisherToStream(new NumberIterablePublisher(0, 1, executor.asJava), 16).runCount } .map(_.sum) + } yield assert(sum)(equalTo(10000L)) } ) diff --git a/zio-interop-reactivestreams/src/test/scala/zio/interop/reactivestreams/SinkToSubscriberSpec.scala b/zio-interop-reactivestreams/src/test/scala/zio/interop/reactivestreams/SinkToSubscriberSpec.scala index 9c735d0..553e923 100644 --- a/zio-interop-reactivestreams/src/test/scala/zio/interop/reactivestreams/SinkToSubscriberSpec.scala +++ b/zio-interop-reactivestreams/src/test/scala/zio/interop/reactivestreams/SinkToSubscriberSpec.scala @@ -4,7 +4,7 @@ import org.reactivestreams.{ Publisher, Subscriber, Subscription } import org.reactivestreams.tck.SubscriberWhiteboxVerification.{ SubscriberPuppet, WhiteboxSubscriberProbe } import org.reactivestreams.tck.{ SubscriberWhiteboxVerification, TestEnvironment } import org.testng.annotations.Test -import zio.{ Chunk, Promise, ZIO, durationInt, durationLong } +import zio.{ Promise, ZIO, durationInt, durationLong, Chunk } import zio.stream.ZSink import zio.test.Assertion._ import zio.test._ diff --git a/zio-interop-reactivestreams/src/test/scala/zio/interop/reactivestreams/SubscriberToSinkSpec.scala b/zio-interop-reactivestreams/src/test/scala/zio/interop/reactivestreams/SubscriberToSinkSpec.scala index c7128c8..c55c536 100644 --- a/zio-interop-reactivestreams/src/test/scala/zio/interop/reactivestreams/SubscriberToSinkSpec.scala +++ b/zio-interop-reactivestreams/src/test/scala/zio/interop/reactivestreams/SubscriberToSinkSpec.scala @@ -88,7 +88,22 @@ object SubscriberToSinkSpec extends ZIOSpecDefault { err2 <- probe.expectError.timeout(100.millis).exit } yield assert(err)(succeeds(equalTo(e))) && assert(err2)(fails(anything)) } - } + }, + test("transports errors when transforming to channel") { + makeSubscriber.flatMap(probe => + ZIO.scoped[Any] { + val channel = probe.underlying.toSubscriberZIOChannel + for { + fiber <- ((ZStream.fromIterable(seq) ++ ZStream.fail(e)).channel >>> channel).runDrain.fork + _ <- ZIO.sleep(100.millis) + _ <- probe.request(length + 1) + elements <- probe.nextElements(length).exit + err <- probe.expectError.exit + _ <- fiber.join + } yield assert(elements)(succeeds(equalTo(seq))) && assert(err)(succeeds(equalTo(e))) + } + ) + } @@ TestAspect.withLiveClock ) val seq: List[Int] = List.range(0, 31)