Skip to content

Commit

Permalink
Merge pull request #51 from MuleSoft-AI-Chain-Project/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
tbolis-at-mulesoft authored Nov 22, 2024
2 parents a22d265 + 2832d3d commit d24620d
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 72 deletions.
1 change: 1 addition & 0 deletions bin
Submodule bin added at 1b79cf
32 changes: 25 additions & 7 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>com.mulesoft.connectors</groupId>
<artifactId>mule4-vectors-connector</artifactId>
<version>0.1.124-SNAPSHOT</version>
<version>0.2.0</version>
<packaging>mule-extension</packaging>
<name>MuleSoft Vectors Connector - Mule 4</name>
<description>MuleSoft Vectors Connector provides access to a broad number of external Vector Stores.</description>
Expand Down Expand Up @@ -236,12 +236,6 @@
<version>${langchain4jVersion}</version>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-weaviate</artifactId>
<version>${langchain4jVersion}</version>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-document-transformer-jsoup</artifactId>
Expand All @@ -254,6 +248,30 @@
<version>${langchain4jVersion}</version>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-qdrant</artifactId>
<version>${langchain4jVersion}</version>
</dependency>

<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-protobuf</artifactId>
<version>1.65.1</version>
</dependency>

<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-netty-shaded</artifactId>
<version>1.65.1</version>
</dependency>

<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java-util</artifactId>
<version>3.25.5</version>
</dependency>

<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ private Constants() {}
public static final String VECTOR_STORE_MILVUS = "MILVUS";
public static final String VECTOR_STORE_CHROMA = "CHROMA";
public static final String VECTOR_STORE_PINECONE = "PINECONE";
public static final String VECTOR_STORE_WEAVIATE = "WEAVIATE";
public static final String VECTOR_STORE_AI_SEARCH = "AI_SEARCH";
public static final String VECTOR_STORE_QDRANT = "QDRANT";

public static final String STORE_SCHEMA_METADATA_FIELD_NAME = "metadata";
public static final String STORE_SCHEMA_VECTOR_FIELD_NAME = "vector";
Expand Down Expand Up @@ -74,7 +74,6 @@ private Constants() {}
public static final String JSON_KEY_TEXT = "text";
public static final String JSON_KEY_STATUS = "status";
public static final String JSON_KEY_EMBEDDING = "embedding";
public static final String JSON_KEY_EMBEDDINGS = "embeddings";
public static final String JSON_KEY_DIMENSIONS = "dimensions";
public static final String JSON_KEY_RESPONSE = "response";
public static final String JSON_KEY_QUESTION = "question";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
* Constants.VECTOR_STORE_MILVUS,
* Constants.VECTOR_STORE_CHROMA,
* Constants.VECTOR_STORE_PINECONE,
* Constants.VECTOR_STORE_WEAVIATE,
* Constants.VECTOR_STORE_AI_SEARCH
* )));
* </pre>
Expand All @@ -47,11 +46,10 @@ public class EmbeddingOperationValidator {
Constants.VECTOR_STORE_MILVUS,
Constants.VECTOR_STORE_CHROMA,
Constants.VECTOR_STORE_PINECONE,
Constants.VECTOR_STORE_WEAVIATE,
Constants.VECTOR_STORE_AI_SEARCH
Constants.VECTOR_STORE_AI_SEARCH,
Constants.VECTOR_STORE_QDRANT
)));

// Weaviate not supported for FILTER_BY_METADATA operation
EMBEDDING_OPERATION_TYPE_TO_SUPPORTED_VECTOR_STORES.put(Constants.EMBEDDING_OPERATION_TYPE_FILTER_BY_METADATA,
new HashSet<>(Arrays.asList(
Constants.VECTOR_STORE_PGVECTOR,
Expand All @@ -60,7 +58,8 @@ public class EmbeddingOperationValidator {
Constants.VECTOR_STORE_MILVUS,
Constants.VECTOR_STORE_CHROMA,
Constants.VECTOR_STORE_PINECONE,
Constants.VECTOR_STORE_AI_SEARCH
Constants.VECTOR_STORE_AI_SEARCH,
Constants.VECTOR_STORE_QDRANT
)));

EMBEDDING_OPERATION_TYPE_TO_SUPPORTED_VECTOR_STORES.put(Constants.EMBEDDING_OPERATION_TYPE_REMOVE_EMBEDDINGS,
Expand All @@ -71,7 +70,6 @@ public class EmbeddingOperationValidator {
Constants.VECTOR_STORE_MILVUS,
Constants.VECTOR_STORE_CHROMA,
// Constants.VECTOR_STORE_PINECONE,
Constants.VECTOR_STORE_WEAVIATE,
Constants.VECTOR_STORE_AI_SEARCH
)));

