Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kafka: authenticate with Event Hubs using OAuth2 #58

Merged
merged 1 commit into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions modules/kafka/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ snowplow.defaults {
"group.id": null # invalid value MUST be overridden by the applicaion
"allow.auto.create.topics": "false"
"auto.offset.reset": "latest"
"security.protocol": "SASL_SSL"
"sasl.mechanism": "OAUTHBEARER"
"sasl.jaas.config": "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;"
}
}
}
Expand All @@ -13,6 +16,9 @@ snowplow.defaults {
kafka: {
producerConf: {
"client.id": null # invalid value MUST be overriden by the application
"security.protocol": "SASL_SSL"
"sasl.mechanism": "OAUTHBEARER"
"sasl.jaas.config": "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;"
}
maxRecordSize: 1000000
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright (c) 2023-present Snowplow Analytics Ltd. All rights reserved.
*
* This program is licensed to you under the Snowplow Community License Version 1.0,
* and you may not use this file except in compliance with the Snowplow Community License Version 1.0.
* You may obtain a copy of the Snowplow Community License Version 1.0 at https://docs.snowplow.io/community-license-1.0
*/
package com.snowplowanalytics.snowplow.azure

import java.net.URI
import java.{lang, util}

import com.nimbusds.jwt.JWTParser

import javax.security.auth.callback.Callback
import javax.security.auth.callback.UnsupportedCallbackException
import javax.security.auth.login.AppConfigurationEntry

import org.apache.kafka.clients.CommonClientConfigs
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken
import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback

import com.azure.identity.DefaultAzureCredentialBuilder
import com.azure.core.credential.TokenRequestContext

trait AzureAuthenticationCallbackHandler extends AuthenticateCallbackHandler {

val credentials = new DefaultAzureCredentialBuilder().build()

var sbUri: String = ""

override def configure(
configs: util.Map[String, _],
saslMechanism: String,
jaasConfigEntries: util.List[AppConfigurationEntry]
): Unit = {
val bootstrapServer =
configs
.get(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG)
.toString
.replaceAll("\\[|\\]", "")
.split(",")
.toList
.headOption match {
case Some(s) => s
case None => throw new Exception("Empty bootstrap servers list")
}
val uri = URI.create("https://" + bootstrapServer)
// Workload identity works with '.default' scope
this.sbUri = s"${uri.getScheme}://${uri.getHost}/.default"
}

override def handle(callbacks: Array[Callback]): Unit =
callbacks.foreach {
case callback: OAuthBearerTokenCallback =>
val token = getOAuthBearerToken()
callback.token(token)
case callback => throw new UnsupportedCallbackException(callback)
}

def getOAuthBearerToken(): OAuthBearerToken = {
val reqContext = new TokenRequestContext()
reqContext.addScopes(sbUri)
val accessToken = credentials.getTokenSync(reqContext).getToken
val jwt = JWTParser.parse(accessToken)
val claims = jwt.getJWTClaimsSet

new OAuthBearerToken {
override def value(): String = accessToken

override def lifetimeMs(): Long = claims.getExpirationTime.getTime

override def scope(): util.Set[String] = null

override def principalName(): String = null

override def startTimeMs(): lang.Long = null
}
}

override def close(): Unit = ()
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,21 @@ import cats.Monad
import com.snowplowanalytics.snowplow.sinks.{Sink, Sinkable}
import fs2.kafka._

import scala.reflect._

import java.util.UUID

import com.snowplowanalytics.snowplow.azure.AzureAuthenticationCallbackHandler

object KafkaSink {

def resource[F[_]: Async](config: KafkaSinkConfig): Resource[F, Sink[F]] = {
def resource[F[_]: Async, T <: AzureAuthenticationCallbackHandler](
config: KafkaSinkConfig,
authHandlerClass: ClassTag[T]
): Resource[F, Sink[F]] = {
val producerSettings =
ProducerSettings[F, String, Array[Byte]]
.withProperty("sasl.login.callback.handler.class", authHandlerClass.runtimeClass.getName)
.withBootstrapServers(config.bootstrapServers)
.withProperties(config.producerConf)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import fs2.Stream
import org.typelevel.log4cats.{Logger, SelfAwareStructuredLogger}
import org.typelevel.log4cats.slf4j.Slf4jLogger

import scala.reflect._

import java.nio.ByteBuffer
import java.time.Instant

Expand All @@ -26,20 +28,27 @@ import org.apache.kafka.common.TopicPartition
// snowplow
import com.snowplowanalytics.snowplow.sources.SourceAndAck
import com.snowplowanalytics.snowplow.sources.internal.{Checkpointer, LowLevelEvents, LowLevelSource}
import com.snowplowanalytics.snowplow.azure.AzureAuthenticationCallbackHandler

object KafkaSource {

private implicit def logger[F[_]: Sync]: SelfAwareStructuredLogger[F] = Slf4jLogger.getLogger[F]

def build[F[_]: Async](config: KafkaSourceConfig): F[SourceAndAck[F]] =
LowLevelSource.toSourceAndAck(lowLevel(config))
def build[F[_]: Async, T <: AzureAuthenticationCallbackHandler](
config: KafkaSourceConfig,
authHandlerClass: ClassTag[T]
): F[SourceAndAck[F]] =
LowLevelSource.toSourceAndAck(lowLevel(config, authHandlerClass))

private def lowLevel[F[_]: Async](config: KafkaSourceConfig): LowLevelSource[F, KafkaCheckpoints[F]] =
private def lowLevel[F[_]: Async, T <: AzureAuthenticationCallbackHandler](
config: KafkaSourceConfig,
authHandlerClass: ClassTag[T]
): LowLevelSource[F, KafkaCheckpoints[F]] =
new LowLevelSource[F, KafkaCheckpoints[F]] {
def checkpointer: Checkpointer[F, KafkaCheckpoints[F]] = kafkaCheckpointer

def stream: Stream[F, Stream[F, LowLevelEvents[KafkaCheckpoints[F]]]] =
kafkaStream(config)
kafkaStream(config, authHandlerClass)
}

case class OffsetAndCommit[F[_]](offset: Long, commit: F[Unit])
Expand All @@ -59,9 +68,12 @@ object KafkaSource {
def nack(c: KafkaCheckpoints[F]): F[Unit] = Applicative[F].unit
}

private def kafkaStream[F[_]: Async](config: KafkaSourceConfig): Stream[F, Stream[F, LowLevelEvents[KafkaCheckpoints[F]]]] =
private def kafkaStream[F[_]: Async, T <: AzureAuthenticationCallbackHandler](
config: KafkaSourceConfig,
authHandlerClass: ClassTag[T]
): Stream[F, Stream[F, LowLevelEvents[KafkaCheckpoints[F]]]] =
KafkaConsumer
.stream(consumerSettings[F](config))
.stream(consumerSettings[F, T](config, authHandlerClass))
.evalTap(_.subscribeTo(config.topicName))
.flatMap { consumer =>
consumer.partitionsMapStream
Expand Down Expand Up @@ -124,8 +136,12 @@ object KafkaSource {
private implicit def byteBufferDeserializer[F[_]: Sync]: Resource[F, ValueDeserializer[F, ByteBuffer]] =
Resource.pure(Deserializer.lift(arr => Sync[F].pure(ByteBuffer.wrap(arr))))

private def consumerSettings[F[_]: Async](config: KafkaSourceConfig): ConsumerSettings[F, Array[Byte], ByteBuffer] =
private def consumerSettings[F[_]: Async, T <: AzureAuthenticationCallbackHandler](
config: KafkaSourceConfig,
authHandlerClass: ClassTag[T]
): ConsumerSettings[F, Array[Byte], ByteBuffer] =
ConsumerSettings[F, Array[Byte], ByteBuffer]
.withProperty("sasl.login.callback.handler.class", authHandlerClass.runtimeClass.getName)
.withBootstrapServers(config.bootstrapServers)
.withProperties(config.consumerConf)
.withEnableAutoCommit(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ class KafkaSinkConfigSpec extends Specification {
topicName = "my-topic",
bootstrapServers = "my-bootstrap-server:9092",
producerConf = Map(
"client.id" -> "my-client-id"
"client.id" -> "my-client-id",
"security.protocol" -> "SASL_SSL",
"sasl.mechanism" -> "OAUTHBEARER",
"sasl.jaas.config" -> "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;"
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ class KafkaSourceConfigSpec extends Specification {
consumerConf = Map(
"group.id" -> "my-consumer-group",
"allow.auto.create.topics" -> "false",
"auto.offset.reset" -> "latest"
"auto.offset.reset" -> "latest",
"security.protocol" -> "SASL_SSL",
"sasl.mechanism" -> "OAUTHBEARER",
"sasl.jaas.config" -> "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;"
)
)

Expand Down
1 change: 1 addition & 0 deletions project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ object Dependencies {
fs2Kafka,
circeConfig,
circeGeneric,
azureIdentity,
snappy,
specs2
)
Expand Down
Loading