diff --git a/modules/gcp/src/main/resources/application.conf b/modules/gcp/src/main/resources/application.conf index f60e218..aa0fafc 100644 --- a/modules/gcp/src/main/resources/application.conf +++ b/modules/gcp/src/main/resources/application.conf @@ -15,9 +15,8 @@ "durationPerAckExtension": "10 minutes" "minRemainingDeadline": 0.1 "progressTimeout": "10 seconds" - "modackOnProgressTimeout": true - "cancelOnProgressTimeout": false - "consistentClientId": true + "prefetchMin": 1 + "prefetchMax": 10 } "output": { "bad": ${snowplow.defaults.sinks.pubsub} diff --git a/modules/gcp/src/main/scala/common-streams-extensions/v2/LeaseManager.scala b/modules/gcp/src/main/scala/common-streams-extensions/v2/LeaseManager.scala new file mode 100644 index 0000000..7c0b949 --- /dev/null +++ b/modules/gcp/src/main/scala/common-streams-extensions/v2/LeaseManager.scala @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2023-present Snowplow Analytics Ltd. All rights reserved. + * + * This program is licensed to you under the Snowplow Community License Version 1.0, + * and you may not use this file except in compliance with the Snowplow Community License Version 1.0. + * You may obtain a copy of the Snowplow Community License Version 1.0 at https://docs.snowplow.io/community-license-1.0 + */ +package com.snowplowanalytics.snowplow.sources.pubsub.v2 + +import cats.effect.{Async, Ref, Resource, Sync} +import cats.effect.kernel.Unique +import cats.implicits._ +import cats.effect.implicits._ +import org.typelevel.log4cats.Logger +import org.typelevel.log4cats.slf4j.Slf4jLogger + +import com.google.cloud.pubsub.v1.stub.SubscriberStub + +private trait LeaseManager[F[_], A] { + def manageLeases(in: A): F[Unique.Token] + def stopManagingLeases(tokens: Vector[Unique.Token]): F[Unit] +} + +private object LeaseManager { + + private implicit def logger[F[_]: Sync]: Logger[F] = Slf4jLogger.getLogger[F] + + def resource[F[_]: Async]( + config: PubsubSourceConfigV2, + stub: SubscriberStub, + ref: Ref[F, Map[Unique.Token, PubsubBatchState]], + channelAffinity: Int + ): Resource[F, LeaseManager[F, SubscriberAction.ProcessRecords]] = + extendDeadlinesInBackground[F](config, stub, ref, channelAffinity) + .as(impl(config, ref, channelAffinity)) + + private def impl[F[_]: Sync]( + config: PubsubSourceConfigV2, + ref: Ref[F, Map[Unique.Token, PubsubBatchState]], + channelAffinity: Int + ): LeaseManager[F, SubscriberAction.ProcessRecords] = new LeaseManager[F, SubscriberAction.ProcessRecords] { + + def manageLeases(in: SubscriberAction.ProcessRecords): F[Unique.Token] = + Unique[F].unique.flatMap { token => + val deadline = in.timeReceived.plusMillis(config.durationPerAckExtension.toMillis) + val ackIds = in.records.map(_.getAckId) + val state = PubsubBatchState(deadline, ackIds, channelAffinity) + ref.update(_ + (token -> state)).as(token) + } + + def stopManagingLeases(tokens: Vector[Unique.Token]): F[Unit] = + ref.update(_.removedAll(tokens)) + } + + private def extendDeadlinesInBackground[F[_]: Async]( + config: PubsubSourceConfigV2, + stub: SubscriberStub, + refStates: Ref[F, Map[Unique.Token, PubsubBatchState]], + channelAffinity: Int + ): Resource[F, Unit] = { + def go: F[Unit] = for { + now <- Sync[F].realTimeInstant + minAllowedDeadline = now.plusMillis((config.minRemainingDeadline * config.durationPerAckExtension.toMillis).toLong) + newDeadline = now.plusMillis(config.durationPerAckExtension.toMillis) + toExtend <- refStates.modify { m => + val toExtend = m.filter { case (_, batchState) => + batchState.channelAffinity === channelAffinity && batchState.currentDeadline.isBefore(minAllowedDeadline) + } + val fixed = toExtend.view + .mapValues(_.copy(currentDeadline = newDeadline)) + .toMap + (m ++ fixed, toExtend.values.toVector) + } + _ <- if (toExtend.isEmpty) + Sync[F].sleep(0.5 * config.minRemainingDeadline * config.durationPerAckExtension) + else { + val ackIds = toExtend.sortBy(_.currentDeadline).flatMap(_.ackIds.toVector) + Utils.modAck[F](config.subscription, stub, ackIds, config.durationPerAckExtension, channelAffinity) + } + _ <- go + } yield () + go.background.void + } + +} diff --git a/modules/gcp/src/main/scala/common-streams-extensions/v2/PubsubBatchState.scala b/modules/gcp/src/main/scala/common-streams-extensions/v2/PubsubBatchState.scala index 498a968..b51e99b 100644 --- a/modules/gcp/src/main/scala/common-streams-extensions/v2/PubsubBatchState.scala +++ b/modules/gcp/src/main/scala/common-streams-extensions/v2/PubsubBatchState.scala @@ -7,6 +7,8 @@ */ package com.snowplowanalytics.snowplow.sources.pubsub.v2 +import cats.data.NonEmptyVector + import java.time.Instant /** @@ -23,6 +25,6 @@ import java.time.Instant */ private case class PubsubBatchState( currentDeadline: Instant, - ackIds: Vector[String], + ackIds: NonEmptyVector[String], channelAffinity: Int ) diff --git a/modules/gcp/src/main/scala/common-streams-extensions/v2/PubsubCheckpointer.scala b/modules/gcp/src/main/scala/common-streams-extensions/v2/PubsubCheckpointer.scala index 631586d..83ec8dc 100644 --- a/modules/gcp/src/main/scala/common-streams-extensions/v2/PubsubCheckpointer.scala +++ b/modules/gcp/src/main/scala/common-streams-extensions/v2/PubsubCheckpointer.scala @@ -60,7 +60,7 @@ class PubsubCheckpointer[F[_]: Async]( ackDatas <- refAckIds.modify(m => (m.removedAll(c), c.flatMap(m.get))) grouped = ackDatas.groupBy(_.channelAffinity) _ <- grouped.toVector.parTraverse_ { case (channelAffinity, ackDatas) => - ackDatas.flatMap(_.ackIds).grouped(1000).toVector.traverse_ { ackIds => + ackDatas.flatMap(_.ackIds.toVector).grouped(1000).toVector.traverse_ { ackIds => val request = AcknowledgeRequest.newBuilder.setSubscription(subscription.show).addAllAckIds(ackIds.asJava).build val context = GrpcCallContext.createDefault.withChannelAffinity(channelAffinity) val attempt = for { @@ -88,7 +88,7 @@ class PubsubCheckpointer[F[_]: Async]( ackDatas <- refAckIds.modify(m => (m.removedAll(c), c.flatMap(m.get))) grouped = ackDatas.groupBy(_.channelAffinity) _ <- grouped.toVector.parTraverse_ { case (channelAffinity, ackDatas) => - val ackIds = ackDatas.flatMap(_.ackIds) + val ackIds = ackDatas.flatMap(_.ackIds.toVector) // A nack is just a modack with zero duration Utils.modAck[F](subscription, stub, ackIds, Duration.Zero, channelAffinity) } diff --git a/modules/gcp/src/main/scala/common-streams-extensions/v2/PubsubSourceConfigV2.scala b/modules/gcp/src/main/scala/common-streams-extensions/v2/PubsubSourceConfigV2.scala index 700e61a..f16d7e0 100644 --- a/modules/gcp/src/main/scala/common-streams-extensions/v2/PubsubSourceConfigV2.scala +++ b/modules/gcp/src/main/scala/common-streams-extensions/v2/PubsubSourceConfigV2.scala @@ -25,9 +25,8 @@ case class PubsubSourceConfigV2( gcpUserAgent: GcpUserAgent, maxPullsPerTransportChannel: Int, progressTimeout: FiniteDuration, - modackOnProgressTimeout: Boolean, - cancelOnProgressTimeout: Boolean, - consistentClientId: Boolean + prefetchMin: Int, + prefetchMax: Int ) object PubsubSourceConfigV2 { diff --git a/modules/gcp/src/main/scala/common-streams-extensions/v2/PubsubSourceV2.scala b/modules/gcp/src/main/scala/common-streams-extensions/v2/PubsubSourceV2.scala index 2adef3c..64b4b3d 100644 --- a/modules/gcp/src/main/scala/common-streams-extensions/v2/PubsubSourceV2.scala +++ b/modules/gcp/src/main/scala/common-streams-extensions/v2/PubsubSourceV2.scala @@ -8,11 +8,10 @@ package com.snowplowanalytics.snowplow.sources.pubsub.v2 import cats.effect.{Async, Deferred, Ref, Resource, Sync} -import cats.effect.std.{Hotswap, Queue, QueueSink} import cats.effect.kernel.Unique -import cats.effect.implicits._ import cats.implicits._ -import fs2.{Chunk, Pipe, Stream} +import cats.effect.implicits._ +import fs2.{Chunk, Stream} import org.typelevel.log4cats.Logger import org.typelevel.log4cats.slf4j.Slf4jLogger @@ -20,13 +19,10 @@ import java.time.Instant // pubsub import com.google.api.gax.core.{ExecutorProvider, FixedExecutorProvider} -import com.google.api.gax.grpc.{ChannelPoolSettings, GrpcCallContext} -import com.google.api.gax.rpc.{ResponseObserver, StreamController} +import com.google.api.gax.grpc.ChannelPoolSettings import com.google.cloud.pubsub.v1.SubscriptionAdminSettings import com.google.cloud.pubsub.v1.stub.SubscriberStubSettings -import com.google.pubsub.v1.{StreamingPullRequest, StreamingPullResponse} import com.google.cloud.pubsub.v1.stub.{GrpcSubscriberStub, SubscriberStub} -import io.grpc.Status import org.threeten.bp.{Duration => ThreetenDuration} // snowplow @@ -34,11 +30,9 @@ import com.snowplowanalytics.snowplow.pubsub.GcpUserAgent import com.snowplowanalytics.snowplow.sources.SourceAndAck import com.snowplowanalytics.snowplow.sources.internal.{Checkpointer, LowLevelEvents, LowLevelSource} -import scala.concurrent.duration.{Duration, DurationDouble, FiniteDuration} import scala.jdk.CollectionConverters._ - +import scala.concurrent.duration.{Duration, FiniteDuration} import java.util.concurrent.{ExecutorService, Executors, LinkedBlockingQueue} -import java.util.UUID object PubsubSourceV2 { @@ -72,6 +66,7 @@ object PubsubSourceV2 { channelCount = chooseNumTransportChannels(config, parallelPullCount) stub <- Stream.resource(stubResource(config, channelCount)) refStates <- Stream.eval(Ref[F].of(Map.empty[Unique.Token, PubsubBatchState])) + _ <- Stream.bracket(Sync[F].unit)(_ => nackRefStatesForShutdown(config, stub, refStates)) _ <- Stream.eval(deferredResources.complete(PubsubCheckpointer.Resources(stub, refStates))) } yield Stream .range(0, parallelPullCount) @@ -83,157 +78,56 @@ object PubsubSourceV2 { stub: SubscriberStub, refStates: Ref[F, Map[Unique.Token, PubsubBatchState]], channelAffinity: Int - ): Stream[F, LowLevelEvents[Vector[Unique.Token]]] = { - val jQueue = new LinkedBlockingQueue[SubscriberAction]() - val clientId = UUID.randomUUID - val resource = initializeStreamingPull[F](config, stub, jQueue, channelAffinity, clientId) - + ): Stream[F, LowLevelEvents[Vector[Unique.Token]]] = for { - (hotswap, _) <- Stream.resource(Hotswap(resource)) - fs2Queue <- Stream.eval(Queue.synchronous[F, SubscriberAction]) - _ <- extendDeadlines(config, stub, refStates, channelAffinity).spawn - _ <- Stream.eval(queueToQueue(config, jQueue, fs2Queue, stub, channelAffinity)).repeat.spawn - lle <- Stream - .fromQueueUnterminated(fs2Queue) - .through(toLowLevelEvents(config, refStates, hotswap, resource, channelAffinity)) - } yield lle - } - - private def queueToQueue[F[_]: Async]( - config: PubsubSourceConfigV2, - jQueue: LinkedBlockingQueue[SubscriberAction], - fs2Queue: QueueSink[F, SubscriberAction], - stub: SubscriberStub, - channelAffinity: Int - ): F[Unit] = - resolveNextAction(jQueue).flatMap { - case action @ SubscriberAction.ProcessRecords(records, controller, _) => - val fallback = if (config.modackOnProgressTimeout) { - val ackIds = records.map(_.getAckId) - if (config.cancelOnProgressTimeout) - Logger[F].debug(s"Cancelling Pubsub channel $channelAffinity for not making progress") *> - Sync[F].delay(controller.cancel()) *> Utils.modAck(config.subscription, stub, ackIds, Duration.Zero, channelAffinity) - else - Logger[F].debug(s"Nacking on Pubsub channel $channelAffinity for not making progress") *> - Sync[F].delay(controller.request(1)) *> Utils.modAck(config.subscription, stub, ackIds, Duration.Zero, channelAffinity) - } else { - if (config.cancelOnProgressTimeout) - Logger[F].debug(s"Cancelling Pubsub channel $channelAffinity for not making progress") *> - Sync[F].delay(controller.cancel()) *> fs2Queue.offer(action) - else - fs2Queue.offer(action) - } - fs2Queue.offer(action).timeoutTo(config.progressTimeout, fallback) - case action: SubscriberAction.SubscriberError => - fs2Queue.offer(action) + jQueue <- Stream.emit(new LinkedBlockingQueue[SubscriberAction]()) + _ <- Stream.bracket(Sync[F].unit)(_ => nackQueueForShutdown(config, stub, jQueue, channelAffinity)) + streamManager <- Stream.resource(StreamManager.resource(config, stub, jQueue, channelAffinity)) + leaseManager <- Stream.resource(LeaseManager.resource(config, stub, refStates, channelAffinity)) + sourceCoordinator <- Stream.resource(SourceCoordinator.resource(config, streamManager, leaseManager, channelAffinity)) + _ <- pullFromQueue(jQueue, sourceCoordinator, channelAffinity).spawn + tokenedAction <- Stream.eval(sourceCoordinator.pull).repeat + } yield { + val SourceCoordinator.TokenedA(token, SubscriberAction.ProcessRecords(records, _)) = tokenedAction + val chunk = Chunk.from(records.toVector.map(_.getMessage.getData.asReadOnlyByteBuffer())) + val (tstampSeconds, tstampNanos) = + records.toVector.map(r => (r.getMessage.getPublishTime.getSeconds, r.getMessage.getPublishTime.getNanos)).min + LowLevelEvents(chunk, Vector(token), Some(Instant.ofEpochSecond(tstampSeconds, tstampNanos.toLong))) } - /** - * Modify ack deadlines if we need more time to process the messages - * - * @param config - * The Source configuration - * @param stub - * The GRPC stub on which we can issue modack requests - * @param refStates - * A map from tokens to the data held about a batch of messages received from pubsub. This - * function must update the state if it extends a deadline. - * @param channelAffinity - * Identifies the GRPC channel (TCP connection) creating these Actions. Each GRPC channel has - * its own concurrent stream modifying the ack deadlines. - */ - private def extendDeadlines[F[_]: Async]( - config: PubsubSourceConfigV2, - stub: SubscriberStub, - refStates: Ref[F, Map[Unique.Token, PubsubBatchState]], + private def pullFromQueue[F[_]: Sync]( + queue: LinkedBlockingQueue[SubscriberAction], + sourceCoordinator: SourceCoordinator[F, SubscriberAction.ProcessRecords], channelAffinity: Int ): Stream[F, Nothing] = Stream - .eval(Sync[F].realTimeInstant) - .evalMap { now => - val minAllowedDeadline = now.plusMillis((config.minRemainingDeadline * config.durationPerAckExtension.toMillis).toLong) - val newDeadline = now.plusMillis(config.durationPerAckExtension.toMillis) - refStates.modify { m => - val toExtend = m.filter { case (_, batchState) => - batchState.channelAffinity === channelAffinity && batchState.currentDeadline.isBefore(minAllowedDeadline) - } - val fixed = toExtend.view - .mapValues(_.copy(currentDeadline = newDeadline)) - .toMap - (m ++ fixed, toExtend.values.toVector) - } - } - .evalMap { toExtend => - if (toExtend.isEmpty) - Sync[F].sleep(0.5 * config.minRemainingDeadline * config.durationPerAckExtension) - else { - val ackIds = toExtend.sortBy(_.currentDeadline).flatMap(_.ackIds) - Utils.modAck[F](config.subscription, stub, ackIds, config.durationPerAckExtension, channelAffinity) + .eval { + Sync[F].uncancelable { poll => + poll(resolveNextAction(queue)) + .flatMap { + case SubscriberAction.Ready(controller) => + sourceCoordinator.receiveController(controller) + case processRecords: SubscriberAction.ProcessRecords => + sourceCoordinator.receiveItem(processRecords) + case SubscriberAction.SubscriberError(t) => + if (PubsubRetryOps.isRetryableException(t)) { + // Log at debug level because retryable errors are very frequent. + // In particular, if the pubsub subscription is empty then a streaming pull returns UNAVAILABLE + Logger[F].debug(s"Retryable error on PubSub channel $channelAffinity: ${t.getMessage}") >> + sourceCoordinator.handleStreamError + } else if (t.isInstanceOf[java.util.concurrent.CancellationException]) { + // The SourceCoordinator caused this by cancelling the stream. + // No need to inform the SourceCoordinator. + Logger[F].debug("Cancellation exception on PubSub channel") + } else { + Logger[F].error(t)("Exception from PubSub source") >> Sync[F].raiseError[Unit](t) + } + } } } .repeat .drain - /** - * Pipe from SubscriberAction to LowLevelEvents TODO: Say what else this does - * - * @param config - * The source configuration - * @param refStates - * A map from tokens to the data held about a batch of messages received from pubsub. This - * function must update the state to add new batches. - * @param hotswap - * A Hotswap wrapping the Resource that is populating the queue - * @param toSwap - * Initializes the Resource which is populating the queue. If we get an error from the queue - * then need to swap in the new Resource into the Hotswap - * @param channelAffinity - * Identifies the GRPC channel (TCP connection) creating these Actions. Each GRPC channel has - * its own queue, observer, and puller. - */ - private def toLowLevelEvents[F[_]: Async]( - config: PubsubSourceConfigV2, - refStates: Ref[F, Map[Unique.Token, PubsubBatchState]], - hotswap: Hotswap[F, Unit], - toSwap: Resource[F, Unit], - channelAffinity: Int - ): Pipe[F, SubscriberAction, LowLevelEvents[Vector[Unique.Token]]] = - _.flatMap { - case SubscriberAction.ProcessRecords(records, controller, timeReceived) => - val chunk = Chunk.from(records.map(_.getMessage.getData.asReadOnlyByteBuffer())) - val (tstampSeconds, tstampNanos) = - records.map(r => (r.getMessage.getPublishTime.getSeconds, r.getMessage.getPublishTime.getNanos)).min - val ackIds = records.map(_.getAckId) - Stream.eval { - for { - token <- Unique[F].unique - currentDeadline = timeReceived.plusMillis(config.durationPerAckExtension.toMillis) - _ <- refStates.update(_ + (token -> PubsubBatchState(currentDeadline, ackIds, channelAffinity))) - _ <- Sync[F].delay(controller.request(1)) - } yield LowLevelEvents(chunk, Vector(token), Some(Instant.ofEpochSecond(tstampSeconds, tstampNanos.toLong))) - } - case SubscriberAction.SubscriberError(t) => - if (PubsubRetryOps.isRetryableException(t)) { - // val nextDelay = (2 * delayOnSubscriberError).min((10 + scala.util.Random.nextDouble()).second) - // Log at debug level because retryable errors are very frequent. - // In particular, if the pubsub subscription is empty then a streaming pull returns UNAVAILABLE - Stream.eval { - Logger[F].debug(s"Retryable error on PubSub channel $channelAffinity: ${t.getMessage}") *> - hotswap.clear *> - Async[F].sleep((1.0 + scala.util.Random.nextDouble()).second) *> // TODO expotential backoff - hotswap.swap(toSwap) - }.drain - } else if (t.isInstanceOf[java.util.concurrent.CancellationException]) { - Stream.eval { - Logger[F].debug("Cancellation exception on PubSub channel") *> - hotswap.clear *> - hotswap.swap(toSwap) - }.drain - } else { - Stream.eval(Logger[F].error(t)("Exception from PubSub source")) *> Stream.raiseError[F](t) - } - } - private def resolveNextAction[F[_]: Sync, A](queue: LinkedBlockingQueue[A]): F[A] = Sync[F].delay(Option[A](queue.poll)).flatMap { case Some(action) => Sync[F].pure(action) @@ -276,58 +170,30 @@ object PubsubSourceV2 { Resource.make(Sync[F].delay(GrpcSubscriberStub.create(stubSettings)))(stub => Sync[F].blocking(stub.shutdownNow)) } - private def initializeStreamingPull[F[_]: Sync]( + private def nackRefStatesForShutdown[F[_]: Async]( config: PubsubSourceConfigV2, - subStub: SubscriberStub, - actionQueue: LinkedBlockingQueue[SubscriberAction], - channelAffinity: Int, - clientId: UUID - ): Resource[F, Unit] = { - - val observer = new ResponseObserver[StreamingPullResponse] { - var controller: StreamController = _ - override def onResponse(response: StreamingPullResponse): Unit = { - val messages = response.getReceivedMessagesList.asScala.toVector - if (messages.isEmpty) { - controller.request(1) - } else { - val action = SubscriberAction.ProcessRecords(messages, controller, Instant.now()) - actionQueue.put(action) - } - } - - override def onStart(c: StreamController): Unit = { - controller = c - controller.disableAutoInboundFlowControl() - controller.request(1) + stub: SubscriberStub, + refStates: Ref[F, Map[Unique.Token, PubsubBatchState]] + ): F[Unit] = + refStates.getAndSet(Map.empty).flatMap { m => + m.values.groupBy(_.channelAffinity).toVector.parTraverse_ { case (channelAffinity, batches) => + Utils.modAck(config.subscription, stub, batches.flatMap(_.ackIds.toVector).toVector, Duration.Zero, channelAffinity) } - - override def onError(t: Throwable): Unit = - actionQueue.put(SubscriberAction.SubscriberError(t)) - - override def onComplete(): Unit = () - } - val context = GrpcCallContext.createDefault.withChannelAffinity(channelAffinity) - - val request = StreamingPullRequest.newBuilder - .setSubscription(config.subscription.show) - .setStreamAckDeadlineSeconds(config.durationPerAckExtension.toSeconds.toInt) - .setClientId(if (config.consistentClientId) clientId.toString else UUID.randomUUID.toString) - .setMaxOutstandingMessages(0) - .setMaxOutstandingBytes(0) - .build - - Resource - .make(Sync[F].delay(subStub.streamingPullCallable.splitCall(observer, context))) { stream => - Sync[F].delay(stream.closeSendWithError(Status.CANCELLED.asException)) - } - .evalMap { stream => - Sync[F].delay(stream.send(request)) - } - .void - + private def nackQueueForShutdown[F[_]: Async]( + config: PubsubSourceConfigV2, + stub: SubscriberStub, + queue: LinkedBlockingQueue[SubscriberAction], + channelAffinity: Int + ): F[Unit] = { + val ackIds = queue.iterator.asScala.toVector.flatMap { + case SubscriberAction.ProcessRecords(records, _) => + records.toVector.map(_.getAckId) + case _ => + Vector.empty + } + Utils.modAck(config.subscription, stub, ackIds, Duration.Zero, channelAffinity) } private def executorResource[F[_]: Sync, E <: ExecutorService](make: F[E]): Resource[F, E] = diff --git a/modules/gcp/src/main/scala/common-streams-extensions/v2/SourceCoordinator.scala b/modules/gcp/src/main/scala/common-streams-extensions/v2/SourceCoordinator.scala new file mode 100644 index 0000000..a9de25f --- /dev/null +++ b/modules/gcp/src/main/scala/common-streams-extensions/v2/SourceCoordinator.scala @@ -0,0 +1,250 @@ +/* + * Copyright (c) 2023-present Snowplow Analytics Ltd. All rights reserved. + * + * This program is licensed to you under the Snowplow Community License Version 1.0, + * and you may not use this file except in compliance with the Snowplow Community License Version 1.0. + * You may obtain a copy of the Snowplow Community License Version 1.0 at https://docs.snowplow.io/community-license-1.0 + */ +package com.snowplowanalytics.snowplow.sources.pubsub.v2 + +import cats.effect.{Async, Deferred, Resource, Sync} +import cats.effect.kernel.{DeferredSink, Unique} +import cats.effect.std.AtomicCell +import cats.implicits._ +import cats.effect.implicits._ +import com.google.api.gax.rpc.StreamController +import org.typelevel.log4cats.Logger +import org.typelevel.log4cats.slf4j.Slf4jLogger + +import scala.concurrent.duration.{Duration, DurationDouble, FiniteDuration} + +private class SourceCoordinator[F[_]: Async, A] private ( + config: PubsubSourceConfigV2, + status: AtomicCell[F, SourceCoordinator.Status[F, A]], + streamManager: StreamManager[F], + leaseManager: LeaseManager[F, A], + channelAffinity: Int +) { + import SourceCoordinator._ + + private implicit val logger: Logger[F] = Slf4jLogger.getLogger[F] + + private def raiseForIllegalState[B](status: Status[F, A], thingCannotDo: String): F[B] = + Sync[F].raiseError( + new IllegalStateException(s"${getClass.getName} cannot $thingCannotDo when in state ${status.getClass.getSimpleName}") + ) + + private def initialBackoffWithJitter: F[FiniteDuration] = + Sync[F].delay { + ((1.0 + scala.util.Random.nextDouble()) * 100).millis + } + + def pull: F[TokenedA[A]] = + status + .evalModify[F[TokenedA[A]]] { + case Status.Shutdown(since, buffered) => + buffered match { + case head +: tail => + Sync[F].pure(Status.Shutdown(since, tail) -> head.pure[F]) + case _ => + for { + deferred <- Deferred[F, TokenedA[A]] + now <- Sync[F].realTime + nextBackoff <- initialBackoffWithJitter + } yield { + val pause = if (now - since > nextBackoff) Sync[F].unit else Sync[F].sleep(now - since - nextBackoff) + val get = pause >> streamManager.startAgain >> deferred.get + Status.Initializing[F, A](nextBackoff, BufferOrAwaiter.Awaiter(deferred)) -> get + } + } + case Status.Initializing(lastBackoff, BufferOrAwaiter.Buffer(buffered)) => + // Downstream wants to pull some message, and by luck we are already initializing a streaming pull + buffered match { + case head +: tail => + Sync[F].pure(Status.Initializing(lastBackoff, BufferOrAwaiter.Buffer[F, A](tail)) -> head.pure[F]) + case _ => + for { + deferred <- Deferred[F, TokenedA[A]] + } yield Status.Initializing(lastBackoff, BufferOrAwaiter.Awaiter(deferred)) -> deferred.get + } + case status @ Status.Initializing(_, BufferOrAwaiter.Awaiter(_)) => + // Illegal state because we only start initializing when downstreams calls `pull` + // ...so we only get here if downstream calls `pull` twice without waiting + raiseForIllegalState(status, "pull") + case status @ Status.Requesting(_, BufferOrAwaiter.Awaiter(_), _) => + // We only get here if downstream calls `pull` twice without waiting + raiseForIllegalState(status, "pull") + case Status.Requesting(controller, BufferOrAwaiter.Buffer(buffered), lastBackoff) => + // Downstream wants to pull some message, and by luck we are already requesting some + buffered match { + case head +: tail => + Sync[F].pure(Status.Requesting[F, A](controller, BufferOrAwaiter.Buffer(tail), lastBackoff) -> head.pure[F]) + case _ => + Deferred[F, TokenedA[A]].map { deferred => + Status.Requesting[F, A](controller, BufferOrAwaiter.Awaiter(deferred), lastBackoff) -> deferred.get + } + } + case Status.AwaitingConsumer(controller, buffered, lastFetch) => + buffered match { + case head +: tail if tail.size > config.prefetchMin => + Sync[F].pure(Status.AwaitingConsumer[F, A](controller, tail, lastFetch) -> head.pure[F]) + case head +: tail => + Sync[F].delay(controller.request(1)).as { + Status.Requesting[F, A](controller, BufferOrAwaiter.Buffer(tail), None) -> head.pure[F] + } + case _ => + for { + deferred <- Deferred[F, TokenedA[A]] + _ <- Sync[F].delay(controller.request(1)) + } yield Status.Requesting[F, A](controller, BufferOrAwaiter.Awaiter(deferred), None) -> deferred.get + } + } + .flatten + + def handleStreamError: F[Unit] = + status + .evalModify[F[Unit]] { + case Status.Requesting(_, receive, Some(lastBackoff)) => + val nextBackoff = (lastBackoff * 2).min(10.seconds) + val todo = Sync[F].sleep(nextBackoff) >> streamManager.startAgain + Sync[F].pure(Status.Initializing[F, A](nextBackoff, receive) -> todo) + case Status.Requesting(_, receive, None) => + initialBackoffWithJitter.map { nextBackoff => + val todo = Sync[F].sleep(nextBackoff) >> streamManager.startAgain + Status.Initializing[F, A](nextBackoff, receive) -> todo + } + case shutdown: Status.Shutdown[F, A] => + Sync[F].pure(shutdown -> Sync[F].unit) + case Status.AwaitingConsumer(_, buffered, _) => + Sync[F].realTime.map { now => + Status.Shutdown(now, buffered) -> Sync[F].unit + } + + case status: Status.Initializing[F, A] => + raiseForIllegalState(status, "handle stream error") + } + .flatten + + def receiveItem(in: A): F[Unit] = + status.evalUpdate { + case Status.Requesting(controller, BufferOrAwaiter.Awaiter(deferred), _) => + if (config.prefetchMin > 0) { + for { + token <- leaseManager.manageLeases(in) + _ <- deferred.complete(TokenedA(token, in)) + _ <- Sync[F].delay(controller.request(1)) + } yield Status.Requesting(controller, BufferOrAwaiter.Buffer(Vector.empty), None) + } else { + for { + now <- Sync[F].realTime + token <- leaseManager.manageLeases(in) + _ <- deferred.complete(TokenedA(token, in)) + } yield Status.AwaitingConsumer(controller, Vector.empty, now) + } + case Status.Requesting(controller, BufferOrAwaiter.Buffer(buffered), _) => + if (buffered.size + 1 < config.prefetchMin) { + for { + token <- leaseManager.manageLeases(in) + _ <- Sync[F].delay(controller.request(1)) + } yield Status.Requesting(controller, BufferOrAwaiter.Buffer(buffered :+ TokenedA(token, in)), None) + } else { + for { + now <- Sync[F].realTime + token <- leaseManager.manageLeases(in) + } yield Status.AwaitingConsumer(controller, buffered :+ TokenedA(token, in), now) + } + case status: Status.AwaitingConsumer[F, A] => + raiseForIllegalState(status, "handle stream error") + case status: Status.Initializing[F, A] => + raiseForIllegalState(status, "handle stream error") + case status: Status.Shutdown[F, A] => + raiseForIllegalState(status, "handle stream error") + } + + def receiveController(controller: StreamController): F[Unit] = + status.evalUpdate { + case Status.Initializing(lastBackoff, receive) => + Sync[F].delay(controller.request(1)).as(Status.Requesting(controller, receive, Some(lastBackoff))) + case status => + raiseForIllegalState(status, "receive controller") + } + + private def keepAlive: F[Unit] = + status + .evalModify[FiniteDuration] { + case Status.AwaitingConsumer(controller, buffered, lastFetch) if buffered.size < config.prefetchMax => + Sync[F].realTime.flatMap[(Status[F, A], FiniteDuration)] { now => + if (now > lastFetch + config.progressTimeout) { + for { + _ <- + Logger[F].info( + s"Requesting more messages from pubsub Streaming Pull $channelAffinity to avoid a timeout after ${config.progressTimeout} without activity" + ) + _ <- Sync[F].delay(controller.request(1)) + } yield Status.Requesting(controller, BufferOrAwaiter.Buffer[F, A](buffered), None) -> config.progressTimeout + } else { + val nextTimeout = now - lastFetch - config.progressTimeout + Sync[F].pure(Status.AwaitingConsumer(controller, buffered, lastFetch) -> nextTimeout) + } + } + case Status.AwaitingConsumer(controller, buffered, _) => + for { + _ <- + Logger[F].info( + s"Dropping ${buffered.size} buffered batches from Streaming Pull $channelAffinity. Exceeded ${config.prefetchMax} pre-fetches while attempting to keep the stream alive." + ) + _ <- leaseManager.stopManagingLeases(buffered.map(_.token)) + _ <- Sync[F].delay(controller.cancel()) + now <- Sync[F].realTime + } yield Status.Shutdown(now, Vector.empty) -> config.progressTimeout + case other => + Sync[F].pure(other -> config.progressTimeout) + } + .flatMap[Unit] { timeout => + Sync[F].sleep(timeout) + } + .foreverM +} + +private object SourceCoordinator { + + case class TokenedA[A](token: Unique.Token, value: A) + + def resource[F[_]: Async, A]( + config: PubsubSourceConfigV2, + streamManager: StreamManager[F], + leaseManager: LeaseManager[F, A], + channelAffinity: Int + ): Resource[F, SourceCoordinator[F, A]] = + Resource + .eval(AtomicCell[F].of[Status[F, A]](Status.Shutdown(Duration.Zero, Vector.empty))) + .map { atomicCell => + new SourceCoordinator(config, atomicCell, streamManager, leaseManager, channelAffinity) + } + .flatTap(_.keepAlive.background) + + private sealed trait BufferOrAwaiter[F[_], A] + private object BufferOrAwaiter { + case class Awaiter[F[_], A](complete: DeferredSink[F, TokenedA[A]]) extends BufferOrAwaiter[F, A] + case class Buffer[F[_], A](value: Vector[TokenedA[A]]) extends BufferOrAwaiter[F, A] + } + + private sealed trait Status[F[_], A] + private object Status { + case class Shutdown[F[_], A](since: FiniteDuration, buffer: Vector[TokenedA[A]]) extends Status[F, A] + + case class Initializing[F[_], A](lastBackoff: FiniteDuration, receive: BufferOrAwaiter[F, A]) extends Status[F, A] + + case class Requesting[F[_], A]( + controller: StreamController, + receive: BufferOrAwaiter[F, A], + lastBackoff: Option[FiniteDuration] + ) extends Status[F, A] + + case class AwaitingConsumer[F[_], A]( + controller: StreamController, + buffered: Vector[TokenedA[A]], + lastFetch: FiniteDuration + ) extends Status[F, A] + } +} diff --git a/modules/gcp/src/main/scala/common-streams-extensions/v2/StreamManager.scala b/modules/gcp/src/main/scala/common-streams-extensions/v2/StreamManager.scala new file mode 100644 index 0000000..ccc5556 --- /dev/null +++ b/modules/gcp/src/main/scala/common-streams-extensions/v2/StreamManager.scala @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2023-present Snowplow Analytics Ltd. All rights reserved. + * + * This program is licensed to you under the Snowplow Community License Version 1.0, + * and you may not use this file except in compliance with the Snowplow Community License Version 1.0. + * You may obtain a copy of the Snowplow Community License Version 1.0 at https://docs.snowplow.io/community-license-1.0 + */ +package com.snowplowanalytics.snowplow.sources.pubsub.v2 + +import cats.data.NonEmptyVector +import cats.effect.{Async, Resource, Sync} +import cats.effect.std.Hotswap +import cats.implicits._ + +// pubsub +import com.google.api.gax.grpc.GrpcCallContext +import com.google.api.gax.rpc.{ResponseObserver, StreamController} +import com.google.pubsub.v1.{StreamingPullRequest, StreamingPullResponse} +import com.google.cloud.pubsub.v1.stub.SubscriberStub +import io.grpc.Status + +import scala.jdk.CollectionConverters._ + +import java.util.concurrent.LinkedBlockingQueue +import java.util.UUID +import java.time.Instant + +private trait StreamManager[F[_]] { + def startAgain: F[Unit] +} + +private object StreamManager { + + def resource[F[_]: Async]( + config: PubsubSourceConfigV2, + stub: SubscriberStub, + actionQueue: LinkedBlockingQueue[SubscriberAction], + channelAffinity: Int + ): Resource[F, StreamManager[F]] = + for { + clientId <- Resource.eval(Sync[F].delay(UUID.randomUUID)) + hotswap <- Hotswap.create[F, Unit] + } yield new StreamManager[F] { + def startAgain: F[Unit] = + hotswap.swap(initializeStreamingPull(config, stub, actionQueue, channelAffinity, clientId)) + } + + private def initializeStreamingPull[F[_]: Sync]( + config: PubsubSourceConfigV2, + subStub: SubscriberStub, + actionQueue: LinkedBlockingQueue[SubscriberAction], + channelAffinity: Int, + clientId: UUID + ): Resource[F, Unit] = { + + val observer = new ResponseObserver[StreamingPullResponse] { + var controller: StreamController = _ + override def onResponse(response: StreamingPullResponse): Unit = { + val messages = response.getReceivedMessagesList.asScala.toVector + NonEmptyVector.fromVector(messages) match { + case Some(nev) => + val action = SubscriberAction.ProcessRecords(nev, Instant.now()) + actionQueue.put(action) + case None => + // messages was empty + controller.request(1) + } + } + + override def onStart(c: StreamController): Unit = { + controller = c + controller.disableAutoInboundFlowControl() + actionQueue.put(SubscriberAction.Ready(controller)) + } + + override def onError(t: Throwable): Unit = + actionQueue.put(SubscriberAction.SubscriberError(t)) + + override def onComplete(): Unit = () + + } + + val context = GrpcCallContext.createDefault.withChannelAffinity(channelAffinity) + + val request = StreamingPullRequest.newBuilder + .setSubscription(config.subscription.show) + .setStreamAckDeadlineSeconds(config.durationPerAckExtension.toSeconds.toInt) + .setClientId(clientId.toString) + .setMaxOutstandingMessages(0) + .setMaxOutstandingBytes(0) + .build + + Resource + .make(Sync[F].delay(subStub.streamingPullCallable.splitCall(observer, context))) { stream => + Sync[F].delay(stream.closeSendWithError(Status.CANCELLED.asException)) + } + .evalMap { stream => + Sync[F].delay(stream.send(request)) + } + .void + + } +} diff --git a/modules/gcp/src/main/scala/common-streams-extensions/v2/SubscriberAction.scala b/modules/gcp/src/main/scala/common-streams-extensions/v2/SubscriberAction.scala index fa61c44..437d648 100644 --- a/modules/gcp/src/main/scala/common-streams-extensions/v2/SubscriberAction.scala +++ b/modules/gcp/src/main/scala/common-streams-extensions/v2/SubscriberAction.scala @@ -7,6 +7,8 @@ */ package com.snowplowanalytics.snowplow.sources.pubsub.v2 +import cats.data.NonEmptyVector + import com.google.pubsub.v1.ReceivedMessage import com.google.api.gax.rpc.StreamController import java.time.Instant @@ -24,15 +26,20 @@ private object SubscriberAction { * * @param records * The received records - * @param streamController - * The GRPC stream controller. When this action is handed over to cats-effect/fs2 world then we - * must tell the stream controller we are ready to receive more events * @param timeRecieved * Timestamp the records were pulled over the GRPC stream */ case class ProcessRecords( - records: Vector[ReceivedMessage], - streamController: StreamController, + records: NonEmptyVector[ReceivedMessage], timeReceived: Instant ) extends SubscriberAction + + /** + * The GRPC stream is ready to send us messages + * + * @param streamController + * The GRPC stream controller. When this action is handed over to cats-effect/fs2 world then we + * must tell the stream controller we are ready to receive more events + */ + case class Ready(controller: StreamController) extends SubscriberAction }