Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop #51

Merged
merged 10 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bin
Submodule bin added at 1b79cf
32 changes: 25 additions & 7 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>com.mulesoft.connectors</groupId>
<artifactId>mule4-vectors-connector</artifactId>
<version>0.1.124-SNAPSHOT</version>
<version>0.2.0</version>
<packaging>mule-extension</packaging>
<name>MuleSoft Vectors Connector - Mule 4</name>
<description>MuleSoft Vectors Connector provides access to a broad number of external Vector Stores.</description>
Expand Down Expand Up @@ -236,12 +236,6 @@
<version>${langchain4jVersion}</version>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-weaviate</artifactId>
<version>${langchain4jVersion}</version>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-document-transformer-jsoup</artifactId>
Expand All @@ -254,6 +248,30 @@
<version>${langchain4jVersion}</version>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-qdrant</artifactId>
<version>${langchain4jVersion}</version>
</dependency>

<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-protobuf</artifactId>
<version>1.65.1</version>
</dependency>

<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-netty-shaded</artifactId>
<version>1.65.1</version>
</dependency>

<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java-util</artifactId>
<version>3.25.5</version>
</dependency>

<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ private Constants() {}
public static final String VECTOR_STORE_MILVUS = "MILVUS";
public static final String VECTOR_STORE_CHROMA = "CHROMA";
public static final String VECTOR_STORE_PINECONE = "PINECONE";
public static final String VECTOR_STORE_WEAVIATE = "WEAVIATE";
public static final String VECTOR_STORE_AI_SEARCH = "AI_SEARCH";
public static final String VECTOR_STORE_QDRANT = "QDRANT";

public static final String STORE_SCHEMA_METADATA_FIELD_NAME = "metadata";
public static final String STORE_SCHEMA_VECTOR_FIELD_NAME = "vector";
Expand Down Expand Up @@ -74,7 +74,6 @@ private Constants() {}
public static final String JSON_KEY_TEXT = "text";
public static final String JSON_KEY_STATUS = "status";
public static final String JSON_KEY_EMBEDDING = "embedding";
public static final String JSON_KEY_EMBEDDINGS = "embeddings";
public static final String JSON_KEY_DIMENSIONS = "dimensions";
public static final String JSON_KEY_RESPONSE = "response";
public static final String JSON_KEY_QUESTION = "question";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
* Constants.VECTOR_STORE_MILVUS,
* Constants.VECTOR_STORE_CHROMA,
* Constants.VECTOR_STORE_PINECONE,
* Constants.VECTOR_STORE_WEAVIATE,
* Constants.VECTOR_STORE_AI_SEARCH
* )));
* </pre>
Expand All @@ -47,11 +46,10 @@ public class EmbeddingOperationValidator {
Constants.VECTOR_STORE_MILVUS,
Constants.VECTOR_STORE_CHROMA,
Constants.VECTOR_STORE_PINECONE,
Constants.VECTOR_STORE_WEAVIATE,
Constants.VECTOR_STORE_AI_SEARCH
Constants.VECTOR_STORE_AI_SEARCH,
Constants.VECTOR_STORE_QDRANT
)));

// Weaviate not supported for FILTER_BY_METADATA operation
EMBEDDING_OPERATION_TYPE_TO_SUPPORTED_VECTOR_STORES.put(Constants.EMBEDDING_OPERATION_TYPE_FILTER_BY_METADATA,
new HashSet<>(Arrays.asList(
Constants.VECTOR_STORE_PGVECTOR,
Expand All @@ -60,7 +58,8 @@ public class EmbeddingOperationValidator {
Constants.VECTOR_STORE_MILVUS,
Constants.VECTOR_STORE_CHROMA,
Constants.VECTOR_STORE_PINECONE,
Constants.VECTOR_STORE_AI_SEARCH
Constants.VECTOR_STORE_AI_SEARCH,
Constants.VECTOR_STORE_QDRANT
)));

EMBEDDING_OPERATION_TYPE_TO_SUPPORTED_VECTOR_STORES.put(Constants.EMBEDDING_OPERATION_TYPE_REMOVE_EMBEDDINGS,
Expand All @@ -71,7 +70,6 @@ public class EmbeddingOperationValidator {
Constants.VECTOR_STORE_MILVUS,
Constants.VECTOR_STORE_CHROMA,
// Constants.VECTOR_STORE_PINECONE,
Constants.VECTOR_STORE_WEAVIATE,
Constants.VECTOR_STORE_AI_SEARCH
)));