Expand All @@ -83,8 +81,8 @@ public class EmbeddingOperationValidator {
Constants.VECTOR_STORE_MILVUS,
Constants.VECTOR_STORE_CHROMA,
// Constants.VECTOR_STORE_PINECONE, // Do not support GTE with strings.
// Constants.VECTOR_STORE_WEAVIATE, // Not Supported
Constants.VECTOR_STORE_AI_SEARCH
Constants.VECTOR_STORE_AI_SEARCH,
Constants.VECTOR_STORE_QDRANT
)));

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ public Set<Value> resolve() throws ValueResolvingException {
return ValueBuilder.getValuesFor(
Constants.VECTOR_STORE_PGVECTOR,
Constants.VECTOR_STORE_ELASTICSEARCH,
Constants.VECTOR_STORE_OPENSEARCH,
Constants.VECTOR_STORE_OPENSEARCH,
Constants.VECTOR_STORE_MILVUS,
Constants.VECTOR_STORE_CHROMA,
Constants.VECTOR_STORE_PINECONE,
Constants.VECTOR_STORE_WEAVIATE,
Constants.VECTOR_STORE_AI_SEARCH,
Constants.VECTOR_STORE_OPENSEARCH
Constants.VECTOR_STORE_OPENSEARCH,
Constants.VECTOR_STORE_QDRANT
); // MuleChainVectorsConstants.VECTOR_STORE_NEO4J
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import org.mule.extension.mulechain.vectors.internal.store.opensearch.OpenSearchStore;
import org.mule.extension.mulechain.vectors.internal.store.pgvector.PGVectorStore;
import org.mule.extension.mulechain.vectors.internal.store.pinecone.PineconeStore;
import org.mule.extension.mulechain.vectors.internal.store.weviate.WeaviateStore;
import org.mule.extension.mulechain.vectors.internal.store.qdrant.QdrantStore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -274,11 +274,6 @@ public BaseStore build() {
baseStore = new AISearchStore(storeName, configuration, queryParams, dimension);
break;

case Constants.VECTOR_STORE_WEAVIATE:

baseStore = new WeaviateStore(storeName, configuration, queryParams, dimension);
break;

case Constants.VECTOR_STORE_CHROMA:

baseStore = new ChromaStore(storeName, configuration, queryParams, dimension);
Expand All @@ -299,6 +294,11 @@ public BaseStore build() {
baseStore = new OpenSearchStore(storeName, configuration, queryParams, dimension);
break;

case Constants.VECTOR_STORE_QDRANT:

baseStore = new QdrantStore(storeName, configuration, queryParams, dimension);
break;

default:
//throw new IllegalOperationException("Unsupported Vector Store: " + configuration.getVectorStore());
baseStore = null;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package org.mule.extension.mulechain.vectors.internal.store.qdrant;

import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.qdrant.QdrantEmbeddingStore;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.QdrantGrpcClient;
import io.qdrant.client.grpc.Collections;
import io.qdrant.client.grpc.JsonWithInt;
import io.qdrant.client.grpc.Points;
import org.json.JSONObject;
import org.mule.extension.mulechain.vectors.internal.config.Configuration;
import org.mule.extension.mulechain.vectors.internal.constant.Constants;
import org.mule.extension.mulechain.vectors.internal.helper.parameter.QueryParameters;
import org.mule.extension.mulechain.vectors.internal.store.BaseStore;
import org.mule.extension.mulechain.vectors.internal.util.JsonUtils;

import java.util.*;
import java.util.concurrent.ExecutionException;

public class QdrantStore extends BaseStore {

private final QdrantClient client;
private final String payloadTextKey;

public QdrantStore(String storeName, Configuration configuration, QueryParameters queryParams, int dimension) {

super(storeName, configuration, queryParams, dimension);

JSONObject config = JsonUtils.readConfigFile(configuration.getConfigFilePath());
JSONObject vectorStoreConfig = config.getJSONObject(Constants.VECTOR_STORE_QDRANT);
String host = vectorStoreConfig.getString("QDRANT_HOST");
String apiKey = vectorStoreConfig.getString("QDRANT_API_KEY");
int port = vectorStoreConfig.getInt("QDRANT_GRPC_PORT");
boolean useTls = vectorStoreConfig.getBoolean("QDRANT_USE_TLS");
this.client = new QdrantClient(QdrantGrpcClient.newBuilder(host, port, useTls).withApiKey(apiKey).build());
this.payloadTextKey = vectorStoreConfig.getString("QDRANT_TEXT_KEY");

try {
if (!this.client.collectionExistsAsync(this.storeName).get() && dimension > 0) {
this.client.createCollectionAsync(storeName,
Collections.VectorParams.newBuilder().setDistance(Collections.Distance.Cosine)
.setSize(dimension).build())
.get();
}
} catch (ExecutionException | InterruptedException e) {
throw new RuntimeException(e);
}
}

public EmbeddingStore<TextSegment> buildEmbeddingStore() {

return QdrantEmbeddingStore.builder()
.client(client)
.payloadTextKey(payloadTextKey)
.collectionName(storeName)
.build();
}

@Override
public JSONObject listSources() {
try {
// Optional max limit of 100k points.
int MAX_POINTS = 10000;

HashMap<String, JSONObject> sourceObjectMap = new HashMap<String, JSONObject>();
JSONObject jsonObject = new JSONObject();

boolean keepScrolling = true;
Points.PointId nextOffset = null;
List<Points.RetrievedPoint> points = new ArrayList<>(MAX_POINTS);
while (keepScrolling && points.size() < MAX_POINTS) {
Points.ScrollPoints.Builder request = Points.ScrollPoints.newBuilder()
.setCollectionName(storeName)
.setLimit(Math.min(queryParams.embeddingPageSize(), MAX_POINTS - points.size()));
if (nextOffset != null) {
request.setOffset(nextOffset);
}

Points.ScrollResponse response = client.scrollAsync(request.build()).get();

points.addAll(response.getResultList());
nextOffset = response.getNextPageOffset();
keepScrolling = nextOffset.hasNum() || nextOffset.hasUuid();
}

for (Points.RetrievedPoint point : points) {
JSONObject metadataObject = new JSONObject(JsonFactory.toJson(point.getPayloadMap()));
JSONObject sourceObject = getSourceObject(metadataObject);
addOrUpdateSourceObjectIntoSourceObjectMap(sourceObjectMap, sourceObject);
}

jsonObject.put(Constants.JSON_KEY_SOURCES,
JsonUtils.jsonObjectCollectionToJsonArray(sourceObjectMap.values()));
jsonObject.put(Constants.JSON_KEY_SOURCE_COUNT, sourceObjectMap.size());

return jsonObject;
} catch (ExecutionException | InterruptedException | InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
}
}

final class JsonFactory {
public static String toJson(Map<String, JsonWithInt.Value> map)
throws InvalidProtocolBufferException {

Struct.Builder structBuilder = Struct.newBuilder();
map.forEach((key, value) -> structBuilder.putFields(key, toProtobufValue(value)));
return JsonFormat.printer().print(structBuilder.build());
}

private static Value toProtobufValue(io.qdrant.client.grpc.JsonWithInt.Value value) {
switch (value.getKindCase()) {
case NULL_VALUE:
return Value.newBuilder().setNullValueValue(0).build();

case BOOL_VALUE:
return Value.newBuilder().setBoolValue(value.getBoolValue()).build();

case STRING_VALUE:
return Value.newBuilder().setStringValue(value.getStringValue()).build();

case INTEGER_VALUE:
return Value.newBuilder().setNumberValue(value.getIntegerValue()).build();

case DOUBLE_VALUE:
return Value.newBuilder().setNumberValue(value.getDoubleValue()).build();

case STRUCT_VALUE:
Struct.Builder structBuilder = Struct.newBuilder();
value.getStructValue()
.getFieldsMap()
.forEach(
(key, val) -> {
structBuilder.putFields(key, toProtobufValue(val));
});
return Value.newBuilder().setStructValue(structBuilder).build();

case LIST_VALUE:
Value.Builder listBuilder = Value.newBuilder();
value.getListValue().getValuesList().stream()
.map(JsonFactory::toProtobufValue)
.forEach(listBuilder.getListValueBuilder()::addValues);
return listBuilder.build();

default:
throw new IllegalArgumentException("Unsupported payload value type: " + value.getKindCase());
}
}
}

This file was deleted.

0 comments on commit d24620d

Please sign in to comment.