Skip to content

Commit

Permalink
Merge pull request #22 from MuleSoft-AI-Chain-Project/bugfix/list-sou…
Browse files Browse the repository at this point in the history
…rces

Bugfix/list sources
  • Loading branch information
tbolis-at-mulesoft authored Nov 7, 2024
2 parents 2b217e5 + 132afbc commit 3c462f4
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 54 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>com.mule.mulechain</groupId>
<artifactId>mulechain-vectors</artifactId>
<version>0.1.78-SNAPSHOT</version>
<version>0.1.83-SNAPSHOT</version>
<packaging>mule-extension</packaging>
<name>MAC Vectors</name>

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package org.mule.extension.mulechain.vectors.internal.helper.parameter;

import org.mule.runtime.api.meta.ExpressionSupport;
import org.mule.runtime.extension.api.annotation.Expression;
import org.mule.runtime.extension.api.annotation.param.Optional;
import org.mule.runtime.extension.api.annotation.param.Parameter;
import org.mule.runtime.extension.api.annotation.param.display.Summary;

public class QueryParameters {

@Parameter
@Expression(ExpressionSupport.SUPPORTED)
@Summary("The embedding page size used when querying the vector store. Defaults to 5000 embeddings.")
@Optional(defaultValue = "5000")
private Number embeddingPageSize;

// @Parameter
// @Expression(ExpressionSupport.SUPPORTED)
// @Summary("The offset used when querying the vector store")
// @Optional(defaultValue = "0")
// private Number offset;

// @Parameter
// @Expression(ExpressionSupport.SUPPORTED)
// @Summary("The limit applied used when querying the vector store")
// @Optional
// private Number limit;

public int embeddingPageSize() {return embeddingPageSize != null ? embeddingPageSize.intValue() : 5000;}

// public int offset() {
// return offset.intValue();
// }

// public int limit() {
// return limit.intValue();
// }
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.mule.extension.mulechain.vectors.internal.operation;

import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey;
import static org.apache.commons.io.IOUtils.toInputStream;
import static org.mule.extension.mulechain.vectors.internal.util.JsonUtils.readConfigFile;
import static org.mule.runtime.extension.api.annotation.param.MediaType.APPLICATION_JSON;
Expand All @@ -14,28 +15,23 @@
import org.mule.extension.mulechain.vectors.internal.helper.EmbeddingStoreIngestorHelper;
import org.mule.extension.mulechain.vectors.internal.helper.factory.EmbeddingModelFactory;
import org.mule.extension.mulechain.vectors.internal.helper.factory.EmbeddingStoreFactory;
import org.mule.extension.mulechain.vectors.internal.helper.parameter.FileTypeParameters;
import org.mule.extension.mulechain.vectors.internal.helper.parameter.*;
import org.mule.extension.mulechain.vectors.internal.config.Configuration;
import org.mule.extension.mulechain.vectors.internal.helper.parameter.MetadataFilterParameters;
import org.mule.extension.mulechain.vectors.internal.helper.parameter.EmbeddingModelNameParameters;
import dev.langchain4j.store.embedding.*;
import dev.langchain4j.store.embedding.filter.Filter;
import org.json.JSONArray;
import org.json.JSONObject;
import org.mule.extension.mulechain.vectors.internal.util.JsonUtils;
import org.mule.runtime.extension.api.annotation.Alias;
import org.mule.runtime.extension.api.annotation.param.MediaType;
import org.mule.runtime.extension.api.annotation.param.ParameterGroup;
import org.mule.runtime.extension.api.annotation.param.*;

import static java.util.stream.Collectors.joining;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.document.splitter.DocumentSplitters;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;

import org.mule.runtime.extension.api.annotation.param.Config;
import org.mule.extension.mulechain.vectors.internal.helper.parameter.StorageTypeParameters;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -359,68 +355,104 @@ public InputStream queryByFilterFromEmbedding(String storeName, String question,
@Alias("EMBEDDING-list-sources")
public InputStream listSourcesFromStore(String storeName,
@Config Configuration configuration,
@ParameterGroup(name = "Additional Properties") EmbeddingModelNameParameters modelParams) {
@ParameterGroup(name = "Querying Strategy") QueryParameters queryParams,
@ParameterGroup(name = "Additional Properties") EmbeddingModelNameParameters modelParams
) {

EmbeddingOperationValidator.validateOperationType(
Constants.EMBEDDING_OPERATION_TYPE_FILTER_BY_METADATA,configuration.getVectorStore());

EmbeddingModel embeddingModel = EmbeddingModelFactory.createModel(configuration, modelParams);
EmbeddingStore<TextSegment> store = EmbeddingStoreFactory.createStore(configuration, storeName, embeddingModel.dimension());

Embedding queryEmbedding = embeddingModel.embed(".").content();
EmbeddingSearchRequest searchRequest = EmbeddingSearchRequest.builder()
.queryEmbedding(queryEmbedding)
.maxResults(16384)
.minScore(0.0)
.build();
// Create a general query vector (e.g., zero vector). Zero vector is often used when you need to retrieve all
// embeddings without any specific bias.
float[] queryVector = new float[embeddingModel.dimension()];
for (int i = 0; i < embeddingModel.dimension(); i++) {
queryVector[i]=0.0f; // Zero vector
}

EmbeddingSearchResult<TextSegment> searchResult = store.search(searchRequest);
List<EmbeddingMatch<TextSegment>> embeddingMatches = searchResult.matches();
String information = embeddingMatches.stream()
.map(match -> match.embedded().text())
.collect(joining("\n\n"));
Embedding queryEmbedding = new Embedding(queryVector);

JSONObject jsonObject = new JSONObject();
jsonObject.put("storeName", storeName);
JSONArray sources = new JSONArray();
String absoluteDirectoryPath;
String fileName;
String url;
String ingestionDatetime;

JSONObject contentObject;
String fullPath;

HashMap<String, JSONObject> sourcesJSONObjectHashMap = new HashMap<String, JSONObject>();
for (EmbeddingMatch<TextSegment> match : embeddingMatches) {

Metadata matchMetadata = match.embedded().metadata();
fileName = matchMetadata.getString(Constants.METADATA_KEY_FILE_NAME);
url = matchMetadata.getString(Constants.METADATA_KEY_URL);
fullPath = matchMetadata.getString(Constants.METADATA_KEY_FULL_PATH);
absoluteDirectoryPath = matchMetadata.getString(Constants.METADATA_KEY_ABSOLUTE_DIRECTORY_PATH);
ingestionDatetime = matchMetadata.getString(Constants.METADATA_KEY_INGESTION_DATETIME);

contentObject = new JSONObject();
contentObject.put(Constants.METADATA_KEY_ABSOLUTE_DIRECTORY_PATH, absoluteDirectoryPath);
contentObject.put(Constants.METADATA_KEY_FULL_PATH, fullPath);
contentObject.put(Constants.METADATA_KEY_FILE_NAME, fileName);
contentObject.put(Constants.METADATA_KEY_URL, url);
contentObject.put(Constants.METADATA_KEY_INGESTION_DATETIME, ingestionDatetime);

String key =
((fullPath != null &&!fullPath.isEmpty()) ? fullPath :
(url != null && !url.isEmpty()) ? url : "") +
((ingestionDatetime != null && !ingestionDatetime.isEmpty()) ? ingestionDatetime : "");

// Add contentObject to sources only if it has at least one key-value pair
if (!contentObject.isEmpty() && !key.isEmpty()) {
List<EmbeddingMatch<TextSegment>> embeddingMatches = null;
HashMap<String, JSONObject> sourcesJSONObjectHashMap = new HashMap<String, JSONObject>();
String lowerBoundaryIngestionDateTime = "0000-00-00T00:00:00.000Z";
int lowerBoundaryIndex = -1;

LOGGER.debug("Embedding page size: " + queryParams.embeddingPageSize());
String previousPageEmbeddingId = "";
do {

EmbeddingSearchRequest searchRequest = EmbeddingSearchRequest.builder()
.queryEmbedding(queryEmbedding)
.maxResults(queryParams.embeddingPageSize())
.minScore(0.0)
.filter(
metadataKey(Constants.METADATA_KEY_INGESTION_DATETIME).isGreaterThan(lowerBoundaryIngestionDateTime).or(
metadataKey(Constants.METADATA_KEY_INGESTION_DATETIME).isGreaterThanOrEqualTo(lowerBoundaryIngestionDateTime).and(
metadataKey("index").isGreaterThan(lowerBoundaryIndex))))
.build();

EmbeddingSearchResult<TextSegment> searchResult = store.search(searchRequest);
embeddingMatches = searchResult.matches();

String currentPageEmbeddingId = "";
for (EmbeddingMatch<TextSegment> match : embeddingMatches) {

Metadata matchMetadata = match.embedded().metadata();
String index = matchMetadata.getString("index");
String fileName = matchMetadata.getString(Constants.METADATA_KEY_FILE_NAME);
String url = matchMetadata.getString(Constants.METADATA_KEY_URL);
String fullPath = matchMetadata.getString(Constants.METADATA_KEY_FULL_PATH);
String absoluteDirectoryPath = matchMetadata.getString(Constants.METADATA_KEY_ABSOLUTE_DIRECTORY_PATH);
String ingestionDatetime = matchMetadata.getString(Constants.METADATA_KEY_INGESTION_DATETIME);

if(lowerBoundaryIngestionDateTime.compareTo(ingestionDatetime) < 0) {

lowerBoundaryIngestionDateTime = ingestionDatetime;
lowerBoundaryIndex = -1;
} else if(lowerBoundaryIngestionDateTime.compareTo(ingestionDatetime) == 0) {

if(Integer.parseInt(index) > lowerBoundaryIndex) {
lowerBoundaryIndex = Integer.parseInt(index);
}
}

JSONObject contentObject = new JSONObject();
contentObject.put(Constants.METADATA_KEY_ABSOLUTE_DIRECTORY_PATH, absoluteDirectoryPath);
contentObject.put(Constants.METADATA_KEY_FULL_PATH, fullPath);
contentObject.put(Constants.METADATA_KEY_FILE_NAME, fileName);
contentObject.put(Constants.METADATA_KEY_URL, url);
contentObject.put(Constants.METADATA_KEY_INGESTION_DATETIME, ingestionDatetime);

String key =
((fullPath != null && !fullPath.isEmpty()) ? fullPath :
(url != null && !url.isEmpty()) ? url : "") +
((ingestionDatetime != null && !ingestionDatetime.isEmpty()) ? ingestionDatetime : "");

// Add contentObject to sources only if it has at least one key-value pair
if (!contentObject.isEmpty() && !key.isEmpty()) {

sourcesJSONObjectHashMap.put(key, contentObject);
}
currentPageEmbeddingId = match.embeddingId();
}

sourcesJSONObjectHashMap.put(key, contentObject);
LOGGER.debug("previousPageEmbeddingId: " + previousPageEmbeddingId + ", currentPageEmbeddingId: " + currentPageEmbeddingId);
if(previousPageEmbeddingId.compareTo(currentPageEmbeddingId) == 0) {
break;
} else {
previousPageEmbeddingId = currentPageEmbeddingId;
}
}

} while(embeddingMatches.size() == queryParams.embeddingPageSize());

jsonObject.put("sources", JsonUtils.jsonObjectCollectionToJsonArray(sourcesJSONObjectHashMap.values()));
jsonObject.put("sourceCount", sourcesJSONObjectHashMap.size());

return toInputStream(jsonObject.toString(), StandardCharsets.UTF_8);
}
Expand Down

0 comments on commit 3c462f4

Please sign in to comment.