Expand All @@ -83,8 +81,8 @@ public class EmbeddingOperationValidator {
Constants.VECTOR_STORE_MILVUS,
Constants.VECTOR_STORE_CHROMA,
// Constants.VECTOR_STORE_PINECONE, // Do not support GTE with strings.
// Constants.VECTOR_STORE_WEAVIATE, // Not Supported
Constants.VECTOR_STORE_AI_SEARCH
Constants.VECTOR_STORE_AI_SEARCH,
Constants.VECTOR_STORE_QDRANT
)));

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ public Set<Value> resolve() throws ValueResolvingException {
return ValueBuilder.getValuesFor(
Constants.VECTOR_STORE_PGVECTOR,
Constants.VECTOR_STORE_ELASTICSEARCH,
Constants.VECTOR_STORE_OPENSEARCH,
Constants.VECTOR_STORE_OPENSEARCH,
Constants.VECTOR_STORE_MILVUS,
Constants.VECTOR_STORE_CHROMA,
Constants.VECTOR_STORE_PINECONE,
Constants.VECTOR_STORE_WEAVIATE,
Constants.VECTOR_STORE_AI_SEARCH,
Constants.VECTOR_STORE_OPENSEARCH
Constants.VECTOR_STORE_OPENSEARCH,
Constants.VECTOR_STORE_QDRANT
); // MuleChainVectorsConstants.VECTOR_STORE_NEO4J
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import org.mule.extension.mulechain.vectors.internal.store.opensearch.OpenSearchStore;
import org.mule.extension.mulechain.vectors.internal.store.pgvector.PGVectorStore;
import org.mule.extension.mulechain.vectors.internal.store.pinecone.PineconeStore;
import org.mule.extension.mulechain.vectors.internal.store.weviate.WeaviateStore;
import org.mule.extension.mulechain.vectors.internal.store.qdrant.QdrantStore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -274,11 +274,6 @@ public BaseStore build() {
baseStore = new AISearchStore(storeName, configuration, queryParams, dimension);
break;

case Constants.VECTOR_STORE_WEAVIATE:

baseStore = new WeaviateStore(storeName, configuration, queryParams, dimension);
break;

case Constants.VECTOR_STORE_CHROMA:

baseStore = new ChromaStore(storeName, configuration, queryParams, dimension);
Expand All @@ -299,6 +294,11 @@ public BaseStore build() {
baseStore = new OpenSearchStore(storeName, configuration, queryParams, dimension);
break;

case Constants.VECTOR_STORE_QDRANT:

baseStore = new QdrantStore(storeName, configuration, queryParams, dimension);
break;

default:
//throw new IllegalOperationException("Unsupported Vector Store: " + configuration.getVectorStore());
baseStore = null;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package org.mule.extension.mulechain.vectors.internal.store.qdrant;

import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.qdrant.QdrantEmbeddingStore;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.QdrantGrpcClient;
import io.qdrant.client.grpc.Collections;
import io.qdrant.client.grpc.JsonWithInt;
import io.qdrant.client.grpc.Points;
import org.json.JSONObject;
import org.mule.extension.mulechain.vectors.internal.config.Configuration;
import org.mule.extension.mulechain.vectors.internal.constant.Constants;
import org.mule.extension.mulechain.vectors.internal.helper.parameter.QueryParameters;
import org.mule.extension.mulechain.vectors.internal.store.BaseStore;
import org.mule.extension.mulechain.vectors.internal.util.JsonUtils;

import java.util.*;
import java.util.concurrent.ExecutionException;

public class QdrantStore extends BaseStore {

private final QdrantClient client;
private final String payloadTextKey;

public QdrantStore(String storeName, Configuration configuration, QueryParameters queryParams, int dimension) {

super(storeName, configuration, queryParams, dimension);

JSONObject config = JsonUtils.readConfigFile(configuration.getConfigFilePath());
JSONObject vectorStoreConfig = config.getJSONObject(Constants.VECTOR_STORE_QDRANT);
String host = vectorStoreConfig.getString("QDRANT_HOST");
String apiKey = vectorStoreConfig.getString("QDRANT_API_KEY");
int port = vectorStoreConfig.getInt("QDRANT_GRPC_PORT");
boolean useTls = vectorStoreConfig.getBoolean("QDRANT_USE_TLS");
this.client = new QdrantClient(QdrantGrpcClient.newBuilder(host, port, useTls).withApiKey(apiKey).build());
this.payloadTextKey = vectorStoreConfig.getString("QDRANT_TEXT_KEY");

try {
if (!this.client.collectionExistsAsync(this.storeName).get() && dimension > 0) {
this.client.createCollectionAsync(storeName,
Collections.VectorParams.newBuilder().setDistance(Collections.Distance.Cosine)
.setSize(dimension).build())
.get();
}
} catch (ExecutionException | InterruptedException e) {
throw new RuntimeException(e);
}
}

public EmbeddingStore<TextSegment> buildEmbeddingStore() {

return QdrantEmbeddingStore.builder()
.client(client)
.payloadTextKey(payloadTextKey)
.collectionName(storeName)
.build();
}

@Override
public JSONObject listSources() {
try {
// Optional max limit of 100k points.
int MAX_POINTS = 10000;

HashMap<String, JSONObject> sourceObjectMap = new HashMap<String, JSONObject>();
JSONObject jsonObject = new JSONObject();

boolean keepScrolling = true;
Points.PointId nextOffset = null;
List<Points.RetrievedPoint> points = new ArrayList<>(MAX_POINTS);
while (keepScrolling && points.size() < MAX_POINTS) {
Points.ScrollPoints.Builder request = Points.ScrollPoints.newBuilder()
.setCollectionName(storeName)
.setLimit(Math.min(queryParams.embeddingPageSize(), MAX_POINTS - points.size()));
if (nextOffset != null) {
request.setOffset(nextOffset);
}

Points.ScrollResponse response = client.scrollAsync(request.build()).get();

points.addAll(response.getResultList());
nextOffset = response.getNextPageOffset();
keepScrolling = nextOffset.hasNum() || nextOffset.hasUuid();
}

for (Points.RetrievedPoint point : points) {
JSONObject metadataObject = new JSONObject(JsonFactory.toJson(point.getPayloadMap()));
JSONObject sourceObject = getSourceObject(metadataObject);
addOrUpdateSourceObjectIntoSourceObjectMap(sourceObjectMap, sourceObject);
}

jsonObject.put(Constants.JSON_KEY_SOURCES,
JsonUtils.jsonObjectCollectionToJsonArray(sourceObjectMap.values()));
jsonObject.put(Constants.JSON_KEY_SOURCE_COUNT, sourceObjectMap.size());

return jsonObject;
} catch (ExecutionException | InterruptedException | InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
}
}

final class JsonFactory {
public static String toJson(Map<String, JsonWithInt.Value> map)
throws InvalidProtocolBufferException {

Struct.Builder structBuilder = Struct.newBuilder();
map.forEach((key, value) -> structBuilder.putFields(key, toProtobufValue(value)));
return JsonFormat.printer().print(structBuilder.build());
}

private static Value toProtobufValue(io.qdrant.client.grpc.JsonWithInt.Value value) {
switch (value.getKindCase()) {
case NULL_VALUE:
return Value.newBuilder().setNullValueValue(0).build();

case BOOL_VALUE:
return Value.newBuilder().setBoolValue(value.getBoolValue()).build();

case STRING_VALUE:
return Value.newBuilder().setStringValue(value.getStringValue()).build();

case INTEGER_VALUE:
return Value.newBuilder().setNumberValue(value.getIntegerValue()).build();

case DOUBLE_VALUE:
return Value.newBuilder().setNumberValue(value.getDoubleValue()).build();

case STRUCT_VALUE:
Struct.Builder structBuilder = Struct.newBuilder();
value.getStructValue()
.getFieldsMap()
.forEach(
(key, val) -> {
structBuilder.putFields(key, toProtobufValue(val));
});
return Value.newBuilder().setStructValue(structBuilder).build();

case LIST_VALUE:
Value.Builder listBuilder = Value.newBuilder();
value.getListValue().getValuesList().stream()
.map(JsonFactory::toProtobufValue)
.forEach(listBuilder.getListValueBuilder()::addValues);
return listBuilder.build();

default:
throw new IllegalArgumentException("Unsupported payload value type: " + value.getKindCase());
}
}
}

This file was deleted.