Skip to content

Commit

Permalink
feat: Multi vector (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 authored Jul 2, 2024
1 parent cafb26d commit f5edb21
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 16 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ on:
- edited
- synchronize
- reopened
schedule:
- cron: 0 0 * * *

permissions:
contents: write
checks: write

jobs:
build:
name: Build
name: Test and Build
runs-on: ubuntu-latest

steps:
Expand All @@ -27,7 +29,7 @@ jobs:
java-version: '17'
distribution: 'temurin'

- name: Build And Test
- name: Gradle Build
uses: gradle/gradle-build-action@v2
with:
gradle-version: 8.5
Expand Down
45 changes: 45 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,51 @@ Reference: [Creating a collection with sparse vectors](https://qdrant.tech/docum

</details>

<details>
<summary><b>Multi-vector</b></summary>

```json
{
"collection_name": "{collection_name}",
"id": 1,
"vector": {
"some-multi": [
[
0.1,
0.2,
0.3,
0.4,
0.5,
0.6,
0.7,
0.8,
0.9,
1.0
],
[
1.0,
0.9,
0.8,
0.5,
0.4,
0.8,
0.6,
0.4,
0.2,
0.1
]
]
},
"payload": {
"name": "kafka",
"description": "Kafka is a distributed streaming platform",
"url": "https://kafka.apache.org/"
}
}
```

</details>

<details>
<summary><b>Combination of named dense and sparse vectors</b></summary>

Expand Down
4 changes: 2 additions & 2 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ plugins {
}

group = 'io.qdrant'
version = '1.0.0'
version = '1.1.0'
description = 'Kafka Sink Connector for Qdrant.'
java.sourceCompatibility = JavaVersion.VERSION_1_8
java.targetCompatibility = JavaVersion.VERSION_1_8
Expand Down Expand Up @@ -43,7 +43,7 @@ def kafkaVersion = '3.5.0'

dependencies {
implementation "org.apache.kafka:connect-api:$kafkaVersion"
implementation 'io.qdrant:client:1.9.1'
implementation 'io.qdrant:client:1.10.0'
implementation 'io.grpc:grpc-protobuf:1.59.0'
implementation "io.grpc:grpc-netty-shaded:1.59.0"
implementation 'com.google.guava:guava:33.2.1-jre'
Expand Down
23 changes: 23 additions & 0 deletions src/intTest/java/io/qdrant/kafka/BaseKafkaConnectTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import static org.apache.kafka.connect.runtime.SinkConnectorConfig.TOPICS_CONFIG;

import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
Expand Down Expand Up @@ -148,4 +150,25 @@ void writeSparseVector(String collectionName, Object id, String name, int vector

connect.kafka().produce(topicName, message);
}

void writeMultiVector(
String collectionName, Object id, int vectorSize, String name, int multiSize)
throws Exception {
Map<String, Object> messageMap = new HashMap<>();
messageMap.put("collection_name", collectionName);
messageMap.put("id", id);

Map<String, Object> vectorMap = new HashMap<>();
List<List<Float>> multiVector = new ArrayList<>(multiSize);
for (int i = 0; i < multiSize; i++) {
multiVector.add(randomVector(vectorSize));
}
vectorMap.put(name, multiVector);

messageMap.put("vector", vectorMap);

String message = new ObjectMapper().writeValueAsString(messageMap);

connect.kafka().produce(topicName, message);
}
}
20 changes: 20 additions & 0 deletions src/intTest/java/io/qdrant/kafka/BaseQdrantTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import io.qdrant.client.QdrantGrpcClient;
import io.qdrant.client.grpc.Collections.CreateCollection;
import io.qdrant.client.grpc.Collections.Distance;
import io.qdrant.client.grpc.Collections.MultiVectorComparator;
import io.qdrant.client.grpc.Collections.MultiVectorConfig;
import io.qdrant.client.grpc.Collections.SparseVectorConfig;
import io.qdrant.client.grpc.Collections.SparseVectorParams;
import io.qdrant.client.grpc.Collections.VectorParams;
Expand Down Expand Up @@ -44,6 +46,10 @@ public class BaseQdrantTest {
String sparseVecCollection = "sparse-vec-collection";
String sparseVecName = "sparse-vec";

String multiVecCollection = "multi-vec-collection";
String multiVecName = "multi-vec";
int multiVecSize = 58;

@BeforeEach
void setup() throws Exception {
qdrantClient =
Expand Down Expand Up @@ -76,6 +82,19 @@ void setup() throws Exception {
.putMap(sparseVecName, SparseVectorParams.getDefaultInstance()))
.build())
.get();

// Create multi vector collection
Map<String, VectorParams> multiVecParams = new HashMap<>();
multiVecParams.put(
multiVecName,
VectorParams.newBuilder()
.setSize(multiVecSize)
.setDistance(Distance.Dot)
.setMultivectorConfig(
MultiVectorConfig.newBuilder().setComparator(MultiVectorComparator.MaxSim).build())
.build());

qdrantClient.createCollectionAsync(multiVecCollection, multiVecParams).get();
}

@AfterEach
Expand All @@ -84,6 +103,7 @@ void tearDown() {
qdrantClient.deleteCollectionAsync(unnamedVecCollection);
qdrantClient.deleteCollectionAsync(namedVecCollection);
qdrantClient.deleteCollectionAsync(sparseVecCollection);
qdrantClient.deleteCollectionAsync(multiVecCollection);
qdrantClient.close();
}
}
Expand Down
15 changes: 15 additions & 0 deletions src/intTest/java/io/qdrant/kafka/QdrantSinkConnectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,19 @@ public void testSparseVector() throws Exception {

waitForPoints(sparseVecCollection, pointsCount);
}

@Test
public void testMultiVector() throws Exception {
connect.configureConnector(CONNECTOR_NAME, connectorProperties());
waitForConnectorToStart(CONNECTOR_NAME, 1);

int pointsCount = randomPositiveInt(100);
int multiSize = randomPositiveInt(20);

for (int i = 0; i < pointsCount; i++) {
writeMultiVector(multiVecCollection, i, multiVecSize, multiVecName, multiSize);
}

waitForPoints(multiVecCollection, pointsCount);
}
}
57 changes: 46 additions & 11 deletions src/main/java/io/qdrant/kafka/VectorsFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apache.kafka.connect.errors.DataException;

