Skip to content

Commit

Permalink
Moving things around
Browse files Browse the repository at this point in the history
  • Loading branch information
pondzix authored and istreeter committed Sep 13, 2024
1 parent a6b96d0 commit 1c37867
Show file tree
Hide file tree
Showing 6 changed files with 406 additions and 299 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (c) 2023-present Snowplow Analytics Ltd. All rights reserved.
*
* This program is licensed to you under the Snowplow Community License Version 1.0,
* and you may not use this file except in compliance with the Snowplow Community License Version 1.0.
* You may obtain a copy of the Snowplow Community License Version 1.0 at https://docs.snowplow.io/community-license-1.0
*/
package com.snowplowanalytics.snowplow.sources.kinesis

import cats.implicits._
import cats.{Order, Semigroup}
import software.amazon.kinesis.processor.RecordProcessorCheckpointer
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber

import java.util.concurrent.CountDownLatch

private sealed trait Checkpointable {
def extendedSequenceNumber: ExtendedSequenceNumber
}

private object Checkpointable {
final case class Record(extendedSequenceNumber: ExtendedSequenceNumber, checkpointer: RecordProcessorCheckpointer) extends Checkpointable

final case class ShardEnd(checkpointer: RecordProcessorCheckpointer, release: CountDownLatch) extends Checkpointable {
override def extendedSequenceNumber: ExtendedSequenceNumber = ExtendedSequenceNumber.SHARD_END
}

implicit def checkpointableOrder: Order[Checkpointable] = Order.from { case (a, b) =>
a.extendedSequenceNumber.compareTo(b.extendedSequenceNumber)
}

implicit def checkpointableSemigroup: Semigroup[Checkpointable] = new Semigroup[Checkpointable] {
def combine(x: Checkpointable, y: Checkpointable): Checkpointable =
x.max(y)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) 2023-present Snowplow Analytics Ltd. All rights reserved.
*
* This program is licensed to you under the Snowplow Community License Version 1.0,
* and you may not use this file except in compliance with the Snowplow Community License Version 1.0.
* You may obtain a copy of the Snowplow Community License Version 1.0 at https://docs.snowplow.io/community-license-1.0
*/
package com.snowplowanalytics.snowplow.sources.kinesis

import software.amazon.kinesis.lifecycle.events.{ProcessRecordsInput, ShardEndedInput}

import java.util.concurrent.CountDownLatch

private sealed trait KCLAction

private object KCLAction {

final case class ProcessRecords(shardId: String, processRecordsInput: ProcessRecordsInput) extends KCLAction
final case class ShardEnd(
shardId: String,
await: CountDownLatch,
shardEndedInput: ShardEndedInput
) extends KCLAction
final case class KCLError(t: Throwable) extends KCLAction

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/*
* Copyright (c) 2023-present Snowplow Analytics Ltd. All rights reserved.
*
* This program is licensed to you under the Snowplow Community License Version 1.0,
* and you may not use this file except in compliance with the Snowplow Community License Version 1.0.
* You may obtain a copy of the Snowplow Community License Version 1.0 at https://docs.snowplow.io/community-license-1.0
*/
package com.snowplowanalytics.snowplow.sources.kinesis

import cats.effect.implicits._
import cats.effect.{Async, Resource, Sync}
import cats.implicits._
import com.snowplowanalytics.snowplow.sources.kinesis.KCLAction.KCLError
import software.amazon.awssdk.awscore.defaultsmode.DefaultsMode
import software.amazon.awssdk.services.cloudwatch.CloudWatchAsyncClient
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient
import software.amazon.kinesis.common.{ConfigsBuilder, InitialPositionInStream, InitialPositionInStreamExtended}
import software.amazon.kinesis.coordinator.{Scheduler, WorkerStateChangeListener}
import software.amazon.kinesis.metrics.MetricsLevel
import software.amazon.kinesis.processor.SingleStreamTracker
import software.amazon.kinesis.retrieval.fanout.FanOutConfig
import software.amazon.kinesis.retrieval.polling.PollingConfig

import java.net.URI
import java.util.Date
import java.util.concurrent.SynchronousQueue
import java.util.concurrent.atomic.AtomicReference

private[kinesis] object KCLScheduler {

def populateQueue[F[_]: Async](
config: KinesisSourceConfig,
queue: SynchronousQueue[KCLAction]
): Resource[F, Unit] =
for {
kinesis <- mkKinesisClient[F](config.customEndpoint)
dynamo <- mkDynamoDbClient[F](config.dynamodbCustomEndpoint)
cloudWatch <- mkCloudWatchClient[F](config.cloudwatchCustomEndpoint)
scheduler <- Resource.eval(mkScheduler(kinesis, dynamo, cloudWatch, config, queue))
_ <- runInBackground(scheduler)
} yield ()

private def mkScheduler[F[_]: Sync](
kinesisClient: KinesisAsyncClient,
dynamoDbClient: DynamoDbAsyncClient,
cloudWatchClient: CloudWatchAsyncClient,
kinesisConfig: KinesisSourceConfig,
queue: SynchronousQueue[KCLAction]
): F[Scheduler] =
Sync[F].delay {
val configsBuilder =
new ConfigsBuilder(
kinesisConfig.streamName,
kinesisConfig.appName,
kinesisClient,
dynamoDbClient,
cloudWatchClient,
kinesisConfig.workerIdentifier,
() => ShardRecordProcessor(queue, new AtomicReference(Set.empty[String]))
)

val retrievalConfig =
configsBuilder.retrievalConfig
.streamTracker(new SingleStreamTracker(kinesisConfig.streamName, initialPositionOf(kinesisConfig.initialPosition)))
.retrievalSpecificConfig {
kinesisConfig.retrievalMode match {
case KinesisSourceConfig.Retrieval.FanOut =>
new FanOutConfig(kinesisClient).streamName(kinesisConfig.streamName).applicationName(kinesisConfig.appName)
case KinesisSourceConfig.Retrieval.Polling(maxRecords) =>
new PollingConfig(kinesisConfig.streamName, kinesisClient).maxRecords(maxRecords)
}
}

val leaseManagementConfig =
configsBuilder.leaseManagementConfig
.failoverTimeMillis(kinesisConfig.leaseDuration.toMillis)

// We ask to see empty batches, so that we can update the health check even when there are no records in the stream
val processorConfig =
configsBuilder.processorConfig
.callProcessRecordsEvenForEmptyRecordList(true)

val coordinatorConfig = configsBuilder.coordinatorConfig
.workerStateChangeListener(new WorkerStateChangeListener {
def onWorkerStateChange(newState: WorkerStateChangeListener.WorkerState): Unit = ()
override def onAllInitializationAttemptsFailed(e: Throwable): Unit =
queue.put(KCLError(e))
})

new Scheduler(
configsBuilder.checkpointConfig,
coordinatorConfig,
leaseManagementConfig,
configsBuilder.lifecycleConfig,
configsBuilder.metricsConfig.metricsLevel(MetricsLevel.NONE),
processorConfig,
retrievalConfig
)
}

private def runInBackground[F[_]: Async](scheduler: Scheduler): Resource[F, Unit] =
Sync[F].blocking(scheduler.run()).background *> Resource.onFinalize(Sync[F].blocking(scheduler.shutdown()))

private def initialPositionOf(config: KinesisSourceConfig.InitialPosition): InitialPositionInStreamExtended =
config match {
case KinesisSourceConfig.InitialPosition.Latest => InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST)
case KinesisSourceConfig.InitialPosition.TrimHorizon =>
InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.TRIM_HORIZON)
case KinesisSourceConfig.InitialPosition.AtTimestamp(instant) =>
InitialPositionInStreamExtended.newInitialPositionAtTimestamp(Date.from(instant))
}

private def mkKinesisClient[F[_]: Sync](customEndpoint: Option[URI]): Resource[F, KinesisAsyncClient] =
Resource.fromAutoCloseable {
Sync[F].blocking { // Blocking because this might dial the EC2 metadata endpoint
val builder =
KinesisAsyncClient
.builder()
.defaultsMode(DefaultsMode.AUTO)
val customized = customEndpoint.map(builder.endpointOverride).getOrElse(builder)
customized.build
}
}

private def mkDynamoDbClient[F[_]: Sync](customEndpoint: Option[URI]): Resource[F, DynamoDbAsyncClient] =
Resource.fromAutoCloseable {
Sync[F].blocking { // Blocking because this might dial the EC2 metadata endpoint
val builder =
DynamoDbAsyncClient
.builder()
.defaultsMode(DefaultsMode.AUTO)
val customized = customEndpoint.map(builder.endpointOverride).getOrElse(builder)
customized.build
}
}

private def mkCloudWatchClient[F[_]: Sync](customEndpoint: Option[URI]): Resource[F, CloudWatchAsyncClient] =
Resource.fromAutoCloseable {
Sync[F].blocking { // Blocking because this might dial the EC2 metadata endpoint
val builder =
CloudWatchAsyncClient
.builder()
.defaultsMode(DefaultsMode.AUTO)
val customized = customEndpoint.map(builder.endpointOverride).getOrElse(builder)
customized.build
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) 2023-present Snowplow Analytics Ltd. All rights reserved.
*
* This program is licensed to you under the Snowplow Community License Version 1.0,
* and you may not use this file except in compliance with the Snowplow Community License Version 1.0.
* You may obtain a copy of the Snowplow Community License Version 1.0 at https://docs.snowplow.io/community-license-1.0
*/
package com.snowplowanalytics.snowplow.sources.kinesis

import cats.effect.{Async, Sync}
import cats.implicits._
import cats.effect.implicits._
import com.snowplowanalytics.snowplow.sources.internal.Checkpointer
import org.typelevel.log4cats.Logger
import software.amazon.kinesis.exceptions.ShutdownException
import software.amazon.kinesis.processor.RecordProcessorCheckpointer
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber

import java.util.concurrent.CountDownLatch

private class KinesisCheckpointer[F[_]: Async: Logger] extends Checkpointer[F, Map[String, Checkpointable]] {

override val empty: Map[String, Checkpointable] = Map.empty

override def combine(x: Map[String, Checkpointable], y: Map[String, Checkpointable]): Map[String, Checkpointable] =
x |+| y

override def ack(c: Map[String, Checkpointable]): F[Unit] =
c.toList.parTraverse_ {
case (shardId, Checkpointable.Record(extendedSequenceNumber, checkpointer)) =>
checkpointRecord(shardId, extendedSequenceNumber, checkpointer)
case (shardId, Checkpointable.ShardEnd(checkpointer, release)) =>
checkpointShardEnd(shardId, checkpointer, release)
}

override def nack(c: Map[String, Checkpointable]): F[Unit] =
Sync[F].unit

private def checkpointShardEnd(
shardId: String,
checkpointer: RecordProcessorCheckpointer,
release: CountDownLatch
) =
Logger[F].debug(s"Checkpointing shard $shardId at SHARD_END") *>
Sync[F].blocking(checkpointer.checkpoint()).recoverWith(ignoreShutdownExceptions(shardId)) *>
Sync[F].delay(release.countDown())

private def checkpointRecord(
shardId: String,
extendedSequenceNumber: ExtendedSequenceNumber,
checkpointer: RecordProcessorCheckpointer
) =
Logger[F].debug(s"Checkpointing shard $shardId at $extendedSequenceNumber") *>
Sync[F]
.blocking(
checkpointer.checkpoint(extendedSequenceNumber.sequenceNumber, extendedSequenceNumber.subSequenceNumber)
)
.recoverWith(ignoreShutdownExceptions(shardId))

private def ignoreShutdownExceptions(shardId: String): PartialFunction[Throwable, F[Unit]] = { case _: ShutdownException =>
// The ShardRecordProcessor instance has been shutdown. This just means another KCL
// worker has stolen our lease. It is expected during autoscaling of instances, and is
// safe to ignore.
Logger[F].warn(s"Skipping checkpointing of shard $shardId because this worker no longer owns the lease")
}
}
Loading

0 comments on commit 1c37867

Please sign in to comment.