Skip to content

Commit

Permalink
integration[patch]: feat: implement max marginal relevance search for…
Browse files Browse the repository at this point in the history
… Weaviate vector store (#3395)

* feat: implement max marginal relevance search for Weaviate vector store

* formatting

* Adds docs

---------

Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
the-powerpointer and jacoblee93 authored Nov 28, 2023
1 parent 85470e4 commit ed51ace
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 5 deletions.
10 changes: 9 additions & 1 deletion docs/core_docs/docs/integrations/vectorstores/weaviate.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,15 @@ import QueryExample from "@examples/indexes/vector_stores/weaviate_search.ts";

<CodeBlock language="typescript">{QueryExample}</CodeBlock>

## Usage delete documents
## Usage, maximal marginal relevance

import MaximumMarginalRelevanceExample from "@examples/indexes/vector_stores/weaviate_mmr.ts";

You can use maximal marginal relevance search, which optimizes for similarity to the query AND diversity.

<CodeBlock language="typescript">{MaximumMarginalRelevanceExample}</CodeBlock>

## Usage, delete documents

import DeleteExample from "@examples/indexes/vector_stores/weaviate_delete.ts";

Expand Down
28 changes: 28 additions & 0 deletions examples/src/indexes/vector_stores/weaviate_mmr.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import weaviate from "weaviate-ts-client";
import { WeaviateStore } from "langchain/vectorstores/weaviate";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";

export async function run() {
// Something wrong with the weaviate-ts-client types, so we need to disable
const client = (weaviate as any).client({
scheme: process.env.WEAVIATE_SCHEME || "https",
host: process.env.WEAVIATE_HOST || "localhost",
apiKey: new (weaviate as any).ApiKey(
process.env.WEAVIATE_API_KEY || "default"
),
});

// Create a store for an existing index
const store = await WeaviateStore.fromExistingIndex(new OpenAIEmbeddings(), {
client,
indexName: "Test",
metadataKeys: ["foo"],
});

const resultOne = await store.maxMarginalRelevanceSearch("Hello world", {
k: 1,
});

console.log(resultOne);
}
76 changes: 72 additions & 4 deletions langchain/src/vectorstores/weaviate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ import type {
WeaviateObject,
WhereFilter,
} from "weaviate-ts-client";
import { VectorStore } from "./base.js";
import { MaxMarginalRelevanceSearchOptions, VectorStore } from "./base.js";
import { Embeddings } from "../embeddings/base.js";
import { Document } from "../document.js";
import { maximalMarginalRelevance } from "../util/math.js";

// Note this function is not generic, it is designed specifically for Weaviate
// https://weaviate.io/developers/weaviate/config-refs/datatypes#introduction
Expand Down Expand Up @@ -263,11 +264,35 @@ export class WeaviateStore extends VectorStore {
k: number,
filter?: WeaviateFilter
): Promise<[Document, number][]> {
const resultsWithEmbedding =
await this.similaritySearchVectorWithScoreAndEmbedding(query, k, filter);
return resultsWithEmbedding.map(([document, score, _embedding]) => [
document,
score,
]);
}

/**
* Method to perform a similarity search on the stored vectors in the
* Weaviate index. It returns the top k most similar documents, their
* similarity scores and embedding vectors.
* @param query The query vector.
* @param k The number of most similar documents to return.
* @param filter Optional filter to apply to the search.
* @returns An array of tuples, where each tuple contains a document, its similarity score and its embedding vector.
*/
async similaritySearchVectorWithScoreAndEmbedding(
query: number[],
k: number,
filter?: WeaviateFilter
): Promise<[Document, number, number[]][]> {
try {
let builder = await this.client.graphql
let builder = this.client.graphql
.get()
.withClassName(this.indexName)
.withFields(`${this.queryAttrs.join(" ")} _additional { distance }`)
.withFields(
`${this.queryAttrs.join(" ")} _additional { distance vector }`
)
.withNearVector({
vector: query,
distance: filter?.distance,
Expand All @@ -284,7 +309,7 @@ export class WeaviateStore extends VectorStore {

const result = await builder.do();

const documents: [Document, number][] = [];
const documents: [Document, number, number[]][] = [];
for (const data of result.data.Get[this.indexName]) {
const { [this.textKey]: text, _additional, ...rest }: ResultRow = data;

Expand All @@ -294,6 +319,7 @@ export class WeaviateStore extends VectorStore {
metadata: rest,
}),
_additional.distance,
_additional.vector,
]);
}
return documents;
Expand All @@ -302,6 +328,48 @@ export class WeaviateStore extends VectorStore {
}
}

/**
* Return documents selected using the maximal marginal relevance.
* Maximal marginal relevance optimizes for similarity to the query AND diversity
* among selected documents.
*
* @param {string} query - Text to look up documents similar to.
* @param {number} options.k - Number of documents to return.
* @param {number} options.fetchK - Number of documents to fetch before passing to the MMR algorithm.
* @param {number} options.lambda - Number between 0 and 1 that determines the degree of diversity among the results,
* where 0 corresponds to maximum diversity and 1 to minimum diversity.
* @param {this["FilterType"]} options.filter - Optional filter
* @param _callbacks
*
* @returns {Promise<Document[]>} - List of documents selected by maximal marginal relevance.
*/
override async maxMarginalRelevanceSearch(
query: string,
options: MaxMarginalRelevanceSearchOptions<this["FilterType"]>,
_callbacks: undefined
): Promise<Document[]> {
const { k, fetchK = 20, lambda = 0.5, filter } = options;
const queryEmbedding: number[] = await this.embeddings.embedQuery(query);
const allResults: [Document, number, number[]][] =
await this.similaritySearchVectorWithScoreAndEmbedding(
queryEmbedding,
fetchK,
filter
);
const embeddingList = allResults.map(
([_doc, _score, embedding]) => embedding
);
const mmrIndexes = maximalMarginalRelevance(
queryEmbedding,
embeddingList,
lambda,
k
);
return mmrIndexes
.filter((idx) => idx !== -1)
.map((idx) => allResults[idx][0]);
}

/**
* Static method to create a new `WeaviateStore` instance from a list of
* texts. It first creates documents from the texts and metadata, then
Expand Down

1 comment on commit ed51ace

@vercel
Copy link

@vercel vercel bot commented on ed51ace Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.