From 24c7791d4edd0f1a3a5ab15e4c43e8209bc61dd5 Mon Sep 17 00:00:00 2001 From: Almog Gavra Date: Mon, 9 Sep 2024 15:46:32 -0700 Subject: [PATCH] fix mongo KV range scans (#343) --- .../kafka/internal/db/mongo/KVDoc.java | 20 ++- .../kafka/internal/db/mongo/MongoKVTable.java | 7 +- ...esponsiveKeyValueStoreIntegrationTest.java | 128 ++++++++++++++++-- .../internal/db/mongo/MongoKVTableTest.java | 44 ++++-- 4 files changed, 165 insertions(+), 34 deletions(-) diff --git a/kafka-client/src/main/java/dev/responsive/kafka/internal/db/mongo/KVDoc.java b/kafka-client/src/main/java/dev/responsive/kafka/internal/db/mongo/KVDoc.java index df599be54..82a6958f6 100644 --- a/kafka-client/src/main/java/dev/responsive/kafka/internal/db/mongo/KVDoc.java +++ b/kafka-client/src/main/java/dev/responsive/kafka/internal/db/mongo/KVDoc.java @@ -29,6 +29,7 @@ public class KVDoc { public static final String VALUE = "value"; public static final String EPOCH = "epoch"; public static final String TIMESTAMP = "ts"; + public static final String KAFKA_PARTITION = "partition"; public static final String TOMBSTONE_TS = "tombstoneTs"; // We use a string key for ID because mongo range scans don't work as expected for binary @@ -38,6 +39,7 @@ public class KVDoc { byte[] value; long epoch; long timestamp; + int kafkaPartition; Date tombstoneTs; public KVDoc() { @@ -48,12 +50,14 @@ public KVDoc( @BsonProperty(ID) String id, @BsonProperty(VALUE) byte[] value, @BsonProperty(EPOCH) long epoch, - @BsonProperty(TIMESTAMP) long timestamp + @BsonProperty(TIMESTAMP) long timestamp, + @BsonProperty(KAFKA_PARTITION) int kafkaPartition ) { this.id = id; this.value = value; this.epoch = epoch; this.timestamp = timestamp; + this.kafkaPartition = kafkaPartition; } public String getKey() { @@ -88,6 +92,14 @@ public long getTimestamp() { return timestamp; } + public int getKafkaPartition() { + return kafkaPartition; + } + + public void setKafkaPartition(final int kafkaPartition) { + this.kafkaPartition = kafkaPartition; + } + public Date getTombstoneTs() { return tombstoneTs; } @@ -108,12 +120,14 @@ public boolean equals(final Object o) { return epoch == kvDoc.epoch && Objects.equals(id, kvDoc.id) && Arrays.equals(value, kvDoc.value) + && Objects.equals(timestamp, kvDoc.timestamp) + && Objects.equals(kafkaPartition, kvDoc.kafkaPartition) && Objects.equals(tombstoneTs, kvDoc.tombstoneTs); } @Override public int hashCode() { - int result = Objects.hash(id, epoch, tombstoneTs); + int result = Objects.hash(id, epoch, tombstoneTs, timestamp, kafkaPartition); result = 31 * result + Arrays.hashCode(value); return result; } @@ -125,6 +139,8 @@ public String toString() { + ", value=" + Arrays.toString(value) + ", epoch=" + epoch + ", tombstoneTs=" + tombstoneTs + + ", timestamp=" + timestamp + + ", kafkaPartition=" + kafkaPartition + '}'; } } \ No newline at end of file diff --git a/kafka-client/src/main/java/dev/responsive/kafka/internal/db/mongo/MongoKVTable.java b/kafka-client/src/main/java/dev/responsive/kafka/internal/db/mongo/MongoKVTable.java index 201de93f2..d85ae2e3e 100644 --- a/kafka-client/src/main/java/dev/responsive/kafka/internal/db/mongo/MongoKVTable.java +++ b/kafka-client/src/main/java/dev/responsive/kafka/internal/db/mongo/MongoKVTable.java @@ -170,7 +170,8 @@ public KeyValueIterator range( Filters.gte(KVDoc.ID, keyCodec.encode(from)), Filters.lte(KVDoc.ID, keyCodec.encode(to)), Filters.not(Filters.exists(KVDoc.TOMBSTONE_TS)), - Filters.gte(KVDoc.TIMESTAMP, minValidTs) + Filters.gte(KVDoc.TIMESTAMP, minValidTs), + Filters.eq(KVDoc.KAFKA_PARTITION, kafkaPartition) ) ); return Iterators.kv( @@ -185,7 +186,8 @@ public KeyValueIterator range( public KeyValueIterator all(final int kafkaPartition, final long minValidTs) { final FindIterable result = docs.find(Filters.and( Filters.not(Filters.exists(KVDoc.TOMBSTONE_TS)), - Filters.gte(KVDoc.TIMESTAMP, minValidTs) + Filters.gte(KVDoc.TIMESTAMP, minValidTs), + Filters.eq(KVDoc.KAFKA_PARTITION, kafkaPartition) )); return Iterators.kv( result.iterator(), @@ -213,6 +215,7 @@ public WriteModel insert( Updates.set(KVDoc.VALUE, value), Updates.set(KVDoc.EPOCH, epoch), Updates.set(KVDoc.TIMESTAMP, epochMillis), + Updates.set(KVDoc.KAFKA_PARTITION, kafkaPartition), Updates.unset(KVDoc.TOMBSTONE_TS) ), new UpdateOptions().upsert(true) diff --git a/kafka-client/src/test/java/dev/responsive/kafka/integration/ResponsiveKeyValueStoreIntegrationTest.java b/kafka-client/src/test/java/dev/responsive/kafka/integration/ResponsiveKeyValueStoreIntegrationTest.java index 825221821..04eee505b 100644 --- a/kafka-client/src/test/java/dev/responsive/kafka/integration/ResponsiveKeyValueStoreIntegrationTest.java +++ b/kafka-client/src/test/java/dev/responsive/kafka/integration/ResponsiveKeyValueStoreIntegrationTest.java @@ -37,20 +37,31 @@ import java.util.Map; import java.util.Optional; import java.util.Random; +import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import org.apache.kafka.clients.admin.Admin; import org.apache.kafka.clients.admin.NewTopic; import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.common.serialization.Serdes; import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; import org.apache.kafka.streams.StreamsBuilder; import org.apache.kafka.streams.kstream.KStream; import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.api.Record; import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.KeyValueIterator; import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; import org.apache.kafka.streams.state.internals.RocksDBKeyValueBytesStoreSupplier; import org.hamcrest.Matchers; +import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInfo; @@ -81,7 +92,7 @@ public void before( final var result = admin.createTopics( List.of( - new NewTopic(inputTopic(), Optional.of(1), Optional.empty()), + new NewTopic(inputTopic(), Optional.of(2), Optional.empty()), new NewTopic(outputTopic(), Optional.of(1), Optional.empty()) ) ); @@ -104,40 +115,104 @@ private String outputTopic() { */ @Test public void shouldMatchRocksDB() throws Exception { - final KeyValueBytesStoreSupplier rocksDbStore = - new RocksDBKeyValueBytesStoreSupplier(name, false); - - final KeyValueBytesStoreSupplier responsiveStore = - ResponsiveStores.keyValueStore(ResponsiveKeyValueParams.keyValue(name)); - - final StoreComparatorSuppliers.CompareFunction compare = - (String method, Object[] args, Object actual, Object truth) -> { - final String reason = method + " should yield identical results."; - assertThat(reason, actual, Matchers.equalTo(truth)); - }; + final StoreComparatorSuppliers.MultiKeyValueStoreSupplier + multiKeyValueStoreSupplier = multiKeyValueStoreSupplier(name); final Materialized> combinedStore = - Materialized.as(new StoreComparatorSuppliers.MultiKeyValueStoreSupplier( - rocksDbStore, responsiveStore, compare - )); + Materialized.as(multiKeyValueStoreSupplier); + + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder( + multiKeyValueStoreSupplier(name + "2"), + Serdes.String(), + Serdes.String()); // Start from timestamp of 0L to get predictable results final List> inputEvents = Arrays.asList( new KeyValueTimestamp<>("key", "a", 0L), + new KeyValueTimestamp<>("keyB", "x", 0L), + new KeyValueTimestamp<>("keyC", "y", 0L), + new KeyValueTimestamp<>("keyD", "z", 0L), + new KeyValueTimestamp<>("key", "c", 1_000L), + new KeyValueTimestamp<>("keyB", "x", 1_200L), + new KeyValueTimestamp<>("keyC", "y", 1_300L), + new KeyValueTimestamp<>("keyD", "z", 1_400L), + new KeyValueTimestamp<>("key", "b", 2_000L), + new KeyValueTimestamp<>("keyB", "x", 2_200L), + new KeyValueTimestamp<>("keyC", "y", 2_300L), + new KeyValueTimestamp<>("keyD", "z", 2_400L), + new KeyValueTimestamp<>("key", "d", 3_000L), new KeyValueTimestamp<>("key", "b", 3_000L), new KeyValueTimestamp<>("key", null, 4_000L), + new KeyValueTimestamp<>("key2", "e", 4_000L), + new KeyValueTimestamp<>("key2B", "x", 4_200L), + new KeyValueTimestamp<>("key2C", "y", 4_300L), + new KeyValueTimestamp<>("key2D", "z", 4_400L), + new KeyValueTimestamp<>("key2", "b", 5_000L), + new KeyValueTimestamp<>("STOP", "b", 18_000L) ); final CountDownLatch outputLatch = new CountDownLatch(1); final StreamsBuilder builder = new StreamsBuilder(); + builder.addStateStore(storeBuilder); + final KStream input = builder.stream(inputTopic()); input + // add a processor that issues a range scan on a KV state store + // since there are no DSL methods for this we have to do it + // via the process API + .process(new ProcessorSupplier() { + @Override + public Set> stores() { + return Set.of(storeBuilder); + } + + @Override + public Processor get() { + return new Processor<>() { + + private ProcessorContext context; + private KeyValueStore store; + + @Override + public void init(final ProcessorContext context) { + this.store = context.getStateStore(storeBuilder.name()); + this.context = context; + } + + @Override + public void process(final Record record) { + store.put(record.key(), record.value()); + if (record.value() == null) { + context.forward(record); + return; + } + + final StringBuilder combined; + try ( + KeyValueIterator range = store.range( + record.key(), + record.key() + "Z" + ) + ) { + combined = new StringBuilder(record.value()); + while (range.hasNext()) { + final KeyValue next = range.next(); + combined.append(next.value); + } + } + + context.forward(record.withValue(combined.toString())); + } + }; + } + }) .groupByKey() .aggregate(() -> "", (k, v1, agg) -> agg + v1, combinedStore) .toStream() @@ -169,4 +244,27 @@ public void shouldMatchRocksDB() throws Exception { } } + @NotNull + private StoreComparatorSuppliers.MultiKeyValueStoreSupplier multiKeyValueStoreSupplier( + final String name + ) { + final KeyValueBytesStoreSupplier rocksDbStore = + new RocksDBKeyValueBytesStoreSupplier(name, false); + + final KeyValueBytesStoreSupplier responsiveStore = + ResponsiveStores.keyValueStore(ResponsiveKeyValueParams.keyValue(name)); + + final StoreComparatorSuppliers.CompareFunction compare = + (String method, Object[] args, Object actual, Object truth) -> { + final String reason = method + " should yield identical results."; + assertThat(reason, actual, Matchers.equalTo(truth)); + }; + + return new StoreComparatorSuppliers.MultiKeyValueStoreSupplier( + rocksDbStore, + responsiveStore, + compare + ); + } + } diff --git a/kafka-client/src/test/java/dev/responsive/kafka/internal/db/mongo/MongoKVTableTest.java b/kafka-client/src/test/java/dev/responsive/kafka/internal/db/mongo/MongoKVTableTest.java index 903b33190..44df7329e 100644 --- a/kafka-client/src/test/java/dev/responsive/kafka/internal/db/mongo/MongoKVTableTest.java +++ b/kafka-client/src/test/java/dev/responsive/kafka/internal/db/mongo/MongoKVTableTest.java @@ -210,17 +210,24 @@ public void shouldIncludeResultsWithNewerTimestamp() { } @Test - public void shouldHandleRangeScansCorrectly() { + public void shouldHandlePartitionedRangeScansCorrectly() { // Given: final MongoKVTable table = new MongoKVTable(client, name, UNSHARDED); - var writerFactory = table.init(0); - var writer = writerFactory.createWriter(0); - writer.insert(bytes(10, 11, 12, 12, 13), byteArray(1), 100); - writer.insert(bytes(10, 11, 12, 13), byteArray(2), 100); - writer.insert(bytes(10, 11, 13), byteArray(3), 100); - writer.insert(bytes(10, 11, 13, 14), byteArray(4), 100); - writer.insert(bytes(11, 12), byteArray(5), 100); - writer.flush(); + + final var writerFactory0 = table.init(0); + final var writer0 = writerFactory0.createWriter(0); + final var writerFactory1 = table.init(1); + final var writer1 = writerFactory1.createWriter(1); + + writer0.insert(bytes(10, 11, 12, 12, 13), byteArray(1), 100); + writer0.insert(bytes(10, 11, 12, 13), byteArray(2), 100); + writer0.insert(bytes(10, 11, 13), byteArray(3), 100); + writer1.insert(bytes(10, 11, 13, 13), byteArray(3), 100); // in range, excluded by partition + writer0.insert(bytes(10, 11, 13, 14), byteArray(4), 100); + writer0.insert(bytes(11, 12), byteArray(5), 100); + + writer0.flush(); + writer1.flush(); // When: final var iter = table.range(0, bytes(10, 11, 12, 13), bytes(10, 11, 13, 14), -1); @@ -294,12 +301,19 @@ public void shouldFilterExpiredItemsFromRangeScans() { @Test public void shouldHandleFullScansCorrectly() { final MongoKVTable table = new MongoKVTable(client, name, UNSHARDED); - var writerFactory = table.init(0); - var writer = writerFactory.createWriter(0); - writer.insert(bytes(10, 11, 12, 13), byteArray(2), 100); - writer.insert(bytes(10, 11, 13), byteArray(3), 100); - writer.insert(bytes(10, 11, 13, 14), byteArray(4), 100); - writer.flush(); + + final var writerFactory0 = table.init(0); + final var writer0 = writerFactory0.createWriter(0); + final var writerFactory1 = table.init(1); + final var writer1 = writerFactory1.createWriter(1); + + writer0.insert(bytes(10, 11, 12, 13), byteArray(2), 100); + writer0.insert(bytes(10, 11, 13), byteArray(3), 100); + writer0.insert(bytes(10, 11, 13, 14), byteArray(4), 100); + writer1.insert(bytes(11, 13, 14), byteArray(5), 100); // excluded by partition + + writer0.flush(); + writer1.flush(); // When: final var iter = table.all(0, -1);