Skip to content

Commit

Permalink
fix mongo KV range scans (#343)
Browse files Browse the repository at this point in the history
  • Loading branch information
agavra authored Sep 9, 2024
1 parent 1501630 commit 24c7791
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public class KVDoc {
public static final String VALUE = "value";
public static final String EPOCH = "epoch";
public static final String TIMESTAMP = "ts";
public static final String KAFKA_PARTITION = "partition";
public static final String TOMBSTONE_TS = "tombstoneTs";

// We use a string key for ID because mongo range scans don't work as expected for binary
Expand All @@ -38,6 +39,7 @@ public class KVDoc {
byte[] value;
long epoch;
long timestamp;
int kafkaPartition;
Date tombstoneTs;

public KVDoc() {
Expand All @@ -48,12 +50,14 @@ public KVDoc(
@BsonProperty(ID) String id,
@BsonProperty(VALUE) byte[] value,
@BsonProperty(EPOCH) long epoch,
@BsonProperty(TIMESTAMP) long timestamp
@BsonProperty(TIMESTAMP) long timestamp,
@BsonProperty(KAFKA_PARTITION) int kafkaPartition
) {
this.id = id;
this.value = value;
this.epoch = epoch;
this.timestamp = timestamp;
this.kafkaPartition = kafkaPartition;
}

public String getKey() {
Expand Down Expand Up @@ -88,6 +92,14 @@ public long getTimestamp() {
return timestamp;
}

public int getKafkaPartition() {
return kafkaPartition;
}

public void setKafkaPartition(final int kafkaPartition) {
this.kafkaPartition = kafkaPartition;
}

public Date getTombstoneTs() {
return tombstoneTs;
}
Expand All @@ -108,12 +120,14 @@ public boolean equals(final Object o) {
return epoch == kvDoc.epoch
&& Objects.equals(id, kvDoc.id)
&& Arrays.equals(value, kvDoc.value)
&& Objects.equals(timestamp, kvDoc.timestamp)
&& Objects.equals(kafkaPartition, kvDoc.kafkaPartition)
&& Objects.equals(tombstoneTs, kvDoc.tombstoneTs);
}

@Override
public int hashCode() {
int result = Objects.hash(id, epoch, tombstoneTs);
int result = Objects.hash(id, epoch, tombstoneTs, timestamp, kafkaPartition);
result = 31 * result + Arrays.hashCode(value);
return result;
}
Expand All @@ -125,6 +139,8 @@ public String toString() {
+ ", value=" + Arrays.toString(value)
+ ", epoch=" + epoch
+ ", tombstoneTs=" + tombstoneTs
+ ", timestamp=" + timestamp
+ ", kafkaPartition=" + kafkaPartition
+ '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ public KeyValueIterator<Bytes, byte[]> range(
Filters.gte(KVDoc.ID, keyCodec.encode(from)),
Filters.lte(KVDoc.ID, keyCodec.encode(to)),
Filters.not(Filters.exists(KVDoc.TOMBSTONE_TS)),
Filters.gte(KVDoc.TIMESTAMP, minValidTs)
Filters.gte(KVDoc.TIMESTAMP, minValidTs),
Filters.eq(KVDoc.KAFKA_PARTITION, kafkaPartition)
)
);
return Iterators.kv(
Expand All @@ -185,7 +186,8 @@ public KeyValueIterator<Bytes, byte[]> range(
public KeyValueIterator<Bytes, byte[]> all(final int kafkaPartition, final long minValidTs) {
final FindIterable<KVDoc> result = docs.find(Filters.and(
Filters.not(Filters.exists(KVDoc.TOMBSTONE_TS)),
Filters.gte(KVDoc.TIMESTAMP, minValidTs)
Filters.gte(KVDoc.TIMESTAMP, minValidTs),
Filters.eq(KVDoc.KAFKA_PARTITION, kafkaPartition)
));
return Iterators.kv(
result.iterator(),
Expand Down Expand Up @@ -213,6 +215,7 @@ public WriteModel<KVDoc> insert(
Updates.set(KVDoc.VALUE, value),
Updates.set(KVDoc.EPOCH, epoch),
Updates.set(KVDoc.TIMESTAMP, epochMillis),
Updates.set(KVDoc.KAFKA_PARTITION, kafkaPartition),
Updates.unset(KVDoc.TOMBSTONE_TS)
),
new UpdateOptions().upsert(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,31 @@
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import org.apache.kafka.clients.admin.Admin;
import org.apache.kafka.clients.admin.NewTopic;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.common.serialization.Serdes;
import org.apache.kafka.common.utils.Bytes;
import org.apache.kafka.streams.KeyValue;
import org.apache.kafka.streams.StreamsBuilder;
import org.apache.kafka.streams.kstream.KStream;
import org.apache.kafka.streams.kstream.Materialized;
import org.apache.kafka.streams.processor.api.Processor;
import org.apache.kafka.streams.processor.api.ProcessorContext;
import org.apache.kafka.streams.processor.api.ProcessorSupplier;
import org.apache.kafka.streams.processor.api.Record;
import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier;
import org.apache.kafka.streams.state.KeyValueIterator;
import org.apache.kafka.streams.state.KeyValueStore;
import org.apache.kafka.streams.state.StoreBuilder;
import org.apache.kafka.streams.state.Stores;
import org.apache.kafka.streams.state.internals.RocksDBKeyValueBytesStoreSupplier;
import org.hamcrest.Matchers;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;
Expand Down Expand Up @@ -81,7 +92,7 @@ public void before(

final var result = admin.createTopics(
List.of(
new NewTopic(inputTopic(), Optional.of(1), Optional.empty()),
new NewTopic(inputTopic(), Optional.of(2), Optional.empty()),
new NewTopic(outputTopic(), Optional.of(1), Optional.empty())
)
);
Expand All @@ -104,40 +115,104 @@ private String outputTopic() {
*/
@Test
public void shouldMatchRocksDB() throws Exception {
final KeyValueBytesStoreSupplier rocksDbStore =
new RocksDBKeyValueBytesStoreSupplier(name, false);

final KeyValueBytesStoreSupplier responsiveStore =
ResponsiveStores.keyValueStore(ResponsiveKeyValueParams.keyValue(name));

final StoreComparatorSuppliers.CompareFunction compare =
(String method, Object[] args, Object actual, Object truth) -> {
final String reason = method + " should yield identical results.";
assertThat(reason, actual, Matchers.equalTo(truth));
};
final StoreComparatorSuppliers.MultiKeyValueStoreSupplier
multiKeyValueStoreSupplier = multiKeyValueStoreSupplier(name);

final Materialized<String, String, KeyValueStore<Bytes, byte[]>> combinedStore =
Materialized.as(new StoreComparatorSuppliers.MultiKeyValueStoreSupplier(
rocksDbStore, responsiveStore, compare
));
Materialized.as(multiKeyValueStoreSupplier);

final StoreBuilder<KeyValueStore<String, String>> storeBuilder =
Stores.keyValueStoreBuilder(
multiKeyValueStoreSupplier(name + "2"),
Serdes.String(),
Serdes.String());

// Start from timestamp of 0L to get predictable results
final List<KeyValueTimestamp<String, String>> inputEvents = Arrays.asList(
new KeyValueTimestamp<>("key", "a", 0L),
new KeyValueTimestamp<>("keyB", "x", 0L),
new KeyValueTimestamp<>("keyC", "y", 0L),
new KeyValueTimestamp<>("keyD", "z", 0L),

new KeyValueTimestamp<>("key", "c", 1_000L),
new KeyValueTimestamp<>("keyB", "x", 1_200L),
new KeyValueTimestamp<>("keyC", "y", 1_300L),
new KeyValueTimestamp<>("keyD", "z", 1_400L),

new KeyValueTimestamp<>("key", "b", 2_000L),
new KeyValueTimestamp<>("keyB", "x", 2_200L),
new KeyValueTimestamp<>("keyC", "y", 2_300L),
new KeyValueTimestamp<>("keyD", "z", 2_400L),

new KeyValueTimestamp<>("key", "d", 3_000L),
new KeyValueTimestamp<>("key", "b", 3_000L),
new KeyValueTimestamp<>("key", null, 4_000L),

new KeyValueTimestamp<>("key2", "e", 4_000L),
new KeyValueTimestamp<>("key2B", "x", 4_200L),
new KeyValueTimestamp<>("key2C", "y", 4_300L),
new KeyValueTimestamp<>("key2D", "z", 4_400L),

new KeyValueTimestamp<>("key2", "b", 5_000L),

new KeyValueTimestamp<>("STOP", "b", 18_000L)
);
final CountDownLatch outputLatch = new CountDownLatch(1);

final StreamsBuilder builder = new StreamsBuilder();
builder.addStateStore(storeBuilder);

final KStream<String, String> input = builder.stream(inputTopic());
input
// add a processor that issues a range scan on a KV state store
// since there are no DSL methods for this we have to do it
// via the process API
.process(new ProcessorSupplier<String, String, String, String>() {
@Override
public Set<StoreBuilder<?>> stores() {
return Set.of(storeBuilder);
}

@Override
public Processor<String, String, String, String> get() {
return new Processor<>() {

private ProcessorContext<String, String> context;
private KeyValueStore<String, String> store;

@Override
public void init(final ProcessorContext<String, String> context) {
this.store = context.getStateStore(storeBuilder.name());
this.context = context;
}

@Override
public void process(final Record<String, String> record) {
store.put(record.key(), record.value());
if (record.value() == null) {
context.forward(record);
return;
}

final StringBuilder combined;
try (
KeyValueIterator<String, String> range = store.range(
record.key(),
record.key() + "Z"
)
) {
combined = new StringBuilder(record.value());
while (range.hasNext()) {
final KeyValue<String, String> next = range.next();
combined.append(next.value);
}
}

context.forward(record.withValue(combined.toString()));
}
};
}
})
.groupByKey()
.aggregate(() -> "", (k, v1, agg) -> agg + v1, combinedStore)
.toStream()
Expand Down Expand Up @@ -169,4 +244,27 @@ public void shouldMatchRocksDB() throws Exception {
}
}

@NotNull
private StoreComparatorSuppliers.MultiKeyValueStoreSupplier multiKeyValueStoreSupplier(
final String name
) {
final KeyValueBytesStoreSupplier rocksDbStore =
new RocksDBKeyValueBytesStoreSupplier(name, false);

final KeyValueBytesStoreSupplier responsiveStore =
ResponsiveStores.keyValueStore(ResponsiveKeyValueParams.keyValue(name));

final StoreComparatorSuppliers.CompareFunction compare =
(String method, Object[] args, Object actual, Object truth) -> {
final String reason = method + " should yield identical results.";
assertThat(reason, actual, Matchers.equalTo(truth));
};

return new StoreComparatorSuppliers.MultiKeyValueStoreSupplier(
rocksDbStore,
responsiveStore,
compare
);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -210,17 +210,24 @@ public void shouldIncludeResultsWithNewerTimestamp() {
}

@Test
public void shouldHandleRangeScansCorrectly() {
public void shouldHandlePartitionedRangeScansCorrectly() {
// Given:
final MongoKVTable table = new MongoKVTable(client, name, UNSHARDED);
var writerFactory = table.init(0);
var writer = writerFactory.createWriter(0);
writer.insert(bytes(10, 11, 12, 12, 13), byteArray(1), 100);
writer.insert(bytes(10, 11, 12, 13), byteArray(2), 100);
writer.insert(bytes(10, 11, 13), byteArray(3), 100);
writer.insert(bytes(10, 11, 13, 14), byteArray(4), 100);
writer.insert(bytes(11, 12), byteArray(5), 100);
writer.flush();

final var writerFactory0 = table.init(0);
final var writer0 = writerFactory0.createWriter(0);
final var writerFactory1 = table.init(1);
final var writer1 = writerFactory1.createWriter(1);

writer0.insert(bytes(10, 11, 12, 12, 13), byteArray(1), 100);
writer0.insert(bytes(10, 11, 12, 13), byteArray(2), 100);
writer0.insert(bytes(10, 11, 13), byteArray(3), 100);
writer1.insert(bytes(10, 11, 13, 13), byteArray(3), 100); // in range, excluded by partition
writer0.insert(bytes(10, 11, 13, 14), byteArray(4), 100);
writer0.insert(bytes(11, 12), byteArray(5), 100);

writer0.flush();
writer1.flush();

// When:
final var iter = table.range(0, bytes(10, 11, 12, 13), bytes(10, 11, 13, 14), -1);
Expand Down Expand Up @@ -294,12 +301,19 @@ public void shouldFilterExpiredItemsFromRangeScans() {
@Test
public void shouldHandleFullScansCorrectly() {
final MongoKVTable table = new MongoKVTable(client, name, UNSHARDED);
var writerFactory = table.init(0);
var writer = writerFactory.createWriter(0);
writer.insert(bytes(10, 11, 12, 13), byteArray(2), 100);
writer.insert(bytes(10, 11, 13), byteArray(3), 100);
writer.insert(bytes(10, 11, 13, 14), byteArray(4), 100);
writer.flush();

final var writerFactory0 = table.init(0);
final var writer0 = writerFactory0.createWriter(0);
final var writerFactory1 = table.init(1);
final var writer1 = writerFactory1.createWriter(1);

writer0.insert(bytes(10, 11, 12, 13), byteArray(2), 100);
writer0.insert(bytes(10, 11, 13), byteArray(3), 100);
writer0.insert(bytes(10, 11, 13, 14), byteArray(4), 100);
writer1.insert(bytes(11, 13, 14), byteArray(5), 100); // excluded by partition

writer0.flush();
writer1.flush();

// When:
final var iter = table.all(0, -1);
Expand Down

0 comments on commit 24c7791

Please sign in to comment.