diff --git a/modules/kafka/src/main/resources/reference.conf b/modules/kafka/src/main/resources/reference.conf index 4783cdb..81621fa 100644 --- a/modules/kafka/src/main/resources/reference.conf +++ b/modules/kafka/src/main/resources/reference.conf @@ -5,6 +5,9 @@ snowplow.defaults { "group.id": null # invalid value MUST be overridden by the applicaion "allow.auto.create.topics": "false" "auto.offset.reset": "latest" + "security.protocol": "SASL_SSL" + "sasl.mechanism": "OAUTHBEARER" + "sasl.jaas.config": "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;" } } } @@ -13,6 +16,9 @@ snowplow.defaults { kafka: { producerConf: { "client.id": null # invalid value MUST be overriden by the application + "security.protocol": "SASL_SSL" + "sasl.mechanism": "OAUTHBEARER" + "sasl.jaas.config": "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;" } maxRecordSize: 1000000 } diff --git a/modules/kafka/src/main/scala/com/snowplowanalytics/snowplow/azure/AzureAuthenticationCallbackHandler.scala b/modules/kafka/src/main/scala/com/snowplowanalytics/snowplow/azure/AzureAuthenticationCallbackHandler.scala new file mode 100644 index 0000000..7eb6f9f --- /dev/null +++ b/modules/kafka/src/main/scala/com/snowplowanalytics/snowplow/azure/AzureAuthenticationCallbackHandler.scala @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2023-present Snowplow Analytics Ltd. All rights reserved. + * + * This program is licensed to you under the Snowplow Community License Version 1.0, + * and you may not use this file except in compliance with the Snowplow Community License Version 1.0. + * You may obtain a copy of the Snowplow Community License Version 1.0 at https://docs.snowplow.io/community-license-1.0 + */ +package com.snowplowanalytics.snowplow.azure + +import java.net.URI +import java.{lang, util} + +import com.nimbusds.jwt.JWTParser + +import javax.security.auth.callback.Callback +import javax.security.auth.callback.UnsupportedCallbackException +import javax.security.auth.login.AppConfigurationEntry + +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken +import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback + +import com.azure.identity.DefaultAzureCredentialBuilder +import com.azure.core.credential.TokenRequestContext + +trait AzureAuthenticationCallbackHandler extends AuthenticateCallbackHandler { + + val credentials = new DefaultAzureCredentialBuilder().build() + + var sbUri: String = "" + + override def configure( + configs: util.Map[String, _], + saslMechanism: String, + jaasConfigEntries: util.List[AppConfigurationEntry] + ): Unit = { + val bootstrapServer = + configs + .get(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG) + .toString + .replaceAll("\\[|\\]", "") + .split(",") + .toList + .headOption match { + case Some(s) => s + case None => throw new Exception("Empty bootstrap servers list") + } + val uri = URI.create("https://" + bootstrapServer) + // Workload identity works with '.default' scope + this.sbUri = s"${uri.getScheme}://${uri.getHost}/.default" + } + + override def handle(callbacks: Array[Callback]): Unit = + callbacks.foreach { + case callback: OAuthBearerTokenCallback => + val token = getOAuthBearerToken() + callback.token(token) + case callback => throw new UnsupportedCallbackException(callback) + } + + def getOAuthBearerToken(): OAuthBearerToken = { + val reqContext = new TokenRequestContext() + reqContext.addScopes(sbUri) + val accessToken = credentials.getTokenSync(reqContext).getToken + val jwt = JWTParser.parse(accessToken) + val claims = jwt.getJWTClaimsSet + + new OAuthBearerToken { + override def value(): String = accessToken + + override def lifetimeMs(): Long = claims.getExpirationTime.getTime + + override def scope(): util.Set[String] = null + + override def principalName(): String = null + + override def startTimeMs(): lang.Long = null + } + } + + override def close(): Unit = () +} diff --git a/modules/kafka/src/main/scala/com/snowplowanalytics/snowplow/sinks/kafka/KafkaSink.scala b/modules/kafka/src/main/scala/com/snowplowanalytics/snowplow/sinks/kafka/KafkaSink.scala index 07dc15d..376db27 100644 --- a/modules/kafka/src/main/scala/com/snowplowanalytics/snowplow/sinks/kafka/KafkaSink.scala +++ b/modules/kafka/src/main/scala/com/snowplowanalytics/snowplow/sinks/kafka/KafkaSink.scala @@ -14,13 +14,21 @@ import cats.Monad import com.snowplowanalytics.snowplow.sinks.{Sink, Sinkable} import fs2.kafka._ +import scala.reflect._ + import java.util.UUID +import com.snowplowanalytics.snowplow.azure.AzureAuthenticationCallbackHandler + object KafkaSink { - def resource[F[_]: Async](config: KafkaSinkConfig): Resource[F, Sink[F]] = { + def resource[F[_]: Async, T <: AzureAuthenticationCallbackHandler]( + config: KafkaSinkConfig, + authHandlerClass: ClassTag[T] + ): Resource[F, Sink[F]] = { val producerSettings = ProducerSettings[F, String, Array[Byte]] + .withProperty("sasl.login.callback.handler.class", authHandlerClass.runtimeClass.getName) .withBootstrapServers(config.bootstrapServers) .withProperties(config.producerConf) diff --git a/modules/kafka/src/main/scala/com/snowplowanalytics/snowplow/sources/kafka/KafkaSource.scala b/modules/kafka/src/main/scala/com/snowplowanalytics/snowplow/sources/kafka/KafkaSource.scala index e979392..0e76894 100644 --- a/modules/kafka/src/main/scala/com/snowplowanalytics/snowplow/sources/kafka/KafkaSource.scala +++ b/modules/kafka/src/main/scala/com/snowplowanalytics/snowplow/sources/kafka/KafkaSource.scala @@ -16,6 +16,8 @@ import fs2.Stream import org.typelevel.log4cats.{Logger, SelfAwareStructuredLogger} import org.typelevel.log4cats.slf4j.Slf4jLogger +import scala.reflect._ + import java.nio.ByteBuffer import java.time.Instant @@ -26,20 +28,27 @@ import org.apache.kafka.common.TopicPartition // snowplow import com.snowplowanalytics.snowplow.sources.SourceAndAck import com.snowplowanalytics.snowplow.sources.internal.{Checkpointer, LowLevelEvents, LowLevelSource} +import com.snowplowanalytics.snowplow.azure.AzureAuthenticationCallbackHandler object KafkaSource { private implicit def logger[F[_]: Sync]: SelfAwareStructuredLogger[F] = Slf4jLogger.getLogger[F] - def build[F[_]: Async](config: KafkaSourceConfig): F[SourceAndAck[F]] = - LowLevelSource.toSourceAndAck(lowLevel(config)) + def build[F[_]: Async, T <: AzureAuthenticationCallbackHandler]( + config: KafkaSourceConfig, + authHandlerClass: ClassTag[T] + ): F[SourceAndAck[F]] = + LowLevelSource.toSourceAndAck(lowLevel(config, authHandlerClass)) - private def lowLevel[F[_]: Async](config: KafkaSourceConfig): LowLevelSource[F, KafkaCheckpoints[F]] = + private def lowLevel[F[_]: Async, T <: AzureAuthenticationCallbackHandler]( + config: KafkaSourceConfig, + authHandlerClass: ClassTag[T] + ): LowLevelSource[F, KafkaCheckpoints[F]] = new LowLevelSource[F, KafkaCheckpoints[F]] { def checkpointer: Checkpointer[F, KafkaCheckpoints[F]] = kafkaCheckpointer def stream: Stream[F, Stream[F, LowLevelEvents[KafkaCheckpoints[F]]]] = - kafkaStream(config) + kafkaStream(config, authHandlerClass) } case class OffsetAndCommit[F[_]](offset: Long, commit: F[Unit]) @@ -59,9 +68,12 @@ object KafkaSource { def nack(c: KafkaCheckpoints[F]): F[Unit] = Applicative[F].unit } - private def kafkaStream[F[_]: Async](config: KafkaSourceConfig): Stream[F, Stream[F, LowLevelEvents[KafkaCheckpoints[F]]]] = + private def kafkaStream[F[_]: Async, T <: AzureAuthenticationCallbackHandler]( + config: KafkaSourceConfig, + authHandlerClass: ClassTag[T] + ): Stream[F, Stream[F, LowLevelEvents[KafkaCheckpoints[F]]]] = KafkaConsumer - .stream(consumerSettings[F](config)) + .stream(consumerSettings[F, T](config, authHandlerClass)) .evalTap(_.subscribeTo(config.topicName)) .flatMap { consumer => consumer.partitionsMapStream @@ -124,8 +136,12 @@ object KafkaSource { private implicit def byteBufferDeserializer[F[_]: Sync]: Resource[F, ValueDeserializer[F, ByteBuffer]] = Resource.pure(Deserializer.lift(arr => Sync[F].pure(ByteBuffer.wrap(arr)))) - private def consumerSettings[F[_]: Async](config: KafkaSourceConfig): ConsumerSettings[F, Array[Byte], ByteBuffer] = + private def consumerSettings[F[_]: Async, T <: AzureAuthenticationCallbackHandler]( + config: KafkaSourceConfig, + authHandlerClass: ClassTag[T] + ): ConsumerSettings[F, Array[Byte], ByteBuffer] = ConsumerSettings[F, Array[Byte], ByteBuffer] + .withProperty("sasl.login.callback.handler.class", authHandlerClass.runtimeClass.getName) .withBootstrapServers(config.bootstrapServers) .withProperties(config.consumerConf) .withEnableAutoCommit(false) diff --git a/modules/kafka/src/test/scala/com/snowplowanalytics/snowplow/sinks/kafka/KafkaSinkConfigSpec.scala b/modules/kafka/src/test/scala/com/snowplowanalytics/snowplow/sinks/kafka/KafkaSinkConfigSpec.scala index 1e439f1..cc92656 100644 --- a/modules/kafka/src/test/scala/com/snowplowanalytics/snowplow/sinks/kafka/KafkaSinkConfigSpec.scala +++ b/modules/kafka/src/test/scala/com/snowplowanalytics/snowplow/sinks/kafka/KafkaSinkConfigSpec.scala @@ -44,7 +44,10 @@ class KafkaSinkConfigSpec extends Specification { topicName = "my-topic", bootstrapServers = "my-bootstrap-server:9092", producerConf = Map( - "client.id" -> "my-client-id" + "client.id" -> "my-client-id", + "security.protocol" -> "SASL_SSL", + "sasl.mechanism" -> "OAUTHBEARER", + "sasl.jaas.config" -> "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;" ) ) diff --git a/modules/kafka/src/test/scala/com/snowplowanalytics/snowplow/sources/kafka/KafkaSourceConfigSpec.scala b/modules/kafka/src/test/scala/com/snowplowanalytics/snowplow/sources/kafka/KafkaSourceConfigSpec.scala index 954eb57..c87a45e 100644 --- a/modules/kafka/src/test/scala/com/snowplowanalytics/snowplow/sources/kafka/KafkaSourceConfigSpec.scala +++ b/modules/kafka/src/test/scala/com/snowplowanalytics/snowplow/sources/kafka/KafkaSourceConfigSpec.scala @@ -46,7 +46,10 @@ class KafkaSourceConfigSpec extends Specification { consumerConf = Map( "group.id" -> "my-consumer-group", "allow.auto.create.topics" -> "false", - "auto.offset.reset" -> "latest" + "auto.offset.reset" -> "latest", + "security.protocol" -> "SASL_SSL", + "sasl.mechanism" -> "OAUTHBEARER", + "sasl.jaas.config" -> "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;" ) ) diff --git a/project/Dependencies.scala b/project/Dependencies.scala index db67aeb..a22fa15 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -153,6 +153,7 @@ object Dependencies { fs2Kafka, circeConfig, circeGeneric, + azureIdentity, snappy, specs2 )