Skip to content

Commit

Permalink
fix: Threadsafety fixes for transactional stages (akka#1738)
Browse files Browse the repository at this point in the history
  • Loading branch information
johanandren authored May 3, 2024
1 parent 274a44f commit 50fda13
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ private class SubSourceLogic[K, V, Msg](

/** We have created a source for these partitions, but it has not started up and is not in subSources yet. */
private var partitionsInStartup: immutable.Set[TopicPartition] = immutable.Set.empty
protected var subSources: Map[TopicPartition, SubSourceStageLogicControl] = immutable.Map.empty
@volatile protected var subSources: Map[TopicPartition, SubSourceStageLogicControl] = immutable.Map.empty

/** Kafka has signalled these partitions are revoked, but some may be re-assigned just after revoking. */
private var partitionsToRevoke: Set[TopicPartition] = Set.empty
Expand Down
40 changes: 28 additions & 12 deletions core/src/main/scala/akka/kafka/internal/TransactionalSources.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ private[internal] abstract class TransactionalSourceLogic[K, V, Msg](shape: Sour
override protected def logSource: Class[_] = classOf[TransactionalSourceLogic[_, _, _]]

private val inFlightRecords = InFlightRecords.empty
@volatile protected var threadSafeSourceActor: ActorRef = _
@volatile protected var threadSafeConsumerActor: ActorRef = _

override def preStart(): Unit = {
super.preStart()
// actually the stage actor of this stage, initialized in super preStart
threadSafeSourceActor = this.sourceActor.ref
threadSafeConsumerActor = this.consumerActor
}

override def messageHandling = super.messageHandling.orElse(drainHandling).orElse {
case (_, Revoked(tps)) =>
Expand Down Expand Up @@ -152,18 +161,17 @@ private[internal] abstract class TransactionalSourceLogic[K, V, Msg](shape: Sour
override protected def addToPartitionAssignmentHandler(
handler: PartitionAssignmentHandler
): PartitionAssignmentHandler = {
// FIXME this touches mutable internal stage fields (sourceActor, stageActor, consumerActor, subSources) from
// another thread (consumer actor) not thread safe
// Note: this runs on a different thread so be careful with any mutable state
val blockingRevokedCall = new PartitionAssignmentHandler {
override def onAssign(assignedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = ()

// This is invoked in the KafkaConsumerActor thread when doing poll.
override def onRevoke(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit =
if (waitForDraining(revokedTps)) {
sourceActor.ref.tell(Revoked(revokedTps.toList), consumerActor)
threadSafeSourceActor.tell(Revoked(revokedTps.toList), consumerActor)
} else {
sourceActor.ref.tell(Failure(new Error("Timeout while draining")), consumerActor)
consumerActor.tell(KafkaConsumerActor.Internal.StopFromStage(id), consumerActor)
threadSafeSourceActor.tell(Failure(new Error("Timeout while draining")), consumerActor)
threadSafeConsumerActor.tell(KafkaConsumerActor.Internal.StopFromStage(id), consumerActor)
}

override def onLost(lostTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit =
Expand All @@ -175,7 +183,7 @@ private[internal] abstract class TransactionalSourceLogic[K, V, Msg](shape: Sour
import akka.pattern.ask
implicit val timeout = Timeout(consumerSettings.commitTimeout)
try {
Await.result(ask(stageActor.ref, Drain(partitions, None, Drained)), timeout.duration)
Await.result(ask(threadSafeSourceActor, Drain(partitions, None, Drained)), timeout.duration)
true
} catch {
case t: Throwable =>
Expand Down Expand Up @@ -241,11 +249,19 @@ private[kafka] final class TransactionalSubSource[K, V](

new SubSourceLogic(shape, txConsumerSettings, subscription, subSourceStageLogicFactory = factory) {

@volatile var threadSafeSourceActor: ActorRef = _
@volatile var threadSafeConsumerActor: ActorRef = _

override def preStart(): Unit = {
super.preStart()
threadSafeSourceActor = this.sourceActor.ref
threadSafeConsumerActor = this.consumerActor
}

override protected def addToPartitionAssignmentHandler(
handler: PartitionAssignmentHandler
): PartitionAssignmentHandler = {
// FIXME this touches mutable internal stage fields (sourceActor, stageActor, consumerActor, subSources) from
// another thread (consumer actor) not thread safe
// Note: this runs on a different thread so be careful with any mutable state
val blockingRevokedCall = new PartitionAssignmentHandler {
override def onAssign(assignedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = ()

Expand All @@ -255,10 +271,10 @@ private[kafka] final class TransactionalSubSource[K, V](
else if (waitForDraining(revokedTps)) {
subSources.values
.map(_.controlAndStageActor.stageActor)
.foreach(_.tell(Revoked(revokedTps.toList), stageActor.ref))
.foreach(_.tell(Revoked(revokedTps.toList), threadSafeSourceActor))
} else {
sourceActor.ref.tell(Status.Failure(new Error("Timeout while draining")), stageActor.ref)
consumerActor.tell(KafkaConsumerActor.Internal.StopFromStage(id), stageActor.ref)
threadSafeSourceActor.tell(Status.Failure(new Error("Timeout while draining")), threadSafeSourceActor)
threadSafeConsumerActor.tell(KafkaConsumerActor.Internal.StopFromStage(id), threadSafeSourceActor)
}

override def onLost(lostTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit =
Expand All @@ -279,7 +295,7 @@ private[kafka] final class TransactionalSubSource[K, V](
Await.result(Future.sequence(drainCommandFutures), timeout.duration)
true
} catch {
case t: Throwable =>
case _: Throwable =>
false
}
}
Expand Down

0 comments on commit 50fda13

Please sign in to comment.