From f5edb210b4422a94cec3b985c8c5511d1c3b6e31 Mon Sep 17 00:00:00 2001 From: Anush Date: Tue, 2 Jul 2024 12:31:37 +0530 Subject: [PATCH] feat: Multi vector (#2) --- .github/workflows/test.yml | 6 +- README.md | 45 +++++++++++++++ build.gradle | 4 +- .../io/qdrant/kafka/BaseKafkaConnectTest.java | 23 ++++++++ .../java/io/qdrant/kafka/BaseQdrantTest.java | 20 +++++++ .../qdrant/kafka/QdrantSinkConnectorTest.java | 15 +++++ .../java/io/qdrant/kafka/VectorsFactory.java | 57 +++++++++++++++---- .../io/qdrant/kafka/VectorsFactoryTest.java | 53 ++++++++++++++++- 8 files changed, 207 insertions(+), 16 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 753613d..0b04789 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -7,6 +7,8 @@ on: - edited - synchronize - reopened + schedule: + - cron: 0 0 * * * permissions: contents: write @@ -14,7 +16,7 @@ permissions: jobs: build: - name: Build + name: Test and Build runs-on: ubuntu-latest steps: @@ -27,7 +29,7 @@ jobs: java-version: '17' distribution: 'temurin' - - name: Build And Test + - name: Gradle Build uses: gradle/gradle-build-action@v2 with: gradle-version: 8.5 diff --git a/README.md b/README.md index e2044c2..b2f0c89 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,51 @@ Reference: [Creating a collection with sparse vectors](https://qdrant.tech/docum +
+ Multi-vector + +```json +{ + "collection_name": "{collection_name}", + "id": 1, + "vector": { + "some-multi": [ + [ + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.6, + 0.7, + 0.8, + 0.9, + 1.0 + ], + [ + 1.0, + 0.9, + 0.8, + 0.5, + 0.4, + 0.8, + 0.6, + 0.4, + 0.2, + 0.1 + ] + ] + }, + "payload": { + "name": "kafka", + "description": "Kafka is a distributed streaming platform", + "url": "https://kafka.apache.org/" + } +} +``` + +
+
Combination of named dense and sparse vectors diff --git a/build.gradle b/build.gradle index 3b66e75..8375dc4 100644 --- a/build.gradle +++ b/build.gradle @@ -11,7 +11,7 @@ plugins { } group = 'io.qdrant' -version = '1.0.0' +version = '1.1.0' description = 'Kafka Sink Connector for Qdrant.' java.sourceCompatibility = JavaVersion.VERSION_1_8 java.targetCompatibility = JavaVersion.VERSION_1_8 @@ -43,7 +43,7 @@ def kafkaVersion = '3.5.0' dependencies { implementation "org.apache.kafka:connect-api:$kafkaVersion" - implementation 'io.qdrant:client:1.9.1' + implementation 'io.qdrant:client:1.10.0' implementation 'io.grpc:grpc-protobuf:1.59.0' implementation "io.grpc:grpc-netty-shaded:1.59.0" implementation 'com.google.guava:guava:33.2.1-jre' diff --git a/src/intTest/java/io/qdrant/kafka/BaseKafkaConnectTest.java b/src/intTest/java/io/qdrant/kafka/BaseKafkaConnectTest.java index c12860c..d732d7c 100644 --- a/src/intTest/java/io/qdrant/kafka/BaseKafkaConnectTest.java +++ b/src/intTest/java/io/qdrant/kafka/BaseKafkaConnectTest.java @@ -8,8 +8,10 @@ import static org.apache.kafka.connect.runtime.SinkConnectorConfig.TOPICS_CONFIG; import com.fasterxml.jackson.databind.ObjectMapper; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.ExecutionException; @@ -148,4 +150,25 @@ void writeSparseVector(String collectionName, Object id, String name, int vector connect.kafka().produce(topicName, message); } + + void writeMultiVector( + String collectionName, Object id, int vectorSize, String name, int multiSize) + throws Exception { + Map messageMap = new HashMap<>(); + messageMap.put("collection_name", collectionName); + messageMap.put("id", id); + + Map vectorMap = new HashMap<>(); + List> multiVector = new ArrayList<>(multiSize); + for (int i = 0; i < multiSize; i++) { + multiVector.add(randomVector(vectorSize)); + } + vectorMap.put(name, multiVector); + + messageMap.put("vector", vectorMap); + + String message = new ObjectMapper().writeValueAsString(messageMap); + + connect.kafka().produce(topicName, message); + } } diff --git a/src/intTest/java/io/qdrant/kafka/BaseQdrantTest.java b/src/intTest/java/io/qdrant/kafka/BaseQdrantTest.java index d93b695..16f8689 100644 --- a/src/intTest/java/io/qdrant/kafka/BaseQdrantTest.java +++ b/src/intTest/java/io/qdrant/kafka/BaseQdrantTest.java @@ -7,6 +7,8 @@ import io.qdrant.client.QdrantGrpcClient; import io.qdrant.client.grpc.Collections.CreateCollection; import io.qdrant.client.grpc.Collections.Distance; +import io.qdrant.client.grpc.Collections.MultiVectorComparator; +import io.qdrant.client.grpc.Collections.MultiVectorConfig; import io.qdrant.client.grpc.Collections.SparseVectorConfig; import io.qdrant.client.grpc.Collections.SparseVectorParams; import io.qdrant.client.grpc.Collections.VectorParams; @@ -44,6 +46,10 @@ public class BaseQdrantTest { String sparseVecCollection = "sparse-vec-collection"; String sparseVecName = "sparse-vec"; + String multiVecCollection = "multi-vec-collection"; + String multiVecName = "multi-vec"; + int multiVecSize = 58; + @BeforeEach void setup() throws Exception { qdrantClient = @@ -76,6 +82,19 @@ void setup() throws Exception { .putMap(sparseVecName, SparseVectorParams.getDefaultInstance())) .build()) .get(); + + // Create multi vector collection + Map multiVecParams = new HashMap<>(); + multiVecParams.put( + multiVecName, + VectorParams.newBuilder() + .setSize(multiVecSize) + .setDistance(Distance.Dot) + .setMultivectorConfig( + MultiVectorConfig.newBuilder().setComparator(MultiVectorComparator.MaxSim).build()) + .build()); + + qdrantClient.createCollectionAsync(multiVecCollection, multiVecParams).get(); } @AfterEach @@ -84,6 +103,7 @@ void tearDown() { qdrantClient.deleteCollectionAsync(unnamedVecCollection); qdrantClient.deleteCollectionAsync(namedVecCollection); qdrantClient.deleteCollectionAsync(sparseVecCollection); + qdrantClient.deleteCollectionAsync(multiVecCollection); qdrantClient.close(); } } diff --git a/src/intTest/java/io/qdrant/kafka/QdrantSinkConnectorTest.java b/src/intTest/java/io/qdrant/kafka/QdrantSinkConnectorTest.java index 9adae95..b0ab8c8 100644 --- a/src/intTest/java/io/qdrant/kafka/QdrantSinkConnectorTest.java +++ b/src/intTest/java/io/qdrant/kafka/QdrantSinkConnectorTest.java @@ -61,4 +61,19 @@ public void testSparseVector() throws Exception { waitForPoints(sparseVecCollection, pointsCount); } + + @Test + public void testMultiVector() throws Exception { + connect.configureConnector(CONNECTOR_NAME, connectorProperties()); + waitForConnectorToStart(CONNECTOR_NAME, 1); + + int pointsCount = randomPositiveInt(100); + int multiSize = randomPositiveInt(20); + + for (int i = 0; i < pointsCount; i++) { + writeMultiVector(multiVecCollection, i, multiVecSize, multiVecName, multiSize); + } + + waitForPoints(multiVecCollection, pointsCount); + } } diff --git a/src/main/java/io/qdrant/kafka/VectorsFactory.java b/src/main/java/io/qdrant/kafka/VectorsFactory.java index a381047..e95d69a 100644 --- a/src/main/java/io/qdrant/kafka/VectorsFactory.java +++ b/src/main/java/io/qdrant/kafka/VectorsFactory.java @@ -13,6 +13,7 @@ import org.apache.kafka.connect.errors.DataException; /* Helper to convert JSON vector representations into io.qdrant.client.grpc.Points.Vectors. */ + // Example JSON inputs: // { // "vector": [ @@ -48,6 +49,17 @@ // } // } // } + +// { +// "vector": { +// "some-name": [ +// [0.041732933, 0.013779674, -0.027564144], +// [0.051345434, 0.013743223, -0.027576543], +// [0.041732933, 0.013779674, -0.027564144] +// ] +// } +// } + class VectorsFactory { public static Vectors vectors(Value vectorValue) throws DataException { @@ -67,17 +79,6 @@ public static Vectors vectors(Value vectorValue) throws DataException { return vectorsBuilder.build(); } - private static Vector parseDenseVector(ListValue listValue) throws DataException { - Vector.Builder vectorBuilder = Vector.newBuilder(); - for (Value value : listValue.getValuesList()) { - if (!value.hasDoubleValue()) { - throw new DataException("Vector data must be a list of floats"); - } - vectorBuilder.addData((float) value.getDoubleValue()); - } - return vectorBuilder.build(); - } - private static NamedVectors parseNamedVectors(Struct struct) throws DataException { NamedVectors.Builder namedVectorsBuilder = NamedVectors.newBuilder(); for (Map.Entry entry : struct.getFieldsMap().entrySet()) { @@ -94,6 +95,40 @@ private static NamedVectors parseNamedVectors(Struct struct) throws DataExceptio return namedVectorsBuilder.build(); } + private static Vector parseDenseVector(ListValue listValue) throws DataException { + Vector.Builder vectorBuilder = Vector.newBuilder(); + for (Value value : listValue.getValuesList()) { + if (value.hasListValue()) { + return parseMultiDenseVector(listValue); + } + + if (!value.hasDoubleValue()) { + throw new DataException("Dense vector data must be a list of floats"); + } + vectorBuilder.addData((float) value.getDoubleValue()); + } + return vectorBuilder.build(); + } + + private static Vector parseMultiDenseVector(ListValue listValue) throws DataException { + Vector.Builder vectorBuilder = Vector.newBuilder(); + int numRows = listValue.getValuesCount(); + + for (Value row : listValue.getValuesList()) { + if (!row.hasListValue()) { + throw new DataException("Multi vector data must be a list of lists of floats"); + } + for (Value value : row.getListValue().getValuesList()) { + if (!value.hasDoubleValue()) { + throw new DataException("Multi vector data must be a list of lists of floats"); + } + vectorBuilder.addData((float) value.getDoubleValue()); + } + } + vectorBuilder.setVectorsCount(numRows); + return vectorBuilder.build(); + } + private static Vector parseSparseVector(Struct struct) throws DataException { Map fields = struct.getFieldsMap(); diff --git a/src/test/java/io/qdrant/kafka/VectorsFactoryTest.java b/src/test/java/io/qdrant/kafka/VectorsFactoryTest.java index c3fce20..c567790 100644 --- a/src/test/java/io/qdrant/kafka/VectorsFactoryTest.java +++ b/src/test/java/io/qdrant/kafka/VectorsFactoryTest.java @@ -66,6 +66,53 @@ void testNamedVectors() { .addValues(Value.newBuilder().setDoubleValue(0.52354).build()) .build()) .build()); + fieldsMap.put( + "multi", + Value.newBuilder() + .setListValue( + ListValue.newBuilder() + .addValues( + Value.newBuilder() + .setListValue( + ListValue.newBuilder() + .addValues(Value.newBuilder().setDoubleValue(0.32).build()) + .addValues(Value.newBuilder().setDoubleValue(0.432).build()) + .addValues(Value.newBuilder().setDoubleValue(0.423).build()) + .addValues(Value.newBuilder().setDoubleValue(0.52354).build()) + .build()) + .build()) + .addValues( + Value.newBuilder() + .setListValue( + ListValue.newBuilder() + .addValues(Value.newBuilder().setDoubleValue(0.32).build()) + .addValues(Value.newBuilder().setDoubleValue(0.432).build()) + .addValues(Value.newBuilder().setDoubleValue(0.423).build()) + .addValues(Value.newBuilder().setDoubleValue(0.52354).build()) + .build()) + .build()) + .addValues( + Value.newBuilder() + .setListValue( + ListValue.newBuilder() + .addValues(Value.newBuilder().setDoubleValue(0.32).build()) + .addValues(Value.newBuilder().setDoubleValue(0.432).build()) + .addValues(Value.newBuilder().setDoubleValue(0.423).build()) + .addValues(Value.newBuilder().setDoubleValue(0.52354).build()) + .build()) + .build()) + .addValues( + Value.newBuilder() + .setListValue( + ListValue.newBuilder() + .addValues(Value.newBuilder().setDoubleValue(0.32).build()) + .addValues(Value.newBuilder().setDoubleValue(0.432).build()) + .addValues(Value.newBuilder().setDoubleValue(0.423).build()) + .addValues(Value.newBuilder().setDoubleValue(0.52354).build()) + .build()) + .build()) + .build()) + .build()); Struct struct = Struct.newBuilder().putAllFields(fieldsMap).build(); Value vectorValue = Value.newBuilder().setStructValue(struct).build(); @@ -74,9 +121,13 @@ void testNamedVectors() { assertTrue(vectors.hasVectors()); NamedVectors namedVectors = vectors.getVectors(); - assertEquals(2, namedVectors.getVectorsCount()); + assertEquals(3, namedVectors.getVectorsCount()); assertTrue(namedVectors.containsVectors("boi")); assertTrue(namedVectors.containsVectors("gal")); + assertTrue(namedVectors.containsVectors("multi")); + + Vector multiVector = namedVectors.getVectorsOrThrow("multi"); + assertEquals(4, multiVector.getVectorsCount()); } @Test