Skip to content

Commit

Permalink
KAFKA-14505; [1/N] Add support for transactional writes to Coordinato…
Browse files Browse the repository at this point in the history
…rRuntime (apache#14844)

This patch adds support for transactional writes to the CoordinatorRuntime framework. This mainly consists in adding CoordinatorRuntime#scheduleTransactionalWriteOperation and in adding the producerId and producerEpoch to various interfaces. The patch also extends the CoordinatorLoaderImpl and the CoordinatorPartitionWriter accordingly.

Reviewers: Justine Olshan <[email protected]>
  • Loading branch information
dajac authored and clolov committed Apr 5, 2024
1 parent 3e63319 commit 24d0db9
Show file tree
Hide file tree
Showing 11 changed files with 414 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ class CoordinatorLoaderImpl[T](
batch.asScala.foreach { record =>
numRecords = numRecords + 1
try {
coordinator.replay(deserializer.deserialize(record.key, record.value))
coordinator.replay(
batch.producerId,
batch.producerEpoch,
deserializer.deserialize(record.key, record.value)
)
} catch {
case ex: UnknownRecordTypeException =>
warn(s"Unknown record type ${ex.unknownType} while loading offsets and group metadata " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import kafka.server.{ActionQueue, ReplicaManager}
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.errors.RecordTooLargeException
import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.record.{CompressionType, MemoryRecords, TimestampType}
import org.apache.kafka.common.record.{CompressionType, MemoryRecords, RecordBatch, TimestampType}
import org.apache.kafka.common.record.Record.EMPTY_HEADERS
import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
import org.apache.kafka.common.utils.Time
Expand Down Expand Up @@ -106,13 +106,17 @@ class CoordinatorPartitionWriter[T](
* Write records to the partitions. Records are written in one batch so
* atomicity is guaranteed.
*
* @param tp The partition to write records to.
* @param records The list of records. The records are written in a single batch.
* @param tp The partition to write records to.
* @param producerId The producer id.
* @param producerEpoch The producer epoch.
* @param records The list of records. The records are written in a single batch.
* @return The log end offset right after the written records.
* @throws KafkaException Any KafkaException caught during the write operation.
*/
override def append(
tp: TopicPartition,
producerId: Long,
producerEpoch: Short,
records: util.List[T]
): Long = {
if (records.isEmpty) throw new IllegalStateException("records must be non-empty.")
Expand All @@ -129,7 +133,12 @@ class CoordinatorPartitionWriter[T](
compressionType,
TimestampType.CREATE_TIME,
0L,
maxBatchSize
time.milliseconds(),
producerId,
producerEpoch,
0,
producerId != RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PARTITION_LEADER_EPOCH
)

records.forEach { record =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import kafka.server.ReplicaManager
import kafka.utils.TestUtils
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.errors.NotLeaderOrFollowerException
import org.apache.kafka.common.record.{CompressionType, FileRecords, MemoryRecords, SimpleRecord}
import org.apache.kafka.common.record.{CompressionType, FileRecords, MemoryRecords, RecordBatch, SimpleRecord}
import org.apache.kafka.common.utils.{MockTime, Time}
import org.apache.kafka.coordinator.group.runtime.CoordinatorLoader.UnknownRecordTypeException
import org.apache.kafka.coordinator.group.runtime.{CoordinatorLoader, CoordinatorPlayback}
Expand Down Expand Up @@ -104,7 +104,7 @@ class CoordinatorLoaderImplTest {
)) { loader =>
when(replicaManager.getLog(tp)).thenReturn(Some(log))
when(log.logStartOffset).thenReturn(0L)
when(replicaManager.getLogEndOffset(tp)).thenReturn(Some(5L))
when(replicaManager.getLogEndOffset(tp)).thenReturn(Some(7L))

val readResult1 = logReadResult(startOffset = 0, records = Seq(
new SimpleRecord("k1".getBytes, "v1".getBytes),
Expand All @@ -131,13 +131,27 @@ class CoordinatorLoaderImplTest {
minOneMessage = true
)).thenReturn(readResult2)

val readResult3 = logReadResult(startOffset = 5, producerId = 100L, producerEpoch = 5, records = Seq(
new SimpleRecord("k6".getBytes, "v6".getBytes),
new SimpleRecord("k7".getBytes, "v7".getBytes)
))

when(log.read(
startOffset = 5L,
maxLength = 1000,
isolation = FetchIsolation.LOG_END,
minOneMessage = true
)).thenReturn(readResult3)

assertNotNull(loader.load(tp, coordinator).get(10, TimeUnit.SECONDS))

verify(coordinator).replay(("k1", "v1"))
verify(coordinator).replay(("k2", "v2"))
verify(coordinator).replay(("k3", "v3"))
verify(coordinator).replay(("k4", "v4"))
verify(coordinator).replay(("k5", "v5"))
verify(coordinator).replay(RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, ("k1", "v1"))
verify(coordinator).replay(RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, ("k2", "v2"))
verify(coordinator).replay(RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, ("k3", "v3"))
verify(coordinator).replay(RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, ("k4", "v4"))
verify(coordinator).replay(RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, ("k5", "v5"))
verify(coordinator).replay(100L, 5.toShort, ("k6", "v6"))
verify(coordinator).replay(100L, 5.toShort, ("k7", "v7"))
}
}

Expand Down Expand Up @@ -220,7 +234,7 @@ class CoordinatorLoaderImplTest {

loader.load(tp, coordinator).get(10, TimeUnit.SECONDS)

verify(coordinator).replay(("k2", "v2"))
verify(coordinator).replay(RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, ("k2", "v2"))
}
}

Expand Down Expand Up @@ -354,14 +368,28 @@ class CoordinatorLoaderImplTest {

private def logReadResult(
startOffset: Long,
producerId: Long = RecordBatch.NO_PRODUCER_ID,
producerEpoch: Short = RecordBatch.NO_PRODUCER_EPOCH,
records: Seq[SimpleRecord]
): FetchDataInfo = {
val fileRecords = mock(classOf[FileRecords])
val memoryRecords = MemoryRecords.withRecords(
startOffset,
CompressionType.NONE,
records: _*
)
val memoryRecords = if (producerId == RecordBatch.NO_PRODUCER_ID) {
MemoryRecords.withRecords(
startOffset,
CompressionType.NONE,
records: _*
)
} else {
MemoryRecords.withTransactionalRecords(
startOffset,
CompressionType.NONE,
producerId,
producerEpoch,
0,
RecordBatch.NO_PARTITION_LEADER_EPOCH,
records: _*
)
}

when(fileRecords.sizeInBytes).thenReturn(memoryRecords.sizeInBytes)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
import org.apache.kafka.common.utils.{MockTime, Time}
import org.apache.kafka.coordinator.group.runtime.PartitionWriter
import org.apache.kafka.storage.internals.log.{AppendOrigin, LogConfig}
import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows}
import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows, assertTrue}
import org.junit.jupiter.api.Test
import org.mockito.{ArgumentCaptor, ArgumentMatchers}
import org.mockito.Mockito.{mock, verify, when}
Expand Down Expand Up @@ -133,7 +133,12 @@ class CoordinatorPartitionWriterTest {
("k2", "v2"),
)

assertEquals(11, partitionRecordWriter.append(tp, records.asJava))
assertEquals(11, partitionRecordWriter.append(
tp,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_EPOCH,
records.asJava
))

val batch = recordsCapture.getValue.getOrElse(tp,
throw new AssertionError(s"No records for $tp"))
Expand All @@ -149,6 +154,86 @@ class CoordinatorPartitionWriterTest {
assertEquals(records, receivedRecords)
}

@Test
def testTransactionalWriteRecords(): Unit = {
val tp = new TopicPartition("foo", 0)
val replicaManager = mock(classOf[ReplicaManager])
val time = new MockTime()
val partitionRecordWriter = new CoordinatorPartitionWriter(
replicaManager,
new StringKeyValueSerializer(),
CompressionType.NONE,
time
)

when(replicaManager.getLogConfig(tp)).thenReturn(Some(LogConfig.fromProps(
Collections.emptyMap(),
new Properties()
)))

val recordsCapture: ArgumentCaptor[Map[TopicPartition, MemoryRecords]] =
ArgumentCaptor.forClass(classOf[Map[TopicPartition, MemoryRecords]])
val callbackCapture: ArgumentCaptor[Map[TopicPartition, PartitionResponse] => Unit] =
ArgumentCaptor.forClass(classOf[Map[TopicPartition, PartitionResponse] => Unit])

when(replicaManager.appendRecords(
ArgumentMatchers.eq(0L),
ArgumentMatchers.eq(1.toShort),
ArgumentMatchers.eq(true),
ArgumentMatchers.eq(AppendOrigin.COORDINATOR),
recordsCapture.capture(),
callbackCapture.capture(),
ArgumentMatchers.any(),
ArgumentMatchers.any(),
ArgumentMatchers.any(),
ArgumentMatchers.any(),
ArgumentMatchers.any()
)).thenAnswer(_ => {
callbackCapture.getValue.apply(Map(
tp -> new PartitionResponse(
Errors.NONE,
5,
10,
RecordBatch.NO_TIMESTAMP,
-1,
Collections.emptyList(),
""
)
))
})

val records = List(
("k0", "v0"),
("k1", "v1"),
("k2", "v2"),
)

assertEquals(11, partitionRecordWriter.append(
tp,
100L,
50.toShort,
records.asJava
))

val batch = recordsCapture.getValue.getOrElse(tp,
throw new AssertionError(s"No records for $tp"))
assertEquals(1, batch.batches().asScala.toList.size)

val firstBatch = batch.batches.asScala.head
assertEquals(100L, firstBatch.producerId)
assertEquals(50.toShort, firstBatch.producerEpoch)
assertTrue(firstBatch.isTransactional)

val receivedRecords = batch.records.asScala.map { record =>
(
Charset.defaultCharset().decode(record.key).toString,
Charset.defaultCharset().decode(record.value).toString,
)
}.toList

assertEquals(records, receivedRecords)
}

@Test
def testWriteRecordsWithFailure(): Unit = {
val tp = new TopicPartition("foo", 0)
Expand Down Expand Up @@ -195,8 +280,12 @@ class CoordinatorPartitionWriterTest {
("k2", "v2"),
)

assertThrows(classOf[NotLeaderOrFollowerException],
() => partitionRecordWriter.append(tp, records.asJava))
assertThrows(classOf[NotLeaderOrFollowerException], () => partitionRecordWriter.append(
tp,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_EPOCH,
records.asJava)
)
}

@Test
Expand Down Expand Up @@ -224,8 +313,12 @@ class CoordinatorPartitionWriterTest {
("k1", new String(randomBytes)),
)

assertThrows(classOf[RecordTooLargeException],
() => partitionRecordWriter.append(tp, records.asJava))
assertThrows(classOf[RecordTooLargeException], () => partitionRecordWriter.append(
tp,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_EPOCH,
records.asJava)
)
}

@Test
Expand All @@ -244,8 +337,12 @@ class CoordinatorPartitionWriterTest {
new Properties()
)))

assertThrows(classOf[IllegalStateException],
() => partitionRecordWriter.append(tp, List.empty.asJava))
assertThrows(classOf[IllegalStateException], () => partitionRecordWriter.append(
tp,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_EPOCH,
List.empty.asJava)
)
}

@Test
Expand All @@ -267,7 +364,11 @@ class CoordinatorPartitionWriterTest {
("k2", "v2"),
)

assertThrows(classOf[NotLeaderOrFollowerException],
() => partitionRecordWriter.append(tp, records.asJava))
assertThrows(classOf[NotLeaderOrFollowerException], () => partitionRecordWriter.append(
tp,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_EPOCH,
records.asJava)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -590,12 +590,18 @@ private ApiMessage messageOrNull(ApiMessageAndVersion apiMessageAndVersion) {

/**
* Replays the Record to update the hard state of the group coordinator.
* @param record The record to apply to the state machine.
*
* @param producerId The producer id.
* @param producerEpoch The producer epoch.
* @param record The record to apply to the state machine.
* @throws RuntimeException
*/
@Override
public void replay(Record record) throws RuntimeException {
public void replay(
long producerId,
short producerEpoch,
Record record
) throws RuntimeException {
ApiMessageAndVersion key = record.key();
ApiMessageAndVersion value = record.value();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,17 @@
* @param <U> The type of the record.
*/
public interface CoordinatorPlayback<U> {

/**
* Applies the given record to this object.
*
* @param record A record.
* @param producerId The producer id.
* @param producerEpoch The producer epoch.
* @param record A record.
* @throws RuntimeException if the record can not be applied.
*/
void replay(U record) throws RuntimeException;
void replay(
long producerId,
short producerEpoch,
U record
) throws RuntimeException;
}
Loading

0 comments on commit 24d0db9

Please sign in to comment.