/* Helper to convert JSON vector representations into io.qdrant.client.grpc.Points.Vectors. */

// Example JSON inputs:
// {
// "vector": [
Expand Down Expand Up @@ -48,6 +49,17 @@
// }
// }
// }

// {
// "vector": {
// "some-name": [
// [0.041732933, 0.013779674, -0.027564144],
// [0.051345434, 0.013743223, -0.027576543],
// [0.041732933, 0.013779674, -0.027564144]
// ]
// }
// }

class VectorsFactory {

public static Vectors vectors(Value vectorValue) throws DataException {
Expand All @@ -67,17 +79,6 @@ public static Vectors vectors(Value vectorValue) throws DataException {
return vectorsBuilder.build();
}

private static Vector parseDenseVector(ListValue listValue) throws DataException {
Vector.Builder vectorBuilder = Vector.newBuilder();
for (Value value : listValue.getValuesList()) {
if (!value.hasDoubleValue()) {
throw new DataException("Vector data must be a list of floats");
}
vectorBuilder.addData((float) value.getDoubleValue());
}
return vectorBuilder.build();
}

private static NamedVectors parseNamedVectors(Struct struct) throws DataException {
NamedVectors.Builder namedVectorsBuilder = NamedVectors.newBuilder();
for (Map.Entry<String, Value> entry : struct.getFieldsMap().entrySet()) {
Expand All @@ -94,6 +95,40 @@ private static NamedVectors parseNamedVectors(Struct struct) throws DataExceptio
return namedVectorsBuilder.build();
}

private static Vector parseDenseVector(ListValue listValue) throws DataException {
Vector.Builder vectorBuilder = Vector.newBuilder();
for (Value value : listValue.getValuesList()) {
if (value.hasListValue()) {
return parseMultiDenseVector(listValue);
}

if (!value.hasDoubleValue()) {
throw new DataException("Dense vector data must be a list of floats");
}
vectorBuilder.addData((float) value.getDoubleValue());
}
return vectorBuilder.build();
}

private static Vector parseMultiDenseVector(ListValue listValue) throws DataException {
Vector.Builder vectorBuilder = Vector.newBuilder();
int numRows = listValue.getValuesCount();

for (Value row : listValue.getValuesList()) {
if (!row.hasListValue()) {
throw new DataException("Multi vector data must be a list of lists of floats");
}
for (Value value : row.getListValue().getValuesList()) {
if (!value.hasDoubleValue()) {
throw new DataException("Multi vector data must be a list of lists of floats");
}
vectorBuilder.addData((float) value.getDoubleValue());
}
}
vectorBuilder.setVectorsCount(numRows);
return vectorBuilder.build();
}

private static Vector parseSparseVector(Struct struct) throws DataException {
Map<String, Value> fields = struct.getFieldsMap();

Expand Down
53 changes: 52 additions & 1 deletion src/test/java/io/qdrant/kafka/VectorsFactoryTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,53 @@ void testNamedVectors() {
.addValues(Value.newBuilder().setDoubleValue(0.52354).build())
.build())
.build());
fieldsMap.put(
"multi",
Value.newBuilder()
.setListValue(
ListValue.newBuilder()
.addValues(
Value.newBuilder()
.setListValue(
ListValue.newBuilder()
.addValues(Value.newBuilder().setDoubleValue(0.32).build())
.addValues(Value.newBuilder().setDoubleValue(0.432).build())
.addValues(Value.newBuilder().setDoubleValue(0.423).build())
.addValues(Value.newBuilder().setDoubleValue(0.52354).build())
.build())
.build())
.addValues(
Value.newBuilder()
.setListValue(
ListValue.newBuilder()
.addValues(Value.newBuilder().setDoubleValue(0.32).build())
.addValues(Value.newBuilder().setDoubleValue(0.432).build())
.addValues(Value.newBuilder().setDoubleValue(0.423).build())
.addValues(Value.newBuilder().setDoubleValue(0.52354).build())
.build())
.build())
.addValues(
Value.newBuilder()
.setListValue(
ListValue.newBuilder()
.addValues(Value.newBuilder().setDoubleValue(0.32).build())
.addValues(Value.newBuilder().setDoubleValue(0.432).build())
.addValues(Value.newBuilder().setDoubleValue(0.423).build())
.addValues(Value.newBuilder().setDoubleValue(0.52354).build())
.build())
.build())
.addValues(
Value.newBuilder()
.setListValue(
ListValue.newBuilder()
.addValues(Value.newBuilder().setDoubleValue(0.32).build())
.addValues(Value.newBuilder().setDoubleValue(0.432).build())
.addValues(Value.newBuilder().setDoubleValue(0.423).build())
.addValues(Value.newBuilder().setDoubleValue(0.52354).build())
.build())
.build())
.build())
.build());

Struct struct = Struct.newBuilder().putAllFields(fieldsMap).build();
Value vectorValue = Value.newBuilder().setStructValue(struct).build();
Expand All @@ -74,9 +121,13 @@ void testNamedVectors() {

assertTrue(vectors.hasVectors());
NamedVectors namedVectors = vectors.getVectors();
assertEquals(2, namedVectors.getVectorsCount());
assertEquals(3, namedVectors.getVectorsCount());
assertTrue(namedVectors.containsVectors("boi"));
assertTrue(namedVectors.containsVectors("gal"));
assertTrue(namedVectors.containsVectors("multi"));

Vector multiVector = namedVectors.getVectorsOrThrow("multi");
assertEquals(4, multiVector.getVectorsCount());
}

@Test
Expand Down

0 comments on commit f5edb21

Please sign in to comment.