diff --git a/docker-compose.yml b/docker-compose.yml index f050399fc8a9..d07b5cfb3a1b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -11,6 +11,7 @@ services: - ./environment_tests/scripts:/scripts - ./langchain:/langchain - ./langchain-core:/langchain-core + - ./libs/langchain-community:/langchain-community - ./libs/langchain-anthropic:/langchain-anthropic - ./libs/langchain-openai:/langchain-openai command: bash /scripts/docker-ci-entrypoint.sh @@ -25,6 +26,7 @@ services: - ./environment_tests/scripts:/scripts - ./langchain:/langchain - ./langchain-core:/langchain-core + - ./libs/langchain-community:/langchain-community - ./libs/langchain-anthropic:/langchain-anthropic - ./libs/langchain-openai:/langchain-openai command: bash /scripts/docker-ci-entrypoint.sh @@ -39,6 +41,7 @@ services: - ./environment_tests/scripts:/scripts - ./langchain:/langchain - ./langchain-core:/langchain-core + - ./libs/langchain-community:/langchain-community - ./libs/langchain-anthropic:/langchain-anthropic - ./libs/langchain-openai:/langchain-openai command: bash /scripts/docker-ci-entrypoint.sh @@ -53,6 +56,7 @@ services: - ./environment_tests/scripts:/scripts - ./langchain:/langchain - ./langchain-core:/langchain-core + - ./libs/langchain-community:/langchain-community - ./libs/langchain-anthropic:/langchain-anthropic - ./libs/langchain-openai:/langchain-openai command: bash /scripts/docker-ci-entrypoint.sh @@ -67,6 +71,7 @@ services: - ./environment_tests/scripts:/scripts - ./langchain:/langchain - ./langchain-core:/langchain-core + - ./libs/langchain-community:/langchain-community - ./libs/langchain-anthropic:/langchain-anthropic - ./libs/langchain-openai:/langchain-openai command: bash /scripts/docker-ci-entrypoint.sh @@ -81,6 +86,7 @@ services: - ./environment_tests/scripts:/scripts - ./langchain:/langchain - ./langchain-core:/langchain-core + - ./libs/langchain-community:/langchain-community - ./libs/langchain-anthropic:/langchain-anthropic - ./libs/langchain-openai:/langchain-openai command: bash /scripts/docker-ci-entrypoint.sh @@ -92,6 +98,7 @@ services: # - ./environment_tests/scripts:/scripts # - ./langchain:/langchain-workspace # - ./langchain-core:/langchain-core + # - ./libs/langchain-community:/langchain-community-workspace # - ./libs/langchain-anthropic:/langchain-anthropic-workspace # command: bash /scripts/docker-bun-ci-entrypoint.sh success: diff --git a/docs/core_docs/docs/expression_language/get_started.mdx b/docs/core_docs/docs/expression_language/get_started.mdx new file mode 100644 index 000000000000..a3e6e8035675 --- /dev/null +++ b/docs/core_docs/docs/expression_language/get_started.mdx @@ -0,0 +1,92 @@ +--- +sidebar_position: 0 +title: Get started +--- + +import CodeBlock from "@theme/CodeBlock"; +import BasicExample from "@examples/guides/expression_language/get_started/basic.ts"; +import BasicPromptExample from "@examples/guides/expression_language/get_started/prompt.ts"; +import BasicChatModelExample from "@examples/guides/expression_language/get_started/chat_model.ts"; +import BasicLLMModelExample from "@examples/guides/expression_language/get_started/llm_model.ts"; +import BasicOutputParserExample from "@examples/guides/expression_language/get_started/output_parser.ts"; +import BasicRagExample from "@examples/guides/expression_language/get_started/rag.ts"; + +# Get started + +LCEL makes it easy to build complex chains from basic components, and supports out of the box functionality such as streaming, parallelism, and logging. + +## Basic example: prompt + model + output parser + +The most basic and common use case is chaining a prompt template and a model together. To see how this works, let's create a chain that takes a topic and generates a joke: + +{BasicExample} + +:::tip + +[LangSmith trace](https://smith.langchain.com/public/dcac6d79-5254-4889-a974-4b3abaf605b4/r) + +::: + +Notice in this line we're chaining our prompt, LLM model and output parser together: + +```typescript +const chain = prompt.pipe(model).pipe(outputParser); +``` + +The `.pipe()` method allows for chaining together any number of runnables. It will pass the output of one through to the input of the next. + +Here, the prompt is passed a `topic` and when invoked it returns a formatted string with the `{topic}` input variable replaced with the string we passed to the invoke call. +That string is then passed as the input to the LLM which returns a `BaseMessage` object. Finally, the output parser takes that `BaseMessage` object and returns the content of that object as a string. + +### 1. Prompt + +`prompt` is a `BasePromptTemplate`, which means it takes in an object of template variables and produces a `PromptValue`. +A `PromptValue` is a wrapper around a completed prompt that can be passed to either an `LLM` (which takes a string as input) or `ChatModel` (which takes a sequence of messages as input). +It can work with either language model type because it defines logic both for producing BaseMessages and for producing a string. + +{BasicPromptExample} + +### 2. Model + +The `PromptValue` is then passed to `model`. In this case our `model` is a `ChatModel`, meaning it will output a `BaseMessage`. + +{BasicChatModelExample} + +If our model was an LLM, it would output a string. + +{BasicLLMModelExample} + +### 3. Output parser + +And lastly we pass our `model` output to the `outputParser`, which is a `BaseOutputParser` meaning it takes either a string or a `BaseMessage` as input. The `StringOutputParser` specifically simple converts any input into a string. + +{BasicOutputParserExample} + +## RAG Search Example + +For our next example, we want to run a retrieval-augmented generation chain to add some context when responding to questions. + +{BasicRagExample} + +:::tip + +[LangSmith trace](https://smith.langchain.com/public/f0205e20-c46f-47cd-a3a4-6a95451f8a25/r) + +::: + +In this chain we add some extra logic around retrieving context from a vector store. + +We first instantiated our model, vector store and output parser. Then we defined our prompt, which takes in two input variables: + +- `context` -> this is a string which is returned from our vector store based on a semantic search from the input. +- `question` -> this is the question we want to ask. + +Next we created a `setupAndRetriever` runnable. This has two components which return the values required by our prompt: + +- `context` -> this is a `RunnableLambda` which takes the input from the `.invoke()` call, makes a request to our vector store, and returns the first result. +- `question` -> this uses a `RunnablePassthrough` which simply passes whatever the input was through to the next step, and in our case it returns it to the key in the object we defined. + +Both of these are wrapped inside a `RunnableMap`. This is a special type of runnable that takes an object of runnables and executes them all in parallel. +It then returns an object with the same keys as the input object, but with the values replaced with the output of the runnables. + +Finally, we pass the output of the `setupAndRetriever` to our `prompt` and then to our `model` and `outputParser` as before. diff --git a/environment_tests/scripts/docker-ci-entrypoint.sh b/environment_tests/scripts/docker-ci-entrypoint.sh index dd98d4276256..a4e2d45e7d45 100644 --- a/environment_tests/scripts/docker-ci-entrypoint.sh +++ b/environment_tests/scripts/docker-ci-entrypoint.sh @@ -19,6 +19,7 @@ cp ../root/yarn.lock ../root/.yarnrc.yml . # Avoid calling "yarn add ../langchain" as yarn berry does seem to hang for ~30s # before installation actually occurs sed -i 's/"@langchain\/core": "workspace:\*"/"@langchain\/core": "..\/langchain-core"/g' package.json +sed -i 's/"@langchain\/community": "workspace:\*"/"@langchain\/community": "..\/langchain-community"/g' package.json sed -i 's/"@langchain\/anthropic": "workspace:\*"/"@langchain\/anthropic": "..\/langchain-anthropic"/g' package.json sed -i 's/"@langchain\/openai": "workspace:\*"/"@langchain\/openai": "..\/langchain-openai"/g' package.json sed -i 's/"langchain": "workspace:\*"/"langchain": "..\/langchain"/g' package.json diff --git a/environment_tests/test-exports-bun/package.json b/environment_tests/test-exports-bun/package.json index 8294ac754257..649c5e7a1f3d 100644 --- a/environment_tests/test-exports-bun/package.json +++ b/environment_tests/test-exports-bun/package.json @@ -18,6 +18,7 @@ "license": "MIT", "dependencies": { "@langchain/anthropic": "workspace:*", + "@langchain/community": "workspace:*", "@langchain/core": "workspace:*", "@langchain/openai": "workspace:*", "d3-dsv": "2", diff --git a/environment_tests/test-exports-cf/package.json b/environment_tests/test-exports-cf/package.json index b87c55713fda..fa01eb216974 100644 --- a/environment_tests/test-exports-cf/package.json +++ b/environment_tests/test-exports-cf/package.json @@ -9,6 +9,7 @@ }, "dependencies": { "@langchain/anthropic": "workspace:*", + "@langchain/community": "workspace:*", "@langchain/core": "workspace:*", "@langchain/openai": "workspace:*", "langchain": "workspace:*" diff --git a/environment_tests/test-exports-cjs/package.json b/environment_tests/test-exports-cjs/package.json index b0b43151e2b4..e9aecb14e38d 100644 --- a/environment_tests/test-exports-cjs/package.json +++ b/environment_tests/test-exports-cjs/package.json @@ -19,6 +19,7 @@ "license": "MIT", "dependencies": { "@langchain/anthropic": "workspace:*", + "@langchain/community": "workspace:*", "@langchain/core": "workspace:*", "@langchain/openai": "workspace:*", "d3-dsv": "2", diff --git a/environment_tests/test-exports-esbuild/package.json b/environment_tests/test-exports-esbuild/package.json index 15241cd5ef20..5d456fba7335 100644 --- a/environment_tests/test-exports-esbuild/package.json +++ b/environment_tests/test-exports-esbuild/package.json @@ -17,6 +17,7 @@ "license": "MIT", "dependencies": { "@langchain/anthropic": "workspace:*", + "@langchain/community": "workspace:*", "@langchain/core": "workspace:*", "@langchain/openai": "workspace:*", "d3-dsv": "2", diff --git a/environment_tests/test-exports-esm/package.json b/environment_tests/test-exports-esm/package.json index a79fc8e8b174..9eeb68a80474 100644 --- a/environment_tests/test-exports-esm/package.json +++ b/environment_tests/test-exports-esm/package.json @@ -20,6 +20,7 @@ "license": "MIT", "dependencies": { "@langchain/anthropic": "workspace:*", + "@langchain/community": "workspace:*", "@langchain/core": "workspace:*", "@langchain/openai": "workspace:*", "d3-dsv": "2", diff --git a/environment_tests/test-exports-vercel/package.json b/environment_tests/test-exports-vercel/package.json index c784c3d21a8d..dde8eb4e7128 100644 --- a/environment_tests/test-exports-vercel/package.json +++ b/environment_tests/test-exports-vercel/package.json @@ -10,6 +10,7 @@ }, "dependencies": { "@langchain/anthropic": "workspace:*", + "@langchain/community": "workspace:*", "@langchain/core": "workspace:*", "@langchain/openai": "workspace:*", "@types/node": "18.15.11", diff --git a/environment_tests/test-exports-vite/package.json b/environment_tests/test-exports-vite/package.json index 3a3ea14886d5..2240bd9398ca 100644 --- a/environment_tests/test-exports-vite/package.json +++ b/environment_tests/test-exports-vite/package.json @@ -11,6 +11,7 @@ }, "dependencies": { "@langchain/anthropic": "workspace:*", + "@langchain/community": "workspace:*", "@langchain/core": "workspace:*", "@langchain/openai": "workspace:*", "langchain": "workspace:*" diff --git a/examples/src/experimental/masking/basic.ts b/examples/src/experimental/masking/basic.ts index 77a5c78de2b2..7d364ab00449 100644 --- a/examples/src/experimental/masking/basic.ts +++ b/examples/src/experimental/masking/basic.ts @@ -20,11 +20,11 @@ maskingParser.addTransformer(piiMaskingTransformer); const input = "Contact me at jane.doe@email.com or 555-123-4567. Also reach me at john.smith@email.com"; -const masked = await maskingParser.parse(input); +const masked = await maskingParser.mask(input); console.log(masked); // Contact me at [email-a31e486e324f6] or [phone-da8fc1584f224]. Also reach me at [email-d5b6237633d95] -const rehydrated = maskingParser.rehydrate(masked); +const rehydrated = await maskingParser.rehydrate(masked); console.log(rehydrated); // Contact me at jane.doe@email.com or 555-123-4567. Also reach me at john.smith@email.com diff --git a/examples/src/experimental/masking/kitchen_sink.ts b/examples/src/experimental/masking/kitchen_sink.ts index 07e85e7fc50f..4743c20a7b2f 100644 --- a/examples/src/experimental/masking/kitchen_sink.ts +++ b/examples/src/experimental/masking/kitchen_sink.ts @@ -70,7 +70,7 @@ const message = // Mask and rehydrate the message maskingParser - .parse(message) + .mask(message) .then((maskedMessage: string) => { console.log(`Masked message: ${maskedMessage}`); return maskingParser.rehydrate(maskedMessage); diff --git a/examples/src/experimental/masking/next.ts b/examples/src/experimental/masking/next.ts index 85621a4b8dca..dac49a781978 100644 --- a/examples/src/experimental/masking/next.ts +++ b/examples/src/experimental/masking/next.ts @@ -36,11 +36,11 @@ export async function POST(req: Request) { const formattedPreviousMessages = messages.slice(0, -1).map(formatMessage); const currentMessageContent = messages[messages.length - 1].content; // Extract the content of the last message // Mask sensitive information in the current message - const guardedMessageContent = await maskingParser.parse( + const guardedMessageContent = await maskingParser.mask( currentMessageContent ); // Mask sensitive information in the chat history - const guardedHistory = await maskingParser.parse( + const guardedHistory = await maskingParser.mask( formattedPreviousMessages.join("\n") ); @@ -64,6 +64,11 @@ export async function POST(req: Request) { headers: { "content-type": "text/plain; charset=utf-8" }, }); } catch (e: any) { - return Response.json({ error: e.message }, { status: 500 }); + return new Response(JSON.stringify({ error: e.message }), { + status: 500, + headers: { + "content-type": "application/json", + }, + }); } } diff --git a/examples/src/guides/expression_language/get_started/basic.ts b/examples/src/guides/expression_language/get_started/basic.ts new file mode 100644 index 000000000000..a6035b82a531 --- /dev/null +++ b/examples/src/guides/expression_language/get_started/basic.ts @@ -0,0 +1,20 @@ +import { ChatOpenAI } from "langchain/chat_models/openai"; +import { ChatPromptTemplate } from "langchain/prompts"; +import { StringOutputParser } from "langchain/schema/output_parser"; + +const prompt = ChatPromptTemplate.fromMessages([ + ["human", "Tell me a short joke about {topic}"], +]); +const model = new ChatOpenAI({}); +const outputParser = new StringOutputParser(); + +const chain = prompt.pipe(model).pipe(outputParser); + +const response = await chain.invoke({ + topic: "ice cream", +}); +console.log(response); +/** +Why did the ice cream go to the gym? +Because it wanted to get a little "cone"ditioning! + */ diff --git a/examples/src/guides/expression_language/get_started/chat_model.ts b/examples/src/guides/expression_language/get_started/chat_model.ts new file mode 100644 index 000000000000..f1da2c7c8072 --- /dev/null +++ b/examples/src/guides/expression_language/get_started/chat_model.ts @@ -0,0 +1,14 @@ +import { ChatOpenAI } from "langchain/chat_models/openai"; + +const model = new ChatOpenAI({}); +const promptAsString = "Human: Tell me a short joke about ice cream"; + +const response = await model.invoke(promptAsString); +console.log(response); +/** +AIMessage { + content: 'Sure, here you go: Why did the ice cream go to school? Because it wanted to get a little "sundae" education!', + name: undefined, + additional_kwargs: { function_call: undefined, tool_calls: undefined } +} + */ diff --git a/examples/src/guides/expression_language/get_started/llm_model.ts b/examples/src/guides/expression_language/get_started/llm_model.ts new file mode 100644 index 000000000000..e689a8f828a0 --- /dev/null +++ b/examples/src/guides/expression_language/get_started/llm_model.ts @@ -0,0 +1,12 @@ +import { OpenAI } from "langchain/llms/openai"; + +const model = new OpenAI({}); +const promptAsString = "Human: Tell me a short joke about ice cream"; + +const response = await model.invoke(promptAsString); +console.log(response); +/** +Why did the ice cream go to therapy? + +Because it was feeling a little rocky road. + */ diff --git a/examples/src/guides/expression_language/get_started/output_parser.ts b/examples/src/guides/expression_language/get_started/output_parser.ts new file mode 100644 index 000000000000..7640166c1452 --- /dev/null +++ b/examples/src/guides/expression_language/get_started/output_parser.ts @@ -0,0 +1,12 @@ +import { AIMessage } from "langchain/schema"; +import { StringOutputParser } from "langchain/schema/output_parser"; + +const outputParser = new StringOutputParser(); +const message = new AIMessage( + 'Sure, here you go: Why did the ice cream go to school? Because it wanted to get a little "sundae" education!' +); +const parsed = await outputParser.invoke(message); +console.log(parsed); +/** +Sure, here you go: Why did the ice cream go to school? Because it wanted to get a little "sundae" education! + */ diff --git a/examples/src/guides/expression_language/get_started/prompt.ts b/examples/src/guides/expression_language/get_started/prompt.ts new file mode 100644 index 000000000000..fe178719f954 --- /dev/null +++ b/examples/src/guides/expression_language/get_started/prompt.ts @@ -0,0 +1,34 @@ +import { ChatPromptTemplate } from "langchain/prompts"; + +const prompt = ChatPromptTemplate.fromMessages([ + ["human", "Tell me a short joke about {topic}"], +]); +const promptValue = await prompt.invoke({ topic: "ice cream" }); +console.log(promptValue); +/** +ChatPromptValue { + messages: [ + HumanMessage { + content: 'Tell me a short joke about ice cream', + name: undefined, + additional_kwargs: {} + } + ] +} + */ +const promptAsMessages = promptValue.toChatMessages(); +console.log(promptAsMessages); +/** +[ + HumanMessage { + content: 'Tell me a short joke about ice cream', + name: undefined, + additional_kwargs: {} + } +] + */ +const promptAsString = promptValue.toString(); +console.log(promptAsString); +/** +Human: Tell me a short joke about ice cream + */ diff --git a/examples/src/guides/expression_language/get_started/rag.ts b/examples/src/guides/expression_language/get_started/rag.ts new file mode 100644 index 000000000000..5127c837ff93 --- /dev/null +++ b/examples/src/guides/expression_language/get_started/rag.ts @@ -0,0 +1,47 @@ +import { ChatOpenAI } from "langchain/chat_models/openai"; +import { Document } from "langchain/document"; +import { OpenAIEmbeddings } from "langchain/embeddings/openai"; +import { ChatPromptTemplate } from "langchain/prompts"; +import { + RunnableLambda, + RunnableMap, + RunnablePassthrough, +} from "langchain/runnables"; +import { StringOutputParser } from "langchain/schema/output_parser"; +import { HNSWLib } from "langchain/vectorstores/hnswlib"; + +const vectorStore = await HNSWLib.fromDocuments( + [ + new Document({ pageContent: "Harrison worked at Kensho" }), + new Document({ pageContent: "Bears like to eat honey." }), + ], + new OpenAIEmbeddings() +); +const retriever = vectorStore.asRetriever(1); + +const prompt = ChatPromptTemplate.fromMessages([ + [ + "ai", + `Answer the question based on only the following context: + +{context}`, + ], + ["human", "{question}"], +]); +const model = new ChatOpenAI({}); +const outputParser = new StringOutputParser(); + +const setupAndRetrieval = RunnableMap.from({ + context: new RunnableLambda({ + func: (input: string) => + retriever.invoke(input).then((response) => response[0].pageContent), + }).withConfig({ runName: "contextRetriever" }), + question: new RunnablePassthrough(), +}); +const chain = setupAndRetrieval.pipe(prompt).pipe(model).pipe(outputParser); + +const response = await chain.invoke("Where did Harrison work?"); +console.log(response); +/** +Harrison worked at Kensho. + */ diff --git a/langchain-core/.gitignore b/langchain-core/.gitignore index 05ba9919a328..c6100ae28393 100644 --- a/langchain-core/.gitignore +++ b/langchain-core/.gitignore @@ -106,6 +106,9 @@ utils/json_patch.d.ts utils/json_schema.cjs utils/json_schema.js utils/json_schema.d.ts +utils/math.cjs +utils/math.js +utils/math.d.ts utils/stream.cjs utils/stream.js utils/stream.d.ts diff --git a/langchain-core/package.json b/langchain-core/package.json index 132e1840c301..fafa426e1b89 100644 --- a/langchain-core/package.json +++ b/langchain-core/package.json @@ -1,6 +1,6 @@ { "name": "@langchain/core", - "version": "0.0.10", + "version": "0.0.11-rc.1", "description": "Core LangChain.js abstractions and schemas", "type": "module", "engines": { @@ -40,6 +40,7 @@ "decamelize": "1.2.0", "js-tiktoken": "^1.0.8", "langsmith": "~0.0.48", + "ml-distance": "^4.0.0", "p-queue": "^6.6.2", "p-retry": "4", "uuid": "^9.0.0", @@ -59,6 +60,7 @@ "eslint-plugin-prettier": "^4.2.1", "jest": "^29.5.0", "jest-environment-node": "^29.6.4", + "ml-matrix": "^6.10.4", "prettier": "^2.8.3", "release-it": "^15.10.1", "rimraf": "^5.0.1", @@ -263,6 +265,11 @@ "import": "./utils/json_schema.js", "require": "./utils/json_schema.cjs" }, + "./utils/math": { + "types": "./utils/math.d.ts", + "import": "./utils/math.js", + "require": "./utils/math.cjs" + }, "./utils/stream": { "types": "./utils/stream.d.ts", "import": "./utils/stream.js", @@ -400,6 +407,9 @@ "utils/json_schema.cjs", "utils/json_schema.js", "utils/json_schema.d.ts", + "utils/math.cjs", + "utils/math.js", + "utils/math.d.ts", "utils/stream.cjs", "utils/stream.js", "utils/stream.d.ts", diff --git a/langchain-core/scripts/create-entrypoints.js b/langchain-core/scripts/create-entrypoints.js index b49de2176afb..328796a52d64 100644 --- a/langchain-core/scripts/create-entrypoints.js +++ b/langchain-core/scripts/create-entrypoints.js @@ -44,6 +44,7 @@ const entrypoints = { "utils/hash": "utils/hash", "utils/json_patch": "utils/json_patch", "utils/json_schema": "utils/json_schema", + "utils/math": "utils/math", "utils/stream": "utils/stream", "utils/testing": "utils/testing/index", "utils/tiktoken": "utils/tiktoken", diff --git a/langchain-core/src/documents/transformers.ts b/langchain-core/src/documents/transformers.ts index b42ccab2dc16..45b5a6a06607 100644 --- a/langchain-core/src/documents/transformers.ts +++ b/langchain-core/src/documents/transformers.ts @@ -36,3 +36,20 @@ export abstract class BaseDocumentTransformer< return this.transformDocuments(input); } } + +/** + * Class for document transformers that return exactly one transformed document + * for each input document. + */ +export abstract class MappingDocumentTransformer extends BaseDocumentTransformer { + async transformDocuments(documents: Document[]): Promise { + const newDocuments = []; + for (const document of documents) { + const transformedDocument = await this._transformDocument(document); + newDocuments.push(transformedDocument); + } + return newDocuments; + } + + abstract _transformDocument(document: Document): Promise; +} diff --git a/langchain-core/src/load/import_map.ts b/langchain-core/src/load/import_map.ts index 40fe9f8af9a4..0c94e2838e41 100644 --- a/langchain-core/src/load/import_map.ts +++ b/langchain-core/src/load/import_map.ts @@ -35,6 +35,7 @@ export * as utils__env from "../utils/env.js"; export * as utils__hash from "../utils/hash.js"; export * as utils__json_patch from "../utils/json_patch.js"; export * as utils__json_schema from "../utils/json_schema.js"; +export * as utils__math from "../utils/math.js"; export * as utils__stream from "../utils/stream.js"; export * as utils__testing from "../utils/testing/index.js"; export * as utils__tiktoken from "../utils/tiktoken.js"; diff --git a/langchain-core/src/load/index.ts b/langchain-core/src/load/index.ts index 69e32bc52cbc..870b7a374f93 100644 --- a/langchain-core/src/load/index.ts +++ b/langchain-core/src/load/index.ts @@ -104,15 +104,25 @@ async function reviver( | (typeof finalImportMap)[keyof typeof finalImportMap] | OptionalImportMap[keyof OptionalImportMap] | null = null; + + const optionalImportNamespaceAliases = [namespace.join("/")]; + if (namespace[0] === "langchain_community") { + optionalImportNamespaceAliases.push( + ["langchain", ...namespace.slice(1)].join("/") + ); + } + const matchingNamespaceAlias = optionalImportNamespaceAliases.find( + (alias) => alias in optionalImportsMap + ); if ( defaultOptionalImportEntrypoints .concat(optionalImportEntrypoints) .includes(namespace.join("/")) || - namespace.join("/") in optionalImportsMap + matchingNamespaceAlias ) { - if (namespace.join("/") in optionalImportsMap) { + if (matchingNamespaceAlias !== undefined) { module = await optionalImportsMap[ - namespace.join("/") as keyof typeof optionalImportsMap + matchingNamespaceAlias as keyof typeof optionalImportsMap ]; } else { throw new Error( @@ -126,6 +136,7 @@ async function reviver( if ( namespace[0] === "langchain" || namespace[0] === "langchain_core" || + namespace[0] === "langchain_community" || namespace[0] === "langchain_anthropic" || namespace[0] === "langchain_openai" ) { diff --git a/langchain-core/src/messages/index.ts b/langchain-core/src/messages/index.ts index 9ddc2fd64e06..2019e67e7620 100644 --- a/langchain-core/src/messages/index.ts +++ b/langchain-core/src/messages/index.ts @@ -734,3 +734,29 @@ export function mapStoredMessageToChatMessage(message: StoredMessage) { throw new Error(`Got unexpected type: ${storedMessage.type}`); } } + +/** + * Transforms an array of `StoredMessage` instances into an array of + * `BaseMessage` instances. It uses the `mapV1MessageToStoredMessage` + * function to ensure all messages are in the `StoredMessage` format, then + * creates new instances of the appropriate `BaseMessage` subclass based + * on the type of each message. This function is used to prepare stored + * messages for use in a chat context. + */ +export function mapStoredMessagesToChatMessages( + messages: StoredMessage[] +): BaseMessage[] { + return messages.map(mapStoredMessageToChatMessage); +} + +/** + * Transforms an array of `BaseMessage` instances into an array of + * `StoredMessage` instances. It does this by calling the `toDict` method + * on each `BaseMessage`, which returns a `StoredMessage`. This function + * is used to prepare chat messages for storage. + */ +export function mapChatMessagesToStoredMessages( + messages: BaseMessage[] +): StoredMessage[] { + return messages.map((message) => message.toDict()); +} diff --git a/langchain-core/src/utils/math.ts b/langchain-core/src/utils/math.ts new file mode 100644 index 000000000000..fe703c2d5f79 --- /dev/null +++ b/langchain-core/src/utils/math.ts @@ -0,0 +1,180 @@ +import { + similarity as ml_distance_similarity, + distance as ml_distance, +} from "ml-distance"; + +type VectorFunction = (xVector: number[], yVector: number[]) => number; + +/** + * Apply a row-wise function between two matrices with the same number of columns. + * + * @param {number[][]} X - The first matrix. + * @param {number[][]} Y - The second matrix. + * @param {VectorFunction} func - The function to apply. + * + * @throws {Error} If the number of columns in X and Y are not the same. + * + * @returns {number[][] | [[]]} A matrix where each row represents the result of applying the function between the corresponding rows of X and Y. + */ + +export function matrixFunc( + X: number[][], + Y: number[][], + func: VectorFunction +): number[][] { + if ( + X.length === 0 || + X[0].length === 0 || + Y.length === 0 || + Y[0].length === 0 + ) { + return [[]]; + } + + if (X[0].length !== Y[0].length) { + throw new Error( + `Number of columns in X and Y must be the same. X has shape ${[ + X.length, + X[0].length, + ]} and Y has shape ${[Y.length, Y[0].length]}.` + ); + } + + return X.map((xVector) => + Y.map((yVector) => func(xVector, yVector)).map((similarity) => + Number.isNaN(similarity) ? 0 : similarity + ) + ); +} + +export function normalize(M: number[][], similarity = false): number[][] { + const max = matrixMaxVal(M); + return M.map((row) => + row.map((val) => (similarity ? 1 - val / max : val / max)) + ); +} + +/** + * This function calculates the row-wise cosine similarity between two matrices with the same number of columns. + * + * @param {number[][]} X - The first matrix. + * @param {number[][]} Y - The second matrix. + * + * @throws {Error} If the number of columns in X and Y are not the same. + * + * @returns {number[][] | [[]]} A matrix where each row represents the cosine similarity values between the corresponding rows of X and Y. + */ +export function cosineSimilarity(X: number[][], Y: number[][]): number[][] { + return matrixFunc(X, Y, ml_distance_similarity.cosine); +} + +export function innerProduct(X: number[][], Y: number[][]): number[][] { + return matrixFunc(X, Y, ml_distance.innerProduct); +} + +export function euclideanDistance(X: number[][], Y: number[][]): number[][] { + return matrixFunc(X, Y, ml_distance.euclidean); +} + +/** + * This function implements the Maximal Marginal Relevance algorithm + * to select a set of embeddings that maximizes the diversity and relevance to a query embedding. + * + * @param {number[]|number[][]} queryEmbedding - The query embedding. + * @param {number[][]} embeddingList - The list of embeddings to select from. + * @param {number} [lambda=0.5] - The trade-off parameter between relevance and diversity. + * @param {number} [k=4] - The maximum number of embeddings to select. + * + * @returns {number[]} The indexes of the selected embeddings in the embeddingList. + */ +export function maximalMarginalRelevance( + queryEmbedding: number[] | number[][], + embeddingList: number[][], + lambda = 0.5, + k = 4 +): number[] { + if (Math.min(k, embeddingList.length) <= 0) { + return []; + } + + const queryEmbeddingExpanded = ( + Array.isArray(queryEmbedding[0]) ? queryEmbedding : [queryEmbedding] + ) as number[][]; + + const similarityToQuery = cosineSimilarity( + queryEmbeddingExpanded, + embeddingList + )[0]; + const mostSimilarEmbeddingIndex = argMax(similarityToQuery).maxIndex; + + const selectedEmbeddings = [embeddingList[mostSimilarEmbeddingIndex]]; + const selectedEmbeddingsIndexes = [mostSimilarEmbeddingIndex]; + + while (selectedEmbeddingsIndexes.length < Math.min(k, embeddingList.length)) { + let bestScore = -Infinity; + let bestIndex = -1; + + const similarityToSelected = cosineSimilarity( + embeddingList, + selectedEmbeddings + ); + + similarityToQuery.forEach((queryScore, queryScoreIndex) => { + if (selectedEmbeddingsIndexes.includes(queryScoreIndex)) { + return; + } + const maxSimilarityToSelected = Math.max( + ...similarityToSelected[queryScoreIndex] + ); + const score = + lambda * queryScore - (1 - lambda) * maxSimilarityToSelected; + + if (score > bestScore) { + bestScore = score; + bestIndex = queryScoreIndex; + } + }); + selectedEmbeddings.push(embeddingList[bestIndex]); + selectedEmbeddingsIndexes.push(bestIndex); + } + + return selectedEmbeddingsIndexes; +} + +type MaxInfo = { + maxIndex: number; + maxValue: number; +}; + +/** + * Finds the index of the maximum value in the given array. + * @param {number[]} array - The input array. + * + * @returns {number} The index of the maximum value in the array. If the array is empty, returns -1. + */ +function argMax(array: number[]): MaxInfo { + if (array.length === 0) { + return { + maxIndex: -1, + maxValue: NaN, + }; + } + + let maxValue = array[0]; + let maxIndex = 0; + + for (let i = 1; i < array.length; i += 1) { + if (array[i] > maxValue) { + maxIndex = i; + maxValue = array[i]; + } + } + return { maxIndex, maxValue }; +} + +function matrixMaxVal(arrays: number[][]): number { + return arrays.reduce( + (acc, array) => Math.max(acc, argMax(array).maxValue), + 0 + ); +} diff --git a/langchain/src/util/tests/math_utils.test.ts b/langchain-core/src/utils/tests/math_utils.test.ts similarity index 100% rename from langchain/src/util/tests/math_utils.test.ts rename to langchain-core/src/utils/tests/math_utils.test.ts diff --git a/langchain/package.json b/langchain/package.json index 5c1b3fe2ce79..b858507fa64a 100644 --- a/langchain/package.json +++ b/langchain/package.json @@ -1,6 +1,6 @@ { "name": "langchain", - "version": "0.0.203", + "version": "0.0.204-rc.0", "description": "Typescript bindings for langchain", "type": "module", "engines": { @@ -882,10 +882,6 @@ "author": "LangChain", "license": "MIT", "devDependencies": { - "@aws-crypto/sha256-js": "^5.0.0", - "@aws-sdk/client-bedrock-runtime": "^3.422.0", - "@aws-sdk/client-dynamodb": "^3.310.0", - "@aws-sdk/client-kendra": "^3.352.0", "@aws-sdk/client-lambda": "^3.310.0", "@aws-sdk/client-s3": "^3.310.0", "@aws-sdk/client-sagemaker-runtime": "^3.414.0", @@ -930,7 +926,6 @@ "@tsconfig/recommended": "^1.0.2", "@types/d3-dsv": "^2", "@types/decamelize": "^1.2.0", - "@types/flat": "^5.0.2", "@types/html-to-text": "^9", "@types/js-yaml": "^4", "@types/jsdom": "^21.1.1", @@ -988,7 +983,6 @@ "llmonitor": "^0.5.9", "lodash": "^4.17.21", "mammoth": "^1.5.1", - "ml-matrix": "^6.10.4", "mongodb": "^5.2.0", "mysql2": "^3.3.3", "neo4j-driver": "^5.12.0", @@ -1026,10 +1020,6 @@ "youtubei.js": "^5.8.0" }, "peerDependencies": { - "@aws-crypto/sha256-js": "^5.0.0", - "@aws-sdk/client-bedrock-runtime": "^3.422.0", - "@aws-sdk/client-dynamodb": "^3.310.0", - "@aws-sdk/client-kendra": "^3.352.0", "@aws-sdk/client-lambda": "^3.310.0", "@aws-sdk/client-s3": "^3.310.0", "@aws-sdk/client-sagemaker-runtime": "^3.310.0", @@ -1129,18 +1119,6 @@ "youtubei.js": "^5.8.0" }, "peerDependenciesMeta": { - "@aws-crypto/sha256-js": { - "optional": true - }, - "@aws-sdk/client-bedrock-runtime": { - "optional": true - }, - "@aws-sdk/client-dynamodb": { - "optional": true - }, - "@aws-sdk/client-kendra": { - "optional": true - }, "@aws-sdk/client-lambda": { "optional": true }, @@ -1435,17 +1413,17 @@ }, "dependencies": { "@anthropic-ai/sdk": "^0.9.1", - "@langchain/core": "~0.0.10", + "@langchain/community": "~0.0.0", + "@langchain/core": "~0.0.11-rc.1", + "@langchain/openai": "~0.0.2-rc.0", "binary-extensions": "^2.2.0", "expr-eval": "^2.0.2", - "flat": "^5.0.2", "js-tiktoken": "^1.0.7", "js-yaml": "^4.1.0", "jsonpointer": "^5.0.1", "langchainhub": "~0.0.6", "langsmith": "~0.0.48", "ml-distance": "^4.0.0", - "openai": "^4.19.0", "openapi-types": "^12.1.3", "p-retry": "4", "uuid": "^9.0.0", diff --git a/langchain/scripts/check-tree-shaking.js b/langchain/scripts/check-tree-shaking.js index 66a1f194199f..bae5213269d7 100644 --- a/langchain/scripts/check-tree-shaking.js +++ b/langchain/scripts/check-tree-shaking.js @@ -28,6 +28,7 @@ export function listExternals() { /node\:/, /js-tiktoken/, /@langchain\/core/, + /@langchain\/community/, "axios", // axios is a dependency of openai "convex", "convex/server", diff --git a/langchain/src/agents/openai/output_parser.ts b/langchain/src/agents/openai/output_parser.ts index 960f6db84e47..fc14b6a4f160 100644 --- a/langchain/src/agents/openai/output_parser.ts +++ b/langchain/src/agents/openai/output_parser.ts @@ -1,4 +1,4 @@ -import type { OpenAI as OpenAIClient } from "openai"; +import type { OpenAIClient } from "@langchain/openai"; import { AgentAction, AgentFinish, diff --git a/langchain/src/cache/cloudflare_kv.ts b/langchain/src/cache/cloudflare_kv.ts index d438c7cd7cc5..b3f86e2b187e 100644 --- a/langchain/src/cache/cloudflare_kv.ts +++ b/langchain/src/cache/cloudflare_kv.ts @@ -1,77 +1 @@ -import type { KVNamespace } from "@cloudflare/workers-types"; - -import { BaseCache, Generation } from "../schema/index.js"; -import { - getCacheKey, - serializeGeneration, - deserializeStoredGeneration, -} from "./base.js"; - -/** - * Represents a specific implementation of a caching mechanism using Cloudflare KV - * as the underlying storage system. It extends the `BaseCache` class and - * overrides its methods to provide the Cloudflare KV-specific logic. - * @example - * ```typescript - * // Example of using OpenAI with Cloudflare KV as cache in a Cloudflare Worker - * const cache = new CloudflareKVCache(env.KV_NAMESPACE); - * const model = new ChatAnthropic({ - * cache, - * }); - * const response = await model.invoke("How are you today?"); - * return new Response(JSON.stringify(response), { - * headers: { "content-type": "application/json" }, - * }); - * - * ``` - */ -export class CloudflareKVCache extends BaseCache { - private binding: KVNamespace; - - constructor(binding: KVNamespace) { - super(); - this.binding = binding; - } - - /** - * Retrieves data from the cache. It constructs a cache key from the given - * `prompt` and `llmKey`, and retrieves the corresponding value from the - * Cloudflare KV namespace. - * @param prompt The prompt used to construct the cache key. - * @param llmKey The LLM key used to construct the cache key. - * @returns An array of Generations if found, null otherwise. - */ - public async lookup(prompt: string, llmKey: string) { - let idx = 0; - let key = getCacheKey(prompt, llmKey, String(idx)); - let value = await this.binding.get(key); - const generations: Generation[] = []; - - while (value) { - generations.push(deserializeStoredGeneration(JSON.parse(value))); - idx += 1; - key = getCacheKey(prompt, llmKey, String(idx)); - value = await this.binding.get(key); - } - - return generations.length > 0 ? generations : null; - } - - /** - * Updates the cache with new data. It constructs a cache key from the - * given `prompt` and `llmKey`, and stores the `value` in the Cloudflare KV - * namespace. - * @param prompt The prompt used to construct the cache key. - * @param llmKey The LLM key used to construct the cache key. - * @param value The value to be stored in the cache. - */ - public async update(prompt: string, llmKey: string, value: Generation[]) { - for (let i = 0; i < value.length; i += 1) { - const key = getCacheKey(prompt, llmKey, String(i)); - await this.binding.put( - key, - JSON.stringify(serializeGeneration(value[i])) - ); - } - } -} +export * from "@langchain/community/caches/cloudflare_kv"; diff --git a/langchain/src/cache/momento.ts b/langchain/src/cache/momento.ts index 3c452a429e45..0a24cf70f21b 100644 --- a/langchain/src/cache/momento.ts +++ b/langchain/src/cache/momento.ts @@ -1,173 +1 @@ -/* eslint-disable no-instanceof/no-instanceof */ -import { - ICacheClient, - CacheGet, - CacheSet, - InvalidArgumentError, -} from "@gomomento/sdk-core"; - -import { BaseCache, Generation } from "../schema/index.js"; -import { - deserializeStoredGeneration, - getCacheKey, - serializeGeneration, -} from "./base.js"; -import { ensureCacheExists } from "../util/momento.js"; - -/** - * The settings to instantiate the Momento standard cache. - */ -export interface MomentoCacheProps { - /** - * The Momento cache client. - */ - client: ICacheClient; - /** - * The name of the cache to use to store the data. - */ - cacheName: string; - /** - * The time to live for the cache items. If not specified, - * the cache client default is used. - */ - ttlSeconds?: number; - /** - * If true, ensure that the cache exists before returning. - * If false, the cache is not checked for existence. - * Defaults to true. - */ - ensureCacheExists?: true; -} - -/** - * A cache that uses Momento as the backing store. - * See https://gomomento.com. - * @example - * ```typescript - * const cache = new MomentoCache({ - * client: new CacheClient({ - * configuration: Configurations.Laptop.v1(), - * credentialProvider: CredentialProvider.fromEnvironmentVariable({ - * environmentVariableName: "MOMENTO_API_KEY", - * }), - * defaultTtlSeconds: 60 * 60 * 24, // Cache TTL set to 24 hours. - * }), - * cacheName: "langchain", - * }); - * // Initialize the OpenAI model with Momento cache for caching responses - * const model = new ChatOpenAI({ - * cache, - * }); - * await model.invoke("How are you today?"); - * const cachedValues = await cache.lookup("How are you today?", "llmKey"); - * ``` - */ -export class MomentoCache extends BaseCache { - private client: ICacheClient; - - private readonly cacheName: string; - - private readonly ttlSeconds?: number; - - private constructor(props: MomentoCacheProps) { - super(); - this.client = props.client; - this.cacheName = props.cacheName; - - this.validateTtlSeconds(props.ttlSeconds); - this.ttlSeconds = props.ttlSeconds; - } - - /** - * Create a new standard cache backed by Momento. - * - * @param {MomentoCacheProps} props The settings to instantiate the cache. - * @param {ICacheClient} props.client The Momento cache client. - * @param {string} props.cacheName The name of the cache to use to store the data. - * @param {number} props.ttlSeconds The time to live for the cache items. If not specified, - * the cache client default is used. - * @param {boolean} props.ensureCacheExists If true, ensure that the cache exists before returning. - * If false, the cache is not checked for existence. Defaults to true. - * @throws {@link InvalidArgumentError} if {@link props.ttlSeconds} is not strictly positive. - * @returns The Momento-backed cache. - */ - public static async fromProps( - props: MomentoCacheProps - ): Promise { - const instance = new MomentoCache(props); - if (props.ensureCacheExists || props.ensureCacheExists === undefined) { - await ensureCacheExists(props.client, props.cacheName); - } - return instance; - } - - /** - * Validate the user-specified TTL, if provided, is strictly positive. - * @param ttlSeconds The TTL to validate. - */ - private validateTtlSeconds(ttlSeconds?: number): void { - if (ttlSeconds !== undefined && ttlSeconds <= 0) { - throw new InvalidArgumentError("ttlSeconds must be positive."); - } - } - - /** - * Lookup LLM generations in cache by prompt and associated LLM key. - * @param prompt The prompt to lookup. - * @param llmKey The LLM key to lookup. - * @returns The generations associated with the prompt and LLM key, or null if not found. - */ - public async lookup( - prompt: string, - llmKey: string - ): Promise { - const key = getCacheKey(prompt, llmKey); - const getResponse = await this.client.get(this.cacheName, key); - - if (getResponse instanceof CacheGet.Hit) { - const value = getResponse.valueString(); - const parsedValue = JSON.parse(value); - if (!Array.isArray(parsedValue)) { - return null; - } - return JSON.parse(value).map(deserializeStoredGeneration); - } else if (getResponse instanceof CacheGet.Miss) { - return null; - } else if (getResponse instanceof CacheGet.Error) { - throw getResponse.innerException(); - } else { - throw new Error(`Unknown response type: ${getResponse.toString()}`); - } - } - - /** - * Update the cache with the given generations. - * - * Note this overwrites any existing generations for the given prompt and LLM key. - * - * @param prompt The prompt to update. - * @param llmKey The LLM key to update. - * @param value The generations to store. - */ - public async update( - prompt: string, - llmKey: string, - value: Generation[] - ): Promise { - const key = getCacheKey(prompt, llmKey); - const setResponse = await this.client.set( - this.cacheName, - key, - JSON.stringify(value.map(serializeGeneration)), - { ttl: this.ttlSeconds } - ); - - if (setResponse instanceof CacheSet.Success) { - // pass - } else if (setResponse instanceof CacheSet.Error) { - throw setResponse.innerException(); - } else { - throw new Error(`Unknown response type: ${setResponse.toString()}`); - } - } -} +export * from "@langchain/community/caches/momento"; diff --git a/langchain/src/cache/upstash_redis.ts b/langchain/src/cache/upstash_redis.ts index 7f1660d6606d..8e1a82be82e3 100644 --- a/langchain/src/cache/upstash_redis.ts +++ b/langchain/src/cache/upstash_redis.ts @@ -1,91 +1 @@ -import { Redis, type RedisConfigNodejs } from "@upstash/redis"; - -import { BaseCache, Generation, StoredGeneration } from "../schema/index.js"; -import { - deserializeStoredGeneration, - getCacheKey, - serializeGeneration, -} from "./base.js"; - -export type UpstashRedisCacheProps = { - /** - * The config to use to instantiate an Upstash Redis client. - */ - config?: RedisConfigNodejs; - /** - * An existing Upstash Redis client. - */ - client?: Redis; -}; - -/** - * A cache that uses Upstash as the backing store. - * See https://docs.upstash.com/redis. - * @example - * ```typescript - * const cache = new UpstashRedisCache({ - * config: { - * url: "UPSTASH_REDIS_REST_URL", - * token: "UPSTASH_REDIS_REST_TOKEN", - * }, - * }); - * // Initialize the OpenAI model with Upstash Redis cache for caching responses - * const model = new ChatOpenAI({ - * cache, - * }); - * await model.invoke("How are you today?"); - * const cachedValues = await cache.lookup("How are you today?", "llmKey"); - * ``` - */ -export class UpstashRedisCache extends BaseCache { - private redisClient: Redis; - - constructor(props: UpstashRedisCacheProps) { - super(); - const { config, client } = props; - - if (client) { - this.redisClient = client; - } else if (config) { - this.redisClient = new Redis(config); - } else { - throw new Error( - `Upstash Redis caches require either a config object or a pre-configured client.` - ); - } - } - - /** - * Lookup LLM generations in cache by prompt and associated LLM key. - */ - public async lookup(prompt: string, llmKey: string) { - let idx = 0; - let key = getCacheKey(prompt, llmKey, String(idx)); - let value = await this.redisClient.get(key); - const generations: Generation[] = []; - - while (value) { - generations.push(deserializeStoredGeneration(value)); - idx += 1; - key = getCacheKey(prompt, llmKey, String(idx)); - value = await this.redisClient.get(key); - } - - return generations.length > 0 ? generations : null; - } - - /** - * Update the cache with the given generations. - * - * Note this overwrites any existing generations for the given prompt and LLM key. - */ - public async update(prompt: string, llmKey: string, value: Generation[]) { - for (let i = 0; i < value.length; i += 1) { - const key = getCacheKey(prompt, llmKey, String(i)); - await this.redisClient.set( - key, - JSON.stringify(serializeGeneration(value[i])) - ); - } - } -} +export * from "@langchain/community/caches/upstash_redis"; diff --git a/langchain/src/callbacks/handlers/llmonitor.ts b/langchain/src/callbacks/handlers/llmonitor.ts index 9453d62ed7ef..92fe3e48f946 100644 --- a/langchain/src/callbacks/handlers/llmonitor.ts +++ b/langchain/src/callbacks/handlers/llmonitor.ts @@ -1,340 +1 @@ -import monitor from "llmonitor"; -import { LLMonitorOptions, ChatMessage, cJSON } from "llmonitor/types"; - -import { BaseRun, RunUpdate as BaseRunUpdate, KVMap } from "langsmith/schemas"; - -import { getEnvironmentVariable } from "../../util/env.js"; - -import { - BaseMessage, - ChainValues, - Generation, - LLMResult, -} from "../../schema/index.js"; -import { Serialized } from "../../load/serializable.js"; - -import { BaseCallbackHandler, BaseCallbackHandlerInput } from "../base.js"; - -type Role = "user" | "ai" | "system" | "function" | "tool"; - -// Langchain Helpers -// Input can be either a single message, an array of message, or an array of array of messages (batch requests) - -const parseRole = (id: string[]): Role => { - const roleHint = id[id.length - 1]; - - if (roleHint.includes("Human")) return "user"; - if (roleHint.includes("System")) return "system"; - if (roleHint.includes("AI")) return "ai"; - if (roleHint.includes("Function")) return "function"; - if (roleHint.includes("Tool")) return "tool"; - - return "ai"; -}; - -type Message = BaseMessage | Generation | string; - -type OutputMessage = ChatMessage | string; - -const PARAMS_TO_CAPTURE = [ - "stop", - "stop_sequences", - "function_call", - "functions", - "tools", - "tool_choice", - "response_format", -]; - -export const convertToLLMonitorMessages = ( - input: Message | Message[] | Message[][] -): OutputMessage | OutputMessage[] | OutputMessage[][] => { - const parseMessage = (raw: Message): OutputMessage => { - if (typeof raw === "string") return raw; - // sometimes the message is nested in a "message" property - if ("message" in raw) return parseMessage(raw.message as Message); - - // Serialize - const message = JSON.parse(JSON.stringify(raw)); - - try { - // "id" contains an array describing the constructor, with last item actual schema type - const role = parseRole(message.id); - - const obj = message.kwargs; - const text = message.text ?? obj.content; - - return { - role, - text, - ...(obj.additional_kwargs ?? {}), - }; - } catch (e) { - // if parsing fails, return the original message - return message.text ?? message; - } - }; - - if (Array.isArray(input)) { - // eslint-disable-next-line @typescript-eslint/ban-ts-comment - // @ts-ignore Confuses the compiler - return input.length === 1 - ? convertToLLMonitorMessages(input[0]) - : input.map(convertToLLMonitorMessages); - } - return parseMessage(input); -}; - -const parseInput = (rawInput: Record) => { - if (!rawInput) return null; - - const { input, inputs, question } = rawInput; - - if (input) return input; - if (inputs) return inputs; - if (question) return question; - - return rawInput; -}; - -const parseOutput = (rawOutput: Record) => { - if (!rawOutput) return null; - - const { text, output, answer, result } = rawOutput; - - if (text) return text; - if (answer) return answer; - if (output) return output; - if (result) return result; - - return rawOutput; -}; - -const parseExtraAndName = ( - llm: Serialized, - extraParams?: KVMap, - metadata?: KVMap -) => { - const params = { - ...(extraParams?.invocation_params ?? {}), - // eslint-disable-next-line @typescript-eslint/ban-ts-comment - // @ts-ignore this is a valid property - ...(llm?.kwargs ?? {}), - ...(metadata || {}), - }; - - const { model, model_name, modelName, model_id, userId, userProps, ...rest } = - params; - - const name = model || modelName || model_name || model_id || llm.id.at(-1); - - // Filter rest to only include params we want to capture - const extra = Object.fromEntries( - Object.entries(rest).filter( - ([key]) => - PARAMS_TO_CAPTURE.includes(key) || - ["string", "number", "boolean"].includes(typeof rest[key]) - ) - ) as cJSON; - - return { name, extra, userId, userProps }; -}; - -export interface Run extends BaseRun { - id: string; - child_runs: this[]; - child_execution_order: number; -} - -export interface RunUpdate extends BaseRunUpdate { - events: BaseRun["events"]; -} - -export interface LLMonitorHandlerFields - extends BaseCallbackHandlerInput, - LLMonitorOptions {} - -export class LLMonitorHandler - extends BaseCallbackHandler - implements LLMonitorHandlerFields -{ - name = "llmonitor_handler"; - - monitor: typeof monitor; - - constructor(fields: LLMonitorHandlerFields = {}) { - super(fields); - - this.monitor = monitor; - - if (fields) { - const { appId, apiUrl, verbose } = fields; - - this.monitor.init({ - verbose, - appId: appId ?? getEnvironmentVariable("LLMONITOR_APP_ID"), - apiUrl: apiUrl ?? getEnvironmentVariable("LLMONITOR_API_URL"), - }); - } - } - - async handleLLMStart( - llm: Serialized, - prompts: string[], - runId: string, - parentRunId?: string, - extraParams?: KVMap, - tags?: string[], - metadata?: KVMap - ): Promise { - const { name, extra, userId, userProps } = parseExtraAndName( - llm, - extraParams, - metadata - ); - - await this.monitor.trackEvent("llm", "start", { - runId, - parentRunId, - name, - input: convertToLLMonitorMessages(prompts), - extra, - userId, - userProps, - tags, - runtime: "langchain-js", - }); - } - - async handleChatModelStart( - llm: Serialized, - messages: BaseMessage[][], - runId: string, - parentRunId?: string, - extraParams?: KVMap, - tags?: string[], - metadata?: KVMap - ): Promise { - const { name, extra, userId, userProps } = parseExtraAndName( - llm, - extraParams, - metadata - ); - - await this.monitor.trackEvent("llm", "start", { - runId, - parentRunId, - name, - input: convertToLLMonitorMessages(messages), - extra, - userId, - userProps, - tags, - runtime: "langchain-js", - }); - } - - async handleLLMEnd(output: LLMResult, runId: string): Promise { - const { generations, llmOutput } = output; - - await this.monitor.trackEvent("llm", "end", { - runId, - output: convertToLLMonitorMessages(generations), - tokensUsage: { - completion: llmOutput?.tokenUsage?.completionTokens, - prompt: llmOutput?.tokenUsage?.promptTokens, - }, - }); - } - - async handleLLMError(error: Error, runId: string): Promise { - await this.monitor.trackEvent("llm", "error", { - runId, - error, - }); - } - - async handleChainStart( - chain: Serialized, - inputs: ChainValues, - runId: string, - parentRunId?: string, - tags?: string[], - metadata?: KVMap - ): Promise { - const { agentName, userId, userProps, ...rest } = metadata || {}; - - // allow the user to specify an agent name - const name = agentName || chain.id.at(-1); - - // Attempt to automatically detect if this is an agent or chain - const runType = - agentName || ["AgentExecutor", "PlanAndExecute"].includes(name) - ? "agent" - : "chain"; - - await this.monitor.trackEvent(runType, "start", { - runId, - parentRunId, - name, - userId, - userProps, - input: parseInput(inputs) as cJSON, - extra: rest, - tags, - runtime: "langchain-js", - }); - } - - async handleChainEnd(outputs: ChainValues, runId: string): Promise { - await this.monitor.trackEvent("chain", "end", { - runId, - output: parseOutput(outputs) as cJSON, - }); - } - - async handleChainError(error: Error, runId: string): Promise { - await this.monitor.trackEvent("chain", "error", { - runId, - error, - }); - } - - async handleToolStart( - tool: Serialized, - input: string, - runId: string, - parentRunId?: string, - tags?: string[], - metadata?: KVMap - ): Promise { - const { toolName, userId, userProps, ...rest } = metadata || {}; - const name = toolName || tool.id.at(-1); - - await this.monitor.trackEvent("tool", "start", { - runId, - parentRunId, - name, - userId, - userProps, - input, - extra: rest, - tags, - runtime: "langchain-js", - }); - } - - async handleToolEnd(output: string, runId: string): Promise { - await this.monitor.trackEvent("tool", "end", { - runId, - output, - }); - } - - async handleToolError(error: Error, runId: string): Promise { - await this.monitor.trackEvent("tool", "error", { - runId, - error, - }); - } -} +export * from "@langchain/community/callbacks/handlers/llmonitor"; diff --git a/langchain/src/callbacks/tests/llmonitor.int.test.ts b/langchain/src/callbacks/tests/llmonitor.int.test.ts index eb796840d66c..62c589f501d8 100644 --- a/langchain/src/callbacks/tests/llmonitor.int.test.ts +++ b/langchain/src/callbacks/tests/llmonitor.int.test.ts @@ -16,7 +16,7 @@ import { Calculator } from "../../tools/calculator.js"; import { initializeAgentExecutorWithOptions } from "../../agents/initialize.js"; -test("Test traced agent with openai functions", async () => { +test.skip("Test traced agent with openai functions", async () => { const tools = [new Calculator()]; const chat = new ChatOpenAI({ modelName: "gpt-3.5-turbo", temperature: 0 }); @@ -41,7 +41,7 @@ test("Test traced agent with openai functions", async () => { console.log(result); }); -test("Test traced chain with tags", async () => { +test.skip("Test traced chain with tags", async () => { const llm = new OpenAI(); const qaPrompt = new PromptTemplate({ template: "Q: {question} A:", @@ -75,7 +75,7 @@ test("Test traced chain with tags", async () => { ); }); -test("Test traced chat call with tags", async () => { +test.skip("Test traced chat call with tags", async () => { const chat = new ChatOpenAI({ callbacks: [new LLMonitorHandler({ verbose: true })], }); diff --git a/langchain/src/chains/openai_functions/openapi.ts b/langchain/src/chains/openai_functions/openapi.ts index 06eb92ad4995..5306eef626db 100644 --- a/langchain/src/chains/openai_functions/openapi.ts +++ b/langchain/src/chains/openai_functions/openapi.ts @@ -1,4 +1,4 @@ -import type { OpenAI as OpenAIClient } from "openai"; +import type { OpenAIClient } from "@langchain/openai"; import { JsonSchema7ObjectType } from "zod-to-json-schema/src/parsers/object.js"; import { JsonSchema7ArrayType } from "zod-to-json-schema/src/parsers/array.js"; import { JsonSchema7Type } from "zod-to-json-schema/src/parseDef.js"; diff --git a/langchain/src/chains/openai_moderation.ts b/langchain/src/chains/openai_moderation.ts index 862552099405..2474baaab0ba 100644 --- a/langchain/src/chains/openai_moderation.ts +++ b/langchain/src/chains/openai_moderation.ts @@ -1,4 +1,4 @@ -import { type ClientOptions, OpenAI as OpenAIClient } from "openai"; +import { type ClientOptions, OpenAIClient } from "@langchain/openai"; import { BaseChain, ChainInputs } from "./base.js"; import { ChainValues } from "../schema/index.js"; import { AsyncCaller, AsyncCallerParams } from "../util/async_caller.js"; diff --git a/langchain/src/chat_models/baiduwenxin.ts b/langchain/src/chat_models/baiduwenxin.ts index 060da0298e76..61ffe21de4d4 100644 --- a/langchain/src/chat_models/baiduwenxin.ts +++ b/langchain/src/chat_models/baiduwenxin.ts @@ -1,542 +1 @@ -import { BaseChatModel, BaseChatModelParams } from "./base.js"; -import { - AIMessage, - BaseMessage, - ChatGeneration, - ChatMessage, - ChatResult, -} from "../schema/index.js"; -import { CallbackManagerForLLMRun } from "../callbacks/manager.js"; -import { getEnvironmentVariable } from "../util/env.js"; - -/** - * Type representing the role of a message in the Wenxin chat model. - */ -export type WenxinMessageRole = "assistant" | "user"; - -/** - * Interface representing a message in the Wenxin chat model. - */ -interface WenxinMessage { - role: WenxinMessageRole; - content: string; -} - -/** - * Interface representing the usage of tokens in a chat completion. - */ -interface TokenUsage { - completionTokens?: number; - promptTokens?: number; - totalTokens?: number; -} - -/** - * Interface representing a request for a chat completion. - */ -interface ChatCompletionRequest { - messages: WenxinMessage[]; - stream?: boolean; - user_id?: string; - temperature?: number; - top_p?: number; - penalty_score?: number; - system?: string; -} - -/** - * Interface representing a response from a chat completion. - */ -interface ChatCompletionResponse { - id: string; - object: string; - created: number; - result: string; - need_clear_history: boolean; - usage: TokenUsage; -} - -/** - * Interface defining the input to the ChatBaiduWenxin class. - */ -declare interface BaiduWenxinChatInput { - /** Model name to use. Available options are: ERNIE-Bot, ERNIE-Bot-turbo, ERNIE-Bot-4 - * @default "ERNIE-Bot-turbo" - */ - modelName: string; - - /** Whether to stream the results or not. Defaults to false. */ - streaming?: boolean; - - /** Messages to pass as a prefix to the prompt */ - prefixMessages?: WenxinMessage[]; - - /** - * ID of the end-user who made requests. - */ - userId?: string; - - /** - * API key to use when making requests. Defaults to the value of - * `BAIDU_API_KEY` environment variable. - */ - baiduApiKey?: string; - - /** - * Secret key to use when making requests. Defaults to the value of - * `BAIDU_SECRET_KEY` environment variable. - */ - baiduSecretKey?: string; - - /** Amount of randomness injected into the response. Ranges - * from 0 to 1 (0 is not included). Use temp closer to 0 for analytical / - * multiple choice, and temp closer to 1 for creative - * and generative tasks. Defaults to 0.95. - */ - temperature?: number; - - /** Total probability mass of tokens to consider at each step. Range - * from 0 to 1.0. Defaults to 0.8. - */ - topP?: number; - - /** Penalizes repeated tokens according to frequency. Range - * from 1.0 to 2.0. Defaults to 1.0. - */ - penaltyScore?: number; -} - -/** - * Function that extracts the custom role of a generic chat message. - * @param message Chat message from which to extract the custom role. - * @returns The custom role of the chat message. - */ -function extractGenericMessageCustomRole(message: ChatMessage) { - if (message.role !== "assistant" && message.role !== "user") { - console.warn(`Unknown message role: ${message.role}`); - } - - return message.role as WenxinMessageRole; -} - -/** - * Function that converts a base message to a Wenxin message role. - * @param message Base message to convert. - * @returns The Wenxin message role. - */ -function messageToWenxinRole(message: BaseMessage): WenxinMessageRole { - const type = message._getType(); - switch (type) { - case "ai": - return "assistant"; - case "human": - return "user"; - case "system": - throw new Error("System messages should not be here"); - case "function": - throw new Error("Function messages not supported"); - case "generic": { - if (!ChatMessage.isInstance(message)) - throw new Error("Invalid generic chat message"); - return extractGenericMessageCustomRole(message); - } - default: - throw new Error(`Unknown message type: ${type}`); - } -} - -/** - * Wrapper around Baidu ERNIE large language models that use the Chat endpoint. - * - * To use you should have the `BAIDU_API_KEY` and `BAIDU_SECRET_KEY` - * environment variable set. - * - * @augments BaseLLM - * @augments BaiduERNIEInput - * @example - * ```typescript - * const ernieTurbo = new ChatBaiduWenxin({ - * baiduApiKey: "YOUR-API-KEY", - * baiduSecretKey: "YOUR-SECRET-KEY", - * }); - * - * const ernie = new ChatBaiduWenxin({ - * modelName: "ERNIE-Bot", - * temperature: 1, - * baiduApiKey: "YOUR-API-KEY", - * baiduSecretKey: "YOUR-SECRET-KEY", - * }); - * - * const messages = [new HumanMessage("Hello")]; - * - * let res = await ernieTurbo.call(messages); - * - * res = await ernie.call(messages); - * ``` - */ -export class ChatBaiduWenxin - extends BaseChatModel - implements BaiduWenxinChatInput -{ - static lc_name() { - return "ChatBaiduWenxin"; - } - - get callKeys(): string[] { - return ["stop", "signal", "options"]; - } - - get lc_secrets(): { [key: string]: string } | undefined { - return { - baiduApiKey: "BAIDU_API_KEY", - baiduSecretKey: "BAIDU_SECRET_KEY", - }; - } - - get lc_aliases(): { [key: string]: string } | undefined { - return undefined; - } - - lc_serializable = true; - - baiduApiKey?: string; - - baiduSecretKey?: string; - - accessToken: string; - - streaming = false; - - prefixMessages?: WenxinMessage[]; - - userId?: string; - - modelName = "ERNIE-Bot-turbo"; - - apiUrl: string; - - temperature?: number | undefined; - - topP?: number | undefined; - - penaltyScore?: number | undefined; - - constructor(fields?: Partial & BaseChatModelParams) { - super(fields ?? {}); - - this.baiduApiKey = - fields?.baiduApiKey ?? getEnvironmentVariable("BAIDU_API_KEY"); - if (!this.baiduApiKey) { - throw new Error("Baidu API key not found"); - } - - this.baiduSecretKey = - fields?.baiduSecretKey ?? getEnvironmentVariable("BAIDU_SECRET_KEY"); - if (!this.baiduSecretKey) { - throw new Error("Baidu Secret key not found"); - } - - this.streaming = fields?.streaming ?? this.streaming; - this.prefixMessages = fields?.prefixMessages ?? this.prefixMessages; - this.userId = fields?.userId ?? this.userId; - this.temperature = fields?.temperature ?? this.temperature; - this.topP = fields?.topP ?? this.topP; - this.penaltyScore = fields?.penaltyScore ?? this.penaltyScore; - - this.modelName = fields?.modelName ?? this.modelName; - - if (this.modelName === "ERNIE-Bot") { - this.apiUrl = - "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"; - } else if (this.modelName === "ERNIE-Bot-turbo") { - this.apiUrl = - "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"; - } else if (this.modelName === "ERNIE-Bot-4") { - this.apiUrl = - "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"; - } else { - throw new Error(`Invalid model name: ${this.modelName}`); - } - } - - /** - * Method that retrieves the access token for making requests to the Baidu - * API. - * @param options Optional parsed call options. - * @returns The access token for making requests to the Baidu API. - */ - async getAccessToken(options?: this["ParsedCallOptions"]) { - const url = `https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=${this.baiduApiKey}&client_secret=${this.baiduSecretKey}`; - const response = await fetch(url, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "application/json", - }, - signal: options?.signal, - }); - if (!response.ok) { - const text = await response.text(); - const error = new Error( - `Baidu get access token failed with status code ${response.status}, response: ${text}` - ); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (error as any).response = response; - throw error; - } - const json = await response.json(); - return json.access_token; - } - - /** - * Get the parameters used to invoke the model - */ - invocationParams(): Omit { - return { - stream: this.streaming, - user_id: this.userId, - temperature: this.temperature, - top_p: this.topP, - penalty_score: this.penaltyScore, - }; - } - - /** - * Get the identifying parameters for the model - */ - identifyingParams() { - return { - model_name: this.modelName, - ...this.invocationParams(), - }; - } - - /** @ignore */ - async _generate( - messages: BaseMessage[], - options?: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): Promise { - const tokenUsage: TokenUsage = {}; - - const params = this.invocationParams(); - - // Wenxin requires the system message to be put in the params, not messages array - const systemMessage = messages.find( - (message) => message._getType() === "system" - ); - if (systemMessage) { - // eslint-disable-next-line no-param-reassign - messages = messages.filter((message) => message !== systemMessage); - params.system = systemMessage.text; - } - const messagesMapped: WenxinMessage[] = messages.map((message) => ({ - role: messageToWenxinRole(message), - content: message.text, - })); - - const data = params.stream - ? await new Promise((resolve, reject) => { - let response: ChatCompletionResponse; - let rejected = false; - let resolved = false; - this.completionWithRetry( - { - ...params, - messages: messagesMapped, - }, - true, - options?.signal, - (event) => { - const data = JSON.parse(event.data); - - if (data?.error_code) { - if (rejected) { - return; - } - rejected = true; - reject(new Error(data?.error_msg)); - return; - } - - const message = data as { - id: string; - object: string; - created: number; - sentence_id?: number; - is_end: boolean; - result: string; - need_clear_history: boolean; - usage: TokenUsage; - }; - - // on the first message set the response properties - if (!response) { - response = { - id: message.id, - object: message.object, - created: message.created, - result: message.result, - need_clear_history: message.need_clear_history, - usage: message.usage, - }; - } else { - response.result += message.result; - response.created = message.created; - response.need_clear_history = message.need_clear_history; - response.usage = message.usage; - } - - // TODO this should pass part.index to the callback - // when that's supported there - // eslint-disable-next-line no-void - void runManager?.handleLLMNewToken(message.result ?? ""); - - if (message.is_end) { - if (resolved || rejected) { - return; - } - resolved = true; - resolve(response); - } - } - ).catch((error) => { - if (!rejected) { - rejected = true; - reject(error); - } - }); - }) - : await this.completionWithRetry( - { - ...params, - messages: messagesMapped, - }, - false, - options?.signal - ).then((data) => { - if (data?.error_code) { - throw new Error(data?.error_msg); - } - return data; - }); - - const { - completion_tokens: completionTokens, - prompt_tokens: promptTokens, - total_tokens: totalTokens, - } = data.usage ?? {}; - - if (completionTokens) { - tokenUsage.completionTokens = - (tokenUsage.completionTokens ?? 0) + completionTokens; - } - - if (promptTokens) { - tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + promptTokens; - } - - if (totalTokens) { - tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens; - } - - const generations: ChatGeneration[] = []; - const text = data.result ?? ""; - generations.push({ - text, - message: new AIMessage(text), - }); - return { - generations, - llmOutput: { tokenUsage }, - }; - } - - /** @ignore */ - async completionWithRetry( - request: ChatCompletionRequest, - stream: boolean, - signal?: AbortSignal, - onmessage?: (event: MessageEvent) => void - ) { - // The first run will get the accessToken - if (!this.accessToken) { - this.accessToken = await this.getAccessToken(); - } - - const makeCompletionRequest = async () => { - const url = `${this.apiUrl}?access_token=${this.accessToken}`; - const response = await fetch(url, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify(request), - signal, - }); - - if (!stream) { - return response.json(); - } else { - if (response.body) { - // response will not be a stream if an error occurred - if ( - !response.headers - .get("content-type") - ?.startsWith("text/event-stream") - ) { - onmessage?.( - new MessageEvent("message", { - data: await response.text(), - }) - ); - return; - } - - const reader = response.body.getReader(); - - const decoder = new TextDecoder("utf-8"); - let data = ""; - - let continueReading = true; - while (continueReading) { - const { done, value } = await reader.read(); - if (done) { - continueReading = false; - break; - } - data += decoder.decode(value); - - let continueProcessing = true; - while (continueProcessing) { - const newlineIndex = data.indexOf("\n"); - if (newlineIndex === -1) { - continueProcessing = false; - break; - } - const line = data.slice(0, newlineIndex); - data = data.slice(newlineIndex + 1); - - if (line.startsWith("data:")) { - const event = new MessageEvent("message", { - data: line.slice("data:".length).trim(), - }); - onmessage?.(event); - } - } - } - } - } - }; - return this.caller.call(makeCompletionRequest); - } - - _llmType() { - return "baiduwenxin"; - } - - /** @ignore */ - _combineLLMOutput() { - return []; - } -} +export * from "@langchain/community/chat_models/baiduwenxin"; diff --git a/langchain/src/chat_models/bedrock/index.ts b/langchain/src/chat_models/bedrock/index.ts index 04fabc096d00..a78523c28a2b 100644 --- a/langchain/src/chat_models/bedrock/index.ts +++ b/langchain/src/chat_models/bedrock/index.ts @@ -1,38 +1 @@ -import { defaultProvider } from "@aws-sdk/credential-provider-node"; -import { BaseBedrockInput } from "../../util/bedrock.js"; -import { BedrockChat as BaseBedrockChat } from "./web.js"; -import { BaseChatModelParams } from "../base.js"; - -/** - * @example - * ```typescript - * const model = new BedrockChat({ - * model: "anthropic.claude-v2", - * region: "us-east-1", - * }); - * const res = await model.invoke([{ content: "Tell me a joke" }]); - * console.log(res); - * ``` - */ -export class BedrockChat extends BaseBedrockChat { - static lc_name() { - return "BedrockChat"; - } - - constructor(fields?: Partial & BaseChatModelParams) { - super({ - ...fields, - credentials: fields?.credentials ?? defaultProvider(), - }); - } -} - -export { - convertMessagesToPromptAnthropic, - convertMessagesToPrompt, -} from "./web.js"; - -/** - * @deprecated Use `BedrockChat` instead. - */ -export const ChatBedrock = BedrockChat; +export * from "@langchain/community/chat_models/bedrock"; diff --git a/langchain/src/chat_models/bedrock/web.ts b/langchain/src/chat_models/bedrock/web.ts index 0431f5baa359..9ef97f768f8a 100644 --- a/langchain/src/chat_models/bedrock/web.ts +++ b/langchain/src/chat_models/bedrock/web.ts @@ -1,431 +1 @@ -import { SignatureV4 } from "@smithy/signature-v4"; -import { HttpRequest } from "@smithy/protocol-http"; -import { EventStreamCodec } from "@smithy/eventstream-codec"; -import { fromUtf8, toUtf8 } from "@smithy/util-utf8"; -import { Sha256 } from "@aws-crypto/sha256-js"; - -import { - BaseBedrockInput, - BedrockLLMInputOutputAdapter, - type CredentialType, -} from "../../util/bedrock.js"; -import { getEnvironmentVariable } from "../../util/env.js"; -import { SimpleChatModel, BaseChatModelParams } from "../base.js"; -import { CallbackManagerForLLMRun } from "../../callbacks/manager.js"; -import { - AIMessageChunk, - BaseMessage, - AIMessage, - ChatGenerationChunk, - ChatMessage, -} from "../../schema/index.js"; -import type { SerializedFields } from "../../load/map_keys.js"; - -function convertOneMessageToText( - message: BaseMessage, - humanPrompt: string, - aiPrompt: string -): string { - if (message._getType() === "human") { - return `${humanPrompt} ${message.content}`; - } else if (message._getType() === "ai") { - return `${aiPrompt} ${message.content}`; - } else if (message._getType() === "system") { - return `${humanPrompt} ${message.content}`; - } else if (ChatMessage.isInstance(message)) { - return `\n\n${ - message.role[0].toUpperCase() + message.role.slice(1) - }: {message.content}`; - } - throw new Error(`Unknown role: ${message._getType()}`); -} - -export function convertMessagesToPromptAnthropic( - messages: BaseMessage[], - humanPrompt = "\n\nHuman:", - aiPrompt = "\n\nAssistant:" -): string { - const messagesCopy = [...messages]; - - if ( - messagesCopy.length === 0 || - messagesCopy[messagesCopy.length - 1]._getType() !== "ai" - ) { - messagesCopy.push(new AIMessage({ content: "" })); - } - - return messagesCopy - .map((message) => convertOneMessageToText(message, humanPrompt, aiPrompt)) - .join(""); -} - -/** - * Function that converts an array of messages into a single string prompt - * that can be used as input for a chat model. It delegates the conversion - * logic to the appropriate provider-specific function. - * @param messages Array of messages to be converted. - * @param options Options to be used during the conversion. - * @returns A string prompt that can be used as input for a chat model. - */ -export function convertMessagesToPrompt( - messages: BaseMessage[], - provider: string -): string { - if (provider === "anthropic") { - return convertMessagesToPromptAnthropic(messages); - } - throw new Error(`Provider ${provider} does not support chat.`); -} - -/** - * A type of Large Language Model (LLM) that interacts with the Bedrock - * service. It extends the base `LLM` class and implements the - * `BaseBedrockInput` interface. The class is designed to authenticate and - * interact with the Bedrock service, which is a part of Amazon Web - * Services (AWS). It uses AWS credentials for authentication and can be - * configured with various parameters such as the model to use, the AWS - * region, and the maximum number of tokens to generate. - * @example - * ```typescript - * const model = new BedrockChat({ - * model: "anthropic.claude-v2", - * region: "us-east-1", - * }); - * const res = await model.invoke([{ content: "Tell me a joke" }]); - * console.log(res); - * ``` - */ -export class BedrockChat extends SimpleChatModel implements BaseBedrockInput { - model = "amazon.titan-tg1-large"; - - region: string; - - credentials: CredentialType; - - temperature?: number | undefined = undefined; - - maxTokens?: number | undefined = undefined; - - fetchFn: typeof fetch; - - endpointHost?: string; - - /** @deprecated */ - stopSequences?: string[]; - - modelKwargs?: Record; - - codec: EventStreamCodec = new EventStreamCodec(toUtf8, fromUtf8); - - streaming = false; - - lc_serializable = true; - - get lc_aliases(): Record { - return { - model: "model_id", - region: "region_name", - }; - } - - get lc_secrets(): { [key: string]: string } | undefined { - return { - "credentials.accessKeyId": "BEDROCK_AWS_ACCESS_KEY_ID", - "credentials.secretAccessKey": "BEDROCK_AWS_SECRET_ACCESS_KEY", - }; - } - - get lc_attributes(): SerializedFields | undefined { - return { region: this.region }; - } - - _llmType() { - return "bedrock"; - } - - static lc_name() { - return "BedrockChat"; - } - - constructor(fields?: Partial & BaseChatModelParams) { - super(fields ?? {}); - - this.model = fields?.model ?? this.model; - const allowedModels = ["ai21", "anthropic", "amazon", "cohere", "meta"]; - if (!allowedModels.includes(this.model.split(".")[0])) { - throw new Error( - `Unknown model: '${this.model}', only these are supported: ${allowedModels}` - ); - } - const region = - fields?.region ?? getEnvironmentVariable("AWS_DEFAULT_REGION"); - if (!region) { - throw new Error( - "Please set the AWS_DEFAULT_REGION environment variable or pass it to the constructor as the region field." - ); - } - this.region = region; - - const credentials = fields?.credentials; - if (!credentials) { - throw new Error( - "Please set the AWS credentials in the 'credentials' field." - ); - } - this.credentials = credentials; - - this.temperature = fields?.temperature ?? this.temperature; - this.maxTokens = fields?.maxTokens ?? this.maxTokens; - this.fetchFn = fields?.fetchFn ?? fetch.bind(globalThis); - this.endpointHost = fields?.endpointHost ?? fields?.endpointUrl; - this.stopSequences = fields?.stopSequences; - this.modelKwargs = fields?.modelKwargs; - this.streaming = fields?.streaming ?? this.streaming; - } - - /** Call out to Bedrock service model. - Arguments: - prompt: The prompt to pass into the model. - - Returns: - The string generated by the model. - - Example: - response = model.call("Tell me a joke.") - */ - async _call( - messages: BaseMessage[], - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): Promise { - const service = "bedrock-runtime"; - const endpointHost = - this.endpointHost ?? `${service}.${this.region}.amazonaws.com`; - const provider = this.model.split(".")[0]; - if (this.streaming) { - const stream = this._streamResponseChunks(messages, options, runManager); - let finalResult: ChatGenerationChunk | undefined; - for await (const chunk of stream) { - if (finalResult === undefined) { - finalResult = chunk; - } else { - finalResult = finalResult.concat(chunk); - } - } - const messageContent = finalResult?.message.content; - if (messageContent && typeof messageContent !== "string") { - throw new Error( - "Non-string output for ChatBedrock is currently not supported." - ); - } - return messageContent ?? ""; - } - - const response = await this._signedFetch(messages, options, { - bedrockMethod: "invoke", - endpointHost, - provider, - }); - const json = await response.json(); - if (!response.ok) { - throw new Error( - `Error ${response.status}: ${json.message ?? JSON.stringify(json)}` - ); - } - const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, json); - return text; - } - - async _signedFetch( - messages: BaseMessage[], - options: this["ParsedCallOptions"], - fields: { - bedrockMethod: "invoke" | "invoke-with-response-stream"; - endpointHost: string; - provider: string; - } - ) { - const { bedrockMethod, endpointHost, provider } = fields; - const inputBody = BedrockLLMInputOutputAdapter.prepareInput( - provider, - convertMessagesToPromptAnthropic(messages), - this.maxTokens, - this.temperature, - options.stop ?? this.stopSequences, - this.modelKwargs, - fields.bedrockMethod - ); - - const url = new URL( - `https://${endpointHost}/model/${this.model}/${bedrockMethod}` - ); - - const request = new HttpRequest({ - hostname: url.hostname, - path: url.pathname, - protocol: url.protocol, - method: "POST", // method must be uppercase - body: JSON.stringify(inputBody), - query: Object.fromEntries(url.searchParams.entries()), - headers: { - // host is required by AWS Signature V4: https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html - host: url.host, - accept: "application/json", - "content-type": "application/json", - }, - }); - - const signer = new SignatureV4({ - credentials: this.credentials, - service: "bedrock", - region: this.region, - sha256: Sha256, - }); - - const signedRequest = await signer.sign(request); - - // Send request to AWS using the low-level fetch API - const response = await this.caller.callWithOptions( - { signal: options.signal }, - async () => - this.fetchFn(url, { - headers: signedRequest.headers, - body: signedRequest.body, - method: signedRequest.method, - }) - ); - return response; - } - - async *_streamResponseChunks( - messages: BaseMessage[], - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): AsyncGenerator { - const provider = this.model.split(".")[0]; - const service = "bedrock-runtime"; - - const endpointHost = - this.endpointHost ?? `${service}.${this.region}.amazonaws.com`; - - const bedrockMethod = - provider === "anthropic" || provider === "cohere" || provider === "meta" - ? "invoke-with-response-stream" - : "invoke"; - - const response = await this._signedFetch(messages, options, { - bedrockMethod, - endpointHost, - provider, - }); - - if (response.status < 200 || response.status >= 300) { - throw Error( - `Failed to access underlying url '${endpointHost}': got ${ - response.status - } ${response.statusText}: ${await response.text()}` - ); - } - - if ( - provider === "anthropic" || - provider === "cohere" || - provider === "meta" - ) { - const reader = response.body?.getReader(); - const decoder = new TextDecoder(); - for await (const chunk of this._readChunks(reader)) { - const event = this.codec.decode(chunk); - if ( - (event.headers[":event-type"] !== undefined && - event.headers[":event-type"].value !== "chunk") || - event.headers[":content-type"].value !== "application/json" - ) { - throw Error(`Failed to get event chunk: got ${chunk}`); - } - const body = JSON.parse(decoder.decode(event.body)); - if (body.message) { - throw new Error(body.message); - } - if (body.bytes !== undefined) { - const chunkResult = JSON.parse( - decoder.decode( - Uint8Array.from(atob(body.bytes), (m) => m.codePointAt(0) ?? 0) - ) - ); - const text = BedrockLLMInputOutputAdapter.prepareOutput( - provider, - chunkResult - ); - yield new ChatGenerationChunk({ - text, - message: new AIMessageChunk({ content: text }), - }); - // eslint-disable-next-line no-void - void runManager?.handleLLMNewToken(text); - } - } - } else { - const json = await response.json(); - const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, json); - yield new ChatGenerationChunk({ - text, - message: new AIMessageChunk({ content: text }), - }); - // eslint-disable-next-line no-void - void runManager?.handleLLMNewToken(text); - } - } - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - _readChunks(reader: any) { - function _concatChunks(a: Uint8Array, b: Uint8Array) { - const newBuffer = new Uint8Array(a.length + b.length); - newBuffer.set(a); - newBuffer.set(b, a.length); - return newBuffer; - } - - function getMessageLength(buffer: Uint8Array) { - if (buffer.byteLength === 0) return 0; - const view = new DataView( - buffer.buffer, - buffer.byteOffset, - buffer.byteLength - ); - - return view.getUint32(0, false); - } - - return { - async *[Symbol.asyncIterator]() { - let readResult = await reader.read(); - - let buffer: Uint8Array = new Uint8Array(0); - while (!readResult.done) { - const chunk: Uint8Array = readResult.value; - - buffer = _concatChunks(buffer, chunk); - let messageLength = getMessageLength(buffer); - - while (buffer.byteLength > 0 && buffer.byteLength >= messageLength) { - yield buffer.slice(0, messageLength); - buffer = buffer.slice(messageLength); - messageLength = getMessageLength(buffer); - } - - readResult = await reader.read(); - } - }, - }; - } - - _combineLLMOutput() { - return {}; - } -} - -/** - * @deprecated Use `BedrockChat` instead. - */ -export const ChatBedrock = BedrockChat; +export * from "@langchain/community/chat_models/bedrock/web"; diff --git a/langchain/src/chat_models/cloudflare_workersai.ts b/langchain/src/chat_models/cloudflare_workersai.ts index b8d2f5971814..009e5183f0a0 100644 --- a/langchain/src/chat_models/cloudflare_workersai.ts +++ b/langchain/src/chat_models/cloudflare_workersai.ts @@ -1,247 +1 @@ -import { SimpleChatModel, BaseChatModelParams } from "./base.js"; -import { BaseLanguageModelCallOptions } from "../base_language/index.js"; -import { - AIMessageChunk, - BaseMessage, - ChatGenerationChunk, - ChatMessage, -} from "../schema/index.js"; -import { getEnvironmentVariable } from "../util/env.js"; -import { CloudflareWorkersAIInput } from "../llms/cloudflare_workersai.js"; -import { CallbackManagerForLLMRun } from "../callbacks/manager.js"; -import { convertEventStreamToIterableReadableDataStream } from "../util/event-source-parse.js"; - -/** - * An interface defining the options for a Cloudflare Workers AI call. It extends - * the BaseLanguageModelCallOptions interface. - */ -export interface ChatCloudflareWorkersAICallOptions - extends BaseLanguageModelCallOptions {} - -/** - * A class that enables calls to the Cloudflare Workers AI API to access large language - * models in a chat-like fashion. It extends the SimpleChatModel class and - * implements the CloudflareWorkersAIInput interface. - * @example - * ```typescript - * const model = new ChatCloudflareWorkersAI({ - * model: "@cf/meta/llama-2-7b-chat-int8", - * cloudflareAccountId: process.env.CLOUDFLARE_ACCOUNT_ID, - * cloudflareApiToken: process.env.CLOUDFLARE_API_TOKEN - * }); - * - * const response = await model.invoke([ - * ["system", "You are a helpful assistant that translates English to German."], - * ["human", `Translate "I love programming".`] - * ]); - * - * console.log(response); - * ``` - */ -export class ChatCloudflareWorkersAI - extends SimpleChatModel - implements CloudflareWorkersAIInput -{ - static lc_name() { - return "ChatCloudflareWorkersAI"; - } - - lc_serializable = true; - - model = "@cf/meta/llama-2-7b-chat-int8"; - - cloudflareAccountId?: string; - - cloudflareApiToken?: string; - - baseUrl: string; - - streaming = false; - - constructor(fields?: CloudflareWorkersAIInput & BaseChatModelParams) { - super(fields ?? {}); - - this.model = fields?.model ?? this.model; - this.streaming = fields?.streaming ?? this.streaming; - this.cloudflareAccountId = - fields?.cloudflareAccountId ?? - getEnvironmentVariable("CLOUDFLARE_ACCOUNT_ID"); - this.cloudflareApiToken = - fields?.cloudflareApiToken ?? - getEnvironmentVariable("CLOUDFLARE_API_TOKEN"); - this.baseUrl = - fields?.baseUrl ?? - `https://api.cloudflare.com/client/v4/accounts/${this.cloudflareAccountId}/ai/run`; - if (this.baseUrl.endsWith("/")) { - this.baseUrl = this.baseUrl.slice(0, -1); - } - } - - get lc_secrets(): { [key: string]: string } | undefined { - return { - cloudflareApiToken: "CLOUDFLARE_API_TOKEN", - }; - } - - _llmType() { - return "cloudflare"; - } - - /** Get the identifying parameters for this LLM. */ - get identifyingParams() { - return { model: this.model }; - } - - /** - * Get the parameters used to invoke the model - */ - invocationParams(_options?: this["ParsedCallOptions"]) { - return { - model: this.model, - }; - } - - _combineLLMOutput() { - return {}; - } - - /** - * Method to validate the environment. - */ - validateEnvironment() { - if (!this.cloudflareAccountId) { - throw new Error( - `No Cloudflare account ID found. Please provide it when instantiating the CloudflareWorkersAI class, or set it as "CLOUDFLARE_ACCOUNT_ID" in your environment variables.` - ); - } - if (!this.cloudflareApiToken) { - throw new Error( - `No Cloudflare API key found. Please provide it when instantiating the CloudflareWorkersAI class, or set it as "CLOUDFLARE_API_KEY" in your environment variables.` - ); - } - } - - async _request( - messages: BaseMessage[], - options: this["ParsedCallOptions"], - stream?: boolean - ) { - this.validateEnvironment(); - const url = `${this.baseUrl}/${this.model}`; - const headers = { - Authorization: `Bearer ${this.cloudflareApiToken}`, - "Content-Type": "application/json", - }; - - const formattedMessages = this._formatMessages(messages); - - const data = { messages: formattedMessages, stream }; - return this.caller.call(async () => { - const response = await fetch(url, { - method: "POST", - headers, - body: JSON.stringify(data), - signal: options.signal, - }); - if (!response.ok) { - const error = new Error( - `Cloudflare LLM call failed with status code ${response.status}` - ); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (error as any).response = response; - throw error; - } - return response; - }); - } - - async *_streamResponseChunks( - messages: BaseMessage[], - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): AsyncGenerator { - const response = await this._request(messages, options, true); - if (!response.body) { - throw new Error("Empty response from Cloudflare. Please try again."); - } - const stream = convertEventStreamToIterableReadableDataStream( - response.body - ); - for await (const chunk of stream) { - if (chunk !== "[DONE]") { - const parsedChunk = JSON.parse(chunk); - const generationChunk = new ChatGenerationChunk({ - message: new AIMessageChunk({ content: parsedChunk.response }), - text: parsedChunk.response, - }); - yield generationChunk; - // eslint-disable-next-line no-void - void runManager?.handleLLMNewToken(generationChunk.text ?? ""); - } - } - } - - protected _formatMessages( - messages: BaseMessage[] - ): { role: string; content: string }[] { - const formattedMessages = messages.map((message) => { - let role; - if (message._getType() === "human") { - role = "user"; - } else if (message._getType() === "ai") { - role = "assistant"; - } else if (message._getType() === "system") { - role = "system"; - } else if (ChatMessage.isInstance(message)) { - role = message.role; - } else { - console.warn( - `Unsupported message type passed to Cloudflare: "${message._getType()}"` - ); - role = "user"; - } - if (typeof message.content !== "string") { - throw new Error( - "ChatCloudflareWorkersAI currently does not support non-string message content." - ); - } - return { - role, - content: message.content, - }; - }); - return formattedMessages; - } - - /** @ignore */ - async _call( - messages: BaseMessage[], - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): Promise { - if (!this.streaming) { - const response = await this._request(messages, options); - - const responseData = await response.json(); - - return responseData.result.response; - } else { - const stream = this._streamResponseChunks(messages, options, runManager); - let finalResult: ChatGenerationChunk | undefined; - for await (const chunk of stream) { - if (finalResult === undefined) { - finalResult = chunk; - } else { - finalResult = finalResult.concat(chunk); - } - } - const messageContent = finalResult?.message.content; - if (messageContent && typeof messageContent !== "string") { - throw new Error( - "Non-string output for ChatCloudflareWorkersAI is currently not supported." - ); - } - return messageContent ?? ""; - } - } -} +export * from "@langchain/community/chat_models/cloudflare_workersai"; diff --git a/langchain/src/chat_models/fireworks.ts b/langchain/src/chat_models/fireworks.ts index 29e12cc34ea2..438f96c7c33d 100644 --- a/langchain/src/chat_models/fireworks.ts +++ b/langchain/src/chat_models/fireworks.ts @@ -1,137 +1 @@ -import type { OpenAI as OpenAIClient } from "openai"; -import type { ChatOpenAICallOptions, OpenAIChatInput } from "./openai.js"; -import type { OpenAICoreRequestOptions } from "../types/openai-types.js"; -import type { BaseChatModelParams } from "./base.js"; -import { ChatOpenAI } from "./openai.js"; -import { getEnvironmentVariable } from "../util/env.js"; - -type FireworksUnsupportedArgs = - | "frequencyPenalty" - | "presencePenalty" - | "logitBias" - | "functions"; - -type FireworksUnsupportedCallOptions = "functions" | "function_call" | "tools"; - -export type ChatFireworksCallOptions = Partial< - Omit ->; - -/** - * Wrapper around Fireworks API for large language models fine-tuned for chat - * - * Fireworks API is compatible to the OpenAI API with some limitations described in - * https://readme.fireworks.ai/docs/openai-compatibility. - * - * To use, you should have the `openai` package installed and - * the `FIREWORKS_API_KEY` environment variable set. - * @example - * ```typescript - * const model = new ChatFireworks({ - * temperature: 0.9, - * fireworksApiKey: "YOUR-API-KEY", - * }); - * - * const response = await model.invoke("Hello, how are you?"); - * console.log(response); - * ``` - */ -export class ChatFireworks extends ChatOpenAI { - static lc_name() { - return "ChatFireworks"; - } - - _llmType() { - return "fireworks"; - } - - get lc_secrets(): { [key: string]: string } | undefined { - return { - fireworksApiKey: "FIREWORKS_API_KEY", - }; - } - - lc_serializable = true; - - fireworksApiKey?: string; - - constructor( - fields?: Partial< - Omit - > & - BaseChatModelParams & { fireworksApiKey?: string } - ) { - const fireworksApiKey = - fields?.fireworksApiKey || getEnvironmentVariable("FIREWORKS_API_KEY"); - - if (!fireworksApiKey) { - throw new Error( - `Fireworks API key not found. Please set the FIREWORKS_API_KEY environment variable or provide the key into "fireworksApiKey"` - ); - } - - super({ - ...fields, - modelName: - fields?.modelName || "accounts/fireworks/models/llama-v2-13b-chat", - openAIApiKey: fireworksApiKey, - configuration: { - baseURL: "https://api.fireworks.ai/inference/v1", - }, - }); - - this.fireworksApiKey = fireworksApiKey; - } - - toJSON() { - const result = super.toJSON(); - - if ( - "kwargs" in result && - typeof result.kwargs === "object" && - result.kwargs != null - ) { - delete result.kwargs.openai_api_key; - delete result.kwargs.configuration; - } - - return result; - } - - async completionWithRetry( - request: OpenAIClient.Chat.ChatCompletionCreateParamsStreaming, - options?: OpenAICoreRequestOptions - ): Promise>; - - async completionWithRetry( - request: OpenAIClient.Chat.ChatCompletionCreateParamsNonStreaming, - options?: OpenAICoreRequestOptions - ): Promise; - - /** - * Calls the Fireworks API with retry logic in case of failures. - * @param request The request to send to the Fireworks API. - * @param options Optional configuration for the API call. - * @returns The response from the Fireworks API. - */ - async completionWithRetry( - request: - | OpenAIClient.Chat.ChatCompletionCreateParamsStreaming - | OpenAIClient.Chat.ChatCompletionCreateParamsNonStreaming, - options?: OpenAICoreRequestOptions - ): Promise< - | AsyncIterable - | OpenAIClient.Chat.Completions.ChatCompletion - > { - delete request.frequency_penalty; - delete request.presence_penalty; - delete request.logit_bias; - delete request.functions; - - if (request.stream === true) { - return super.completionWithRetry(request, options); - } - - return super.completionWithRetry(request, options); - } -} +export * from "@langchain/community/chat_models/fireworks"; diff --git a/langchain/src/chat_models/googlepalm.ts b/langchain/src/chat_models/googlepalm.ts index 56cca1023943..53f338a9674d 100644 --- a/langchain/src/chat_models/googlepalm.ts +++ b/langchain/src/chat_models/googlepalm.ts @@ -1,340 +1 @@ -import { DiscussServiceClient } from "@google-ai/generativelanguage"; -import type { protos } from "@google-ai/generativelanguage"; -import { GoogleAuth } from "google-auth-library"; -import { CallbackManagerForLLMRun } from "../callbacks/manager.js"; -import { - AIMessage, - BaseMessage, - ChatMessage, - ChatResult, - isBaseMessage, -} from "../schema/index.js"; -import { getEnvironmentVariable } from "../util/env.js"; -import { BaseChatModel, BaseChatModelParams } from "./base.js"; - -export type BaseMessageExamplePair = { - input: BaseMessage; - output: BaseMessage; -}; - -/** - * An interface defining the input to the ChatGooglePaLM class. - */ -export interface GooglePaLMChatInput extends BaseChatModelParams { - /** - * Model Name to use - * - * Note: The format must follow the pattern - `models/{model}` - */ - modelName?: string; - - /** - * Controls the randomness of the output. - * - * Values can range from [0.0,1.0], inclusive. A value closer to 1.0 - * will produce responses that are more varied and creative, while - * a value closer to 0.0 will typically result in less surprising - * responses from the model. - * - * Note: The default value varies by model - */ - temperature?: number; - - /** - * Top-p changes how the model selects tokens for output. - * - * Tokens are selected from most probable to least until the sum - * of their probabilities equals the top-p value. - * - * For example, if tokens A, B, and C have a probability of - * .3, .2, and .1 and the top-p value is .5, then the model will - * select either A or B as the next token (using temperature). - * - * Note: The default value varies by model - */ - topP?: number; - - /** - * Top-k changes how the model selects tokens for output. - * - * A top-k of 1 means the selected token is the most probable among - * all tokens in the model’s vocabulary (also called greedy decoding), - * while a top-k of 3 means that the next token is selected from - * among the 3 most probable tokens (using temperature). - * - * Note: The default value varies by model - */ - topK?: number; - - examples?: - | protos.google.ai.generativelanguage.v1beta2.IExample[] - | BaseMessageExamplePair[]; - - /** - * Google Palm API key to use - */ - apiKey?: string; -} - -function getMessageAuthor(message: BaseMessage) { - const type = message._getType(); - if (ChatMessage.isInstance(message)) { - return message.role; - } - return message.name ?? type; -} - -/** - * A class that wraps the Google Palm chat model. - * @example - * ```typescript - * const model = new ChatGooglePaLM({ - * apiKey: "", - * temperature: 0.7, - * modelName: "models/chat-bison-001", - * topK: 40, - * topP: 1, - * examples: [ - * { - * input: new HumanMessage("What is your favorite sock color?"), - * output: new AIMessage("My favorite sock color be arrrr-ange!"), - * }, - * ], - * }); - * const questions = [ - * new SystemMessage( - * "You are a funny assistant that answers in pirate language." - * ), - * new HumanMessage("What is your favorite food?"), - * ]; - * const res = await model.call(questions); - * console.log({ res }); - * ``` - */ -export class ChatGooglePaLM - extends BaseChatModel - implements GooglePaLMChatInput -{ - static lc_name() { - return "ChatGooglePaLM"; - } - - lc_serializable = true; - - get lc_secrets(): { [key: string]: string } | undefined { - return { - apiKey: "GOOGLE_PALM_API_KEY", - }; - } - - modelName = "models/chat-bison-001"; - - temperature?: number; // default value chosen based on model - - topP?: number; // default value chosen based on model - - topK?: number; // default value chosen based on model - - examples: protos.google.ai.generativelanguage.v1beta2.IExample[] = []; - - apiKey?: string; - - private client: DiscussServiceClient; - - constructor(fields?: GooglePaLMChatInput) { - super(fields ?? {}); - - this.modelName = fields?.modelName ?? this.modelName; - - this.temperature = fields?.temperature ?? this.temperature; - if (this.temperature && (this.temperature < 0 || this.temperature > 1)) { - throw new Error("`temperature` must be in the range of [0.0,1.0]"); - } - - this.topP = fields?.topP ?? this.topP; - if (this.topP && this.topP < 0) { - throw new Error("`topP` must be a positive integer"); - } - - this.topK = fields?.topK ?? this.topK; - if (this.topK && this.topK < 0) { - throw new Error("`topK` must be a positive integer"); - } - - this.examples = - fields?.examples?.map((example) => { - if ( - (isBaseMessage(example.input) && - typeof example.input.content !== "string") || - (isBaseMessage(example.output) && - typeof example.output.content !== "string") - ) { - throw new Error( - "GooglePaLM example messages may only have string content." - ); - } - return { - input: { - ...example.input, - content: example.input?.content as string, - }, - output: { - ...example.output, - content: example.output?.content as string, - }, - }; - }) ?? this.examples; - - this.apiKey = - fields?.apiKey ?? getEnvironmentVariable("GOOGLE_PALM_API_KEY"); - if (!this.apiKey) { - throw new Error( - "Please set an API key for Google Palm 2 in the environment variable GOOGLE_PALM_API_KEY or in the `apiKey` field of the GooglePalm constructor" - ); - } - - this.client = new DiscussServiceClient({ - authClient: new GoogleAuth().fromAPIKey(this.apiKey), - }); - } - - _combineLLMOutput() { - return []; - } - - _llmType() { - return "googlepalm"; - } - - async _generate( - messages: BaseMessage[], - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): Promise { - const palmMessages = await this.caller.callWithOptions( - { signal: options.signal }, - this._generateMessage.bind(this), - this._mapBaseMessagesToPalmMessages(messages), - this._getPalmContextInstruction(messages), - this.examples - ); - const chatResult = this._mapPalmMessagesToChatResult(palmMessages); - - // Google Palm doesn't provide streaming as of now. But to support streaming handlers - // we call the handler with entire response text - void runManager?.handleLLMNewToken( - chatResult.generations.length > 0 ? chatResult.generations[0].text : "" - ); - - return chatResult; - } - - protected async _generateMessage( - messages: protos.google.ai.generativelanguage.v1beta2.IMessage[], - context?: string, - examples?: protos.google.ai.generativelanguage.v1beta2.IExample[] - ): Promise { - const [palmMessages] = await this.client.generateMessage({ - candidateCount: 1, - model: this.modelName, - temperature: this.temperature, - topK: this.topK, - topP: this.topP, - prompt: { - context, - examples, - messages, - }, - }); - return palmMessages; - } - - protected _getPalmContextInstruction( - messages: BaseMessage[] - ): string | undefined { - // get the first message and checks if it's a system 'system' messages - const systemMessage = - messages.length > 0 && getMessageAuthor(messages[0]) === "system" - ? messages[0] - : undefined; - if ( - systemMessage?.content !== undefined && - typeof systemMessage.content !== "string" - ) { - throw new Error("Non-string system message content is not supported."); - } - return systemMessage?.content; - } - - protected _mapBaseMessagesToPalmMessages( - messages: BaseMessage[] - ): protos.google.ai.generativelanguage.v1beta2.IMessage[] { - // remove all 'system' messages - const nonSystemMessages = messages.filter( - (m) => getMessageAuthor(m) !== "system" - ); - - // requires alternate human & ai messages. Throw error if two messages are consecutive - nonSystemMessages.forEach((msg, index) => { - if (index < 1) return; - if ( - getMessageAuthor(msg) === getMessageAuthor(nonSystemMessages[index - 1]) - ) { - throw new Error( - `Google PaLM requires alternate messages between authors` - ); - } - }); - - return nonSystemMessages.map((m) => { - if (typeof m.content !== "string") { - throw new Error( - "ChatGooglePaLM does not support non-string message content." - ); - } - return { - author: getMessageAuthor(m), - content: m.content, - citationMetadata: { - citationSources: m.additional_kwargs.citationSources as - | protos.google.ai.generativelanguage.v1beta2.ICitationSource[] - | undefined, - }, - }; - }); - } - - protected _mapPalmMessagesToChatResult( - msgRes: protos.google.ai.generativelanguage.v1beta2.IGenerateMessageResponse - ): ChatResult { - if ( - msgRes.candidates && - msgRes.candidates.length > 0 && - msgRes.candidates[0] - ) { - const message = msgRes.candidates[0]; - return { - generations: [ - { - text: message.content ?? "", - message: new AIMessage({ - content: message.content ?? "", - name: message.author === null ? undefined : message.author, - additional_kwargs: { - citationSources: message.citationMetadata?.citationSources, - filters: msgRes.filters, // content filters applied - }, - }), - }, - ], - }; - } - // if rejected or error, return empty generations with reason in filters - return { - generations: [], - llmOutput: { - filters: msgRes.filters, - }, - }; - } -} +export * from "@langchain/community/chat_models/googlepalm"; diff --git a/langchain/src/chat_models/googlevertexai/index.ts b/langchain/src/chat_models/googlevertexai/index.ts index e8a3a07da320..b2977f65c20a 100644 --- a/langchain/src/chat_models/googlevertexai/index.ts +++ b/langchain/src/chat_models/googlevertexai/index.ts @@ -1,64 +1 @@ -import { GoogleAuthOptions } from "google-auth-library"; -import { BaseChatGoogleVertexAI, GoogleVertexAIChatInput } from "./common.js"; -import { GoogleVertexAILLMConnection } from "../../util/googlevertexai-connection.js"; -import { GAuthClient } from "../../util/googlevertexai-gauth.js"; - -/** - * Enables calls to the Google Cloud's Vertex AI API to access - * Large Language Models in a chat-like fashion. - * - * To use, you will need to have one of the following authentication - * methods in place: - * - You are logged into an account permitted to the Google Cloud project - * using Vertex AI. - * - You are running this on a machine using a service account permitted to - * the Google Cloud project using Vertex AI. - * - The `GOOGLE_APPLICATION_CREDENTIALS` environment variable is set to the - * path of a credentials file for a service account permitted to the - * Google Cloud project using Vertex AI. - * @example - * ```typescript - * const model = new ChatGoogleVertexAI({ - * temperature: 0.7, - * }); - * const result = await model.invoke("What is the capital of France?"); - * ``` - */ -export class ChatGoogleVertexAI extends BaseChatGoogleVertexAI { - static lc_name() { - return "ChatVertexAI"; - } - - constructor(fields?: GoogleVertexAIChatInput) { - super(fields); - - const client = new GAuthClient({ - scopes: "https://www.googleapis.com/auth/cloud-platform", - ...fields?.authOptions, - }); - - this.connection = new GoogleVertexAILLMConnection( - { ...fields, ...this }, - this.caller, - client, - false - ); - - this.streamedConnection = new GoogleVertexAILLMConnection( - { ...fields, ...this }, - this.caller, - client, - true - ); - } -} - -export type { - ChatExample, - GoogleVertexAIChatAuthor, - GoogleVertexAIChatInput, - GoogleVertexAIChatInstance, - GoogleVertexAIChatMessage, - GoogleVertexAIChatMessageFields, - GoogleVertexAIChatPrediction, -} from "./common.js"; +export * from "@langchain/community/chat_models/googlevertexai"; diff --git a/langchain/src/chat_models/googlevertexai/web.ts b/langchain/src/chat_models/googlevertexai/web.ts index acbaa9144f4c..4c350a89266a 100644 --- a/langchain/src/chat_models/googlevertexai/web.ts +++ b/langchain/src/chat_models/googlevertexai/web.ts @@ -1,66 +1 @@ -import { GoogleVertexAILLMConnection } from "../../util/googlevertexai-connection.js"; -import { - WebGoogleAuthOptions, - WebGoogleAuth, -} from "../../util/googlevertexai-webauth.js"; -import { BaseChatGoogleVertexAI, GoogleVertexAIChatInput } from "./common.js"; - -/** - * Enables calls to the Google Cloud's Vertex AI API to access - * Large Language Models in a chat-like fashion. - * - * This entrypoint and class are intended to be used in web environments like Edge - * functions where you do not have access to the file system. It supports passing - * service account credentials directly as a "GOOGLE_VERTEX_AI_WEB_CREDENTIALS" - * environment variable or directly as "authOptions.credentials". - * @example - * ```typescript - * const model = new ChatGoogleVertexAI({ - * temperature: 0.7, - * }); - * const result = await model.invoke( - * "How do I implement a binary search algorithm in Python?", - * ); - * ``` - */ -export class ChatGoogleVertexAI extends BaseChatGoogleVertexAI { - static lc_name() { - return "ChatVertexAI"; - } - - get lc_secrets(): { [key: string]: string } { - return { - "authOptions.credentials": "GOOGLE_VERTEX_AI_WEB_CREDENTIALS", - }; - } - - constructor(fields?: GoogleVertexAIChatInput) { - super(fields); - - const client = new WebGoogleAuth(fields?.authOptions); - - this.connection = new GoogleVertexAILLMConnection( - { ...fields, ...this }, - this.caller, - client, - false - ); - - this.streamedConnection = new GoogleVertexAILLMConnection( - { ...fields, ...this }, - this.caller, - client, - true - ); - } -} - -export type { - ChatExample, - GoogleVertexAIChatAuthor, - GoogleVertexAIChatInput, - GoogleVertexAIChatInstance, - GoogleVertexAIChatMessage, - GoogleVertexAIChatMessageFields, - GoogleVertexAIChatPrediction, -} from "./common.js"; +export * from "@langchain/community/chat_models/googlevertexai/web"; diff --git a/langchain/src/chat_models/iflytek_xinghuo/index.ts b/langchain/src/chat_models/iflytek_xinghuo/index.ts index ac54461be18a..9b988a537704 100644 --- a/langchain/src/chat_models/iflytek_xinghuo/index.ts +++ b/langchain/src/chat_models/iflytek_xinghuo/index.ts @@ -1,43 +1 @@ -import WebSocket from "ws"; -import { BaseChatIflytekXinghuo } from "./common.js"; -import { - BaseWebSocketStream, - WebSocketStreamOptions, -} from "../../util/iflytek_websocket_stream.js"; - -class WebSocketStream extends BaseWebSocketStream { - // eslint-disable-next-line @typescript-eslint/ban-ts-comment - // @ts-ignore - openWebSocket(url: string, options: WebSocketStreamOptions): WebSocket { - return new WebSocket(url, options.protocols ?? []); - } -} - -/** - * @example - * ```typescript - * const model = new ChatIflytekXinghuo(); - * const response = await model.call([new HumanMessage("Nice to meet you!")]); - * console.log(response); - * ``` - */ -export class ChatIflytekXinghuo extends BaseChatIflytekXinghuo { - async openWebSocketStream( - options: WebSocketStreamOptions - ): Promise { - const host = "spark-api.xf-yun.com"; - const date = new Date().toUTCString(); - const url = `GET /${this.version}/chat HTTP/1.1`; - const { createHmac } = await import("node:crypto"); - const hash = createHmac("sha256", this.iflytekApiSecret) - .update(`host: ${host}\ndate: ${date}\n${url}`) - .digest("base64"); - const authorization_origin = `api_key="${this.iflytekApiKey}", algorithm="hmac-sha256", headers="host date request-line", signature="${hash}"`; - const authorization = Buffer.from(authorization_origin).toString("base64"); - let authWebSocketUrl = this.apiUrl; - authWebSocketUrl += `?authorization=${authorization}`; - authWebSocketUrl += `&host=${encodeURIComponent(host)}`; - authWebSocketUrl += `&date=${encodeURIComponent(date)}`; - return new WebSocketStream(authWebSocketUrl, options) as WebSocketStream; - } -} +export * from "@langchain/community/chat_models/iflytek_xinghuo"; diff --git a/langchain/src/chat_models/iflytek_xinghuo/web.ts b/langchain/src/chat_models/iflytek_xinghuo/web.ts index 87b372b802ad..8867445ee59c 100644 --- a/langchain/src/chat_models/iflytek_xinghuo/web.ts +++ b/langchain/src/chat_models/iflytek_xinghuo/web.ts @@ -1,49 +1 @@ -import { BaseChatIflytekXinghuo } from "./common.js"; -import { - WebSocketStreamOptions, - BaseWebSocketStream, -} from "../../util/iflytek_websocket_stream.js"; - -class WebSocketStream extends BaseWebSocketStream { - openWebSocket(url: string, options: WebSocketStreamOptions): WebSocket { - return new WebSocket(url, options.protocols ?? []); - } -} - -/** - * @example - * ```typescript - * const model = new ChatIflytekXinghuo(); - * const response = await model.call([new HumanMessage("Nice to meet you!")]); - * console.log(response); - * ``` - */ -export class ChatIflytekXinghuo extends BaseChatIflytekXinghuo { - async openWebSocketStream( - options: WebSocketStreamOptions - ): Promise { - const host = "spark-api.xf-yun.com"; - const date = new Date().toUTCString(); - const url = `GET /${this.version}/chat HTTP/1.1`; - const keyBuffer = new TextEncoder().encode(this.iflytekApiSecret); - const dataBuffer = new TextEncoder().encode( - `host: ${host}\ndate: ${date}\n${url}` - ); - const cryptoKey = await crypto.subtle.importKey( - "raw", - keyBuffer, - { name: "HMAC", hash: "SHA-256" }, - false, - ["sign"] - ); - const signature = await crypto.subtle.sign("HMAC", cryptoKey, dataBuffer); - const hash = window.btoa(String.fromCharCode(...new Uint8Array(signature))); - const authorization_origin = `api_key="${this.iflytekApiKey}", algorithm="hmac-sha256", headers="host date request-line", signature="${hash}"`; - const authorization = window.btoa(authorization_origin); - let authWebSocketUrl = this.apiUrl; - authWebSocketUrl += `?authorization=${authorization}`; - authWebSocketUrl += `&host=${encodeURIComponent(host)}`; - authWebSocketUrl += `&date=${encodeURIComponent(date)}`; - return new WebSocketStream(authWebSocketUrl, options) as WebSocketStream; - } -} +export * from "@langchain/community/chat_models/iflytek_xinghuo/web"; diff --git a/langchain/src/chat_models/llama_cpp.ts b/langchain/src/chat_models/llama_cpp.ts index 3df8b0d2a3c9..ae06e6116cfc 100644 --- a/langchain/src/chat_models/llama_cpp.ts +++ b/langchain/src/chat_models/llama_cpp.ts @@ -1,322 +1 @@ -import { - LlamaModel, - LlamaContext, - LlamaChatSession, - type ConversationInteraction, -} from "node-llama-cpp"; -import { SimpleChatModel, BaseChatModelParams } from "./base.js"; -import { - LlamaBaseCppInputs, - createLlamaModel, - createLlamaContext, -} from "../util/llama_cpp.js"; -import { BaseLanguageModelCallOptions } from "../base_language/index.js"; -import { CallbackManagerForLLMRun } from "../callbacks/manager.js"; -import { - BaseMessage, - ChatGenerationChunk, - AIMessageChunk, - ChatMessage, -} from "../schema/index.js"; - -/** - * Note that the modelPath is the only required parameter. For testing you - * can set this in the environment variable `LLAMA_PATH`. - */ -export interface LlamaCppInputs - extends LlamaBaseCppInputs, - BaseChatModelParams {} - -export interface LlamaCppCallOptions extends BaseLanguageModelCallOptions { - /** The maximum number of tokens the response should contain. */ - maxTokens?: number; - /** A function called when matching the provided token array */ - onToken?: (tokens: number[]) => void; -} - -/** - * To use this model you need to have the `node-llama-cpp` module installed. - * This can be installed using `npm install -S node-llama-cpp` and the minimum - * version supported in version 2.0.0. - * This also requires that have a locally built version of Llama2 installed. - * @example - * ```typescript - * // Initialize the ChatLlamaCpp model with the path to the model binary file. - * const model = new ChatLlamaCpp({ - * modelPath: "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin", - * temperature: 0.5, - * }); - * - * // Call the model with a message and await the response. - * const response = await model.call([ - * new HumanMessage({ content: "My name is John." }), - * ]); - * - * // Log the response to the console. - * console.log({ response }); - * - * ``` - */ -export class ChatLlamaCpp extends SimpleChatModel { - declare CallOptions: LlamaCppCallOptions; - - static inputs: LlamaCppInputs; - - maxTokens?: number; - - temperature?: number; - - topK?: number; - - topP?: number; - - trimWhitespaceSuffix?: boolean; - - _model: LlamaModel; - - _context: LlamaContext; - - _session: LlamaChatSession | null; - - static lc_name() { - return "ChatLlamaCpp"; - } - - constructor(inputs: LlamaCppInputs) { - super(inputs); - this.maxTokens = inputs?.maxTokens; - this.temperature = inputs?.temperature; - this.topK = inputs?.topK; - this.topP = inputs?.topP; - this.trimWhitespaceSuffix = inputs?.trimWhitespaceSuffix; - this._model = createLlamaModel(inputs); - this._context = createLlamaContext(this._model, inputs); - this._session = null; - } - - _llmType() { - return "llama2_cpp"; - } - - /** @ignore */ - _combineLLMOutput() { - return {}; - } - - invocationParams() { - return { - maxTokens: this.maxTokens, - temperature: this.temperature, - topK: this.topK, - topP: this.topP, - trimWhitespaceSuffix: this.trimWhitespaceSuffix, - }; - } - - /** @ignore */ - async _call( - messages: BaseMessage[], - options: this["ParsedCallOptions"] - ): Promise { - let prompt = ""; - - if (messages.length > 1) { - // We need to build a new _session - prompt = this._buildSession(messages); - } else if (!this._session) { - prompt = this._buildSession(messages); - } else { - if (typeof messages[0].content !== "string") { - throw new Error( - "ChatLlamaCpp does not support non-string message content in sessions." - ); - } - // If we already have a session then we should just have a single prompt - prompt = messages[0].content; - } - - try { - const promptOptions = { - onToken: options.onToken, - maxTokens: this?.maxTokens, - temperature: this?.temperature, - topK: this?.topK, - topP: this?.topP, - trimWhitespaceSuffix: this?.trimWhitespaceSuffix, - }; - // @ts-expect-error - TS2531: Object is possibly 'null'. - const completion = await this._session.prompt(prompt, promptOptions); - return completion; - } catch (e) { - throw new Error("Error getting prompt completion."); - } - } - - async *_streamResponseChunks( - input: BaseMessage[], - _options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): AsyncGenerator { - const promptOptions = { - temperature: this?.temperature, - topK: this?.topK, - topP: this?.topP, - }; - - const prompt = this._buildPrompt(input); - - const stream = await this.caller.call(async () => - this._context.evaluate(this._context.encode(prompt), promptOptions) - ); - - for await (const chunk of stream) { - yield new ChatGenerationChunk({ - text: this._context.decode([chunk]), - message: new AIMessageChunk({ - content: this._context.decode([chunk]), - }), - generationInfo: {}, - }); - await runManager?.handleLLMNewToken(this._context.decode([chunk]) ?? ""); - } - } - - // This constructs a new session if we need to adding in any sys messages or previous chats - protected _buildSession(messages: BaseMessage[]): string { - let prompt = ""; - let sysMessage = ""; - let noSystemMessages: BaseMessage[] = []; - let interactions: ConversationInteraction[] = []; - - // Let's see if we have a system message - if (messages.findIndex((msg) => msg._getType() === "system") !== -1) { - const sysMessages = messages.filter( - (message) => message._getType() === "system" - ); - - const systemMessageContent = sysMessages[sysMessages.length - 1].content; - - if (typeof systemMessageContent !== "string") { - throw new Error( - "ChatLlamaCpp does not support non-string message content in sessions." - ); - } - // Only use the last provided system message - sysMessage = systemMessageContent; - - // Now filter out the system messages - noSystemMessages = messages.filter( - (message) => message._getType() !== "system" - ); - } else { - noSystemMessages = messages; - } - - // Lets see if we just have a prompt left or are their previous interactions? - if (noSystemMessages.length > 1) { - // Is the last message a prompt? - if ( - noSystemMessages[noSystemMessages.length - 1]._getType() === "human" - ) { - const finalMessageContent = - noSystemMessages[noSystemMessages.length - 1].content; - if (typeof finalMessageContent !== "string") { - throw new Error( - "ChatLlamaCpp does not support non-string message content in sessions." - ); - } - prompt = finalMessageContent; - interactions = this._convertMessagesToInteractions( - noSystemMessages.slice(0, noSystemMessages.length - 1) - ); - } else { - interactions = this._convertMessagesToInteractions(noSystemMessages); - } - } else { - if (typeof noSystemMessages[0].content !== "string") { - throw new Error( - "ChatLlamaCpp does not support non-string message content in sessions." - ); - } - // If there was only a single message we assume it's a prompt - prompt = noSystemMessages[0].content; - } - - // Now lets construct a session according to what we got - if (sysMessage !== "" && interactions.length > 0) { - this._session = new LlamaChatSession({ - context: this._context, - conversationHistory: interactions, - systemPrompt: sysMessage, - }); - } else if (sysMessage !== "" && interactions.length === 0) { - this._session = new LlamaChatSession({ - context: this._context, - systemPrompt: sysMessage, - }); - } else if (sysMessage === "" && interactions.length > 0) { - this._session = new LlamaChatSession({ - context: this._context, - conversationHistory: interactions, - }); - } else { - this._session = new LlamaChatSession({ - context: this._context, - }); - } - - return prompt; - } - - // This builds a an array of interactions - protected _convertMessagesToInteractions( - messages: BaseMessage[] - ): ConversationInteraction[] { - const result: ConversationInteraction[] = []; - - for (let i = 0; i < messages.length; i += 2) { - if (i + 1 < messages.length) { - const prompt = messages[i].content; - const response = messages[i + 1].content; - if (typeof prompt !== "string" || typeof response !== "string") { - throw new Error( - "ChatLlamaCpp does not support non-string message content." - ); - } - result.push({ - prompt, - response, - }); - } - } - - return result; - } - - protected _buildPrompt(input: BaseMessage[]): string { - const prompt = input - .map((message) => { - let messageText; - if (message._getType() === "human") { - messageText = `[INST] ${message.content} [/INST]`; - } else if (message._getType() === "ai") { - messageText = message.content; - } else if (message._getType() === "system") { - messageText = `<> ${message.content} <>`; - } else if (ChatMessage.isInstance(message)) { - messageText = `\n\n${message.role[0].toUpperCase()}${message.role.slice( - 1 - )}: ${message.content}`; - } else { - console.warn( - `Unsupported message type passed to llama_cpp: "${message._getType()}"` - ); - messageText = ""; - } - return messageText; - }) - .join("\n"); - - return prompt; - } -} +export * from "@langchain/community/chat_models/llama_cpp"; diff --git a/langchain/src/chat_models/minimax.ts b/langchain/src/chat_models/minimax.ts index c0e521d6fec0..6cb2a1436dbd 100644 --- a/langchain/src/chat_models/minimax.ts +++ b/langchain/src/chat_models/minimax.ts @@ -1,880 +1 @@ -import type { OpenAI as OpenAIClient } from "openai"; - -import { BaseChatModel, BaseChatModelParams } from "./base.js"; -import { - AIMessage, - BaseMessage, - ChatGeneration, - ChatMessage, - ChatResult, - HumanMessage, -} from "../schema/index.js"; -import { CallbackManagerForLLMRun } from "../callbacks/manager.js"; -import { getEnvironmentVariable } from "../util/env.js"; -import { StructuredTool } from "../tools/index.js"; -import { BaseFunctionCallOptions } from "../base_language/index.js"; -import { formatToOpenAIFunction } from "../tools/convert_to_openai.js"; - -/** - * Type representing the sender_type of a message in the Minimax chat model. - */ -export type MinimaxMessageRole = "BOT" | "USER" | "FUNCTION"; - -/** - * Interface representing a message in the Minimax chat model. - */ -interface MinimaxChatCompletionRequestMessage { - sender_type: MinimaxMessageRole; - sender_name?: string; - text: string; -} - -/** - * Interface representing a request for a chat completion. - */ -interface MinimaxChatCompletionRequest { - model: string; - messages: MinimaxChatCompletionRequestMessage[]; - stream?: boolean; - prompt?: string; - temperature?: number; - top_p?: number; - tokens_to_generate?: number; - skip_info_mask?: boolean; - mask_sensitive_info?: boolean; - beam_width?: number; - use_standard_sse?: boolean; - role_meta?: RoleMeta; - bot_setting?: BotSetting[]; - reply_constraints?: ReplyConstraints; - sample_messages?: MinimaxChatCompletionRequestMessage[]; - /** - * A list of functions the model may generate JSON inputs for. - * @type {Array} - */ - functions?: OpenAIClient.Chat.ChatCompletionCreateParams.Function[]; - plugins?: string[]; -} - -interface RoleMeta { - role_meta: string; - bot_name: string; -} - -interface RawGlyph { - type: "raw"; - raw_glyph: string; -} - -interface JsonGlyph { - type: "json_value"; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - json_properties: any; -} - -type ReplyConstraintsGlyph = RawGlyph | JsonGlyph; - -interface ReplyConstraints { - sender_type: string; - sender_name: string; - glyph?: ReplyConstraintsGlyph; -} - -interface BotSetting { - content: string; - bot_name: string; -} - -export declare interface ConfigurationParameters { - basePath?: string; - headers?: Record; -} - -/** - * Interface defining the input to the ChatMinimax class. - */ -declare interface MinimaxChatInputBase { - /** Model name to use - * @default "abab5.5-chat" - */ - modelName: string; - - /** Whether to stream the results or not. Defaults to false. */ - streaming?: boolean; - - prefixMessages?: MinimaxChatCompletionRequestMessage[]; - - /** - * API key to use when making requests. Defaults to the value of - * `MINIMAX_GROUP_ID` environment variable. - */ - minimaxGroupId?: string; - - /** - * Secret key to use when making requests. Defaults to the value of - * `MINIMAX_API_KEY` environment variable. - */ - minimaxApiKey?: string; - - /** Amount of randomness injected into the response. Ranges - * from 0 to 1 (0 is not included). Use temp closer to 0 for analytical / - * multiple choice, and temp closer to 1 for creative - * and generative tasks. Defaults to 0.95. - */ - temperature?: number; - - /** - * The smaller the sampling method, the more determinate the result; - * the larger the number, the more random the result. - */ - topP?: number; - - /** - * Enable Chatcompletion pro - */ - proVersion?: boolean; - - /** - * Pay attention to the maximum number of tokens generated, - * this parameter does not affect the generation effect of the model itself, - * but only realizes the function by truncating the tokens exceeding the limit. - * It is necessary to ensure that the number of tokens of the input context plus this value is less than 6144 or 16384, - * otherwise the request will fail. - */ - tokensToGenerate?: number; -} - -declare interface MinimaxChatInputNormal { - /** - * Dialogue setting, characters, or functionality setting. - */ - prompt?: string; - /** - * Sensitize text information in the output that may involve privacy issues, - * currently including but not limited to emails, domain names, - * links, ID numbers, home addresses, etc. Default false, ie. enable sensitization. - */ - skipInfoMask?: boolean; - - /** - * Whether to use the standard SSE format, when set to true, - * the streaming results will be separated by two line breaks. - * This parameter only takes effect when stream is set to true. - */ - useStandardSse?: boolean; - - /** - * If it is true, this indicates that the current request is set to continuation mode, - * and the response is a continuation of the last sentence in the incoming messages; - * at this time, the last sender is not limited to USER, it can also be BOT. - * Assuming the last sentence of incoming messages is {"sender_type": " U S E R", "text": "天生我材"}, - * the completion of the reply may be "It must be useful." - */ - continueLastMessage?: boolean; - - /** - * How many results to generate; the default is 1 and the maximum is not more than 4. - * Because beamWidth generates multiple results, it will consume more tokens. - */ - beamWidth?: number; - - /** - * Dialogue Metadata - */ - roleMeta?: RoleMeta; -} - -declare interface MinimaxChatInputPro extends MinimaxChatInputBase { - /** - * For the text information in the output that may involve privacy issues, - * code masking is currently included but not limited to emails, domains, links, ID numbers, home addresses, etc., - * with the default being true, that is, code masking is enabled. - */ - maskSensitiveInfo?: boolean; - - /** - * Default bot name - */ - defaultBotName?: string; - - /** - * Default user name - */ - defaultUserName?: string; - - /** - * Setting for each robot, only available for pro version. - */ - botSetting?: BotSetting[]; - - replyConstraints?: ReplyConstraints; -} - -type MinimaxChatInput = MinimaxChatInputNormal & MinimaxChatInputPro; - -/** - * Function that extracts the custom sender_type of a generic chat message. - * @param message Chat message from which to extract the custom sender_type. - * @returns The custom sender_type of the chat message. - */ -function extractGenericMessageCustomRole(message: ChatMessage) { - if (message.role !== "ai" && message.role !== "user") { - console.warn(`Unknown message role: ${message.role}`); - } - if (message.role === "ai") { - return "BOT" as MinimaxMessageRole; - } - if (message.role === "user") { - return "USER" as MinimaxMessageRole; - } - return message.role as MinimaxMessageRole; -} - -/** - * Function that converts a base message to a Minimax message sender_type. - * @param message Base message to convert. - * @returns The Minimax message sender_type. - */ -function messageToMinimaxRole(message: BaseMessage): MinimaxMessageRole { - const type = message._getType(); - switch (type) { - case "ai": - return "BOT"; - case "human": - return "USER"; - case "system": - throw new Error("System messages not supported"); - case "function": - return "FUNCTION"; - case "generic": { - if (!ChatMessage.isInstance(message)) - throw new Error("Invalid generic chat message"); - return extractGenericMessageCustomRole(message); - } - default: - throw new Error(`Unknown message type: ${type}`); - } -} - -export interface ChatMinimaxCallOptions extends BaseFunctionCallOptions { - tools?: StructuredTool[]; - defaultUserName?: string; - defaultBotName?: string; - plugins?: string[]; - botSetting?: BotSetting[]; - replyConstraints?: ReplyConstraints; - sampleMessages?: BaseMessage[]; -} - -/** - * Wrapper around Minimax large language models that use the Chat endpoint. - * - * To use you should have the `MINIMAX_GROUP_ID` and `MINIMAX_API_KEY` - * environment variable set. - * @example - * ```typescript - * // Define a chat prompt with a system message setting the context for translation - * const chatPrompt = ChatPromptTemplate.fromMessages([ - * SystemMessagePromptTemplate.fromTemplate( - * "You are a helpful assistant that translates {input_language} to {output_language}.", - * ), - * HumanMessagePromptTemplate.fromTemplate("{text}"), - * ]); - * - * // Create a new LLMChain with the chat model and the defined prompt - * const chainB = new LLMChain({ - * prompt: chatPrompt, - * llm: new ChatMinimax({ temperature: 0.01 }), - * }); - * - * // Call the chain with the input language, output language, and the text to translate - * const resB = await chainB.call({ - * input_language: "English", - * output_language: "Chinese", - * text: "I love programming.", - * }); - * - * // Log the result - * console.log({ resB }); - * - * ``` - */ -export class ChatMinimax - extends BaseChatModel - implements MinimaxChatInput -{ - static lc_name() { - return "ChatMinimax"; - } - - get callKeys(): (keyof ChatMinimaxCallOptions)[] { - return [ - ...(super.callKeys as (keyof ChatMinimaxCallOptions)[]), - "functions", - "tools", - "defaultBotName", - "defaultUserName", - "plugins", - "replyConstraints", - "botSetting", - "sampleMessages", - ]; - } - - get lc_secrets(): { [key: string]: string } | undefined { - return { - minimaxApiKey: "MINIMAX_API_KEY", - minimaxGroupId: "MINIMAX_GROUP_ID", - }; - } - - lc_serializable = true; - - minimaxGroupId?: string; - - minimaxApiKey?: string; - - streaming = false; - - prompt?: string; - - modelName = "abab5.5-chat"; - - defaultBotName?: string = "Assistant"; - - defaultUserName?: string = "I"; - - prefixMessages?: MinimaxChatCompletionRequestMessage[]; - - apiUrl: string; - - basePath?: string = "https://api.minimax.chat/v1"; - - headers?: Record; - - temperature?: number = 0.9; - - topP?: number = 0.8; - - tokensToGenerate?: number; - - skipInfoMask?: boolean; - - proVersion?: boolean = true; - - beamWidth?: number; - - botSetting?: BotSetting[]; - - continueLastMessage?: boolean; - - maskSensitiveInfo?: boolean; - - roleMeta?: RoleMeta; - - useStandardSse?: boolean; - - replyConstraints?: ReplyConstraints; - - constructor( - fields?: Partial & - BaseChatModelParams & { - configuration?: ConfigurationParameters; - } - ) { - super(fields ?? {}); - - this.minimaxGroupId = - fields?.minimaxGroupId ?? getEnvironmentVariable("MINIMAX_GROUP_ID"); - if (!this.minimaxGroupId) { - throw new Error("Minimax GroupID not found"); - } - - this.minimaxApiKey = - fields?.minimaxApiKey ?? getEnvironmentVariable("MINIMAX_API_KEY"); - - if (!this.minimaxApiKey) { - throw new Error("Minimax ApiKey not found"); - } - - this.streaming = fields?.streaming ?? this.streaming; - this.prompt = fields?.prompt ?? this.prompt; - this.temperature = fields?.temperature ?? this.temperature; - this.topP = fields?.topP ?? this.topP; - this.skipInfoMask = fields?.skipInfoMask ?? this.skipInfoMask; - this.prefixMessages = fields?.prefixMessages ?? this.prefixMessages; - this.maskSensitiveInfo = - fields?.maskSensitiveInfo ?? this.maskSensitiveInfo; - this.beamWidth = fields?.beamWidth ?? this.beamWidth; - this.continueLastMessage = - fields?.continueLastMessage ?? this.continueLastMessage; - this.tokensToGenerate = fields?.tokensToGenerate ?? this.tokensToGenerate; - this.roleMeta = fields?.roleMeta ?? this.roleMeta; - this.botSetting = fields?.botSetting ?? this.botSetting; - this.useStandardSse = fields?.useStandardSse ?? this.useStandardSse; - this.replyConstraints = fields?.replyConstraints ?? this.replyConstraints; - this.defaultBotName = fields?.defaultBotName ?? this.defaultBotName; - - this.modelName = fields?.modelName ?? this.modelName; - this.basePath = fields?.configuration?.basePath ?? this.basePath; - this.headers = fields?.configuration?.headers ?? this.headers; - this.proVersion = fields?.proVersion ?? this.proVersion; - - const modelCompletion = this.proVersion - ? "chatcompletion_pro" - : "chatcompletion"; - this.apiUrl = `${this.basePath}/text/${modelCompletion}`; - } - - fallbackBotName(options?: this["ParsedCallOptions"]) { - let botName = options?.defaultBotName ?? this.defaultBotName ?? "Assistant"; - if (this.botSetting) { - botName = this.botSetting[0].bot_name; - } - return botName; - } - - defaultReplyConstraints(options?: this["ParsedCallOptions"]) { - const constraints = options?.replyConstraints ?? this.replyConstraints; - if (!constraints) { - let botName = - options?.defaultBotName ?? this.defaultBotName ?? "Assistant"; - if (this.botSetting) { - botName = this.botSetting[0].bot_name; - } - - return { - sender_type: "BOT", - sender_name: botName, - }; - } - return constraints; - } - - /** - * Get the parameters used to invoke the model - */ - invocationParams( - options?: this["ParsedCallOptions"] - ): Omit { - return { - model: this.modelName, - stream: this.streaming, - prompt: this.prompt, - temperature: this.temperature, - top_p: this.topP, - tokens_to_generate: this.tokensToGenerate, - skip_info_mask: this.skipInfoMask, - mask_sensitive_info: this.maskSensitiveInfo, - beam_width: this.beamWidth, - use_standard_sse: this.useStandardSse, - role_meta: this.roleMeta, - bot_setting: options?.botSetting ?? this.botSetting, - reply_constraints: this.defaultReplyConstraints(options), - sample_messages: this.messageToMinimaxMessage( - options?.sampleMessages, - options - ), - functions: - options?.functions ?? - (options?.tools - ? options?.tools.map(formatToOpenAIFunction) - : undefined), - plugins: options?.plugins, - }; - } - - /** - * Get the identifying parameters for the model - */ - identifyingParams() { - return { - ...this.invocationParams(), - }; - } - - /** - * Convert a list of messages to the format expected by the model. - * @param messages - * @param options - */ - messageToMinimaxMessage( - messages?: BaseMessage[], - options?: this["ParsedCallOptions"] - ): MinimaxChatCompletionRequestMessage[] | undefined { - return messages - ?.filter((message) => { - if (ChatMessage.isInstance(message)) { - return message.role !== "system"; - } - return message._getType() !== "system"; - }) - ?.map((message) => { - const sender_type = messageToMinimaxRole(message); - if (typeof message.content !== "string") { - throw new Error( - "ChatMinimax does not support non-string message content." - ); - } - return { - sender_type, - text: message.content, - sender_name: - message.name ?? - (sender_type === "BOT" - ? this.fallbackBotName() - : options?.defaultUserName ?? this.defaultUserName), - }; - }); - } - - /** @ignore */ - async _generate( - messages: BaseMessage[], - options?: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): Promise { - const tokenUsage = { totalTokens: 0 }; - this.botSettingFallback(options, messages); - - const params = this.invocationParams(options); - const messagesMapped: MinimaxChatCompletionRequestMessage[] = [ - ...(this.messageToMinimaxMessage(messages, options) ?? []), - ...(this.prefixMessages ?? []), - ]; - - const data = params.stream - ? await new Promise((resolve, reject) => { - let response: ChatCompletionResponse; - let rejected = false; - let resolved = false; - this.completionWithRetry( - { - ...params, - messages: messagesMapped, - }, - true, - options?.signal, - (event) => { - const data = JSON.parse(event.data); - - if (data?.error_code) { - if (rejected) { - return; - } - rejected = true; - reject(data); - return; - } - - const message = data as ChatCompletionResponse; - // on the first message set the response properties - - if (!message.choices[0].finish_reason) { - // the last stream message - let streamText; - if (this.proVersion) { - const messages = message.choices[0].messages ?? []; - streamText = messages[0].text; - } else { - streamText = message.choices[0].delta; - } - - // TODO this should pass part.index to the callback - // when that's supported there - // eslint-disable-next-line no-void - void runManager?.handleLLMNewToken(streamText ?? ""); - return; - } - - response = message; - if (!this.proVersion) { - response.choices[0].text = message.reply; - } - - if (resolved || rejected) { - return; - } - resolved = true; - resolve(response); - } - ).catch((error) => { - if (!rejected) { - rejected = true; - reject(error); - } - }); - }) - : await this.completionWithRetry( - { - ...params, - messages: messagesMapped, - }, - false, - options?.signal - ); - - const { total_tokens: totalTokens } = data.usage ?? {}; - - if (totalTokens) { - tokenUsage.totalTokens = totalTokens; - } - - if (data.base_resp?.status_code !== 0) { - throw new Error(`Minimax API error: ${data.base_resp?.status_msg}`); - } - const generations: ChatGeneration[] = []; - - if (this.proVersion) { - for (const choice of data.choices) { - const messages = choice.messages ?? []; - // 取最后一条消息 - if (messages) { - const message = messages[messages.length - 1]; - const text = message?.text ?? ""; - generations.push({ - text, - message: minimaxResponseToChatMessage(message), - }); - } - } - } else { - for (const choice of data.choices) { - const text = choice?.text ?? ""; - generations.push({ - text, - message: minimaxResponseToChatMessage({ - sender_type: "BOT", - sender_name: - options?.defaultBotName ?? this.defaultBotName ?? "Assistant", - text, - }), - }); - } - } - return { - generations, - llmOutput: { tokenUsage }, - }; - } - - /** @ignore */ - async completionWithRetry( - request: MinimaxChatCompletionRequest, - stream: boolean, - signal?: AbortSignal, - onmessage?: (event: MessageEvent) => void - ) { - // The first run will get the accessToken - const makeCompletionRequest = async () => { - const url = `${this.apiUrl}?GroupId=${this.minimaxGroupId}`; - const response = await fetch(url, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${this.minimaxApiKey}`, - ...this.headers, - }, - body: JSON.stringify(request), - signal, - }); - - if (!stream) { - const json = await response.json(); - return json as ChatCompletionResponse; - } else { - if (response.body) { - const reader = response.body.getReader(); - - const decoder = new TextDecoder("utf-8"); - let data = ""; - - let continueReading = true; - while (continueReading) { - const { done, value } = await reader.read(); - if (done) { - continueReading = false; - break; - } - data += decoder.decode(value); - - let continueProcessing = true; - while (continueProcessing) { - const newlineIndex = data.indexOf("\n"); - if (newlineIndex === -1) { - continueProcessing = false; - break; - } - const line = data.slice(0, newlineIndex); - data = data.slice(newlineIndex + 1); - - if (line.startsWith("data:")) { - const event = new MessageEvent("message", { - data: line.slice("data:".length).trim(), - }); - onmessage?.(event); - } - } - } - return {} as ChatCompletionResponse; - } - return {} as ChatCompletionResponse; - } - }; - return this.caller.call(makeCompletionRequest); - } - - _llmType() { - return "minimax"; - } - - /** @ignore */ - _combineLLMOutput() { - return []; - } - - private botSettingFallback( - options?: this["ParsedCallOptions"], - messages?: BaseMessage[] - ) { - const botSettings = options?.botSetting ?? this.botSetting; - if (!botSettings) { - const systemMessages = messages?.filter((message) => { - if (ChatMessage.isInstance(message)) { - return message.role === "system"; - } - return message._getType() === "system"; - }); - - // get the last system message - if (!systemMessages?.length) { - return; - } - const lastSystemMessage = systemMessages[systemMessages.length - 1]; - - if (typeof lastSystemMessage.content !== "string") { - throw new Error( - "ChatMinimax does not support non-string message content." - ); - } - - // setting the default botSetting. - this.botSetting = [ - { - content: lastSystemMessage.content, - bot_name: - options?.defaultBotName ?? this.defaultBotName ?? "Assistant", - }, - ]; - } - } -} - -function minimaxResponseToChatMessage( - message: ChatCompletionResponseMessage -): BaseMessage { - switch (message.sender_type) { - case "USER": - return new HumanMessage(message.text || ""); - case "BOT": - return new AIMessage(message.text || "", { - function_call: message.function_call, - }); - case "FUNCTION": - return new AIMessage(message.text || ""); - default: - return new ChatMessage( - message.text || "", - message.sender_type ?? "unknown" - ); - } -} - -/** ---Response Model---* */ -/** - * Interface representing a message responsed in the Minimax chat model. - */ -interface ChatCompletionResponseMessage { - sender_type: MinimaxMessageRole; - sender_name?: string; - text: string; - function_call?: ChatCompletionResponseMessageFunctionCall; -} - -/** - * Interface representing the usage of tokens in a chat completion. - */ -interface TokenUsage { - total_tokens?: number; -} - -interface BaseResp { - status_code?: number; - status_msg?: string; -} - -/** - * The name and arguments of a function that should be called, as generated by the model. - * @export - * @interface ChatCompletionResponseMessageFunctionCall - */ -export interface ChatCompletionResponseMessageFunctionCall { - /** - * The name of the function to call. - * @type {string} - * @memberof ChatCompletionResponseMessageFunctionCall - */ - name?: string; - /** - * The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. - * @type {string} - * @memberof ChatCompletionResponseMessageFunctionCall - */ - arguments?: string; -} - -/** - * - * @export - * @interface ChatCompletionResponseChoices - */ -export interface ChatCompletionResponseChoicesPro { - /** - * - * @type {string} - * @memberof ChatCompletionResponseChoices - */ - messages?: ChatCompletionResponseMessage[]; - - /** - * - * @type {string} - * @memberof ChatCompletionResponseChoices - */ - finish_reason?: string; -} - -interface ChatCompletionResponseChoices { - delta?: string; - text?: string; - index?: number; - finish_reason?: string; -} - -/** - * Interface representing a response from a chat completion. - */ -interface ChatCompletionResponse { - model: string; - created: number; - reply: string; - input_sensitive?: boolean; - input_sensitive_type?: number; - output_sensitive?: boolean; - output_sensitive_type?: number; - usage?: TokenUsage; - base_resp?: BaseResp; - choices: Array< - ChatCompletionResponseChoicesPro & ChatCompletionResponseChoices - >; -} +export * from "@langchain/community/chat_models/minimax"; diff --git a/langchain/src/chat_models/ollama.ts b/langchain/src/chat_models/ollama.ts index a3fbdb367eb2..175e84ab3511 100644 --- a/langchain/src/chat_models/ollama.ts +++ b/langchain/src/chat_models/ollama.ts @@ -1,298 +1 @@ -import { SimpleChatModel, BaseChatModelParams } from "./base.js"; -import { BaseLanguageModelCallOptions } from "../base_language/index.js"; -import { createOllamaStream, OllamaInput } from "../util/ollama.js"; -import { CallbackManagerForLLMRun } from "../callbacks/manager.js"; -import { - AIMessageChunk, - BaseMessage, - ChatGenerationChunk, - ChatMessage, -} from "../schema/index.js"; -import type { StringWithAutocomplete } from "../util/types.js"; - -/** - * An interface defining the options for an Ollama API call. It extends - * the BaseLanguageModelCallOptions interface. - */ -export interface OllamaCallOptions extends BaseLanguageModelCallOptions {} - -/** - * A class that enables calls to the Ollama API to access large language - * models in a chat-like fashion. It extends the SimpleChatModel class and - * implements the OllamaInput interface. - * @example - * ```typescript - * const prompt = ChatPromptTemplate.fromMessages([ - * [ - * "system", - * `You are an expert translator. Format all responses as JSON objects with two keys: "original" and "translated".`, - * ], - * ["human", `Translate "{input}" into {language}.`], - * ]); - * - * const model = new ChatOllama({ - * baseUrl: "http://api.example.com", - * model: "llama2", - * format: "json", - * }); - * - * const chain = prompt.pipe(model); - * - * const result = await chain.invoke({ - * input: "I love programming", - * language: "German", - * }); - * - * ``` - */ -export class ChatOllama - extends SimpleChatModel - implements OllamaInput -{ - static lc_name() { - return "ChatOllama"; - } - - lc_serializable = true; - - model = "llama2"; - - baseUrl = "http://localhost:11434"; - - embeddingOnly?: boolean; - - f16KV?: boolean; - - frequencyPenalty?: number; - - logitsAll?: boolean; - - lowVram?: boolean; - - mainGpu?: number; - - mirostat?: number; - - mirostatEta?: number; - - mirostatTau?: number; - - numBatch?: number; - - numCtx?: number; - - numGpu?: number; - - numGqa?: number; - - numKeep?: number; - - numThread?: number; - - penalizeNewline?: boolean; - - presencePenalty?: number; - - repeatLastN?: number; - - repeatPenalty?: number; - - ropeFrequencyBase?: number; - - ropeFrequencyScale?: number; - - temperature?: number; - - stop?: string[]; - - tfsZ?: number; - - topK?: number; - - topP?: number; - - typicalP?: number; - - useMLock?: boolean; - - useMMap?: boolean; - - vocabOnly?: boolean; - - format?: StringWithAutocomplete<"json">; - - constructor(fields: OllamaInput & BaseChatModelParams) { - super(fields); - this.model = fields.model ?? this.model; - this.baseUrl = fields.baseUrl?.endsWith("/") - ? fields.baseUrl.slice(0, -1) - : fields.baseUrl ?? this.baseUrl; - this.embeddingOnly = fields.embeddingOnly; - this.f16KV = fields.f16KV; - this.frequencyPenalty = fields.frequencyPenalty; - this.logitsAll = fields.logitsAll; - this.lowVram = fields.lowVram; - this.mainGpu = fields.mainGpu; - this.mirostat = fields.mirostat; - this.mirostatEta = fields.mirostatEta; - this.mirostatTau = fields.mirostatTau; - this.numBatch = fields.numBatch; - this.numCtx = fields.numCtx; - this.numGpu = fields.numGpu; - this.numGqa = fields.numGqa; - this.numKeep = fields.numKeep; - this.numThread = fields.numThread; - this.penalizeNewline = fields.penalizeNewline; - this.presencePenalty = fields.presencePenalty; - this.repeatLastN = fields.repeatLastN; - this.repeatPenalty = fields.repeatPenalty; - this.ropeFrequencyBase = fields.ropeFrequencyBase; - this.ropeFrequencyScale = fields.ropeFrequencyScale; - this.temperature = fields.temperature; - this.stop = fields.stop; - this.tfsZ = fields.tfsZ; - this.topK = fields.topK; - this.topP = fields.topP; - this.typicalP = fields.typicalP; - this.useMLock = fields.useMLock; - this.useMMap = fields.useMMap; - this.vocabOnly = fields.vocabOnly; - this.format = fields.format; - } - - _llmType() { - return "ollama"; - } - - /** - * A method that returns the parameters for an Ollama API call. It - * includes model and options parameters. - * @param options Optional parsed call options. - * @returns An object containing the parameters for an Ollama API call. - */ - invocationParams(options?: this["ParsedCallOptions"]) { - return { - model: this.model, - format: this.format, - options: { - embedding_only: this.embeddingOnly, - f16_kv: this.f16KV, - frequency_penalty: this.frequencyPenalty, - logits_all: this.logitsAll, - low_vram: this.lowVram, - main_gpu: this.mainGpu, - mirostat: this.mirostat, - mirostat_eta: this.mirostatEta, - mirostat_tau: this.mirostatTau, - num_batch: this.numBatch, - num_ctx: this.numCtx, - num_gpu: this.numGpu, - num_gqa: this.numGqa, - num_keep: this.numKeep, - num_thread: this.numThread, - penalize_newline: this.penalizeNewline, - presence_penalty: this.presencePenalty, - repeat_last_n: this.repeatLastN, - repeat_penalty: this.repeatPenalty, - rope_frequency_base: this.ropeFrequencyBase, - rope_frequency_scale: this.ropeFrequencyScale, - temperature: this.temperature, - stop: options?.stop ?? this.stop, - tfs_z: this.tfsZ, - top_k: this.topK, - top_p: this.topP, - typical_p: this.typicalP, - use_mlock: this.useMLock, - use_mmap: this.useMMap, - vocab_only: this.vocabOnly, - }, - }; - } - - _combineLLMOutput() { - return {}; - } - - async *_streamResponseChunks( - input: BaseMessage[], - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): AsyncGenerator { - const stream = await this.caller.call(async () => - createOllamaStream( - this.baseUrl, - { - ...this.invocationParams(options), - prompt: this._formatMessagesAsPrompt(input), - }, - options - ) - ); - for await (const chunk of stream) { - if (!chunk.done) { - yield new ChatGenerationChunk({ - text: chunk.response, - message: new AIMessageChunk({ content: chunk.response }), - }); - await runManager?.handleLLMNewToken(chunk.response ?? ""); - } else { - yield new ChatGenerationChunk({ - text: "", - message: new AIMessageChunk({ content: "" }), - generationInfo: { - model: chunk.model, - total_duration: chunk.total_duration, - load_duration: chunk.load_duration, - prompt_eval_count: chunk.prompt_eval_count, - prompt_eval_duration: chunk.prompt_eval_duration, - eval_count: chunk.eval_count, - eval_duration: chunk.eval_duration, - }, - }); - } - } - } - - protected _formatMessagesAsPrompt(messages: BaseMessage[]): string { - const formattedMessages = messages - .map((message) => { - let messageText; - if (message._getType() === "human") { - messageText = `[INST] ${message.content} [/INST]`; - } else if (message._getType() === "ai") { - messageText = message.content; - } else if (message._getType() === "system") { - messageText = `<> ${message.content} <>`; - } else if (ChatMessage.isInstance(message)) { - messageText = `\n\n${message.role[0].toUpperCase()}${message.role.slice( - 1 - )}: ${message.content}`; - } else { - console.warn( - `Unsupported message type passed to Ollama: "${message._getType()}"` - ); - messageText = ""; - } - return messageText; - }) - .join("\n"); - return formattedMessages; - } - - /** @ignore */ - async _call( - messages: BaseMessage[], - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): Promise { - const chunks = []; - for await (const chunk of this._streamResponseChunks( - messages, - options, - runManager - )) { - chunks.push(chunk.message.content); - } - return chunks.join(""); - } -} +export * from "@langchain/community/chat_models/ollama"; diff --git a/langchain/src/chat_models/openai.ts b/langchain/src/chat_models/openai.ts index c8e99c525246..cf822cb30a8d 100644 --- a/langchain/src/chat_models/openai.ts +++ b/langchain/src/chat_models/openai.ts @@ -1,836 +1,20 @@ -import { type ClientOptions, OpenAI as OpenAIClient } from "openai"; +import { + ChatOpenAI, + type ChatOpenAICallOptions, + messageToOpenAIRole, +} from "@langchain/openai"; import { CallbackManagerForLLMRun } from "../callbacks/manager.js"; -import { - AIMessage, - AIMessageChunk, - BaseMessage, - ChatGeneration, - ChatGenerationChunk, - ChatMessage, - ChatMessageChunk, - ChatResult, - FunctionMessageChunk, - HumanMessageChunk, - SystemMessageChunk, - ToolMessage, - ToolMessageChunk, -} from "../schema/index.js"; -import { StructuredTool } from "../tools/base.js"; -import { formatToOpenAITool } from "../tools/convert_to_openai.js"; -import { - AzureOpenAIInput, - OpenAICallOptions, - OpenAIChatInput, - OpenAICoreRequestOptions, - LegacyOpenAIInput, -} from "../types/openai-types.js"; -import { OpenAIEndpointConfig, getEndpoint } from "../util/azure.js"; -import { getEnvironmentVariable } from "../util/env.js"; +import { BaseMessage, ChatMessage, ChatResult } from "../schema/index.js"; import { promptLayerTrackRequest } from "../util/prompt-layer.js"; -import { BaseChatModel, BaseChatModelParams } from "./base.js"; -import { BaseFunctionCallOptions } from "../base_language/index.js"; -import { NewTokenIndices } from "../callbacks/base.js"; -import { wrapOpenAIClientError } from "../util/openai.js"; -import { - FunctionDef, - formatFunctionDefinitions, -} from "../util/openai-format-fndef.js"; - -export type { AzureOpenAIInput, OpenAICallOptions, OpenAIChatInput }; - -interface TokenUsage { - completionTokens?: number; - promptTokens?: number; - totalTokens?: number; -} - -interface OpenAILLMOutput { - tokenUsage: TokenUsage; -} - -// TODO import from SDK when available -type OpenAIRoleEnum = "system" | "assistant" | "user" | "function" | "tool"; - -type OpenAICompletionParam = - OpenAIClient.Chat.Completions.ChatCompletionMessageParam; -type OpenAIFnDef = OpenAIClient.Chat.ChatCompletionCreateParams.Function; -type OpenAIFnCallOption = OpenAIClient.Chat.ChatCompletionFunctionCallOption; - -function extractGenericMessageCustomRole(message: ChatMessage) { - if ( - message.role !== "system" && - message.role !== "assistant" && - message.role !== "user" && - message.role !== "function" && - message.role !== "tool" - ) { - console.warn(`Unknown message role: ${message.role}`); - } - - return message.role as OpenAIRoleEnum; -} - -function messageToOpenAIRole(message: BaseMessage): OpenAIRoleEnum { - const type = message._getType(); - switch (type) { - case "system": - return "system"; - case "ai": - return "assistant"; - case "human": - return "user"; - case "function": - return "function"; - case "tool": - return "tool"; - case "generic": { - if (!ChatMessage.isInstance(message)) - throw new Error("Invalid generic chat message"); - return extractGenericMessageCustomRole(message); - } - default: - throw new Error(`Unknown message type: ${type}`); - } -} - -function openAIResponseToChatMessage( - message: OpenAIClient.Chat.Completions.ChatCompletionMessage -): BaseMessage { - switch (message.role) { - case "assistant": - return new AIMessage(message.content || "", { - function_call: message.function_call, - tool_calls: message.tool_calls, - }); - default: - return new ChatMessage(message.content || "", message.role ?? "unknown"); - } -} - -function _convertDeltaToMessageChunk( - // eslint-disable-next-line @typescript-eslint/no-explicit-any - delta: Record, - defaultRole?: OpenAIRoleEnum -) { - const role = delta.role ?? defaultRole; - const content = delta.content ?? ""; - let additional_kwargs; - if (delta.function_call) { - additional_kwargs = { - function_call: delta.function_call, - }; - } else if (delta.tool_calls) { - additional_kwargs = { - tool_calls: delta.tool_calls, - }; - } else { - additional_kwargs = {}; - } - if (role === "user") { - return new HumanMessageChunk({ content }); - } else if (role === "assistant") { - return new AIMessageChunk({ content, additional_kwargs }); - } else if (role === "system") { - return new SystemMessageChunk({ content }); - } else if (role === "function") { - return new FunctionMessageChunk({ - content, - additional_kwargs, - name: delta.name, - }); - } else if (role === "tool") { - return new ToolMessageChunk({ - content, - additional_kwargs, - tool_call_id: delta.tool_call_id, - }); - } else { - return new ChatMessageChunk({ content, role }); - } -} - -function convertMessagesToOpenAIParams(messages: BaseMessage[]) { - // TODO: Function messages do not support array content, fix cast - return messages.map( - (message) => - ({ - role: messageToOpenAIRole(message), - content: message.content, - name: message.name, - function_call: message.additional_kwargs.function_call, - tool_calls: message.additional_kwargs.tool_calls, - tool_call_id: (message as ToolMessage).tool_call_id, - } as OpenAICompletionParam) - ); -} - -export interface ChatOpenAICallOptions - extends OpenAICallOptions, - BaseFunctionCallOptions { - tools?: StructuredTool[] | OpenAIClient.ChatCompletionTool[]; - tool_choice?: OpenAIClient.ChatCompletionToolChoiceOption; - promptIndex?: number; - response_format?: { type: "json_object" }; - seed?: number; -} - -/** - * Wrapper around OpenAI large language models that use the Chat endpoint. - * - * To use you should have the `openai` package installed, with the - * `OPENAI_API_KEY` environment variable set. - * - * To use with Azure you should have the `openai` package installed, with the - * `AZURE_OPENAI_API_KEY`, - * `AZURE_OPENAI_API_INSTANCE_NAME`, - * `AZURE_OPENAI_API_DEPLOYMENT_NAME` - * and `AZURE_OPENAI_API_VERSION` environment variable set. - * `AZURE_OPENAI_BASE_PATH` is optional and will override `AZURE_OPENAI_API_INSTANCE_NAME` if you need to use a custom endpoint. - * - * @remarks - * Any parameters that are valid to be passed to {@link - * https://platform.openai.com/docs/api-reference/chat/create | - * `openai.createChatCompletion`} can be passed through {@link modelKwargs}, even - * if not explicitly available on this class. - * @example - * ```typescript - * // Create a new instance of ChatOpenAI with specific temperature and model name settings - * const model = new ChatOpenAI({ - * temperature: 0.9, - * modelName: "ft:gpt-3.5-turbo-0613:{ORG_NAME}::{MODEL_ID}", - * }); - * - * // Invoke the model with a message and await the response - * const message = await model.invoke("Hi there!"); - * - * // Log the response to the console - * console.log(message); - * - * ``` - */ -export class ChatOpenAI< - CallOptions extends ChatOpenAICallOptions = ChatOpenAICallOptions - > - extends BaseChatModel - implements OpenAIChatInput, AzureOpenAIInput -{ - static lc_name() { - return "ChatOpenAI"; - } - - get callKeys() { - return [ - ...super.callKeys, - "options", - "function_call", - "functions", - "tools", - "tool_choice", - "promptIndex", - "response_format", - "seed", - ]; - } - - lc_serializable = true; - - get lc_secrets(): { [key: string]: string } | undefined { - return { - openAIApiKey: "OPENAI_API_KEY", - azureOpenAIApiKey: "AZURE_OPENAI_API_KEY", - organization: "OPENAI_ORGANIZATION", - }; - } - - get lc_aliases(): Record { - return { - modelName: "model", - openAIApiKey: "openai_api_key", - azureOpenAIApiVersion: "azure_openai_api_version", - azureOpenAIApiKey: "azure_openai_api_key", - azureOpenAIApiInstanceName: "azure_openai_api_instance_name", - azureOpenAIApiDeploymentName: "azure_openai_api_deployment_name", - }; - } - - temperature = 1; - - topP = 1; - - frequencyPenalty = 0; - - presencePenalty = 0; - - n = 1; - - logitBias?: Record; - - modelName = "gpt-3.5-turbo"; - modelKwargs?: OpenAIChatInput["modelKwargs"]; - - stop?: string[]; - - user?: string; - - timeout?: number; - - streaming = false; - - maxTokens?: number; - - openAIApiKey?: string; - - azureOpenAIApiVersion?: string; - - azureOpenAIApiKey?: string; - - azureOpenAIApiInstanceName?: string; - - azureOpenAIApiDeploymentName?: string; - - azureOpenAIBasePath?: string; - - organization?: string; - - private client: OpenAIClient; - - private clientConfig: ClientOptions; - - constructor( - fields?: Partial & - Partial & - BaseChatModelParams & { - configuration?: ClientOptions & LegacyOpenAIInput; - }, - /** @deprecated */ - configuration?: ClientOptions & LegacyOpenAIInput - ) { - super(fields ?? {}); - - this.openAIApiKey = - fields?.openAIApiKey ?? getEnvironmentVariable("OPENAI_API_KEY"); - - this.azureOpenAIApiKey = - fields?.azureOpenAIApiKey ?? - getEnvironmentVariable("AZURE_OPENAI_API_KEY"); - - if (!this.azureOpenAIApiKey && !this.openAIApiKey) { - throw new Error("OpenAI or Azure OpenAI API key not found"); - } - - this.azureOpenAIApiInstanceName = - fields?.azureOpenAIApiInstanceName ?? - getEnvironmentVariable("AZURE_OPENAI_API_INSTANCE_NAME"); - - this.azureOpenAIApiDeploymentName = - fields?.azureOpenAIApiDeploymentName ?? - getEnvironmentVariable("AZURE_OPENAI_API_DEPLOYMENT_NAME"); - - this.azureOpenAIApiVersion = - fields?.azureOpenAIApiVersion ?? - getEnvironmentVariable("AZURE_OPENAI_API_VERSION"); - - this.azureOpenAIBasePath = - fields?.azureOpenAIBasePath ?? - getEnvironmentVariable("AZURE_OPENAI_BASE_PATH"); - - this.organization = - fields?.configuration?.organization ?? - getEnvironmentVariable("OPENAI_ORGANIZATION"); - - this.modelName = fields?.modelName ?? this.modelName; - this.modelKwargs = fields?.modelKwargs ?? {}; - this.timeout = fields?.timeout; - - this.temperature = fields?.temperature ?? this.temperature; - this.topP = fields?.topP ?? this.topP; - this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty; - this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty; - this.maxTokens = fields?.maxTokens; - this.n = fields?.n ?? this.n; - this.logitBias = fields?.logitBias; - this.stop = fields?.stop; - this.user = fields?.user; - - this.streaming = fields?.streaming ?? false; - - if (this.azureOpenAIApiKey) { - if (!this.azureOpenAIApiInstanceName && !this.azureOpenAIBasePath) { - throw new Error("Azure OpenAI API instance name not found"); - } - if (!this.azureOpenAIApiDeploymentName) { - throw new Error("Azure OpenAI API deployment name not found"); - } - if (!this.azureOpenAIApiVersion) { - throw new Error("Azure OpenAI API version not found"); - } - this.openAIApiKey = this.openAIApiKey ?? ""; - } - - this.clientConfig = { - apiKey: this.openAIApiKey, - organization: this.organization, - baseURL: configuration?.basePath ?? fields?.configuration?.basePath, - dangerouslyAllowBrowser: true, - defaultHeaders: - configuration?.baseOptions?.headers ?? - fields?.configuration?.baseOptions?.headers, - defaultQuery: - configuration?.baseOptions?.params ?? - fields?.configuration?.baseOptions?.params, - ...configuration, - ...fields?.configuration, - }; - } - - /** - * Get the parameters used to invoke the model - */ - invocationParams( - options?: this["ParsedCallOptions"] - ): Omit { - function isStructuredToolArray( - tools?: unknown[] - ): tools is StructuredTool[] { - return ( - tools !== undefined && - tools.every((tool) => - Array.isArray((tool as StructuredTool).lc_namespace) - ) - ); - } - const params: Omit< - OpenAIClient.Chat.ChatCompletionCreateParams, - "messages" - > = { - model: this.modelName, - temperature: this.temperature, - top_p: this.topP, - frequency_penalty: this.frequencyPenalty, - presence_penalty: this.presencePenalty, - max_tokens: this.maxTokens === -1 ? undefined : this.maxTokens, - n: this.n, - logit_bias: this.logitBias, - stop: options?.stop ?? this.stop, - user: this.user, - stream: this.streaming, - functions: options?.functions, - function_call: options?.function_call, - tools: isStructuredToolArray(options?.tools) - ? options?.tools.map(formatToOpenAITool) - : options?.tools, - tool_choice: options?.tool_choice, - response_format: options?.response_format, - seed: options?.seed, - ...this.modelKwargs, - }; - return params; - } - - /** @ignore */ - _identifyingParams(): Omit< - OpenAIClient.Chat.ChatCompletionCreateParams, - "messages" - > & { - model_name: string; - } & ClientOptions { - return { - model_name: this.modelName, - ...this.invocationParams(), - ...this.clientConfig, - }; - } - - async *_streamResponseChunks( - messages: BaseMessage[], - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): AsyncGenerator { - const messagesMapped: OpenAICompletionParam[] = - convertMessagesToOpenAIParams(messages); - const params = { - ...this.invocationParams(options), - messages: messagesMapped, - stream: true as const, - }; - let defaultRole: OpenAIRoleEnum | undefined; - const streamIterable = await this.completionWithRetry(params, options); - for await (const data of streamIterable) { - const choice = data?.choices[0]; - if (!choice) { - continue; - } +export { + type AzureOpenAIInput, + type OpenAICallOptions, + type OpenAIChatInput, +} from "@langchain/openai"; - const { delta } = choice; - if (!delta) { - continue; - } - const chunk = _convertDeltaToMessageChunk(delta, defaultRole); - defaultRole = delta.role ?? defaultRole; - const newTokenIndices = { - prompt: options.promptIndex ?? 0, - completion: choice.index ?? 0, - }; - if (typeof chunk.content !== "string") { - console.log( - "[WARNING]: Received non-string content from OpenAI. This is currently not supported." - ); - continue; - } - const generationChunk = new ChatGenerationChunk({ - message: chunk, - text: chunk.content, - generationInfo: newTokenIndices, - }); - yield generationChunk; - // eslint-disable-next-line no-void - void runManager?.handleLLMNewToken( - generationChunk.text ?? "", - newTokenIndices, - undefined, - undefined, - undefined, - { chunk: generationChunk } - ); - } - if (options.signal?.aborted) { - throw new Error("AbortError"); - } - } - - /** - * Get the identifying parameters for the model - * - */ - identifyingParams() { - return this._identifyingParams(); - } - - /** @ignore */ - async _generate( - messages: BaseMessage[], - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): Promise { - const tokenUsage: TokenUsage = {}; - const params = this.invocationParams(options); - const messagesMapped: OpenAICompletionParam[] = - convertMessagesToOpenAIParams(messages); - - if (params.stream) { - const stream = this._streamResponseChunks(messages, options, runManager); - const finalChunks: Record = {}; - for await (const chunk of stream) { - const index = - (chunk.generationInfo as NewTokenIndices)?.completion ?? 0; - if (finalChunks[index] === undefined) { - finalChunks[index] = chunk; - } else { - finalChunks[index] = finalChunks[index].concat(chunk); - } - } - const generations = Object.entries(finalChunks) - .sort(([aKey], [bKey]) => parseInt(aKey, 10) - parseInt(bKey, 10)) - .map(([_, value]) => value); - - const { functions, function_call } = this.invocationParams(options); - - // OpenAI does not support token usage report under stream mode, - // fallback to estimation. - - const promptTokenUsage = await this.getEstimatedTokenCountFromPrompt( - messages, - functions, - function_call - ); - const completionTokenUsage = await this.getNumTokensFromGenerations( - generations - ); - - tokenUsage.promptTokens = promptTokenUsage; - tokenUsage.completionTokens = completionTokenUsage; - tokenUsage.totalTokens = promptTokenUsage + completionTokenUsage; - return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } }; - } else { - const data = await this.completionWithRetry( - { - ...params, - stream: false, - messages: messagesMapped, - }, - { - signal: options?.signal, - ...options?.options, - } - ); - const { - completion_tokens: completionTokens, - prompt_tokens: promptTokens, - total_tokens: totalTokens, - } = data?.usage ?? {}; - - if (completionTokens) { - tokenUsage.completionTokens = - (tokenUsage.completionTokens ?? 0) + completionTokens; - } - - if (promptTokens) { - tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + promptTokens; - } - - if (totalTokens) { - tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens; - } - - const generations: ChatGeneration[] = []; - for (const part of data?.choices ?? []) { - const text = part.message?.content ?? ""; - const generation: ChatGeneration = { - text, - message: openAIResponseToChatMessage( - part.message ?? { role: "assistant" } - ), - }; - if (part.finish_reason) { - generation.generationInfo = { finish_reason: part.finish_reason }; - } - generations.push(generation); - } - return { - generations, - llmOutput: { tokenUsage }, - }; - } - } - - /** - * Estimate the number of tokens a prompt will use. - * Modified from: https://github.com/hmarr/openai-chat-tokens/blob/main/src/index.ts - */ - private async getEstimatedTokenCountFromPrompt( - messages: BaseMessage[], - functions?: OpenAIFnDef[], - function_call?: "none" | "auto" | OpenAIFnCallOption - ): Promise { - // It appears that if functions are present, the first system message is padded with a trailing newline. This - // was inferred by trying lots of combinations of messages and functions and seeing what the token counts were. - - let tokens = (await this.getNumTokensFromMessages(messages)).totalCount; - - // If there are functions, add the function definitions as they count towards token usage - if (functions && function_call !== "auto") { - const promptDefinitions = formatFunctionDefinitions( - functions as unknown as FunctionDef[] - ); - tokens += await this.getNumTokens(promptDefinitions); - tokens += 9; // Add nine per completion - } - - // If there's a system message _and_ functions are present, subtract four tokens. I assume this is because - // functions typically add a system message, but reuse the first one if it's already there. This offsets - // the extra 9 tokens added by the function definitions. - if (functions && messages.find((m) => m._getType() === "system")) { - tokens -= 4; - } - - // If function_call is 'none', add one token. - // If it's a FunctionCall object, add 4 + the number of tokens in the function name. - // If it's undefined or 'auto', don't add anything. - if (function_call === "none") { - tokens += 1; - } else if (typeof function_call === "object") { - tokens += (await this.getNumTokens(function_call.name)) + 4; - } - - return tokens; - } - - /** - * Estimate the number of tokens an array of generations have used. - */ - private async getNumTokensFromGenerations(generations: ChatGeneration[]) { - const generationUsages = await Promise.all( - generations.map(async (generation) => { - if (generation.message.additional_kwargs?.function_call) { - return (await this.getNumTokensFromMessages([generation.message])) - .countPerMessage[0]; - } else { - return await this.getNumTokens(generation.message.content); - } - }) - ); - - return generationUsages.reduce((a, b) => a + b, 0); - } - - async getNumTokensFromMessages(messages: BaseMessage[]) { - let totalCount = 0; - let tokensPerMessage = 0; - let tokensPerName = 0; - - // From: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb - if (this.modelName === "gpt-3.5-turbo-0301") { - tokensPerMessage = 4; - tokensPerName = -1; - } else { - tokensPerMessage = 3; - tokensPerName = 1; - } - - const countPerMessage = await Promise.all( - messages.map(async (message) => { - const textCount = await this.getNumTokens(message.content); - const roleCount = await this.getNumTokens(messageToOpenAIRole(message)); - const nameCount = - message.name !== undefined - ? tokensPerName + (await this.getNumTokens(message.name)) - : 0; - let count = textCount + tokensPerMessage + roleCount + nameCount; - - // From: https://github.com/hmarr/openai-chat-tokens/blob/main/src/index.ts messageTokenEstimate - const openAIMessage = message; - if (openAIMessage._getType() === "function") { - count -= 2; - } - if (openAIMessage.additional_kwargs?.function_call) { - count += 3; - } - if (openAIMessage?.additional_kwargs.function_call?.name) { - count += await this.getNumTokens( - openAIMessage.additional_kwargs.function_call?.name - ); - } - if (openAIMessage.additional_kwargs.function_call?.arguments) { - count += await this.getNumTokens( - // Remove newlines and spaces - JSON.stringify( - JSON.parse( - openAIMessage.additional_kwargs.function_call?.arguments - ) - ) - ); - } - - totalCount += count; - return count; - }) - ); - - totalCount += 3; // every reply is primed with <|start|>assistant<|message|> - - return { totalCount, countPerMessage }; - } - - /** - * Calls the OpenAI API with retry logic in case of failures. - * @param request The request to send to the OpenAI API. - * @param options Optional configuration for the API call. - * @returns The response from the OpenAI API. - */ - async completionWithRetry( - request: OpenAIClient.Chat.ChatCompletionCreateParamsStreaming, - options?: OpenAICoreRequestOptions - ): Promise>; - - async completionWithRetry( - request: OpenAIClient.Chat.ChatCompletionCreateParamsNonStreaming, - options?: OpenAICoreRequestOptions - ): Promise; - - async completionWithRetry( - request: - | OpenAIClient.Chat.ChatCompletionCreateParamsStreaming - | OpenAIClient.Chat.ChatCompletionCreateParamsNonStreaming, - options?: OpenAICoreRequestOptions - ): Promise< - | AsyncIterable - | OpenAIClient.Chat.Completions.ChatCompletion - > { - const requestOptions = this._getClientOptions(options); - return this.caller.call(async () => { - try { - const res = await this.client.chat.completions.create( - request, - requestOptions - ); - return res; - } catch (e) { - const error = wrapOpenAIClientError(e); - throw error; - } - }); - } - - private _getClientOptions(options: OpenAICoreRequestOptions | undefined) { - if (!this.client) { - const openAIEndpointConfig: OpenAIEndpointConfig = { - azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName, - azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName, - azureOpenAIApiKey: this.azureOpenAIApiKey, - azureOpenAIBasePath: this.azureOpenAIBasePath, - baseURL: this.clientConfig.baseURL, - }; - - const endpoint = getEndpoint(openAIEndpointConfig); - const params = { - ...this.clientConfig, - baseURL: endpoint, - timeout: this.timeout, - maxRetries: 0, - }; - if (!params.baseURL) { - delete params.baseURL; - } - - this.client = new OpenAIClient(params); - } - const requestOptions = { - ...this.clientConfig, - ...options, - } as OpenAICoreRequestOptions; - if (this.azureOpenAIApiKey) { - requestOptions.headers = { - "api-key": this.azureOpenAIApiKey, - ...requestOptions.headers, - }; - requestOptions.query = { - "api-version": this.azureOpenAIApiVersion, - ...requestOptions.query, - }; - } - return requestOptions; - } - - _llmType() { - return "openai"; - } - - /** @ignore */ - _combineLLMOutput(...llmOutputs: OpenAILLMOutput[]): OpenAILLMOutput { - return llmOutputs.reduce<{ - [key in keyof OpenAILLMOutput]: Required; - }>( - (acc, llmOutput) => { - if (llmOutput && llmOutput.tokenUsage) { - acc.tokenUsage.completionTokens += - llmOutput.tokenUsage.completionTokens ?? 0; - acc.tokenUsage.promptTokens += llmOutput.tokenUsage.promptTokens ?? 0; - acc.tokenUsage.totalTokens += llmOutput.tokenUsage.totalTokens ?? 0; - } - return acc; - }, - { - tokenUsage: { - completionTokens: 0, - promptTokens: 0, - totalTokens: 0, - }, - } - ); - } -} +export { type ChatOpenAICallOptions, ChatOpenAI }; export class PromptLayerChatOpenAI extends ChatOpenAI { promptLayerApiKey?: string; diff --git a/langchain/src/chat_models/portkey.ts b/langchain/src/chat_models/portkey.ts index fd41231b9ac3..41b8ae09537c 100644 --- a/langchain/src/chat_models/portkey.ts +++ b/langchain/src/chat_models/portkey.ts @@ -1,182 +1 @@ -import { LLMOptions } from "portkey-ai"; -import { CallbackManagerForLLMRun } from "../callbacks/manager.js"; -import { PortkeySession, getPortkeySession } from "../llms/portkey.js"; -import { - AIMessage, - AIMessageChunk, - BaseMessage, - ChatGeneration, - ChatGenerationChunk, - ChatMessage, - ChatMessageChunk, - ChatResult, - FunctionMessageChunk, - HumanMessage, - HumanMessageChunk, - SystemMessage, - SystemMessageChunk, -} from "../schema/index.js"; -import { BaseChatModel } from "./base.js"; - -interface Message { - role?: string; - content?: string; -} - -function portkeyResponseToChatMessage(message: Message): BaseMessage { - switch (message.role) { - case "user": - return new HumanMessage(message.content || ""); - case "assistant": - return new AIMessage(message.content || ""); - case "system": - return new SystemMessage(message.content || ""); - default: - return new ChatMessage(message.content || "", message.role ?? "unknown"); - } -} - -function _convertDeltaToMessageChunk( - // eslint-disable-next-line @typescript-eslint/no-explicit-any - delta: Record -) { - const { role } = delta; - const content = delta.content ?? ""; - let additional_kwargs; - if (delta.function_call) { - additional_kwargs = { - function_call: delta.function_call, - }; - } else { - additional_kwargs = {}; - } - if (role === "user") { - return new HumanMessageChunk({ content }); - } else if (role === "assistant") { - return new AIMessageChunk({ content, additional_kwargs }); - } else if (role === "system") { - return new SystemMessageChunk({ content }); - } else if (role === "function") { - return new FunctionMessageChunk({ - content, - additional_kwargs, - name: delta.name, - }); - } else { - return new ChatMessageChunk({ content, role }); - } -} - -export class PortkeyChat extends BaseChatModel { - apiKey?: string = undefined; - - baseURL?: string = undefined; - - mode?: string = undefined; - - llms?: [LLMOptions] | null = undefined; - - session: PortkeySession; - - constructor(init?: Partial) { - super(init ?? {}); - this.apiKey = init?.apiKey; - this.baseURL = init?.baseURL; - this.mode = init?.mode; - this.llms = init?.llms; - this.session = getPortkeySession({ - apiKey: this.apiKey, - baseURL: this.baseURL, - llms: this.llms, - mode: this.mode, - }); - } - - _llmType() { - return "portkey"; - } - - async _generate( - messages: BaseMessage[], - options: this["ParsedCallOptions"], - _?: CallbackManagerForLLMRun - ): Promise { - const messagesList = messages.map((message) => { - if (typeof message.content !== "string") { - throw new Error( - "PortkeyChat does not support non-string message content." - ); - } - return { - role: message._getType() as string, - content: message.content, - }; - }); - const response = await this.session.portkey.chatCompletions.create({ - messages: messagesList, - ...options, - stream: false, - }); - const generations: ChatGeneration[] = []; - for (const data of response.choices ?? []) { - const text = data.message?.content ?? ""; - const generation: ChatGeneration = { - text, - message: portkeyResponseToChatMessage(data.message ?? {}), - }; - if (data.finish_reason) { - generation.generationInfo = { finish_reason: data.finish_reason }; - } - generations.push(generation); - } - - return { - generations, - }; - } - - async *_streamResponseChunks( - messages: BaseMessage[], - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): AsyncGenerator { - const messagesList = messages.map((message) => { - if (typeof message.content !== "string") { - throw new Error( - "PortkeyChat does not support non-string message content." - ); - } - return { - role: message._getType() as string, - content: message.content, - }; - }); - const response = await this.session.portkey.chatCompletions.create({ - messages: messagesList, - ...options, - stream: true, - }); - for await (const data of response) { - const choice = data?.choices[0]; - if (!choice) { - continue; - } - const chunk = new ChatGenerationChunk({ - message: _convertDeltaToMessageChunk(choice.delta ?? {}), - text: choice.message?.content ?? "", - generationInfo: { - finishReason: choice.finish_reason, - }, - }); - yield chunk; - void runManager?.handleLLMNewToken(chunk.text ?? ""); - } - if (options.signal?.aborted) { - throw new Error("AbortError"); - } - } - - _combineLLMOutput() { - return {}; - } -} +export * from "@langchain/community/chat_models/portkey"; diff --git a/langchain/src/chat_models/yandex.ts b/langchain/src/chat_models/yandex.ts index 5365183792ca..3dbac8ffb540 100644 --- a/langchain/src/chat_models/yandex.ts +++ b/langchain/src/chat_models/yandex.ts @@ -1,142 +1 @@ -import { CallbackManagerForLLMRun } from "../callbacks/manager.js"; -import { YandexGPTInputs } from "../llms/yandex.js"; -import { - AIMessage, - BaseMessage, - ChatResult, - ChatGeneration, -} from "../schema/index.js"; -import { getEnvironmentVariable } from "../util/env.js"; -import { BaseChatModel } from "./base.js"; - -const apiUrl = "https://llm.api.cloud.yandex.net/llm/v1alpha/chat"; - -interface ParsedMessage { - role: string; - text: string; -} - -function _parseChatHistory(history: BaseMessage[]): [ParsedMessage[], string] { - const chatHistory: ParsedMessage[] = []; - let instruction = ""; - - for (const message of history) { - if (typeof message.content !== "string") { - throw new Error( - "ChatYandexGPT does not support non-string message content." - ); - } - if ("content" in message) { - if (message._getType() === "human") { - chatHistory.push({ role: "user", text: message.content }); - } else if (message._getType() === "ai") { - chatHistory.push({ role: "assistant", text: message.content }); - } else if (message._getType() === "system") { - instruction = message.content; - } - } - } - - return [chatHistory, instruction]; -} - -/** - * @example - * ```typescript - * const chat = new ChatYandexGPT({}); - * // The assistant is set to translate English to French. - * const res = await chat.call([ - * new SystemMessage( - * "You are a helpful assistant that translates English to French." - * ), - * new HumanMessage("I love programming."), - * ]); - * console.log(res); - * ``` - */ -export class ChatYandexGPT extends BaseChatModel { - apiKey?: string; - - iamToken?: string; - - temperature = 0.6; - - maxTokens = 1700; - - model = "general"; - - constructor(fields?: YandexGPTInputs) { - super(fields ?? {}); - - const apiKey = fields?.apiKey ?? getEnvironmentVariable("YC_API_KEY"); - - const iamToken = fields?.iamToken ?? getEnvironmentVariable("YC_IAM_TOKEN"); - - if (apiKey === undefined && iamToken === undefined) { - throw new Error( - "Please set the YC_API_KEY or YC_IAM_TOKEN environment variable or pass it to the constructor as the apiKey or iamToken field." - ); - } - - this.apiKey = apiKey; - this.iamToken = iamToken; - this.maxTokens = fields?.maxTokens ?? this.maxTokens; - this.temperature = fields?.temperature ?? this.temperature; - this.model = fields?.model ?? this.model; - } - - _llmType() { - return "yandexgpt"; - } - - _combineLLMOutput?() { - return {}; - } - - /** @ignore */ - async _generate( - messages: BaseMessage[], - options: this["ParsedCallOptions"], - _?: CallbackManagerForLLMRun | undefined - ): Promise { - const [messageHistory, instruction] = _parseChatHistory(messages); - const headers = { "Content-Type": "application/json", Authorization: "" }; - if (this.apiKey !== undefined) { - headers.Authorization = `Api-Key ${this.apiKey}`; - } else { - headers.Authorization = `Bearer ${this.iamToken}`; - } - const bodyData = { - model: this.model, - generationOptions: { - temperature: this.temperature, - maxTokens: this.maxTokens, - }, - messages: messageHistory, - instructionText: instruction, - }; - const response = await fetch(apiUrl, { - method: "POST", - headers, - body: JSON.stringify(bodyData), - signal: options?.signal, - }); - if (!response.ok) { - throw new Error( - `Failed to fetch ${apiUrl} from YandexGPT: ${response.status}` - ); - } - const responseData = await response.json(); - const { result } = responseData; - const { text } = result.message; - const totalTokens = result.num_tokens; - const generations: ChatGeneration[] = [ - { text, message: new AIMessage(text) }, - ]; - - return { - generations, - llmOutput: { totalTokens }, - }; - } -} +export * from "@langchain/community/chat_models/yandex"; diff --git a/langchain/src/document_loaders/fs/openai_whisper_audio.ts b/langchain/src/document_loaders/fs/openai_whisper_audio.ts index fbbdbef6f25b..468eba28c1e5 100644 --- a/langchain/src/document_loaders/fs/openai_whisper_audio.ts +++ b/langchain/src/document_loaders/fs/openai_whisper_audio.ts @@ -1,4 +1,4 @@ -import { type ClientOptions, OpenAI as OpenAIClient, toFile } from "openai"; +import { type ClientOptions, OpenAIClient, toFile } from "@langchain/openai"; import { Document } from "../../document.js"; import { BufferLoader } from "./buffer.js"; diff --git a/langchain/src/document_transformers/html_to_text.ts b/langchain/src/document_transformers/html_to_text.ts index 987d1c6feed3..a3d4023dfdd0 100644 --- a/langchain/src/document_transformers/html_to_text.ts +++ b/langchain/src/document_transformers/html_to_text.ts @@ -1,43 +1 @@ -import { htmlToText } from "html-to-text"; -import type { HtmlToTextOptions } from "html-to-text"; -import { Document } from "../document.js"; -import { MappingDocumentTransformer } from "../schema/document.js"; - -/** - * A transformer that converts HTML content to plain text. - * @example - * ```typescript - * const loader = new CheerioWebBaseLoader("https://example.com/some-page"); - * const docs = await loader.load(); - * - * const splitter = new RecursiveCharacterTextSplitter({ - * maxCharacterCount: 1000, - * }); - * const transformer = new HtmlToTextTransformer(); - * - * // The sequence of text splitting followed by HTML to text transformation - * const sequence = splitter.pipe(transformer); - * - * // Processing the loaded documents through the sequence - * const newDocuments = await sequence.invoke(docs); - * - * console.log(newDocuments); - * ``` - */ -export class HtmlToTextTransformer extends MappingDocumentTransformer { - static lc_name() { - return "HtmlToTextTransformer"; - } - - constructor(protected options: HtmlToTextOptions = {}) { - super(options); - } - - async _transformDocument(document: Document): Promise { - const extractedContent = htmlToText(document.pageContent, this.options); - return new Document({ - pageContent: extractedContent, - metadata: { ...document.metadata }, - }); - } -} +export * from "@langchain/community/document_transformers/html_to_text"; diff --git a/langchain/src/document_transformers/mozilla_readability.ts b/langchain/src/document_transformers/mozilla_readability.ts index 1eb302bbbc38..481f786c13d0 100644 --- a/langchain/src/document_transformers/mozilla_readability.ts +++ b/langchain/src/document_transformers/mozilla_readability.ts @@ -1,52 +1 @@ -import { Readability } from "@mozilla/readability"; -import { JSDOM } from "jsdom"; -import { Options } from "mozilla-readability"; -import { Document } from "../document.js"; -import { MappingDocumentTransformer } from "../schema/document.js"; - -/** - * A transformer that uses the Mozilla Readability library to extract the - * main content from a web page. - * @example - * ```typescript - * const loader = new CheerioWebBaseLoader("https://example.com/article"); - * const docs = await loader.load(); - * - * const splitter = new RecursiveCharacterTextSplitter({ - * maxCharacterCount: 5000, - * }); - * const transformer = new MozillaReadabilityTransformer(); - * - * // The sequence processes the loaded documents through the splitter and then the transformer. - * const sequence = splitter.pipe(transformer); - * - * // Invoke the sequence to transform the documents into a more readable format. - * const newDocuments = await sequence.invoke(docs); - * - * console.log(newDocuments); - * ``` - */ -export class MozillaReadabilityTransformer extends MappingDocumentTransformer { - static lc_name() { - return "MozillaReadabilityTransformer"; - } - - constructor(protected options: Options = {}) { - super(options); - } - - async _transformDocument(document: Document): Promise { - const doc = new JSDOM(document.pageContent); - - const readability = new Readability(doc.window.document, this.options); - - const result = readability.parse(); - - return new Document({ - pageContent: result?.textContent ?? "", - metadata: { - ...document.metadata, - }, - }); - } -} +export * from "@langchain/community/document_transformers/mozilla_readability"; diff --git a/langchain/src/embeddings/bedrock.ts b/langchain/src/embeddings/bedrock.ts index f9ca960f3a1a..0158430fd0a5 100644 --- a/langchain/src/embeddings/bedrock.ts +++ b/langchain/src/embeddings/bedrock.ts @@ -1,142 +1 @@ -import { - BedrockRuntimeClient, - InvokeModelCommand, -} from "@aws-sdk/client-bedrock-runtime"; -import { Embeddings, EmbeddingsParams } from "./base.js"; -import type { CredentialType } from "../util/bedrock.js"; - -/** - * Interface that extends EmbeddingsParams and defines additional - * parameters specific to the BedrockEmbeddings class. - */ -export interface BedrockEmbeddingsParams extends EmbeddingsParams { - /** - * Model Name to use. Defaults to `amazon.titan-embed-text-v1` if not provided - * - */ - model?: string; - - /** - * A client provided by the user that allows them to customze any - * SDK configuration options. - */ - client?: BedrockRuntimeClient; - - region?: string; - - credentials?: CredentialType; -} - -/** - * Class that extends the Embeddings class and provides methods for - * generating embeddings using the Bedrock API. - * @example - * ```typescript - * const embeddings = new BedrockEmbeddings({ - * region: "your-aws-region", - * credentials: { - * accessKeyId: "your-access-key-id", - * secretAccessKey: "your-secret-access-key", - * }, - * model: "amazon.titan-embed-text-v1", - * }); - * - * // Embed a query and log the result - * const res = await embeddings.embedQuery( - * "What would be a good company name for a company that makes colorful socks?" - * ); - * console.log({ res }); - * ``` - */ -export class BedrockEmbeddings - extends Embeddings - implements BedrockEmbeddingsParams -{ - model: string; - - client: BedrockRuntimeClient; - - batchSize = 512; - - constructor(fields?: BedrockEmbeddingsParams) { - super(fields ?? {}); - - this.model = fields?.model ?? "amazon.titan-embed-text-v1"; - - this.client = - fields?.client ?? - new BedrockRuntimeClient({ - region: fields?.region, - credentials: fields?.credentials, - }); - } - - /** - * Protected method to make a request to the Bedrock API to generate - * embeddings. Handles the retry logic and returns the response from the - * API. - * @param request Request to send to the Bedrock API. - * @returns Promise that resolves to the response from the API. - */ - protected async _embedText(text: string): Promise { - return this.caller.call(async () => { - try { - // replace newlines, which can negatively affect performance. - const cleanedText = text.replace(/\n/g, " "); - - const res = await this.client.send( - new InvokeModelCommand({ - modelId: this.model, - body: JSON.stringify({ - inputText: cleanedText, - }), - contentType: "application/json", - accept: "application/json", - }) - ); - - const body = new TextDecoder().decode(res.body); - return JSON.parse(body).embedding; - } catch (e) { - console.error({ - error: e, - }); - // eslint-disable-next-line no-instanceof/no-instanceof - if (e instanceof Error) { - throw new Error( - `An error occurred while embedding documents with Bedrock: ${e.message}` - ); - } - - throw new Error( - "An error occurred while embedding documents with Bedrock" - ); - } - }); - } - - /** - * Method that takes a document as input and returns a promise that - * resolves to an embedding for the document. It calls the _embedText - * method with the document as the input. - * @param document Document for which to generate an embedding. - * @returns Promise that resolves to an embedding for the input document. - */ - embedQuery(document: string): Promise { - return this.caller.callWithOptions( - {}, - this._embedText.bind(this), - document - ); - } - - /** - * Method to generate embeddings for an array of texts. Calls _embedText - * method which batches and handles retry logic when calling the AWS Bedrock API. - * @param documents Array of texts for which to generate embeddings. - * @returns Promise that resolves to a 2D array of embeddings for each input document. - */ - async embedDocuments(documents: string[]): Promise { - return Promise.all(documents.map((document) => this._embedText(document))); - } -} +export * from "@langchain/community/embeddings/bedrock"; diff --git a/langchain/src/embeddings/cloudflare_workersai.ts b/langchain/src/embeddings/cloudflare_workersai.ts index 191213dfbf5f..b5d4cd6238a8 100644 --- a/langchain/src/embeddings/cloudflare_workersai.ts +++ b/langchain/src/embeddings/cloudflare_workersai.ts @@ -1,94 +1 @@ -import { Ai } from "@cloudflare/ai"; -import { Fetcher } from "@cloudflare/workers-types"; -import { chunkArray } from "../util/chunk.js"; -import { Embeddings, EmbeddingsParams } from "./base.js"; - -type AiTextEmbeddingsInput = { - text: string | string[]; -}; - -type AiTextEmbeddingsOutput = { - shape: number[]; - data: number[][]; -}; - -export interface CloudflareWorkersAIEmbeddingsParams extends EmbeddingsParams { - /** Binding */ - binding: Fetcher; - - /** Model name to use */ - modelName?: string; - - /** - * The maximum number of documents to embed in a single request. - */ - batchSize?: number; - - /** - * Whether to strip new lines from the input text. This is recommended by - * OpenAI, but may not be suitable for all use cases. - */ - stripNewLines?: boolean; -} - -export class CloudflareWorkersAIEmbeddings extends Embeddings { - modelName = "@cf/baai/bge-base-en-v1.5"; - - batchSize = 50; - - stripNewLines = true; - - ai: Ai; - - constructor(fields: CloudflareWorkersAIEmbeddingsParams) { - super(fields); - - if (!fields.binding) { - throw new Error( - "Must supply a Workers AI binding, eg { binding: env.AI }" - ); - } - this.ai = new Ai(fields.binding); - this.modelName = fields.modelName ?? this.modelName; - this.stripNewLines = fields.stripNewLines ?? this.stripNewLines; - } - - async embedDocuments(texts: string[]): Promise { - const batches = chunkArray( - this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts, - this.batchSize - ); - - const batchRequests = batches.map((batch) => this.runEmbedding(batch)); - const batchResponses = await Promise.all(batchRequests); - const embeddings: number[][] = []; - - for (let i = 0; i < batchResponses.length; i += 1) { - const batchResponse = batchResponses[i]; - for (let j = 0; j < batchResponse.length; j += 1) { - embeddings.push(batchResponse[j]); - } - } - - return embeddings; - } - - async embedQuery(text: string): Promise { - const data = await this.runEmbedding([ - this.stripNewLines ? text.replace(/\n/g, " ") : text, - ]); - return data[0]; - } - - private async runEmbedding(texts: string[]) { - return this.caller.call(async () => { - const response: AiTextEmbeddingsOutput = await this.ai.run( - this.modelName, - { - text: texts, - } as AiTextEmbeddingsInput - ); - return response.data; - }); - } -} +export * from "@langchain/community/embeddings/cloudflare_workersai"; diff --git a/langchain/src/embeddings/cohere.ts b/langchain/src/embeddings/cohere.ts index 4d510c205459..83eb358d4e37 100644 --- a/langchain/src/embeddings/cohere.ts +++ b/langchain/src/embeddings/cohere.ts @@ -1,155 +1 @@ -import { chunkArray } from "../util/chunk.js"; -import { getEnvironmentVariable } from "../util/env.js"; -import { Embeddings, EmbeddingsParams } from "./base.js"; - -/** - * Interface that extends EmbeddingsParams and defines additional - * parameters specific to the CohereEmbeddings class. - */ -export interface CohereEmbeddingsParams extends EmbeddingsParams { - modelName: string; - - /** - * The maximum number of documents to embed in a single request. This is - * limited by the Cohere API to a maximum of 96. - */ - batchSize?: number; -} - -/** - * A class for generating embeddings using the Cohere API. - * @example - * ```typescript - * // Embed a query using the CohereEmbeddings class - * const model = new ChatOpenAI(); - * const res = await model.embedQuery( - * "What would be a good company name for a company that makes colorful socks?", - * ); - * console.log({ res }); - * - * ``` - */ -export class CohereEmbeddings - extends Embeddings - implements CohereEmbeddingsParams -{ - modelName = "small"; - - batchSize = 48; - - private apiKey: string; - - private client: typeof import("cohere-ai"); - - /** - * Constructor for the CohereEmbeddings class. - * @param fields - An optional object with properties to configure the instance. - */ - constructor( - fields?: Partial & { - verbose?: boolean; - apiKey?: string; - } - ) { - const fieldsWithDefaults = { maxConcurrency: 2, ...fields }; - - super(fieldsWithDefaults); - - const apiKey = - fieldsWithDefaults?.apiKey || getEnvironmentVariable("COHERE_API_KEY"); - - if (!apiKey) { - throw new Error("Cohere API key not found"); - } - - this.modelName = fieldsWithDefaults?.modelName ?? this.modelName; - this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize; - this.apiKey = apiKey; - } - - /** - * Generates embeddings for an array of texts. - * @param texts - An array of strings to generate embeddings for. - * @returns A Promise that resolves to an array of embeddings. - */ - async embedDocuments(texts: string[]): Promise { - await this.maybeInitClient(); - - const batches = chunkArray(texts, this.batchSize); - - const batchRequests = batches.map((batch) => - this.embeddingWithRetry({ - model: this.modelName, - texts: batch, - }) - ); - - const batchResponses = await Promise.all(batchRequests); - - const embeddings: number[][] = []; - - for (let i = 0; i < batchResponses.length; i += 1) { - const batch = batches[i]; - const { body: batchResponse } = batchResponses[i]; - for (let j = 0; j < batch.length; j += 1) { - embeddings.push(batchResponse.embeddings[j]); - } - } - - return embeddings; - } - - /** - * Generates an embedding for a single text. - * @param text - A string to generate an embedding for. - * @returns A Promise that resolves to an array of numbers representing the embedding. - */ - async embedQuery(text: string): Promise { - await this.maybeInitClient(); - - const { body } = await this.embeddingWithRetry({ - model: this.modelName, - texts: [text], - }); - return body.embeddings[0]; - } - - /** - * Generates embeddings with retry capabilities. - * @param request - An object containing the request parameters for generating embeddings. - * @returns A Promise that resolves to the API response. - */ - private async embeddingWithRetry( - request: Parameters[0] - ) { - await this.maybeInitClient(); - - return this.caller.call(this.client.embed.bind(this.client), request); - } - - /** - * Initializes the Cohere client if it hasn't been initialized already. - */ - private async maybeInitClient() { - if (!this.client) { - const { cohere } = await CohereEmbeddings.imports(); - - this.client = cohere; - this.client.init(this.apiKey); - } - } - - /** @ignore */ - static async imports(): Promise<{ - cohere: typeof import("cohere-ai"); - }> { - try { - const { default: cohere } = await import("cohere-ai"); - return { cohere }; - } catch (e) { - throw new Error( - "Please install cohere-ai as a dependency with, e.g. `yarn add cohere-ai`" - ); - } - } -} +export * from "@langchain/community/embeddings/cohere"; diff --git a/langchain/src/embeddings/googlepalm.ts b/langchain/src/embeddings/googlepalm.ts index 2d969dc98106..950bc455c2c6 100644 --- a/langchain/src/embeddings/googlepalm.ts +++ b/langchain/src/embeddings/googlepalm.ts @@ -1,107 +1 @@ -import { TextServiceClient } from "@google-ai/generativelanguage"; -import { GoogleAuth } from "google-auth-library"; -import { Embeddings, EmbeddingsParams } from "./base.js"; -import { getEnvironmentVariable } from "../util/env.js"; - -/** - * Interface that extends EmbeddingsParams and defines additional - * parameters specific to the GooglePaLMEmbeddings class. - */ -export interface GooglePaLMEmbeddingsParams extends EmbeddingsParams { - /** - * Model Name to use - * - * Note: The format must follow the pattern - `models/{model}` - */ - modelName?: string; - /** - * Google Palm API key to use - */ - apiKey?: string; -} - -/** - * Class that extends the Embeddings class and provides methods for - * generating embeddings using the Google Palm API. - * @example - * ```typescript - * const model = new GooglePaLMEmbeddings({ - * apiKey: "", - * modelName: "models/embedding-gecko-001", - * }); - * - * // Embed a single query - * const res = await model.embedQuery( - * "What would be a good company name for a company that makes colorful socks?" - * ); - * console.log({ res }); - * - * // Embed multiple documents - * const documentRes = await model.embedDocuments(["Hello world", "Bye bye"]); - * console.log({ documentRes }); - * ``` - */ -export class GooglePaLMEmbeddings - extends Embeddings - implements GooglePaLMEmbeddingsParams -{ - apiKey?: string; - - modelName = "models/embedding-gecko-001"; - - private client: TextServiceClient; - - constructor(fields?: GooglePaLMEmbeddingsParams) { - super(fields ?? {}); - - this.modelName = fields?.modelName ?? this.modelName; - - this.apiKey = - fields?.apiKey ?? getEnvironmentVariable("GOOGLE_PALM_API_KEY"); - if (!this.apiKey) { - throw new Error( - "Please set an API key for Google Palm 2 in the environment variable GOOGLE_PALM_API_KEY or in the `apiKey` field of the GooglePalm constructor" - ); - } - - this.client = new TextServiceClient({ - authClient: new GoogleAuth().fromAPIKey(this.apiKey), - }); - } - - protected async _embedText(text: string): Promise { - // replace newlines, which can negatively affect performance. - const cleanedText = text.replace(/\n/g, " "); - const res = await this.client.embedText({ - model: this.modelName, - text: cleanedText, - }); - return res[0].embedding?.value ?? []; - } - - /** - * Method that takes a document as input and returns a promise that - * resolves to an embedding for the document. It calls the _embedText - * method with the document as the input. - * @param document Document for which to generate an embedding. - * @returns Promise that resolves to an embedding for the input document. - */ - embedQuery(document: string): Promise { - return this.caller.callWithOptions( - {}, - this._embedText.bind(this), - document - ); - } - - /** - * Method that takes an array of documents as input and returns a promise - * that resolves to a 2D array of embeddings for each document. It calls - * the _embedText method for each document in the array. - * @param documents Array of documents for which to generate embeddings. - * @returns Promise that resolves to a 2D array of embeddings for each input document. - */ - embedDocuments(documents: string[]): Promise { - return Promise.all(documents.map((document) => this._embedText(document))); - } -} +export * from "@langchain/community/embeddings/googlepalm"; diff --git a/langchain/src/embeddings/googlevertexai.ts b/langchain/src/embeddings/googlevertexai.ts index 121d8c6efd9b..3eded4165344 100644 --- a/langchain/src/embeddings/googlevertexai.ts +++ b/langchain/src/embeddings/googlevertexai.ts @@ -1,145 +1 @@ -import { GoogleAuth, GoogleAuthOptions } from "google-auth-library"; -import { Embeddings, EmbeddingsParams } from "./base.js"; -import { - GoogleVertexAIBasePrediction, - GoogleVertexAIBaseLLMInput, - GoogleVertexAILLMPredictions, -} from "../types/googlevertexai-types.js"; -import { GoogleVertexAILLMConnection } from "../util/googlevertexai-connection.js"; -import { AsyncCallerCallOptions } from "../util/async_caller.js"; -import { chunkArray } from "../util/chunk.js"; - -/** - * Defines the parameters required to initialize a - * GoogleVertexAIEmbeddings instance. It extends EmbeddingsParams and - * GoogleVertexAIConnectionParams. - */ -export interface GoogleVertexAIEmbeddingsParams - extends EmbeddingsParams, - GoogleVertexAIBaseLLMInput {} - -/** - * Defines additional options specific to the - * GoogleVertexAILLMEmbeddingsInstance. It extends AsyncCallerCallOptions. - */ -interface GoogleVertexAILLMEmbeddingsOptions extends AsyncCallerCallOptions {} - -/** - * Represents an instance for generating embeddings using the Google - * Vertex AI API. It contains the content to be embedded. - */ -interface GoogleVertexAILLMEmbeddingsInstance { - content: string; -} - -/** - * Defines the structure of the embeddings results returned by the Google - * Vertex AI API. It extends GoogleVertexAIBasePrediction and contains the - * embeddings and their statistics. - */ -interface GoogleVertexEmbeddingsResults extends GoogleVertexAIBasePrediction { - embeddings: { - statistics: { - token_count: number; - truncated: boolean; - }; - values: number[]; - }; -} - -/** - * Enables calls to the Google Cloud's Vertex AI API to access - * the embeddings generated by Large Language Models. - * - * To use, you will need to have one of the following authentication - * methods in place: - * - You are logged into an account permitted to the Google Cloud project - * using Vertex AI. - * - You are running this on a machine using a service account permitted to - * the Google Cloud project using Vertex AI. - * - The `GOOGLE_APPLICATION_CREDENTIALS` environment variable is set to the - * path of a credentials file for a service account permitted to the - * Google Cloud project using Vertex AI. - * @example - * ```typescript - * const model = new GoogleVertexAIEmbeddings(); - * const res = await model.embedQuery( - * "What would be a good company name for a company that makes colorful socks?" - * ); - * console.log({ res }); - * ``` - */ -export class GoogleVertexAIEmbeddings - extends Embeddings - implements GoogleVertexAIEmbeddingsParams -{ - model = "textembedding-gecko"; - - private connection: GoogleVertexAILLMConnection< - GoogleVertexAILLMEmbeddingsOptions, - GoogleVertexAILLMEmbeddingsInstance, - GoogleVertexEmbeddingsResults, - GoogleAuthOptions - >; - - constructor(fields?: GoogleVertexAIEmbeddingsParams) { - super(fields ?? {}); - - this.model = fields?.model ?? this.model; - - this.connection = new GoogleVertexAILLMConnection( - { ...fields, ...this }, - this.caller, - new GoogleAuth({ - scopes: "https://www.googleapis.com/auth/cloud-platform", - ...fields?.authOptions, - }) - ); - } - - /** - * Takes an array of documents as input and returns a promise that - * resolves to a 2D array of embeddings for each document. It splits the - * documents into chunks and makes requests to the Google Vertex AI API to - * generate embeddings. - * @param documents An array of documents to be embedded. - * @returns A promise that resolves to a 2D array of embeddings for each document. - */ - async embedDocuments(documents: string[]): Promise { - const instanceChunks: GoogleVertexAILLMEmbeddingsInstance[][] = chunkArray( - documents.map((document) => ({ - content: document, - })), - 5 - ); // Vertex AI accepts max 5 instances per prediction - const parameters = {}; - const options = {}; - const responses = await Promise.all( - instanceChunks.map((instances) => - this.connection.request(instances, parameters, options) - ) - ); - const result: number[][] = - responses - ?.map( - (response) => - ( - response?.data as GoogleVertexAILLMPredictions - )?.predictions?.map((result) => result.embeddings.values) ?? [] - ) - .flat() ?? []; - return result; - } - - /** - * Takes a document as input and returns a promise that resolves to an - * embedding for the document. It calls the embedDocuments method with the - * document as the input. - * @param document A document to be embedded. - * @returns A promise that resolves to an embedding for the document. - */ - async embedQuery(document: string): Promise { - const data = await this.embedDocuments([document]); - return data[0]; - } -} +export * from "@langchain/community/embeddings/googlevertexai"; diff --git a/langchain/src/embeddings/gradient_ai.ts b/langchain/src/embeddings/gradient_ai.ts index e03b7a364c0f..83b02683d3d9 100644 --- a/langchain/src/embeddings/gradient_ai.ts +++ b/langchain/src/embeddings/gradient_ai.ts @@ -1,118 +1 @@ -import { Gradient } from "@gradientai/nodejs-sdk"; -import { getEnvironmentVariable } from "../util/env.js"; -import { chunkArray } from "../util/chunk.js"; -import { Embeddings, EmbeddingsParams } from "./base.js"; - -/** - * Interface for GradientEmbeddings parameters. Extends EmbeddingsParams and - * defines additional parameters specific to the GradientEmbeddings class. - */ -export interface GradientEmbeddingsParams extends EmbeddingsParams { - /** - * Gradient AI Access Token. - * Provide Access Token if you do not wish to automatically pull from env. - */ - gradientAccessKey?: string; - /** - * Gradient Workspace Id. - * Provide workspace id if you do not wish to automatically pull from env. - */ - workspaceId?: string; -} - -/** - * Class for generating embeddings using the Gradient AI's API. Extends the - * Embeddings class and implements GradientEmbeddingsParams and - */ -export class GradientEmbeddings - extends Embeddings - implements GradientEmbeddingsParams -{ - gradientAccessKey?: string; - - workspaceId?: string; - - batchSize = 128; - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - model: any; - - constructor(fields: GradientEmbeddingsParams) { - super(fields); - - this.gradientAccessKey = - fields?.gradientAccessKey ?? - getEnvironmentVariable("GRADIENT_ACCESS_TOKEN"); - this.workspaceId = - fields?.workspaceId ?? getEnvironmentVariable("GRADIENT_WORKSPACE_ID"); - - if (!this.gradientAccessKey) { - throw new Error("Missing Gradient AI Access Token"); - } - - if (!this.workspaceId) { - throw new Error("Missing Gradient AI Workspace ID"); - } - } - - /** - * Method to generate embeddings for an array of documents. Splits the - * documents into batches and makes requests to the Gradient API to generate - * embeddings. - * @param texts Array of documents to generate embeddings for. - * @returns Promise that resolves to a 2D array of embeddings for each document. - */ - async embedDocuments(texts: string[]): Promise { - await this.setModel(); - - const mappedTexts = texts.map((text) => ({ input: text })); - - const batches = chunkArray(mappedTexts, this.batchSize); - - const batchRequests = batches.map((batch) => - this.caller.call(async () => - this.model.generateEmbeddings({ - inputs: batch, - }) - ) - ); - const batchResponses = await Promise.all(batchRequests); - - const embeddings: number[][] = []; - for (let i = 0; i < batchResponses.length; i += 1) { - const batch = batches[i]; - const { embeddings: batchResponse } = batchResponses[i]; - for (let j = 0; j < batch.length; j += 1) { - embeddings.push(batchResponse[j].embedding); - } - } - return embeddings; - } - - /** - * Method to generate an embedding for a single document. Calls the - * embedDocuments method with the document as the input. - * @param text Document to generate an embedding for. - * @returns Promise that resolves to an embedding for the document. - */ - async embedQuery(text: string): Promise { - const data = await this.embedDocuments([text]); - return data[0]; - } - - /** - * Method to set the model to use for generating embeddings. - * @sets the class' `model` value to that of the retrieved Embeddings Model. - */ - async setModel() { - if (this.model) return; - - const gradient = new Gradient({ - accessToken: this.gradientAccessKey, - workspaceId: this.workspaceId, - }); - this.model = await gradient.getEmbeddingsModel({ - slug: "bge-large", - }); - } -} +export * from "@langchain/community/embeddings/gradient_ai"; diff --git a/langchain/src/embeddings/hf.ts b/langchain/src/embeddings/hf.ts index 83b801a90566..7042535179ed 100644 --- a/langchain/src/embeddings/hf.ts +++ b/langchain/src/embeddings/hf.ts @@ -1,77 +1 @@ -import { HfInference, HfInferenceEndpoint } from "@huggingface/inference"; -import { Embeddings, EmbeddingsParams } from "./base.js"; -import { getEnvironmentVariable } from "../util/env.js"; - -/** - * Interface that extends EmbeddingsParams and defines additional - * parameters specific to the HuggingFaceInferenceEmbeddings class. - */ -export interface HuggingFaceInferenceEmbeddingsParams extends EmbeddingsParams { - apiKey?: string; - model?: string; - endpointUrl?: string; -} - -/** - * Class that extends the Embeddings class and provides methods for - * generating embeddings using Hugging Face models through the - * HuggingFaceInference API. - */ -export class HuggingFaceInferenceEmbeddings - extends Embeddings - implements HuggingFaceInferenceEmbeddingsParams -{ - apiKey?: string; - - model: string; - - endpointUrl?: string; - - client: HfInference | HfInferenceEndpoint; - - constructor(fields?: HuggingFaceInferenceEmbeddingsParams) { - super(fields ?? {}); - - this.model = fields?.model ?? "BAAI/bge-base-en-v1.5"; - this.apiKey = - fields?.apiKey ?? getEnvironmentVariable("HUGGINGFACEHUB_API_KEY"); - this.endpointUrl = fields?.endpointUrl; - this.client = this.endpointUrl - ? new HfInference(this.apiKey).endpoint(this.endpointUrl) - : new HfInference(this.apiKey); - } - - async _embed(texts: string[]): Promise { - // replace newlines, which can negatively affect performance. - const clean = texts.map((text) => text.replace(/\n/g, " ")); - return this.caller.call(() => - this.client.featureExtraction({ - model: this.model, - inputs: clean, - }) - ) as Promise; - } - - /** - * Method that takes a document as input and returns a promise that - * resolves to an embedding for the document. It calls the _embed method - * with the document as the input and returns the first embedding in the - * resulting array. - * @param document Document to generate an embedding for. - * @returns Promise that resolves to an embedding for the document. - */ - embedQuery(document: string): Promise { - return this._embed([document]).then((embeddings) => embeddings[0]); - } - - /** - * Method that takes an array of documents as input and returns a promise - * that resolves to a 2D array of embeddings for each document. It calls - * the _embed method with the documents as the input. - * @param documents Array of documents to generate embeddings for. - * @returns Promise that resolves to a 2D array of embeddings for each document. - */ - embedDocuments(documents: string[]): Promise { - return this._embed(documents); - } -} +export * from "@langchain/community/embeddings/hf"; diff --git a/langchain/src/embeddings/hf_transformers.ts b/langchain/src/embeddings/hf_transformers.ts index e2b8bdcc98a5..aff3ec98244e 100644 --- a/langchain/src/embeddings/hf_transformers.ts +++ b/langchain/src/embeddings/hf_transformers.ts @@ -1,105 +1 @@ -import { Pipeline, pipeline } from "@xenova/transformers"; -import { chunkArray } from "../util/chunk.js"; -import { Embeddings, EmbeddingsParams } from "./base.js"; - -export interface HuggingFaceTransformersEmbeddingsParams - extends EmbeddingsParams { - /** Model name to use */ - modelName: string; - - /** - * Timeout to use when making requests to OpenAI. - */ - timeout?: number; - - /** - * The maximum number of documents to embed in a single request. - */ - batchSize?: number; - - /** - * Whether to strip new lines from the input text. This is recommended by - * OpenAI, but may not be suitable for all use cases. - */ - stripNewLines?: boolean; -} - -/** - * @example - * ```typescript - * const model = new HuggingFaceTransformersEmbeddings({ - * modelName: "Xenova/all-MiniLM-L6-v2", - * }); - * - * // Embed a single query - * const res = await model.embedQuery( - * "What would be a good company name for a company that makes colorful socks?" - * ); - * console.log({ res }); - * - * // Embed multiple documents - * const documentRes = await model.embedDocuments(["Hello world", "Bye bye"]); - * console.log({ documentRes }); - * ``` - */ -export class HuggingFaceTransformersEmbeddings - extends Embeddings - implements HuggingFaceTransformersEmbeddingsParams -{ - modelName = "Xenova/all-MiniLM-L6-v2"; - - batchSize = 512; - - stripNewLines = true; - - timeout?: number; - - private pipelinePromise: Promise; - - constructor(fields?: Partial) { - super(fields ?? {}); - - this.modelName = fields?.modelName ?? this.modelName; - this.stripNewLines = fields?.stripNewLines ?? this.stripNewLines; - this.timeout = fields?.timeout; - } - - async embedDocuments(texts: string[]): Promise { - const batches = chunkArray( - this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts, - this.batchSize - ); - - const batchRequests = batches.map((batch) => this.runEmbedding(batch)); - const batchResponses = await Promise.all(batchRequests); - const embeddings: number[][] = []; - - for (let i = 0; i < batchResponses.length; i += 1) { - const batchResponse = batchResponses[i]; - for (let j = 0; j < batchResponse.length; j += 1) { - embeddings.push(batchResponse[j]); - } - } - - return embeddings; - } - - async embedQuery(text: string): Promise { - const data = await this.runEmbedding([ - this.stripNewLines ? text.replace(/\n/g, " ") : text, - ]); - return data[0]; - } - - private async runEmbedding(texts: string[]) { - const pipe = await (this.pipelinePromise ??= pipeline( - "feature-extraction", - this.modelName - )); - - return this.caller.call(async () => { - const output = await pipe(texts, { pooling: "mean", normalize: true }); - return output.tolist(); - }); - } -} +export * from "@langchain/community/embeddings/hf_transformers"; diff --git a/langchain/src/embeddings/llama_cpp.ts b/langchain/src/embeddings/llama_cpp.ts index 51b9b3c3e007..3fcf55dda59a 100644 --- a/langchain/src/embeddings/llama_cpp.ts +++ b/langchain/src/embeddings/llama_cpp.ts @@ -1,103 +1 @@ -import { LlamaModel, LlamaContext } from "node-llama-cpp"; -import { - LlamaBaseCppInputs, - createLlamaModel, - createLlamaContext, -} from "../util/llama_cpp.js"; -import { Embeddings, EmbeddingsParams } from "./base.js"; - -/** - * Note that the modelPath is the only required parameter. For testing you - * can set this in the environment variable `LLAMA_PATH`. - */ -export interface LlamaCppEmbeddingsParams - extends LlamaBaseCppInputs, - EmbeddingsParams {} - -/** - * @example - * ```typescript - * // Initialize LlamaCppEmbeddings with the path to the model file - * const embeddings = new LlamaCppEmbeddings({ - * modelPath: "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin", - * }); - * - * // Embed a query string using the Llama embeddings - * const res = embeddings.embedQuery("Hello Llama!"); - * - * // Output the resulting embeddings - * console.log(res); - * - * ``` - */ -export class LlamaCppEmbeddings extends Embeddings { - _model: LlamaModel; - - _context: LlamaContext; - - constructor(inputs: LlamaCppEmbeddingsParams) { - super(inputs); - const _inputs = inputs; - _inputs.embedding = true; - - this._model = createLlamaModel(_inputs); - this._context = createLlamaContext(this._model, _inputs); - } - - /** - * Generates embeddings for an array of texts. - * @param texts - An array of strings to generate embeddings for. - * @returns A Promise that resolves to an array of embeddings. - */ - async embedDocuments(texts: string[]): Promise { - const tokensArray = []; - - for (const text of texts) { - const encodings = await this.caller.call( - () => - new Promise((resolve) => { - resolve(this._context.encode(text)); - }) - ); - tokensArray.push(encodings); - } - - const embeddings: number[][] = []; - - for (const tokens of tokensArray) { - const embedArray: number[] = []; - - for (let i = 0; i < tokens.length; i += 1) { - const nToken: number = +tokens[i]; - embedArray.push(nToken); - } - - embeddings.push(embedArray); - } - - return embeddings; - } - - /** - * Generates an embedding for a single text. - * @param text - A string to generate an embedding for. - * @returns A Promise that resolves to an array of numbers representing the embedding. - */ - async embedQuery(text: string): Promise { - const tokens: number[] = []; - - const encodings = await this.caller.call( - () => - new Promise((resolve) => { - resolve(this._context.encode(text)); - }) - ); - - for (let i = 0; i < encodings.length; i += 1) { - const token: number = +encodings[i]; - tokens.push(token); - } - - return tokens; - } -} +export * from "@langchain/community/embeddings/llama_cpp"; diff --git a/langchain/src/embeddings/minimax.ts b/langchain/src/embeddings/minimax.ts index 80697e30f60a..8568bf09207e 100644 --- a/langchain/src/embeddings/minimax.ts +++ b/langchain/src/embeddings/minimax.ts @@ -1,222 +1 @@ -import { getEnvironmentVariable } from "../util/env.js"; -import { chunkArray } from "../util/chunk.js"; -import { Embeddings, EmbeddingsParams } from "./base.js"; -import { ConfigurationParameters } from "../chat_models/minimax.js"; - -/** - * Interface for MinimaxEmbeddings parameters. Extends EmbeddingsParams and - * defines additional parameters specific to the MinimaxEmbeddings class. - */ -export interface MinimaxEmbeddingsParams extends EmbeddingsParams { - /** Model name to use */ - modelName: string; - - /** - * API key to use when making requests. Defaults to the value of - * `MINIMAX_GROUP_ID` environment variable. - */ - minimaxGroupId?: string; - - /** - * Secret key to use when making requests. Defaults to the value of - * `MINIMAX_API_KEY` environment variable. - */ - minimaxApiKey?: string; - - /** - * The maximum number of documents to embed in a single request. This is - * limited by the Minimax API to a maximum of 4096. - */ - batchSize?: number; - - /** - * Whether to strip new lines from the input text. This is recommended by - * Minimax, but may not be suitable for all use cases. - */ - stripNewLines?: boolean; - - /** - * The target use-case after generating the vector. - * When using embeddings, the vector of the target content is first generated through the db and stored in the vector database, - * and then the vector of the retrieval text is generated through the query. - * Note: For the parameters of the partial algorithm, we adopted a separate algorithm plan for query and db. - * Therefore, for a paragraph of text, if it is to be used as a retrieval text, it should use the db, - * and if it is used as a retrieval text, it should use the query. - */ - type?: "db" | "query"; -} - -export interface CreateMinimaxEmbeddingRequest { - /** - * @type {string} - * @memberof CreateMinimaxEmbeddingRequest - */ - model: string; - - /** - * Text to generate vector expectation - * @type {CreateEmbeddingRequestInput} - * @memberof CreateMinimaxEmbeddingRequest - */ - texts: string[]; - - /** - * The target use-case after generating the vector. When using embeddings, - * first generate the vector of the target content through the db and store it in the vector database, - * and then generate the vector of the retrieval text through the query. - * Note: For the parameter of the algorithm, we use the algorithm scheme of query and db separation, - * so a text, if it is to be retrieved as a text, should use the db, - * if it is used as a retrieval text, should use the query. - * @type {string} - * @memberof CreateMinimaxEmbeddingRequest - */ - type: "db" | "query"; -} - -/** - * Class for generating embeddings using the Minimax API. Extends the - * Embeddings class and implements MinimaxEmbeddingsParams - * @example - * ```typescript - * const embeddings = new MinimaxEmbeddings(); - * - * // Embed a single query - * const queryEmbedding = await embeddings.embedQuery("Hello world"); - * console.log(queryEmbedding); - * - * // Embed multiple documents - * const documentsEmbedding = await embeddings.embedDocuments([ - * "Hello world", - * "Bye bye", - * ]); - * console.log(documentsEmbedding); - * ``` - */ -export class MinimaxEmbeddings - extends Embeddings - implements MinimaxEmbeddingsParams -{ - modelName = "embo-01"; - - batchSize = 512; - - stripNewLines = true; - - minimaxGroupId?: string; - - minimaxApiKey?: string; - - type: "db" | "query" = "db"; - - apiUrl: string; - - basePath?: string = "https://api.minimax.chat/v1"; - - headers?: Record; - - constructor( - fields?: Partial & { - configuration?: ConfigurationParameters; - } - ) { - const fieldsWithDefaults = { maxConcurrency: 2, ...fields }; - super(fieldsWithDefaults); - - this.minimaxGroupId = - fields?.minimaxGroupId ?? getEnvironmentVariable("MINIMAX_GROUP_ID"); - if (!this.minimaxGroupId) { - throw new Error("Minimax GroupID not found"); - } - - this.minimaxApiKey = - fields?.minimaxApiKey ?? getEnvironmentVariable("MINIMAX_API_KEY"); - - if (!this.minimaxApiKey) { - throw new Error("Minimax ApiKey not found"); - } - - this.modelName = fieldsWithDefaults?.modelName ?? this.modelName; - this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize; - this.type = fieldsWithDefaults?.type ?? this.type; - this.stripNewLines = - fieldsWithDefaults?.stripNewLines ?? this.stripNewLines; - this.basePath = fields?.configuration?.basePath ?? this.basePath; - this.apiUrl = `${this.basePath}/embeddings`; - this.headers = fields?.configuration?.headers ?? this.headers; - } - - /** - * Method to generate embeddings for an array of documents. Splits the - * documents into batches and makes requests to the Minimax API to generate - * embeddings. - * @param texts Array of documents to generate embeddings for. - * @returns Promise that resolves to a 2D array of embeddings for each document. - */ - async embedDocuments(texts: string[]): Promise { - const batches = chunkArray( - this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts, - this.batchSize - ); - - const batchRequests = batches.map((batch) => - this.embeddingWithRetry({ - model: this.modelName, - texts: batch, - type: this.type, - }) - ); - const batchResponses = await Promise.all(batchRequests); - - const embeddings: number[][] = []; - for (let i = 0; i < batchResponses.length; i += 1) { - const batch = batches[i]; - const { vectors: batchResponse } = batchResponses[i]; - for (let j = 0; j < batch.length; j += 1) { - embeddings.push(batchResponse[j]); - } - } - return embeddings; - } - - /** - * Method to generate an embedding for a single document. Calls the - * embeddingWithRetry method with the document as the input. - * @param text Document to generate an embedding for. - * @returns Promise that resolves to an embedding for the document. - */ - async embedQuery(text: string): Promise { - const { vectors } = await this.embeddingWithRetry({ - model: this.modelName, - texts: [this.stripNewLines ? text.replace(/\n/g, " ") : text], - type: this.type, - }); - return vectors[0]; - } - - /** - * Private method to make a request to the Minimax API to generate - * embeddings. Handles the retry logic and returns the response from the - * API. - * @param request Request to send to the Minimax API. - * @returns Promise that resolves to the response from the API. - */ - private async embeddingWithRetry(request: CreateMinimaxEmbeddingRequest) { - const makeCompletionRequest = async () => { - const url = `${this.apiUrl}?GroupId=${this.minimaxGroupId}`; - const response = await fetch(url, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${this.minimaxApiKey}`, - ...this.headers, - }, - body: JSON.stringify(request), - }); - - const json = await response.json(); - return json; - }; - - return this.caller.call(makeCompletionRequest); - } -} +export * from "@langchain/community/embeddings/minimax"; diff --git a/langchain/src/embeddings/ollama.ts b/langchain/src/embeddings/ollama.ts index de9c77797dfe..9bd994c08dd9 100644 --- a/langchain/src/embeddings/ollama.ts +++ b/langchain/src/embeddings/ollama.ts @@ -1,148 +1 @@ -import { OllamaInput, OllamaRequestParams } from "../util/ollama.js"; -import { Embeddings, EmbeddingsParams } from "./base.js"; - -type CamelCasedRequestOptions = Omit< - OllamaInput, - "baseUrl" | "model" | "format" ->; - -/** - * Interface for OllamaEmbeddings parameters. Extends EmbeddingsParams and - * defines additional parameters specific to the OllamaEmbeddings class. - */ -interface OllamaEmbeddingsParams extends EmbeddingsParams { - /** The Ollama model to use, e.g: "llama2:13b" */ - model?: string; - - /** Base URL of the Ollama server, defaults to "http://localhost:11434" */ - baseUrl?: string; - - /** Advanced Ollama API request parameters in camelCase, see - * https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values - * for details of the available parameters. - */ - requestOptions?: CamelCasedRequestOptions; -} - -export class OllamaEmbeddings extends Embeddings { - model = "llama2"; - - baseUrl = "http://localhost:11434"; - - requestOptions?: OllamaRequestParams["options"]; - - constructor(params?: OllamaEmbeddingsParams) { - super(params || {}); - - if (params?.model) { - this.model = params.model; - } - - if (params?.baseUrl) { - this.baseUrl = params.baseUrl; - } - - if (params?.requestOptions) { - this.requestOptions = this._convertOptions(params.requestOptions); - } - } - - /** convert camelCased Ollama request options like "useMMap" to - * the snake_cased equivalent which the ollama API actually uses. - * Used only for consistency with the llms/Ollama and chatModels/Ollama classes - */ - _convertOptions(requestOptions: CamelCasedRequestOptions) { - const snakeCasedOptions: Record = {}; - const mapping: Record = { - embeddingOnly: "embedding_only", - f16KV: "f16_kv", - frequencyPenalty: "frequency_penalty", - logitsAll: "logits_all", - lowVram: "low_vram", - mainGpu: "main_gpu", - mirostat: "mirostat", - mirostatEta: "mirostat_eta", - mirostatTau: "mirostat_tau", - numBatch: "num_batch", - numCtx: "num_ctx", - numGpu: "num_gpu", - numGqa: "num_gqa", - numKeep: "num_keep", - numThread: "num_thread", - penalizeNewline: "penalize_newline", - presencePenalty: "presence_penalty", - repeatLastN: "repeat_last_n", - repeatPenalty: "repeat_penalty", - ropeFrequencyBase: "rope_frequency_base", - ropeFrequencyScale: "rope_frequency_scale", - temperature: "temperature", - stop: "stop", - tfsZ: "tfs_z", - topK: "top_k", - topP: "top_p", - typicalP: "typical_p", - useMLock: "use_mlock", - useMMap: "use_mmap", - vocabOnly: "vocab_only", - }; - - for (const [key, value] of Object.entries(requestOptions)) { - const snakeCasedOption = mapping[key as keyof CamelCasedRequestOptions]; - if (snakeCasedOption) { - snakeCasedOptions[snakeCasedOption] = value; - } - } - return snakeCasedOptions; - } - - async _request(prompt: string): Promise { - const { model, baseUrl, requestOptions } = this; - - let formattedBaseUrl = baseUrl; - if (formattedBaseUrl.startsWith("http://localhost:")) { - // Node 18 has issues with resolving "localhost" - // See https://github.com/node-fetch/node-fetch/issues/1624 - formattedBaseUrl = formattedBaseUrl.replace( - "http://localhost:", - "http://127.0.0.1:" - ); - } - - const response = await fetch(`${formattedBaseUrl}/api/embeddings`, { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ - prompt, - model, - options: requestOptions, - }), - }); - if (!response.ok) { - throw new Error( - `Request to Ollama server failed: ${response.status} ${response.statusText}` - ); - } - - const json = await response.json(); - return json.embedding; - } - - async _embed(strings: string[]): Promise { - const embeddings: number[][] = []; - - for await (const prompt of strings) { - const embedding = await this.caller.call(() => this._request(prompt)); - embeddings.push(embedding); - } - - return embeddings; - } - - async embedDocuments(documents: string[]) { - return this._embed(documents); - } - - async embedQuery(document: string) { - return (await this.embedDocuments([document]))[0]; - } -} +export * from "@langchain/community/embeddings/ollama"; diff --git a/langchain/src/embeddings/openai.ts b/langchain/src/embeddings/openai.ts index 0cc726d03ca3..9c5af9d81995 100644 --- a/langchain/src/embeddings/openai.ts +++ b/langchain/src/embeddings/openai.ts @@ -1,269 +1,4 @@ -import { type ClientOptions, OpenAI as OpenAIClient } from "openai"; -import { getEnvironmentVariable } from "../util/env.js"; -import { - AzureOpenAIInput, - OpenAICoreRequestOptions, - LegacyOpenAIInput, -} from "../types/openai-types.js"; -import { chunkArray } from "../util/chunk.js"; -import { Embeddings, EmbeddingsParams } from "./base.js"; -import { getEndpoint, OpenAIEndpointConfig } from "../util/azure.js"; -import { wrapOpenAIClientError } from "../util/openai.js"; - -/** - * Interface for OpenAIEmbeddings parameters. Extends EmbeddingsParams and - * defines additional parameters specific to the OpenAIEmbeddings class. - */ -export interface OpenAIEmbeddingsParams extends EmbeddingsParams { - /** Model name to use */ - modelName: string; - - /** - * Timeout to use when making requests to OpenAI. - */ - timeout?: number; - - /** - * The maximum number of documents to embed in a single request. This is - * limited by the OpenAI API to a maximum of 2048. - */ - batchSize?: number; - - /** - * Whether to strip new lines from the input text. This is recommended by - * OpenAI, but may not be suitable for all use cases. - */ - stripNewLines?: boolean; -} - -/** - * Class for generating embeddings using the OpenAI API. Extends the - * Embeddings class and implements OpenAIEmbeddingsParams and - * AzureOpenAIInput. - * @example - * ```typescript - * // Embed a query using OpenAIEmbeddings to generate embeddings for a given text - * const model = new OpenAIEmbeddings(); - * const res = await model.embedQuery( - * "What would be a good company name for a company that makes colorful socks?", - * ); - * console.log({ res }); - * - * ``` - */ -export class OpenAIEmbeddings - extends Embeddings - implements OpenAIEmbeddingsParams, AzureOpenAIInput -{ - modelName = "text-embedding-ada-002"; - - batchSize = 512; - - stripNewLines = true; - - timeout?: number; - - azureOpenAIApiVersion?: string; - - azureOpenAIApiKey?: string; - - azureOpenAIApiInstanceName?: string; - - azureOpenAIApiDeploymentName?: string; - - azureOpenAIBasePath?: string; - - organization?: string; - - private client: OpenAIClient; - - private clientConfig: ClientOptions; - - constructor( - fields?: Partial & - Partial & { - verbose?: boolean; - openAIApiKey?: string; - configuration?: ClientOptions; - }, - configuration?: ClientOptions & LegacyOpenAIInput - ) { - const fieldsWithDefaults = { maxConcurrency: 2, ...fields }; - - super(fieldsWithDefaults); - - let apiKey = - fieldsWithDefaults?.openAIApiKey ?? - getEnvironmentVariable("OPENAI_API_KEY"); - - const azureApiKey = - fieldsWithDefaults?.azureOpenAIApiKey ?? - getEnvironmentVariable("AZURE_OPENAI_API_KEY"); - if (!azureApiKey && !apiKey) { - throw new Error("OpenAI or Azure OpenAI API key not found"); - } - - const azureApiInstanceName = - fieldsWithDefaults?.azureOpenAIApiInstanceName ?? - getEnvironmentVariable("AZURE_OPENAI_API_INSTANCE_NAME"); - - const azureApiDeploymentName = - (fieldsWithDefaults?.azureOpenAIApiEmbeddingsDeploymentName || - fieldsWithDefaults?.azureOpenAIApiDeploymentName) ?? - (getEnvironmentVariable("AZURE_OPENAI_API_EMBEDDINGS_DEPLOYMENT_NAME") || - getEnvironmentVariable("AZURE_OPENAI_API_DEPLOYMENT_NAME")); - - const azureApiVersion = - fieldsWithDefaults?.azureOpenAIApiVersion ?? - getEnvironmentVariable("AZURE_OPENAI_API_VERSION"); - - this.azureOpenAIBasePath = - fieldsWithDefaults?.azureOpenAIBasePath ?? - getEnvironmentVariable("AZURE_OPENAI_BASE_PATH"); - - this.organization = - fieldsWithDefaults?.configuration?.organization ?? - getEnvironmentVariable("OPENAI_ORGANIZATION"); - - this.modelName = fieldsWithDefaults?.modelName ?? this.modelName; - this.batchSize = - fieldsWithDefaults?.batchSize ?? (azureApiKey ? 1 : this.batchSize); - this.stripNewLines = - fieldsWithDefaults?.stripNewLines ?? this.stripNewLines; - this.timeout = fieldsWithDefaults?.timeout; - - this.azureOpenAIApiVersion = azureApiVersion; - this.azureOpenAIApiKey = azureApiKey; - this.azureOpenAIApiInstanceName = azureApiInstanceName; - this.azureOpenAIApiDeploymentName = azureApiDeploymentName; - - if (this.azureOpenAIApiKey) { - if (!this.azureOpenAIApiInstanceName && !this.azureOpenAIBasePath) { - throw new Error("Azure OpenAI API instance name not found"); - } - if (!this.azureOpenAIApiDeploymentName) { - throw new Error("Azure OpenAI API deployment name not found"); - } - if (!this.azureOpenAIApiVersion) { - throw new Error("Azure OpenAI API version not found"); - } - apiKey = apiKey ?? ""; - } - - this.clientConfig = { - apiKey, - organization: this.organization, - baseURL: configuration?.basePath, - dangerouslyAllowBrowser: true, - defaultHeaders: configuration?.baseOptions?.headers, - defaultQuery: configuration?.baseOptions?.params, - ...configuration, - ...fields?.configuration, - }; - } - - /** - * Method to generate embeddings for an array of documents. Splits the - * documents into batches and makes requests to the OpenAI API to generate - * embeddings. - * @param texts Array of documents to generate embeddings for. - * @returns Promise that resolves to a 2D array of embeddings for each document. - */ - async embedDocuments(texts: string[]): Promise { - const batches = chunkArray( - this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts, - this.batchSize - ); - - const batchRequests = batches.map((batch) => - this.embeddingWithRetry({ - model: this.modelName, - input: batch, - }) - ); - const batchResponses = await Promise.all(batchRequests); - - const embeddings: number[][] = []; - for (let i = 0; i < batchResponses.length; i += 1) { - const batch = batches[i]; - const { data: batchResponse } = batchResponses[i]; - for (let j = 0; j < batch.length; j += 1) { - embeddings.push(batchResponse[j].embedding); - } - } - return embeddings; - } - - /** - * Method to generate an embedding for a single document. Calls the - * embeddingWithRetry method with the document as the input. - * @param text Document to generate an embedding for. - * @returns Promise that resolves to an embedding for the document. - */ - async embedQuery(text: string): Promise { - const { data } = await this.embeddingWithRetry({ - model: this.modelName, - input: this.stripNewLines ? text.replace(/\n/g, " ") : text, - }); - return data[0].embedding; - } - - /** - * Private method to make a request to the OpenAI API to generate - * embeddings. Handles the retry logic and returns the response from the - * API. - * @param request Request to send to the OpenAI API. - * @returns Promise that resolves to the response from the API. - */ - private async embeddingWithRetry( - request: OpenAIClient.EmbeddingCreateParams - ) { - if (!this.client) { - const openAIEndpointConfig: OpenAIEndpointConfig = { - azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName, - azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName, - azureOpenAIApiKey: this.azureOpenAIApiKey, - azureOpenAIBasePath: this.azureOpenAIBasePath, - baseURL: this.clientConfig.baseURL, - }; - - const endpoint = getEndpoint(openAIEndpointConfig); - - const params = { - ...this.clientConfig, - baseURL: endpoint, - timeout: this.timeout, - maxRetries: 0, - }; - - if (!params.baseURL) { - delete params.baseURL; - } - - this.client = new OpenAIClient(params); - } - const requestOptions: OpenAICoreRequestOptions = {}; - if (this.azureOpenAIApiKey) { - requestOptions.headers = { - "api-key": this.azureOpenAIApiKey, - ...requestOptions.headers, - }; - requestOptions.query = { - "api-version": this.azureOpenAIApiVersion, - ...requestOptions.query, - }; - } - return this.caller.call(async () => { - try { - const res = await this.client.embeddings.create( - request, - requestOptions - ); - return res; - } catch (e) { - const error = wrapOpenAIClientError(e); - throw error; - } - }); - } -} +export { + type OpenAIEmbeddingsParams, + OpenAIEmbeddings, +} from "@langchain/openai"; diff --git a/langchain/src/embeddings/tensorflow.ts b/langchain/src/embeddings/tensorflow.ts index 1fc198ccdb8f..cd09a7c82cf5 100644 --- a/langchain/src/embeddings/tensorflow.ts +++ b/langchain/src/embeddings/tensorflow.ts @@ -1,91 +1 @@ -import { load } from "@tensorflow-models/universal-sentence-encoder"; -import * as tf from "@tensorflow/tfjs-core"; - -import { Embeddings, EmbeddingsParams } from "./base.js"; - -/** - * Interface that extends EmbeddingsParams and defines additional - * parameters specific to the TensorFlowEmbeddings class. - */ -export interface TensorFlowEmbeddingsParams extends EmbeddingsParams {} - -/** - * Class that extends the Embeddings class and provides methods for - * generating embeddings using the Universal Sentence Encoder model from - * TensorFlow.js. - * @example - * ```typescript - * const embeddings = new TensorFlowEmbeddings(); - * const store = new MemoryVectorStore(embeddings); - * - * const documents = [ - * "A document", - * "Some other piece of text", - * "One more", - * "And another", - * ]; - * - * await store.addDocuments( - * documents.map((pageContent) => new Document({ pageContent })) - * ); - * ``` - */ -export class TensorFlowEmbeddings extends Embeddings { - constructor(fields?: TensorFlowEmbeddingsParams) { - super(fields ?? {}); - - try { - tf.backend(); - } catch (e) { - throw new Error("No TensorFlow backend found, see instructions at ..."); - } - } - - _cached: ReturnType; - - /** - * Private method that loads the Universal Sentence Encoder model if it - * hasn't been loaded already. It returns a promise that resolves to the - * loaded model. - * @returns Promise that resolves to the loaded Universal Sentence Encoder model. - */ - private async load() { - if (this._cached === undefined) { - this._cached = load(); - } - return this._cached; - } - - private _embed(texts: string[]) { - return this.caller.call(async () => { - const model = await this.load(); - return model.embed(texts); - }); - } - - /** - * Method that takes a document as input and returns a promise that - * resolves to an embedding for the document. It calls the _embed method - * with the document as the input and processes the result to return a - * single embedding. - * @param document Document to generate an embedding for. - * @returns Promise that resolves to an embedding for the input document. - */ - embedQuery(document: string): Promise { - return this._embed([document]) - .then((embeddings) => embeddings.array()) - .then((embeddings) => embeddings[0]); - } - - /** - * Method that takes an array of documents as input and returns a promise - * that resolves to a 2D array of embeddings for each document. It calls - * the _embed method with the documents as the input and processes the - * result to return the embeddings. - * @param documents Array of documents to generate embeddings for. - * @returns Promise that resolves to a 2D array of embeddings for each input document. - */ - embedDocuments(documents: string[]): Promise { - return this._embed(documents).then((embeddings) => embeddings.array()); - } -} +export * from "@langchain/community/embeddings/tensorflow"; diff --git a/langchain/src/embeddings/voyage.ts b/langchain/src/embeddings/voyage.ts index 6b4d03c7e210..ea725613d118 100644 --- a/langchain/src/embeddings/voyage.ts +++ b/langchain/src/embeddings/voyage.ts @@ -1,152 +1 @@ -import { chunkArray } from "../util/chunk.js"; -import { getEnvironmentVariable } from "../util/env.js"; -import { Embeddings, EmbeddingsParams } from "./base.js"; - -/** - * Interface that extends EmbeddingsParams and defines additional - * parameters specific to the VoyageEmbeddings class. - */ -export interface VoyageEmbeddingsParams extends EmbeddingsParams { - modelName: string; - - /** - * The maximum number of documents to embed in a single request. This is - * limited by the Voyage AI API to a maximum of 8. - */ - batchSize?: number; -} - -/** - * Interface for the request body to generate embeddings. - */ -export interface CreateVoyageEmbeddingRequest { - /** - * @type {string} - * @memberof CreateVoyageEmbeddingRequest - */ - model: string; - - /** - * Text to generate vector expectation - * @type {CreateEmbeddingRequestInput} - * @memberof CreateVoyageEmbeddingRequest - */ - input: string | string[]; -} - -/** - * A class for generating embeddings using the Voyage AI API. - */ -export class VoyageEmbeddings - extends Embeddings - implements VoyageEmbeddingsParams -{ - modelName = "voyage-01"; - - batchSize = 8; - - private apiKey: string; - - basePath?: string = "https://api.voyageai.com/v1"; - - apiUrl: string; - - headers?: Record; - - /** - * Constructor for the VoyageEmbeddings class. - * @param fields - An optional object with properties to configure the instance. - */ - constructor( - fields?: Partial & { - verbose?: boolean; - apiKey?: string; - } - ) { - const fieldsWithDefaults = { ...fields }; - - super(fieldsWithDefaults); - - const apiKey = - fieldsWithDefaults?.apiKey || getEnvironmentVariable("VOYAGEAI_API_KEY"); - - if (!apiKey) { - throw new Error("Voyage AI API key not found"); - } - - this.modelName = fieldsWithDefaults?.modelName ?? this.modelName; - this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize; - this.apiKey = apiKey; - this.apiUrl = `${this.basePath}/embeddings`; - } - - /** - * Generates embeddings for an array of texts. - * @param texts - An array of strings to generate embeddings for. - * @returns A Promise that resolves to an array of embeddings. - */ - async embedDocuments(texts: string[]): Promise { - const batches = chunkArray(texts, this.batchSize); - - const batchRequests = batches.map((batch) => - this.embeddingWithRetry({ - model: this.modelName, - input: batch, - }) - ); - - const batchResponses = await Promise.all(batchRequests); - - const embeddings: number[][] = []; - - for (let i = 0; i < batchResponses.length; i += 1) { - const batch = batches[i]; - const { data: batchResponse } = batchResponses[i]; - for (let j = 0; j < batch.length; j += 1) { - embeddings.push(batchResponse[j].embedding); - } - } - - return embeddings; - } - - /** - * Generates an embedding for a single text. - * @param text - A string to generate an embedding for. - * @returns A Promise that resolves to an array of numbers representing the embedding. - */ - async embedQuery(text: string): Promise { - const { data } = await this.embeddingWithRetry({ - model: this.modelName, - input: text, - }); - - return data[0].embedding; - } - - /** - * Makes a request to the Voyage AI API to generate embeddings for an array of texts. - * @param request - An object with properties to configure the request. - * @returns A Promise that resolves to the response from the Voyage AI API. - */ - - private async embeddingWithRetry(request: CreateVoyageEmbeddingRequest) { - const makeCompletionRequest = async () => { - const url = `${this.apiUrl}`; - const response = await fetch(url, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${this.apiKey}`, - ...this.headers, - }, - body: JSON.stringify(request), - }); - - const json = await response.json(); - return json; - }; - - return this.caller.call(makeCompletionRequest); - } -} +export * from "@langchain/community/embeddings/voyage"; diff --git a/langchain/src/experimental/openai_assistant/index.ts b/langchain/src/experimental/openai_assistant/index.ts index 2e39836f30f3..248ff9407977 100644 --- a/langchain/src/experimental/openai_assistant/index.ts +++ b/langchain/src/experimental/openai_assistant/index.ts @@ -1,4 +1,4 @@ -import { type ClientOptions, OpenAI as OpenAIClient } from "openai"; +import { type ClientOptions, OpenAIClient } from "@langchain/openai"; import { Runnable } from "../../schema/runnable/base.js"; import { sleep } from "../../util/time.js"; import type { RunnableConfig } from "../../schema/runnable/config.js"; diff --git a/langchain/src/experimental/openai_assistant/schema.ts b/langchain/src/experimental/openai_assistant/schema.ts index 10d4ce2658f4..a74ca1fce69f 100644 --- a/langchain/src/experimental/openai_assistant/schema.ts +++ b/langchain/src/experimental/openai_assistant/schema.ts @@ -1,4 +1,4 @@ -import type { OpenAI as OpenAIClient } from "openai"; +import type { OpenAIClient } from "@langchain/openai"; import type { AgentFinish, AgentAction } from "../../schema/index.js"; export type OpenAIAssistantFinish = AgentFinish & { diff --git a/langchain/src/experimental/openai_assistant/tests/openai_assistant.int.test.ts b/langchain/src/experimental/openai_assistant/tests/openai_assistant.int.test.ts index d1dac7daed6a..27d86403fc7e 100644 --- a/langchain/src/experimental/openai_assistant/tests/openai_assistant.int.test.ts +++ b/langchain/src/experimental/openai_assistant/tests/openai_assistant.int.test.ts @@ -2,7 +2,7 @@ /* eslint-disable @typescript-eslint/no-non-null-assertion */ import { z } from "zod"; -import { OpenAI as OpenAIClient } from "openai"; +import { OpenAIClient } from "@langchain/openai"; import { AgentExecutor } from "../../../agents/executor.js"; import { StructuredTool } from "../../../tools/base.js"; import { OpenAIAssistantRunnable } from "../index.js"; diff --git a/langchain/src/experimental/openai_files/index.ts b/langchain/src/experimental/openai_files/index.ts index 6522e906fd29..62f35dadf438 100644 --- a/langchain/src/experimental/openai_files/index.ts +++ b/langchain/src/experimental/openai_files/index.ts @@ -1,4 +1,4 @@ -import { OpenAI as OpenAIClient, type ClientOptions } from "openai"; +import { OpenAIClient, type ClientOptions } from "@langchain/openai"; import { Serializable } from "../../load/serializable.js"; export type OpenAIFilesInput = { diff --git a/langchain/src/graphs/neo4j_graph.ts b/langchain/src/graphs/neo4j_graph.ts index c404e7e3b2ad..99cc0011f46a 100644 --- a/langchain/src/graphs/neo4j_graph.ts +++ b/langchain/src/graphs/neo4j_graph.ts @@ -1,286 +1 @@ -import neo4j, { Neo4jError } from "neo4j-driver"; - -interface Neo4jGraphConfig { - url: string; - username: string; - password: string; - database?: string; -} - -interface StructuredSchema { - nodeProps: { [key: NodeType["labels"]]: NodeType["properties"] }; - relProps: { [key: RelType["type"]]: RelType["properties"] }; - relationships: PathType[]; -} - -type NodeType = { - labels: string; - properties: { property: string; type: string }[]; -}; -type RelType = { - type: string; - properties: { property: string; type: string }[]; -}; -type PathType = { start: string; type: string; end: string }; - -/** - * @security *Security note*: Make sure that the database connection uses credentials - * that are narrowly-scoped to only include necessary permissions. - * Failure to do so may result in data corruption or loss, since the calling - * code may attempt commands that would result in deletion, mutation - * of data if appropriately prompted or reading sensitive data if such - * data is present in the database. - * The best way to guard against such negative outcomes is to (as appropriate) - * limit the permissions granted to the credentials used with this tool. - * For example, creating read only users for the database is a good way to - * ensure that the calling code cannot mutate or delete data. - * - * @link See https://js.langchain.com/docs/security for more information. - */ -export class Neo4jGraph { - private driver: neo4j.Driver; - - private database: string; - - private schema = ""; - - private structuredSchema: StructuredSchema = { - nodeProps: {}, - relProps: {}, - relationships: [], - }; - - constructor({ - url, - username, - password, - database = "neo4j", - }: Neo4jGraphConfig) { - try { - this.driver = neo4j.driver(url, neo4j.auth.basic(username, password)); - this.database = database; - } catch (error) { - throw new Error( - "Could not create a Neo4j driver instance. Please check the connection details." - ); - } - } - - static async initialize(config: Neo4jGraphConfig): Promise { - const graph = new Neo4jGraph(config); - - try { - await graph.verifyConnectivity(); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } catch (error: any) { - console.log("Failed to verify connection."); - } - - try { - await graph.refreshSchema(); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } catch (error: any) { - const message = [ - "Could not use APOC procedures.", - "Please ensure the APOC plugin is installed in Neo4j and that", - "'apoc.meta.data()' is allowed in Neo4j configuration", - ].join("\n"); - - throw new Error(message); - } finally { - console.log("Schema refreshed successfully."); - } - - return graph; - } - - getSchema(): string { - return this.schema; - } - - getStructuredSchema() { - return this.structuredSchema; - } - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - async query(query: string, params: any = {}): Promise { - try { - const result = await this.driver.executeQuery(query, params, { - database: this.database, - }); - return toObjects(result.records); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } catch (error: any) { - if ( - // eslint-disable-next-line - error instanceof Neo4jError && - error.code === "Neo.ClientError.Procedure.ProcedureNotFound" - ) { - throw new Error("Procedure not found in Neo4j."); - } - } - return undefined; - } - - async verifyConnectivity() { - await this.driver.verifyAuthentication(); - } - - async refreshSchema() { - const nodePropertiesQuery = ` - CALL apoc.meta.data() - YIELD label, other, elementType, type, property - WHERE NOT type = "RELATIONSHIP" AND elementType = "node" - WITH label AS nodeLabels, collect({property:property, type:type}) AS properties - RETURN {labels: nodeLabels, properties: properties} AS output - `; - - const relPropertiesQuery = ` - CALL apoc.meta.data() - YIELD label, other, elementType, type, property - WHERE NOT type = "RELATIONSHIP" AND elementType = "relationship" - WITH label AS nodeLabels, collect({property:property, type:type}) AS properties - RETURN {type: nodeLabels, properties: properties} AS output - `; - - const relQuery = ` - CALL apoc.meta.data() - YIELD label, other, elementType, type, property - WHERE type = "RELATIONSHIP" AND elementType = "node" - UNWIND other AS other_node - RETURN {start: label, type: property, end: toString(other_node)} AS output - `; - - // Assuming query method is defined and returns a Promise - const nodeProperties: NodeType[] | undefined = ( - await this.query(nodePropertiesQuery) - )?.map((el: { output: NodeType }) => el.output); - - const relationshipsProperties: RelType[] | undefined = ( - await this.query(relPropertiesQuery) - )?.map((el: { output: RelType }) => el.output); - - const relationships: PathType[] | undefined = ( - await this.query(relQuery) - )?.map((el: { output: PathType }) => el.output); - - // Structured schema similar to Python's dictionary comprehension - this.structuredSchema = { - nodeProps: Object.fromEntries( - nodeProperties?.map((el) => [el.labels, el.properties]) || [] - ), - relProps: Object.fromEntries( - relationshipsProperties?.map((el) => [el.type, el.properties]) || [] - ), - relationships: relationships || [], - }; - - // Format node properties - const formattedNodeProps = nodeProperties?.map((el) => { - const propsStr = el.properties - .map((prop) => `${prop.property}: ${prop.type}`) - .join(", "); - return `${el.labels} {${propsStr}}`; - }); - - // Format relationship properties - const formattedRelProps = relationshipsProperties?.map((el) => { - const propsStr = el.properties - .map((prop) => `${prop.property}: ${prop.type}`) - .join(", "); - return `${el.type} {${propsStr}}`; - }); - - // Format relationships - const formattedRels = relationships?.map( - (el) => `(:${el.start})-[:${el.type}]->(:${el.end})` - ); - - // Combine all formatted elements into a single string - this.schema = [ - "Node properties are the following:", - formattedNodeProps?.join(", "), - "Relationship properties are the following:", - formattedRelProps?.join(", "), - "The relationships are the following:", - formattedRels?.join(", "), - ].join("\n"); - } - - async close() { - await this.driver.close(); - } -} - -function toObjects(records: neo4j.Record[]) { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const recordValues: Record[] = records.map((record) => { - const rObj = record.toObject(); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const out: { [key: string]: any } = {}; - Object.keys(rObj).forEach((key) => { - out[key] = itemIntToString(rObj[key]); - }); - return out; - }); - return recordValues; -} - -// eslint-disable-next-line @typescript-eslint/no-explicit-any -function itemIntToString(item: any): any { - if (neo4j.isInt(item)) return item.toString(); - if (Array.isArray(item)) return item.map((ii) => itemIntToString(ii)); - if (["number", "string", "boolean"].indexOf(typeof item) !== -1) return item; - if (item === null) return item; - if (typeof item === "object") return objIntToString(item); -} - -// eslint-disable-next-line @typescript-eslint/no-explicit-any -function objIntToString(obj: any) { - const entry = extractFromNeoObjects(obj); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let newObj: any = null; - if (Array.isArray(entry)) { - newObj = entry.map((item) => itemIntToString(item)); - } else if (entry !== null && typeof entry === "object") { - newObj = {}; - Object.keys(entry).forEach((key) => { - newObj[key] = itemIntToString(entry[key]); - }); - } - return newObj; -} - -// eslint-disable-next-line @typescript-eslint/no-explicit-any -function extractFromNeoObjects(obj: any) { - if ( - // eslint-disable-next-line - obj instanceof (neo4j.types.Node as any) || - // eslint-disable-next-line - obj instanceof (neo4j.types.Relationship as any) - ) { - return obj.properties; - // eslint-disable-next-line - } else if (obj instanceof (neo4j.types.Path as any)) { - // eslint-disable-next-line - return [].concat.apply([], extractPathForRows(obj)); - } - return obj; -} - -const extractPathForRows = (path: neo4j.Path) => { - let { segments } = path; - // Zero length path. No relationship, end === start - if (!Array.isArray(path.segments) || path.segments.length < 1) { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - segments = [{ ...path, end: null } as any]; - } - // eslint-disable-next-line @typescript-eslint/no-explicit-any - return segments.map((segment: any) => - [ - objIntToString(segment.start), - objIntToString(segment.relationship), - objIntToString(segment.end), - ].filter((part) => part !== null) - ); -}; +export * from "@langchain/community/graphs/neo4j_graph"; diff --git a/langchain/src/graphs/tests/neo4j_graph.int.test.ts b/langchain/src/graphs/tests/neo4j_graph.int.test.ts deleted file mode 100644 index 3b47800fc323..000000000000 --- a/langchain/src/graphs/tests/neo4j_graph.int.test.ts +++ /dev/null @@ -1,56 +0,0 @@ -/* eslint-disable no-process-env */ - -import { test } from "@jest/globals"; -import { Neo4jGraph } from "../neo4j_graph.js"; - -describe.skip("Neo4j Graph Tests", () => { - const url = process.env.NEO4J_URI as string; - const username = process.env.NEO4J_USERNAME as string; - const password = process.env.NEO4J_PASSWORD as string; - let graph: Neo4jGraph; - - beforeEach(async () => { - graph = await Neo4jGraph.initialize({ url, username, password }); - }); - afterEach(async () => { - await graph.close(); - }); - - test("Schema generation works correctly", async () => { - expect(url).toBeDefined(); - expect(username).toBeDefined(); - expect(password).toBeDefined(); - - // Clear the database - await graph.query("MATCH (n) DETACH DELETE n"); - - await graph.query( - "CREATE (a:Actor {name:'Bruce Willis'})" + - "-[:ACTED_IN {roles: ['Butch Coolidge']}]->(:Movie {title: 'Pulp Fiction'})" - ); - - await graph.refreshSchema(); - console.log(graph.getSchema()); - - // expect(graph.getSchema()).toMatchInlineSnapshot(` - // "Node properties are the following: - // Actor {name: STRING}, Movie {title: STRING} - // Relationship properties are the following: - // ACTED_IN {roles: LIST} - // The relationships are the following: - // (:Actor)-[:ACTED_IN]->(:Movie)" - // `); - }); - - test("Test that Neo4j database is correctly instantiated and connected", async () => { - expect(url).toBeDefined(); - expect(username).toBeDefined(); - expect(password).toBeDefined(); - - // Integers are casted to strings in the output - const expectedOutput = [{ output: { str: "test", int: "1" } }]; - const res = await graph.query('RETURN {str: "test", int: 1} AS output'); - await graph.close(); - expect(res).toEqual(expectedOutput); - }); -}); diff --git a/langchain/src/llms/ai21.ts b/langchain/src/llms/ai21.ts index 66d63b97abc3..d68f4b637a87 100644 --- a/langchain/src/llms/ai21.ts +++ b/langchain/src/llms/ai21.ts @@ -1,199 +1 @@ -import { LLM, BaseLLMParams } from "./base.js"; -import { getEnvironmentVariable } from "../util/env.js"; - -/** - * Type definition for AI21 penalty data. - */ -export type AI21PenaltyData = { - scale: number; - applyToWhitespaces: boolean; - applyToPunctuations: boolean; - applyToNumbers: boolean; - applyToStopwords: boolean; - applyToEmojis: boolean; -}; - -/** - * Interface for AI21 input parameters. - */ -export interface AI21Input extends BaseLLMParams { - ai21ApiKey?: string; - model?: string; - temperature?: number; - minTokens?: number; - maxTokens?: number; - topP?: number; - presencePenalty?: AI21PenaltyData; - countPenalty?: AI21PenaltyData; - frequencyPenalty?: AI21PenaltyData; - numResults?: number; - logitBias?: Record; - stop?: string[]; - baseUrl?: string; -} - -/** - * Class representing the AI21 language model. It extends the LLM (Large - * Language Model) class, providing a standard interface for interacting - * with the AI21 language model. - */ -export class AI21 extends LLM implements AI21Input { - model = "j2-jumbo-instruct"; - - temperature = 0.7; - - maxTokens = 1024; - - minTokens = 0; - - topP = 1; - - presencePenalty = AI21.getDefaultAI21PenaltyData(); - - countPenalty = AI21.getDefaultAI21PenaltyData(); - - frequencyPenalty = AI21.getDefaultAI21PenaltyData(); - - numResults = 1; - - logitBias?: Record; - - ai21ApiKey?: string; - - stop?: string[]; - - baseUrl?: string; - - constructor(fields?: AI21Input) { - super(fields ?? {}); - - this.model = fields?.model ?? this.model; - this.temperature = fields?.temperature ?? this.temperature; - this.maxTokens = fields?.maxTokens ?? this.maxTokens; - this.minTokens = fields?.minTokens ?? this.minTokens; - this.topP = fields?.topP ?? this.topP; - this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty; - this.countPenalty = fields?.countPenalty ?? this.countPenalty; - this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty; - this.numResults = fields?.numResults ?? this.numResults; - this.logitBias = fields?.logitBias; - this.ai21ApiKey = - fields?.ai21ApiKey ?? getEnvironmentVariable("AI21_API_KEY"); - this.stop = fields?.stop; - this.baseUrl = fields?.baseUrl; - } - - /** - * Method to validate the environment. It checks if the AI21 API key is - * set. If not, it throws an error. - */ - validateEnvironment() { - if (!this.ai21ApiKey) { - throw new Error( - `No AI21 API key found. Please set it as "AI21_API_KEY" in your environment variables.` - ); - } - } - - /** - * Static method to get the default penalty data for AI21. - * @returns AI21PenaltyData - */ - static getDefaultAI21PenaltyData(): AI21PenaltyData { - return { - scale: 0, - applyToWhitespaces: true, - applyToPunctuations: true, - applyToNumbers: true, - applyToStopwords: true, - applyToEmojis: true, - }; - } - - /** Get the type of LLM. */ - _llmType() { - return "ai21"; - } - - /** Get the default parameters for calling AI21 API. */ - get defaultParams() { - return { - temperature: this.temperature, - maxTokens: this.maxTokens, - minTokens: this.minTokens, - topP: this.topP, - presencePenalty: this.presencePenalty, - countPenalty: this.countPenalty, - frequencyPenalty: this.frequencyPenalty, - numResults: this.numResults, - logitBias: this.logitBias, - }; - } - - /** Get the identifying parameters for this LLM. */ - get identifyingParams() { - return { ...this.defaultParams, model: this.model }; - } - - /** Call out to AI21's complete endpoint. - Args: - prompt: The prompt to pass into the model. - stop: Optional list of stop words to use when generating. - - Returns: - The string generated by the model. - - Example: - let response = ai21._call("Tell me a joke."); - */ - async _call( - prompt: string, - options: this["ParsedCallOptions"] - ): Promise { - let stop = options?.stop; - this.validateEnvironment(); - if (this.stop && stop && this.stop.length > 0 && stop.length > 0) { - throw new Error("`stop` found in both the input and default params."); - } - stop = this.stop ?? stop ?? []; - - const baseUrl = - this.baseUrl ?? this.model === "j1-grande-instruct" - ? "https://api.ai21.com/studio/v1/experimental" - : "https://api.ai21.com/studio/v1"; - - const url = `${baseUrl}/${this.model}/complete`; - const headers = { - Authorization: `Bearer ${this.ai21ApiKey}`, - "Content-Type": "application/json", - }; - const data = { prompt, stopSequences: stop, ...this.defaultParams }; - const responseData = await this.caller.callWithOptions({}, async () => { - const response = await fetch(url, { - method: "POST", - headers, - body: JSON.stringify(data), - signal: options.signal, - }); - if (!response.ok) { - const error = new Error( - `AI21 call failed with status code ${response.status}` - ); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (error as any).response = response; - throw error; - } - return response.json(); - }); - - if ( - !responseData.completions || - responseData.completions.length === 0 || - !responseData.completions[0].data - ) { - throw new Error("No completions found in response"); - } - - return responseData.completions[0].data.text ?? ""; - } -} +export * from "@langchain/community/llms/ai21"; diff --git a/langchain/src/llms/aleph_alpha.ts b/langchain/src/llms/aleph_alpha.ts index 0e8968327073..1d0c12be8e26 100644 --- a/langchain/src/llms/aleph_alpha.ts +++ b/langchain/src/llms/aleph_alpha.ts @@ -1,298 +1 @@ -import { LLM, BaseLLMParams } from "./base.js"; -import { getEnvironmentVariable } from "../util/env.js"; - -/** - * Interface for the input parameters specific to the Aleph Alpha LLM. - */ -export interface AlephAlphaInput extends BaseLLMParams { - model: string; - maximum_tokens: number; - minimum_tokens?: number; - echo?: boolean; - temperature?: number; - top_k?: number; - top_p?: number; - presence_penalty?: number; - frequency_penalty?: number; - sequence_penalty?: number; - sequence_penalty_min_length?: number; - repetition_penalties_include_prompt?: boolean; - repetition_penalties_include_completion?: boolean; - use_multiplicative_presence_penalty?: boolean; - use_multiplicative_frequency_penalty?: boolean; - use_multiplicative_sequence_penalty?: boolean; - penalty_bias?: string; - penalty_exceptions?: string[]; - penalty_exceptions_include_stop_sequences?: boolean; - best_of?: number; - n?: number; - logit_bias?: object; - log_probs?: number; - tokens?: boolean; - raw_completion: boolean; - disable_optimizations?: boolean; - completion_bias_inclusion?: string[]; - completion_bias_inclusion_first_token_only: boolean; - completion_bias_exclusion?: string[]; - completion_bias_exclusion_first_token_only: boolean; - contextual_control_threshold?: number; - control_log_additive: boolean; - stop?: string[]; - aleph_alpha_api_key?: string; - base_url: string; -} - -/** - * Specific implementation of a Large Language Model (LLM) designed to - * interact with the Aleph Alpha API. It extends the base LLM class and - * includes a variety of parameters for customizing the behavior of the - * Aleph Alpha model. - */ -export class AlephAlpha extends LLM implements AlephAlphaInput { - model = "luminous-base"; - - maximum_tokens = 64; - - minimum_tokens = 0; - - echo: boolean; - - temperature = 0.0; - - top_k: number; - - top_p = 0.0; - - presence_penalty?: number; - - frequency_penalty?: number; - - sequence_penalty?: number; - - sequence_penalty_min_length?: number; - - repetition_penalties_include_prompt?: boolean; - - repetition_penalties_include_completion?: boolean; - - use_multiplicative_presence_penalty?: boolean; - - use_multiplicative_frequency_penalty?: boolean; - - use_multiplicative_sequence_penalty?: boolean; - - penalty_bias?: string; - - penalty_exceptions?: string[]; - - penalty_exceptions_include_stop_sequences?: boolean; - - best_of?: number; - - n?: number; - - logit_bias?: object; - - log_probs?: number; - - tokens?: boolean; - - raw_completion: boolean; - - disable_optimizations?: boolean; - - completion_bias_inclusion?: string[]; - - completion_bias_inclusion_first_token_only: boolean; - - completion_bias_exclusion?: string[]; - - completion_bias_exclusion_first_token_only: boolean; - - contextual_control_threshold?: number; - - control_log_additive: boolean; - - aleph_alpha_api_key? = getEnvironmentVariable("ALEPH_ALPHA_API_KEY"); - - stop?: string[]; - - base_url = "https://api.aleph-alpha.com/complete"; - - constructor(fields: Partial) { - super(fields ?? {}); - this.model = fields?.model ?? this.model; - this.temperature = fields?.temperature ?? this.temperature; - this.maximum_tokens = fields?.maximum_tokens ?? this.maximum_tokens; - this.minimum_tokens = fields?.minimum_tokens ?? this.minimum_tokens; - this.top_k = fields?.top_k ?? this.top_k; - this.top_p = fields?.top_p ?? this.top_p; - this.presence_penalty = fields?.presence_penalty ?? this.presence_penalty; - this.frequency_penalty = - fields?.frequency_penalty ?? this.frequency_penalty; - this.sequence_penalty = fields?.sequence_penalty ?? this.sequence_penalty; - this.sequence_penalty_min_length = - fields?.sequence_penalty_min_length ?? this.sequence_penalty_min_length; - this.repetition_penalties_include_prompt = - fields?.repetition_penalties_include_prompt ?? - this.repetition_penalties_include_prompt; - this.repetition_penalties_include_completion = - fields?.repetition_penalties_include_completion ?? - this.repetition_penalties_include_completion; - this.use_multiplicative_presence_penalty = - fields?.use_multiplicative_presence_penalty ?? - this.use_multiplicative_presence_penalty; - this.use_multiplicative_frequency_penalty = - fields?.use_multiplicative_frequency_penalty ?? - this.use_multiplicative_frequency_penalty; - this.use_multiplicative_sequence_penalty = - fields?.use_multiplicative_sequence_penalty ?? - this.use_multiplicative_sequence_penalty; - this.penalty_bias = fields?.penalty_bias ?? this.penalty_bias; - this.penalty_exceptions = - fields?.penalty_exceptions ?? this.penalty_exceptions; - this.penalty_exceptions_include_stop_sequences = - fields?.penalty_exceptions_include_stop_sequences ?? - this.penalty_exceptions_include_stop_sequences; - this.best_of = fields?.best_of ?? this.best_of; - this.n = fields?.n ?? this.n; - this.logit_bias = fields?.logit_bias ?? this.logit_bias; - this.log_probs = fields?.log_probs ?? this.log_probs; - this.tokens = fields?.tokens ?? this.tokens; - this.raw_completion = fields?.raw_completion ?? this.raw_completion; - this.disable_optimizations = - fields?.disable_optimizations ?? this.disable_optimizations; - this.completion_bias_inclusion = - fields?.completion_bias_inclusion ?? this.completion_bias_inclusion; - this.completion_bias_inclusion_first_token_only = - fields?.completion_bias_inclusion_first_token_only ?? - this.completion_bias_inclusion_first_token_only; - this.completion_bias_exclusion = - fields?.completion_bias_exclusion ?? this.completion_bias_exclusion; - this.completion_bias_exclusion_first_token_only = - fields?.completion_bias_exclusion_first_token_only ?? - this.completion_bias_exclusion_first_token_only; - this.contextual_control_threshold = - fields?.contextual_control_threshold ?? this.contextual_control_threshold; - this.control_log_additive = - fields?.control_log_additive ?? this.control_log_additive; - this.aleph_alpha_api_key = - fields?.aleph_alpha_api_key ?? this.aleph_alpha_api_key; - this.stop = fields?.stop ?? this.stop; - } - - /** - * Validates the environment by ensuring the necessary Aleph Alpha API key - * is available. Throws an error if the API key is missing. - */ - validateEnvironment() { - if (!this.aleph_alpha_api_key) { - throw new Error( - "Aleph Alpha API Key is missing in environment variables." - ); - } - } - - /** Get the default parameters for calling Aleph Alpha API. */ - get defaultParams() { - return { - model: this.model, - temperature: this.temperature, - maximum_tokens: this.maximum_tokens, - minimum_tokens: this.minimum_tokens, - top_k: this.top_k, - top_p: this.top_p, - presence_penalty: this.presence_penalty, - frequency_penalty: this.frequency_penalty, - sequence_penalty: this.sequence_penalty, - sequence_penalty_min_length: this.sequence_penalty_min_length, - repetition_penalties_include_prompt: - this.repetition_penalties_include_prompt, - repetition_penalties_include_completion: - this.repetition_penalties_include_completion, - use_multiplicative_presence_penalty: - this.use_multiplicative_presence_penalty, - use_multiplicative_frequency_penalty: - this.use_multiplicative_frequency_penalty, - use_multiplicative_sequence_penalty: - this.use_multiplicative_sequence_penalty, - penalty_bias: this.penalty_bias, - penalty_exceptions: this.penalty_exceptions, - penalty_exceptions_include_stop_sequences: - this.penalty_exceptions_include_stop_sequences, - best_of: this.best_of, - n: this.n, - logit_bias: this.logit_bias, - log_probs: this.log_probs, - tokens: this.tokens, - raw_completion: this.raw_completion, - disable_optimizations: this.disable_optimizations, - completion_bias_inclusion: this.completion_bias_inclusion, - completion_bias_inclusion_first_token_only: - this.completion_bias_inclusion_first_token_only, - completion_bias_exclusion: this.completion_bias_exclusion, - completion_bias_exclusion_first_token_only: - this.completion_bias_exclusion_first_token_only, - contextual_control_threshold: this.contextual_control_threshold, - control_log_additive: this.control_log_additive, - }; - } - - /** Get the identifying parameters for this LLM. */ - get identifyingParams() { - return { ...this.defaultParams }; - } - - /** Get the type of LLM. */ - _llmType(): string { - return "aleph_alpha"; - } - - async _call( - prompt: string, - options: this["ParsedCallOptions"] - ): Promise { - let stop = options?.stop; - this.validateEnvironment(); - if (this.stop && stop && this.stop.length > 0 && stop.length > 0) { - throw new Error("`stop` found in both the input and default params."); - } - stop = this.stop ?? stop ?? []; - const headers = { - Authorization: `Bearer ${this.aleph_alpha_api_key}`, - "Content-Type": "application/json", - Accept: "application/json", - }; - const data = { prompt, stop_sequences: stop, ...this.defaultParams }; - const responseData = await this.caller.call(async () => { - const response = await fetch(this.base_url, { - method: "POST", - headers, - body: JSON.stringify(data), - signal: options.signal, - }); - if (!response.ok) { - // consume the response body to release the connection - // https://undici.nodejs.org/#/?id=garbage-collection - const text = await response.text(); - const error = new Error( - `Aleph Alpha call failed with status ${response.status} and body ${text}` - ); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (error as any).response = response; - throw error; - } - return response.json(); - }); - - if ( - !responseData.completions || - responseData.completions.length === 0 || - !responseData.completions[0].completion - ) { - throw new Error("No completions found in response"); - } - - return responseData.completions[0].completion ?? ""; - } -} +export * from "@langchain/community/llms/aleph_alpha"; diff --git a/langchain/src/llms/bedrock/web.ts b/langchain/src/llms/bedrock/web.ts index f10660c9feab..3ed2ccdd2650 100644 --- a/langchain/src/llms/bedrock/web.ts +++ b/langchain/src/llms/bedrock/web.ts @@ -1,356 +1 @@ -import { SignatureV4 } from "@smithy/signature-v4"; - -import { HttpRequest } from "@smithy/protocol-http"; -import { EventStreamCodec } from "@smithy/eventstream-codec"; -import { fromUtf8, toUtf8 } from "@smithy/util-utf8"; -import { Sha256 } from "@aws-crypto/sha256-js"; - -import { - BaseBedrockInput, - BedrockLLMInputOutputAdapter, - type CredentialType, -} from "../../util/bedrock.js"; -import { getEnvironmentVariable } from "../../util/env.js"; -import { LLM, BaseLLMParams } from "../base.js"; -import { CallbackManagerForLLMRun } from "../../callbacks/manager.js"; -import { GenerationChunk } from "../../schema/index.js"; -import type { SerializedFields } from "../../load/map_keys.js"; - -/** - * A type of Large Language Model (LLM) that interacts with the Bedrock - * service. It extends the base `LLM` class and implements the - * `BaseBedrockInput` interface. The class is designed to authenticate and - * interact with the Bedrock service, which is a part of Amazon Web - * Services (AWS). It uses AWS credentials for authentication and can be - * configured with various parameters such as the model to use, the AWS - * region, and the maximum number of tokens to generate. - */ -export class Bedrock extends LLM implements BaseBedrockInput { - model = "amazon.titan-tg1-large"; - - region: string; - - credentials: CredentialType; - - temperature?: number | undefined = undefined; - - maxTokens?: number | undefined = undefined; - - fetchFn: typeof fetch; - - endpointHost?: string; - - /** @deprecated */ - stopSequences?: string[]; - - modelKwargs?: Record; - - codec: EventStreamCodec = new EventStreamCodec(toUtf8, fromUtf8); - - streaming = false; - - lc_serializable = true; - - get lc_aliases(): Record { - return { - model: "model_id", - region: "region_name", - }; - } - - get lc_secrets(): { [key: string]: string } | undefined { - return { - "credentials.accessKeyId": "BEDROCK_AWS_ACCESS_KEY_ID", - "credentials.secretAccessKey": "BEDROCK_AWS_SECRET_ACCESS_KEY", - }; - } - - get lc_attributes(): SerializedFields | undefined { - return { region: this.region }; - } - - _llmType() { - return "bedrock"; - } - - static lc_name() { - return "Bedrock"; - } - - constructor(fields?: Partial & BaseLLMParams) { - super(fields ?? {}); - - this.model = fields?.model ?? this.model; - const allowedModels = ["ai21", "anthropic", "amazon", "cohere", "meta"]; - if (!allowedModels.includes(this.model.split(".")[0])) { - throw new Error( - `Unknown model: '${this.model}', only these are supported: ${allowedModels}` - ); - } - const region = - fields?.region ?? getEnvironmentVariable("AWS_DEFAULT_REGION"); - if (!region) { - throw new Error( - "Please set the AWS_DEFAULT_REGION environment variable or pass it to the constructor as the region field." - ); - } - this.region = region; - - const credentials = fields?.credentials; - if (!credentials) { - throw new Error( - "Please set the AWS credentials in the 'credentials' field." - ); - } - this.credentials = credentials; - - this.temperature = fields?.temperature ?? this.temperature; - this.maxTokens = fields?.maxTokens ?? this.maxTokens; - this.fetchFn = fields?.fetchFn ?? fetch.bind(globalThis); - this.endpointHost = fields?.endpointHost ?? fields?.endpointUrl; - this.stopSequences = fields?.stopSequences; - this.modelKwargs = fields?.modelKwargs; - this.streaming = fields?.streaming ?? this.streaming; - } - - /** Call out to Bedrock service model. - Arguments: - prompt: The prompt to pass into the model. - - Returns: - The string generated by the model. - - Example: - response = model.call("Tell me a joke.") - */ - async _call( - prompt: string, - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): Promise { - const service = "bedrock-runtime"; - const endpointHost = - this.endpointHost ?? `${service}.${this.region}.amazonaws.com`; - const provider = this.model.split(".")[0]; - if (this.streaming) { - const stream = this._streamResponseChunks(prompt, options, runManager); - let finalResult: GenerationChunk | undefined; - for await (const chunk of stream) { - if (finalResult === undefined) { - finalResult = chunk; - } else { - finalResult = finalResult.concat(chunk); - } - } - return finalResult?.text ?? ""; - } - const response = await this._signedFetch(prompt, options, { - bedrockMethod: "invoke", - endpointHost, - provider, - }); - const json = await response.json(); - if (!response.ok) { - throw new Error( - `Error ${response.status}: ${json.message ?? JSON.stringify(json)}` - ); - } - const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, json); - return text; - } - - async _signedFetch( - prompt: string, - options: this["ParsedCallOptions"], - fields: { - bedrockMethod: "invoke" | "invoke-with-response-stream"; - endpointHost: string; - provider: string; - } - ) { - const { bedrockMethod, endpointHost, provider } = fields; - const inputBody = BedrockLLMInputOutputAdapter.prepareInput( - provider, - prompt, - this.maxTokens, - this.temperature, - options.stop ?? this.stopSequences, - this.modelKwargs, - fields.bedrockMethod - ); - - const url = new URL( - `https://${endpointHost}/model/${this.model}/${bedrockMethod}` - ); - - const request = new HttpRequest({ - hostname: url.hostname, - path: url.pathname, - protocol: url.protocol, - method: "POST", // method must be uppercase - body: JSON.stringify(inputBody), - query: Object.fromEntries(url.searchParams.entries()), - headers: { - // host is required by AWS Signature V4: https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html - host: url.host, - accept: "application/json", - "content-type": "application/json", - }, - }); - - const signer = new SignatureV4({ - credentials: this.credentials, - service: "bedrock", - region: this.region, - sha256: Sha256, - }); - - const signedRequest = await signer.sign(request); - - // Send request to AWS using the low-level fetch API - const response = await this.caller.callWithOptions( - { signal: options.signal }, - async () => - this.fetchFn(url, { - headers: signedRequest.headers, - body: signedRequest.body, - method: signedRequest.method, - }) - ); - return response; - } - - invocationParams(options?: this["ParsedCallOptions"]) { - return { - model: this.model, - region: this.region, - temperature: this.temperature, - maxTokens: this.maxTokens, - stop: options?.stop ?? this.stopSequences, - modelKwargs: this.modelKwargs, - }; - } - - async *_streamResponseChunks( - prompt: string, - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): AsyncGenerator { - const provider = this.model.split(".")[0]; - const bedrockMethod = - provider === "anthropic" || provider === "cohere" || provider === "meta" - ? "invoke-with-response-stream" - : "invoke"; - - const service = "bedrock-runtime"; - const endpointHost = - this.endpointHost ?? `${service}.${this.region}.amazonaws.com`; - - // Send request to AWS using the low-level fetch API - const response = await this._signedFetch(prompt, options, { - bedrockMethod, - endpointHost, - provider, - }); - - if (response.status < 200 || response.status >= 300) { - throw Error( - `Failed to access underlying url '${endpointHost}': got ${ - response.status - } ${response.statusText}: ${await response.text()}` - ); - } - - if ( - provider === "anthropic" || - provider === "cohere" || - provider === "meta" - ) { - const reader = response.body?.getReader(); - const decoder = new TextDecoder(); - for await (const chunk of this._readChunks(reader)) { - const event = this.codec.decode(chunk); - if ( - (event.headers[":event-type"] !== undefined && - event.headers[":event-type"].value !== "chunk") || - event.headers[":content-type"].value !== "application/json" - ) { - throw Error(`Failed to get event chunk: got ${chunk}`); - } - const body = JSON.parse(decoder.decode(event.body)); - if (body.message) { - throw new Error(body.message); - } - if (body.bytes !== undefined) { - const chunkResult = JSON.parse( - decoder.decode( - Uint8Array.from(atob(body.bytes), (m) => m.codePointAt(0) ?? 0) - ) - ); - const text = BedrockLLMInputOutputAdapter.prepareOutput( - provider, - chunkResult - ); - yield new GenerationChunk({ - text, - generationInfo: {}, - }); - // eslint-disable-next-line no-void - void runManager?.handleLLMNewToken(text); - } - } - } else { - const json = await response.json(); - const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, json); - yield new GenerationChunk({ - text, - generationInfo: {}, - }); - // eslint-disable-next-line no-void - void runManager?.handleLLMNewToken(text); - } - } - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - _readChunks(reader: any) { - function _concatChunks(a: Uint8Array, b: Uint8Array) { - const newBuffer = new Uint8Array(a.length + b.length); - newBuffer.set(a); - newBuffer.set(b, a.length); - return newBuffer; - } - - function getMessageLength(buffer: Uint8Array) { - if (buffer.byteLength === 0) return 0; - const view = new DataView( - buffer.buffer, - buffer.byteOffset, - buffer.byteLength - ); - - return view.getUint32(0, false); - } - - return { - async *[Symbol.asyncIterator]() { - let readResult = await reader.read(); - - let buffer: Uint8Array = new Uint8Array(0); - while (!readResult.done) { - const chunk: Uint8Array = readResult.value; - - buffer = _concatChunks(buffer, chunk); - let messageLength = getMessageLength(buffer); - - while (buffer.byteLength > 0 && buffer.byteLength >= messageLength) { - yield buffer.slice(0, messageLength); - buffer = buffer.slice(messageLength); - messageLength = getMessageLength(buffer); - } - - readResult = await reader.read(); - } - }, - }; - } -} +export * from "@langchain/community/llms/bedrock/web"; diff --git a/langchain/src/llms/cloudflare_workersai.ts b/langchain/src/llms/cloudflare_workersai.ts index 25c6b5a16276..2e5add32a465 100644 --- a/langchain/src/llms/cloudflare_workersai.ts +++ b/langchain/src/llms/cloudflare_workersai.ts @@ -1,189 +1 @@ -import { LLM, BaseLLMParams } from "./base.js"; -import { getEnvironmentVariable } from "../util/env.js"; -import { CallbackManagerForLLMRun } from "../callbacks/manager.js"; -import { GenerationChunk } from "../schema/index.js"; -import { convertEventStreamToIterableReadableDataStream } from "../util/event-source-parse.js"; - -/** - * Interface for CloudflareWorkersAI input parameters. - */ -export interface CloudflareWorkersAIInput { - cloudflareAccountId?: string; - cloudflareApiToken?: string; - model?: string; - baseUrl?: string; - streaming?: boolean; -} - -/** - * Class representing the CloudflareWorkersAI language model. It extends the LLM (Large - * Language Model) class, providing a standard interface for interacting - * with the CloudflareWorkersAI language model. - */ -export class CloudflareWorkersAI - extends LLM - implements CloudflareWorkersAIInput -{ - model = "@cf/meta/llama-2-7b-chat-int8"; - - cloudflareAccountId?: string; - - cloudflareApiToken?: string; - - baseUrl: string; - - streaming = false; - - static lc_name() { - return "CloudflareWorkersAI"; - } - - lc_serializable = true; - - constructor(fields?: CloudflareWorkersAIInput & BaseLLMParams) { - super(fields ?? {}); - - this.model = fields?.model ?? this.model; - this.streaming = fields?.streaming ?? this.streaming; - this.cloudflareAccountId = - fields?.cloudflareAccountId ?? - getEnvironmentVariable("CLOUDFLARE_ACCOUNT_ID"); - this.cloudflareApiToken = - fields?.cloudflareApiToken ?? - getEnvironmentVariable("CLOUDFLARE_API_TOKEN"); - this.baseUrl = - fields?.baseUrl ?? - `https://api.cloudflare.com/client/v4/accounts/${this.cloudflareAccountId}/ai/run`; - if (this.baseUrl.endsWith("/")) { - this.baseUrl = this.baseUrl.slice(0, -1); - } - } - - /** - * Method to validate the environment. - */ - validateEnvironment() { - if (this.baseUrl === undefined) { - if (!this.cloudflareAccountId) { - throw new Error( - `No Cloudflare account ID found. Please provide it when instantiating the CloudflareWorkersAI class, or set it as "CLOUDFLARE_ACCOUNT_ID" in your environment variables.` - ); - } - if (!this.cloudflareApiToken) { - throw new Error( - `No Cloudflare API key found. Please provide it when instantiating the CloudflareWorkersAI class, or set it as "CLOUDFLARE_API_KEY" in your environment variables.` - ); - } - } - } - - /** Get the identifying parameters for this LLM. */ - get identifyingParams() { - return { model: this.model }; - } - - /** - * Get the parameters used to invoke the model - */ - invocationParams() { - return { - model: this.model, - }; - } - - /** Get the type of LLM. */ - _llmType() { - return "cloudflare"; - } - - async _request( - prompt: string, - options: this["ParsedCallOptions"], - stream?: boolean - ) { - this.validateEnvironment(); - - const url = `${this.baseUrl}/${this.model}`; - const headers = { - Authorization: `Bearer ${this.cloudflareApiToken}`, - "Content-Type": "application/json", - }; - - const data = { prompt, stream }; - return this.caller.call(async () => { - const response = await fetch(url, { - method: "POST", - headers, - body: JSON.stringify(data), - signal: options.signal, - }); - if (!response.ok) { - const error = new Error( - `Cloudflare LLM call failed with status code ${response.status}` - ); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (error as any).response = response; - throw error; - } - return response; - }); - } - - async *_streamResponseChunks( - prompt: string, - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): AsyncGenerator { - const response = await this._request(prompt, options, true); - if (!response.body) { - throw new Error("Empty response from Cloudflare. Please try again."); - } - const stream = convertEventStreamToIterableReadableDataStream( - response.body - ); - for await (const chunk of stream) { - if (chunk !== "[DONE]") { - const parsedChunk = JSON.parse(chunk); - const generationChunk = new GenerationChunk({ - text: parsedChunk.response, - }); - yield generationChunk; - // eslint-disable-next-line no-void - void runManager?.handleLLMNewToken(generationChunk.text ?? ""); - } - } - } - - /** Call out to CloudflareWorkersAI's complete endpoint. - Args: - prompt: The prompt to pass into the model. - Returns: - The string generated by the model. - Example: - let response = CloudflareWorkersAI.call("Tell me a joke."); - */ - async _call( - prompt: string, - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): Promise { - if (!this.streaming) { - const response = await this._request(prompt, options); - - const responseData = await response.json(); - - return responseData.result.response; - } else { - const stream = this._streamResponseChunks(prompt, options, runManager); - let finalResult: GenerationChunk | undefined; - for await (const chunk of stream) { - if (finalResult === undefined) { - finalResult = chunk; - } else { - finalResult = finalResult.concat(chunk); - } - } - return finalResult?.text ?? ""; - } - } -} +export * from "@langchain/community/llms/cloudflare_workersai"; diff --git a/langchain/src/llms/cohere.ts b/langchain/src/llms/cohere.ts index 393876efb939..8b911819109e 100644 --- a/langchain/src/llms/cohere.ts +++ b/langchain/src/llms/cohere.ts @@ -1,129 +1 @@ -import { getEnvironmentVariable } from "../util/env.js"; -import { LLM, BaseLLMParams } from "./base.js"; - -/** - * Interface for the input parameters specific to the Cohere model. - */ -export interface CohereInput extends BaseLLMParams { - /** Sampling temperature to use */ - temperature?: number; - - /** - * Maximum number of tokens to generate in the completion. - */ - maxTokens?: number; - - /** Model to use */ - model?: string; - - apiKey?: string; -} - -/** - * Class representing a Cohere Large Language Model (LLM). It interacts - * with the Cohere API to generate text completions. - * @example - * ```typescript - * const model = new Cohere({ - * temperature: 0.7, - * maxTokens: 20, - * maxRetries: 5, - * }); - * - * const res = await model.call( - * "Question: What would be a good company name for a company that makes colorful socks?\nAnswer:" - * ); - * console.log({ res }); - * ``` - */ -export class Cohere extends LLM implements CohereInput { - static lc_name() { - return "Cohere"; - } - - get lc_secrets(): { [key: string]: string } | undefined { - return { - apiKey: "COHERE_API_KEY", - }; - } - - get lc_aliases(): { [key: string]: string } | undefined { - return { - apiKey: "cohere_api_key", - }; - } - - lc_serializable = true; - - temperature = 0; - - maxTokens = 250; - - model: string; - - apiKey: string; - - constructor(fields?: CohereInput) { - super(fields ?? {}); - - const apiKey = fields?.apiKey ?? getEnvironmentVariable("COHERE_API_KEY"); - - if (!apiKey) { - throw new Error( - "Please set the COHERE_API_KEY environment variable or pass it to the constructor as the apiKey field." - ); - } - - this.apiKey = apiKey; - this.maxTokens = fields?.maxTokens ?? this.maxTokens; - this.temperature = fields?.temperature ?? this.temperature; - this.model = fields?.model ?? this.model; - } - - _llmType() { - return "cohere"; - } - - /** @ignore */ - async _call( - prompt: string, - options: this["ParsedCallOptions"] - ): Promise { - const { cohere } = await Cohere.imports(); - - cohere.init(this.apiKey); - - // Hit the `generate` endpoint on the `large` model - const generateResponse = await this.caller.callWithOptions( - { signal: options.signal }, - cohere.generate.bind(cohere), - { - prompt, - model: this.model, - max_tokens: this.maxTokens, - temperature: this.temperature, - end_sequences: options.stop, - } - ); - try { - return generateResponse.body.generations[0].text; - } catch { - console.log(generateResponse); - throw new Error("Could not parse response."); - } - } - - /** @ignore */ - static async imports(): Promise<{ - cohere: typeof import("cohere-ai"); - }> { - try { - const { default: cohere } = await import("cohere-ai"); - return { cohere }; - } catch (e) { - throw new Error( - "Please install cohere-ai as a dependency with, e.g. `yarn add cohere-ai`" - ); - } - } -} +export * from "@langchain/community/llms/cohere"; diff --git a/langchain/src/llms/fireworks.ts b/langchain/src/llms/fireworks.ts index 8dd5981aad69..c8ff87b2d830 100644 --- a/langchain/src/llms/fireworks.ts +++ b/langchain/src/llms/fireworks.ts @@ -1,140 +1 @@ -import type { OpenAI as OpenAIClient } from "openai"; - -import type { BaseLLMParams } from "./base.js"; -import type { OpenAICallOptions, OpenAIInput } from "./openai.js"; -import type { OpenAICoreRequestOptions } from "../types/openai-types.js"; -import { getEnvironmentVariable } from "../util/env.js"; -import { OpenAI } from "./openai.js"; - -type FireworksUnsupportedArgs = - | "frequencyPenalty" - | "presencePenalty" - | "bestOf" - | "logitBias"; - -type FireworksUnsupportedCallOptions = "functions" | "function_call" | "tools"; - -export type FireworksCallOptions = Partial< - Omit ->; - -/** - * Wrapper around Fireworks API for large language models - * - * Fireworks API is compatible to the OpenAI API with some limitations described in - * https://readme.fireworks.ai/docs/openai-compatibility. - * - * To use, you should have the `openai` package installed and - * the `FIREWORKS_API_KEY` environment variable set. - */ -export class Fireworks extends OpenAI { - static lc_name() { - return "Fireworks"; - } - - _llmType() { - return "fireworks"; - } - - get lc_secrets(): { [key: string]: string } | undefined { - return { - fireworksApiKey: "FIREWORKS_API_KEY", - }; - } - - lc_serializable = true; - - fireworksApiKey?: string; - - constructor( - fields?: Partial< - Omit - > & - BaseLLMParams & { fireworksApiKey?: string } - ) { - const fireworksApiKey = - fields?.fireworksApiKey || getEnvironmentVariable("FIREWORKS_API_KEY"); - - if (!fireworksApiKey) { - throw new Error( - `Fireworks API key not found. Please set the FIREWORKS_API_KEY environment variable or provide the key into "fireworksApiKey"` - ); - } - - super({ - ...fields, - openAIApiKey: fireworksApiKey, - modelName: fields?.modelName || "accounts/fireworks/models/llama-v2-13b", - configuration: { - baseURL: "https://api.fireworks.ai/inference/v1", - }, - }); - - this.fireworksApiKey = fireworksApiKey; - } - - toJSON() { - const result = super.toJSON(); - - if ( - "kwargs" in result && - typeof result.kwargs === "object" && - result.kwargs != null - ) { - delete result.kwargs.openai_api_key; - delete result.kwargs.configuration; - } - - return result; - } - - async completionWithRetry( - request: OpenAIClient.CompletionCreateParamsStreaming, - options?: OpenAICoreRequestOptions - ): Promise>; - - async completionWithRetry( - request: OpenAIClient.CompletionCreateParamsNonStreaming, - options?: OpenAICoreRequestOptions - ): Promise; - - /** - * Calls the Fireworks API with retry logic in case of failures. - * @param request The request to send to the Fireworks API. - * @param options Optional configuration for the API call. - * @returns The response from the Fireworks API. - */ - async completionWithRetry( - request: - | OpenAIClient.CompletionCreateParamsStreaming - | OpenAIClient.CompletionCreateParamsNonStreaming, - options?: OpenAICoreRequestOptions - ): Promise< - AsyncIterable | OpenAIClient.Completions.Completion - > { - // https://readme.fireworks.ai/docs/openai-compatibility#api-compatibility - if (Array.isArray(request.prompt)) { - if (request.prompt.length > 1) { - throw new Error("Multiple prompts are not supported by Fireworks"); - } - - const prompt = request.prompt[0]; - if (typeof prompt !== "string") { - throw new Error("Only string prompts are supported by Fireworks"); - } - - request.prompt = prompt; - } - - delete request.frequency_penalty; - delete request.presence_penalty; - delete request.best_of; - delete request.logit_bias; - - if (request.stream === true) { - return super.completionWithRetry(request, options); - } - - return super.completionWithRetry(request, options); - } -} +export * from "@langchain/community/llms/fireworks"; diff --git a/langchain/src/llms/googlepalm.ts b/langchain/src/llms/googlepalm.ts index 2bd8d8d2fc49..5e95faa32f13 100644 --- a/langchain/src/llms/googlepalm.ts +++ b/langchain/src/llms/googlepalm.ts @@ -1,203 +1 @@ -import { TextServiceClient, protos } from "@google-ai/generativelanguage"; -import { GoogleAuth } from "google-auth-library"; -import { BaseLLMParams, LLM } from "./base.js"; -import { getEnvironmentVariable } from "../util/env.js"; - -/** - * Input for Text generation for Google Palm - */ -export interface GooglePaLMTextInput extends BaseLLMParams { - /** - * Model Name to use - * - * Note: The format must follow the pattern - `models/{model}` - */ - modelName?: string; - - /** - * Controls the randomness of the output. - * - * Values can range from [0.0,1.0], inclusive. A value closer to 1.0 - * will produce responses that are more varied and creative, while - * a value closer to 0.0 will typically result in more straightforward - * responses from the model. - * - * Note: The default value varies by model - */ - temperature?: number; - - /** - * Maximum number of tokens to generate in the completion. - */ - maxOutputTokens?: number; - - /** - * Top-p changes how the model selects tokens for output. - * - * Tokens are selected from most probable to least until the sum - * of their probabilities equals the top-p value. - * - * For example, if tokens A, B, and C have a probability of - * .3, .2, and .1 and the top-p value is .5, then the model will - * select either A or B as the next token (using temperature). - * - * Note: The default value varies by model - */ - topP?: number; - - /** - * Top-k changes how the model selects tokens for output. - * - * A top-k of 1 means the selected token is the most probable among - * all tokens in the model’s vocabulary (also called greedy decoding), - * while a top-k of 3 means that the next token is selected from - * among the 3 most probable tokens (using temperature). - * - * Note: The default value varies by model - */ - topK?: number; - - /** - * The set of character sequences (up to 5) that will stop output generation. - * If specified, the API will stop at the first appearance of a stop - * sequence. - * - * Note: The stop sequence will not be included as part of the response. - */ - stopSequences?: string[]; - - /** - * A list of unique `SafetySetting` instances for blocking unsafe content. The API will block - * any prompts and responses that fail to meet the thresholds set by these settings. If there - * is no `SafetySetting` for a given `SafetyCategory` provided in the list, the API will use - * the default safety setting for that category. - */ - safetySettings?: protos.google.ai.generativelanguage.v1beta2.ISafetySetting[]; - - /** - * Google Palm API key to use - */ - apiKey?: string; -} - -/** - * Google Palm 2 Language Model Wrapper to generate texts - */ -export class GooglePaLM extends LLM implements GooglePaLMTextInput { - get lc_secrets(): { [key: string]: string } | undefined { - return { - apiKey: "GOOGLE_PALM_API_KEY", - }; - } - - modelName = "models/text-bison-001"; - - temperature?: number; // default value chosen based on model - - maxOutputTokens?: number; // defaults to 64 - - topP?: number; // default value chosen based on model - - topK?: number; // default value chosen based on model - - stopSequences: string[] = []; - - safetySettings?: protos.google.ai.generativelanguage.v1beta2.ISafetySetting[]; // default safety setting for that category - - apiKey?: string; - - private client: TextServiceClient; - - constructor(fields?: GooglePaLMTextInput) { - super(fields ?? {}); - - this.modelName = fields?.modelName ?? this.modelName; - - this.temperature = fields?.temperature ?? this.temperature; - if (this.temperature && (this.temperature < 0 || this.temperature > 1)) { - throw new Error("`temperature` must be in the range of [0.0,1.0]"); - } - - this.maxOutputTokens = fields?.maxOutputTokens ?? this.maxOutputTokens; - if (this.maxOutputTokens && this.maxOutputTokens < 0) { - throw new Error("`maxOutputTokens` must be a positive integer"); - } - - this.topP = fields?.topP ?? this.topP; - if (this.topP && this.topP < 0) { - throw new Error("`topP` must be a positive integer"); - } - - if (this.topP && this.topP > 1) { - throw new Error("Google PaLM `topP` must in the range of [0,1]"); - } - - this.topK = fields?.topK ?? this.topK; - if (this.topK && this.topK < 0) { - throw new Error("`topK` must be a positive integer"); - } - - this.stopSequences = fields?.stopSequences ?? this.stopSequences; - - this.safetySettings = fields?.safetySettings ?? this.safetySettings; - if (this.safetySettings && this.safetySettings.length > 0) { - const safetySettingsSet = new Set( - this.safetySettings.map((s) => s.category) - ); - if (safetySettingsSet.size !== this.safetySettings.length) { - throw new Error( - "The categories in `safetySettings` array must be unique" - ); - } - } - - this.apiKey = - fields?.apiKey ?? getEnvironmentVariable("GOOGLE_PALM_API_KEY"); - if (!this.apiKey) { - throw new Error( - "Please set an API key for Google Palm 2 in the environment variable GOOGLE_PALM_API_KEY or in the `apiKey` field of the GooglePalm constructor" - ); - } - - this.client = new TextServiceClient({ - authClient: new GoogleAuth().fromAPIKey(this.apiKey), - }); - } - - _llmType(): string { - return "googlepalm"; - } - - async _call( - prompt: string, - options: this["ParsedCallOptions"] - ): Promise { - const res = await this.caller.callWithOptions( - { signal: options.signal }, - this._generateText.bind(this), - prompt - ); - return res ?? ""; - } - - protected async _generateText( - prompt: string - ): Promise { - const res = await this.client.generateText({ - model: this.modelName, - temperature: this.temperature, - candidateCount: 1, - topK: this.topK, - topP: this.topP, - maxOutputTokens: this.maxOutputTokens, - stopSequences: this.stopSequences, - safetySettings: this.safetySettings, - prompt: { - text: prompt, - }, - }); - return res[0].candidates && res[0].candidates.length > 0 - ? res[0].candidates[0].output - : undefined; - } -} +export * from "@langchain/community/llms/googlepalm"; diff --git a/langchain/src/llms/googlevertexai/index.ts b/langchain/src/llms/googlevertexai/index.ts index c3c7cbd6127a..0616e82bdd49 100644 --- a/langchain/src/llms/googlevertexai/index.ts +++ b/langchain/src/llms/googlevertexai/index.ts @@ -1,66 +1 @@ -import { GoogleAuthOptions } from "google-auth-library"; -import { GoogleVertexAILLMConnection } from "../../util/googlevertexai-connection.js"; -import { GoogleVertexAIBaseLLMInput } from "../../types/googlevertexai-types.js"; -import { BaseGoogleVertexAI } from "./common.js"; -import { GAuthClient } from "../../util/googlevertexai-gauth.js"; - -/** - * Interface representing the input to the Google Vertex AI model. - */ -export interface GoogleVertexAITextInput - extends GoogleVertexAIBaseLLMInput {} - -/** - * Enables calls to the Google Cloud's Vertex AI API to access - * Large Language Models. - * - * To use, you will need to have one of the following authentication - * methods in place: - * - You are logged into an account permitted to the Google Cloud project - * using Vertex AI. - * - You are running this on a machine using a service account permitted to - * the Google Cloud project using Vertex AI. - * - The `GOOGLE_APPLICATION_CREDENTIALS` environment variable is set to the - * path of a credentials file for a service account permitted to the - * Google Cloud project using Vertex AI. - * @example - * ```typescript - * const model = new GoogleVertexAI({ - * temperature: 0.7, - * }); - * const stream = await model.stream( - * "What would be a good company name for a company that makes colorful socks?", - * ); - * for await (const chunk of stream) { - * console.log(chunk); - * } - * ``` - */ -export class GoogleVertexAI extends BaseGoogleVertexAI { - static lc_name() { - return "VertexAI"; - } - - constructor(fields?: GoogleVertexAITextInput) { - super(fields); - - const client = new GAuthClient({ - scopes: "https://www.googleapis.com/auth/cloud-platform", - ...fields?.authOptions, - }); - - this.connection = new GoogleVertexAILLMConnection( - { ...fields, ...this }, - this.caller, - client, - false - ); - - this.streamedConnection = new GoogleVertexAILLMConnection( - { ...fields, ...this }, - this.caller, - client, - true - ); - } -} +export * from "@langchain/community/llms/googlevertexai"; diff --git a/langchain/src/llms/googlevertexai/web.ts b/langchain/src/llms/googlevertexai/web.ts index 0b656308d53b..cecd871df86d 100644 --- a/langchain/src/llms/googlevertexai/web.ts +++ b/langchain/src/llms/googlevertexai/web.ts @@ -1,66 +1 @@ -import { - WebGoogleAuth, - WebGoogleAuthOptions, -} from "../../util/googlevertexai-webauth.js"; -import { GoogleVertexAILLMConnection } from "../../util/googlevertexai-connection.js"; -import { GoogleVertexAIBaseLLMInput } from "../../types/googlevertexai-types.js"; -import { BaseGoogleVertexAI } from "./common.js"; - -/** - * Interface representing the input to the Google Vertex AI model. - */ -export interface GoogleVertexAITextInput - extends GoogleVertexAIBaseLLMInput {} - -/** - * Enables calls to the Google Cloud's Vertex AI API to access - * Large Language Models. - * - * This entrypoint and class are intended to be used in web environments like Edge - * functions where you do not have access to the file system. It supports passing - * service account credentials directly as a "GOOGLE_VERTEX_AI_WEB_CREDENTIALS" - * environment variable or directly as "authOptions.credentials". - * @example - * ```typescript - * const model = new GoogleVertexAI({ - * temperature: 0.7, - * }); - * const stream = await model.stream( - * "What would be a good company name for a company that makes colorful socks?", - * ); - * for await (const chunk of stream) { - * console.log(chunk); - * } - * ``` - */ -export class GoogleVertexAI extends BaseGoogleVertexAI { - static lc_name() { - return "VertexAI"; - } - - get lc_secrets(): { [key: string]: string } { - return { - "authOptions.credentials": "GOOGLE_VERTEX_AI_WEB_CREDENTIALS", - }; - } - - constructor(fields?: GoogleVertexAITextInput) { - super(fields); - - const client = new WebGoogleAuth(fields?.authOptions); - - this.connection = new GoogleVertexAILLMConnection( - { ...fields, ...this }, - this.caller, - client, - false - ); - - this.streamedConnection = new GoogleVertexAILLMConnection( - { ...fields, ...this }, - this.caller, - client, - true - ); - } -} +export * from "@langchain/community/llms/googlevertexai/web"; diff --git a/langchain/src/llms/gradient_ai.ts b/langchain/src/llms/gradient_ai.ts index 782bb8cd114e..f6a15989551a 100644 --- a/langchain/src/llms/gradient_ai.ts +++ b/langchain/src/llms/gradient_ai.ts @@ -1,136 +1 @@ -import { Gradient } from "@gradientai/nodejs-sdk"; -import { BaseLLMCallOptions, BaseLLMParams, LLM } from "./base.js"; -import { getEnvironmentVariable } from "../util/env.js"; - -/** - * The GradientLLMParams interface defines the input parameters for - * the GradientLLM class. - */ -export interface GradientLLMParams extends BaseLLMParams { - /** - * Gradient AI Access Token. - * Provide Access Token if you do not wish to automatically pull from env. - */ - gradientAccessKey?: string; - /** - * Gradient Workspace Id. - * Provide workspace id if you do not wish to automatically pull from env. - */ - workspaceId?: string; - /** - * Parameters accepted by the Gradient npm package. - */ - inferenceParameters?: Record; - /** - * Gradient AI Model Slug. - */ - modelSlug?: string; - /** - * Gradient Adapter ID for custom fine tuned models. - */ - adapterId?: string; -} - -/** - * The GradientLLM class is used to interact with Gradient AI inference Endpoint models. - * This requires your Gradient AI Access Token which is autoloaded if not specified. - */ -export class GradientLLM extends LLM { - static lc_name() { - return "GradientLLM"; - } - - get lc_secrets(): { [key: string]: string } | undefined { - return { - gradientAccessKey: "GRADIENT_ACCESS_TOKEN", - workspaceId: "GRADIENT_WORKSPACE_ID", - }; - } - - modelSlug = "llama2-7b-chat"; - - adapterId?: string; - - gradientAccessKey?: string; - - workspaceId?: string; - - inferenceParameters?: Record; - - // Gradient AI does not export the BaseModel type. Once it does, we can use it here. - // eslint-disable-next-line @typescript-eslint/no-explicit-any - model: any; - - constructor(fields: GradientLLMParams) { - super(fields); - - this.modelSlug = fields?.modelSlug ?? this.modelSlug; - this.adapterId = fields?.adapterId; - this.gradientAccessKey = - fields?.gradientAccessKey ?? - getEnvironmentVariable("GRADIENT_ACCESS_TOKEN"); - this.workspaceId = - fields?.workspaceId ?? getEnvironmentVariable("GRADIENT_WORKSPACE_ID"); - - this.inferenceParameters = fields.inferenceParameters; - - if (!this.gradientAccessKey) { - throw new Error("Missing Gradient AI Access Token"); - } - - if (!this.workspaceId) { - throw new Error("Missing Gradient AI Workspace ID"); - } - } - - _llmType() { - return "gradient_ai"; - } - - /** - * Calls the Gradient AI endpoint and retrieves the result. - * @param {string} prompt The input prompt. - * @returns {Promise} A promise that resolves to the generated string. - */ - /** @ignore */ - async _call( - prompt: string, - _options: this["ParsedCallOptions"] - ): Promise { - await this.setModel(); - - // GradientLLM does not export the CompleteResponse type. Once it does, we can use it here. - interface CompleteResponse { - finishReason: string; - generatedOutput: string; - } - - const response = (await this.caller.call(async () => - this.model.complete({ - query: prompt, - ...this.inferenceParameters, - }) - )) as CompleteResponse; - - return response.generatedOutput; - } - - async setModel() { - if (this.model) return; - - const gradient = new Gradient({ - accessToken: this.gradientAccessKey, - workspaceId: this.workspaceId, - }); - - if (this.adapterId) { - this.model = await gradient.getModelAdapter({ - modelAdapterId: this.adapterId, - }); - } else { - this.model = await gradient.getBaseModel({ - baseModelSlug: this.modelSlug, - }); - } - } -} +export * from "@langchain/community/llms/gradient_ai"; diff --git a/langchain/src/llms/hf.ts b/langchain/src/llms/hf.ts index e849a7acb37e..2f0e767bf6cd 100644 --- a/langchain/src/llms/hf.ts +++ b/langchain/src/llms/hf.ts @@ -1,155 +1 @@ -import { getEnvironmentVariable } from "../util/env.js"; -import { LLM, BaseLLMParams } from "./base.js"; - -/** - * Interface defining the parameters for configuring the Hugging Face - * model for text generation. - */ -export interface HFInput { - /** Model to use */ - model: string; - - /** Custom inference endpoint URL to use */ - endpointUrl?: string; - - /** Sampling temperature to use */ - temperature?: number; - - /** - * Maximum number of tokens to generate in the completion. - */ - maxTokens?: number; - - /** Total probability mass of tokens to consider at each step */ - topP?: number; - - /** Integer to define the top tokens considered within the sample operation to create new text. */ - topK?: number; - - /** Penalizes repeated tokens according to frequency */ - frequencyPenalty?: number; - - /** API key to use. */ - apiKey?: string; - - /** - * Credentials to use for the request. If this is a string, it will be passed straight on. If it's a boolean, true will be "include" and false will not send credentials at all. - */ - includeCredentials?: string | boolean; -} - -/** - * Class implementing the Large Language Model (LLM) interface using the - * Hugging Face Inference API for text generation. - * @example - * ```typescript - * const model = new HuggingFaceInference({ - * model: "gpt2", - * temperature: 0.7, - * maxTokens: 50, - * }); - * - * const res = await model.call( - * "Question: What would be a good company name for a company that makes colorful socks?\nAnswer:" - * ); - * console.log({ res }); - * ``` - */ -export class HuggingFaceInference extends LLM implements HFInput { - get lc_secrets(): { [key: string]: string } | undefined { - return { - apiKey: "HUGGINGFACEHUB_API_KEY", - }; - } - - model = "gpt2"; - - temperature: number | undefined = undefined; - - maxTokens: number | undefined = undefined; - - topP: number | undefined = undefined; - - topK: number | undefined = undefined; - - frequencyPenalty: number | undefined = undefined; - - apiKey: string | undefined = undefined; - - endpointUrl: string | undefined = undefined; - - includeCredentials: string | boolean | undefined = undefined; - - constructor(fields?: Partial & BaseLLMParams) { - super(fields ?? {}); - - this.model = fields?.model ?? this.model; - this.temperature = fields?.temperature ?? this.temperature; - this.maxTokens = fields?.maxTokens ?? this.maxTokens; - this.topP = fields?.topP ?? this.topP; - this.topK = fields?.topK ?? this.topK; - this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty; - this.apiKey = - fields?.apiKey ?? getEnvironmentVariable("HUGGINGFACEHUB_API_KEY"); - this.endpointUrl = fields?.endpointUrl; - this.includeCredentials = fields?.includeCredentials; - - if (!this.apiKey) { - throw new Error( - "Please set an API key for HuggingFace Hub in the environment variable HUGGINGFACEHUB_API_KEY or in the apiKey field of the HuggingFaceInference constructor." - ); - } - } - - _llmType() { - return "hf"; - } - - /** @ignore */ - async _call( - prompt: string, - options: this["ParsedCallOptions"] - ): Promise { - const { HfInference } = await HuggingFaceInference.imports(); - const hf = this.endpointUrl - ? new HfInference(this.apiKey, { - includeCredentials: this.includeCredentials, - }).endpoint(this.endpointUrl) - : new HfInference(this.apiKey, { - includeCredentials: this.includeCredentials, - }); - - const res = await this.caller.callWithOptions( - { signal: options.signal }, - hf.textGeneration.bind(hf), - { - model: this.model, - parameters: { - // make it behave similar to openai, returning only the generated text - return_full_text: false, - temperature: this.temperature, - max_new_tokens: this.maxTokens, - top_p: this.topP, - top_k: this.topK, - repetition_penalty: this.frequencyPenalty, - }, - inputs: prompt, - } - ); - return res.generated_text; - } - - /** @ignore */ - static async imports(): Promise<{ - HfInference: typeof import("@huggingface/inference").HfInference; - }> { - try { - const { HfInference } = await import("@huggingface/inference"); - return { HfInference }; - } catch (e) { - throw new Error( - "Please install huggingface as a dependency with, e.g. `yarn add @huggingface/inference`" - ); - } - } -} +export * from "@langchain/community/llms/hf"; diff --git a/langchain/src/llms/llama_cpp.ts b/langchain/src/llms/llama_cpp.ts index f2d6518ffd2b..ce09302c20d3 100644 --- a/langchain/src/llms/llama_cpp.ts +++ b/langchain/src/llms/llama_cpp.ts @@ -1,116 +1 @@ -import { LlamaModel, LlamaContext, LlamaChatSession } from "node-llama-cpp"; -import { - LlamaBaseCppInputs, - createLlamaModel, - createLlamaContext, - createLlamaSession, -} from "../util/llama_cpp.js"; -import { LLM, BaseLLMCallOptions, BaseLLMParams } from "./base.js"; -import { CallbackManagerForLLMRun } from "../callbacks/manager.js"; -import { GenerationChunk } from "../schema/index.js"; - -/** - * Note that the modelPath is the only required parameter. For testing you - * can set this in the environment variable `LLAMA_PATH`. - */ -export interface LlamaCppInputs extends LlamaBaseCppInputs, BaseLLMParams {} - -export interface LlamaCppCallOptions extends BaseLLMCallOptions { - /** The maximum number of tokens the response should contain. */ - maxTokens?: number; - /** A function called when matching the provided token array */ - onToken?: (tokens: number[]) => void; -} - -/** - * To use this model you need to have the `node-llama-cpp` module installed. - * This can be installed using `npm install -S node-llama-cpp` and the minimum - * version supported in version 2.0.0. - * This also requires that have a locally built version of Llama2 installed. - */ -export class LlamaCpp extends LLM { - declare CallOptions: LlamaCppCallOptions; - - static inputs: LlamaCppInputs; - - maxTokens?: number; - - temperature?: number; - - topK?: number; - - topP?: number; - - trimWhitespaceSuffix?: boolean; - - _model: LlamaModel; - - _context: LlamaContext; - - _session: LlamaChatSession; - - static lc_name() { - return "LlamaCpp"; - } - - constructor(inputs: LlamaCppInputs) { - super(inputs); - this.maxTokens = inputs?.maxTokens; - this.temperature = inputs?.temperature; - this.topK = inputs?.topK; - this.topP = inputs?.topP; - this.trimWhitespaceSuffix = inputs?.trimWhitespaceSuffix; - this._model = createLlamaModel(inputs); - this._context = createLlamaContext(this._model, inputs); - this._session = createLlamaSession(this._context); - } - - _llmType() { - return "llama2_cpp"; - } - - /** @ignore */ - async _call( - prompt: string, - options?: this["ParsedCallOptions"] - ): Promise { - try { - const promptOptions = { - onToken: options?.onToken, - maxTokens: this?.maxTokens, - temperature: this?.temperature, - topK: this?.topK, - topP: this?.topP, - trimWhitespaceSuffix: this?.trimWhitespaceSuffix, - }; - const completion = await this._session.prompt(prompt, promptOptions); - return completion; - } catch (e) { - throw new Error("Error getting prompt completion."); - } - } - - async *_streamResponseChunks( - prompt: string, - _options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): AsyncGenerator { - const promptOptions = { - temperature: this?.temperature, - topK: this?.topK, - topP: this?.topP, - }; - - const stream = await this.caller.call(async () => - this._context.evaluate(this._context.encode(prompt), promptOptions) - ); - - for await (const chunk of stream) { - yield new GenerationChunk({ - text: this._context.decode([chunk]), - generationInfo: {}, - }); - await runManager?.handleLLMNewToken(this._context.decode([chunk]) ?? ""); - } - } -} +export * from "@langchain/community/llms/llama_cpp"; diff --git a/langchain/src/llms/ollama.ts b/langchain/src/llms/ollama.ts index 8b5c178cfb1f..1369ed4c0c60 100644 --- a/langchain/src/llms/ollama.ts +++ b/langchain/src/llms/ollama.ts @@ -1,245 +1 @@ -import { LLM, BaseLLMParams } from "./base.js"; -import { - createOllamaStream, - OllamaInput, - OllamaCallOptions, -} from "../util/ollama.js"; -import { CallbackManagerForLLMRun } from "../callbacks/manager.js"; -import { GenerationChunk } from "../schema/index.js"; -import type { StringWithAutocomplete } from "../util/types.js"; - -/** - * Class that represents the Ollama language model. It extends the base - * LLM class and implements the OllamaInput interface. - * @example - * ```typescript - * const ollama = new Ollama({ - * baseUrl: "http://api.example.com", - * model: "llama2", - * }); - * - * // Streaming translation from English to German - * const stream = await ollama.stream( - * `Translate "I love programming" into German.` - * ); - * - * const chunks = []; - * for await (const chunk of stream) { - * chunks.push(chunk); - * } - * - * console.log(chunks.join("")); - * ``` - */ -export class Ollama extends LLM implements OllamaInput { - static lc_name() { - return "Ollama"; - } - - lc_serializable = true; - - model = "llama2"; - - baseUrl = "http://localhost:11434"; - - embeddingOnly?: boolean; - - f16KV?: boolean; - - frequencyPenalty?: number; - - logitsAll?: boolean; - - lowVram?: boolean; - - mainGpu?: number; - - mirostat?: number; - - mirostatEta?: number; - - mirostatTau?: number; - - numBatch?: number; - - numCtx?: number; - - numGpu?: number; - - numGqa?: number; - - numKeep?: number; - - numThread?: number; - - penalizeNewline?: boolean; - - presencePenalty?: number; - - repeatLastN?: number; - - repeatPenalty?: number; - - ropeFrequencyBase?: number; - - ropeFrequencyScale?: number; - - temperature?: number; - - stop?: string[]; - - tfsZ?: number; - - topK?: number; - - topP?: number; - - typicalP?: number; - - useMLock?: boolean; - - useMMap?: boolean; - - vocabOnly?: boolean; - - format?: StringWithAutocomplete<"json">; - - constructor(fields: OllamaInput & BaseLLMParams) { - super(fields); - this.model = fields.model ?? this.model; - this.baseUrl = fields.baseUrl?.endsWith("/") - ? fields.baseUrl.slice(0, -1) - : fields.baseUrl ?? this.baseUrl; - - this.embeddingOnly = fields.embeddingOnly; - this.f16KV = fields.f16KV; - this.frequencyPenalty = fields.frequencyPenalty; - this.logitsAll = fields.logitsAll; - this.lowVram = fields.lowVram; - this.mainGpu = fields.mainGpu; - this.mirostat = fields.mirostat; - this.mirostatEta = fields.mirostatEta; - this.mirostatTau = fields.mirostatTau; - this.numBatch = fields.numBatch; - this.numCtx = fields.numCtx; - this.numGpu = fields.numGpu; - this.numGqa = fields.numGqa; - this.numKeep = fields.numKeep; - this.numThread = fields.numThread; - this.penalizeNewline = fields.penalizeNewline; - this.presencePenalty = fields.presencePenalty; - this.repeatLastN = fields.repeatLastN; - this.repeatPenalty = fields.repeatPenalty; - this.ropeFrequencyBase = fields.ropeFrequencyBase; - this.ropeFrequencyScale = fields.ropeFrequencyScale; - this.temperature = fields.temperature; - this.stop = fields.stop; - this.tfsZ = fields.tfsZ; - this.topK = fields.topK; - this.topP = fields.topP; - this.typicalP = fields.typicalP; - this.useMLock = fields.useMLock; - this.useMMap = fields.useMMap; - this.vocabOnly = fields.vocabOnly; - this.format = fields.format; - } - - _llmType() { - return "ollama"; - } - - invocationParams(options?: this["ParsedCallOptions"]) { - return { - model: this.model, - format: this.format, - options: { - embedding_only: this.embeddingOnly, - f16_kv: this.f16KV, - frequency_penalty: this.frequencyPenalty, - logits_all: this.logitsAll, - low_vram: this.lowVram, - main_gpu: this.mainGpu, - mirostat: this.mirostat, - mirostat_eta: this.mirostatEta, - mirostat_tau: this.mirostatTau, - num_batch: this.numBatch, - num_ctx: this.numCtx, - num_gpu: this.numGpu, - num_gqa: this.numGqa, - num_keep: this.numKeep, - num_thread: this.numThread, - penalize_newline: this.penalizeNewline, - presence_penalty: this.presencePenalty, - repeat_last_n: this.repeatLastN, - repeat_penalty: this.repeatPenalty, - rope_frequency_base: this.ropeFrequencyBase, - rope_frequency_scale: this.ropeFrequencyScale, - temperature: this.temperature, - stop: options?.stop ?? this.stop, - tfs_z: this.tfsZ, - top_k: this.topK, - top_p: this.topP, - typical_p: this.typicalP, - use_mlock: this.useMLock, - use_mmap: this.useMMap, - vocab_only: this.vocabOnly, - }, - }; - } - - async *_streamResponseChunks( - prompt: string, - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): AsyncGenerator { - const stream = await this.caller.call(async () => - createOllamaStream( - this.baseUrl, - { ...this.invocationParams(options), prompt }, - options - ) - ); - for await (const chunk of stream) { - if (!chunk.done) { - yield new GenerationChunk({ - text: chunk.response, - generationInfo: { - ...chunk, - response: undefined, - }, - }); - await runManager?.handleLLMNewToken(chunk.response ?? ""); - } else { - yield new GenerationChunk({ - text: "", - generationInfo: { - model: chunk.model, - total_duration: chunk.total_duration, - load_duration: chunk.load_duration, - prompt_eval_count: chunk.prompt_eval_count, - prompt_eval_duration: chunk.prompt_eval_duration, - eval_count: chunk.eval_count, - eval_duration: chunk.eval_duration, - }, - }); - } - } - } - - /** @ignore */ - async _call( - prompt: string, - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): Promise { - const chunks = []; - for await (const chunk of this._streamResponseChunks( - prompt, - options, - runManager - )) { - chunks.push(chunk.text); - } - return chunks.join(""); - } -} +export * from "@langchain/community/llms/ollama"; diff --git a/langchain/src/llms/openai-chat.ts b/langchain/src/llms/openai-chat.ts index 4077a5efde89..06cb7faf0f1d 100644 --- a/langchain/src/llms/openai-chat.ts +++ b/langchain/src/llms/openai-chat.ts @@ -1,471 +1,18 @@ -import { type ClientOptions, OpenAI as OpenAIClient } from "openai"; +import { OpenAIChat } from "@langchain/openai"; + import { CallbackManagerForLLMRun } from "../callbacks/manager.js"; -import { Generation, GenerationChunk, LLMResult } from "../schema/index.js"; -import { - AzureOpenAIInput, - OpenAICallOptions, - OpenAIChatInput, - OpenAICoreRequestOptions, - LegacyOpenAIInput, -} from "../types/openai-types.js"; -import { OpenAIEndpointConfig, getEndpoint } from "../util/azure.js"; +import type { Generation, LLMResult } from "../schema/index.js"; import { getEnvironmentVariable } from "../util/env.js"; import { promptLayerTrackRequest } from "../util/prompt-layer.js"; -import { BaseLLMParams, LLM } from "./base.js"; -import { wrapOpenAIClientError } from "../util/openai.js"; - -export { type AzureOpenAIInput, type OpenAIChatInput }; -/** - * Interface that extends the OpenAICallOptions interface and includes an - * optional promptIndex property. It represents the options that can be - * passed when making a call to the OpenAI Chat API. - */ -export interface OpenAIChatCallOptions extends OpenAICallOptions { - promptIndex?: number; -} - -/** - * Wrapper around OpenAI large language models that use the Chat endpoint. - * - * To use you should have the `openai` package installed, with the - * `OPENAI_API_KEY` environment variable set. - * - * To use with Azure you should have the `openai` package installed, with the - * `AZURE_OPENAI_API_KEY`, - * `AZURE_OPENAI_API_INSTANCE_NAME`, - * `AZURE_OPENAI_API_DEPLOYMENT_NAME` - * and `AZURE_OPENAI_API_VERSION` environment variable set. - * - * @remarks - * Any parameters that are valid to be passed to {@link - * https://platform.openai.com/docs/api-reference/chat/create | - * `openai.createCompletion`} can be passed through {@link modelKwargs}, even - * if not explicitly available on this class. - * - * @augments BaseLLM - * @augments OpenAIInput - * @augments AzureOpenAIChatInput - * @example - * ```typescript - * const model = new OpenAIChat({ - * prefixMessages: [ - * { - * role: "system", - * content: "You are a helpful assistant that answers in pirate language", - * }, - * ], - * maxTokens: 50, - * }); - * - * const res = await model.call( - * "What would be a good company name for a company that makes colorful socks?" - * ); - * console.log({ res }); - * ``` - */ -export class OpenAIChat - extends LLM - implements OpenAIChatInput, AzureOpenAIInput -{ - static lc_name() { - return "OpenAIChat"; - } - - get callKeys() { - return [...super.callKeys, "options", "promptIndex"]; - } - - lc_serializable = true; - - get lc_secrets(): { [key: string]: string } | undefined { - return { - openAIApiKey: "OPENAI_API_KEY", - azureOpenAIApiKey: "AZURE_OPENAI_API_KEY", - organization: "OPENAI_ORGANIZATION", - }; - } - - get lc_aliases(): Record { - return { - modelName: "model", - openAIApiKey: "openai_api_key", - azureOpenAIApiVersion: "azure_openai_api_version", - azureOpenAIApiKey: "azure_openai_api_key", - azureOpenAIApiInstanceName: "azure_openai_api_instance_name", - azureOpenAIApiDeploymentName: "azure_openai_api_deployment_name", - }; - } - - temperature = 1; - - topP = 1; - - frequencyPenalty = 0; - - presencePenalty = 0; - - n = 1; - - logitBias?: Record; - - maxTokens?: number; - - modelName = "gpt-3.5-turbo"; - - prefixMessages?: OpenAIClient.Chat.ChatCompletionMessageParam[]; - - modelKwargs?: OpenAIChatInput["modelKwargs"]; - - timeout?: number; - - stop?: string[]; - - user?: string; - - streaming = false; - - openAIApiKey?: string; - - azureOpenAIApiVersion?: string; - - azureOpenAIApiKey?: string; - - azureOpenAIApiInstanceName?: string; - - azureOpenAIApiDeploymentName?: string; - - azureOpenAIBasePath?: string; - - organization?: string; - - private client: OpenAIClient; - - private clientConfig: ClientOptions; - - constructor( - fields?: Partial & - Partial & - BaseLLMParams & { - configuration?: ClientOptions & LegacyOpenAIInput; - }, - /** @deprecated */ - configuration?: ClientOptions & LegacyOpenAIInput - ) { - super(fields ?? {}); - - this.openAIApiKey = - fields?.openAIApiKey ?? getEnvironmentVariable("OPENAI_API_KEY"); - - this.azureOpenAIApiKey = - fields?.azureOpenAIApiKey ?? - getEnvironmentVariable("AZURE_OPENAI_API_KEY"); - - if (!this.azureOpenAIApiKey && !this.openAIApiKey) { - throw new Error("OpenAI or Azure OpenAI API key not found"); - } - - this.azureOpenAIApiInstanceName = - fields?.azureOpenAIApiInstanceName ?? - getEnvironmentVariable("AZURE_OPENAI_API_INSTANCE_NAME"); - - this.azureOpenAIApiDeploymentName = - (fields?.azureOpenAIApiCompletionsDeploymentName || - fields?.azureOpenAIApiDeploymentName) ?? - (getEnvironmentVariable("AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME") || - getEnvironmentVariable("AZURE_OPENAI_API_DEPLOYMENT_NAME")); - - this.azureOpenAIApiVersion = - fields?.azureOpenAIApiVersion ?? - getEnvironmentVariable("AZURE_OPENAI_API_VERSION"); - - this.azureOpenAIBasePath = - fields?.azureOpenAIBasePath ?? - getEnvironmentVariable("AZURE_OPENAI_BASE_PATH"); - this.organization = - fields?.configuration?.organization ?? - getEnvironmentVariable("OPENAI_ORGANIZATION"); +export { + type AzureOpenAIInput, + type OpenAICallOptions, + type OpenAIInput, + type OpenAIChatCallOptions, +} from "@langchain/openai"; - this.modelName = fields?.modelName ?? this.modelName; - this.prefixMessages = fields?.prefixMessages ?? this.prefixMessages; - this.modelKwargs = fields?.modelKwargs ?? {}; - this.timeout = fields?.timeout; - - this.temperature = fields?.temperature ?? this.temperature; - this.topP = fields?.topP ?? this.topP; - this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty; - this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty; - this.n = fields?.n ?? this.n; - this.logitBias = fields?.logitBias; - this.maxTokens = fields?.maxTokens; - this.stop = fields?.stop; - this.user = fields?.user; - - this.streaming = fields?.streaming ?? false; - - if (this.n > 1) { - throw new Error( - "Cannot use n > 1 in OpenAIChat LLM. Use ChatOpenAI Chat Model instead." - ); - } - - if (this.azureOpenAIApiKey) { - if (!this.azureOpenAIApiInstanceName && !this.azureOpenAIBasePath) { - throw new Error("Azure OpenAI API instance name not found"); - } - if (!this.azureOpenAIApiDeploymentName) { - throw new Error("Azure OpenAI API deployment name not found"); - } - if (!this.azureOpenAIApiVersion) { - throw new Error("Azure OpenAI API version not found"); - } - this.openAIApiKey = this.openAIApiKey ?? ""; - } - - this.clientConfig = { - apiKey: this.openAIApiKey, - organization: this.organization, - baseURL: configuration?.basePath ?? fields?.configuration?.basePath, - dangerouslyAllowBrowser: true, - defaultHeaders: - configuration?.baseOptions?.headers ?? - fields?.configuration?.baseOptions?.headers, - defaultQuery: - configuration?.baseOptions?.params ?? - fields?.configuration?.baseOptions?.params, - ...configuration, - ...fields?.configuration, - }; - } - - /** - * Get the parameters used to invoke the model - */ - invocationParams( - options?: this["ParsedCallOptions"] - ): Omit { - return { - model: this.modelName, - temperature: this.temperature, - top_p: this.topP, - frequency_penalty: this.frequencyPenalty, - presence_penalty: this.presencePenalty, - n: this.n, - logit_bias: this.logitBias, - max_tokens: this.maxTokens === -1 ? undefined : this.maxTokens, - stop: options?.stop ?? this.stop, - user: this.user, - stream: this.streaming, - ...this.modelKwargs, - }; - } - - /** @ignore */ - _identifyingParams(): Omit< - OpenAIClient.Chat.ChatCompletionCreateParams, - "messages" - > & { - model_name: string; - } & ClientOptions { - return { - model_name: this.modelName, - ...this.invocationParams(), - ...this.clientConfig, - }; - } - - /** - * Get the identifying parameters for the model - */ - identifyingParams(): Omit< - OpenAIClient.Chat.ChatCompletionCreateParams, - "messages" - > & { - model_name: string; - } & ClientOptions { - return { - model_name: this.modelName, - ...this.invocationParams(), - ...this.clientConfig, - }; - } - - /** - * Formats the messages for the OpenAI API. - * @param prompt The prompt to be formatted. - * @returns Array of formatted messages. - */ - private formatMessages( - prompt: string - ): OpenAIClient.Chat.ChatCompletionMessageParam[] { - const message: OpenAIClient.Chat.ChatCompletionMessageParam = { - role: "user", - content: prompt, - }; - return this.prefixMessages ? [...this.prefixMessages, message] : [message]; - } - - async *_streamResponseChunks( - prompt: string, - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): AsyncGenerator { - const params = { - ...this.invocationParams(options), - messages: this.formatMessages(prompt), - stream: true as const, - }; - const stream = await this.completionWithRetry(params, options); - for await (const data of stream) { - const choice = data?.choices[0]; - if (!choice) { - continue; - } - const { delta } = choice; - const generationChunk = new GenerationChunk({ - text: delta.content ?? "", - }); - yield generationChunk; - const newTokenIndices = { - prompt: options.promptIndex ?? 0, - completion: choice.index ?? 0, - }; - // eslint-disable-next-line no-void - void runManager?.handleLLMNewToken( - generationChunk.text ?? "", - newTokenIndices - ); - } - if (options.signal?.aborted) { - throw new Error("AbortError"); - } - } - - /** @ignore */ - async _call( - prompt: string, - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): Promise { - const params = this.invocationParams(options); - - if (params.stream) { - const stream = await this._streamResponseChunks( - prompt, - options, - runManager - ); - let finalChunk: GenerationChunk | undefined; - for await (const chunk of stream) { - if (finalChunk === undefined) { - finalChunk = chunk; - } else { - finalChunk = finalChunk.concat(chunk); - } - } - return finalChunk?.text ?? ""; - } else { - const response = await this.completionWithRetry( - { - ...params, - stream: false, - messages: this.formatMessages(prompt), - }, - { - signal: options.signal, - ...options.options, - } - ); - return response?.choices[0]?.message?.content ?? ""; - } - } - - /** - * Calls the OpenAI API with retry logic in case of failures. - * @param request The request to send to the OpenAI API. - * @param options Optional configuration for the API call. - * @returns The response from the OpenAI API. - */ - async completionWithRetry( - request: OpenAIClient.Chat.ChatCompletionCreateParamsStreaming, - options?: OpenAICoreRequestOptions - ): Promise>; - - async completionWithRetry( - request: OpenAIClient.Chat.ChatCompletionCreateParamsNonStreaming, - options?: OpenAICoreRequestOptions - ): Promise; - - async completionWithRetry( - request: - | OpenAIClient.Chat.ChatCompletionCreateParamsStreaming - | OpenAIClient.Chat.ChatCompletionCreateParamsNonStreaming, - options?: OpenAICoreRequestOptions - ): Promise< - | AsyncIterable - | OpenAIClient.Chat.Completions.ChatCompletion - > { - const requestOptions = this._getClientOptions(options); - return this.caller.call(async () => { - try { - const res = await this.client.chat.completions.create( - request, - requestOptions - ); - return res; - } catch (e) { - const error = wrapOpenAIClientError(e); - throw error; - } - }); - } - - /** @ignore */ - private _getClientOptions(options: OpenAICoreRequestOptions | undefined) { - if (!this.client) { - const openAIEndpointConfig: OpenAIEndpointConfig = { - azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName, - azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName, - azureOpenAIApiKey: this.azureOpenAIApiKey, - azureOpenAIBasePath: this.azureOpenAIBasePath, - baseURL: this.clientConfig.baseURL, - }; - - const endpoint = getEndpoint(openAIEndpointConfig); - - const params = { - ...this.clientConfig, - baseURL: endpoint, - timeout: this.timeout, - maxRetries: 0, - }; - if (!params.baseURL) { - delete params.baseURL; - } - - this.client = new OpenAIClient(params); - } - const requestOptions = { - ...this.clientConfig, - ...options, - } as OpenAICoreRequestOptions; - if (this.azureOpenAIApiKey) { - requestOptions.headers = { - "api-key": this.azureOpenAIApiKey, - ...requestOptions.headers, - }; - requestOptions.query = { - "api-version": this.azureOpenAIApiVersion, - ...requestOptions.query, - }; - } - return requestOptions; - } - - _llmType() { - return "openai"; - } -} +export { OpenAIChat }; /** * PromptLayer wrapper to OpenAIChat diff --git a/langchain/src/llms/openai.ts b/langchain/src/llms/openai.ts index c962f30c096f..edfb429272b5 100644 --- a/langchain/src/llms/openai.ts +++ b/langchain/src/llms/openai.ts @@ -1,558 +1,17 @@ -import type { TiktokenModel } from "js-tiktoken/lite"; -import { type ClientOptions, OpenAI as OpenAIClient } from "openai"; -import { calculateMaxTokens } from "../base_language/count_tokens.js"; +import { OpenAI } from "@langchain/openai"; + import { CallbackManagerForLLMRun } from "../callbacks/manager.js"; -import { GenerationChunk, LLMResult } from "../schema/index.js"; -import { - AzureOpenAIInput, - OpenAICallOptions, - OpenAICoreRequestOptions, - OpenAIInput, - LegacyOpenAIInput, -} from "../types/openai-types.js"; -import { OpenAIEndpointConfig, getEndpoint } from "../util/azure.js"; -import { chunkArray } from "../util/chunk.js"; +import type { LLMResult } from "../schema/index.js"; import { getEnvironmentVariable } from "../util/env.js"; import { promptLayerTrackRequest } from "../util/prompt-layer.js"; -import { BaseLLM, BaseLLMParams } from "./base.js"; -import { OpenAIChat } from "./openai-chat.js"; -import { wrapOpenAIClientError } from "../util/openai.js"; - -export type { AzureOpenAIInput, OpenAICallOptions, OpenAIInput }; - -/** - * Interface for tracking token usage in OpenAI calls. - */ -interface TokenUsage { - completionTokens?: number; - promptTokens?: number; - totalTokens?: number; -} - -/** - * Wrapper around OpenAI large language models. - * - * To use you should have the `openai` package installed, with the - * `OPENAI_API_KEY` environment variable set. - * - * To use with Azure you should have the `openai` package installed, with the - * `AZURE_OPENAI_API_KEY`, - * `AZURE_OPENAI_API_INSTANCE_NAME`, - * `AZURE_OPENAI_API_DEPLOYMENT_NAME` - * and `AZURE_OPENAI_API_VERSION` environment variable set. - * - * @remarks - * Any parameters that are valid to be passed to {@link - * https://platform.openai.com/docs/api-reference/completions/create | - * `openai.createCompletion`} can be passed through {@link modelKwargs}, even - * if not explicitly available on this class. - * @example - * ```typescript - * const model = new OpenAI({ - * modelName: "gpt-4", - * temperature: 0.7, - * maxTokens: 1000, - * maxRetries: 5, - * }); - * - * const res = await model.call( - * "Question: What would be a good company name for a company that makes colorful socks?\nAnswer:" - * ); - * console.log({ res }); - * ``` - */ -export class OpenAI - extends BaseLLM - implements OpenAIInput, AzureOpenAIInput -{ - static lc_name() { - return "OpenAI"; - } - - get callKeys() { - return [...super.callKeys, "options"]; - } - - lc_serializable = true; - - get lc_secrets(): { [key: string]: string } | undefined { - return { - openAIApiKey: "OPENAI_API_KEY", - azureOpenAIApiKey: "AZURE_OPENAI_API_KEY", - organization: "OPENAI_ORGANIZATION", - }; - } - - get lc_aliases(): Record { - return { - modelName: "model", - openAIApiKey: "openai_api_key", - azureOpenAIApiVersion: "azure_openai_api_version", - azureOpenAIApiKey: "azure_openai_api_key", - azureOpenAIApiInstanceName: "azure_openai_api_instance_name", - azureOpenAIApiDeploymentName: "azure_openai_api_deployment_name", - }; - } - - temperature = 0.7; - - maxTokens = 256; - - topP = 1; - - frequencyPenalty = 0; - - presencePenalty = 0; - - n = 1; - - bestOf?: number; - - logitBias?: Record; - - modelName = "gpt-3.5-turbo-instruct"; - - modelKwargs?: OpenAIInput["modelKwargs"]; - - batchSize = 20; - - timeout?: number; - stop?: string[]; +export { + type AzureOpenAIInput, + type OpenAICallOptions, + type OpenAIInput, +} from "@langchain/openai"; - user?: string; - - streaming = false; - - openAIApiKey?: string; - - azureOpenAIApiVersion?: string; - - azureOpenAIApiKey?: string; - - azureOpenAIApiInstanceName?: string; - - azureOpenAIApiDeploymentName?: string; - - azureOpenAIBasePath?: string; - - organization?: string; - - private client: OpenAIClient; - - private clientConfig: ClientOptions; - - constructor( - fields?: Partial & - Partial & - BaseLLMParams & { - configuration?: ClientOptions & LegacyOpenAIInput; - }, - /** @deprecated */ - configuration?: ClientOptions & LegacyOpenAIInput - ) { - if ( - (fields?.modelName?.startsWith("gpt-3.5-turbo") || - fields?.modelName?.startsWith("gpt-4")) && - !fields?.modelName?.includes("-instruct") - ) { - // eslint-disable-next-line no-constructor-return - return new OpenAIChat( - fields, - configuration - ) as unknown as OpenAI; - } - super(fields ?? {}); - - this.openAIApiKey = - fields?.openAIApiKey ?? getEnvironmentVariable("OPENAI_API_KEY"); - - this.azureOpenAIApiKey = - fields?.azureOpenAIApiKey ?? - getEnvironmentVariable("AZURE_OPENAI_API_KEY"); - - if (!this.azureOpenAIApiKey && !this.openAIApiKey) { - throw new Error("OpenAI or Azure OpenAI API key not found"); - } - - this.azureOpenAIApiInstanceName = - fields?.azureOpenAIApiInstanceName ?? - getEnvironmentVariable("AZURE_OPENAI_API_INSTANCE_NAME"); - - this.azureOpenAIApiDeploymentName = - (fields?.azureOpenAIApiCompletionsDeploymentName || - fields?.azureOpenAIApiDeploymentName) ?? - (getEnvironmentVariable("AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME") || - getEnvironmentVariable("AZURE_OPENAI_API_DEPLOYMENT_NAME")); - - this.azureOpenAIApiVersion = - fields?.azureOpenAIApiVersion ?? - getEnvironmentVariable("AZURE_OPENAI_API_VERSION"); - - this.azureOpenAIBasePath = - fields?.azureOpenAIBasePath ?? - getEnvironmentVariable("AZURE_OPENAI_BASE_PATH"); - - this.organization = - fields?.configuration?.organization ?? - getEnvironmentVariable("OPENAI_ORGANIZATION"); - - this.modelName = fields?.modelName ?? this.modelName; - this.modelKwargs = fields?.modelKwargs ?? {}; - this.batchSize = fields?.batchSize ?? this.batchSize; - this.timeout = fields?.timeout; - - this.temperature = fields?.temperature ?? this.temperature; - this.maxTokens = fields?.maxTokens ?? this.maxTokens; - this.topP = fields?.topP ?? this.topP; - this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty; - this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty; - this.n = fields?.n ?? this.n; - this.bestOf = fields?.bestOf ?? this.bestOf; - this.logitBias = fields?.logitBias; - this.stop = fields?.stop; - this.user = fields?.user; - - this.streaming = fields?.streaming ?? false; - - if (this.streaming && this.bestOf && this.bestOf > 1) { - throw new Error("Cannot stream results when bestOf > 1"); - } - - if (this.azureOpenAIApiKey) { - if (!this.azureOpenAIApiInstanceName && !this.azureOpenAIBasePath) { - throw new Error("Azure OpenAI API instance name not found"); - } - if (!this.azureOpenAIApiDeploymentName) { - throw new Error("Azure OpenAI API deployment name not found"); - } - if (!this.azureOpenAIApiVersion) { - throw new Error("Azure OpenAI API version not found"); - } - this.openAIApiKey = this.openAIApiKey ?? ""; - } - - this.clientConfig = { - apiKey: this.openAIApiKey, - organization: this.organization, - baseURL: configuration?.basePath ?? fields?.configuration?.basePath, - dangerouslyAllowBrowser: true, - defaultHeaders: - configuration?.baseOptions?.headers ?? - fields?.configuration?.baseOptions?.headers, - defaultQuery: - configuration?.baseOptions?.params ?? - fields?.configuration?.baseOptions?.params, - ...configuration, - ...fields?.configuration, - }; - } - - /** - * Get the parameters used to invoke the model - */ - invocationParams( - options?: this["ParsedCallOptions"] - ): Omit { - return { - model: this.modelName, - temperature: this.temperature, - max_tokens: this.maxTokens, - top_p: this.topP, - frequency_penalty: this.frequencyPenalty, - presence_penalty: this.presencePenalty, - n: this.n, - best_of: this.bestOf, - logit_bias: this.logitBias, - stop: options?.stop ?? this.stop, - user: this.user, - stream: this.streaming, - ...this.modelKwargs, - }; - } - - /** @ignore */ - _identifyingParams(): Omit & { - model_name: string; - } & ClientOptions { - return { - model_name: this.modelName, - ...this.invocationParams(), - ...this.clientConfig, - }; - } - - /** - * Get the identifying parameters for the model - */ - identifyingParams(): Omit & { - model_name: string; - } & ClientOptions { - return this._identifyingParams(); - } - - /** - * Call out to OpenAI's endpoint with k unique prompts - * - * @param [prompts] - The prompts to pass into the model. - * @param [options] - Optional list of stop words to use when generating. - * @param [runManager] - Optional callback manager to use when generating. - * - * @returns The full LLM output. - * - * @example - * ```ts - * import { OpenAI } from "langchain/llms/openai"; - * const openai = new OpenAI(); - * const response = await openai.generate(["Tell me a joke."]); - * ``` - */ - async _generate( - prompts: string[], - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): Promise { - const subPrompts = chunkArray(prompts, this.batchSize); - const choices: OpenAIClient.CompletionChoice[] = []; - const tokenUsage: TokenUsage = {}; - - const params = this.invocationParams(options); - - if (params.max_tokens === -1) { - if (prompts.length !== 1) { - throw new Error( - "max_tokens set to -1 not supported for multiple inputs" - ); - } - params.max_tokens = await calculateMaxTokens({ - prompt: prompts[0], - // Cast here to allow for other models that may not fit the union - modelName: this.modelName as TiktokenModel, - }); - } - - for (let i = 0; i < subPrompts.length; i += 1) { - const data = params.stream - ? await (async () => { - const choices: OpenAIClient.CompletionChoice[] = []; - let response: Omit | undefined; - const stream = await this.completionWithRetry( - { - ...params, - stream: true, - prompt: subPrompts[i], - }, - options - ); - for await (const message of stream) { - // on the first message set the response properties - if (!response) { - response = { - id: message.id, - object: message.object, - created: message.created, - model: message.model, - }; - } - - // on all messages, update choice - for (const part of message.choices) { - if (!choices[part.index]) { - choices[part.index] = part; - } else { - const choice = choices[part.index]; - choice.text += part.text; - choice.finish_reason = part.finish_reason; - choice.logprobs = part.logprobs; - } - void runManager?.handleLLMNewToken(part.text, { - prompt: Math.floor(part.index / this.n), - completion: part.index % this.n, - }); - } - } - if (options.signal?.aborted) { - throw new Error("AbortError"); - } - return { ...response, choices }; - })() - : await this.completionWithRetry( - { - ...params, - stream: false, - prompt: subPrompts[i], - }, - { - signal: options.signal, - ...options.options, - } - ); - - choices.push(...data.choices); - const { - completion_tokens: completionTokens, - prompt_tokens: promptTokens, - total_tokens: totalTokens, - } = data.usage - ? data.usage - : { - completion_tokens: undefined, - prompt_tokens: undefined, - total_tokens: undefined, - }; - - if (completionTokens) { - tokenUsage.completionTokens = - (tokenUsage.completionTokens ?? 0) + completionTokens; - } - - if (promptTokens) { - tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + promptTokens; - } - - if (totalTokens) { - tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens; - } - } - - const generations = chunkArray(choices, this.n).map((promptChoices) => - promptChoices.map((choice) => ({ - text: choice.text ?? "", - generationInfo: { - finishReason: choice.finish_reason, - logprobs: choice.logprobs, - }, - })) - ); - return { - generations, - llmOutput: { tokenUsage }, - }; - } - - // TODO(jacoblee): Refactor with _generate(..., {stream: true}) implementation? - async *_streamResponseChunks( - input: string, - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): AsyncGenerator { - const params = { - ...this.invocationParams(options), - prompt: input, - stream: true as const, - }; - const stream = await this.completionWithRetry(params, options); - for await (const data of stream) { - const choice = data?.choices[0]; - if (!choice) { - continue; - } - const chunk = new GenerationChunk({ - text: choice.text, - generationInfo: { - finishReason: choice.finish_reason, - }, - }); - yield chunk; - // eslint-disable-next-line no-void - void runManager?.handleLLMNewToken(chunk.text ?? ""); - } - if (options.signal?.aborted) { - throw new Error("AbortError"); - } - } - - /** - * Calls the OpenAI API with retry logic in case of failures. - * @param request The request to send to the OpenAI API. - * @param options Optional configuration for the API call. - * @returns The response from the OpenAI API. - */ - async completionWithRetry( - request: OpenAIClient.CompletionCreateParamsStreaming, - options?: OpenAICoreRequestOptions - ): Promise>; - - async completionWithRetry( - request: OpenAIClient.CompletionCreateParamsNonStreaming, - options?: OpenAICoreRequestOptions - ): Promise; - - async completionWithRetry( - request: - | OpenAIClient.CompletionCreateParamsStreaming - | OpenAIClient.CompletionCreateParamsNonStreaming, - options?: OpenAICoreRequestOptions - ): Promise< - AsyncIterable | OpenAIClient.Completions.Completion - > { - const requestOptions = this._getClientOptions(options); - return this.caller.call(async () => { - try { - const res = await this.client.completions.create( - request, - requestOptions - ); - return res; - } catch (e) { - const error = wrapOpenAIClientError(e); - throw error; - } - }); - } - - /** - * Calls the OpenAI API with retry logic in case of failures. - * @param request The request to send to the OpenAI API. - * @param options Optional configuration for the API call. - * @returns The response from the OpenAI API. - */ - private _getClientOptions(options: OpenAICoreRequestOptions | undefined) { - if (!this.client) { - const openAIEndpointConfig: OpenAIEndpointConfig = { - azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName, - azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName, - azureOpenAIApiKey: this.azureOpenAIApiKey, - azureOpenAIBasePath: this.azureOpenAIBasePath, - baseURL: this.clientConfig.baseURL, - }; - - const endpoint = getEndpoint(openAIEndpointConfig); - - const params = { - ...this.clientConfig, - baseURL: endpoint, - timeout: this.timeout, - maxRetries: 0, - }; - - if (!params.baseURL) { - delete params.baseURL; - } - - this.client = new OpenAIClient(params); - } - const requestOptions = { - ...this.clientConfig, - ...options, - } as OpenAICoreRequestOptions; - if (this.azureOpenAIApiKey) { - requestOptions.headers = { - "api-key": this.azureOpenAIApiKey, - ...requestOptions.headers, - }; - requestOptions.query = { - "api-version": this.azureOpenAIApiVersion, - ...requestOptions.query, - }; - } - return requestOptions; - } - - _llmType() { - return "openai"; - } -} +export { OpenAI }; /** * PromptLayer wrapper to OpenAI diff --git a/langchain/src/llms/portkey.ts b/langchain/src/llms/portkey.ts index c6a2df313826..79b975e5fabc 100644 --- a/langchain/src/llms/portkey.ts +++ b/langchain/src/llms/portkey.ts @@ -1,179 +1 @@ -import _ from "lodash"; -import { LLMOptions, Portkey as _Portkey } from "portkey-ai"; -import { CallbackManagerForLLMRun } from "../callbacks/manager.js"; -import { GenerationChunk, LLMResult } from "../schema/index.js"; -import { getEnvironmentVariable } from "../util/env.js"; -import { BaseLLM } from "./base.js"; - -interface PortkeyOptions { - apiKey?: string; - baseURL?: string; - mode?: string; - llms?: [LLMOptions] | null; -} - -const readEnv = (env: string, default_val?: string): string | undefined => - getEnvironmentVariable(env) ?? default_val; - -export class PortkeySession { - portkey: _Portkey; - - constructor(options: PortkeyOptions = {}) { - if (!options.apiKey) { - /* eslint-disable no-param-reassign */ - options.apiKey = readEnv("PORTKEY_API_KEY"); - } - - if (!options.baseURL) { - /* eslint-disable no-param-reassign */ - options.baseURL = readEnv("PORTKEY_BASE_URL", "https://api.portkey.ai"); - } - - this.portkey = new _Portkey({}); - this.portkey.llms = [{}]; - if (!options.apiKey) { - throw new Error("Set Portkey ApiKey in PORTKEY_API_KEY env variable"); - } - - this.portkey = new _Portkey(options); - } -} - -const defaultPortkeySession: { - session: PortkeySession; - options: PortkeyOptions; -}[] = []; - -/** - * Get a session for the Portkey API. If one already exists with the same options, - * it will be returned. Otherwise, a new session will be created. - * @param options - * @returns - */ -export function getPortkeySession(options: PortkeyOptions = {}) { - let session = defaultPortkeySession.find((session) => - _.isEqual(session.options, options) - )?.session; - - if (!session) { - session = new PortkeySession(options); - defaultPortkeySession.push({ session, options }); - } - return session; -} - -/** - * @example - * ```typescript - * const model = new Portkey({ - * mode: "single", - * llms: [ - * { - * provider: "openai", - * virtual_key: "open-ai-key-1234", - * model: "text-davinci-003", - * max_tokens: 2000, - * }, - * ], - * }); - * - * // Stream the output of the model and process it - * const res = await model.stream( - * "Question: Write a story about a king\nAnswer:" - * ); - * for await (const i of res) { - * process.stdout.write(i); - * } - * ``` - */ -export class Portkey extends BaseLLM { - apiKey?: string = undefined; - - baseURL?: string = undefined; - - mode?: string = undefined; - - llms?: [LLMOptions] | null = undefined; - - session: PortkeySession; - - constructor(init?: Partial) { - super(init ?? {}); - this.apiKey = init?.apiKey; - - this.baseURL = init?.baseURL; - - this.mode = init?.mode; - - this.llms = init?.llms; - - this.session = getPortkeySession({ - apiKey: this.apiKey, - baseURL: this.baseURL, - llms: this.llms, - mode: this.mode, - }); - } - - _llmType() { - return "portkey"; - } - - async _generate( - prompts: string[], - options: this["ParsedCallOptions"], - _?: CallbackManagerForLLMRun - ): Promise { - const choices = []; - for (let i = 0; i < prompts.length; i += 1) { - const response = await this.session.portkey.completions.create({ - prompt: prompts[i], - ...options, - stream: false, - }); - choices.push(response.choices); - } - const generations = choices.map((promptChoices) => - promptChoices.map((choice) => ({ - text: choice.text ?? "", - generationInfo: { - finishReason: choice.finish_reason, - logprobs: choice.logprobs, - }, - })) - ); - - return { - generations, - }; - } - - async *_streamResponseChunks( - input: string, - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): AsyncGenerator { - const response = await this.session.portkey.completions.create({ - prompt: input, - ...options, - stream: true, - }); - for await (const data of response) { - const choice = data?.choices[0]; - if (!choice) { - continue; - } - const chunk = new GenerationChunk({ - text: choice.text ?? "", - generationInfo: { - finishReason: choice.finish_reason, - }, - }); - yield chunk; - void runManager?.handleLLMNewToken(chunk.text ?? ""); - } - if (options.signal?.aborted) { - throw new Error("AbortError"); - } - } -} +export * from "@langchain/community/llms/portkey"; diff --git a/langchain/src/llms/raycast.ts b/langchain/src/llms/raycast.ts index 7901f0fde238..a36430dc3817 100644 --- a/langchain/src/llms/raycast.ts +++ b/langchain/src/llms/raycast.ts @@ -1,99 +1 @@ -import { AI, environment } from "@raycast/api"; -import { LLM, BaseLLMParams } from "./base.js"; - -/** - * The input parameters for the RaycastAI class, which extends the BaseLLMParams interface. - */ -export interface RaycastAIInput extends BaseLLMParams { - model?: AI.Model; - creativity?: number; - rateLimitPerMinute?: number; -} - -const wait = (ms: number) => - new Promise((resolve) => { - setTimeout(resolve, ms); - }); - -/** - * The RaycastAI class, which extends the LLM class and implements the RaycastAIInput interface. - */ -export class RaycastAI extends LLM implements RaycastAIInput { - /** - * The model to use for generating text. - */ - model: AI.Model; - - /** - * The creativity parameter, also known as the "temperature". - */ - creativity: number; - - /** - * The rate limit for API calls, in requests per minute. - */ - rateLimitPerMinute: number; - - /** - * The timestamp of the last API call, used to enforce the rate limit. - */ - private lastCallTimestamp = 0; - - /** - * Creates a new instance of the RaycastAI class. - * @param {RaycastAIInput} fields The input parameters for the RaycastAI class. - * @throws {Error} If the Raycast AI environment is not accessible. - */ - constructor(fields: RaycastAIInput) { - super(fields ?? {}); - - if (!environment.canAccess(AI)) { - throw new Error("Raycast AI environment is not accessible."); - } - - this.model = fields.model ?? "text-davinci-003"; - this.creativity = fields.creativity ?? 0.5; - this.rateLimitPerMinute = fields.rateLimitPerMinute ?? 10; - } - - /** - * Returns the type of the LLM, which is "raycast_ai". - * @return {string} The type of the LLM. - * @ignore - */ - _llmType() { - return "raycast_ai"; - } - - /** - * Calls AI.ask with the given prompt and returns the generated text. - * @param {string} prompt The prompt to generate text from. - * @return {Promise} A Promise that resolves to the generated text. - * @ignore - */ - async _call( - prompt: string, - options: this["ParsedCallOptions"] - ): Promise { - const response = await this.caller.call(async () => { - // Rate limit calls to Raycast AI - const now = Date.now(); - const timeSinceLastCall = now - this.lastCallTimestamp; - const timeToWait = - (60 / this.rateLimitPerMinute) * 1000 - timeSinceLastCall; - - if (timeToWait > 0) { - await wait(timeToWait); - } - - return await AI.ask(prompt, { - model: this.model, - creativity: this.creativity, - signal: options.signal, - }); - }); - - // Since Raycast AI returns the response directly, no need for output transformation - return response; - } -} +export * from "@langchain/community/llms/raycast"; diff --git a/langchain/src/llms/replicate.ts b/langchain/src/llms/replicate.ts index 27fa66c5eb7c..72c1ca24a637 100644 --- a/langchain/src/llms/replicate.ts +++ b/langchain/src/llms/replicate.ts @@ -1,158 +1 @@ -import { getEnvironmentVariable } from "../util/env.js"; -import { LLM, BaseLLMParams } from "./base.js"; - -/** - * Interface defining the structure of the input data for the Replicate - * class. It includes details about the model to be used, any additional - * input parameters, and the API key for the Replicate service. - */ -export interface ReplicateInput { - // owner/model_name:version - model: `${string}/${string}:${string}`; - - input?: { - // different models accept different inputs - [key: string]: string | number | boolean; - }; - - apiKey?: string; - - /** The key used to pass prompts to the model. */ - promptKey?: string; -} - -/** - * Class responsible for managing the interaction with the Replicate API. - * It handles the API key and model details, makes the actual API calls, - * and converts the API response into a format usable by the rest of the - * LangChain framework. - * @example - * ```typescript - * const model = new Replicate({ - * model: "replicate/flan-t5-xl:3ae0799123a1fe11f8c89fd99632f843fc5f7a761630160521c4253149754523", - * }); - * - * const res = await model.call( - * "Question: What would be a good company name for a company that makes colorful socks?\nAnswer:" - * ); - * console.log({ res }); - * ``` - */ -export class Replicate extends LLM implements ReplicateInput { - static lc_name() { - return "Replicate"; - } - - get lc_secrets(): { [key: string]: string } | undefined { - return { - apiKey: "REPLICATE_API_TOKEN", - }; - } - - lc_serializable = true; - - model: ReplicateInput["model"]; - - input: ReplicateInput["input"]; - - apiKey: string; - - promptKey?: string; - - constructor(fields: ReplicateInput & BaseLLMParams) { - super(fields); - - const apiKey = - fields?.apiKey ?? - getEnvironmentVariable("REPLICATE_API_KEY") ?? // previous environment variable for backwards compatibility - getEnvironmentVariable("REPLICATE_API_TOKEN"); // current environment variable, matching the Python library - - if (!apiKey) { - throw new Error( - "Please set the REPLICATE_API_TOKEN environment variable" - ); - } - - this.apiKey = apiKey; - this.model = fields.model; - this.input = fields.input ?? {}; - this.promptKey = fields.promptKey; - } - - _llmType() { - return "replicate"; - } - - /** @ignore */ - async _call( - prompt: string, - options: this["ParsedCallOptions"] - ): Promise { - const imports = await Replicate.imports(); - - const replicate = new imports.Replicate({ - userAgent: "langchain", - auth: this.apiKey, - }); - - if (this.promptKey === undefined) { - const [modelString, versionString] = this.model.split(":"); - const version = await replicate.models.versions.get( - modelString.split("/")[0], - modelString.split("/")[1], - versionString - ); - const openapiSchema = version.openapi_schema; - const inputProperties: { "x-order": number | undefined }[] = - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (openapiSchema as any)?.components?.schemas?.Input?.properties; - if (inputProperties === undefined) { - this.promptKey = "prompt"; - } else { - const sortedInputProperties = Object.entries(inputProperties).sort( - ([_keyA, valueA], [_keyB, valueB]) => { - const orderA = valueA["x-order"] || 0; - const orderB = valueB["x-order"] || 0; - return orderA - orderB; - } - ); - this.promptKey = sortedInputProperties[0][0] ?? "prompt"; - } - } - const output = await this.caller.callWithOptions( - { signal: options.signal }, - () => - replicate.run(this.model, { - input: { - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - [this.promptKey!]: prompt, - ...this.input, - }, - }) - ); - - if (typeof output === "string") { - return output; - } else if (Array.isArray(output)) { - return output.join(""); - } else { - // Note this is a little odd, but the output format is not consistent - // across models, so it makes some amount of sense. - return String(output); - } - } - - /** @ignore */ - static async imports(): Promise<{ - Replicate: typeof import("replicate").default; - }> { - try { - const { default: Replicate } = await import("replicate"); - return { Replicate }; - } catch (e) { - throw new Error( - "Please install replicate as a dependency with, e.g. `yarn add replicate`" - ); - } - } -} +export * from "@langchain/community/llms/replicate"; diff --git a/langchain/src/llms/sagemaker_endpoint.ts b/langchain/src/llms/sagemaker_endpoint.ts index 38706d608417..f9a5e590c0fb 100644 --- a/langchain/src/llms/sagemaker_endpoint.ts +++ b/langchain/src/llms/sagemaker_endpoint.ts @@ -1,283 +1 @@ -import { - InvokeEndpointCommand, - InvokeEndpointWithResponseStreamCommand, - SageMakerRuntimeClient, - SageMakerRuntimeClientConfig, -} from "@aws-sdk/client-sagemaker-runtime"; -import { CallbackManagerForLLMRun } from "../callbacks/manager.js"; -import { GenerationChunk } from "../schema/index.js"; -import { BaseLLMCallOptions, BaseLLMParams, LLM } from "./base.js"; - -/** - * A handler class to transform input from LLM to a format that SageMaker - * endpoint expects. Similarily, the class also handles transforming output from - * the SageMaker endpoint to a format that LLM class expects. - * - * Example: - * ``` - * class ContentHandler implements ContentHandlerBase { - * contentType = "application/json" - * accepts = "application/json" - * - * transformInput(prompt: string, modelKwargs: Record) { - * const inputString = JSON.stringify({ - * prompt, - * ...modelKwargs - * }) - * return Buffer.from(inputString) - * } - * - * transformOutput(output: Uint8Array) { - * const responseJson = JSON.parse(Buffer.from(output).toString("utf-8")) - * return responseJson[0].generated_text - * } - * - * } - * ``` - */ -export abstract class BaseSageMakerContentHandler { - contentType = "text/plain"; - - accepts = "text/plain"; - - /** - * Transforms the prompt and model arguments into a specific format for sending to SageMaker. - * @param {InputType} prompt The prompt to be transformed. - * @param {Record} modelKwargs Additional arguments. - * @returns {Promise} A promise that resolves to the formatted data for sending. - */ - abstract transformInput( - prompt: InputType, - modelKwargs: Record - ): Promise; - - /** - * Transforms SageMaker output into a desired format. - * @param {Uint8Array} output The raw output from SageMaker. - * @returns {Promise} A promise that resolves to the transformed data. - */ - abstract transformOutput(output: Uint8Array): Promise; -} - -export type SageMakerLLMContentHandler = BaseSageMakerContentHandler< - string, - string ->; - -/** - * The SageMakerEndpointInput interface defines the input parameters for - * the SageMakerEndpoint class, which includes the endpoint name, client - * options for the SageMaker client, the content handler, and optional - * keyword arguments for the model and the endpoint. - */ -export interface SageMakerEndpointInput extends BaseLLMParams { - /** - * The name of the endpoint from the deployed SageMaker model. Must be unique - * within an AWS Region. - */ - endpointName: string; - /** - * Options passed to the SageMaker client. - */ - clientOptions: SageMakerRuntimeClientConfig; - /** - * Key word arguments to pass to the model. - */ - modelKwargs?: Record; - /** - * Optional attributes passed to the InvokeEndpointCommand - */ - endpointKwargs?: Record; - /** - * The content handler class that provides an input and output transform - * functions to handle formats between LLM and the endpoint. - */ - contentHandler: SageMakerLLMContentHandler; - streaming?: boolean; -} - -/** - * The SageMakerEndpoint class is used to interact with SageMaker - * Inference Endpoint models. It uses the AWS client for authentication, - * which automatically loads credentials. - * If a specific credential profile is to be used, the name of the profile - * from the ~/.aws/credentials file must be passed. The credentials or - * roles used should have the required policies to access the SageMaker - * endpoint. - */ -export class SageMakerEndpoint extends LLM { - static lc_name() { - return "SageMakerEndpoint"; - } - - get lc_secrets(): { [key: string]: string } | undefined { - return { - "clientOptions.credentials.accessKeyId": "AWS_ACCESS_KEY_ID", - "clientOptions.credentials.secretAccessKey": "AWS_SECRET_ACCESS_KEY", - "clientOptions.credentials.sessionToken": "AWS_SESSION_TOKEN", - }; - } - - endpointName: string; - - modelKwargs?: Record; - - endpointKwargs?: Record; - - client: SageMakerRuntimeClient; - - contentHandler: SageMakerLLMContentHandler; - - streaming: boolean; - - constructor(fields: SageMakerEndpointInput) { - super(fields); - - if (!fields.clientOptions.region) { - throw new Error( - `Please pass a "clientOptions" object with a "region" field to the constructor` - ); - } - - const endpointName = fields?.endpointName; - if (!endpointName) { - throw new Error(`Please pass an "endpointName" field to the constructor`); - } - - const contentHandler = fields?.contentHandler; - if (!contentHandler) { - throw new Error( - `Please pass a "contentHandler" field to the constructor` - ); - } - - this.endpointName = fields.endpointName; - this.contentHandler = fields.contentHandler; - this.endpointKwargs = fields.endpointKwargs; - this.modelKwargs = fields.modelKwargs; - this.streaming = fields.streaming ?? false; - this.client = new SageMakerRuntimeClient(fields.clientOptions); - } - - _llmType() { - return "sagemaker_endpoint"; - } - - /** - * Calls the SageMaker endpoint and retrieves the result. - * @param {string} prompt The input prompt. - * @param {this["ParsedCallOptions"]} options Parsed call options. - * @param {CallbackManagerForLLMRun} runManager Optional run manager. - * @returns {Promise} A promise that resolves to the generated string. - */ - /** @ignore */ - async _call( - prompt: string, - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): Promise { - return this.streaming - ? await this.streamingCall(prompt, options, runManager) - : await this.noStreamingCall(prompt, options); - } - - private async streamingCall( - prompt: string, - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): Promise { - const chunks = []; - for await (const chunk of this._streamResponseChunks( - prompt, - options, - runManager - )) { - chunks.push(chunk.text); - } - return chunks.join(""); - } - - private async noStreamingCall( - prompt: string, - options: this["ParsedCallOptions"] - ): Promise { - const body = await this.contentHandler.transformInput( - prompt, - this.modelKwargs ?? {} - ); - const { contentType, accepts } = this.contentHandler; - - const response = await this.caller.call(() => - this.client.send( - new InvokeEndpointCommand({ - EndpointName: this.endpointName, - Body: body, - ContentType: contentType, - Accept: accepts, - ...this.endpointKwargs, - }), - { abortSignal: options.signal } - ) - ); - - if (response.Body === undefined) { - throw new Error("Inference result missing Body"); - } - return this.contentHandler.transformOutput(response.Body); - } - - /** - * Streams response chunks from the SageMaker endpoint. - * @param {string} prompt The input prompt. - * @param {this["ParsedCallOptions"]} options Parsed call options. - * @returns {AsyncGenerator} An asynchronous generator yielding generation chunks. - */ - async *_streamResponseChunks( - prompt: string, - options: this["ParsedCallOptions"], - runManager?: CallbackManagerForLLMRun - ): AsyncGenerator { - const body = await this.contentHandler.transformInput( - prompt, - this.modelKwargs ?? {} - ); - const { contentType, accepts } = this.contentHandler; - - const stream = await this.caller.call(() => - this.client.send( - new InvokeEndpointWithResponseStreamCommand({ - EndpointName: this.endpointName, - Body: body, - ContentType: contentType, - Accept: accepts, - ...this.endpointKwargs, - }), - { abortSignal: options.signal } - ) - ); - - if (!stream.Body) { - throw new Error("Inference result missing Body"); - } - - for await (const chunk of stream.Body) { - if (chunk.PayloadPart && chunk.PayloadPart.Bytes) { - const text = await this.contentHandler.transformOutput( - chunk.PayloadPart.Bytes - ); - yield new GenerationChunk({ - text, - generationInfo: { - ...chunk, - response: undefined, - }, - }); - await runManager?.handleLLMNewToken(text); - } else if (chunk.InternalStreamFailure) { - throw new Error(chunk.InternalStreamFailure.message); - } else if (chunk.ModelStreamError) { - throw new Error(chunk.ModelStreamError.message); - } - } - } -} +export * from "@langchain/community/llms/sagemaker_endpoint"; diff --git a/langchain/src/llms/watsonx_ai.ts b/langchain/src/llms/watsonx_ai.ts index dca510ba21c1..741308ef5ffb 100644 --- a/langchain/src/llms/watsonx_ai.ts +++ b/langchain/src/llms/watsonx_ai.ts @@ -1,194 +1 @@ -import { BaseLLMCallOptions, BaseLLMParams, LLM } from "./base.js"; -import { getEnvironmentVariable } from "../util/env.js"; - -/** - * The WatsonxAIParams interface defines the input parameters for - * the WatsonxAI class. - */ -export interface WatsonxAIParams extends BaseLLMParams { - /** - * WatsonX AI Complete Endpoint. - * Can be used if you want a fully custom endpoint. - */ - endpoint?: string; - /** - * IBM Cloud Compute Region. - * eg. us-south, us-east, etc. - */ - region?: string; - /** - * WatsonX AI Version. - * Date representing the WatsonX AI Version. - * eg. 2023-05-29 - */ - version?: string; - /** - * WatsonX AI Key. - * Provide API Key if you do not wish to automatically pull from env. - */ - ibmCloudApiKey?: string; - /** - * WatsonX AI Key. - * Provide API Key if you do not wish to automatically pull from env. - */ - projectId?: string; - /** - * Parameters accepted by the WatsonX AI Endpoint. - */ - modelParameters?: Record; - /** - * WatsonX AI Model ID. - */ - modelId?: string; -} - -const endpointConstructor = (region: string, version: string) => - `https://${region}.ml.cloud.ibm.com/ml/v1-beta/generation/text?version=${version}`; - -/** - * The WatsonxAI class is used to interact with Watsonx AI - * Inference Endpoint models. It uses IBM Cloud for authentication. - * This requires your IBM Cloud API Key which is autoloaded if not specified. - */ - -export class WatsonxAI extends LLM { - static lc_name() { - return "WatsonxAI"; - } - - get lc_secrets(): { [key: string]: string } | undefined { - return { - ibmCloudApiKey: "IBM_CLOUD_API_KEY", - projectId: "WATSONX_PROJECT_ID", - }; - } - - endpoint: string; - - region = "us-south"; - - version = "2023-05-29"; - - modelId = "meta-llama/llama-2-70b-chat"; - - modelKwargs?: Record; - - ibmCloudApiKey?: string; - - ibmCloudToken?: string; - - ibmCloudTokenExpiresAt?: number; - - projectId?: string; - - modelParameters?: Record; - - constructor(fields: WatsonxAIParams) { - super(fields); - - this.region = fields?.region ?? this.region; - this.version = fields?.version ?? this.version; - this.modelId = fields?.modelId ?? this.modelId; - this.ibmCloudApiKey = - fields?.ibmCloudApiKey ?? getEnvironmentVariable("IBM_CLOUD_API_KEY"); - this.projectId = - fields?.projectId ?? getEnvironmentVariable("WATSONX_PROJECT_ID"); - - this.endpoint = - fields?.endpoint ?? endpointConstructor(this.region, this.version); - this.modelParameters = fields.modelParameters; - - if (!this.ibmCloudApiKey) { - throw new Error("Missing IBM Cloud API Key"); - } - - if (!this.projectId) { - throw new Error("Missing WatsonX AI Project ID"); - } - } - - _llmType() { - return "watsonx_ai"; - } - - /** - * Calls the WatsonX AI endpoint and retrieves the result. - * @param {string} prompt The input prompt. - * @returns {Promise} A promise that resolves to the generated string. - */ - /** @ignore */ - async _call( - prompt: string, - _options: this["ParsedCallOptions"] - ): Promise { - interface WatsonxAIResponse { - results: { - generated_text: string; - generated_token_count: number; - input_token_count: number; - }[]; - errors: { - code: string; - message: string; - }[]; - } - const response = (await this.caller.call(async () => - fetch(this.endpoint, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "application/json", - Authorization: `Bearer ${await this.generateToken()}`, - }, - body: JSON.stringify({ - project_id: this.projectId, - model_id: this.modelId, - input: prompt, - parameters: this.modelParameters, - }), - }).then((res) => res.json()) - )) as WatsonxAIResponse; - - /** - * Handle Errors for invalid requests. - */ - if (response.errors) { - throw new Error(response.errors[0].message); - } - - return response.results[0].generated_text; - } - - async generateToken(): Promise { - if (this.ibmCloudToken && this.ibmCloudTokenExpiresAt) { - if (this.ibmCloudTokenExpiresAt > Date.now()) { - return this.ibmCloudToken; - } - } - - interface TokenResponse { - access_token: string; - expiration: number; - } - - const urlTokenParams = new URLSearchParams(); - urlTokenParams.append( - "grant_type", - "urn:ibm:params:oauth:grant-type:apikey" - ); - urlTokenParams.append("apikey", this.ibmCloudApiKey as string); - - const data = (await fetch("https://iam.cloud.ibm.com/identity/token", { - method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - }, - body: urlTokenParams, - }).then((res) => res.json())) as TokenResponse; - - this.ibmCloudTokenExpiresAt = data.expiration * 1000; - this.ibmCloudToken = data.access_token; - - return this.ibmCloudToken; - } -} +export * from "@langchain/community/llms/watsonx_ai"; diff --git a/langchain/src/llms/writer.ts b/langchain/src/llms/writer.ts index 323167d41bdc..225f212949d8 100644 --- a/langchain/src/llms/writer.ts +++ b/langchain/src/llms/writer.ts @@ -1,172 +1 @@ -import { Writer as WriterClient } from "@writerai/writer-sdk"; - -import { BaseLLMParams, LLM } from "./base.js"; -import { getEnvironmentVariable } from "../util/env.js"; - -/** - * Interface for the input parameters specific to the Writer model. - */ -export interface WriterInput extends BaseLLMParams { - /** Writer API key */ - apiKey?: string; - - /** Writer organization ID */ - orgId?: string | number; - - /** Model to use */ - model?: string; - - /** Sampling temperature to use */ - temperature?: number; - - /** Minimum number of tokens to generate. */ - minTokens?: number; - - /** Maximum number of tokens to generate in the completion. */ - maxTokens?: number; - - /** Generates this many completions server-side and returns the "best"." */ - bestOf?: number; - - /** Penalizes repeated tokens according to frequency. */ - frequencyPenalty?: number; - - /** Whether to return log probabilities. */ - logprobs?: number; - - /** Number of completions to generate. */ - n?: number; - - /** Penalizes repeated tokens regardless of frequency. */ - presencePenalty?: number; - - /** Total probability mass of tokens to consider at each step. */ - topP?: number; -} - -/** - * Class representing a Writer Large Language Model (LLM). It interacts - * with the Writer API to generate text completions. - */ -export class Writer extends LLM implements WriterInput { - static lc_name() { - return "Writer"; - } - - get lc_secrets(): { [key: string]: string } | undefined { - return { - apiKey: "WRITER_API_KEY", - orgId: "WRITER_ORG_ID", - }; - } - - get lc_aliases(): { [key: string]: string } | undefined { - return { - apiKey: "writer_api_key", - orgId: "writer_org_id", - }; - } - - lc_serializable = true; - - apiKey: string; - - orgId: number; - - model = "palmyra-instruct"; - - temperature?: number; - - minTokens?: number; - - maxTokens?: number; - - bestOf?: number; - - frequencyPenalty?: number; - - logprobs?: number; - - n?: number; - - presencePenalty?: number; - - topP?: number; - - constructor(fields?: WriterInput) { - super(fields ?? {}); - - const apiKey = fields?.apiKey ?? getEnvironmentVariable("WRITER_API_KEY"); - const orgId = fields?.orgId ?? getEnvironmentVariable("WRITER_ORG_ID"); - - if (!apiKey) { - throw new Error( - "Please set the WRITER_API_KEY environment variable or pass it to the constructor as the apiKey field." - ); - } - - if (!orgId) { - throw new Error( - "Please set the WRITER_ORG_ID environment variable or pass it to the constructor as the orgId field." - ); - } - - this.apiKey = apiKey; - this.orgId = typeof orgId === "string" ? parseInt(orgId, 10) : orgId; - this.model = fields?.model ?? this.model; - this.temperature = fields?.temperature ?? this.temperature; - this.minTokens = fields?.minTokens ?? this.minTokens; - this.maxTokens = fields?.maxTokens ?? this.maxTokens; - this.bestOf = fields?.bestOf ?? this.bestOf; - this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty; - this.logprobs = fields?.logprobs ?? this.logprobs; - this.n = fields?.n ?? this.n; - this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty; - this.topP = fields?.topP ?? this.topP; - } - - _llmType() { - return "writer"; - } - - /** @ignore */ - async _call( - prompt: string, - options: this["ParsedCallOptions"] - ): Promise { - const sdk = new WriterClient({ - security: { - apiKey: this.apiKey, - }, - organizationId: this.orgId, - }); - - return this.caller.callWithOptions({ signal: options.signal }, async () => { - try { - const res = await sdk.completions.create({ - completionRequest: { - prompt, - stop: options.stop, - temperature: this.temperature, - minTokens: this.minTokens, - maxTokens: this.maxTokens, - bestOf: this.bestOf, - n: this.n, - frequencyPenalty: this.frequencyPenalty, - logprobs: this.logprobs, - presencePenalty: this.presencePenalty, - topP: this.topP, - }, - modelId: this.model, - }); - return ( - res.completionResponse?.choices?.[0].text ?? "No completion found." - ); - } catch (e) { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (e as any).response = (e as any).rawResponse; - throw e; - } - }); - } -} +export * from "@langchain/community/llms/writer"; diff --git a/langchain/src/llms/yandex.ts b/langchain/src/llms/yandex.ts index 96b70e7ced55..d19f0acf6b0c 100644 --- a/langchain/src/llms/yandex.ts +++ b/langchain/src/llms/yandex.ts @@ -1,123 +1 @@ -import { getEnvironmentVariable } from "../util/env.js"; -import { LLM, BaseLLMParams } from "./base.js"; - -const apiUrl = "https://llm.api.cloud.yandex.net/llm/v1alpha/instruct"; - -export interface YandexGPTInputs extends BaseLLMParams { - /** - * What sampling temperature to use. - * Should be a double number between 0 (inclusive) and 1 (inclusive). - */ - temperature?: number; - - /** - * Maximum limit on the total number of tokens - * used for both the input prompt and the generated response. - */ - maxTokens?: number; - - /** Model name to use. */ - model?: string; - - /** - * Yandex Cloud Api Key for service account - * with the `ai.languageModels.user` role. - */ - apiKey?: string; - - /** - * Yandex Cloud IAM token for service account - * with the `ai.languageModels.user` role. - */ - iamToken?: string; -} - -export class YandexGPT extends LLM implements YandexGPTInputs { - static lc_name() { - return "Yandex GPT"; - } - - get lc_secrets(): { [key: string]: string } | undefined { - return { - apiKey: "YC_API_KEY", - iamToken: "YC_IAM_TOKEN", - }; - } - - temperature = 0.6; - - maxTokens = 1700; - - model = "general"; - - apiKey?: string; - - iamToken?: string; - - constructor(fields?: YandexGPTInputs) { - super(fields ?? {}); - - const apiKey = fields?.apiKey ?? getEnvironmentVariable("YC_API_KEY"); - - const iamToken = fields?.iamToken ?? getEnvironmentVariable("YC_IAM_TOKEN"); - - if (apiKey === undefined && iamToken === undefined) { - throw new Error( - "Please set the YC_API_KEY or YC_IAM_TOKEN environment variable or pass it to the constructor as the apiKey or iamToken field." - ); - } - - this.apiKey = apiKey; - this.iamToken = iamToken; - this.maxTokens = fields?.maxTokens ?? this.maxTokens; - this.temperature = fields?.temperature ?? this.temperature; - this.model = fields?.model ?? this.model; - } - - _llmType() { - return "yandexgpt"; - } - - /** @ignore */ - async _call( - prompt: string, - options: this["ParsedCallOptions"] - ): Promise { - // Hit the `generate` endpoint on the `large` model - return this.caller.callWithOptions({ signal: options.signal }, async () => { - const headers = { "Content-Type": "application/json", Authorization: "" }; - if (this.apiKey !== undefined) { - headers.Authorization = `Api-Key ${this.apiKey}`; - } else { - headers.Authorization = `Bearer ${this.iamToken}`; - } - const bodyData = { - model: this.model, - generationOptions: { - temperature: this.temperature, - maxTokens: this.maxTokens, - }, - - requestText: prompt, - }; - - try { - const response = await fetch(apiUrl, { - method: "POST", - headers, - body: JSON.stringify(bodyData), - }); - if (!response.ok) { - throw new Error( - `Failed to fetch ${apiUrl} from YandexGPT: ${response.status}` - ); - } - - const responseData = await response.json(); - return responseData.result.alternatives[0].text; - } catch (error) { - throw new Error(`Failed to fetch ${apiUrl} from YandexGPT ${error}`); - } - }); - } -} +export * from "@langchain/community/llms/yandex"; diff --git a/langchain/src/load/import_type.d.ts b/langchain/src/load/import_type.d.ts index 9b2b3abbde95..5add518536fb 100644 --- a/langchain/src/load/import_type.d.ts +++ b/langchain/src/load/import_type.d.ts @@ -522,59 +522,9 @@ export interface OptionalImportMap { export interface SecretMap { ANTHROPIC_API_KEY?: string; AWS_ACCESS_KEY_ID?: string; - AWS_SECRETE_ACCESS_KEY?: string; AWS_SECRET_ACCESS_KEY?: string; - AWS_SESSION_TOKEN?: string; - AZURE_OPENAI_API_KEY?: string; - BAIDU_API_KEY?: string; - BAIDU_SECRET_KEY?: string; - BEDROCK_AWS_ACCESS_KEY_ID?: string; - BEDROCK_AWS_SECRET_ACCESS_KEY?: string; - CLOUDFLARE_API_TOKEN?: string; - COHERE_API_KEY?: string; - DATABERRY_API_KEY?: string; - FIREWORKS_API_KEY?: string; - GOOGLE_API_KEY?: string; - GOOGLE_PALM_API_KEY?: string; - GOOGLE_PLACES_API_KEY?: string; - GOOGLE_VERTEX_AI_WEB_CREDENTIALS?: string; - GRADIENT_ACCESS_TOKEN?: string; - GRADIENT_WORKSPACE_ID?: string; - HUGGINGFACEHUB_API_KEY?: string; - IBM_CLOUD_API_KEY?: string; - IFLYTEK_API_KEY?: string; - IFLYTEK_API_SECRET?: string; - MILVUS_PASSWORD?: string; - MILVUS_SSL?: string; - MILVUS_USERNAME?: string; - MINIMAX_API_KEY?: string; - MINIMAX_GROUP_ID?: string; OPENAI_API_KEY?: string; - OPENAI_ORGANIZATION?: string; - PLANETSCALE_DATABASE_URL?: string; - PLANETSCALE_HOST?: string; - PLANETSCALE_PASSWORD?: string; - PLANETSCALE_USERNAME?: string; PROMPTLAYER_API_KEY?: string; - QDRANT_API_KEY?: string; - QDRANT_URL?: string; - REDIS_PASSWORD?: string; - REDIS_URL?: string; - REDIS_USERNAME?: string; REMOTE_RETRIEVER_AUTH_BEARER?: string; - REPLICATE_API_TOKEN?: string; - SEARXNG_API_BASE?: string; - UPSTASH_REDIS_REST_TOKEN?: string; - UPSTASH_REDIS_REST_URL?: string; - VECTARA_API_KEY?: string; - VECTARA_CORPUS_ID?: string; - VECTARA_CUSTOMER_ID?: string; - WATSONX_PROJECT_ID?: string; - WRITER_API_KEY?: string; - WRITER_ORG_ID?: string; - YC_API_KEY?: string; - YC_IAM_TOKEN?: string; ZAPIER_NLA_API_KEY?: string; - ZEP_API_KEY?: string; - ZEP_API_URL?: string; } diff --git a/langchain/src/load/index.ts b/langchain/src/load/index.ts index 0f4fe863a1e7..c0522c9bd151 100644 --- a/langchain/src/load/index.ts +++ b/langchain/src/load/index.ts @@ -7,7 +7,8 @@ export async function load( text: string, // eslint-disable-next-line @typescript-eslint/no-explicit-any secretsMap: Record = {}, - optionalImportsMap: OptionalImportMap = {} + // eslint-disable-next-line @typescript-eslint/no-explicit-any + optionalImportsMap: OptionalImportMap & Record = {} ): Promise { return coreLoad(text, { secretsMap, diff --git a/langchain/src/retrievers/amazon_kendra.ts b/langchain/src/retrievers/amazon_kendra.ts index fb2ba2123b4f..8a734c4eea6f 100644 --- a/langchain/src/retrievers/amazon_kendra.ts +++ b/langchain/src/retrievers/amazon_kendra.ts @@ -1,317 +1 @@ -import { - AttributeFilter, - DocumentAttribute, - DocumentAttributeValue, - KendraClient, - KendraClientConfig, - QueryCommand, - QueryCommandOutput, - QueryResultItem, - RetrieveCommand, - RetrieveCommandOutput, - RetrieveResultItem, -} from "@aws-sdk/client-kendra"; - -import { BaseRetriever } from "../schema/retriever.js"; -import { Document } from "../document.js"; - -/** - * Interface for the arguments required to initialize an - * AmazonKendraRetriever instance. - */ -export interface AmazonKendraRetrieverArgs { - indexId: string; - topK: number; - region: string; - attributeFilter?: AttributeFilter; - clientOptions?: KendraClientConfig; -} - -/** - * Class for interacting with Amazon Kendra, an intelligent search service - * provided by AWS. Extends the BaseRetriever class. - * @example - * ```typescript - * const retriever = new AmazonKendraRetriever({ - * topK: 10, - * indexId: "YOUR_INDEX_ID", - * region: "us-east-2", - * clientOptions: { - * credentials: { - * accessKeyId: "YOUR_ACCESS_KEY_ID", - * secretAccessKey: "YOUR_SECRET_ACCESS_KEY", - * }, - * }, - * }); - * - * const docs = await retriever.getRelevantDocuments("How are clouds formed?"); - * ``` - */ -export class AmazonKendraRetriever extends BaseRetriever { - static lc_name() { - return "AmazonKendraRetriever"; - } - - lc_namespace = ["langchain", "retrievers", "amazon_kendra"]; - - indexId: string; - - topK: number; - - kendraClient: KendraClient; - - attributeFilter?: AttributeFilter; - - constructor({ - indexId, - topK = 10, - clientOptions, - attributeFilter, - region, - }: AmazonKendraRetrieverArgs) { - super(); - - if (!region) { - throw new Error("Please pass regionName field to the constructor!"); - } - - if (!indexId) { - throw new Error("Please pass Kendra Index Id to the constructor"); - } - - this.topK = topK; - this.kendraClient = new KendraClient({ - region, - ...clientOptions, - }); - this.attributeFilter = attributeFilter; - this.indexId = indexId; - } - - // A method to combine title and excerpt into a single string. - /** - * Combines title and excerpt into a single string. - * @param title The title of the document. - * @param excerpt An excerpt from the document. - * @returns A single string combining the title and excerpt. - */ - combineText(title?: string, excerpt?: string): string { - let text = ""; - if (title) { - text += `Document Title: ${title}\n`; - } - if (excerpt) { - text += `Document Excerpt: \n${excerpt}\n`; - } - return text; - } - - // A method to clean the result text by replacing sequences of whitespace with a single space and removing ellipses. - /** - * Cleans the result text by replacing sequences of whitespace with a - * single space and removing ellipses. - * @param resText The result text to clean. - * @returns The cleaned result text. - */ - cleanResult(resText: string) { - const res = resText.replace(/\s+/g, " ").replace(/\.\.\./g, ""); - return res; - } - - // A method to extract the attribute value from a DocumentAttributeValue object. - /** - * Extracts the attribute value from a DocumentAttributeValue object. - * @param value The DocumentAttributeValue object to extract the value from. - * @returns The extracted attribute value. - */ - getDocAttributeValue(value: DocumentAttributeValue) { - if (value.DateValue) { - return value.DateValue; - } - if (value.LongValue) { - return value.LongValue; - } - if (value.StringListValue) { - return value.StringListValue; - } - if (value.StringValue) { - return value.StringValue; - } - return ""; - } - - // A method to extract the attribute key-value pairs from an array of DocumentAttribute objects. - /** - * Extracts the attribute key-value pairs from an array of - * DocumentAttribute objects. - * @param documentAttributes The array of DocumentAttribute objects to extract the key-value pairs from. - * @returns An object containing the extracted attribute key-value pairs. - */ - getDocAttributes(documentAttributes?: DocumentAttribute[]): { - [key: string]: unknown; - } { - const attributes: { [key: string]: unknown } = {}; - if (documentAttributes) { - for (const attr of documentAttributes) { - if (attr.Key && attr.Value) { - attributes[attr.Key] = this.getDocAttributeValue(attr.Value); - } - } - } - return attributes; - } - - // A method to convert a RetrieveResultItem object into a Document object. - /** - * Converts a RetrieveResultItem object into a Document object. - * @param item The RetrieveResultItem object to convert. - * @returns A Document object. - */ - convertRetrieverItem(item: RetrieveResultItem) { - const title = item.DocumentTitle || ""; - const excerpt = item.Content ? this.cleanResult(item.Content) : ""; - const pageContent = this.combineText(title, excerpt); - const source = item.DocumentURI; - const attributes = this.getDocAttributes(item.DocumentAttributes); - const metadata = { - source, - title, - excerpt, - document_attributes: attributes, - }; - - return new Document({ pageContent, metadata }); - } - - // A method to extract the top-k documents from a RetrieveCommandOutput object. - /** - * Extracts the top-k documents from a RetrieveCommandOutput object. - * @param response The RetrieveCommandOutput object to extract the documents from. - * @param pageSize The number of documents to extract. - * @returns An array of Document objects. - */ - getRetrieverDocs( - response: RetrieveCommandOutput, - pageSize: number - ): Document[] { - if (!response.ResultItems) return []; - const { length } = response.ResultItems; - const count = length < pageSize ? length : pageSize; - - return response.ResultItems.slice(0, count).map((item) => - this.convertRetrieverItem(item) - ); - } - - // A method to extract the excerpt text from a QueryResultItem object. - /** - * Extracts the excerpt text from a QueryResultItem object. - * @param item The QueryResultItem object to extract the excerpt text from. - * @returns The extracted excerpt text. - */ - getQueryItemExcerpt(item: QueryResultItem) { - if ( - item.AdditionalAttributes && - item.AdditionalAttributes.length && - item.AdditionalAttributes[0].Key === "AnswerText" - ) { - if (!item.AdditionalAttributes) { - return ""; - } - if (!item.AdditionalAttributes[0]) { - return ""; - } - - return this.cleanResult( - item.AdditionalAttributes[0].Value?.TextWithHighlightsValue?.Text || "" - ); - } else if (item.DocumentExcerpt) { - return this.cleanResult(item.DocumentExcerpt.Text || ""); - } else { - return ""; - } - } - - // A method to convert a QueryResultItem object into a Document object. - /** - * Converts a QueryResultItem object into a Document object. - * @param item The QueryResultItem object to convert. - * @returns A Document object. - */ - convertQueryItem(item: QueryResultItem) { - const title = item.DocumentTitle?.Text || ""; - const excerpt = this.getQueryItemExcerpt(item); - const pageContent = this.combineText(title, excerpt); - const source = item.DocumentURI; - const attributes = this.getDocAttributes(item.DocumentAttributes); - const metadata = { - source, - title, - excerpt, - document_attributes: attributes, - }; - - return new Document({ pageContent, metadata }); - } - - // A method to extract the top-k documents from a QueryCommandOutput object. - /** - * Extracts the top-k documents from a QueryCommandOutput object. - * @param response The QueryCommandOutput object to extract the documents from. - * @param pageSize The number of documents to extract. - * @returns An array of Document objects. - */ - getQueryDocs(response: QueryCommandOutput, pageSize: number) { - if (!response.ResultItems) return []; - const { length } = response.ResultItems; - const count = length < pageSize ? length : pageSize; - return response.ResultItems.slice(0, count).map((item) => - this.convertQueryItem(item) - ); - } - - // A method to send a retrieve or query request to Kendra and return the top-k documents. - /** - * Sends a retrieve or query request to Kendra and returns the top-k - * documents. - * @param query The query to send to Kendra. - * @param topK The number of top documents to return. - * @param attributeFilter Optional filter to apply when retrieving documents. - * @returns A Promise that resolves to an array of Document objects. - */ - async queryKendra( - query: string, - topK: number, - attributeFilter?: AttributeFilter - ) { - const retrieveCommand = new RetrieveCommand({ - IndexId: this.indexId, - QueryText: query, - PageSize: topK, - AttributeFilter: attributeFilter, - }); - - const retrieveResponse = await this.kendraClient.send(retrieveCommand); - const retriveLength = retrieveResponse.ResultItems?.length; - - if (retriveLength === 0) { - // Retrieve API returned 0 results, call query API - const queryCommand = new QueryCommand({ - IndexId: this.indexId, - QueryText: query, - PageSize: topK, - AttributeFilter: attributeFilter, - }); - - const queryResponse = await this.kendraClient.send(queryCommand); - return this.getQueryDocs(queryResponse, this.topK); - } else { - return this.getRetrieverDocs(retrieveResponse, this.topK); - } - } - - async _getRelevantDocuments(query: string): Promise { - const docs = await this.queryKendra(query, this.topK, this.attributeFilter); - return docs; - } -} +export * from "@langchain/community/retrievers/amazon_kendra"; diff --git a/langchain/src/retrievers/chaindesk.ts b/langchain/src/retrievers/chaindesk.ts index 26175ede4b98..f61fccbf7e10 100644 --- a/langchain/src/retrievers/chaindesk.ts +++ b/langchain/src/retrievers/chaindesk.ts @@ -1,97 +1 @@ -import { BaseRetriever, type BaseRetrieverInput } from "../schema/retriever.js"; -import { Document } from "../document.js"; -import { AsyncCaller, type AsyncCallerParams } from "../util/async_caller.js"; - -export interface ChaindeskRetrieverArgs - extends AsyncCallerParams, - BaseRetrieverInput { - datastoreId: string; - topK?: number; - filter?: Record; - apiKey?: string; -} - -interface Berry { - text: string; - score: number; - source?: string; - [key: string]: unknown; -} - -/** - * @example - * ```typescript - * const retriever = new ChaindeskRetriever({ - * datastoreId: "DATASTORE_ID", - * apiKey: "CHAINDESK_API_KEY", - * topK: 8, - * }); - * const docs = await retriever.getRelevantDocuments("hello"); - * ``` - */ -export class ChaindeskRetriever extends BaseRetriever { - static lc_name() { - return "ChaindeskRetriever"; - } - - lc_namespace = ["langchain", "retrievers", "chaindesk"]; - - caller: AsyncCaller; - - datastoreId: string; - - topK?: number; - - filter?: Record; - - apiKey?: string; - - constructor({ - datastoreId, - apiKey, - topK, - filter, - ...rest - }: ChaindeskRetrieverArgs) { - super(); - - this.caller = new AsyncCaller(rest); - this.datastoreId = datastoreId; - this.apiKey = apiKey; - this.topK = topK; - this.filter = filter; - } - - async getRelevantDocuments(query: string): Promise { - const r = await this.caller.call( - fetch, - `https://app.chaindesk.ai/api/datastores/${this.datastoreId}/query`, - { - method: "POST", - body: JSON.stringify({ - query, - ...(this.topK ? { topK: this.topK } : {}), - ...(this.filter ? { filters: this.filter } : {}), - }), - headers: { - "Content-Type": "application/json", - ...(this.apiKey ? { Authorization: `Bearer ${this.apiKey}` } : {}), - }, - } - ); - - const { results } = (await r.json()) as { results: Berry[] }; - - return results.map( - ({ text, score, source, ...rest }) => - new Document({ - pageContent: text, - metadata: { - score, - source, - ...rest, - }, - }) - ); - } -} +export * from "@langchain/community/retrievers/chaindesk"; diff --git a/langchain/src/retrievers/databerry.ts b/langchain/src/retrievers/databerry.ts index 3a8358d5b82a..39eb9679c95e 100644 --- a/langchain/src/retrievers/databerry.ts +++ b/langchain/src/retrievers/databerry.ts @@ -1,94 +1 @@ -import { BaseRetriever, BaseRetrieverInput } from "../schema/retriever.js"; -import { Document } from "../document.js"; -import { AsyncCaller, AsyncCallerParams } from "../util/async_caller.js"; - -/** - * Interface for the arguments required to create a new instance of - * DataberryRetriever. - */ -export interface DataberryRetrieverArgs - extends AsyncCallerParams, - BaseRetrieverInput { - datastoreUrl: string; - topK?: number; - apiKey?: string; -} - -/** - * Interface for the structure of a Berry object returned by the Databerry - * API. - */ -interface Berry { - text: string; - score: number; - source?: string; - [key: string]: unknown; -} - -/** - * A specific implementation of a document retriever for the Databerry - * API. It extends the BaseRetriever class, which is an abstract base - * class for a document retrieval system in LangChain. - */ -/** @deprecated Use "langchain/retrievers/chaindesk" instead */ -export class DataberryRetriever extends BaseRetriever { - static lc_name() { - return "DataberryRetriever"; - } - - lc_namespace = ["langchain", "retrievers", "databerry"]; - - get lc_secrets() { - return { apiKey: "DATABERRY_API_KEY" }; - } - - get lc_aliases() { - return { apiKey: "api_key" }; - } - - caller: AsyncCaller; - - datastoreUrl: string; - - topK?: number; - - apiKey?: string; - - constructor(fields: DataberryRetrieverArgs) { - super(fields); - const { datastoreUrl, apiKey, topK, ...rest } = fields; - - this.caller = new AsyncCaller(rest); - this.datastoreUrl = datastoreUrl; - this.apiKey = apiKey; - this.topK = topK; - } - - async _getRelevantDocuments(query: string): Promise { - const r = await this.caller.call(fetch, this.datastoreUrl, { - method: "POST", - body: JSON.stringify({ - query, - ...(this.topK ? { topK: this.topK } : {}), - }), - headers: { - "Content-Type": "application/json", - ...(this.apiKey ? { Authorization: `Bearer ${this.apiKey}` } : {}), - }, - }); - - const { results } = (await r.json()) as { results: Berry[] }; - - return results.map( - ({ text, score, source, ...rest }) => - new Document({ - pageContent: text, - metadata: { - score, - source, - ...rest, - }, - }) - ); - } -} +export * from "@langchain/community/retrievers/databerry"; diff --git a/langchain/src/retrievers/metal.ts b/langchain/src/retrievers/metal.ts index 2632e03826eb..ea8dbb1f03e0 100644 --- a/langchain/src/retrievers/metal.ts +++ b/langchain/src/retrievers/metal.ts @@ -1,70 +1 @@ -import Metal from "@getmetal/metal-sdk"; - -import { BaseRetriever, BaseRetrieverInput } from "../schema/retriever.js"; -import { Document } from "../document.js"; - -/** - * Interface for the fields required during the initialization of a - * `MetalRetriever` instance. It extends the `BaseRetrieverInput` - * interface and adds a `client` field of type `Metal`. - */ -export interface MetalRetrieverFields extends BaseRetrieverInput { - client: Metal; -} - -/** - * Interface to represent a response item from the Metal service. It - * contains a `text` field and an index signature to allow for additional - * unknown properties. - */ -interface ResponseItem { - text: string; - [key: string]: unknown; -} - -/** - * Class used to interact with the Metal service, a managed retrieval & - * memory platform. It allows you to index your data into Metal and run - * semantic search and retrieval on it. It extends the `BaseRetriever` - * class and requires a `Metal` instance and a dictionary of parameters to - * pass to the Metal API during its initialization. - * @example - * ```typescript - * const retriever = new MetalRetriever({ - * client: new Metal( - * process.env.METAL_API_KEY, - * process.env.METAL_CLIENT_ID, - * process.env.METAL_INDEX_ID, - * ), - * }); - * const docs = await retriever.getRelevantDocuments("hello"); - * ``` - */ -export class MetalRetriever extends BaseRetriever { - static lc_name() { - return "MetalRetriever"; - } - - lc_namespace = ["langchain", "retrievers", "metal"]; - - private client: Metal; - - constructor(fields: MetalRetrieverFields) { - super(fields); - - this.client = fields.client; - } - - async _getRelevantDocuments(query: string): Promise { - const res = await this.client.search({ text: query }); - - const items = ("data" in res ? res.data : res) as ResponseItem[]; - return items.map( - ({ text, metadata }) => - new Document({ - pageContent: text, - metadata: metadata as Record, - }) - ); - } -} +export * from "@langchain/community/retrievers/metal"; diff --git a/langchain/src/retrievers/supabase.ts b/langchain/src/retrievers/supabase.ts index ec906c42a8d7..bd3ea5d07c7c 100644 --- a/langchain/src/retrievers/supabase.ts +++ b/langchain/src/retrievers/supabase.ts @@ -1,238 +1 @@ -import type { SupabaseClient } from "@supabase/supabase-js"; -import { Embeddings } from "../embeddings/base.js"; -import { Document } from "../document.js"; -import { BaseRetriever, BaseRetrieverInput } from "../schema/retriever.js"; -import { - CallbackManagerForRetrieverRun, - Callbacks, -} from "../callbacks/manager.js"; - -interface SearchEmbeddingsParams { - query_embedding: number[]; - match_count: number; // int - filter?: Record; // jsonb -} - -interface SearchKeywordParams { - query_text: string; - match_count: number; // int -} - -interface SearchResponseRow { - id: number; - content: string; - metadata: object; - similarity: number; -} - -type SearchResult = [Document, number, number]; - -export interface SupabaseLibArgs extends BaseRetrieverInput { - client: SupabaseClient; - /** - * The table name on Supabase. Defaults to "documents". - */ - tableName?: string; - /** - * The name of the Similarity search function on Supabase. Defaults to "match_documents". - */ - similarityQueryName?: string; - /** - * The name of the Keyword search function on Supabase. Defaults to "kw_match_documents". - */ - keywordQueryName?: string; - /** - * The number of documents to return from the similarity search. Defaults to 2. - */ - similarityK?: number; - /** - * The number of documents to return from the keyword search. Defaults to 2. - */ - keywordK?: number; -} - -export interface SupabaseHybridSearchParams { - query: string; - similarityK: number; - keywordK: number; -} - -/** - * Class for performing hybrid search operations on a Supabase database. - * It extends the `BaseRetriever` class and implements methods for - * similarity search, keyword search, and hybrid search. - */ -export class SupabaseHybridSearch extends BaseRetriever { - static lc_name() { - return "SupabaseHybridSearch"; - } - - lc_namespace = ["langchain", "retrievers", "supabase"]; - - similarityK: number; - - query: string; - - keywordK: number; - - similarityQueryName: string; - - client: SupabaseClient; - - tableName: string; - - keywordQueryName: string; - - embeddings: Embeddings; - - constructor(embeddings: Embeddings, args: SupabaseLibArgs) { - super(args); - this.embeddings = embeddings; - this.client = args.client; - this.tableName = args.tableName || "documents"; - this.similarityQueryName = args.similarityQueryName || "match_documents"; - this.keywordQueryName = args.keywordQueryName || "kw_match_documents"; - this.similarityK = args.similarityK || 2; - this.keywordK = args.keywordK || 2; - } - - /** - * Performs a similarity search on the Supabase database using the - * provided query and returns the top 'k' similar documents. - * @param query The query to use for the similarity search. - * @param k The number of top similar documents to return. - * @param _callbacks Optional callbacks to pass to the embedQuery method. - * @returns A promise that resolves to an array of search results. Each result is a tuple containing a Document, its similarity score, and its ID. - */ - protected async similaritySearch( - query: string, - k: number, - _callbacks?: Callbacks // implement passing to embedQuery later - ): Promise { - const embeddedQuery = await this.embeddings.embedQuery(query); - - const matchDocumentsParams: SearchEmbeddingsParams = { - query_embedding: embeddedQuery, - match_count: k, - }; - - if (Object.keys(this.metadata ?? {}).length > 0) { - matchDocumentsParams.filter = this.metadata; - } - - const { data: searches, error } = await this.client.rpc( - this.similarityQueryName, - matchDocumentsParams - ); - - if (error) { - throw new Error( - `Error searching for documents: ${error.code} ${error.message} ${error.details}` - ); - } - - return (searches as SearchResponseRow[]).map((resp) => [ - new Document({ - metadata: resp.metadata, - pageContent: resp.content, - }), - resp.similarity, - resp.id, - ]); - } - - /** - * Performs a keyword search on the Supabase database using the provided - * query and returns the top 'k' documents that match the keywords. - * @param query The query to use for the keyword search. - * @param k The number of top documents to return that match the keywords. - * @returns A promise that resolves to an array of search results. Each result is a tuple containing a Document, its similarity score multiplied by 10, and its ID. - */ - protected async keywordSearch( - query: string, - k: number - ): Promise { - const kwMatchDocumentsParams: SearchKeywordParams = { - query_text: query, - match_count: k, - }; - - const { data: searches, error } = await this.client.rpc( - this.keywordQueryName, - kwMatchDocumentsParams - ); - - if (error) { - throw new Error( - `Error searching for documents: ${error.code} ${error.message} ${error.details}` - ); - } - - return (searches as SearchResponseRow[]).map((resp) => [ - new Document({ - metadata: resp.metadata, - pageContent: resp.content, - }), - resp.similarity * 10, - resp.id, - ]); - } - - /** - * Combines the results of the `similaritySearch` and `keywordSearch` - * methods and returns the top 'k' documents based on a combination of - * similarity and keyword matching. - * @param query The query to use for the hybrid search. - * @param similarityK The number of top similar documents to return. - * @param keywordK The number of top documents to return that match the keywords. - * @param callbacks Optional callbacks to pass to the similaritySearch method. - * @returns A promise that resolves to an array of search results. Each result is a tuple containing a Document, its combined score, and its ID. - */ - protected async hybridSearch( - query: string, - similarityK: number, - keywordK: number, - callbacks?: Callbacks - ): Promise { - const similarity_search = this.similaritySearch( - query, - similarityK, - callbacks - ); - - const keyword_search = this.keywordSearch(query, keywordK); - - return Promise.all([similarity_search, keyword_search]) - .then((results) => results.flat()) - .then((results) => { - const picks = new Map(); - - results.forEach((result) => { - const id = result[2]; - const nextScore = result[1]; - const prevScore = picks.get(id)?.[1]; - - if (prevScore === undefined || nextScore > prevScore) { - picks.set(id, result); - } - }); - - return Array.from(picks.values()); - }) - .then((results) => results.sort((a, b) => b[1] - a[1])); - } - - async _getRelevantDocuments( - query: string, - runManager?: CallbackManagerForRetrieverRun - ): Promise { - const searchResults = await this.hybridSearch( - query, - this.similarityK, - this.keywordK, - runManager?.getChild("hybrid_search") - ); - - return searchResults.map(([doc]) => doc); - } -} +export * from "@langchain/community/retrievers/supabase"; diff --git a/langchain/src/retrievers/tavily_search_api.ts b/langchain/src/retrievers/tavily_search_api.ts index 7b65e36a6b89..5a59061172d6 100644 --- a/langchain/src/retrievers/tavily_search_api.ts +++ b/langchain/src/retrievers/tavily_search_api.ts @@ -1,140 +1 @@ -import { Document } from "../document.js"; -import { CallbackManagerForRetrieverRun } from "../callbacks/manager.js"; -import { BaseRetriever, type BaseRetrieverInput } from "../schema/retriever.js"; -import { getEnvironmentVariable } from "../util/env.js"; - -/** - * Options for the HydeRetriever class, which includes a BaseLanguageModel - * instance, a VectorStore instance, and an optional promptTemplate which - * can either be a BasePromptTemplate instance or a PromptKey. - */ -export type TavilySearchAPIRetrieverFields = BaseRetrieverInput & { - k?: number; - includeGeneratedAnswer?: boolean; - includeRawContent?: boolean; - includeImages?: boolean; - searchDepth?: "basic" | "advanced"; - includeDomains?: string[]; - excludeDomains?: string[]; - kwargs?: Record; - apiKey?: string; -}; - -/** - * A class for retrieving documents related to a given search term - * using the Tavily Search API. - */ -export class TavilySearchAPIRetriever extends BaseRetriever { - static lc_name() { - return "TavilySearchAPIRetriever"; - } - - get lc_namespace(): string[] { - return ["langchain", "retrievers", "tavily_search_api"]; - } - - k = 10; - - includeGeneratedAnswer = false; - - includeRawContent = false; - - includeImages = false; - - searchDepth = "basic"; - - includeDomains?: string[]; - - excludeDomains?: string[]; - - kwargs: Record = {}; - - apiKey?: string; - - constructor(fields?: TavilySearchAPIRetrieverFields) { - super(fields); - this.k = fields?.k ?? this.k; - this.includeGeneratedAnswer = - fields?.includeGeneratedAnswer ?? this.includeGeneratedAnswer; - this.includeRawContent = - fields?.includeRawContent ?? this.includeRawContent; - this.includeImages = fields?.includeImages ?? this.includeImages; - this.searchDepth = fields?.searchDepth ?? this.searchDepth; - this.includeDomains = fields?.includeDomains ?? this.includeDomains; - this.excludeDomains = fields?.excludeDomains ?? this.excludeDomains; - this.kwargs = fields?.kwargs ?? this.kwargs; - this.apiKey = fields?.apiKey ?? getEnvironmentVariable("TAVILY_API_KEY"); - if (this.apiKey === undefined) { - throw new Error( - `No Tavily API key found. Either set an environment variable named "TAVILY_API_KEY" or pass an API key as "apiKey".` - ); - } - } - - async _getRelevantDocuments( - query: string, - _runManager?: CallbackManagerForRetrieverRun - ): Promise { - const body: Record = { - query, - include_answer: this.includeGeneratedAnswer, - include_raw_content: this.includeRawContent, - include_images: this.includeImages, - max_results: this.k, - search_depth: this.searchDepth, - api_key: this.apiKey, - }; - if (this.includeDomains) { - body.include_domains = this.includeDomains; - } - if (this.excludeDomains) { - body.exclude_domains = this.excludeDomains; - } - - const response = await fetch("https://api.tavily.com/search", { - method: "POST", - headers: { - "content-type": "application/json", - }, - body: JSON.stringify({ ...body, ...this.kwargs }), - }); - const json = await response.json(); - if (!response.ok) { - throw new Error( - `Request failed with status code ${response.status}: ${json.error}` - ); - } - if (!Array.isArray(json.results)) { - throw new Error(`Could not parse Tavily results. Please try again.`); - } - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const docs: Document[] = json.results.map((result: any) => { - const pageContent = this.includeRawContent - ? result.raw_content - : result.content; - const metadata = { - title: result.title, - source: result.url, - ...Object.fromEntries( - Object.entries(result).filter( - ([k]) => !["content", "title", "url", "raw_content"].includes(k) - ) - ), - images: json.images, - }; - return new Document({ pageContent, metadata }); - }); - if (this.includeGeneratedAnswer) { - docs.push( - new Document({ - pageContent: json.answer, - metadata: { - title: "Suggested Answer", - source: "https://tavily.com/", - }, - }) - ); - } - return docs; - } -} +export * from "@langchain/community/retrievers/tavily_search_api"; diff --git a/langchain/src/retrievers/zep.ts b/langchain/src/retrievers/zep.ts index f87d38ca79d8..319cff8852a6 100644 --- a/langchain/src/retrievers/zep.ts +++ b/langchain/src/retrievers/zep.ts @@ -1,169 +1 @@ -import { - MemorySearchPayload, - MemorySearchResult, - NotFoundError, - ZepClient, -} from "@getzep/zep-js"; -import { BaseRetriever, BaseRetrieverInput } from "../schema/retriever.js"; -import { Document } from "../document.js"; - -/** - * Configuration interface for the ZepRetriever class. Extends the - * BaseRetrieverInput interface. - * - * @argument {string} sessionId - The ID of the Zep session. - * @argument {string} url - The URL of the Zep API. - * @argument {number} [topK] - The number of results to return. - * @argument {string} [apiKey] - The API key for the Zep API. - * @argument [searchScope] [searchScope] - The scope of the search: "messages" or "summary". - * @argument [searchType] [searchType] - The type of search to perform: "similarity" or "mmr". - * @argument {number} [mmrLambda] - The lambda value for the MMR search. - * @argument {Record} [filter] - The metadata filter to apply to the search. - */ -export interface ZepRetrieverConfig extends BaseRetrieverInput { - sessionId: string; - url: string; - topK?: number; - apiKey?: string; - searchScope?: "messages" | "summary"; - searchType?: "similarity" | "mmr"; - mmrLambda?: number; - filter?: Record; -} - -/** - * Class for retrieving information from a Zep long-term memory store. - * Extends the BaseRetriever class. - * @example - * ```typescript - * const retriever = new ZepRetriever({ - * url: "http: - * sessionId: "session_exampleUUID", - * topK: 3, - * }); - * const query = "Can I drive red cars in France?"; - * const docs = await retriever.getRelevantDocuments(query); - * ``` - */ -export class ZepRetriever extends BaseRetriever { - static lc_name() { - return "ZepRetriever"; - } - - lc_namespace = ["langchain", "retrievers", "zep"]; - - get lc_secrets(): { [key: string]: string } | undefined { - return { - apiKey: "ZEP_API_KEY", - url: "ZEP_API_URL", - }; - } - - get lc_aliases(): { [key: string]: string } | undefined { - return { apiKey: "api_key" }; - } - - zepClientPromise: Promise; - - private sessionId: string; - - private topK?: number; - - private searchScope?: "messages" | "summary"; - - private searchType?: "similarity" | "mmr"; - - private mmrLambda?: number; - - private filter?: Record; - - constructor(config: ZepRetrieverConfig) { - super(config); - this.sessionId = config.sessionId; - this.topK = config.topK; - this.searchScope = config.searchScope; - this.searchType = config.searchType; - this.mmrLambda = config.mmrLambda; - this.filter = config.filter; - this.zepClientPromise = ZepClient.init(config.url, config.apiKey); - } - - /** - * Converts an array of message search results to an array of Document objects. - * @param {MemorySearchResult[]} results - The array of search results. - * @returns {Document[]} An array of Document objects representing the search results. - */ - private searchMessageResultToDoc(results: MemorySearchResult[]): Document[] { - return results - .filter((r) => r.message) - .map( - ({ - message: { content, metadata: messageMetadata } = {}, - dist, - ...rest - }) => - new Document({ - pageContent: content ?? "", - metadata: { score: dist, ...messageMetadata, ...rest }, - }) - ); - } - - /** - * Converts an array of summary search results to an array of Document objects. - * @param {MemorySearchResult[]} results - The array of search results. - * @returns {Document[]} An array of Document objects representing the search results. - */ - private searchSummaryResultToDoc(results: MemorySearchResult[]): Document[] { - return results - .filter((r) => r.summary) - .map( - ({ - summary: { content, metadata: summaryMetadata } = {}, - dist, - ...rest - }) => - new Document({ - pageContent: content ?? "", - metadata: { score: dist, ...summaryMetadata, ...rest }, - }) - ); - } - - /** - * Retrieves the relevant documents based on the given query. - * @param {string} query - The query string. - * @returns {Promise} A promise that resolves to an array of relevant Document objects. - */ - async _getRelevantDocuments(query: string): Promise { - const payload: MemorySearchPayload = { - text: query, - metadata: this.filter, - search_scope: this.searchScope, - search_type: this.searchType, - mmr_lambda: this.mmrLambda, - }; - // Wait for ZepClient to be initialized - const zepClient = await this.zepClientPromise; - if (!zepClient) { - throw new Error("ZepClient is not initialized"); - } - try { - const results: MemorySearchResult[] = await zepClient.memory.searchMemory( - this.sessionId, - payload, - this.topK - ); - return this.searchScope === "summary" - ? this.searchSummaryResultToDoc(results) - : this.searchMessageResultToDoc(results); - } catch (error) { - // eslint-disable-next-line no-instanceof/no-instanceof - if (error instanceof NotFoundError) { - return Promise.resolve([]); // Return an empty Document array - } - // If it's not a NotFoundError, throw the error again - throw error; - } - } -} +export * from "@langchain/community/retrievers/zep"; diff --git a/langchain/src/schema/document.ts b/langchain/src/schema/document.ts index 9e50f747c03f..8394191d16e2 100644 --- a/langchain/src/schema/document.ts +++ b/langchain/src/schema/document.ts @@ -1,21 +1,4 @@ -import { BaseDocumentTransformer } from "@langchain/core/documents"; -import { Document } from "../document.js"; - -export { BaseDocumentTransformer }; - -/** - * Class for document transformers that return exactly one transformed document - * for each input document. - */ -export abstract class MappingDocumentTransformer extends BaseDocumentTransformer { - async transformDocuments(documents: Document[]): Promise { - const newDocuments = []; - for (const document of documents) { - const transformedDocument = await this._transformDocument(document); - newDocuments.push(transformedDocument); - } - return newDocuments; - } - - abstract _transformDocument(document: Document): Promise; -} +export { + BaseDocumentTransformer, + MappingDocumentTransformer, +} from "@langchain/core/documents"; diff --git a/langchain/src/schema/index.ts b/langchain/src/schema/index.ts index 299068a67252..1fa111c7e7fa 100644 --- a/langchain/src/schema/index.ts +++ b/langchain/src/schema/index.ts @@ -1,11 +1,10 @@ -import type { OpenAI as OpenAIClient } from "openai"; +import type { OpenAIClient } from "@langchain/openai"; import { BaseMessage, HumanMessage, AIMessage, SystemMessage, } from "@langchain/core/messages"; -import { Document } from "../document.js"; import { Serializable } from "../load/serializable.js"; export { @@ -128,12 +127,4 @@ export abstract class BaseEntityStore extends Serializable { abstract clear(): Promise; } -/** - * Abstract class for a document store. All document stores should extend - * this class. - */ -export abstract class Docstore { - abstract search(search: string): Promise; - - abstract add(texts: Record): Promise; -} +export { Docstore } from "@langchain/community/stores/doc/base"; diff --git a/langchain/src/storage/convex.ts b/langchain/src/storage/convex.ts index e32b1647d151..3593ef2bfc30 100644 --- a/langchain/src/storage/convex.ts +++ b/langchain/src/storage/convex.ts @@ -1,224 +1 @@ -// eslint-disable-next-line import/no-extraneous-dependencies -import { - FieldPaths, - FunctionReference, - GenericActionCtx, - GenericDataModel, - NamedTableInfo, - TableNamesInDataModel, - VectorIndexNames, - makeFunctionReference, -} from "convex/server"; -// eslint-disable-next-line import/no-extraneous-dependencies -import { Value } from "convex/values"; -import { BaseStore } from "../schema/storage.js"; - -/** - * Type that defines the config required to initialize the - * ConvexKVStore class. It includes the table name, - * index name, field name. - */ -export type ConvexKVStoreConfig< - DataModel extends GenericDataModel, - TableName extends TableNamesInDataModel, - IndexName extends VectorIndexNames>, - KeyFieldName extends FieldPaths>, - ValueFieldName extends FieldPaths>, - UpsertMutation extends FunctionReference< - "mutation", - "internal", - { table: string; document: object } - >, - LookupQuery extends FunctionReference< - "query", - "internal", - { table: string; index: string; keyField: string; key: string }, - object[] - >, - DeleteManyMutation extends FunctionReference< - "mutation", - "internal", - { table: string; index: string; keyField: string; key: string } - > -> = { - readonly ctx: GenericActionCtx; - /** - * Defaults to "cache" - */ - readonly table?: TableName; - /** - * Defaults to "byKey" - */ - readonly index?: IndexName; - /** - * Defaults to "key" - */ - readonly keyField?: KeyFieldName; - /** - * Defaults to "value" - */ - readonly valueField?: ValueFieldName; - /** - * Defaults to `internal.langchain.db.upsert` - */ - readonly upsert?: UpsertMutation; - /** - * Defaults to `internal.langchain.db.lookup` - */ - readonly lookup?: LookupQuery; - /** - * Defaults to `internal.langchain.db.deleteMany` - */ - readonly deleteMany?: DeleteManyMutation; -}; - -/** - * Class that extends the BaseStore class to interact with a Convex - * database. It provides methods for getting, setting, and deleting key value pairs, - * as well as yielding keys from the database. - */ -export class ConvexKVStore< - T extends Value, - DataModel extends GenericDataModel, - TableName extends TableNamesInDataModel, - IndexName extends VectorIndexNames>, - KeyFieldName extends FieldPaths>, - ValueFieldName extends FieldPaths>, - UpsertMutation extends FunctionReference< - "mutation", - "internal", - { table: string; document: object } - >, - LookupQuery extends FunctionReference< - "query", - "internal", - { table: string; index: string; keyField: string; key: string }, - object[] - >, - DeleteManyMutation extends FunctionReference< - "mutation", - "internal", - { table: string; index: string; keyField: string; key: string } - > -> extends BaseStore { - lc_namespace = ["langchain", "storage", "convex"]; - - private readonly ctx: GenericActionCtx; - - private readonly table: TableName; - - private readonly index: IndexName; - - private readonly keyField: KeyFieldName; - - private readonly valueField: ValueFieldName; - - private readonly upsert: UpsertMutation; - - private readonly lookup: LookupQuery; - - private readonly deleteMany: DeleteManyMutation; - - constructor( - config: ConvexKVStoreConfig< - DataModel, - TableName, - IndexName, - KeyFieldName, - ValueFieldName, - UpsertMutation, - LookupQuery, - DeleteManyMutation - > - ) { - super(config); - this.ctx = config.ctx; - this.table = config.table ?? ("cache" as TableName); - this.index = config.index ?? ("byKey" as IndexName); - this.keyField = config.keyField ?? ("key" as KeyFieldName); - this.valueField = config.valueField ?? ("value" as ValueFieldName); - this.upsert = - // eslint-disable-next-line @typescript-eslint/no-explicit-any - config.upsert ?? (makeFunctionReference("langchain/db:upsert") as any); - this.lookup = - // eslint-disable-next-line @typescript-eslint/no-explicit-any - config.lookup ?? (makeFunctionReference("langchain/db:lookup") as any); - this.deleteMany = - config.deleteMany ?? - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (makeFunctionReference("langchain/db:deleteMany") as any); - } - - /** - * Gets multiple keys from the Convex database. - * @param keys Array of keys to be retrieved. - * @returns An array of retrieved values. - */ - async mget(keys: string[]) { - return (await Promise.all( - keys.map(async (key) => { - const found = (await this.ctx.runQuery(this.lookup, { - table: this.table, - index: this.index, - keyField: this.keyField, - key, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } as any)) as any; - return found.length > 0 ? found[0][this.valueField] : undefined; - }) - )) as (T | undefined)[]; - } - - /** - * Sets multiple keys in the Convex database. - * @param keyValuePairs Array of key-value pairs to be set. - * @returns Promise that resolves when all keys have been set. - */ - async mset(keyValuePairs: [string, T][]): Promise { - // TODO: Remove chunking when Convex handles the concurrent requests correctly - const PAGE_SIZE = 16; - for (let i = 0; i < keyValuePairs.length; i += PAGE_SIZE) { - await Promise.all( - keyValuePairs.slice(i, i + PAGE_SIZE).map(([key, value]) => - this.ctx.runMutation(this.upsert, { - table: this.table, - index: this.index, - keyField: this.keyField, - key, - document: { [this.keyField]: key, [this.valueField]: value }, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } as any) - ) - ); - } - } - - /** - * Deletes multiple keys from the Convex database. - * @param keys Array of keys to be deleted. - * @returns Promise that resolves when all keys have been deleted. - */ - async mdelete(keys: string[]): Promise { - await Promise.all( - keys.map((key) => - this.ctx.runMutation(this.deleteMany, { - table: this.table, - index: this.index, - keyField: this.keyField, - key, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } as any) - ) - ); - } - - /** - * Yields keys from the Convex database. - * @param prefix Optional prefix to filter the keys. - * @returns An AsyncGenerator that yields keys from the Convex database. - */ - // eslint-disable-next-line require-yield - async *yieldKeys(_prefix?: string): AsyncGenerator { - throw new Error("yieldKeys not implemented yet for ConvexKVStore"); - } -} +export * from "@langchain/community/storage/convex"; diff --git a/langchain/src/storage/ioredis.ts b/langchain/src/storage/ioredis.ts index d5864f59b6ea..3e2d8437d910 100644 --- a/langchain/src/storage/ioredis.ts +++ b/langchain/src/storage/ioredis.ts @@ -1,159 +1 @@ -import type { Redis } from "ioredis"; - -import { BaseStore } from "../schema/storage.js"; - -/** - * Class that extends the BaseStore class to interact with a Redis - * database. It provides methods for getting, setting, and deleting data, - * as well as yielding keys from the database. - * @example - * ```typescript - * const store = new RedisByteStore({ client: new Redis({}) }); - * await store.mset([ - * [ - * "message:id:0", - * new TextEncoder().encode(JSON.stringify(new AIMessage("ai stuff..."))), - * ], - * [ - * "message:id:1", - * new TextEncoder().encode( - * JSON.stringify(new HumanMessage("human stuff...")), - * ), - * ], - * ]); - * const retrievedMessages = await store.mget(["message:id:0", "message:id:1"]); - * console.log(retrievedMessages.map((v) => new TextDecoder().decode(v))); - * const yieldedKeys = []; - * for await (const key of store.yieldKeys("message:id:")) { - * yieldedKeys.push(key); - * } - * console.log(yieldedKeys); - * await store.mdelete(yieldedKeys); - * ``` - */ -export class RedisByteStore extends BaseStore { - lc_namespace = ["langchain", "storage"]; - - protected client: Redis; - - protected ttl?: number; - - protected namespace?: string; - - protected yieldKeysScanBatchSize = 1000; - - constructor(fields: { - client: Redis; - ttl?: number; - namespace?: string; - yieldKeysScanBatchSize?: number; - }) { - super(fields); - this.client = fields.client; - this.ttl = fields.ttl; - this.namespace = fields.namespace; - this.yieldKeysScanBatchSize = - fields.yieldKeysScanBatchSize ?? this.yieldKeysScanBatchSize; - } - - _getPrefixedKey(key: string) { - if (this.namespace) { - const delimiter = "/"; - return `${this.namespace}${delimiter}${key}`; - } - return key; - } - - _getDeprefixedKey(key: string) { - if (this.namespace) { - const delimiter = "/"; - return key.slice(this.namespace.length + delimiter.length); - } - return key; - } - - /** - * Gets multiple keys from the Redis database. - * @param keys Array of keys to be retrieved. - * @returns An array of retrieved values. - */ - async mget(keys: string[]) { - const prefixedKeys = keys.map(this._getPrefixedKey.bind(this)); - const retrievedValues = await this.client.mgetBuffer(prefixedKeys); - return retrievedValues.map((value) => { - if (!value) { - return undefined; - } else { - return value; - } - }); - } - - /** - * Sets multiple keys in the Redis database. - * @param keyValuePairs Array of key-value pairs to be set. - * @returns Promise that resolves when all keys have been set. - */ - async mset(keyValuePairs: [string, Uint8Array][]): Promise { - const decoder = new TextDecoder(); - const encodedKeyValuePairs = keyValuePairs.map(([key, value]) => [ - this._getPrefixedKey(key), - decoder.decode(value), - ]); - const pipeline = this.client.pipeline(); - for (const [key, value] of encodedKeyValuePairs) { - if (this.ttl) { - pipeline.set(key, value, "EX", this.ttl); - } else { - pipeline.set(key, value); - } - } - await pipeline.exec(); - } - - /** - * Deletes multiple keys from the Redis database. - * @param keys Array of keys to be deleted. - * @returns Promise that resolves when all keys have been deleted. - */ - async mdelete(keys: string[]): Promise { - await this.client.del(...keys.map(this._getPrefixedKey.bind(this))); - } - - /** - * Yields keys from the Redis database. - * @param prefix Optional prefix to filter the keys. - * @returns An AsyncGenerator that yields keys from the Redis database. - */ - async *yieldKeys(prefix?: string): AsyncGenerator { - let pattern; - if (prefix) { - const wildcardPrefix = prefix.endsWith("*") ? prefix : `${prefix}*`; - pattern = this._getPrefixedKey(wildcardPrefix); - } else { - pattern = this._getPrefixedKey("*"); - } - let [cursor, batch] = await this.client.scan( - 0, - "MATCH", - pattern, - "COUNT", - this.yieldKeysScanBatchSize - ); - for (const key of batch) { - yield this._getDeprefixedKey(key); - } - while (cursor !== "0") { - [cursor, batch] = await this.client.scan( - cursor, - "MATCH", - pattern, - "COUNT", - this.yieldKeysScanBatchSize - ); - for (const key of batch) { - yield this._getDeprefixedKey(key); - } - } - } -} +export * from "@langchain/community/storage/ioredis"; diff --git a/langchain/src/storage/upstash_redis.ts b/langchain/src/storage/upstash_redis.ts index 313444399661..c8732435bbeb 100644 --- a/langchain/src/storage/upstash_redis.ts +++ b/langchain/src/storage/upstash_redis.ts @@ -1,176 +1 @@ -import { Redis as UpstashRedis, type RedisConfigNodejs } from "@upstash/redis"; - -import { BaseStore } from "../schema/storage.js"; - -/** - * Type definition for the input parameters required to initialize an - * instance of the UpstashStoreInput class. - */ -export interface UpstashRedisStoreInput { - sessionTTL?: number; - config?: RedisConfigNodejs; - client?: UpstashRedis; - /** - * The amount of keys to retrieve per batch when yielding keys. - * @default 1000 - */ - yieldKeysScanBatchSize?: number; - /** - * The namespace to use for the keys in the database. - */ - namespace?: string; -} - -/** - * Class that extends the BaseStore class to interact with an Upstash Redis - * database. It provides methods for getting, setting, and deleting data, - * as well as yielding keys from the database. - * @example - * ```typescript - * const store = new UpstashRedisStore({ - * client: new Redis({ - * url: "your-upstash-redis-url", - * token: "your-upstash-redis-token", - * }), - * }); - * await store.mset([ - * ["message:id:0", "encoded-ai-message"], - * ["message:id:1", "encoded-human-message"], - * ]); - * const retrievedMessages = await store.mget(["message:id:0", "message:id:1"]); - * const yieldedKeys = []; - * for await (const key of store.yieldKeys("message:id")) { - * yieldedKeys.push(key); - * } - * await store.mdelete(yieldedKeys); - * ``` - */ -export class UpstashRedisStore extends BaseStore { - lc_namespace = ["langchain", "storage"]; - - protected client: UpstashRedis; - - protected namespace?: string; - - protected yieldKeysScanBatchSize = 1000; - - private sessionTTL?: number; - - constructor(fields: UpstashRedisStoreInput) { - super(fields); - if (fields.client) { - this.client = fields.client; - } else if (fields.config) { - this.client = new UpstashRedis(fields.config); - } else { - throw new Error( - `Upstash Redis store requires either a config object or a pre-configured client.` - ); - } - this.sessionTTL = fields.sessionTTL; - this.yieldKeysScanBatchSize = - fields.yieldKeysScanBatchSize ?? this.yieldKeysScanBatchSize; - this.namespace = fields.namespace; - } - - _getPrefixedKey(key: string) { - if (this.namespace) { - const delimiter = "/"; - return `${this.namespace}${delimiter}${key}`; - } - return key; - } - - _getDeprefixedKey(key: string) { - if (this.namespace) { - const delimiter = "/"; - return key.slice(this.namespace.length + delimiter.length); - } - return key; - } - - /** - * Gets multiple keys from the Upstash Redis database. - * @param keys Array of keys to be retrieved. - * @returns An array of retrieved values. - */ - async mget(keys: string[]) { - const encoder = new TextEncoder(); - - const prefixedKeys = keys.map(this._getPrefixedKey.bind(this)); - const retrievedValues = await this.client.mget( - ...prefixedKeys - ); - return retrievedValues.map((value) => { - if (!value) { - return undefined; - } else if (typeof value === "object") { - return encoder.encode(JSON.stringify(value)); - } else { - return encoder.encode(value); - } - }); - } - - /** - * Sets multiple keys in the Upstash Redis database. - * @param keyValuePairs Array of key-value pairs to be set. - * @returns Promise that resolves when all keys have been set. - */ - async mset(keyValuePairs: [string, Uint8Array][]): Promise { - const decoder = new TextDecoder(); - const encodedKeyValuePairs = keyValuePairs.map(([key, value]) => [ - this._getPrefixedKey(key), - decoder.decode(value), - ]); - const pipeline = this.client.pipeline(); - for (const [key, value] of encodedKeyValuePairs) { - if (this.sessionTTL) { - pipeline.setex(key, this.sessionTTL, value); - } else { - pipeline.set(key, value); - } - } - await pipeline.exec(); - } - - /** - * Deletes multiple keys from the Upstash Redis database. - * @param keys Array of keys to be deleted. - * @returns Promise that resolves when all keys have been deleted. - */ - async mdelete(keys: string[]): Promise { - await this.client.del(...keys.map(this._getPrefixedKey.bind(this))); - } - - /** - * Yields keys from the Upstash Redis database. - * @param prefix Optional prefix to filter the keys. A wildcard (*) is always appended to the end. - * @returns An AsyncGenerator that yields keys from the Upstash Redis database. - */ - async *yieldKeys(prefix?: string): AsyncGenerator { - let pattern; - if (prefix) { - const wildcardPrefix = prefix.endsWith("*") ? prefix : `${prefix}*`; - pattern = `${this._getPrefixedKey(wildcardPrefix)}*`; - } else { - pattern = this._getPrefixedKey("*"); - } - let [cursor, batch] = await this.client.scan(0, { - match: pattern, - count: this.yieldKeysScanBatchSize, - }); - for (const key of batch) { - yield this._getDeprefixedKey(key); - } - while (cursor !== 0) { - [cursor, batch] = await this.client.scan(cursor, { - match: pattern, - count: this.yieldKeysScanBatchSize, - }); - for (const key of batch) { - yield this._getDeprefixedKey(key); - } - } - } -} +export * from "@langchain/community/storage/upstash_redis"; diff --git a/langchain/src/storage/vercel_kv.ts b/langchain/src/storage/vercel_kv.ts index 9b9646377713..1853b88290d7 100644 --- a/langchain/src/storage/vercel_kv.ts +++ b/langchain/src/storage/vercel_kv.ts @@ -1,150 +1 @@ -import { kv, type VercelKV } from "@vercel/kv"; - -import { BaseStore } from "../schema/storage.js"; - -/** - * Class that extends the BaseStore class to interact with a Vercel KV - * database. It provides methods for getting, setting, and deleting data, - * as well as yielding keys from the database. - * @example - * ```typescript - * const store = new VercelKVStore({ - * client: getClient(), - * }); - * await store.mset([ - * { key: "message:id:0", value: "encoded message 0" }, - * { key: "message:id:1", value: "encoded message 1" }, - * ]); - * const retrievedMessages = await store.mget(["message:id:0", "message:id:1"]); - * const yieldedKeys = []; - * for await (const key of store.yieldKeys("message:id:")) { - * yieldedKeys.push(key); - * } - * await store.mdelete(yieldedKeys); - * ``` - */ -export class VercelKVStore extends BaseStore { - lc_namespace = ["langchain", "storage"]; - - protected client: VercelKV; - - protected ttl?: number; - - protected namespace?: string; - - protected yieldKeysScanBatchSize = 1000; - - constructor(fields?: { - client?: VercelKV; - ttl?: number; - namespace?: string; - yieldKeysScanBatchSize?: number; - }) { - super(fields); - this.client = fields?.client ?? kv; - this.ttl = fields?.ttl; - this.namespace = fields?.namespace; - this.yieldKeysScanBatchSize = - fields?.yieldKeysScanBatchSize ?? this.yieldKeysScanBatchSize; - } - - _getPrefixedKey(key: string) { - if (this.namespace) { - const delimiter = "/"; - return `${this.namespace}${delimiter}${key}`; - } - return key; - } - - _getDeprefixedKey(key: string) { - if (this.namespace) { - const delimiter = "/"; - return key.slice(this.namespace.length + delimiter.length); - } - return key; - } - - /** - * Gets multiple keys from the Redis database. - * @param keys Array of keys to be retrieved. - * @returns An array of retrieved values. - */ - async mget(keys: string[]) { - const prefixedKeys = keys.map(this._getPrefixedKey.bind(this)); - const retrievedValues = await this.client.mget<(string | undefined)[]>( - ...prefixedKeys - ); - const encoder = new TextEncoder(); - return retrievedValues.map((value) => { - if (value === undefined || value === null) { - return undefined; - } else if (typeof value === "object") { - return encoder.encode(JSON.stringify(value)); - } else { - return encoder.encode(value); - } - }); - } - - /** - * Sets multiple keys in the Redis database. - * @param keyValuePairs Array of key-value pairs to be set. - * @returns Promise that resolves when all keys have been set. - */ - async mset(keyValuePairs: [string, Uint8Array][]): Promise { - const decoder = new TextDecoder(); - const decodedKeyValuePairs = keyValuePairs.map(([key, value]) => [ - this._getPrefixedKey(key), - decoder.decode(value), - ]); - const pipeline = this.client.pipeline(); - for (const [key, value] of decodedKeyValuePairs) { - if (this.ttl) { - pipeline.setex(key, this.ttl, value); - } else { - pipeline.set(key, value); - } - } - await pipeline.exec(); - } - - /** - * Deletes multiple keys from the Redis database. - * @param keys Array of keys to be deleted. - * @returns Promise that resolves when all keys have been deleted. - */ - async mdelete(keys: string[]): Promise { - await this.client.del(...keys.map(this._getPrefixedKey.bind(this))); - } - - /** - * Yields keys from the Redis database. - * @param prefix Optional prefix to filter the keys. - * @returns An AsyncGenerator that yields keys from the Redis database. - */ - async *yieldKeys(prefix?: string): AsyncGenerator { - let pattern; - if (prefix) { - const wildcardPrefix = prefix.endsWith("*") ? prefix : `${prefix}*`; - pattern = this._getPrefixedKey(wildcardPrefix); - } else { - pattern = this._getPrefixedKey("*"); - } - let [cursor, batch] = await this.client.scan(0, { - match: pattern, - count: this.yieldKeysScanBatchSize, - }); - for (const key of batch) { - yield this._getDeprefixedKey(key); - } - while (cursor !== 0) { - [cursor, batch] = await this.client.scan(cursor, { - match: pattern, - count: this.yieldKeysScanBatchSize, - }); - for (const key of batch) { - yield this._getDeprefixedKey(key); - } - } - } -} +export * from "@langchain/community/storage/vercel_kv"; diff --git a/langchain/src/stores/doc/in_memory.ts b/langchain/src/stores/doc/in_memory.ts index f11220f66170..5b3db5facd35 100644 --- a/langchain/src/stores/doc/in_memory.ts +++ b/langchain/src/stores/doc/in_memory.ts @@ -1,113 +1 @@ -import { Document } from "../../document.js"; -import { Docstore } from "../../schema/index.js"; -import { BaseStoreInterface } from "../../schema/storage.js"; - -/** - * Class for storing and retrieving documents in memory asynchronously. - * Extends the Docstore class. - */ -export class InMemoryDocstore - extends Docstore - implements BaseStoreInterface -{ - _docs: Map; - - constructor(docs?: Map) { - super(); - this._docs = docs ?? new Map(); - } - - /** - * Searches for a document in the store based on its ID. - * @param search The ID of the document to search for. - * @returns The document with the given ID. - */ - async search(search: string): Promise { - const result = this._docs.get(search); - if (!result) { - throw new Error(`ID ${search} not found.`); - } else { - return result; - } - } - - /** - * Adds new documents to the store. - * @param texts An object where the keys are document IDs and the values are the documents themselves. - * @returns Void - */ - async add(texts: Record): Promise { - const keys = [...this._docs.keys()]; - const overlapping = Object.keys(texts).filter((x) => keys.includes(x)); - - if (overlapping.length > 0) { - throw new Error(`Tried to add ids that already exist: ${overlapping}`); - } - - for (const [key, value] of Object.entries(texts)) { - this._docs.set(key, value); - } - } - - async mget(keys: string[]): Promise { - return Promise.all(keys.map((key) => this.search(key))); - } - - async mset(keyValuePairs: [string, Document][]): Promise { - await Promise.all( - keyValuePairs.map(([key, value]) => this.add({ [key]: value })) - ); - } - - async mdelete(_keys: string[]): Promise { - throw new Error("Not implemented."); - } - - // eslint-disable-next-line require-yield - async *yieldKeys(_prefix?: string): AsyncGenerator { - throw new Error("Not implemented"); - } -} - -/** - * Class for storing and retrieving documents in memory synchronously. - */ -export class SynchronousInMemoryDocstore { - _docs: Map; - - constructor(docs?: Map) { - this._docs = docs ?? new Map(); - } - - /** - * Searches for a document in the store based on its ID. - * @param search The ID of the document to search for. - * @returns The document with the given ID. - */ - search(search: string): Document { - const result = this._docs.get(search); - if (!result) { - throw new Error(`ID ${search} not found.`); - } else { - return result; - } - } - - /** - * Adds new documents to the store. - * @param texts An object where the keys are document IDs and the values are the documents themselves. - * @returns Void - */ - add(texts: Record): void { - const keys = [...this._docs.keys()]; - const overlapping = Object.keys(texts).filter((x) => keys.includes(x)); - - if (overlapping.length > 0) { - throw new Error(`Tried to add ids that already exist: ${overlapping}`); - } - - for (const [key, value] of Object.entries(texts)) { - this._docs.set(key, value); - } - } -} +export * from "@langchain/community/stores/doc/in_memory"; diff --git a/langchain/src/stores/message/cassandra.ts b/langchain/src/stores/message/cassandra.ts index 5e63b2b1f11f..88b1293512e5 100644 --- a/langchain/src/stores/message/cassandra.ts +++ b/langchain/src/stores/message/cassandra.ts @@ -1,154 +1 @@ -import { Client, DseClientOptions } from "cassandra-driver"; -import { - BaseMessage, - BaseListChatMessageHistory, - StoredMessage, -} from "../../schema/index.js"; -import { - mapChatMessagesToStoredMessages, - mapStoredMessagesToChatMessages, -} from "./utils.js"; - -export interface CassandraChatMessageHistoryOptions extends DseClientOptions { - keyspace: string; - table: string; - sessionId: string; -} - -/** - * Class for storing chat message history within Cassandra. It extends the - * BaseListChatMessageHistory class and provides methods to get, add, and - * clear messages. - * @example - * ```typescript - * const chatHistory = new CassandraChatMessageHistory({ - * cloud: { - * secureConnectBundle: "", - * }, - * credentials: { - * username: "token", - * password: "", - * }, - * keyspace: "langchain", - * table: "message_history", - * sessionId: "", - * }); - * - * const chain = new ConversationChain({ - * llm: new ChatOpenAI(), - * memory: chatHistory, - * }); - * - * const response = await chain.invoke({ - * input: "What did I just say my name was?", - * }); - * console.log({ response }); - * ``` - */ -export class CassandraChatMessageHistory extends BaseListChatMessageHistory { - lc_namespace = ["langchain", "stores", "message", "cassandra"]; - - private keyspace: string; - - private table: string; - - private client: Client; - - private sessionId: string; - - private tableExists: boolean; - - private options: CassandraChatMessageHistoryOptions; - - private queries: { insert: string; select: string; delete: string }; - - constructor(options: CassandraChatMessageHistoryOptions) { - super(); - this.client = new Client(options); - this.keyspace = options.keyspace; - this.table = options.table; - this.sessionId = options.sessionId; - this.tableExists = false; - this.options = options; - } - - /** - * Method to get all the messages stored in the Cassandra database. - * @returns Array of stored BaseMessage instances. - */ - public async getMessages(): Promise { - await this.ensureTable(); - const resultSet = await this.client.execute( - this.queries.select, - [this.sessionId], - { prepare: true } - ); - const storedMessages: StoredMessage[] = resultSet.rows.map((row) => ({ - type: row.message_type, - data: JSON.parse(row.data), - })); - - const baseMessages = mapStoredMessagesToChatMessages(storedMessages); - return baseMessages; - } - - /** - * Method to add a new message to the Cassandra database. - * @param message The BaseMessage instance to add. - * @returns A promise that resolves when the message has been added. - */ - public async addMessage(message: BaseMessage): Promise { - await this.ensureTable(); - const messages = mapChatMessagesToStoredMessages([message]); - const { type, data } = messages[0]; - return this.client - .execute( - this.queries.insert, - [this.sessionId, type, JSON.stringify(data)], - { prepare: true, ...this.options } - ) - .then(() => {}); - } - - /** - * Method to clear all the messages from the Cassandra database. - * @returns A promise that resolves when all messages have been cleared. - */ - public async clear(): Promise { - await this.ensureTable(); - return this.client - .execute(this.queries.delete, [this.sessionId], { - prepare: true, - ...this.options, - }) - .then(() => {}); - } - - /** - * Method to initialize the Cassandra database. - * @returns Promise that resolves when the database has been initialized. - */ - private async ensureTable(): Promise { - if (this.tableExists) { - return; - } - - await this.client.execute(` - CREATE TABLE IF NOT EXISTS ${this.keyspace}.${this.table} ( - session_id text, - message_ts timestamp, - message_type text, - data text, - PRIMARY KEY ((session_id), message_ts) - ); - `); - - this.queries = { - insert: `INSERT INTO ${this.keyspace}.${this.table} (session_id, message_ts, message_type, data) VALUES (?, toTimestamp(now()), ?, ?);`, - select: `SELECT message_type, data FROM ${this.keyspace}.${this.table} WHERE session_id = ?;`, - delete: `DELETE FROM ${this.keyspace}.${this.table} WHERE session_id = ?;`, - }; - - this.tableExists = true; - } -} +export * from "@langchain/community/stores/message/cassandra"; diff --git a/langchain/src/stores/message/cloudflare_d1.ts b/langchain/src/stores/message/cloudflare_d1.ts index a88cb595c2ca..3dccbdea5a83 100644 --- a/langchain/src/stores/message/cloudflare_d1.ts +++ b/langchain/src/stores/message/cloudflare_d1.ts @@ -1,197 +1 @@ -import { v4 } from "uuid"; -import type { D1Database } from "@cloudflare/workers-types"; - -import { - BaseMessage, - BaseListChatMessageHistory, - StoredMessage, - StoredMessageData, -} from "../../schema/index.js"; -import { - mapChatMessagesToStoredMessages, - mapStoredMessagesToChatMessages, -} from "./utils.js"; - -/** - * Type definition for the input parameters required when instantiating a - * CloudflareD1MessageHistory object. - */ -export type CloudflareD1MessageHistoryInput = { - tableName?: string; - sessionId: string; - database?: D1Database; -}; - -/** - * Interface for the data transfer object used when selecting stored - * messages from the Cloudflare D1 database. - */ -interface selectStoredMessagesDTO { - id: string; - session_id: string; - type: string; - content: string; - role: string | null; - name: string | null; - additional_kwargs: string; -} - -/** - * Class for storing and retrieving chat message history from a - * Cloudflare D1 database. Extends the BaseListChatMessageHistory class. - * @example - * ```typescript - * const memory = new BufferMemory({ - * returnMessages: true, - * chatHistory: new CloudflareD1MessageHistory({ - * tableName: "stored_message", - * sessionId: "example", - * database: env.DB, - * }), - * }); - * - * const chainInput = { input }; - * - * const res = await memory.chatHistory.invoke(chainInput); - * await memory.saveContext(chainInput, { - * output: res, - * }); - * ``` - */ -export class CloudflareD1MessageHistory extends BaseListChatMessageHistory { - lc_namespace = ["langchain", "stores", "message", "cloudflare_d1"]; - - public database: D1Database; - - private tableName: string; - - private sessionId: string; - - private tableInitialized: boolean; - - constructor(fields: CloudflareD1MessageHistoryInput) { - super(fields); - - const { sessionId, database, tableName } = fields; - - if (database) { - this.database = database; - } else { - throw new Error( - "Either a client or config must be provided to CloudflareD1MessageHistory" - ); - } - - this.tableName = tableName || "langchain_chat_histories"; - this.tableInitialized = false; - this.sessionId = sessionId; - } - - /** - * Private method to ensure that the necessary table exists in the - * Cloudflare D1 database before performing any operations. If the table - * does not exist, it is created. - * @returns Promise that resolves to void. - */ - private async ensureTable(): Promise { - if (this.tableInitialized) { - return; - } - - const query = `CREATE TABLE IF NOT EXISTS ${this.tableName} (id TEXT PRIMARY KEY, session_id TEXT, type TEXT, content TEXT, role TEXT, name TEXT, additional_kwargs TEXT);`; - await this.database.prepare(query).bind().all(); - - const idIndexQuery = `CREATE INDEX IF NOT EXISTS id_index ON ${this.tableName} (id);`; - await this.database.prepare(idIndexQuery).bind().all(); - - const sessionIdIndexQuery = `CREATE INDEX IF NOT EXISTS session_id_index ON ${this.tableName} (session_id);`; - await this.database.prepare(sessionIdIndexQuery).bind().all(); - - this.tableInitialized = true; - } - - /** - * Method to retrieve all messages from the Cloudflare D1 database for the - * current session. - * @returns Promise that resolves to an array of BaseMessage objects. - */ - async getMessages(): Promise { - await this.ensureTable(); - - const query = `SELECT * FROM ${this.tableName} WHERE session_id = ?`; - const rawStoredMessages = await this.database - .prepare(query) - .bind(this.sessionId) - .all(); - const storedMessagesObject = - rawStoredMessages.results as unknown as selectStoredMessagesDTO[]; - - const orderedMessages: StoredMessage[] = storedMessagesObject.map( - (message) => { - const data = { - content: message.content, - additional_kwargs: JSON.parse(message.additional_kwargs), - } as StoredMessageData; - - if (message.role) { - data.role = message.role; - } - - if (message.name) { - data.name = message.name; - } - - return { - type: message.type, - data, - }; - } - ); - - return mapStoredMessagesToChatMessages(orderedMessages); - } - - /** - * Method to add a new message to the Cloudflare D1 database for the current - * session. - * @param message The BaseMessage object to be added to the database. - * @returns Promise that resolves to void. - */ - async addMessage(message: BaseMessage): Promise { - await this.ensureTable(); - - const messageToAdd = mapChatMessagesToStoredMessages([message]); - - const query = `INSERT INTO ${this.tableName} (id, session_id, type, content, role, name, additional_kwargs) VALUES(?, ?, ?, ?, ?, ?, ?)`; - - const id = v4(); - - await this.database - .prepare(query) - .bind( - id, - this.sessionId, - messageToAdd[0].type || null, - messageToAdd[0].data.content || null, - messageToAdd[0].data.role || null, - messageToAdd[0].data.name || null, - JSON.stringify(messageToAdd[0].data.additional_kwargs) - ) - .all(); - } - - /** - * Method to delete all messages from the Cloudflare D1 database for the - * current session. - * @returns Promise that resolves to void. - */ - async clear(): Promise { - await this.ensureTable(); - - const query = `DELETE FROM ? WHERE session_id = ? `; - await this.database - .prepare(query) - .bind(this.tableName, this.sessionId) - .all(); - } -} +export * from "@langchain/community/stores/message/cloudflare_d1"; diff --git a/langchain/src/stores/message/convex.ts b/langchain/src/stores/message/convex.ts index c060e076900f..f312698008e5 100644 --- a/langchain/src/stores/message/convex.ts +++ b/langchain/src/stores/message/convex.ts @@ -1,209 +1 @@ -/* eslint-disable @typescript-eslint/no-explicit-any */ - -// eslint-disable-next-line import/no-extraneous-dependencies -import { - DocumentByInfo, - DocumentByName, - FieldPaths, - FunctionReference, - GenericActionCtx, - GenericDataModel, - NamedTableInfo, - TableNamesInDataModel, - IndexNames, - makeFunctionReference, -} from "convex/server"; -import { BaseMessage, BaseListChatMessageHistory } from "../../schema/index.js"; -import { - mapChatMessagesToStoredMessages, - mapStoredMessagesToChatMessages, -} from "./utils.js"; - -/** - * Type that defines the config required to initialize the - * ConvexChatMessageHistory class. At minimum it needs a sessionId - * and an ActionCtx. - */ -export type ConvexChatMessageHistoryInput< - DataModel extends GenericDataModel, - TableName extends TableNamesInDataModel = "messages", - IndexName extends IndexNames< - NamedTableInfo - > = "bySessionId", - SessionIdFieldName extends FieldPaths< - NamedTableInfo - > = "sessionId", - MessageTextFieldName extends FieldPaths< - NamedTableInfo - > = "message", - InsertMutation extends FunctionReference< - "mutation", - "internal", - { table: string; document: object } - > = any, - LookupQuery extends FunctionReference< - "query", - "internal", - { table: string; index: string; keyField: string; key: string }, - object[] - > = any, - DeleteManyMutation extends FunctionReference< - "mutation", - "internal", - { table: string; index: string; keyField: string; key: string } - > = any -> = { - readonly ctx: GenericActionCtx; - readonly sessionId: DocumentByName[SessionIdFieldName]; - /** - * Defaults to "messages" - */ - readonly table?: TableName; - /** - * Defaults to "bySessionId" - */ - readonly index?: IndexName; - /** - * Defaults to "sessionId" - */ - readonly sessionIdField?: SessionIdFieldName; - /** - * Defaults to "message" - */ - readonly messageTextFieldName?: MessageTextFieldName; - /** - * Defaults to `internal.langchain.db.insert` - */ - readonly insert?: InsertMutation; - /** - * Defaults to `internal.langchain.db.lookup` - */ - readonly lookup?: LookupQuery; - /** - * Defaults to `internal.langchain.db.deleteMany` - */ - readonly deleteMany?: DeleteManyMutation; -}; - -export class ConvexChatMessageHistory< - DataModel extends GenericDataModel, - SessionIdFieldName extends FieldPaths< - NamedTableInfo - > = "sessionId", - TableName extends TableNamesInDataModel = "messages", - IndexName extends IndexNames< - NamedTableInfo - > = "bySessionId", - MessageTextFieldName extends FieldPaths< - NamedTableInfo - > = "message", - InsertMutation extends FunctionReference< - "mutation", - "internal", - { table: string; document: object } - > = any, - LookupQuery extends FunctionReference< - "query", - "internal", - { table: string; index: string; keyField: string; key: string }, - object[] - > = any, - DeleteManyMutation extends FunctionReference< - "mutation", - "internal", - { table: string; index: string; keyField: string; key: string } - > = any -> extends BaseListChatMessageHistory { - lc_namespace = ["langchain", "stores", "message", "convex"]; - - private readonly ctx: GenericActionCtx; - - private readonly sessionId: DocumentByInfo< - NamedTableInfo - >[SessionIdFieldName]; - - private readonly table: TableName; - - private readonly index: IndexName; - - private readonly sessionIdField: SessionIdFieldName; - - private readonly messageTextFieldName: MessageTextFieldName; - - private readonly insert: InsertMutation; - - private readonly lookup: LookupQuery; - - private readonly deleteMany: DeleteManyMutation; - - constructor( - config: ConvexChatMessageHistoryInput< - DataModel, - TableName, - IndexName, - SessionIdFieldName, - MessageTextFieldName, - InsertMutation, - LookupQuery, - DeleteManyMutation - > - ) { - super(); - this.ctx = config.ctx; - this.sessionId = config.sessionId; - this.table = config.table ?? ("messages" as TableName); - this.index = config.index ?? ("bySessionId" as IndexName); - this.sessionIdField = - config.sessionIdField ?? ("sessionId" as SessionIdFieldName); - this.messageTextFieldName = - config.messageTextFieldName ?? ("message" as MessageTextFieldName); - this.insert = - config.insert ?? (makeFunctionReference("langchain/db:insert") as any); - this.lookup = - config.lookup ?? (makeFunctionReference("langchain/db:lookup") as any); - this.deleteMany = - config.deleteMany ?? - (makeFunctionReference("langchain/db:deleteMany") as any); - } - - async getMessages(): Promise { - const convexDocuments: any[] = await this.ctx.runQuery(this.lookup, { - table: this.table, - index: this.index, - keyField: this.sessionIdField, - key: this.sessionId, - } as any); - - return mapStoredMessagesToChatMessages( - convexDocuments.map((doc) => doc[this.messageTextFieldName]) - ); - } - - async addMessage(message: BaseMessage): Promise { - const messages = mapChatMessagesToStoredMessages([message]); - // TODO: Remove chunking when Convex handles the concurrent requests correctly - const PAGE_SIZE = 16; - for (let i = 0; i < messages.length; i += PAGE_SIZE) { - await Promise.all( - messages.slice(i, i + PAGE_SIZE).map((message) => - this.ctx.runMutation(this.insert, { - table: this.table, - document: { - [this.sessionIdField]: this.sessionId, - [this.messageTextFieldName]: message, - }, - } as any) - ) - ); - } - } - - async clear(): Promise { - await this.ctx.runMutation(this.deleteMany, { - table: this.table, - index: this.index, - keyField: this.sessionIdField, - key: this.sessionId, - } as any); - } -} +export * from "@langchain/community/stores/message/convex"; diff --git a/langchain/src/stores/message/dynamodb.ts b/langchain/src/stores/message/dynamodb.ts index 519f351fee27..11875415e05f 100644 --- a/langchain/src/stores/message/dynamodb.ts +++ b/langchain/src/stores/message/dynamodb.ts @@ -1,198 +1 @@ -import { - DynamoDBClient, - DynamoDBClientConfig, - GetItemCommand, - GetItemCommandInput, - UpdateItemCommand, - UpdateItemCommandInput, - DeleteItemCommand, - DeleteItemCommandInput, - AttributeValue, -} from "@aws-sdk/client-dynamodb"; - -import { - StoredMessage, - BaseMessage, - BaseListChatMessageHistory, -} from "../../schema/index.js"; -import { - mapChatMessagesToStoredMessages, - mapStoredMessagesToChatMessages, -} from "./utils.js"; - -/** - * Interface defining the fields required to create an instance of - * `DynamoDBChatMessageHistory`. It includes the DynamoDB table name, - * session ID, partition key, sort key, message attribute name, and - * DynamoDB client configuration. - */ -export interface DynamoDBChatMessageHistoryFields { - tableName: string; - sessionId: string; - partitionKey?: string; - sortKey?: string; - messageAttributeName?: string; - config?: DynamoDBClientConfig; - key?: Record; -} - -/** - * Interface defining the structure of a chat message as it is stored in - * DynamoDB. - */ -interface DynamoDBSerializedChatMessage { - M: { - type: { - S: string; - }; - text: { - S: string; - }; - role?: { - S: string; - }; - }; -} - -/** - * Class providing methods to interact with a DynamoDB table to store and - * retrieve chat messages. It extends the `BaseListChatMessageHistory` - * class. - */ -export class DynamoDBChatMessageHistory extends BaseListChatMessageHistory { - lc_namespace = ["langchain", "stores", "message", "dynamodb"]; - - get lc_secrets(): { [key: string]: string } | undefined { - return { - "config.credentials.accessKeyId": "AWS_ACCESS_KEY_ID", - "config.credentials.secretAccessKey": "AWS_SECRETE_ACCESS_KEY", - "config.credentials.sessionToken": "AWS_SESSION_TOKEN", - }; - } - - private tableName: string; - - private sessionId: string; - - private client: DynamoDBClient; - - private partitionKey = "id"; - - private sortKey?: string; - - private messageAttributeName = "messages"; - - private dynamoKey: Record = {}; - - constructor({ - tableName, - sessionId, - partitionKey, - sortKey, - messageAttributeName, - config, - key = {}, - }: DynamoDBChatMessageHistoryFields) { - super(); - - this.tableName = tableName; - this.sessionId = sessionId; - this.client = new DynamoDBClient(config ?? {}); - this.partitionKey = partitionKey ?? this.partitionKey; - this.sortKey = sortKey; - this.messageAttributeName = - messageAttributeName ?? this.messageAttributeName; - this.dynamoKey = key; - - // override dynamoKey with partition key and sort key when key not specified - if (Object.keys(this.dynamoKey).length === 0) { - this.dynamoKey[this.partitionKey] = { S: this.sessionId }; - if (this.sortKey) { - this.dynamoKey[this.sortKey] = { S: this.sortKey }; - } - } - } - - /** - * Retrieves all messages from the DynamoDB table and returns them as an - * array of `BaseMessage` instances. - * @returns Array of stored messages - */ - async getMessages(): Promise { - const params: GetItemCommandInput = { - TableName: this.tableName, - Key: this.dynamoKey, - }; - - const response = await this.client.send(new GetItemCommand(params)); - const items = response.Item - ? response.Item[this.messageAttributeName]?.L ?? [] - : []; - const messages = items - .map((item) => ({ - type: item.M?.type.S, - data: { - role: item.M?.role?.S, - content: item.M?.text.S, - }, - })) - .filter( - (x): x is StoredMessage => - x.type !== undefined && x.data.content !== undefined - ); - return mapStoredMessagesToChatMessages(messages); - } - - /** - * Deletes all messages from the DynamoDB table. - */ - async clear(): Promise { - const params: DeleteItemCommandInput = { - TableName: this.tableName, - Key: this.dynamoKey, - }; - await this.client.send(new DeleteItemCommand(params)); - } - - /** - * Adds a new message to the DynamoDB table. - * @param message The message to be added to the DynamoDB table. - */ - async addMessage(message: BaseMessage) { - const messages = mapChatMessagesToStoredMessages([message]); - - const params: UpdateItemCommandInput = { - TableName: this.tableName, - Key: this.dynamoKey, - ExpressionAttributeNames: { - "#m": this.messageAttributeName, - }, - ExpressionAttributeValues: { - ":empty_list": { - L: [], - }, - ":m": { - L: messages.map((message) => { - const dynamoSerializedMessage: DynamoDBSerializedChatMessage = { - M: { - type: { - S: message.type, - }, - text: { - S: message.data.content, - }, - }, - }; - if (message.data.role) { - dynamoSerializedMessage.M.role = { S: message.data.role }; - } - return dynamoSerializedMessage; - }), - }, - }, - UpdateExpression: - "SET #m = list_append(if_not_exists(#m, :empty_list), :m)", - }; - await this.client.send(new UpdateItemCommand(params)); - } -} +export * from "@langchain/community/stores/message/dynamodb"; diff --git a/langchain/src/stores/message/firestore.ts b/langchain/src/stores/message/firestore.ts index 8d342bc5b981..94de701fb5f5 100644 --- a/langchain/src/stores/message/firestore.ts +++ b/langchain/src/stores/message/firestore.ts @@ -1,195 +1 @@ -import type { AppOptions } from "firebase-admin"; -import { getApps, initializeApp } from "firebase-admin/app"; -import { - getFirestore, - DocumentData, - Firestore, - DocumentReference, - FieldValue, -} from "firebase-admin/firestore"; - -import { - StoredMessage, - BaseMessage, - BaseListChatMessageHistory, -} from "../../schema/index.js"; -import { - mapChatMessagesToStoredMessages, - mapStoredMessagesToChatMessages, -} from "./utils.js"; - -/** - * Interface for FirestoreDBChatMessageHistory. It includes the collection - * name, session ID, user ID, and optionally, the app index and - * configuration for the Firebase app. - */ -export interface FirestoreDBChatMessageHistory { - collectionName: string; - sessionId: string; - userId: string; - appIdx?: number; - config?: AppOptions; -} -/** - * Class for managing chat message history using Google's Firestore as a - * storage backend. Extends the BaseListChatMessageHistory class. - * @example - * ```typescript - * const chatHistory = new FirestoreChatMessageHistory({ - * collectionName: "langchain", - * sessionId: "lc-example", - * userId: "a@example.com", - * config: { projectId: "your-project-id" }, - * }); - * - * const chain = new ConversationChain({ - * llm: new ChatOpenAI(), - * memory: new BufferMemory({ chatHistory }), - * }); - * - * const response = await chain.invoke({ - * input: "What did I just say my name was?", - * }); - * console.log({ response }); - * ``` - */ -export class FirestoreChatMessageHistory extends BaseListChatMessageHistory { - lc_namespace = ["langchain", "stores", "message", "firestore"]; - - private collectionName: string; - - private sessionId: string; - - private userId: string; - - private appIdx: number; - - private config: AppOptions; - - private firestoreClient: Firestore; - - private document: DocumentReference | null; - - constructor({ - collectionName, - sessionId, - userId, - appIdx = 0, - config, - }: FirestoreDBChatMessageHistory) { - super(); - this.collectionName = collectionName; - this.sessionId = sessionId; - this.userId = userId; - this.document = null; - this.appIdx = appIdx; - if (config) this.config = config; - - try { - this.ensureFirestore(); - } catch (error) { - throw new Error(`Unknown response type`); - } - } - - private ensureFirestore(): void { - let app; - // Check if the app is already initialized else get appIdx - if (!getApps().length) app = initializeApp(this.config); - else app = getApps()[this.appIdx]; - - this.firestoreClient = getFirestore(app); - - this.document = this.firestoreClient - .collection(this.collectionName) - .doc(this.sessionId); - } - - /** - * Method to retrieve all messages from the Firestore collection - * associated with the current session. Returns an array of BaseMessage - * objects. - * @returns Array of stored messages - */ - async getMessages(): Promise { - if (!this.document) { - throw new Error("Document not initialized"); - } - - const querySnapshot = await this.document - .collection("messages") - .orderBy("createdAt", "asc") - .get() - .catch((err) => { - throw new Error(`Unknown response type: ${err.toString()}`); - }); - - const response: StoredMessage[] = []; - querySnapshot.forEach((doc) => { - const { type, data } = doc.data(); - response.push({ type, data }); - }); - - return mapStoredMessagesToChatMessages(response); - } - - /** - * Method to add a new message to the Firestore collection. The message is - * passed as a BaseMessage object. - * @param message The message to be added as a BaseMessage object. - */ - public async addMessage(message: BaseMessage) { - const messages = mapChatMessagesToStoredMessages([message]); - await this.upsertMessage(messages[0]); - } - - private async upsertMessage(message: StoredMessage): Promise { - if (!this.document) { - throw new Error("Document not initialized"); - } - await this.document.set( - { - id: this.sessionId, - user_id: this.userId, - }, - { merge: true } - ); - await this.document - .collection("messages") - .add({ - type: message.type, - data: message.data, - createdBy: this.userId, - createdAt: FieldValue.serverTimestamp(), - }) - .catch((err) => { - throw new Error(`Unknown response type: ${err.toString()}`); - }); - } - - /** - * Method to delete all messages from the Firestore collection associated - * with the current session. - */ - public async clear(): Promise { - if (!this.document) { - throw new Error("Document not initialized"); - } - await this.document - .collection("messages") - .get() - .then((querySnapshot) => { - querySnapshot.docs.forEach((snapshot) => { - snapshot.ref.delete().catch((err) => { - throw new Error(`Unknown response type: ${err.toString()}`); - }); - }); - }) - .catch((err) => { - throw new Error(`Unknown response type: ${err.toString()}`); - }); - await this.document.delete().catch((err) => { - throw new Error(`Unknown response type: ${err.toString()}`); - }); - } -} +export * from "@langchain/community/stores/message/firestore"; diff --git a/langchain/src/stores/message/ioredis.ts b/langchain/src/stores/message/ioredis.ts index c705b9228769..2e854bc6c17e 100644 --- a/langchain/src/stores/message/ioredis.ts +++ b/langchain/src/stores/message/ioredis.ts @@ -1,102 +1 @@ -import { Redis, RedisOptions } from "ioredis"; -import { BaseMessage, BaseListChatMessageHistory } from "../../schema/index.js"; -import { - mapChatMessagesToStoredMessages, - mapStoredMessagesToChatMessages, -} from "./utils.js"; - -/** - * Type for the input parameter of the RedisChatMessageHistory - * constructor. It includes fields for the session ID, session TTL, Redis - * URL, Redis configuration, and Redis client. - */ -export type RedisChatMessageHistoryInput = { - sessionId: string; - sessionTTL?: number; - url?: string; - config?: RedisOptions; - client?: Redis; -}; - -/** - * Class used to store chat message history in Redis. It provides methods - * to add, retrieve, and clear messages from the chat history. - * @example - * ```typescript - * const chatHistory = new RedisChatMessageHistory({ - * sessionId: new Date().toISOString(), - * sessionTTL: 300, - * url: "redis: - * }); - * - * const chain = new ConversationChain({ - * llm: new ChatOpenAI({ temperature: 0 }), - * memory: { chatHistory }, - * }); - * - * const response = await chain.invoke({ - * input: "What did I just say my name was?", - * }); - * console.log({ response }); - * ``` - */ -export class RedisChatMessageHistory extends BaseListChatMessageHistory { - lc_namespace = ["langchain", "stores", "message", "ioredis"]; - - get lc_secrets() { - return { - url: "REDIS_URL", - "config.username": "REDIS_USERNAME", - "config.password": "REDIS_PASSWORD", - }; - } - - public client: Redis; - - private sessionId: string; - - private sessionTTL?: number; - - constructor(fields: RedisChatMessageHistoryInput) { - super(fields); - - const { sessionId, sessionTTL, url, config, client } = fields; - this.client = (client ?? - (url ? new Redis(url) : new Redis(config ?? {}))) as Redis; - this.sessionId = sessionId; - this.sessionTTL = sessionTTL; - } - - /** - * Retrieves all messages from the chat history. - * @returns Promise that resolves with an array of BaseMessage instances. - */ - async getMessages(): Promise { - const rawStoredMessages = await this.client.lrange(this.sessionId, 0, -1); - const orderedMessages = rawStoredMessages - .reverse() - .map((message) => JSON.parse(message)); - return mapStoredMessagesToChatMessages(orderedMessages); - } - - /** - * Adds a message to the chat history. - * @param message The message to add to the chat history. - * @returns Promise that resolves when the message has been added. - */ - async addMessage(message: BaseMessage): Promise { - const messageToAdd = mapChatMessagesToStoredMessages([message]); - await this.client.lpush(this.sessionId, JSON.stringify(messageToAdd[0])); - if (this.sessionTTL) { - await this.client.expire(this.sessionId, this.sessionTTL); - } - } - - /** - * Clears all messages from the chat history. - * @returns Promise that resolves when the chat history has been cleared. - */ - async clear(): Promise { - await this.client.del(this.sessionId); - } -} +export * from "@langchain/community/stores/message/ioredis"; diff --git a/langchain/src/stores/message/momento.ts b/langchain/src/stores/message/momento.ts index c902980f694b..060690357c95 100644 --- a/langchain/src/stores/message/momento.ts +++ b/langchain/src/stores/message/momento.ts @@ -1,198 +1 @@ -/* eslint-disable no-instanceof/no-instanceof */ -import { - CacheDelete, - CacheListFetch, - CacheListPushBack, - ICacheClient, - InvalidArgumentError, - CollectionTtl, -} from "@gomomento/sdk-core"; -import { - BaseMessage, - BaseListChatMessageHistory, - StoredMessage, -} from "../../schema/index.js"; -import { - mapChatMessagesToStoredMessages, - mapStoredMessagesToChatMessages, -} from "./utils.js"; -import { ensureCacheExists } from "../../util/momento.js"; - -/** - * The settings to instantiate the Momento chat message history. - */ -export interface MomentoChatMessageHistoryProps { - /** - * The session ID to use to store the data. - */ - sessionId: string; - /** - * The Momento cache client. - */ - client: ICacheClient; - /** - * The name of the cache to use to store the data. - */ - cacheName: string; - /** - * The time to live for the cache items in seconds. - * If not specified, the cache client default is used. - */ - sessionTtl?: number; - /** - * If true, ensure that the cache exists before returning. - * If false, the cache is not checked for existence. - * Defaults to true. - */ - ensureCacheExists?: true; -} - -/** - * A class that stores chat message history using Momento Cache. It - * interacts with a Momento cache client to perform operations like - * fetching, adding, and deleting messages. - * @example - * ```typescript - * const chatHistory = await MomentoChatMessageHistory.fromProps({ - * client: new CacheClient({ - * configuration: Configurations.Laptop.v1(), - * credentialProvider: CredentialProvider.fromEnvironmentVariable({ - * environmentVariableName: "MOMENTO_API_KEY", - * }), - * defaultTtlSeconds: 60 * 60 * 24, - * }), - * cacheName: "langchain", - * sessionId: new Date().toISOString(), - * sessionTtl: 300, - * }); - * - * const messages = await chatHistory.getMessages(); - * console.log({ messages }); - * ``` - */ -export class MomentoChatMessageHistory extends BaseListChatMessageHistory { - lc_namespace = ["langchain", "stores", "message", "momento"]; - - private readonly sessionId: string; - - private readonly client: ICacheClient; - - private readonly cacheName: string; - - private readonly sessionTtl: CollectionTtl; - - private constructor(props: MomentoChatMessageHistoryProps) { - super(); - this.sessionId = props.sessionId; - this.client = props.client; - this.cacheName = props.cacheName; - - this.validateTtlSeconds(props.sessionTtl); - this.sessionTtl = - props.sessionTtl !== undefined - ? CollectionTtl.of(props.sessionTtl) - : CollectionTtl.fromCacheTtl(); - } - - /** - * Create a new chat message history backed by Momento. - * - * @param {MomentoCacheProps} props The settings to instantiate the Momento chat message history. - * @param {string} props.sessionId The session ID to use to store the data. - * @param {ICacheClient} props.client The Momento cache client. - * @param {string} props.cacheName The name of the cache to use to store the data. - * @param {number} props.sessionTtl The time to live for the cache items in seconds. - * If not specified, the cache client default is used. - * @param {boolean} props.ensureCacheExists If true, ensure that the cache exists before returning. - * If false, the cache is not checked for existence. - * @throws {InvalidArgumentError} If {@link props.sessionTtl} is not strictly positive. - * @returns A new chat message history backed by Momento. - */ - public static async fromProps( - props: MomentoChatMessageHistoryProps - ): Promise { - const instance = new MomentoChatMessageHistory(props); - if (props.ensureCacheExists || props.ensureCacheExists === undefined) { - await ensureCacheExists(props.client, props.cacheName); - } - return instance; - } - - /** - * Validate the user-specified TTL, if provided, is strictly positive. - * @param ttlSeconds The TTL to validate. - */ - private validateTtlSeconds(ttlSeconds?: number): void { - if (ttlSeconds !== undefined && ttlSeconds <= 0) { - throw new InvalidArgumentError("ttlSeconds must be positive."); - } - } - - /** - * Fetches messages from the cache. - * @returns A Promise that resolves to an array of BaseMessage instances. - */ - public async getMessages(): Promise { - const fetchResponse = await this.client.listFetch( - this.cacheName, - this.sessionId - ); - - let messages: StoredMessage[] = []; - if (fetchResponse instanceof CacheListFetch.Hit) { - messages = fetchResponse - .valueList() - .map((serializedStoredMessage) => JSON.parse(serializedStoredMessage)); - } else if (fetchResponse instanceof CacheListFetch.Miss) { - // pass - } else if (fetchResponse instanceof CacheListFetch.Error) { - throw fetchResponse.innerException(); - } else { - throw new Error(`Unknown response type: ${fetchResponse.toString()}`); - } - return mapStoredMessagesToChatMessages(messages); - } - - /** - * Adds a message to the cache. - * @param message The BaseMessage instance to add to the cache. - * @returns A Promise that resolves when the message has been added. - */ - public async addMessage(message: BaseMessage): Promise { - const messageToAdd = JSON.stringify( - mapChatMessagesToStoredMessages([message])[0] - ); - - const pushResponse = await this.client.listPushBack( - this.cacheName, - this.sessionId, - messageToAdd, - { ttl: this.sessionTtl } - ); - if (pushResponse instanceof CacheListPushBack.Success) { - // pass - } else if (pushResponse instanceof CacheListPushBack.Error) { - throw pushResponse.innerException(); - } else { - throw new Error(`Unknown response type: ${pushResponse.toString()}`); - } - } - - /** - * Deletes all messages from the cache. - * @returns A Promise that resolves when all messages have been deleted. - */ - public async clear(): Promise { - const deleteResponse = await this.client.delete( - this.cacheName, - this.sessionId - ); - if (deleteResponse instanceof CacheDelete.Success) { - // pass - } else if (deleteResponse instanceof CacheDelete.Error) { - throw deleteResponse.innerException(); - } else { - throw new Error(`Unknown response type: ${deleteResponse.toString()}`); - } - } -} +export * from "@langchain/community/stores/message/momento"; diff --git a/langchain/src/stores/message/mongodb.ts b/langchain/src/stores/message/mongodb.ts index ccca599f1c91..f4bb9a41cc6b 100644 --- a/langchain/src/stores/message/mongodb.ts +++ b/langchain/src/stores/message/mongodb.ts @@ -1,59 +1 @@ -import { Collection, Document as MongoDBDocument, ObjectId } from "mongodb"; -import { BaseMessage, BaseListChatMessageHistory } from "../../schema/index.js"; -import { - mapChatMessagesToStoredMessages, - mapStoredMessagesToChatMessages, -} from "./utils.js"; - -export interface MongoDBChatMessageHistoryInput { - collection: Collection; - sessionId: string; -} - -/** - * @example - * ```typescript - * const chatHistory = new MongoDBChatMessageHistory({ - * collection: myCollection, - * sessionId: 'unique-session-id', - * }); - * const messages = await chatHistory.getMessages(); - * await chatHistory.clear(); - * ``` - */ -export class MongoDBChatMessageHistory extends BaseListChatMessageHistory { - lc_namespace = ["langchain", "stores", "message", "mongodb"]; - - private collection: Collection; - - private sessionId: string; - - constructor({ collection, sessionId }: MongoDBChatMessageHistoryInput) { - super(); - this.collection = collection; - this.sessionId = sessionId; - } - - async getMessages(): Promise { - const document = await this.collection.findOne({ - _id: new ObjectId(this.sessionId), - }); - const messages = document?.messages || []; - return mapStoredMessagesToChatMessages(messages); - } - - async addMessage(message: BaseMessage): Promise { - const messages = mapChatMessagesToStoredMessages([message]); - await this.collection.updateOne( - { _id: new ObjectId(this.sessionId) }, - { - $push: { messages: { $each: messages } }, - }, - { upsert: true } - ); - } - - async clear(): Promise { - await this.collection.deleteOne({ _id: new ObjectId(this.sessionId) }); - } -} +export * from "@langchain/community/stores/message/mongodb"; diff --git a/langchain/src/stores/message/planetscale.ts b/langchain/src/stores/message/planetscale.ts index 10ca1cddf810..bf090eb5fa95 100644 --- a/langchain/src/stores/message/planetscale.ts +++ b/langchain/src/stores/message/planetscale.ts @@ -1,210 +1 @@ -import { - Client as PlanetScaleClient, - Config as PlanetScaleConfig, - Connection as PlanetScaleConnection, -} from "@planetscale/database"; -import { - BaseMessage, - BaseListChatMessageHistory, - StoredMessage, - StoredMessageData, -} from "../../schema/index.js"; -import { - mapChatMessagesToStoredMessages, - mapStoredMessagesToChatMessages, -} from "./utils.js"; - -/** - * Type definition for the input parameters required when instantiating a - * PlanetScaleChatMessageHistory object. - */ -export type PlanetScaleChatMessageHistoryInput = { - tableName?: string; - sessionId: string; - config?: PlanetScaleConfig; - client?: PlanetScaleClient; -}; - -/** - * Interface for the data transfer object used when selecting stored - * messages from the PlanetScale database. - */ -interface selectStoredMessagesDTO { - id: string; - session_id: string; - type: string; - content: string; - role: string | null; - name: string | null; - additional_kwargs: string; -} - -/** - * Class for storing and retrieving chat message history from a - * PlanetScale database. Extends the BaseListChatMessageHistory class. - * @example - * ```typescript - * const chatHistory = new PlanetScaleChatMessageHistory({ - * tableName: "stored_message", - * sessionId: "lc-example", - * config: { - * url: "ADD_YOURS_HERE", - * }, - * }); - * const chain = new ConversationChain({ - * llm: new ChatOpenAI(), - * memory: chatHistory, - * }); - * const response = await chain.invoke({ - * input: "What did I just say my name was?", - * }); - * console.log({ response }); - * ``` - */ -export class PlanetScaleChatMessageHistory extends BaseListChatMessageHistory { - lc_namespace = ["langchain", "stores", "message", "planetscale"]; - - get lc_secrets() { - return { - "config.host": "PLANETSCALE_HOST", - "config.username": "PLANETSCALE_USERNAME", - "config.password": "PLANETSCALE_PASSWORD", - "config.url": "PLANETSCALE_DATABASE_URL", - }; - } - - public client: PlanetScaleClient; - - private connection: PlanetScaleConnection; - - private tableName: string; - - private sessionId: string; - - private tableInitialized: boolean; - - constructor(fields: PlanetScaleChatMessageHistoryInput) { - super(fields); - - const { sessionId, config, client, tableName } = fields; - - if (client) { - this.client = client; - } else if (config) { - this.client = new PlanetScaleClient(config); - } else { - throw new Error( - "Either a client or config must be provided to PlanetScaleChatMessageHistory" - ); - } - - this.connection = this.client.connection(); - - this.tableName = tableName || "langchain_chat_histories"; - this.tableInitialized = false; - this.sessionId = sessionId; - } - - /** - * Private method to ensure that the necessary table exists in the - * PlanetScale database before performing any operations. If the table - * does not exist, it is created. - * @returns Promise that resolves to void. - */ - private async ensureTable(): Promise { - if (this.tableInitialized) { - return; - } - - const query = `CREATE TABLE IF NOT EXISTS ${this.tableName} (id BINARY(16) PRIMARY KEY, session_id VARCHAR(255), type VARCHAR(255), content VARCHAR(255), role VARCHAR(255), name VARCHAR(255), additional_kwargs VARCHAR(255));`; - - await this.connection.execute(query); - - const indexQuery = `ALTER TABLE ${this.tableName} MODIFY id BINARY(16) DEFAULT (UUID_TO_BIN(UUID()));`; - - await this.connection.execute(indexQuery); - - this.tableInitialized = true; - } - - /** - * Method to retrieve all messages from the PlanetScale database for the - * current session. - * @returns Promise that resolves to an array of BaseMessage objects. - */ - async getMessages(): Promise { - await this.ensureTable(); - - const query = `SELECT * FROM ${this.tableName} WHERE session_id = :session_id`; - const params = { - session_id: this.sessionId, - }; - - const rawStoredMessages = await this.connection.execute(query, params); - const storedMessagesObject = - rawStoredMessages.rows as unknown as selectStoredMessagesDTO[]; - - const orderedMessages: StoredMessage[] = storedMessagesObject.map( - (message) => { - const data = { - content: message.content, - additional_kwargs: JSON.parse(message.additional_kwargs), - } as StoredMessageData; - - if (message.role) { - data.role = message.role; - } - - if (message.name) { - data.name = message.name; - } - - return { - type: message.type, - data, - }; - } - ); - return mapStoredMessagesToChatMessages(orderedMessages); - } - - /** - * Method to add a new message to the PlanetScale database for the current - * session. - * @param message The BaseMessage object to be added to the database. - * @returns Promise that resolves to void. - */ - async addMessage(message: BaseMessage): Promise { - await this.ensureTable(); - - const messageToAdd = mapChatMessagesToStoredMessages([message]); - - const query = `INSERT INTO ${this.tableName} (session_id, type, content, role, name, additional_kwargs) VALUES (:session_id, :type, :content, :role, :name, :additional_kwargs)`; - - const params = { - session_id: this.sessionId, - type: messageToAdd[0].type, - content: messageToAdd[0].data.content, - role: messageToAdd[0].data.role, - name: messageToAdd[0].data.name, - additional_kwargs: JSON.stringify(messageToAdd[0].data.additional_kwargs), - }; - - await this.connection.execute(query, params); - } - - /** - * Method to delete all messages from the PlanetScale database for the - * current session. - * @returns Promise that resolves to void. - */ - async clear(): Promise { - await this.ensureTable(); - - const query = `DELETE FROM ${this.tableName} WHERE session_id = :session_id`; - const params = { - session_id: this.sessionId, - }; - await this.connection.execute(query, params); - } -} +export * from "@langchain/community/stores/message/planetscale"; diff --git a/langchain/src/stores/message/redis.ts b/langchain/src/stores/message/redis.ts index fef97ac84af7..801f450bc9b7 100644 --- a/langchain/src/stores/message/redis.ts +++ b/langchain/src/stores/message/redis.ts @@ -1,129 +1 @@ -// TODO: Deprecate in favor of stores/message/ioredis.ts when LLMCache and other implementations are ported -import { - createClient, - RedisClientOptions, - RedisClientType, - RedisModules, - RedisFunctions, - RedisScripts, -} from "redis"; -import { BaseMessage, BaseListChatMessageHistory } from "../../schema/index.js"; -import { - mapChatMessagesToStoredMessages, - mapStoredMessagesToChatMessages, -} from "./utils.js"; - -/** - * Type for the input to the `RedisChatMessageHistory` constructor. - */ -export type RedisChatMessageHistoryInput = { - sessionId: string; - sessionTTL?: number; - config?: RedisClientOptions; - // Typing issues with createClient output: https://github.com/redis/node-redis/issues/1865 - // eslint-disable-next-line @typescript-eslint/no-explicit-any - client?: any; -}; - -/** - * Class for storing chat message history using Redis. Extends the - * `BaseListChatMessageHistory` class. - * @example - * ```typescript - * const chatHistory = new RedisChatMessageHistory({ - * sessionId: new Date().toISOString(), - * sessionTTL: 300, - * url: "redis: - * }); - * - * const chain = new ConversationChain({ - * llm: new ChatOpenAI({ modelName: "gpt-3.5-turbo", temperature: 0 }), - * memory: { chatHistory }, - * }); - * - * const response = await chain.invoke({ - * input: "What did I just say my name was?", - * }); - * console.log({ response }); - * ``` - */ -export class RedisChatMessageHistory extends BaseListChatMessageHistory { - lc_namespace = ["langchain", "stores", "message", "redis"]; - - get lc_secrets() { - return { - "config.url": "REDIS_URL", - "config.username": "REDIS_USERNAME", - "config.password": "REDIS_PASSWORD", - }; - } - - public client: RedisClientType; - - private sessionId: string; - - private sessionTTL?: number; - - constructor(fields: RedisChatMessageHistoryInput) { - super(fields); - - const { sessionId, sessionTTL, config, client } = fields; - this.client = (client ?? createClient(config ?? {})) as RedisClientType< - RedisModules, - RedisFunctions, - RedisScripts - >; - this.sessionId = sessionId; - this.sessionTTL = sessionTTL; - } - - /** - * Ensures the Redis client is ready to perform operations. If the client - * is not ready, it attempts to connect to the Redis database. - * @returns Promise resolving to true when the client is ready. - */ - async ensureReadiness() { - if (!this.client.isReady) { - await this.client.connect(); - } - return true; - } - - /** - * Retrieves all chat messages from the Redis database for the current - * session. - * @returns Promise resolving to an array of `BaseMessage` instances. - */ - async getMessages(): Promise { - await this.ensureReadiness(); - const rawStoredMessages = await this.client.lRange(this.sessionId, 0, -1); - const orderedMessages = rawStoredMessages - .reverse() - .map((message) => JSON.parse(message)); - return mapStoredMessagesToChatMessages(orderedMessages); - } - - /** - * Adds a new chat message to the Redis database for the current session. - * @param message The `BaseMessage` instance to add. - * @returns Promise resolving when the message has been added. - */ - async addMessage(message: BaseMessage): Promise { - await this.ensureReadiness(); - const messageToAdd = mapChatMessagesToStoredMessages([message]); - await this.client.lPush(this.sessionId, JSON.stringify(messageToAdd[0])); - if (this.sessionTTL) { - await this.client.expire(this.sessionId, this.sessionTTL); - } - } - - /** - * Deletes all chat messages from the Redis database for the current - * session. - * @returns Promise resolving when the messages have been deleted. - */ - async clear(): Promise { - await this.ensureReadiness(); - await this.client.del(this.sessionId); - } -} +export * from "@langchain/community/stores/message/redis"; diff --git a/langchain/src/stores/message/upstash_redis.ts b/langchain/src/stores/message/upstash_redis.ts index 0d8f318e07b0..33cb724f14dd 100644 --- a/langchain/src/stores/message/upstash_redis.ts +++ b/langchain/src/stores/message/upstash_redis.ts @@ -1,95 +1 @@ -import { Redis, type RedisConfigNodejs } from "@upstash/redis"; -import { - StoredMessage, - BaseMessage, - BaseListChatMessageHistory, -} from "../../schema/index.js"; -import { - mapChatMessagesToStoredMessages, - mapStoredMessagesToChatMessages, -} from "./utils.js"; - -/** - * Type definition for the input parameters required to initialize an - * instance of the UpstashRedisChatMessageHistory class. - */ -export type UpstashRedisChatMessageHistoryInput = { - sessionId: string; - sessionTTL?: number; - config?: RedisConfigNodejs; - client?: Redis; -}; - -/** - * Class used to store chat message history in Redis. It provides methods - * to add, get, and clear messages. - */ -export class UpstashRedisChatMessageHistory extends BaseListChatMessageHistory { - lc_namespace = ["langchain", "stores", "message", "upstash_redis"]; - - get lc_secrets() { - return { - "config.url": "UPSTASH_REDIS_REST_URL", - "config.token": "UPSTASH_REDIS_REST_TOKEN", - }; - } - - public client: Redis; - - private sessionId: string; - - private sessionTTL?: number; - - constructor(fields: UpstashRedisChatMessageHistoryInput) { - super(fields); - const { sessionId, sessionTTL, config, client } = fields; - if (client) { - this.client = client; - } else if (config) { - this.client = new Redis(config); - } else { - throw new Error( - `Upstash Redis message stores require either a config object or a pre-configured client.` - ); - } - this.sessionId = sessionId; - this.sessionTTL = sessionTTL; - } - - /** - * Retrieves the chat messages from the Redis database. - * @returns An array of BaseMessage instances representing the chat history. - */ - async getMessages(): Promise { - const rawStoredMessages: StoredMessage[] = - await this.client.lrange(this.sessionId, 0, -1); - - const orderedMessages = rawStoredMessages.reverse(); - const previousMessages = orderedMessages.filter( - (x): x is StoredMessage => - x.type !== undefined && x.data.content !== undefined - ); - return mapStoredMessagesToChatMessages(previousMessages); - } - - /** - * Adds a new message to the chat history in the Redis database. - * @param message The message to be added to the chat history. - * @returns Promise resolving to void. - */ - async addMessage(message: BaseMessage): Promise { - const messageToAdd = mapChatMessagesToStoredMessages([message]); - await this.client.lpush(this.sessionId, JSON.stringify(messageToAdd[0])); - if (this.sessionTTL) { - await this.client.expire(this.sessionId, this.sessionTTL); - } - } - - /** - * Deletes all messages from the chat history in the Redis database. - * @returns Promise resolving to void. - */ - async clear(): Promise { - await this.client.del(this.sessionId); - } -} +export * from "@langchain/community/stores/message/upstash_redis"; diff --git a/langchain/src/stores/message/utils.ts b/langchain/src/stores/message/utils.ts index 81a958e5a7e2..3a65c2b25031 100644 --- a/langchain/src/stores/message/utils.ts +++ b/langchain/src/stores/message/utils.ts @@ -1,31 +1,4 @@ -import { - BaseMessage, - StoredMessage, - mapStoredMessageToChatMessage, -} from "../../schema/index.js"; - -/** - * Transforms an array of `StoredMessage` instances into an array of - * `BaseMessage` instances. It uses the `mapV1MessageToStoredMessage` - * function to ensure all messages are in the `StoredMessage` format, then - * creates new instances of the appropriate `BaseMessage` subclass based - * on the type of each message. This function is used to prepare stored - * messages for use in a chat context. - */ -export function mapStoredMessagesToChatMessages( - messages: StoredMessage[] -): BaseMessage[] { - return messages.map(mapStoredMessageToChatMessage); -} - -/** - * Transforms an array of `BaseMessage` instances into an array of - * `StoredMessage` instances. It does this by calling the `toDict` method - * on each `BaseMessage`, which returns a `StoredMessage`. This function - * is used to prepare chat messages for storage. - */ -export function mapChatMessagesToStoredMessages( - messages: BaseMessage[] -): StoredMessage[] { - return messages.map((message) => message.toDict()); -} +export { + mapStoredMessagesToChatMessages, + mapChatMessagesToStoredMessages, +} from "@langchain/core/messages"; diff --git a/langchain/src/stores/message/xata.ts b/langchain/src/stores/message/xata.ts index b459f38de8f2..9df54fb3b943 100644 --- a/langchain/src/stores/message/xata.ts +++ b/langchain/src/stores/message/xata.ts @@ -1,243 +1 @@ -import { - BaseClient, - BaseClientOptions, - GetTableSchemaResponse, - Schemas, - XataApiClient, - parseWorkspacesUrlParts, -} from "@xata.io/client"; -import { - BaseMessage, - BaseListChatMessageHistory, - StoredMessage, - StoredMessageData, -} from "../../schema/index.js"; -import { - mapChatMessagesToStoredMessages, - mapStoredMessagesToChatMessages, -} from "./utils.js"; - -/** - * An object type that represents the input for the XataChatMessageHistory - * class. - */ -export type XataChatMessageHistoryInput = { - sessionId: string; - config?: BaseClientOptions; - client?: XataClient; - table?: string; - createTable?: boolean; - apiKey?: string; -}; - -/** - * An interface that represents the data transfer object for stored - * messages. - */ -interface storedMessagesDTO { - id: string; - sessionId: string; - type: string; - content: string; - role?: string; - name?: string; - additionalKwargs: string; -} - -const chatMemoryColumns: Schemas.Column[] = [ - { name: "sessionId", type: "string" }, - { name: "type", type: "string" }, - { name: "role", type: "string" }, - { name: "content", type: "text" }, - { name: "name", type: "string" }, - { name: "additionalKwargs", type: "text" }, -]; - -/** - * A class for managing chat message history using Xata.io client. It - * extends the BaseListChatMessageHistory class and provides methods to - * get, add, and clear messages. It also ensures the existence of a table - * where the chat messages are stored. - * @example - * ```typescript - * const chatHistory = new XataChatMessageHistory({ - * table: "messages", - * sessionId: new Date().toISOString(), - * client: new BaseClient({ - * databaseURL: process.env.XATA_DB_URL, - * apiKey: process.env.XATA_API_KEY, - * branch: "main", - * }), - * apiKey: process.env.XATA_API_KEY, - * }); - * - * const chain = new ConversationChain({ - * llm: new ChatOpenAI(), - * memory: new BufferMemory({ chatHistory }), - * }); - * - * const response = await chain.invoke({ - * input: "What did I just say my name was?", - * }); - * console.log({ response }); - * ``` - */ -export class XataChatMessageHistory< - XataClient extends BaseClient -> extends BaseListChatMessageHistory { - lc_namespace = ["langchain", "stores", "message", "xata"]; - - public client: XataClient; - - private sessionId: string; - - private table: string; - - private tableInitialized: boolean; - - private createTable: boolean; - - private apiClient: XataApiClient; - - constructor(fields: XataChatMessageHistoryInput) { - super(fields); - - const { sessionId, config, client, table } = fields; - this.sessionId = sessionId; - this.table = table || "memory"; - if (client) { - this.client = client; - } else if (config) { - this.client = new BaseClient(config) as XataClient; - } else { - throw new Error( - "Either a client or a config must be provided to XataChatMessageHistoryInput" - ); - } - if (fields.createTable !== false) { - this.createTable = true; - const apiKey = fields.apiKey || fields.config?.apiKey; - if (!apiKey) { - throw new Error( - "If createTable is set, an apiKey must be provided to XataChatMessageHistoryInput, either directly or through the config object" - ); - } - this.apiClient = new XataApiClient({ apiKey }); - } else { - this.createTable = false; - } - this.tableInitialized = false; - } - - /** - * Retrieves all messages associated with the session ID, ordered by - * creation time. - * @returns A promise that resolves to an array of BaseMessage instances. - */ - async getMessages(): Promise { - await this.ensureTable(); - const records = await this.client.db[this.table] - .filter({ sessionId: this.sessionId }) - .sort("xata.createdAt", "asc") - .getAll(); - - const rawStoredMessages = records as unknown as storedMessagesDTO[]; - const orderedMessages: StoredMessage[] = rawStoredMessages.map( - (message: storedMessagesDTO) => { - const data = { - content: message.content, - additional_kwargs: JSON.parse(message.additionalKwargs), - } as StoredMessageData; - if (message.role) { - data.role = message.role; - } - if (message.name) { - data.name = message.name; - } - - return { - type: message.type, - data, - }; - } - ); - return mapStoredMessagesToChatMessages(orderedMessages); - } - - /** - * Adds a new message to the database. - * @param message The BaseMessage instance to be added. - * @returns A promise that resolves when the message has been added. - */ - async addMessage(message: BaseMessage): Promise { - await this.ensureTable(); - const messageToAdd = mapChatMessagesToStoredMessages([message]); - await this.client.db[this.table].create({ - sessionId: this.sessionId, - type: messageToAdd[0].type, - content: messageToAdd[0].data.content, - role: messageToAdd[0].data.role, - name: messageToAdd[0].data.name, - additionalKwargs: JSON.stringify(messageToAdd[0].data.additional_kwargs), - }); - } - - /** - * Deletes all messages associated with the session ID. - * @returns A promise that resolves when the messages have been deleted. - */ - async clear(): Promise { - await this.ensureTable(); - const records = await this.client.db[this.table] - .select(["id"]) - .filter({ sessionId: this.sessionId }) - .getAll(); - const ids = records.map((m) => m.id); - await this.client.db[this.table].delete(ids); - } - - /** - * Checks if the table exists and creates it if it doesn't. This method is - * called before any operation on the table. - * @returns A promise that resolves when the table has been ensured. - */ - private async ensureTable(): Promise { - if (!this.createTable) { - return; - } - if (this.tableInitialized) { - return; - } - - const { databaseURL, branch } = await this.client.getConfig(); - const [, , host, , database] = databaseURL.split("/"); - const urlParts = parseWorkspacesUrlParts(host); - if (urlParts == null) { - throw new Error("Invalid databaseURL"); - } - const { workspace, region } = urlParts; - const tableParams = { - workspace, - region, - database, - branch, - table: this.table, - }; - - let schema: GetTableSchemaResponse | null = null; - try { - schema = await this.apiClient.tables.getTableSchema(tableParams); - } catch (e) { - // pass - } - if (schema == null) { - await this.apiClient.tables.createTable(tableParams); - await this.apiClient.tables.setTableSchema({ - ...tableParams, - schema: { - columns: chatMemoryColumns, - }, - }); - } - } -} +export * from "@langchain/community/stores/message/xata"; diff --git a/langchain/src/tools/IFTTTWebhook.ts b/langchain/src/tools/IFTTTWebhook.ts index 842dbc9d1aa1..a5795255d9d4 100644 --- a/langchain/src/tools/IFTTTWebhook.ts +++ b/langchain/src/tools/IFTTTWebhook.ts @@ -1,79 +1 @@ -/** From https://github.com/SidU/teams-langchain-js/wiki/Connecting-IFTTT-Services. - -# Creating a webhook -- Go to https://ifttt.com/create - -# Configuring the "If This" -- Click on the "If This" button in the IFTTT interface. -- Search for "Webhooks" in the search bar. -- Choose the first option for "Receive a web request with a JSON payload." -- Choose an Event Name that is specific to the service you plan to connect to. -This will make it easier for you to manage the webhook URL. -For example, if you're connecting to Spotify, you could use "Spotify" as your -Event Name. -- Click the "Create Trigger" button to save your settings and create your webhook. - -# Configuring the "Then That" -- Tap on the "Then That" button in the IFTTT interface. -- Search for the service you want to connect, such as Spotify. -- Choose an action from the service, such as "Add track to a playlist". -- Configure the action by specifying the necessary details, such as the playlist name, -e.g., "Songs from AI". -- Reference the JSON Payload received by the Webhook in your action. For the Spotify -scenario, choose "{{JsonPayload}}" as your search query. -- Tap the "Create Action" button to save your action settings. -- Once you have finished configuring your action, click the "Finish" button to -complete the setup. -- Congratulations! You have successfully connected the Webhook to the desired -service, and you're ready to start receiving data and triggering actions 🎉 - -# Finishing up -- To get your webhook URL go to https://ifttt.com/maker_webhooks/settings -- Copy the IFTTT key value from there. The URL is of the form -https://maker.ifttt.com/use/YOUR_IFTTT_KEY. Grab the YOUR_IFTTT_KEY value. - */ -import { Tool } from "./base.js"; - -/** - * Represents a tool for creating and managing webhooks with the IFTTT (If - * This Then That) service. The IFTTT service allows users to create - * chains of simple conditional statements, called applets, which are - * triggered based on changes to other web services. - */ -export class IFTTTWebhook extends Tool { - static lc_name() { - return "IFTTTWebhook"; - } - - private url: string; - - name: string; - - description: string; - - constructor(url: string, name: string, description: string) { - super(...arguments); - this.url = url; - this.name = name; - this.description = description; - } - - /** @ignore */ - async _call(input: string): Promise { - const headers = { "Content-Type": "application/json" }; - const body = JSON.stringify({ this: input }); - - const response = await fetch(this.url, { - method: "POST", - headers, - body, - }); - - if (!response.ok) { - throw new Error(`HTTP error ${response.status}`); - } - - const result = await response.text(); - return result; - } -} +export * from "@langchain/community/tools/ifttt"; diff --git a/langchain/src/tools/aiplugin.ts b/langchain/src/tools/aiplugin.ts index c08720f7559e..47d7705c5eaf 100644 --- a/langchain/src/tools/aiplugin.ts +++ b/langchain/src/tools/aiplugin.ts @@ -1,81 +1 @@ -import { Tool, ToolParams } from "./base.js"; - -/** - * Interface for parameters required to create an instance of - * AIPluginTool. - */ -export interface AIPluginToolParams extends ToolParams { - name: string; - description: string; - apiSpec: string; -} - -/** - * Class for creating instances of AI tools from plugins. It extends the - * Tool class and implements the AIPluginToolParams interface. - */ -export class AIPluginTool extends Tool implements AIPluginToolParams { - static lc_name() { - return "AIPluginTool"; - } - - private _name: string; - - private _description: string; - - apiSpec: string; - - get name() { - return this._name; - } - - get description() { - return this._description; - } - - constructor(params: AIPluginToolParams) { - super(params); - this._name = params.name; - this._description = params.description; - this.apiSpec = params.apiSpec; - } - - /** @ignore */ - async _call(_input: string) { - return this.apiSpec; - } - - /** - * Static method that creates an instance of AIPluginTool from a given - * plugin URL. It fetches the plugin and its API specification from the - * provided URL and returns a new instance of AIPluginTool with the - * fetched data. - * @param url The URL of the AI plugin. - * @returns A new instance of AIPluginTool. - */ - static async fromPluginUrl(url: string) { - const aiPluginRes = await fetch(url); - if (!aiPluginRes.ok) { - throw new Error( - `Failed to fetch plugin from ${url} with status ${aiPluginRes.status}` - ); - } - const aiPluginJson = await aiPluginRes.json(); - - const apiUrlRes = await fetch(aiPluginJson.api.url); - if (!apiUrlRes.ok) { - throw new Error( - `Failed to fetch API spec from ${aiPluginJson.api.url} with status ${apiUrlRes.status}` - ); - } - const apiUrlJson = await apiUrlRes.text(); - - return new AIPluginTool({ - name: aiPluginJson.name_for_model, - description: `Call this tool to get the OpenAPI spec (and usage guide) for interacting with the ${aiPluginJson.name_for_human} API. You should only call this ONCE! What is the ${aiPluginJson.name_for_human} API useful for? ${aiPluginJson.description_for_human}`, - apiSpec: `Usage Guide: ${aiPluginJson.description_for_model} - -OpenAPI Spec in JSON or YAML format:\n${apiUrlJson}`, - }); - } -} +export * from "@langchain/community/tools/aiplugin"; diff --git a/langchain/src/tools/aws_sfn.ts b/langchain/src/tools/aws_sfn.ts index 8d375f88f540..5c3eb9d14e0f 100644 --- a/langchain/src/tools/aws_sfn.ts +++ b/langchain/src/tools/aws_sfn.ts @@ -1,225 +1 @@ -import { - SFNClient as Client, - StartExecutionCommand as Invoker, - DescribeExecutionCommand as Describer, - SendTaskSuccessCommand as TaskSuccessSender, -} from "@aws-sdk/client-sfn"; - -import { Tool, ToolParams } from "./base.js"; - -/** - * Interface for AWS Step Functions configuration. - */ -export interface SfnConfig { - stateMachineArn: string; - region?: string; - accessKeyId?: string; - secretAccessKey?: string; -} - -/** - * Interface for AWS Step Functions client constructor arguments. - */ -interface SfnClientConstructorArgs { - region?: string; - credentials?: { - accessKeyId: string; - secretAccessKey: string; - }; -} - -/** - * Class for starting the execution of an AWS Step Function. - */ -export class StartExecutionAWSSfnTool extends Tool { - static lc_name() { - return "StartExecutionAWSSfnTool"; - } - - private sfnConfig: SfnConfig; - - public name: string; - - public description: string; - - constructor({ - name, - description, - ...rest - }: SfnConfig & { name: string; description: string }) { - super(); - this.name = name; - this.description = description; - this.sfnConfig = rest; - } - - /** - * Generates a formatted description for the StartExecutionAWSSfnTool. - * @param name Name of the state machine. - * @param description Description of the state machine. - * @returns A formatted description string. - */ - static formatDescription(name: string, description: string): string { - return `Use to start executing the ${name} state machine. Use to run ${name} workflows. Whenever you need to start (or execute) an asynchronous workflow (or state machine) about ${description} you should ALWAYS use this. Input should be a valid JSON string.`; - } - - /** @ignore */ - async _call(input: string): Promise { - const clientConstructorArgs: SfnClientConstructorArgs = - getClientConstructorArgs(this.sfnConfig); - const sfnClient = new Client(clientConstructorArgs); - - return new Promise((resolve) => { - let payload; - try { - payload = JSON.parse(input); - } catch (e) { - console.error("Error starting state machine execution:", e); - resolve("failed to complete request"); - } - - const command = new Invoker({ - stateMachineArn: this.sfnConfig.stateMachineArn, - input: JSON.stringify(payload), - }); - - sfnClient - .send(command) - .then((response) => - resolve( - response.executionArn ? response.executionArn : "request completed." - ) - ) - .catch((error: Error) => { - console.error("Error starting state machine execution:", error); - resolve("failed to complete request"); - }); - }); - } -} - -/** - * Class for checking the status of an AWS Step Function execution. - */ -export class DescribeExecutionAWSSfnTool extends Tool { - static lc_name() { - return "DescribeExecutionAWSSfnTool"; - } - - name = "describe-execution-aws-sfn"; - - description = - "This tool should ALWAYS be used for checking the status of any AWS Step Function execution (aka. state machine execution). Input to this tool is a properly formatted AWS Step Function Execution ARN (executionArn). The output is a stringified JSON object containing the executionArn, name, status, startDate, stopDate, input, output, error, and cause of the execution."; - - sfnConfig: Omit; - - constructor(config: Omit & ToolParams) { - super(config); - this.sfnConfig = config; - } - - /** @ignore */ - async _call(input: string) { - const clientConstructorArgs: SfnClientConstructorArgs = - getClientConstructorArgs(this.sfnConfig); - const sfnClient = new Client(clientConstructorArgs); - - const command = new Describer({ - executionArn: input, - }); - return await sfnClient - .send(command) - .then((response) => - response.executionArn - ? JSON.stringify({ - executionArn: response.executionArn, - name: response.name, - status: response.status, - startDate: response.startDate, - stopDate: response.stopDate, - input: response.input, - output: response.output, - error: response.error, - cause: response.cause, - }) - : "{}" - ) - .catch((error: Error) => { - console.error("Error describing state machine execution:", error); - return "failed to complete request"; - }); - } -} - -/** - * Class for sending a task success signal to an AWS Step Function - * execution. - */ -export class SendTaskSuccessAWSSfnTool extends Tool { - static lc_name() { - return "SendTaskSuccessAWSSfnTool"; - } - - name = "send-task-success-aws-sfn"; - - description = - "This tool should ALWAYS be used for sending task success to an AWS Step Function execution (aka. statemachine exeuction). Input to this tool is a stringify JSON object containing the taskToken and output."; - - sfnConfig: Omit; - - constructor(config: Omit & ToolParams) { - super(config); - this.sfnConfig = config; - } - - /** @ignore */ - async _call(input: string) { - const clientConstructorArgs: SfnClientConstructorArgs = - getClientConstructorArgs(this.sfnConfig); - const sfnClient = new Client(clientConstructorArgs); - - let payload; - try { - payload = JSON.parse(input); - } catch (e) { - console.error("Error starting state machine execution:", e); - return "failed to complete request"; - } - - const command = new TaskSuccessSender({ - taskToken: payload.taskToken, - output: JSON.stringify(payload.output), - }); - - return await sfnClient - .send(command) - .then(() => "request completed.") - .catch((error: Error) => { - console.error( - "Error sending task success to state machine execution:", - error - ); - return "failed to complete request"; - }); - } -} - -/** - * Helper function to construct the AWS SFN client. - */ -function getClientConstructorArgs(config: Partial) { - const clientConstructorArgs: SfnClientConstructorArgs = {}; - - if (config.region) { - clientConstructorArgs.region = config.region; - } - - if (config.accessKeyId && config.secretAccessKey) { - clientConstructorArgs.credentials = { - accessKeyId: config.accessKeyId, - secretAccessKey: config.secretAccessKey, - }; - } - - return clientConstructorArgs; -} +export * from "@langchain/community/tools/aws_sfn"; diff --git a/langchain/src/tools/bingserpapi.ts b/langchain/src/tools/bingserpapi.ts index 0a3010f30034..378e52d4e2c3 100644 --- a/langchain/src/tools/bingserpapi.ts +++ b/langchain/src/tools/bingserpapi.ts @@ -1,78 +1 @@ -import { getEnvironmentVariable } from "../util/env.js"; -import { Tool } from "./base.js"; - -/** - * A tool for web search functionality using Bing's search engine. It - * extends the base `Tool` class and implements the `_call` method to - * perform the search operation. Requires an API key for Bing's search - * engine, which can be set in the environment variables. Also accepts - * additional parameters for the search query. - */ -class BingSerpAPI extends Tool { - static lc_name() { - return "BingSerpAPI"; - } - - /** - * Not implemented. Will throw an error if called. - */ - toJSON() { - return this.toJSONNotImplemented(); - } - - name = "bing-search"; - - description = - "a search engine. useful for when you need to answer questions about current events. input should be a search query."; - - key: string; - - params: Record; - - constructor( - apiKey: string | undefined = getEnvironmentVariable("BingApiKey"), - params: Record = {} - ) { - super(...arguments); - - if (!apiKey) { - throw new Error( - "BingSerpAPI API key not set. You can set it as BingApiKey in your .env file." - ); - } - - this.key = apiKey; - this.params = params; - } - - /** @ignore */ - async _call(input: string): Promise { - const headers = { "Ocp-Apim-Subscription-Key": this.key }; - const params = { q: input, textDecorations: "true", textFormat: "HTML" }; - const searchUrl = new URL("https://api.bing.microsoft.com/v7.0/search"); - - Object.entries(params).forEach(([key, value]) => { - searchUrl.searchParams.append(key, value); - }); - - const response = await fetch(searchUrl, { headers }); - - if (!response.ok) { - throw new Error(`HTTP error ${response.status}`); - } - - const res = await response.json(); - const results: [] = res.webPages.value; - - if (results.length === 0) { - return "No good results found."; - } - const snippets = results - .map((result: { snippet: string }) => result.snippet) - .join(" "); - - return snippets; - } -} - -export { BingSerpAPI }; +export * from "@langchain/community/tools/bingserpapi"; diff --git a/langchain/src/tools/brave_search.ts b/langchain/src/tools/brave_search.ts index 7d5498f2b95b..07bfc4995109 100644 --- a/langchain/src/tools/brave_search.ts +++ b/langchain/src/tools/brave_search.ts @@ -1,77 +1 @@ -import { getEnvironmentVariable } from "../util/env.js"; -import { Tool } from "./base.js"; - -/** - * Interface for the parameters required to instantiate a BraveSearch - * instance. - */ -export interface BraveSearchParams { - apiKey?: string; -} - -/** - * Class for interacting with the Brave Search engine. It extends the Tool - * class and requires an API key to function. The API key can be passed in - * during instantiation or set as an environment variable named - * 'BRAVE_SEARCH_API_KEY'. - */ -export class BraveSearch extends Tool { - static lc_name() { - return "BraveSearch"; - } - - name = "brave-search"; - - description = - "a search engine. useful for when you need to answer questions about current events. input should be a search query."; - - apiKey: string; - - constructor( - fields: BraveSearchParams = { - apiKey: getEnvironmentVariable("BRAVE_SEARCH_API_KEY"), - } - ) { - super(); - - if (!fields.apiKey) { - throw new Error( - `Brave API key not set. Please pass it in or set it as an environment variable named "BRAVE_SEARCH_API_KEY".` - ); - } - - this.apiKey = fields.apiKey; - } - - /** @ignore */ - async _call(input: string): Promise { - const headers = { - "X-Subscription-Token": this.apiKey, - Accept: "application/json", - }; - const searchUrl = new URL( - `https://api.search.brave.com/res/v1/web/search?q=${encodeURIComponent( - input - )}` - ); - - const response = await fetch(searchUrl, { headers }); - - if (!response.ok) { - throw new Error(`HTTP error ${response.status}`); - } - - const parsedResponse = await response.json(); - const webSearchResults = parsedResponse.web?.results; - const finalResults = Array.isArray(webSearchResults) - ? webSearchResults.map( - (item: { title?: string; url?: string; description?: string }) => ({ - title: item.title, - link: item.url, - snippet: item.description, - }) - ) - : []; - return JSON.stringify(finalResults); - } -} +export * from "@langchain/community/tools/brave_search"; diff --git a/langchain/src/tools/connery.ts b/langchain/src/tools/connery.ts index 5ee594ba85a9..73bad45cebf7 100644 --- a/langchain/src/tools/connery.ts +++ b/langchain/src/tools/connery.ts @@ -1,353 +1 @@ -import { AsyncCaller, AsyncCallerParams } from "../util/async_caller.js"; -import { getEnvironmentVariable } from "../util/env.js"; -import { Tool } from "./base.js"; - -/** - * An object containing configuration parameters for the ConneryService class. - * @extends AsyncCallerParams - */ -export interface ConneryServiceParams extends AsyncCallerParams { - runnerUrl: string; - apiKey: string; -} - -type ApiResponse = { - status: "success"; - data: T; -}; - -type ApiErrorResponse = { - status: "error"; - error: { - message: string; - }; -}; - -type Parameter = { - key: string; - title: string; - description: string; - type: string; - validation?: { - required?: boolean; - }; -}; - -type Action = { - id: string; - key: string; - title: string; - description: string; - type: string; - inputParameters: Parameter[]; - outputParameters: Parameter[]; - pluginId: string; -}; - -type Input = { - [key: string]: string; -}; - -type Output = { - [key: string]: string; -}; - -type RunActionResult = { - output: Output; - used: { - actionId: string; - input: Input; - }; -}; - -/** - * A LangChain Tool object wrapping a Connery action. - * @extends Tool - */ -export class ConneryAction extends Tool { - name: string; - - description: string; - - /** - * Creates a ConneryAction instance based on the provided Connery action. - * @param _action The Connery action. - * @param _service The ConneryService instance. - * @returns A ConneryAction instance. - */ - constructor(protected _action: Action, protected _service: ConneryService) { - super(); - - this.name = this._action.title; - this.description = this.getDescription(); - } - - /** - * Runs the Connery action. - * @param prompt This is a plain English prompt with all the information needed to run the action. - * @returns A promise that resolves to a JSON string containing the output of the action. - */ - protected _call(prompt: string): Promise { - return this._service.runAction(this._action.id, prompt); - } - - /** - * Returns the description of the Connery action. - * @returns A string containing the description of the Connery action together with the instructions on how to use it. - */ - protected getDescription(): string { - const { title, description } = this._action; - const inputParameters = this.prepareJsonForTemplate( - this._action.inputParameters - ); - const example1InputParametersSchema = this.prepareJsonForTemplate([ - { - key: "recipient", - title: "Email Recipient", - description: "Email address of the email recipient.", - type: "string", - validation: { - required: true, - }, - }, - { - key: "subject", - title: "Email Subject", - description: "Subject of the email.", - type: "string", - validation: { - required: true, - }, - }, - { - key: "body", - title: "Email Body", - description: "Body of the email.", - type: "string", - validation: { - required: true, - }, - }, - ]); - - const descriptionTemplate = - "# Instructions about tool input:\n" + - "The input to this tool is a plain English prompt with all the input parameters needed to call it. " + - "The input parameters schema of this tool is provided below. " + - "Use the input parameters schema to construct the prompt for the tool. " + - "If the input parameter is required in the schema, it must be provided in the prompt. " + - "Do not come up with the values for the input parameters yourself. " + - "If you do not have enough information to fill in the input parameter, ask the user to provide it. " + - "See examples below on how to construct the prompt based on the provided tool information. " + - "\n\n" + - "# Instructions about tool output:\n" + - "The output of this tool is a JSON string. " + - "Retrieve the output parameters from the JSON string and use them in the next tool. " + - "Do not return the JSON string as the output of the tool. " + - "\n\n" + - "# Example:\n" + - "Tool information:\n" + - "- Title: Send email\n" + - "- Description: Send an email to a recipient.\n" + - `- Input parameters schema in JSON fromat: ${example1InputParametersSchema}\n` + - "The tool input prompt:\n" + - "recipient: test@example.com, subject: 'Test email', body: 'This is a test email sent from Langchain Connery tool.'\n" + - "\n\n" + - "# The tool information\n" + - `- Title: ${title}\n` + - `- Description: ${description}\n` + - `- Input parameters schema in JSON fromat: ${inputParameters}\n`; - - return descriptionTemplate; - } - - /** - * Converts the provided object to a JSON string and escapes '{' and '}' characters. - * @param obj The object to convert to a JSON string. - * @returns A string containing the JSON representation of the provided object with '{' and '}' characters escaped. - */ - // eslint-disable-next-line @typescript-eslint/no-explicit-any - protected prepareJsonForTemplate(obj: any): string { - // Convert the object to a JSON string - const jsonString = JSON.stringify(obj); - - // Replace '{' with '{{' and '}' with '}}' - const escapedJSON = jsonString.replace(/{/g, "{{").replace(/}/g, "}}"); - - return escapedJSON; - } -} - -/** - * A service for working with Connery actions. - * - * Connery is an open-source plugin infrastructure for AI. - * Source code: https://github.com/connery-io/connery-platform - */ -export class ConneryService { - protected runnerUrl: string; - - protected apiKey: string; - - protected asyncCaller: AsyncCaller; - - /** - * Creates a ConneryService instance. - * @param params A ConneryServiceParams object. - * If not provided, the values are retrieved from the CONNERY_RUNNER_URL - * and CONNERY_RUNNER_API_KEY environment variables. - * @returns A ConneryService instance. - */ - constructor(params?: ConneryServiceParams) { - const runnerUrl = - params?.runnerUrl ?? getEnvironmentVariable("CONNERY_RUNNER_URL"); - const apiKey = - params?.apiKey ?? getEnvironmentVariable("CONNERY_RUNNER_API_KEY"); - - if (!runnerUrl || !apiKey) { - throw new Error( - "CONNERY_RUNNER_URL and CONNERY_RUNNER_API_KEY environment variables must be set." - ); - } - - this.runnerUrl = runnerUrl; - this.apiKey = apiKey; - - this.asyncCaller = new AsyncCaller(params ?? {}); - } - - /** - * Returns the list of Connery actions wrapped as a LangChain Tool objects. - * @returns A promise that resolves to an array of ConneryAction objects. - */ - async listActions(): Promise { - const actions = await this._listActions(); - return actions.map((action) => new ConneryAction(action, this)); - } - - /** - * Returns the specified Connery action wrapped as a LangChain Tool object. - * @param actionId The ID of the action to return. - * @returns A promise that resolves to a ConneryAction object. - */ - async getAction(actionId: string): Promise { - const action = await this._getAction(actionId); - return new ConneryAction(action, this); - } - - /** - * Runs the specified Connery action with the provided input. - * @param actionId The ID of the action to run. - * @param prompt This is a plain English prompt with all the information needed to run the action. - * @param input The input expected by the action. - * If provided together with the prompt, the input takes precedence over the input specified in the prompt. - * @returns A promise that resolves to a JSON string containing the output of the action. - */ - async runAction( - actionId: string, - prompt?: string, - input?: Input - ): Promise { - const result = await this._runAction(actionId, prompt, input); - return JSON.stringify(result); - } - - /** - * Returns the list of actions available in the Connery runner. - * @returns A promise that resolves to an array of Action objects. - */ - protected async _listActions(): Promise { - const response = await this.asyncCaller.call( - fetch, - `${this.runnerUrl}/v1/actions`, - { - method: "GET", - headers: this._getHeaders(), - } - ); - await this._handleError(response, "Failed to list actions"); - - const apiResponse: ApiResponse = await response.json(); - return apiResponse.data; - } - - /** - * Returns the specified action available in the Connery runner. - * @param actionId The ID of the action to return. - * @returns A promise that resolves to an Action object. - * @throws An error if the action with the specified ID is not found. - */ - protected async _getAction(actionId: string): Promise { - const actions = await this._listActions(); - const action = actions.find((a) => a.id === actionId); - if (!action) { - throw new Error( - `The action with ID "${actionId}" was not found in the list of available actions in the Connery runner.` - ); - } - return action; - } - - /** - * Runs the specified Connery action with the provided input. - * @param actionId The ID of the action to run. - * @param prompt This is a plain English prompt with all the information needed to run the action. - * @param input The input object expected by the action. - * If provided together with the prompt, the input takes precedence over the input specified in the prompt. - * @returns A promise that resolves to a RunActionResult object. - */ - protected async _runAction( - actionId: string, - prompt?: string, - input?: Input - ): Promise { - const response = await this.asyncCaller.call( - fetch, - `${this.runnerUrl}/v1/actions/${actionId}/run`, - { - method: "POST", - headers: this._getHeaders(), - body: JSON.stringify({ - prompt, - input, - }), - } - ); - await this._handleError(response, "Failed to run action"); - - const apiResponse: ApiResponse = await response.json(); - return apiResponse.data.output; - } - - /** - * Returns a standard set of HTTP headers to be used in API calls to the Connery runner. - * @returns An object containing the standard set of HTTP headers. - */ - protected _getHeaders(): Record { - return { - "Content-Type": "application/json", - "x-api-key": this.apiKey, - }; - } - - /** - * Shared error handler for API calls to the Connery runner. - * If the response is not ok, an error is thrown containing the error message returned by the Connery runner. - * Otherwise, the promise resolves to void. - * @param response The response object returned by the Connery runner. - * @param errorMessage The error message to be used in the error thrown if the response is not ok. - * @returns A promise that resolves to void. - * @throws An error containing the error message returned by the Connery runner. - */ - protected async _handleError( - response: Response, - errorMessage: string - ): Promise { - if (response.ok) return; - - const apiErrorResponse: ApiErrorResponse = await response.json(); - throw new Error( - `${errorMessage}. Status code: ${response.status}. Error message: ${apiErrorResponse.error.message}` - ); - } -} +export * from "@langchain/community/tools/connery"; diff --git a/langchain/src/tools/convert_to_openai.ts b/langchain/src/tools/convert_to_openai.ts index 0c599a14e3b7..b5a875324ec8 100644 --- a/langchain/src/tools/convert_to_openai.ts +++ b/langchain/src/tools/convert_to_openai.ts @@ -1,5 +1,5 @@ import { zodToJsonSchema } from "zod-to-json-schema"; -import type { OpenAI as OpenAIClient } from "openai"; +import type { OpenAIClient } from "@langchain/openai"; import { StructuredTool } from "./base.js"; diff --git a/langchain/src/tools/dadjokeapi.ts b/langchain/src/tools/dadjokeapi.ts index 856b7a519155..8b5874b0e2b5 100644 --- a/langchain/src/tools/dadjokeapi.ts +++ b/langchain/src/tools/dadjokeapi.ts @@ -1,44 +1 @@ -import { Tool } from "./base.js"; - -/** - * The DadJokeAPI class is a tool for generating dad jokes based on a - * specific topic. It fetches jokes from an external API and returns a - * random joke from the results. If no jokes are found for the given - * search term, it returns a message indicating that no jokes were found. - */ -class DadJokeAPI extends Tool { - static lc_name() { - return "DadJokeAPI"; - } - - name = "dadjoke"; - - description = - "a dad joke generator. get a dad joke about a specific topic. input should be a search term."; - - /** @ignore */ - async _call(input: string): Promise { - const headers = { Accept: "application/json" }; - const searchUrl = `https://icanhazdadjoke.com/search?term=${input}`; - - const response = await fetch(searchUrl, { headers }); - - if (!response.ok) { - throw new Error(`HTTP error ${response.status}`); - } - - const data = await response.json(); - const jokes = data.results; - - if (jokes.length === 0) { - return `No dad jokes found about ${input}`; - } - - const randomIndex = Math.floor(Math.random() * jokes.length); - const randomJoke = jokes[randomIndex].joke; - - return randomJoke; - } -} - -export { DadJokeAPI }; +export * from "@langchain/community/tools/dadjokeapi"; diff --git a/langchain/src/tools/dataforseo_api_search.ts b/langchain/src/tools/dataforseo_api_search.ts index 6f6b246351b1..00e1aedc3c18 100644 --- a/langchain/src/tools/dataforseo_api_search.ts +++ b/langchain/src/tools/dataforseo_api_search.ts @@ -1,378 +1 @@ -import { getEnvironmentVariable } from "../util/env.js"; -import { Tool } from "./base.js"; - -/** - * @interface DataForSeoApiConfig - * @description Represents the configuration object used to set up a DataForSeoAPISearch instance. - */ -export interface DataForSeoApiConfig { - /** - * @property apiLogin - * @type {string} - * @description The API login credential for DataForSEO. If not provided, it will be fetched from environment variables. - */ - apiLogin?: string; - - /** - * @property apiPassword - * @type {string} - * @description The API password credential for DataForSEO. If not provided, it will be fetched from environment variables. - */ - apiPassword?: string; - - /** - * @property params - * @type {Record} - * @description Additional parameters to customize the API request. - */ - params?: Record; - - /** - * @property useJsonOutput - * @type {boolean} - * @description Determines if the output should be in JSON format. - */ - useJsonOutput?: boolean; - - /** - * @property jsonResultTypes - * @type {Array} - * @description Specifies the types of results to include in the output. - */ - jsonResultTypes?: Array; - - /** - * @property jsonResultFields - * @type {Array} - * @description Specifies the fields to include in each result object. - */ - jsonResultFields?: Array; - - /** - * @property topCount - * @type {number} - * @description Specifies the maximum number of results to return. - */ - topCount?: number; -} - -/** - * Represents a task in the API response. - */ -type Task = { - id: string; - status_code: number; - status_message: string; - time: string; - result: Result[]; -}; - -/** - * Represents a result in the API response. - */ -type Result = { - keyword: string; - check_url: string; - datetime: string; - spell?: string; - item_types: string[]; - se_results_count: number; - items_count: number; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - items: any[]; -}; - -/** - * Represents the API response. - */ -type ApiResponse = { - status_code: number; - status_message: string; - tasks: Task[]; -}; - -/** - * @class DataForSeoAPISearch - * @extends {Tool} - * @description Represents a wrapper class to work with DataForSEO SERP API. - */ -export class DataForSeoAPISearch extends Tool { - static lc_name() { - return "DataForSeoAPISearch"; - } - - name = "dataforseo-api-wrapper"; - - description = - "A robust Google Search API provided by DataForSeo. This tool is handy when you need information about trending topics or current events."; - - protected apiLogin: string; - - protected apiPassword: string; - - /** - * @property defaultParams - * @type {Record} - * @description These are the default parameters to be used when making an API request. - */ - protected defaultParams: Record = { - location_name: "United States", - language_code: "en", - depth: 10, - se_name: "google", - se_type: "organic", - }; - - protected params: Record = {}; - - protected jsonResultTypes: Array | undefined; - - protected jsonResultFields: Array | undefined; - - protected topCount: number | undefined; - - protected useJsonOutput = false; - - /** - * @constructor - * @param {DataForSeoApiConfig} config - * @description Sets up the class, throws an error if the API login/password isn't provided. - */ - constructor(config: DataForSeoApiConfig = {}) { - super(); - const apiLogin = - config.apiLogin ?? getEnvironmentVariable("DATAFORSEO_LOGIN"); - const apiPassword = - config.apiPassword ?? getEnvironmentVariable("DATAFORSEO_PASSWORD"); - const params = config.params ?? {}; - if (!apiLogin || !apiPassword) { - throw new Error( - "DataForSEO login or password not set. You can set it as DATAFORSEO_LOGIN and DATAFORSEO_PASSWORD in your .env file, or pass it to DataForSeoAPISearch." - ); - } - this.params = { ...this.defaultParams, ...params }; - this.apiLogin = apiLogin; - this.apiPassword = apiPassword; - this.jsonResultTypes = config.jsonResultTypes; - this.jsonResultFields = config.jsonResultFields; - this.useJsonOutput = config.useJsonOutput ?? false; - this.topCount = config.topCount; - } - - /** - * @method _call - * @param {string} keyword - * @returns {Promise} - * @description Initiates a call to the API and processes the response. - */ - async _call(keyword: string): Promise { - return this.useJsonOutput - ? JSON.stringify(await this.results(keyword)) - : this.processResponse(await this.getResponseJson(keyword)); - } - - /** - * @method results - * @param {string} keyword - * @returns {Promise>} - * @description Fetches the results from the API for the given keyword. - */ - // eslint-disable-next-line @typescript-eslint/no-explicit-any - async results(keyword: string): Promise> { - const res = await this.getResponseJson(keyword); - return this.filterResults(res, this.jsonResultTypes); - } - - /** - * @method prepareRequest - * @param {string} keyword - * @returns {{url: string; headers: HeadersInit; data: BodyInit}} - * @description Prepares the request details for the API call. - */ - protected prepareRequest(keyword: string): { - url: string; - headers: HeadersInit; - data: BodyInit; - } { - if (this.apiLogin === undefined || this.apiPassword === undefined) { - throw new Error("api_login or api_password is not provided"); - } - - const credentials = Buffer.from( - `${this.apiLogin}:${this.apiPassword}`, - "utf-8" - ).toString("base64"); - const headers = { - Authorization: `Basic ${credentials}`, - "Content-Type": "application/json", - }; - - const params = { ...this.params }; - params.keyword ??= keyword; - const data = [params]; - - return { - url: `https://api.dataforseo.com/v3/serp/${params.se_name}/${params.se_type}/live/advanced`, - headers, - data: JSON.stringify(data), - }; - } - - /** - * @method getResponseJson - * @param {string} keyword - * @returns {Promise} - * @description Executes a POST request to the provided URL and returns a parsed JSON response. - */ - protected async getResponseJson(keyword: string): Promise { - const requestDetails = this.prepareRequest(keyword); - const response = await fetch(requestDetails.url, { - method: "POST", - headers: requestDetails.headers, - body: requestDetails.data, - }); - - if (!response.ok) { - throw new Error( - `Got ${response.status} error from DataForSEO: ${response.statusText}` - ); - } - - const result: ApiResponse = await response.json(); - return this.checkResponse(result); - } - - /** - * @method checkResponse - * @param {ApiResponse} response - * @returns {ApiResponse} - * @description Checks the response status code. - */ - private checkResponse(response: ApiResponse): ApiResponse { - if (response.status_code !== 20000) { - throw new Error( - `Got error from DataForSEO SERP API: ${response.status_message}` - ); - } - for (const task of response.tasks) { - if (task.status_code !== 20000) { - throw new Error( - `Got error from DataForSEO SERP API: ${task.status_message}` - ); - } - } - return response; - } - - /* eslint-disable @typescript-eslint/no-explicit-any */ - /** - * @method filterResults - * @param {ApiResponse} res - * @param {Array | undefined} types - * @returns {Array} - * @description Filters the results based on the specified result types. - */ - private filterResults( - res: ApiResponse, - types: Array | undefined - ): Array { - const output: Array = []; - for (const task of res.tasks || []) { - for (const result of task.result || []) { - for (const item of result.items || []) { - if ( - types === undefined || - types.length === 0 || - types.includes(item.type) - ) { - const newItem = this.cleanupUnnecessaryItems(item); - if (Object.keys(newItem).length !== 0) { - output.push(newItem); - } - } - if (this.topCount !== undefined && output.length >= this.topCount) { - break; - } - } - } - } - return output; - } - - /* eslint-disable @typescript-eslint/no-explicit-any */ - /* eslint-disable no-param-reassign */ - /** - * @method cleanupUnnecessaryItems - * @param {any} d - * @description Removes unnecessary items from the response. - */ - private cleanupUnnecessaryItems(d: any): any { - if (Array.isArray(d)) { - return d.map((item) => this.cleanupUnnecessaryItems(item)); - } - - const toRemove = ["xpath", "position", "rectangle"]; - if (typeof d === "object" && d !== null) { - return Object.keys(d).reduce((newObj: any, key: string) => { - if ( - (this.jsonResultFields === undefined || - this.jsonResultFields.includes(key)) && - !toRemove.includes(key) - ) { - if (typeof d[key] === "object" && d[key] !== null) { - newObj[key] = this.cleanupUnnecessaryItems(d[key]); - } else { - newObj[key] = d[key]; - } - } - return newObj; - }, {}); - } - - return d; - } - - /** - * @method processResponse - * @param {ApiResponse} res - * @returns {string} - * @description Processes the response to extract meaningful data. - */ - protected processResponse(res: ApiResponse): string { - let returnValue = "No good search result found"; - for (const task of res.tasks || []) { - for (const result of task.result || []) { - const { item_types } = result; - const items = result.items || []; - if (item_types.includes("answer_box")) { - returnValue = items.find( - (item: { type: string; text: string }) => item.type === "answer_box" - ).text; - } else if (item_types.includes("knowledge_graph")) { - returnValue = items.find( - (item: { type: string; description: string }) => - item.type === "knowledge_graph" - ).description; - } else if (item_types.includes("featured_snippet")) { - returnValue = items.find( - (item: { type: string; description: string }) => - item.type === "featured_snippet" - ).description; - } else if (item_types.includes("shopping")) { - returnValue = items.find( - (item: { type: string; price: string }) => item.type === "shopping" - ).price; - } else if (item_types.includes("organic")) { - returnValue = items.find( - (item: { type: string; description: string }) => - item.type === "organic" - ).description; - } - if (returnValue) { - break; - } - } - } - return returnValue; - } -} +export * from "@langchain/community/tools/dataforseo_api_search"; diff --git a/langchain/src/tools/gmail/index.ts b/langchain/src/tools/gmail/index.ts index d2f854da54a4..1c913b663768 100644 --- a/langchain/src/tools/gmail/index.ts +++ b/langchain/src/tools/gmail/index.ts @@ -1,12 +1 @@ -export { GmailCreateDraft } from "./create_draft.js"; -export { GmailGetMessage } from "./get_message.js"; -export { GmailGetThread } from "./get_thread.js"; -export { GmailSearch } from "./search.js"; -export { GmailSendMessage } from "./send_message.js"; - -export type { GmailBaseToolParams } from "./base.js"; -export type { CreateDraftSchema } from "./create_draft.js"; -export type { GetMessageSchema } from "./get_message.js"; -export type { GetThreadSchema } from "./get_thread.js"; -export type { SearchSchema } from "./search.js"; -export type { SendMessageSchema } from "./send_message.js"; +export * from "@langchain/community/tools/gmail"; diff --git a/langchain/src/tools/google_custom_search.ts b/langchain/src/tools/google_custom_search.ts index 003353d2419a..5748c242bdaa 100644 --- a/langchain/src/tools/google_custom_search.ts +++ b/langchain/src/tools/google_custom_search.ts @@ -1,83 +1 @@ -import { getEnvironmentVariable } from "../util/env.js"; -import { Tool } from "./base.js"; - -/** - * Interface for parameters required by GoogleCustomSearch class. - */ -export interface GoogleCustomSearchParams { - apiKey?: string; - googleCSEId?: string; -} - -/** - * Class that uses the Google Search API to perform custom searches. - * Requires environment variables `GOOGLE_API_KEY` and `GOOGLE_CSE_ID` to - * be set. - */ -export class GoogleCustomSearch extends Tool { - static lc_name() { - return "GoogleCustomSearch"; - } - - get lc_secrets(): { [key: string]: string } | undefined { - return { - apiKey: "GOOGLE_API_KEY", - }; - } - - name = "google-custom-search"; - - protected apiKey: string; - - protected googleCSEId: string; - - description = - "a custom search engine. useful for when you need to answer questions about current events. input should be a search query. outputs a JSON array of results."; - - constructor( - fields: GoogleCustomSearchParams = { - apiKey: getEnvironmentVariable("GOOGLE_API_KEY"), - googleCSEId: getEnvironmentVariable("GOOGLE_CSE_ID"), - } - ) { - super(...arguments); - if (!fields.apiKey) { - throw new Error( - `Google API key not set. You can set it as "GOOGLE_API_KEY" in your environment variables.` - ); - } - if (!fields.googleCSEId) { - throw new Error( - `Google custom search engine id not set. You can set it as "GOOGLE_CSE_ID" in your environment variables.` - ); - } - this.apiKey = fields.apiKey; - this.googleCSEId = fields.googleCSEId; - } - - async _call(input: string) { - const res = await fetch( - `https://www.googleapis.com/customsearch/v1?key=${this.apiKey}&cx=${ - this.googleCSEId - }&q=${encodeURIComponent(input)}` - ); - - if (!res.ok) { - throw new Error( - `Got ${res.status} error from Google custom search: ${res.statusText}` - ); - } - - const json = await res.json(); - - const results = - json?.items?.map( - (item: { title?: string; link?: string; snippet?: string }) => ({ - title: item.title, - link: item.link, - snippet: item.snippet, - }) - ) ?? []; - return JSON.stringify(results); - } -} +export * from "@langchain/community/tools/google_custom_search"; diff --git a/langchain/src/tools/google_places.ts b/langchain/src/tools/google_places.ts index 826a33e22c74..a5c7404f9810 100644 --- a/langchain/src/tools/google_places.ts +++ b/langchain/src/tools/google_places.ts @@ -1,96 +1 @@ -import { getEnvironmentVariable } from "../util/env.js"; -import { Tool } from "./base.js"; - -/** - * Interface for parameters required by GooglePlacesAPI class. - */ -export interface GooglePlacesAPIParams { - apiKey?: string; -} - -/** - * Tool that queries the Google Places API - */ -export class GooglePlacesAPI extends Tool { - static lc_name() { - return "GooglePlacesAPI"; - } - - get lc_secrets(): { [key: string]: string } | undefined { - return { - apiKey: "GOOGLE_PLACES_API_KEY", - }; - } - - name = "google_places"; - - protected apiKey: string; - - description = `A wrapper around Google Places API. Useful for when you need to validate or - discover addresses from ambiguous text. Input should be a search query.`; - - constructor(fields?: GooglePlacesAPIParams) { - super(...arguments); - const apiKey = - fields?.apiKey ?? getEnvironmentVariable("GOOGLE_PLACES_API_KEY"); - if (apiKey === undefined) { - throw new Error( - `Google Places API key not set. You can set it as "GOOGLE_PLACES_API_KEY" in your environment variables.` - ); - } - this.apiKey = apiKey; - } - - async _call(input: string) { - const res = await fetch( - `https://places.googleapis.com/v1/places:searchText`, - { - method: "POST", - body: JSON.stringify({ - textQuery: input, - languageCode: "en", - }), - headers: { - "X-Goog-Api-Key": this.apiKey, - "X-Goog-FieldMask": - "places.displayName,places.formattedAddress,places.id,places.internationalPhoneNumber,places.websiteUri", - "Content-Type": "application/json", - }, - } - ); - - if (!res.ok) { - let message; - try { - const json = await res.json(); - message = json.error.message; - } catch (e) { - message = - "Unable to parse error message: Google did not return a JSON response."; - } - throw new Error( - `Got ${res.status}: ${res.statusText} error from Google Places API: ${message}` - ); - } - - const json = await res.json(); - - const results = - json?.places?.map( - (place: { - id?: string; - internationalPhoneNumber?: string; - formattedAddress?: string; - websiteUri?: string; - displayName?: { text?: string }; - }) => ({ - name: place.displayName?.text, - id: place.id, - address: place.formattedAddress, - phoneNumber: place.internationalPhoneNumber, - website: place.websiteUri, - }) - ) ?? []; - return JSON.stringify(results); - } -} +export * from "@langchain/community/tools/google_places"; diff --git a/langchain/src/tools/searchapi.ts b/langchain/src/tools/searchapi.ts index fb48c34581db..b10e7a37472a 100644 --- a/langchain/src/tools/searchapi.ts +++ b/langchain/src/tools/searchapi.ts @@ -1,204 +1 @@ -import { getEnvironmentVariable } from "../util/env.js"; -import { Tool } from "./base.js"; - -type JSONPrimitive = string | number | boolean | null; -type JSONValue = JSONPrimitive | JSONObject | JSONArray; -interface JSONObject { - [key: string]: JSONValue; -} -interface JSONArray extends Array {} - -function isJSONObject(value: JSONValue): value is JSONObject { - return value !== null && typeof value === "object" && !Array.isArray(value); -} - -/** - * SearchApiParameters Type Definition. - * - * For more parameters and supported search engines, refer specific engine documentation: - * Google - https://www.searchapi.io/docs/google - * Google News - https://www.searchapi.io/docs/google-news - * Google Scholar - https://www.searchapi.io/docs/google-scholar - * YouTube Transcripts - https://www.searchapi.io/docs/youtube-transcripts - * and others. - * - */ -export type SearchApiParameters = { - [key: string]: JSONValue; -}; - -/** - * SearchApi Class Definition. - * - * Provides a wrapper around the SearchApi. - * - * Ensure you've set the SEARCHAPI_API_KEY environment variable for authentication. - * You can obtain a free API key from https://www.searchapi.io/. - * @example - * ```typescript - * const searchApi = new SearchApi("your-api-key", { - * engine: "google_news", - * }); - * const agent = RunnableSequence.from([ - * ChatPromptTemplate.fromMessages([ - * ["ai", "Answer the following questions using a bulleted list markdown format.""], - * ["human", "{input}"], - * ]), - * new ChatOpenAI({ temperature: 0 }), - * (input: BaseMessageChunk) => ({ - * log: "test", - * returnValues: { - * output: input, - * }, - * }), - * ]); - * const executor = AgentExecutor.fromAgentAndTools({ - * agent, - * tools: [searchApi], - * }); - * const res = await executor.invoke({ - * input: "What's happening in Ukraine today?"", - * }); - * console.log(res); - * ``` - */ -export class SearchApi extends Tool { - static lc_name() { - return "SearchApi"; - } - - /** - * Converts the SearchApi instance to JSON. This method is not implemented - * and will throw an error if called. - * @returns Throws an error. - */ - toJSON() { - return this.toJSONNotImplemented(); - } - - protected apiKey: string; - - protected params: Partial; - - constructor( - apiKey: string | undefined = getEnvironmentVariable("SEARCHAPI_API_KEY"), - params: Partial = {} - ) { - super(...arguments); - - if (!apiKey) { - throw new Error( - "SearchApi requires an API key. Please set it as SEARCHAPI_API_KEY in your .env file, or pass it as a parameter to the SearchApi constructor." - ); - } - - this.apiKey = apiKey; - this.params = params; - } - - name = "search"; - - /** - * Builds a URL for the SearchApi request. - * @param parameters The parameters for the request. - * @returns A string representing the built URL. - */ - protected buildUrl(searchQuery: string): string { - const preparedParams: [string, string][] = Object.entries({ - engine: "google", - api_key: this.apiKey, - ...this.params, - q: searchQuery, - }) - .filter( - ([key, value]) => - value !== undefined && value !== null && key !== "apiKey" - ) - .map(([key, value]) => [key, `${value}`]); - - const searchParams = new URLSearchParams(preparedParams); - return `https://www.searchapi.io/api/v1/search?${searchParams}`; - } - - /** @ignore */ - /** - * Calls the SearchAPI. - * - * Accepts an input query and fetches the result from SearchApi. - * - * @param {string} input - Search query. - * @returns {string} - Formatted search results or an error message. - * - * NOTE: This method is the core search handler and processes various types - * of search results including Google organic results, videos, jobs, and images. - */ - async _call(input: string) { - const resp = await fetch(this.buildUrl(input)); - - const json = await resp.json(); - - if (json.error) { - throw new Error( - `Failed to load search results from SearchApi due to: ${json.error}` - ); - } - - // Google Search results - if (json.answer_box?.answer) { - return json.answer_box.answer; - } - - if (json.answer_box?.snippet) { - return json.answer_box.snippet; - } - - if (json.knowledge_graph?.description) { - return json.knowledge_graph.description; - } - - // Organic results (Google, Google News) - if (json.organic_results) { - const snippets = json.organic_results - .filter((r: JSONObject) => r.snippet) - .map((r: JSONObject) => r.snippet); - return snippets.join("\n"); - } - - // Google Jobs results - if (json.jobs) { - const jobDescriptions = json.jobs - .slice(0, 1) - .filter((r: JSONObject) => r.description) - .map((r: JSONObject) => r.description); - return jobDescriptions.join("\n"); - } - - // Google Videos results - if (json.videos) { - const videoInfo = json.videos - .filter((r: JSONObject) => r.title && r.link) - .map((r: JSONObject) => `Title: "${r.title}" Link: ${r.link}`); - return videoInfo.join("\n"); - } - - // Google Images results - if (json.images) { - const image_results = json.images.slice(0, 15); - const imageInfo = image_results - .filter( - (r: JSONObject) => - r.title && r.original && isJSONObject(r.original) && r.original.link - ) - .map( - (r: JSONObject) => - `Title: "${r.title}" Link: ${(r.original as JSONObject).link}` - ); - return imageInfo.join("\n"); - } - - return "No good search result found"; - } - - description = - "a search engine. useful for when you need to answer questions about current events. input should be a search query."; -} +export * from "@langchain/community/tools/searchapi"; diff --git a/langchain/src/tools/searxng_search.ts b/langchain/src/tools/searxng_search.ts index 8b1c353e949d..d792040f540a 100644 --- a/langchain/src/tools/searxng_search.ts +++ b/langchain/src/tools/searxng_search.ts @@ -1,258 +1 @@ -import { getEnvironmentVariable } from "../util/env.js"; -import { Tool } from "./base.js"; - -/** - * Interface for the results returned by the Searxng search. - */ -interface SearxngResults { - query: string; - number_of_results: number; - results: Array<{ - url: string; - title: string; - content: string; - img_src: string; - engine: string; - parsed_url: Array; - template: string; - engines: Array; - positions: Array; - score: number; - category: string; - pretty_url: string; - open_group?: boolean; - close_group?: boolean; - }>; - answers: Array; - corrections: Array; - infoboxes: Array<{ - infobox: string; - content: string; - engine: string; - engines: Array; - }>; - suggestions: Array; - unresponsive_engines: Array; -} - -/** - * Interface for custom headers used in the Searxng search. - */ -interface SearxngCustomHeaders { - [key: string]: string; -} - -interface SearxngSearchParams { - /** - * @default 10 - * Number of results included in results - */ - numResults?: number; - /** Comma separated list, specifies the active search categories - * https://docs.searxng.org/user/configured_engines.html#configured-engines - */ - categories?: string; - - /** Comma separated list, specifies the active search engines - * https://docs.searxng.org/user/configured_engines.html#configured-engines - */ - engines?: string; - - /** Code of the language. */ - language?: string; - /** Search page number. */ - pageNumber?: number; - /** - * day / month / year - * - * Time range of search for engines which support it. See if an engine supports time range search in the preferences page of an instance. - */ - timeRange?: number; - - /** - * Throws Error if format is set anything other than "json" - * Output format of results. Format needs to be activated in search: - */ - format?: "json"; - /** Open search results on new tab. */ - resultsOnNewTab?: 0 | 1; - /** Proxy image results through SearXNG. */ - imageProxy?: boolean; - autocomplete?: string; - /** - * Filter search results of engines which support safe search. See if an engine supports safe search in the preferences page of an instance. - */ - safesearch?: 0 | 1 | 2; -} - -/** - * SearxngSearch class represents a meta search engine tool. - * Use this class when you need to answer questions about current events. - * The input should be a search query, and the output is a JSON array of the query results. - * - * note: works best with *agentType*: `structured-chat-zero-shot-react-description` - * https://github.com/searxng/searxng - * @example - * ```typescript - * const executor = AgentExecutor.fromAgentAndTools({ - * agent, - * tools: [ - * new SearxngSearch({ - * params: { - * format: "json", - * engines: "google", - * }, - * headers: {}, - * }), - * ], - * }); - * const result = await executor.invoke({ - * input: `What is Langchain? Describe in 50 words`, - * }); - * ``` - */ -export class SearxngSearch extends Tool { - static lc_name() { - return "SearxngSearch"; - } - - name = "searxng-search"; - - description = - "A meta search engine. Useful for when you need to answer questions about current events. Input should be a search query. Output is a JSON array of the query results"; - - protected apiBase?: string; - - protected params?: SearxngSearchParams = { - numResults: 10, - pageNumber: 1, - format: "json", - imageProxy: true, - safesearch: 0, - }; - - protected headers?: SearxngCustomHeaders; - - get lc_secrets(): { [key: string]: string } | undefined { - return { - apiBase: "SEARXNG_API_BASE", - }; - } - - /** - * Constructor for the SearxngSearch class - * @param apiBase Base URL of the Searxng instance - * @param params SearxNG parameters - * @param headers Custom headers - */ - constructor({ - apiBase, - params, - headers, - }: { - /** Base URL of Searxng instance */ - apiBase?: string; - - /** SearxNG Paramerters - * - * https://docs.searxng.org/dev/search_api.html check here for more details - */ - params?: SearxngSearchParams; - - /** - * Custom headers - * Set custom headers if you're using a api from RapidAPI (https://rapidapi.com/iamrony777/api/searxng) - * No headers needed for a locally self-hosted instance - */ - headers?: SearxngCustomHeaders; - }) { - super(...arguments); - - this.apiBase = getEnvironmentVariable("SEARXNG_API_BASE") || apiBase; - this.headers = { "content-type": "application/json", ...headers }; - - if (!this.apiBase) { - throw new Error( - `SEARXNG_API_BASE not set. You can set it as "SEARXNG_API_BASE" in your environment variables.` - ); - } - - if (params) { - this.params = { ...this.params, ...params }; - } - } - - /** - * Builds the URL for the Searxng search. - * @param path The path for the URL. - * @param parameters The parameters for the URL. - * @param baseUrl The base URL. - * @returns The complete URL as a string. - */ - protected buildUrl

( - path: string, - parameters: P, - baseUrl: string - ): string { - const nonUndefinedParams: [string, string][] = Object.entries(parameters) - .filter(([_, value]) => value !== undefined) - .map(([key, value]) => [key, value.toString()]); // Avoid string conversion - const searchParams = new URLSearchParams(nonUndefinedParams); - return `${baseUrl}/${path}?${searchParams}`; - } - - async _call(input: string): Promise { - const queryParams = { - q: input, - ...this.params, - }; - const url = this.buildUrl("search", queryParams, this.apiBase as string); - - const resp = await fetch(url, { - method: "POST", - headers: this.headers, - signal: AbortSignal.timeout(5 * 1000), // 5 seconds - }); - - if (!resp.ok) { - throw new Error(resp.statusText); - } - - const res: SearxngResults = await resp.json(); - - if ( - !res.results.length && - !res.answers.length && - !res.infoboxes.length && - !res.suggestions.length - ) { - return "No good results found."; - } else if (res.results.length) { - const response: string[] = []; - - res.results.forEach((r) => { - response.push( - JSON.stringify({ - title: r.title || "", - link: r.url || "", - snippet: r.content || "", - }) - ); - }); - - return response.slice(0, this.params?.numResults).toString(); - } else if (res.answers.length) { - return res.answers[0]; - } else if (res.infoboxes.length) { - return res.infoboxes[0]?.content.replaceAll(/<[^>]+>/gi, ""); - } else if (res.suggestions.length) { - let suggestions = "Suggestions: "; - res.suggestions.forEach((s) => { - suggestions += `${s}, `; - }); - return suggestions; - } else { - return "No good results found."; - } - } -} +export * from "@langchain/community/tools/searxng_search"; diff --git a/langchain/src/tools/serpapi.ts b/langchain/src/tools/serpapi.ts index fc3044e9d658..2c3ff83f9b21 100644 --- a/langchain/src/tools/serpapi.ts +++ b/langchain/src/tools/serpapi.ts @@ -1,505 +1 @@ -import { getEnvironmentVariable } from "../util/env.js"; -import { Tool } from "./base.js"; - -/** - * This does not use the `serpapi` package because it appears to cause issues - * when used in `jest` tests. Part of the issue seems to be that the `serpapi` - * package imports a wasm module to use instead of native `fetch`, which we - * don't want anyway. - * - * NOTE: you must provide location, gl and hl or your region and language will - * may not match your location, and will not be deterministic. - */ - -// Copied over from `serpapi` package -interface BaseParameters { - /** - * Parameter defines the device to use to get the results. It can be set to - * `desktop` (default) to use a regular browser, `tablet` to use a tablet browser - * (currently using iPads), or `mobile` to use a mobile browser (currently - * using iPhones). - */ - device?: "desktop" | "tablet" | "mobile"; - /** - * Parameter will force SerpApi to fetch the Google results even if a cached - * version is already present. A cache is served only if the query and all - * parameters are exactly the same. Cache expires after 1h. Cached searches - * are free, and are not counted towards your searches per month. It can be set - * to `false` (default) to allow results from the cache, or `true` to disallow - * results from the cache. `no_cache` and `async` parameters should not be used together. - */ - no_cache?: boolean; - /** - * Specify the client-side timeout of the request. In milliseconds. - */ - timeout?: number; -} - -export interface SerpAPIParameters extends BaseParameters { - /** - * Search Query - * Parameter defines the query you want to search. You can use anything that you - * would use in a regular Google search. e.g. `inurl:`, `site:`, `intitle:`. We - * also support advanced search query parameters such as as_dt and as_eq. See the - * [full list](https://serpapi.com/advanced-google-query-parameters) of supported - * advanced search query parameters. - */ - q: string; - /** - * Location - * Parameter defines from where you want the search to originate. If several - * locations match the location requested, we'll pick the most popular one. Head to - * [/locations.json API](https://serpapi.com/locations-api) if you need more - * precise control. location and uule parameters can't be used together. Avoid - * utilizing location when setting the location outside the U.S. when using Google - * Shopping and/or Google Product API. - */ - location?: string; - /** - * Encoded Location - * Parameter is the Google encoded location you want to use for the search. uule - * and location parameters can't be used together. - */ - uule?: string; - /** - * Google Place ID - * Parameter defines the id (`CID`) of the Google My Business listing you want to - * scrape. Also known as Google Place ID. - */ - ludocid?: string; - /** - * Additional Google Place ID - * Parameter that you might have to use to force the knowledge graph map view to - * show up. You can find the lsig ID by using our [Local Pack - * API](https://serpapi.com/local-pack) or [Local Places Results - * API](https://serpapi.com/local-results). - * lsig ID is also available via a redirect Google uses within [Google My - * Business](https://www.google.com/business/). - */ - lsig?: string; - /** - * Google Knowledge Graph ID - * Parameter defines the id (`KGMID`) of the Google Knowledge Graph listing you - * want to scrape. Also known as Google Knowledge Graph ID. Searches with kgmid - * parameter will return results for the originally encrypted search parameters. - * For some searches, kgmid may override all other parameters except start, and num - * parameters. - */ - kgmid?: string; - /** - * Google Cached Search Parameters ID - * Parameter defines the cached search parameters of the Google Search you want to - * scrape. Searches with si parameter will return results for the originally - * encrypted search parameters. For some searches, si may override all other - * parameters except start, and num parameters. si can be used to scrape Google - * Knowledge Graph Tabs. - */ - si?: string; - /** - * Domain - * Parameter defines the Google domain to use. It defaults to `google.com`. Head to - * the [Google domains page](https://serpapi.com/google-domains) for a full list of - * supported Google domains. - */ - google_domain?: string; - /** - * Country - * Parameter defines the country to use for the Google search. It's a two-letter - * country code. (e.g., `us` for the United States, `uk` for United Kingdom, or - * `fr` for France). Head to the [Google countries - * page](https://serpapi.com/google-countries) for a full list of supported Google - * countries. - */ - gl?: string; - /** - * Language - * Parameter defines the language to use for the Google search. It's a two-letter - * language code. (e.g., `en` for English, `es` for Spanish, or `fr` for French). - * Head to the [Google languages page](https://serpapi.com/google-languages) for a - * full list of supported Google languages. - */ - hl?: string; - /** - * Set Multiple Languages - * Parameter defines one or multiple languages to limit the search to. It uses - * `lang_{two-letter language code}` to specify languages and `|` as a delimiter. - * (e.g., `lang_fr|lang_de` will only search French and German pages). Head to the - * [Google lr languages page](https://serpapi.com/google-lr-languages) for a full - * list of supported languages. - */ - lr?: string; - /** - * as_dt - * Parameter controls whether to include or exclude results from the site named in - * the as_sitesearch parameter. - */ - as_dt?: string; - /** - * as_epq - * Parameter identifies a phrase that all documents in the search results must - * contain. You can also use the [phrase - * search](https://developers.google.com/custom-search/docs/xml_results#PhraseSearchqt) - * query term to search for a phrase. - */ - as_epq?: string; - /** - * as_eq - * Parameter identifies a word or phrase that should not appear in any documents in - * the search results. You can also use the [exclude - * query](https://developers.google.com/custom-search/docs/xml_results#Excludeqt) - * term to ensure that a particular word or phrase will not appear in the documents - * in a set of search results. - */ - as_eq?: string; - /** - * as_lq - * Parameter specifies that all search results should contain a link to a - * particular URL. You can also use the - * [link:](https://developers.google.com/custom-search/docs/xml_results#BackLinksqt) - * query term for this type of query. - */ - as_lq?: string; - /** - * as_nlo - * Parameter specifies the starting value for a search range. Use as_nlo and as_nhi - * to append an inclusive search range. - */ - as_nlo?: string; - /** - * as_nhi - * Parameter specifies the ending value for a search range. Use as_nlo and as_nhi - * to append an inclusive search range. - */ - as_nhi?: string; - /** - * as_oq - * Parameter provides additional search terms to check for in a document, where - * each document in the search results must contain at least one of the additional - * search terms. You can also use the [Boolean - * OR](https://developers.google.com/custom-search/docs/xml_results#BooleanOrqt) - * query term for this type of query. - */ - as_oq?: string; - /** - * as_q - * Parameter provides search terms to check for in a document. This parameter is - * also commonly used to allow users to specify additional terms to search for - * within a set of search results. - */ - as_q?: string; - /** - * as_qdr - * Parameter requests search results from a specified time period (quick date - * range). The following values are supported: - * `d[number]`: requests results from the specified number of past days. Example - * for the past 10 days: `as_qdr=d10` - * `w[number]`: requests results from the specified number of past weeks. - * `m[number]`: requests results from the specified number of past months. - * `y[number]`: requests results from the specified number of past years. Example - * for the past year: `as_qdr=y` - */ - as_qdr?: string; - /** - * as_rq - * Parameter specifies that all search results should be pages that are related to - * the specified URL. The parameter value should be a URL. You can also use the - * [related:](https://developers.google.com/custom-search/docs/xml_results#RelatedLinksqt) - * query term for this type of query. - */ - as_rq?: string; - /** - * as_sitesearch - * Parameter allows you to specify that all search results should be pages from a - * given site. By setting the as_dt parameter, you can also use it to exclude pages - * from a given site from your search resutls. - */ - as_sitesearch?: string; - /** - * Advanced Search Parameters - * (to be searched) parameter defines advanced search parameters that aren't - * possible in the regular query field. (e.g., advanced search for patents, dates, - * news, videos, images, apps, or text contents). - */ - tbs?: string; - /** - * Adult Content Filtering - * Parameter defines the level of filtering for adult content. It can be set to - * `active`, or `off` (default). - */ - safe?: string; - /** - * Exclude Auto-corrected Results - * Parameter defines the exclusion of results from an auto-corrected query that is - * spelled wrong. It can be set to `1` to exclude these results, or `0` to include - * them (default). - */ - nfpr?: string; - /** - * Results Filtering - * Parameter defines if the filters for 'Similar Results' and 'Omitted Results' are - * on or off. It can be set to `1` (default) to enable these filters, or `0` to - * disable these filters. - */ - filter?: string; - /** - * Search Type - * (to be matched) parameter defines the type of search you want to do. - * It can be set to: - * `(no tbm parameter)`: regular Google Search, - * `isch`: [Google Images API](https://serpapi.com/images-results), - * `lcl` - [Google Local API](https://serpapi.com/local-results) - * `vid`: [Google Videos API](https://serpapi.com/videos-results), - * `nws`: [Google News API](https://serpapi.com/news-results), - * `shop`: [Google Shopping API](https://serpapi.com/shopping-results), - * or any other Google service. - */ - tbm?: string; - /** - * Result Offset - * Parameter defines the result offset. It skips the given number of results. It's - * used for pagination. (e.g., `0` (default) is the first page of results, `10` is - * the 2nd page of results, `20` is the 3rd page of results, etc.). - * Google Local Results only accepts multiples of `20`(e.g. `20` for the second - * page results, `40` for the third page results, etc.) as the start value. - */ - start?: number; - /** - * Number of Results - * Parameter defines the maximum number of results to return. (e.g., `10` (default) - * returns 10 results, `40` returns 40 results, and `100` returns 100 results). - */ - num?: string; - /** - * Page Number (images) - * Parameter defines the page number for [Google - * Images](https://serpapi.com/images-results). There are 100 images per page. This - * parameter is equivalent to start (offset) = ijn * 100. This parameter works only - * for [Google Images](https://serpapi.com/images-results) (set tbm to `isch`). - */ - ijn?: string; -} - -type UrlParameters = Record< - string, - string | number | boolean | undefined | null ->; - -/** - * Wrapper around SerpAPI. - * - * To use, you should have the `serpapi` package installed and the SERPAPI_API_KEY environment variable set. - */ -export class SerpAPI extends Tool { - static lc_name() { - return "SerpAPI"; - } - - toJSON() { - return this.toJSONNotImplemented(); - } - - protected key: string; - - protected params: Partial; - - protected baseUrl: string; - - constructor( - apiKey: string | undefined = getEnvironmentVariable("SERPAPI_API_KEY"), - params: Partial = {}, - baseUrl = "https://serpapi.com" - ) { - super(...arguments); - - if (!apiKey) { - throw new Error( - "SerpAPI API key not set. You can set it as SERPAPI_API_KEY in your .env file, or pass it to SerpAPI." - ); - } - - this.key = apiKey; - this.params = params; - this.baseUrl = baseUrl; - } - - name = "search"; - - /** - * Builds a URL for the SerpAPI request. - * @param path The path for the request. - * @param parameters The parameters for the request. - * @param baseUrl The base URL for the request. - * @returns A string representing the built URL. - */ - protected buildUrl

( - path: string, - parameters: P, - baseUrl: string - ): string { - const nonUndefinedParams: [string, string][] = Object.entries(parameters) - .filter(([_, value]) => value !== undefined) - .map(([key, value]) => [key, `${value}`]); - const searchParams = new URLSearchParams(nonUndefinedParams); - return `${baseUrl}/${path}?${searchParams}`; - } - - /** @ignore */ - async _call(input: string) { - const { timeout, ...params } = this.params; - const resp = await fetch( - this.buildUrl( - "search", - { - ...params, - api_key: this.key, - q: input, - }, - this.baseUrl - ), - { - signal: timeout ? AbortSignal.timeout(timeout) : undefined, - } - ); - - const res = await resp.json(); - - if (res.error) { - throw new Error(`Got error from serpAPI: ${res.error}`); - } - - const answer_box = res.answer_box_list - ? res.answer_box_list[0] - : res.answer_box; - if (answer_box) { - if (answer_box.result) { - return answer_box.result; - } else if (answer_box.answer) { - return answer_box.answer; - } else if (answer_box.snippet) { - return answer_box.snippet; - } else if (answer_box.snippet_highlighted_words) { - return answer_box.snippet_highlighted_words.toString(); - } else { - const answer: { [key: string]: string } = {}; - Object.keys(answer_box) - .filter( - (k) => - !Array.isArray(answer_box[k]) && - typeof answer_box[k] !== "object" && - !( - typeof answer_box[k] === "string" && - answer_box[k].startsWith("http") - ) - ) - .forEach((k) => { - answer[k] = answer_box[k]; - }); - return JSON.stringify(answer); - } - } - - if (res.events_results) { - return JSON.stringify(res.events_results); - } - - if (res.sports_results) { - return JSON.stringify(res.sports_results); - } - - if (res.top_stories) { - return JSON.stringify(res.top_stories); - } - - if (res.news_results) { - return JSON.stringify(res.news_results); - } - - if (res.jobs_results?.jobs) { - return JSON.stringify(res.jobs_results.jobs); - } - - if (res.questions_and_answers) { - return JSON.stringify(res.questions_and_answers); - } - - if (res.popular_destinations?.destinations) { - return JSON.stringify(res.popular_destinations.destinations); - } - - if (res.top_sights?.sights) { - const sights: Array<{ [key: string]: string }> = res.top_sights.sights - .map((s: { [key: string]: string }) => ({ - title: s.title, - description: s.description, - price: s.price, - })) - .slice(0, 8); - return JSON.stringify(sights); - } - - if (res.shopping_results && res.shopping_results[0]?.title) { - return JSON.stringify(res.shopping_results.slice(0, 3)); - } - - if (res.images_results && res.images_results[0]?.thumbnail) { - return res.images_results - .map((ir: { thumbnail: string }) => ir.thumbnail) - .slice(0, 10) - .toString(); - } - - const snippets = []; - if (res.knowledge_graph) { - if (res.knowledge_graph.description) { - snippets.push(res.knowledge_graph.description); - } - - const title = res.knowledge_graph.title || ""; - Object.keys(res.knowledge_graph) - .filter( - (k) => - typeof res.knowledge_graph[k] === "string" && - k !== "title" && - k !== "description" && - !k.endsWith("_stick") && - !k.endsWith("_link") && - !k.startsWith("http") - ) - .forEach((k) => - snippets.push(`${title} ${k}: ${res.knowledge_graph[k]}`) - ); - } - - const first_organic_result = res.organic_results?.[0]; - if (first_organic_result) { - if (first_organic_result.snippet) { - snippets.push(first_organic_result.snippet); - } else if (first_organic_result.snippet_highlighted_words) { - snippets.push(first_organic_result.snippet_highlighted_words); - } else if (first_organic_result.rich_snippet) { - snippets.push(first_organic_result.rich_snippet); - } else if (first_organic_result.rich_snippet_table) { - snippets.push(first_organic_result.rich_snippet_table); - } else if (first_organic_result.link) { - snippets.push(first_organic_result.link); - } - } - - if (res.buying_guide) { - snippets.push(res.buying_guide); - } - - if (res.local_results?.places) { - snippets.push(res.local_results.places); - } - - if (snippets.length > 0) { - return JSON.stringify(snippets); - } else { - return "No good search result found"; - } - } - - description = - "a search engine. useful for when you need to answer questions about current events. input should be a search query."; -} +export * from "@langchain/community/tools/serpapi"; diff --git a/langchain/src/tools/serper.ts b/langchain/src/tools/serper.ts index 622dd6398aa9..275f96a4bc15 100644 --- a/langchain/src/tools/serper.ts +++ b/langchain/src/tools/serper.ts @@ -1,107 +1 @@ -import { getEnvironmentVariable } from "../util/env.js"; -import { Tool } from "./base.js"; - -/** - * Defines the parameters that can be passed to the Serper class during - * instantiation. It includes `gl` and `hl` which are optional. - */ -export type SerperParameters = { - gl?: string; - hl?: string; -}; - -/** - * Wrapper around serper. - * - * You can create a free API key at https://serper.dev. - * - * To use, you should have the SERPER_API_KEY environment variable set. - */ -export class Serper extends Tool { - static lc_name() { - return "Serper"; - } - - /** - * Converts the Serper instance to JSON. This method is not implemented - * and will throw an error if called. - * @returns Throws an error. - */ - toJSON() { - return this.toJSONNotImplemented(); - } - - protected key: string; - - protected params: Partial; - - constructor( - apiKey: string | undefined = getEnvironmentVariable("SERPER_API_KEY"), - params: Partial = {} - ) { - super(); - - if (!apiKey) { - throw new Error( - "Serper API key not set. You can set it as SERPER_API_KEY in your .env file, or pass it to Serper." - ); - } - - this.key = apiKey; - this.params = params; - } - - name = "search"; - - /** @ignore */ - async _call(input: string) { - const options = { - method: "POST", - headers: { - "X-API-KEY": this.key, - "Content-Type": "application/json", - }, - body: JSON.stringify({ - q: input, - ...this.params, - }), - }; - - const res = await fetch("https://google.serper.dev/search", options); - - if (!res.ok) { - throw new Error(`Got ${res.status} error from serper: ${res.statusText}`); - } - - const json = await res.json(); - - if (json.answerBox?.answer) { - return json.answerBox.answer; - } - - if (json.answerBox?.snippet) { - return json.answerBox.snippet; - } - - if (json.answerBox?.snippet_highlighted_words) { - return json.answerBox.snippet_highlighted_words[0]; - } - - if (json.sportsResults?.game_spotlight) { - return json.sportsResults.game_spotlight; - } - - if (json.knowledgeGraph?.description) { - return json.knowledgeGraph.description; - } - - if (json.organic?.[0]?.snippet) { - return json.organic[0].snippet; - } - - return "No good search result found"; - } - - description = - "a search engine. useful for when you need to answer questions about current events. input should be a search query."; -} +export * from "@langchain/community/tools/serper"; diff --git a/langchain/src/tools/tests/gmail.test.ts b/langchain/src/tools/tests/gmail.test.ts index e44b6f7fef36..8ab6743d0071 100644 --- a/langchain/src/tools/tests/gmail.test.ts +++ b/langchain/src/tools/tests/gmail.test.ts @@ -1,5 +1,5 @@ import { jest, expect, describe } from "@jest/globals"; -import { GmailGetMessage } from "../gmail/get_message.js"; +import { GmailGetMessage } from "../gmail/index.js"; jest.mock("googleapis", () => ({ google: { diff --git a/langchain/src/tools/wikipedia_query_run.ts b/langchain/src/tools/wikipedia_query_run.ts index 127010b46cf3..f80c7eefd903 100644 --- a/langchain/src/tools/wikipedia_query_run.ts +++ b/langchain/src/tools/wikipedia_query_run.ts @@ -1,181 +1 @@ -import { Tool } from "./base.js"; - -/** - * Interface for the parameters that can be passed to the - * WikipediaQueryRun constructor. - */ -export interface WikipediaQueryRunParams { - topKResults?: number; - maxDocContentLength?: number; - baseUrl?: string; -} - -/** - * Type alias for URL parameters. Represents a record where keys are - * strings and values can be string, number, boolean, undefined, or null. - */ -type UrlParameters = Record< - string, - string | number | boolean | undefined | null ->; - -/** - * Interface for the structure of search results returned by the Wikipedia - * API. - */ -interface SearchResults { - query: { - search: Array<{ - title: string; - }>; - }; -} - -/** - * Interface for the structure of a page returned by the Wikipedia API. - */ -interface Page { - pageid: number; - ns: number; - title: string; - extract: string; -} - -/** - * Interface for the structure of a page result returned by the Wikipedia - * API. - */ -interface PageResult { - batchcomplete: string; - query: { - pages: Record; - }; -} - -/** - * Class for interacting with and fetching data from the Wikipedia API. It - * extends the Tool class. - * @example - * ```typescript - * const wikipediaQuery = new WikipediaQueryRun({ - * topKResults: 3, - * maxDocContentLength: 4000, - * }); - * const result = await wikipediaQuery.call("Langchain"); - * ``` - */ -export class WikipediaQueryRun extends Tool { - static lc_name() { - return "WikipediaQueryRun"; - } - - name = "wikipedia-api"; - - description = - "A tool for interacting with and fetching data from the Wikipedia API."; - - protected topKResults = 3; - - protected maxDocContentLength = 4000; - - protected baseUrl = "https://en.wikipedia.org/w/api.php"; - - constructor(params: WikipediaQueryRunParams = {}) { - super(); - - this.topKResults = params.topKResults ?? this.topKResults; - this.maxDocContentLength = - params.maxDocContentLength ?? this.maxDocContentLength; - this.baseUrl = params.baseUrl ?? this.baseUrl; - } - - async _call(query: string): Promise { - const searchResults = await this._fetchSearchResults(query); - const summaries: string[] = []; - - for ( - let i = 0; - i < Math.min(this.topKResults, searchResults.query.search.length); - i += 1 - ) { - const page = searchResults.query.search[i].title; - const pageDetails = await this._fetchPage(page, true); - - if (pageDetails) { - const summary = `Page: ${page}\nSummary: ${pageDetails.extract}`; - summaries.push(summary); - } - } - - if (summaries.length === 0) { - return "No good Wikipedia Search Result was found"; - } else { - return summaries.join("\n\n").slice(0, this.maxDocContentLength); - } - } - - /** - * Fetches the content of a specific Wikipedia page. It returns the - * extracted content as a string. - * @param page The specific Wikipedia page to fetch its content. - * @param redirect A boolean value to indicate whether to redirect or not. - * @returns The extracted content of the specific Wikipedia page as a string. - */ - public async content(page: string, redirect = true): Promise { - try { - const result = await this._fetchPage(page, redirect); - return result.extract; - } catch (error) { - throw new Error(`Failed to fetch content for page "${page}": ${error}`); - } - } - - /** - * Builds a URL for the Wikipedia API using the provided parameters. - * @param parameters The parameters to be used in building the URL. - * @returns A string representing the built URL. - */ - protected buildUrl

(parameters: P): string { - const nonUndefinedParams: [string, string][] = Object.entries(parameters) - .filter(([_, value]) => value !== undefined) - .map(([key, value]) => [key, `${value}`]); - const searchParams = new URLSearchParams(nonUndefinedParams); - return `${this.baseUrl}?${searchParams}`; - } - - private async _fetchSearchResults(query: string): Promise { - const searchParams = new URLSearchParams({ - action: "query", - list: "search", - srsearch: query, - format: "json", - }); - - const response = await fetch(`${this.baseUrl}?${searchParams.toString()}`); - if (!response.ok) throw new Error("Network response was not ok"); - - const data: SearchResults = await response.json(); - - return data; - } - - private async _fetchPage(page: string, redirect: boolean): Promise { - const params = new URLSearchParams({ - action: "query", - prop: "extracts", - explaintext: "true", - redirects: redirect ? "1" : "0", - format: "json", - titles: page, - }); - - const response = await fetch(`${this.baseUrl}?${params.toString()}`); - if (!response.ok) throw new Error("Network response was not ok"); - - const data: PageResult = await response.json(); - const { pages } = data.query; - const pageId = Object.keys(pages)[0]; - - return pages[pageId]; - } -} +export * from "@langchain/community/tools/wikipedia_query_run"; diff --git a/langchain/src/tools/wolframalpha.ts b/langchain/src/tools/wolframalpha.ts index 76658428a734..e0e76e072248 100644 --- a/langchain/src/tools/wolframalpha.ts +++ b/langchain/src/tools/wolframalpha.ts @@ -1,41 +1 @@ -import { Tool, ToolParams } from "./base.js"; - -/** - * @example - * ```typescript - * const tool = new WolframAlphaTool({ - * appid: "YOUR_APP_ID", - * }); - * const res = await tool.invoke("What is 2 * 2?"); - * ``` - */ -export class WolframAlphaTool extends Tool { - appid: string; - - name = "wolfram_alpha"; - - description = `A wrapper around Wolfram Alpha. Useful for when you need to answer questions about Math, Science, Technology, Culture, Society and Everyday Life. Input should be a search query.`; - - constructor(fields: ToolParams & { appid: string }) { - super(fields); - - this.appid = fields.appid; - } - - get lc_namespace() { - return [...super.lc_namespace, "wolframalpha"]; - } - - static lc_name() { - return "WolframAlphaTool"; - } - - async _call(query: string): Promise { - const url = `https://www.wolframalpha.com/api/v1/llm-api?appid=${ - this.appid - }&input=${encodeURIComponent(query)}`; - const res = await fetch(url); - - return res.text(); - } -} +export * from "@langchain/community/tools/wolframalpha"; diff --git a/langchain/src/types/openai-types.ts b/langchain/src/types/openai-types.ts index f3df0278a6a9..a3acaf1904d0 100644 --- a/langchain/src/types/openai-types.ts +++ b/langchain/src/types/openai-types.ts @@ -1,4 +1,4 @@ -import type { OpenAI as OpenAIClient } from "openai"; +import type { OpenAIClient } from "@langchain/openai"; import { TiktokenModel } from "js-tiktoken/lite"; import { BaseLanguageModelCallOptions } from "../base_language/index.js"; diff --git a/langchain/src/util/event-source-parse.ts b/langchain/src/util/event-source-parse.ts index 5f43f25a0271..0b74a11974ae 100644 --- a/langchain/src/util/event-source-parse.ts +++ b/langchain/src/util/event-source-parse.ts @@ -1,287 +1 @@ -/* eslint-disable prefer-template */ -/* eslint-disable default-case */ -/* eslint-disable no-plusplus */ -// Adapted from https://github.com/gfortaine/fetch-event-source/blob/main/src/parse.ts -// due to a packaging issue in the original. -// MIT License -import { type Readable } from "stream"; -import { IterableReadableStream } from "./stream.js"; - -export const EventStreamContentType = "text/event-stream"; - -/** - * Represents a message sent in an event stream - * https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format - */ -export interface EventSourceMessage { - /** The event ID to set the EventSource object's last event ID value. */ - id: string; - /** A string identifying the type of event described. */ - event: string; - /** The event data */ - data: string; - /** The reconnection interval (in milliseconds) to wait before retrying the connection */ - retry?: number; -} - -function isNodeJSReadable(x: unknown): x is Readable { - return x != null && typeof x === "object" && "on" in x; -} - -/** - * Converts a ReadableStream into a callback pattern. - * @param stream The input ReadableStream. - * @param onChunk A function that will be called on each new byte chunk in the stream. - * @returns {Promise} A promise that will be resolved when the stream closes. - */ -export async function getBytes( - stream: ReadableStream, - onChunk: (arr: Uint8Array, flush?: boolean) => void -) { - // stream is a Node.js Readable / PassThrough stream - // this can happen if node-fetch is polyfilled - if (isNodeJSReadable(stream)) { - return new Promise((resolve) => { - stream.on("readable", () => { - let chunk; - // eslint-disable-next-line no-constant-condition - while (true) { - chunk = stream.read(); - if (chunk == null) { - onChunk(new Uint8Array(), true); - break; - } - onChunk(chunk); - } - - resolve(); - }); - }); - } - - const reader = stream.getReader(); - // CHANGED: Introduced a "flush" mechanism to process potential pending messages when the stream ends. - // This change is essential to ensure that we capture every last piece of information from streams, - // such as those from Azure OpenAI, which may not terminate with a blank line. Without this - // mechanism, we risk ignoring a possibly significant last message. - // See https://github.com/langchain-ai/langchainjs/issues/1299 for details. - // eslint-disable-next-line no-constant-condition - while (true) { - const result = await reader.read(); - if (result.done) { - onChunk(new Uint8Array(), true); - break; - } - onChunk(result.value); - } -} - -const enum ControlChars { - NewLine = 10, - CarriageReturn = 13, - Space = 32, - Colon = 58, -} - -/** - * Parses arbitary byte chunks into EventSource line buffers. - * Each line should be of the format "field: value" and ends with \r, \n, or \r\n. - * @param onLine A function that will be called on each new EventSource line. - * @returns A function that should be called for each incoming byte chunk. - */ -export function getLines( - onLine: (line: Uint8Array, fieldLength: number, flush?: boolean) => void -) { - let buffer: Uint8Array | undefined; - let position: number; // current read position - let fieldLength: number; // length of the `field` portion of the line - let discardTrailingNewline = false; - - // return a function that can process each incoming byte chunk: - return function onChunk(arr: Uint8Array, flush?: boolean) { - if (flush) { - onLine(arr, 0, true); - return; - } - - if (buffer === undefined) { - buffer = arr; - position = 0; - fieldLength = -1; - } else { - // we're still parsing the old line. Append the new bytes into buffer: - buffer = concat(buffer, arr); - } - - const bufLength = buffer.length; - let lineStart = 0; // index where the current line starts - while (position < bufLength) { - if (discardTrailingNewline) { - if (buffer[position] === ControlChars.NewLine) { - lineStart = ++position; // skip to next char - } - - discardTrailingNewline = false; - } - - // start looking forward till the end of line: - let lineEnd = -1; // index of the \r or \n char - for (; position < bufLength && lineEnd === -1; ++position) { - switch (buffer[position]) { - case ControlChars.Colon: - if (fieldLength === -1) { - // first colon in line - fieldLength = position - lineStart; - } - break; - // eslint-disable-next-line @typescript-eslint/ban-ts-comment - // @ts-ignore:7029 \r case below should fallthrough to \n: - case ControlChars.CarriageReturn: - discardTrailingNewline = true; - // eslint-disable-next-line no-fallthrough - case ControlChars.NewLine: - lineEnd = position; - break; - } - } - - if (lineEnd === -1) { - // We reached the end of the buffer but the line hasn't ended. - // Wait for the next arr and then continue parsing: - break; - } - - // we've reached the line end, send it out: - onLine(buffer.subarray(lineStart, lineEnd), fieldLength); - lineStart = position; // we're now on the next line - fieldLength = -1; - } - - if (lineStart === bufLength) { - buffer = undefined; // we've finished reading it - } else if (lineStart !== 0) { - // Create a new view into buffer beginning at lineStart so we don't - // need to copy over the previous lines when we get the new arr: - buffer = buffer.subarray(lineStart); - position -= lineStart; - } - }; -} - -/** - * Parses line buffers into EventSourceMessages. - * @param onId A function that will be called on each `id` field. - * @param onRetry A function that will be called on each `retry` field. - * @param onMessage A function that will be called on each message. - * @returns A function that should be called for each incoming line buffer. - */ -export function getMessages( - onMessage?: (msg: EventSourceMessage) => void, - onId?: (id: string) => void, - onRetry?: (retry: number) => void -) { - let message = newMessage(); - const decoder = new TextDecoder(); - - // return a function that can process each incoming line buffer: - return function onLine( - line: Uint8Array, - fieldLength: number, - flush?: boolean - ) { - if (flush) { - if (!isEmpty(message)) { - onMessage?.(message); - message = newMessage(); - } - return; - } - - if (line.length === 0) { - // empty line denotes end of message. Trigger the callback and start a new message: - onMessage?.(message); - message = newMessage(); - } else if (fieldLength > 0) { - // exclude comments and lines with no values - // line is of format ":" or ": " - // https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation - const field = decoder.decode(line.subarray(0, fieldLength)); - const valueOffset = - fieldLength + (line[fieldLength + 1] === ControlChars.Space ? 2 : 1); - const value = decoder.decode(line.subarray(valueOffset)); - - switch (field) { - case "data": - // if this message already has data, append the new value to the old. - // otherwise, just set to the new value: - message.data = message.data ? message.data + "\n" + value : value; // otherwise, - break; - case "event": - message.event = value; - break; - case "id": - onId?.((message.id = value)); - break; - case "retry": { - const retry = parseInt(value, 10); - if (!Number.isNaN(retry)) { - // per spec, ignore non-integers - onRetry?.((message.retry = retry)); - } - break; - } - } - } - }; -} - -function concat(a: Uint8Array, b: Uint8Array) { - const res = new Uint8Array(a.length + b.length); - res.set(a); - res.set(b, a.length); - return res; -} - -function newMessage(): EventSourceMessage { - // data, event, and id must be initialized to empty strings: - // https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation - // retry should be initialized to undefined so we return a consistent shape - // to the js engine all the time: https://mathiasbynens.be/notes/shapes-ics#takeaways - return { - data: "", - event: "", - id: "", - retry: undefined, - }; -} - -export function convertEventStreamToIterableReadableDataStream( - stream: ReadableStream -) { - const dataStream = new ReadableStream({ - async start(controller) { - const enqueueLine = getMessages((msg) => { - if (msg.data) controller.enqueue(msg.data); - }); - const onLine = ( - line: Uint8Array, - fieldLength: number, - flush?: boolean - ) => { - enqueueLine(line, fieldLength, flush); - if (flush) controller.close(); - }; - await getBytes(stream, getLines(onLine)); - }, - }); - return IterableReadableStream.fromReadableStream(dataStream); -} - -function isEmpty(message: EventSourceMessage): boolean { - return ( - message.data === "" && - message.event === "" && - message.id === "" && - message.retry === undefined - ); -} +export * from "@langchain/community/utils/event_source_parse"; diff --git a/langchain/src/util/math.ts b/langchain/src/util/math.ts index fe703c2d5f79..5c1f37d1b0aa 100644 --- a/langchain/src/util/math.ts +++ b/langchain/src/util/math.ts @@ -1,180 +1 @@ -import { - similarity as ml_distance_similarity, - distance as ml_distance, -} from "ml-distance"; - -type VectorFunction = (xVector: number[], yVector: number[]) => number; - -/** - * Apply a row-wise function between two matrices with the same number of columns. - * - * @param {number[][]} X - The first matrix. - * @param {number[][]} Y - The second matrix. - * @param {VectorFunction} func - The function to apply. - * - * @throws {Error} If the number of columns in X and Y are not the same. - * - * @returns {number[][] | [[]]} A matrix where each row represents the result of applying the function between the corresponding rows of X and Y. - */ - -export function matrixFunc( - X: number[][], - Y: number[][], - func: VectorFunction -): number[][] { - if ( - X.length === 0 || - X[0].length === 0 || - Y.length === 0 || - Y[0].length === 0 - ) { - return [[]]; - } - - if (X[0].length !== Y[0].length) { - throw new Error( - `Number of columns in X and Y must be the same. X has shape ${[ - X.length, - X[0].length, - ]} and Y has shape ${[Y.length, Y[0].length]}.` - ); - } - - return X.map((xVector) => - Y.map((yVector) => func(xVector, yVector)).map((similarity) => - Number.isNaN(similarity) ? 0 : similarity - ) - ); -} - -export function normalize(M: number[][], similarity = false): number[][] { - const max = matrixMaxVal(M); - return M.map((row) => - row.map((val) => (similarity ? 1 - val / max : val / max)) - ); -} - -/** - * This function calculates the row-wise cosine similarity between two matrices with the same number of columns. - * - * @param {number[][]} X - The first matrix. - * @param {number[][]} Y - The second matrix. - * - * @throws {Error} If the number of columns in X and Y are not the same. - * - * @returns {number[][] | [[]]} A matrix where each row represents the cosine similarity values between the corresponding rows of X and Y. - */ -export function cosineSimilarity(X: number[][], Y: number[][]): number[][] { - return matrixFunc(X, Y, ml_distance_similarity.cosine); -} - -export function innerProduct(X: number[][], Y: number[][]): number[][] { - return matrixFunc(X, Y, ml_distance.innerProduct); -} - -export function euclideanDistance(X: number[][], Y: number[][]): number[][] { - return matrixFunc(X, Y, ml_distance.euclidean); -} - -/** - * This function implements the Maximal Marginal Relevance algorithm - * to select a set of embeddings that maximizes the diversity and relevance to a query embedding. - * - * @param {number[]|number[][]} queryEmbedding - The query embedding. - * @param {number[][]} embeddingList - The list of embeddings to select from. - * @param {number} [lambda=0.5] - The trade-off parameter between relevance and diversity. - * @param {number} [k=4] - The maximum number of embeddings to select. - * - * @returns {number[]} The indexes of the selected embeddings in the embeddingList. - */ -export function maximalMarginalRelevance( - queryEmbedding: number[] | number[][], - embeddingList: number[][], - lambda = 0.5, - k = 4 -): number[] { - if (Math.min(k, embeddingList.length) <= 0) { - return []; - } - - const queryEmbeddingExpanded = ( - Array.isArray(queryEmbedding[0]) ? queryEmbedding : [queryEmbedding] - ) as number[][]; - - const similarityToQuery = cosineSimilarity( - queryEmbeddingExpanded, - embeddingList - )[0]; - const mostSimilarEmbeddingIndex = argMax(similarityToQuery).maxIndex; - - const selectedEmbeddings = [embeddingList[mostSimilarEmbeddingIndex]]; - const selectedEmbeddingsIndexes = [mostSimilarEmbeddingIndex]; - - while (selectedEmbeddingsIndexes.length < Math.min(k, embeddingList.length)) { - let bestScore = -Infinity; - let bestIndex = -1; - - const similarityToSelected = cosineSimilarity( - embeddingList, - selectedEmbeddings - ); - - similarityToQuery.forEach((queryScore, queryScoreIndex) => { - if (selectedEmbeddingsIndexes.includes(queryScoreIndex)) { - return; - } - const maxSimilarityToSelected = Math.max( - ...similarityToSelected[queryScoreIndex] - ); - const score = - lambda * queryScore - (1 - lambda) * maxSimilarityToSelected; - - if (score > bestScore) { - bestScore = score; - bestIndex = queryScoreIndex; - } - }); - selectedEmbeddings.push(embeddingList[bestIndex]); - selectedEmbeddingsIndexes.push(bestIndex); - } - - return selectedEmbeddingsIndexes; -} - -type MaxInfo = { - maxIndex: number; - maxValue: number; -}; - -/** - * Finds the index of the maximum value in the given array. - * @param {number[]} array - The input array. - * - * @returns {number} The index of the maximum value in the array. If the array is empty, returns -1. - */ -function argMax(array: number[]): MaxInfo { - if (array.length === 0) { - return { - maxIndex: -1, - maxValue: NaN, - }; - } - - let maxValue = array[0]; - let maxIndex = 0; - - for (let i = 1; i < array.length; i += 1) { - if (array[i] > maxValue) { - maxIndex = i; - maxValue = array[i]; - } - } - return { maxIndex, maxValue }; -} - -function matrixMaxVal(arrays: number[][]): number { - return arrays.reduce( - (acc, array) => Math.max(acc, argMax(array).maxValue), - 0 - ); -} +export * from "@langchain/core/utils/math"; diff --git a/langchain/src/util/openai-format-fndef.ts b/langchain/src/util/openai-format-fndef.ts deleted file mode 100644 index 3fe476b90554..000000000000 --- a/langchain/src/util/openai-format-fndef.ts +++ /dev/null @@ -1,135 +0,0 @@ -/** - * Formatting function definitions for calculating openai function defination token usage. - * - * https://github.com/hmarr/openai-chat-tokens/blob/main/src/functions.ts - * (c) 2023 Harry Marr - * MIT license - */ -import OpenAI from "openai"; - -type OpenAIFunction = OpenAI.Chat.ChatCompletionCreateParams.Function; - -// Types representing the OpenAI function definitions. While the OpenAI client library -// does have types for function definitions, the properties are just Record, -// which isn't very useful for type checking this formatting code. -export interface FunctionDef extends Omit { - name: string; - description?: string; - parameters: ObjectProp; -} - -interface ObjectProp { - type: "object"; - properties?: { - [key: string]: Prop; - }; - required?: string[]; -} - -interface AnyOfProp { - anyOf: Prop[]; -} - -type Prop = { - description?: string; -} & ( - | AnyOfProp - | ObjectProp - | { - type: "string"; - enum?: string[]; - } - | { - type: "number" | "integer"; - minimum?: number; - maximum?: number; - enum?: number[]; - } - | { type: "boolean" } - | { type: "null" } - | { - type: "array"; - items?: Prop; - } -); - -function isAnyOfProp(prop: Prop): prop is AnyOfProp { - return ( - (prop as AnyOfProp).anyOf !== undefined && - Array.isArray((prop as AnyOfProp).anyOf) - ); -} - -// When OpenAI use functions in the prompt, they format them as TypeScript definitions rather than OpenAPI JSON schemas. -// This function converts the JSON schemas into TypeScript definitions. -export function formatFunctionDefinitions(functions: FunctionDef[]) { - const lines = ["namespace functions {", ""]; - for (const f of functions) { - if (f.description) { - lines.push(`// ${f.description}`); - } - if (Object.keys(f.parameters.properties ?? {}).length > 0) { - lines.push(`type ${f.name} = (_: {`); - lines.push(formatObjectProperties(f.parameters, 0)); - lines.push("}) => any;"); - } else { - lines.push(`type ${f.name} = () => any;`); - } - lines.push(""); - } - lines.push("} // namespace functions"); - return lines.join("\n"); -} - -// Format just the properties of an object (not including the surrounding braces) -function formatObjectProperties(obj: ObjectProp, indent: number): string { - const lines: string[] = []; - for (const [name, param] of Object.entries(obj.properties ?? {})) { - if (param.description && indent < 2) { - lines.push(`// ${param.description}`); - } - if (obj.required?.includes(name)) { - lines.push(`${name}: ${formatType(param, indent)},`); - } else { - lines.push(`${name}?: ${formatType(param, indent)},`); - } - } - return lines.map((line) => " ".repeat(indent) + line).join("\n"); -} - -// Format a single property type -function formatType(param: Prop, indent: number): string { - if (isAnyOfProp(param)) { - return param.anyOf.map((v) => formatType(v, indent)).join(" | "); - } - switch (param.type) { - case "string": - if (param.enum) { - return param.enum.map((v) => `"${v}"`).join(" | "); - } - return "string"; - case "number": - if (param.enum) { - return param.enum.map((v) => `${v}`).join(" | "); - } - return "number"; - case "integer": - if (param.enum) { - return param.enum.map((v) => `${v}`).join(" | "); - } - return "number"; - case "boolean": - return "boolean"; - case "null": - return "null"; - case "object": - return ["{", formatObjectProperties(param, indent + 2), "}"].join("\n"); - case "array": - if (param.items) { - return `${formatType(param.items, indent)}[]`; - } - return "any[]"; - default: - return ""; - } -} diff --git a/langchain/src/util/openai.ts b/langchain/src/util/openai.ts deleted file mode 100644 index db7fe66f032e..000000000000 --- a/langchain/src/util/openai.ts +++ /dev/null @@ -1,16 +0,0 @@ -import { APIConnectionTimeoutError, APIUserAbortError } from "openai"; - -// eslint-disable-next-line @typescript-eslint/no-explicit-any -export function wrapOpenAIClientError(e: any) { - let error; - if (e.constructor.name === APIConnectionTimeoutError.name) { - error = new Error(e.message); - error.name = "TimeoutError"; - } else if (e.constructor.name === APIUserAbortError.name) { - error = new Error(e.message); - error.name = "AbortError"; - } else { - error = e; - } - return error; -} diff --git a/langchain/src/vectorstores/analyticdb.ts b/langchain/src/vectorstores/analyticdb.ts index 15f0eb6d831f..99b074dc0ace 100644 --- a/langchain/src/vectorstores/analyticdb.ts +++ b/langchain/src/vectorstores/analyticdb.ts @@ -1,390 +1 @@ -import * as uuid from "uuid"; -import pg, { Pool, PoolConfig } from "pg"; -import { from as copyFrom } from "pg-copy-streams"; -import { pipeline } from "node:stream/promises"; -import { Readable } from "node:stream"; - -import { VectorStore } from "./base.js"; -import { Embeddings } from "../embeddings/base.js"; -import { Document } from "../document.js"; - -const _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain_document"; - -/** - * Interface defining the arguments required to create an instance of - * `AnalyticDBVectorStore`. - */ -export interface AnalyticDBArgs { - connectionOptions: PoolConfig; - embeddingDimension?: number; - collectionName?: string; - preDeleteCollection?: boolean; -} - -/** - * Interface defining the structure of data to be stored in the - * AnalyticDB. - */ -interface DataType { - id: string; - embedding: number[]; - document: string; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - metadata: Record; -} - -/** - * Class that provides methods for creating and managing a collection of - * documents in an AnalyticDB, adding documents or vectors to the - * collection, performing similarity search on vectors, and creating an - * instance of `AnalyticDBVectorStore` from texts or documents. - */ -export class AnalyticDBVectorStore extends VectorStore { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - declare FilterType: Record; - - private pool: Pool; - - private embeddingDimension?: number; - - private collectionName: string; - - private preDeleteCollection: boolean; - - private isCreateCollection = false; - - _vectorstoreType(): string { - return "analyticdb"; - } - - constructor(embeddings: Embeddings, args: AnalyticDBArgs) { - super(embeddings, args); - - this.pool = new pg.Pool({ - host: args.connectionOptions.host, - port: args.connectionOptions.port, - database: args.connectionOptions.database, - user: args.connectionOptions.user, - password: args.connectionOptions.password, - }); - this.embeddingDimension = args.embeddingDimension; - this.collectionName = - args.collectionName || _LANGCHAIN_DEFAULT_COLLECTION_NAME; - this.preDeleteCollection = args.preDeleteCollection || false; - } - - /** - * Closes all the clients in the pool and terminates the pool. - * @returns Promise that resolves when all clients are closed and the pool is terminated. - */ - async end(): Promise { - return this.pool.end(); - } - - /** - * Creates a new table in the database if it does not already exist. The - * table is created with columns for id, embedding, document, and - * metadata. An index is also created on the embedding column if it does - * not already exist. - * @returns Promise that resolves when the table and index are created. - */ - async createTableIfNotExists(): Promise { - if (!this.embeddingDimension) { - this.embeddingDimension = ( - await this.embeddings.embedQuery("test") - ).length; - } - const client = await this.pool.connect(); - try { - await client.query("BEGIN"); - // Create the table if it doesn't exist - await client.query(` - CREATE TABLE IF NOT EXISTS ${this.collectionName} ( - id TEXT PRIMARY KEY DEFAULT NULL, - embedding REAL[], - document TEXT, - metadata JSON - ); - `); - - // Check if the index exists - const indexName = `${this.collectionName}_embedding_idx`; - const indexQuery = ` - SELECT 1 - FROM pg_indexes - WHERE indexname = '${indexName}'; - `; - const result = await client.query(indexQuery); - - // Create the index if it doesn't exist - if (result.rowCount === 0) { - const indexStatement = ` - CREATE INDEX ${indexName} - ON ${this.collectionName} USING ann(embedding) - WITH ( - "dim" = ${this.embeddingDimension}, - "hnsw_m" = 100 - ); - `; - await client.query(indexStatement); - } - await client.query("COMMIT"); - } catch (err) { - await client.query("ROLLBACK"); - throw err; - } finally { - client.release(); - } - } - - /** - * Deletes the collection from the database if it exists. - * @returns Promise that resolves when the collection is deleted. - */ - async deleteCollection(): Promise { - const dropStatement = `DROP TABLE IF EXISTS ${this.collectionName};`; - await this.pool.query(dropStatement); - } - - /** - * Creates a new collection in the database. If `preDeleteCollection` is - * true, any existing collection with the same name is deleted before the - * new collection is created. - * @returns Promise that resolves when the collection is created. - */ - async createCollection(): Promise { - if (this.preDeleteCollection) { - await this.deleteCollection(); - } - await this.createTableIfNotExists(); - this.isCreateCollection = true; - } - - /** - * Adds an array of documents to the collection. The documents are first - * converted to vectors using the `embedDocuments` method of the - * `embeddings` instance. - * @param documents Array of Document instances to be added to the collection. - * @returns Promise that resolves when the documents are added. - */ - async addDocuments(documents: Document[]): Promise { - const texts = documents.map(({ pageContent }) => pageContent); - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents - ); - } - - /** - * Adds an array of vectors and corresponding documents to the collection. - * The vectors and documents are batch inserted into the database. - * @param vectors Array of vectors to be added to the collection. - * @param documents Array of Document instances corresponding to the vectors. - * @returns Promise that resolves when the vectors and documents are added. - */ - async addVectors(vectors: number[][], documents: Document[]): Promise { - if (vectors.length === 0) { - return; - } - if (vectors.length !== documents.length) { - throw new Error(`Vectors and documents must have the same length`); - } - if (!this.embeddingDimension) { - this.embeddingDimension = ( - await this.embeddings.embedQuery("test") - ).length; - } - if (vectors[0].length !== this.embeddingDimension) { - throw new Error( - `Vectors must have the same length as the number of dimensions (${this.embeddingDimension})` - ); - } - - if (!this.isCreateCollection) { - await this.createCollection(); - } - - const client = await this.pool.connect(); - try { - const chunkSize = 500; - const chunksTableData: DataType[] = []; - - for (let i = 0; i < documents.length; i += 1) { - chunksTableData.push({ - id: uuid.v4(), - embedding: vectors[i], - document: documents[i].pageContent, - metadata: documents[i].metadata, - }); - - // Execute the batch insert when the batch size is reached - if (chunksTableData.length === chunkSize) { - const rs = new Readable(); - let currentIndex = 0; - rs._read = function () { - if (currentIndex === chunkSize) { - rs.push(null); - } else { - const data = chunksTableData[currentIndex]; - rs.push( - `${data.id}\t{${data.embedding.join(",")}}\t${ - data.document - }\t${JSON.stringify(data.metadata)}\n` - ); - currentIndex += 1; - } - }; - const ws = client.query( - copyFrom( - `COPY ${this.collectionName}(id, embedding, document, metadata) FROM STDIN` - ) - ); - - await pipeline(rs, ws); - // Clear the chunksTableData list for the next batch - chunksTableData.length = 0; - } - } - - // Insert any remaining records that didn't make up a full batch - if (chunksTableData.length > 0) { - const rs = new Readable(); - let currentIndex = 0; - rs._read = function () { - if (currentIndex === chunksTableData.length) { - rs.push(null); - } else { - const data = chunksTableData[currentIndex]; - rs.push( - `${data.id}\t{${data.embedding.join(",")}}\t${ - data.document - }\t${JSON.stringify(data.metadata)}\n` - ); - currentIndex += 1; - } - }; - const ws = client.query( - copyFrom( - `COPY ${this.collectionName}(id, embedding, document, metadata) FROM STDIN` - ) - ); - await pipeline(rs, ws); - } - } finally { - client.release(); - } - } - - /** - * Performs a similarity search on the vectors in the collection. The - * search is performed using the given query vector and returns the top k - * most similar vectors along with their corresponding documents and - * similarity scores. - * @param query Query vector for the similarity search. - * @param k Number of top similar vectors to return. - * @param filter Optional. Filter to apply on the metadata of the documents. - * @returns Promise that resolves to an array of tuples, each containing a Document instance and its similarity score. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: this["FilterType"] - ): Promise<[Document, number][]> { - if (!this.isCreateCollection) { - await this.createCollection(); - } - - let filterCondition = ""; - const filterEntries = filter ? Object.entries(filter) : []; - if (filterEntries.length > 0) { - const conditions = filterEntries.map( - (_, index) => `metadata->>$${2 * index + 3} = $${2 * index + 4}` - ); - filterCondition = `WHERE ${conditions.join(" AND ")}`; - } - - const sqlQuery = ` - SELECT *, l2_distance(embedding, $1::real[]) AS distance - FROM ${this.collectionName} - ${filterCondition} - ORDER BY embedding <-> $1 - LIMIT $2; - `; - - // Execute the query and fetch the results - const { rows } = await this.pool.query(sqlQuery, [ - query, - k, - ...filterEntries.flatMap(([key, value]) => [key, value]), - ]); - - const result: [Document, number][] = rows.map((row) => [ - new Document({ pageContent: row.document, metadata: row.metadata }), - row.distance, - ]); - - return result; - } - - /** - * Creates an instance of `AnalyticDBVectorStore` from an array of texts - * and corresponding metadata. The texts are first converted to Document - * instances before being added to the collection. - * @param texts Array of texts to be added to the collection. - * @param metadatas Array or object of metadata corresponding to the texts. - * @param embeddings Embeddings instance used to convert the texts to vectors. - * @param dbConfig Configuration for the AnalyticDB. - * @returns Promise that resolves to an instance of `AnalyticDBVectorStore`. - */ - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - dbConfig: AnalyticDBArgs - ): Promise { - const docs = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - return AnalyticDBVectorStore.fromDocuments(docs, embeddings, dbConfig); - } - - /** - * Creates an instance of `AnalyticDBVectorStore` from an array of - * Document instances. The documents are added to the collection. - * @param docs Array of Document instances to be added to the collection. - * @param embeddings Embeddings instance used to convert the documents to vectors. - * @param dbConfig Configuration for the AnalyticDB. - * @returns Promise that resolves to an instance of `AnalyticDBVectorStore`. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig: AnalyticDBArgs - ): Promise { - const instance = new this(embeddings, dbConfig); - await instance.addDocuments(docs); - return instance; - } - - /** - * Creates an instance of `AnalyticDBVectorStore` from an existing index - * in the database. A new collection is created in the database. - * @param embeddings Embeddings instance used to convert the documents to vectors. - * @param dbConfig Configuration for the AnalyticDB. - * @returns Promise that resolves to an instance of `AnalyticDBVectorStore`. - */ - static async fromExistingIndex( - embeddings: Embeddings, - dbConfig: AnalyticDBArgs - ): Promise { - const instance = new this(embeddings, dbConfig); - await instance.createCollection(); - return instance; - } -} +export * from "@langchain/community/vectorstores/analyticdb"; diff --git a/langchain/src/vectorstores/base.ts b/langchain/src/vectorstores/base.ts index 29701d6b091b..58e6d1c589ed 100644 --- a/langchain/src/vectorstores/base.ts +++ b/langchain/src/vectorstores/base.ts @@ -1,299 +1 @@ -import { Embeddings } from "../embeddings/base.js"; -import { Document } from "../document.js"; -import { BaseRetriever, BaseRetrieverInput } from "../schema/retriever.js"; -import { Serializable } from "../load/serializable.js"; -import { - CallbackManagerForRetrieverRun, - Callbacks, -} from "../callbacks/manager.js"; - -/** - * Type for options when adding a document to the VectorStore. - */ -// eslint-disable-next-line @typescript-eslint/no-explicit-any -type AddDocumentOptions = Record; - -/** - * Type for options when performing a maximal marginal relevance search. - */ -export type MaxMarginalRelevanceSearchOptions = { - k: number; - fetchK?: number; - lambda?: number; - filter?: FilterType; -}; - -/** - * Type for options when performing a maximal marginal relevance search - * with the VectorStoreRetriever. - */ -export type VectorStoreRetrieverMMRSearchKwargs = { - fetchK?: number; - lambda?: number; -}; - -/** - * Type for input when creating a VectorStoreRetriever instance. - */ -export type VectorStoreRetrieverInput = - BaseRetrieverInput & - ( - | { - vectorStore: V; - k?: number; - filter?: V["FilterType"]; - searchType?: "similarity"; - } - | { - vectorStore: V; - k?: number; - filter?: V["FilterType"]; - searchType: "mmr"; - searchKwargs?: VectorStoreRetrieverMMRSearchKwargs; - } - ); - -/** - * Class for performing document retrieval from a VectorStore. Can perform - * similarity search or maximal marginal relevance search. - */ -export class VectorStoreRetriever< - V extends VectorStore = VectorStore -> extends BaseRetriever { - static lc_name() { - return "VectorStoreRetriever"; - } - - get lc_namespace() { - return ["langchain", "retrievers", "base"]; - } - - vectorStore: V; - - k = 4; - - searchType = "similarity"; - - searchKwargs?: VectorStoreRetrieverMMRSearchKwargs; - - filter?: V["FilterType"]; - - _vectorstoreType(): string { - return this.vectorStore._vectorstoreType(); - } - - constructor(fields: VectorStoreRetrieverInput) { - super(fields); - this.vectorStore = fields.vectorStore; - this.k = fields.k ?? this.k; - this.searchType = fields.searchType ?? this.searchType; - this.filter = fields.filter; - if (fields.searchType === "mmr") { - this.searchKwargs = fields.searchKwargs; - } - } - - async _getRelevantDocuments( - query: string, - runManager?: CallbackManagerForRetrieverRun - ): Promise { - if (this.searchType === "mmr") { - if (typeof this.vectorStore.maxMarginalRelevanceSearch !== "function") { - throw new Error( - `The vector store backing this retriever, ${this._vectorstoreType()} does not support max marginal relevance search.` - ); - } - return this.vectorStore.maxMarginalRelevanceSearch( - query, - { - k: this.k, - filter: this.filter, - ...this.searchKwargs, - }, - runManager?.getChild("vectorstore") - ); - } - return this.vectorStore.similaritySearch( - query, - this.k, - this.filter, - runManager?.getChild("vectorstore") - ); - } - - async addDocuments( - documents: Document[], - options?: AddDocumentOptions - ): Promise { - return this.vectorStore.addDocuments(documents, options); - } -} - -/** - * Abstract class representing a store of vectors. Provides methods for - * adding vectors and documents, deleting from the store, and searching - * the store. - */ -export abstract class VectorStore extends Serializable { - declare FilterType: object | string; - - lc_namespace = ["langchain", "vectorstores", this._vectorstoreType()]; - - embeddings: Embeddings; - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - constructor(embeddings: Embeddings, dbConfig: Record) { - super(dbConfig); - this.embeddings = embeddings; - } - - abstract _vectorstoreType(): string; - - abstract addVectors( - vectors: number[][], - documents: Document[], - options?: AddDocumentOptions - ): Promise; - - abstract addDocuments( - documents: Document[], - options?: AddDocumentOptions - ): Promise; - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - async delete(_params?: Record): Promise { - throw new Error("Not implemented."); - } - - abstract similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: this["FilterType"] - ): Promise<[Document, number][]>; - - async similaritySearch( - query: string, - k = 4, - filter: this["FilterType"] | undefined = undefined, - _callbacks: Callbacks | undefined = undefined // implement passing to embedQuery later - ): Promise { - const results = await this.similaritySearchVectorWithScore( - await this.embeddings.embedQuery(query), - k, - filter - ); - - return results.map((result) => result[0]); - } - - async similaritySearchWithScore( - query: string, - k = 4, - filter: this["FilterType"] | undefined = undefined, - _callbacks: Callbacks | undefined = undefined // implement passing to embedQuery later - ): Promise<[Document, number][]> { - return this.similaritySearchVectorWithScore( - await this.embeddings.embedQuery(query), - k, - filter - ); - } - - /** - * Return documents selected using the maximal marginal relevance. - * Maximal marginal relevance optimizes for similarity to the query AND diversity - * among selected documents. - * - * @param {string} query - Text to look up documents similar to. - * @param {number} options.k - Number of documents to return. - * @param {number} options.fetchK - Number of documents to fetch before passing to the MMR algorithm. - * @param {number} options.lambda - Number between 0 and 1 that determines the degree of diversity among the results, - * where 0 corresponds to maximum diversity and 1 to minimum diversity. - * @param {this["FilterType"]} options.filter - Optional filter - * @param _callbacks - * - * @returns {Promise} - List of documents selected by maximal marginal relevance. - */ - async maxMarginalRelevanceSearch?( - query: string, - options: MaxMarginalRelevanceSearchOptions, - _callbacks: Callbacks | undefined // implement passing to embedQuery later - ): Promise; - - static fromTexts( - _texts: string[], - _metadatas: object[] | object, - _embeddings: Embeddings, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - _dbConfig: Record - ): Promise { - throw new Error( - "the Langchain vectorstore implementation you are using forgot to override this, please report a bug" - ); - } - - static fromDocuments( - _docs: Document[], - _embeddings: Embeddings, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - _dbConfig: Record - ): Promise { - throw new Error( - "the Langchain vectorstore implementation you are using forgot to override this, please report a bug" - ); - } - - asRetriever( - kOrFields?: number | Partial>, - filter?: this["FilterType"], - callbacks?: Callbacks, - tags?: string[], - metadata?: Record, - verbose?: boolean - ): VectorStoreRetriever { - if (typeof kOrFields === "number") { - return new VectorStoreRetriever({ - vectorStore: this, - k: kOrFields, - filter, - tags: [...(tags ?? []), this._vectorstoreType()], - metadata, - verbose, - callbacks, - }); - } else { - const params = { - vectorStore: this, - k: kOrFields?.k, - filter: kOrFields?.filter, - tags: [...(kOrFields?.tags ?? []), this._vectorstoreType()], - metadata: kOrFields?.metadata, - verbose: kOrFields?.verbose, - callbacks: kOrFields?.callbacks, - searchType: kOrFields?.searchType, - }; - if (kOrFields?.searchType === "mmr") { - return new VectorStoreRetriever({ - ...params, - searchKwargs: kOrFields.searchKwargs, - }); - } - return new VectorStoreRetriever({ ...params }); - } - } -} - -/** - * Abstract class extending VectorStore with functionality for saving and - * loading the vector store. - */ -export abstract class SaveableVectorStore extends VectorStore { - abstract save(directory: string): Promise; - - static load( - _directory: string, - _embeddings: Embeddings - ): Promise { - throw new Error("Not implemented"); - } -} +export * from "@langchain/core/vectorstores"; diff --git a/langchain/src/vectorstores/cassandra.ts b/langchain/src/vectorstores/cassandra.ts index d3938d4a6b5f..6c6e084674a4 100644 --- a/langchain/src/vectorstores/cassandra.ts +++ b/langchain/src/vectorstores/cassandra.ts @@ -1,581 +1 @@ -/* eslint-disable prefer-template */ -import { Client as CassandraClient, DseClientOptions } from "cassandra-driver"; - -import { AsyncCaller, AsyncCallerParams } from "../util/async_caller.js"; -import { Embeddings } from "../embeddings/base.js"; -import { VectorStore } from "./base.js"; -import { Document } from "../document.js"; - -export interface Column { - type: string; - name: string; - partition?: boolean; -} - -export interface Index { - name: string; - value: string; -} - -export interface Filter { - name: string; - value: unknown; - operator?: string; -} - -export type WhereClause = Filter[] | Filter | Record; - -export type SupportedVectorTypes = "cosine" | "dot_product" | "euclidean"; - -export interface CassandraLibArgs extends DseClientOptions, AsyncCallerParams { - table: string; - keyspace: string; - vectorType?: SupportedVectorTypes; - dimensions: number; - primaryKey: Column | Column[]; - metadataColumns: Column[]; - withClause?: string; - indices?: Index[]; - batchSize?: number; -} - -/** - * Class for interacting with the Cassandra database. It extends the - * VectorStore class and provides methods for adding vectors and - * documents, searching for similar vectors, and creating instances from - * texts or documents. - */ -export class CassandraStore extends VectorStore { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - declare FilterType: WhereClause; - - private client: CassandraClient; - - private readonly vectorType: SupportedVectorTypes; - - private readonly dimensions: number; - - private readonly keyspace: string; - - private primaryKey: Column[]; - - private metadataColumns: Column[]; - - private withClause: string; - - private selectColumns: string; - - private readonly table: string; - - private indices: Index[]; - - private isInitialized = false; - - asyncCaller: AsyncCaller; - - private readonly batchSize: number; - - _vectorstoreType(): string { - return "cassandra"; - } - - constructor(embeddings: Embeddings, args: CassandraLibArgs) { - super(embeddings, args); - - const { - indices = [], - maxConcurrency = 25, - withClause = "", - batchSize = 1, - vectorType = "cosine", - dimensions, - keyspace, - table, - primaryKey, - metadataColumns, - } = args; - - const argsWithDefaults = { - ...args, - indices, - maxConcurrency, - withClause, - batchSize, - vectorType, - }; - this.asyncCaller = new AsyncCaller(argsWithDefaults); - this.client = new CassandraClient(argsWithDefaults); - - // Assign properties - this.vectorType = vectorType; - this.dimensions = dimensions; - this.keyspace = keyspace; - this.table = table; - this.primaryKey = Array.isArray(primaryKey) ? primaryKey : [primaryKey]; - this.metadataColumns = metadataColumns; - this.withClause = withClause.trim().replace(/^with\s*/i, ""); - this.indices = indices; - this.batchSize = batchSize >= 1 ? batchSize : 1; - } - - /** - * Method to save vectors to the Cassandra database. - * @param vectors Vectors to save. - * @param documents The documents associated with the vectors. - * @returns Promise that resolves when the vectors have been added. - */ - async addVectors(vectors: number[][], documents: Document[]): Promise { - if (vectors.length === 0) { - return; - } - - if (!this.isInitialized) { - await this.initialize(); - } - - await this.insertAll(vectors, documents); - } - - /** - * Method to add documents to the Cassandra database. - * @param documents The documents to add. - * @returns Promise that resolves when the documents have been added. - */ - async addDocuments(documents: Document[]): Promise { - return this.addVectors( - await this.embeddings.embedDocuments(documents.map((d) => d.pageContent)), - documents - ); - } - - /** - * Method to search for vectors that are similar to a given query vector. - * @param query The query vector. - * @param k The number of similar vectors to return. - * @param filter - * @returns Promise that resolves with an array of tuples, each containing a Document and a score. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: WhereClause - ): Promise<[Document, number][]> { - if (!this.isInitialized) { - await this.initialize(); - } - - // Ensure we have an array of Filter from the public interface - const filters = this.asFilters(filter); - - const queryStr = this.buildSearchQuery(filters); - - // Search query will be of format: - // SELECT ..., text, similarity_x(?) AS similarity_score - // FROM ... - // - // ORDER BY vector ANN OF ? - // LIMIT ? - // If any filter values are specified, they will be in the WHERE clause as - // filter.name filter.operator ? - // queryParams is a list of bind variables sent with the prepared statement - const queryParams = []; - const vectorAsFloat32Array = new Float32Array(query); - queryParams.push(vectorAsFloat32Array); - if (filters) { - const values = (filters as Filter[]).map(({ value }) => value); - queryParams.push(...values); - } - queryParams.push(vectorAsFloat32Array); - queryParams.push(k); - - const queryResultSet = await this.client.execute(queryStr, queryParams, { - prepare: true, - }); - - return queryResultSet?.rows.map((row) => { - const textContent = row.text; - const sanitizedRow = { ...row }; - delete sanitizedRow.text; - delete sanitizedRow.similarity_score; - - // A null value in Cassandra evaluates to a deleted column - // as this is treated as a tombstone record for the cell. - Object.keys(sanitizedRow).forEach((key) => { - if (sanitizedRow[key] === null) { - delete sanitizedRow[key]; - } - }); - - return [ - new Document({ pageContent: textContent, metadata: sanitizedRow }), - row.similarity_score, - ]; - }); - } - - /** - * Static method to create an instance of CassandraStore from texts. - * @param texts The texts to use. - * @param metadatas The metadata associated with the texts. - * @param embeddings The embeddings to use. - * @param args The arguments for the CassandraStore. - * @returns Promise that resolves with a new instance of CassandraStore. - */ - static async fromTexts( - texts: string[], - metadatas: object | object[], - embeddings: Embeddings, - args: CassandraLibArgs - ): Promise { - const docs: Document[] = []; - - for (let index = 0; index < texts.length; index += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[index] : metadatas; - const doc = new Document({ - pageContent: texts[index], - metadata, - }); - docs.push(doc); - } - - return CassandraStore.fromDocuments(docs, embeddings, args); - } - - /** - * Static method to create an instance of CassandraStore from documents. - * @param docs The documents to use. - * @param embeddings The embeddings to use. - * @param args The arguments for the CassandraStore. - * @returns Promise that resolves with a new instance of CassandraStore. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - args: CassandraLibArgs - ): Promise { - const instance = new this(embeddings, args); - await instance.addDocuments(docs); - return instance; - } - - /** - * Static method to create an instance of CassandraStore from an existing - * index. - * @param embeddings The embeddings to use. - * @param args The arguments for the CassandraStore. - * @returns Promise that resolves with a new instance of CassandraStore. - */ - static async fromExistingIndex( - embeddings: Embeddings, - args: CassandraLibArgs - ): Promise { - const instance = new this(embeddings, args); - - await instance.initialize(); - return instance; - } - - /** - * Method to initialize the Cassandra database. - * @returns Promise that resolves when the database has been initialized. - */ - private async initialize(): Promise { - let cql = ""; - cql = `CREATE TABLE IF NOT EXISTS ${this.keyspace}.${this.table} ( - ${this.primaryKey.map((col) => `${col.name} ${col.type}`).join(", ")} - , text TEXT - ${ - this.metadataColumns.length > 0 - ? ", " + - this.metadataColumns - .map((col) => `${col.name} ${col.type}`) - .join(", ") - : "" - } - , vector VECTOR - , ${this.buildPrimaryKey(this.primaryKey)} - ) ${this.withClause ? `WITH ${this.withClause}` : ""};`; - - await this.client.execute(cql); - - this.selectColumns = `${this.primaryKey - .map((col) => `${col.name}`) - .join(", ")} - ${ - this.metadataColumns.length > 0 - ? ", " + - this.metadataColumns - .map((col) => `${col.name}`) - .join(", ") - : "" - }`; - - cql = `CREATE CUSTOM INDEX IF NOT EXISTS idx_vector_${this.table} - ON ${this.keyspace}.${ - this.table - }(vector) USING 'StorageAttachedIndex' WITH OPTIONS = {'similarity_function': '${this.vectorType.toUpperCase()}'};`; - await this.client.execute(cql); - - for await (const { name, value } of this.indices) { - cql = `CREATE CUSTOM INDEX IF NOT EXISTS idx_${this.table}_${name} - ON ${this.keyspace}.${this.table} ${value} USING 'StorageAttachedIndex';`; - await this.client.execute(cql); - } - this.isInitialized = true; - } - - /** - * Method to build the PRIMARY KEY clause for CREATE TABLE. - * @param columns: list of Column to include in the key - * @returns The clause, including PRIMARY KEY - */ - private buildPrimaryKey(columns: Column[]): string { - // Partition columns may be specified with optional attribute col.partition - const partitionColumns = columns - .filter((col) => col.partition) - .map((col) => col.name) - .join(", "); - - // All columns not part of the partition key are clustering columns - const clusteringColumns = columns - .filter((col) => !col.partition) - .map((col) => col.name) - .join(", "); - - let primaryKey = ""; - - // If partition columns are specified, they are included in a () wrapper - // If not, the clustering columns are used, and the first clustering column - // is the partition key per normal Cassandra behaviour. - if (partitionColumns) { - primaryKey = `PRIMARY KEY ((${partitionColumns}), ${clusteringColumns})`; - } else { - primaryKey = `PRIMARY KEY (${clusteringColumns})`; - } - - return primaryKey; - } - - /** - * Type guard to check if an object is a Filter. - * @param obj: the object to check - * @returns boolean indicating if the object is a Filter - */ - private isFilter(obj: unknown): obj is Filter { - return ( - typeof obj === "object" && obj !== null && "name" in obj && "value" in obj - ); - } - - /** - * Helper to convert Record to a Filter[] - * @param record: a key-value Record collection - * @returns Record as a Filter[] - */ - private convertToFilters(record: Record): Filter[] { - return Object.entries(record).map(([name, value]) => ({ - name, - value, - operator: "=", - })); - } - - /** - * Input santisation method for filters, as FilterType is not required to be - * Filter[], but we want to use Filter[] internally. - * @param record: the proposed filter - * @returns A Filter[], which may be empty - */ - private asFilters(record: WhereClause | undefined): Filter[] { - if (!record) { - return []; - } - - // If record is already an array - if (Array.isArray(record)) { - return record.flatMap((item) => { - // Check if item is a Filter before passing it to convertToFilters - if (this.isFilter(item)) { - return [item]; - } else { - // Here item is treated as Record - return this.convertToFilters(item); - } - }); - } - - // If record is a single Filter object, return it in an array - if (this.isFilter(record)) { - return [record]; - } - - // If record is a Record, convert it to an array of Filter - return this.convertToFilters(record); - } - - /** - * Method to build the WHERE clause of a CQL query, using bind variable ? - * @param filters list of filters to include in the WHERE clause - * @returns The WHERE clause - */ - private buildWhereClause(filters?: Filter[]): string { - if (!filters || filters.length === 0) { - return ""; - } - - const whereConditions = filters.map( - ({ name, operator = "=" }) => `${name} ${operator} ?` - ); - - return `WHERE ${whereConditions.join(" AND ")}`; - } - - /** - * Method to build an CQL query for searching for similar vectors in the - * Cassandra database. - * @param query The query vector. - * @param k The number of similar vectors to return. - * @param filters - * @returns The CQL query string. - */ - private buildSearchQuery(filters: Filter[]): string { - const whereClause = filters ? this.buildWhereClause(filters) : ""; - - const cqlQuery = `SELECT ${this.selectColumns}, text, similarity_${this.vectorType}(vector, ?) AS similarity_score - FROM ${this.keyspace}.${this.table} ${whereClause} ORDER BY vector ANN OF ? LIMIT ?`; - - return cqlQuery; - } - - /** - * Method for inserting vectors and documents into the Cassandra database in a batch. - * @param batchVectors The list of vectors to insert. - * @param batchDocuments The list of documents to insert. - * @returns Promise that resolves when the batch has been inserted. - */ - private async executeInsert( - batchVectors: number[][], - batchDocuments: Document[] - ): Promise { - // Input validation: Check if the lengths of batchVectors and batchDocuments are the same - if (batchVectors.length !== batchDocuments.length) { - throw new Error( - `The lengths of vectors (${batchVectors.length}) and documents (${batchDocuments.length}) must be the same.` - ); - } - - // Initialize an array to hold query objects - const queries = []; - - // Loop through each vector and document in the batch - for (let i = 0; i < batchVectors.length; i += 1) { - // Convert the list of numbers to a Float32Array, the driver's expected format of a vector - const preparedVector = new Float32Array(batchVectors[i]); - // Retrieve the corresponding document - const document = batchDocuments[i]; - - // Extract metadata column names and values from the document - const metadataColNames = Object.keys(document.metadata); - const metadataVals = Object.values(document.metadata); - - // Prepare the metadata columns string for the query, if metadata exists - const metadataInsert = - metadataColNames.length > 0 ? ", " + metadataColNames.join(", ") : ""; - - // Construct the query string and parameters - const query = { - query: `INSERT INTO ${this.keyspace}.${ - this.table - } (vector, text${metadataInsert}) - VALUES (?, ?${", ?".repeat(metadataColNames.length)})`, - params: [preparedVector, document.pageContent, ...metadataVals], - }; - - // Add the query to the list - queries.push(query); - } - - // Execute the queries: use a batch if multiple, otherwise execute a single query - if (queries.length === 1) { - await this.client.execute(queries[0].query, queries[0].params, { - prepare: true, - }); - } else { - await this.client.batch(queries, { prepare: true, logged: false }); - } - } - - /** - * Method for inserting vectors and documents into the Cassandra database in - * parallel, keeping within maxConcurrency number of active insert statements. - * @param vectors The vectors to insert. - * @param documents The documents to insert. - * @returns Promise that resolves when the documents have been added. - */ - private async insertAll( - vectors: number[][], - documents: Document[] - ): Promise { - // Input validation: Check if the lengths of vectors and documents are the same - if (vectors.length !== documents.length) { - throw new Error( - `The lengths of vectors (${vectors.length}) and documents (${documents.length}) must be the same.` - ); - } - - // Early exit: If there are no vectors or documents to insert, return immediately - if (vectors.length === 0) { - return; - } - - // Ensure the store is initialized before proceeding - if (!this.isInitialized) { - await this.initialize(); - } - - // Initialize an array to hold promises for each batch insert - const insertPromises: Promise[] = []; - - // Buffers to hold the current batch of vectors and documents - let currentBatchVectors: number[][] = []; - let currentBatchDocuments: Document[] = []; - - // Loop through each vector/document pair to insert; we use - // <= vectors.length to ensure the last batch is inserted - for (let i = 0; i <= vectors.length; i += 1) { - // Check if we're still within the array boundaries - if (i < vectors.length) { - // Add the current vector and document to the batch - currentBatchVectors.push(vectors[i]); - currentBatchDocuments.push(documents[i]); - } - - // Check if we've reached the batch size or end of the array - if ( - currentBatchVectors.length >= this.batchSize || - i === vectors.length - ) { - // Only proceed if there are items in the current batch - if (currentBatchVectors.length > 0) { - // Create copies of the current batch arrays to use in the async insert operation - const batchVectors = [...currentBatchVectors]; - const batchDocuments = [...currentBatchDocuments]; - - // Execute the insert using the AsyncCaller - it will handle concurrency and queueing. - insertPromises.push( - this.asyncCaller.call(() => - this.executeInsert(batchVectors, batchDocuments) - ) - ); - - // Clear the current buffers for the next iteration - currentBatchVectors = []; - currentBatchDocuments = []; - } - } - } - - // Wait for all insert operations to complete. - await Promise.all(insertPromises); - } -} +export * from "@langchain/community/vectorstores/cassandra"; diff --git a/langchain/src/vectorstores/chroma.ts b/langchain/src/vectorstores/chroma.ts index 6b52b8c6f1b7..583129d43503 100644 --- a/langchain/src/vectorstores/chroma.ts +++ b/langchain/src/vectorstores/chroma.ts @@ -1,364 +1 @@ -import * as uuid from "uuid"; -import type { ChromaClient as ChromaClientT, Collection } from "chromadb"; -import type { CollectionMetadata, Where } from "chromadb/dist/main/types.js"; - -import { Embeddings } from "../embeddings/base.js"; -import { VectorStore } from "./base.js"; -import { Document } from "../document.js"; - -/** - * Defines the arguments that can be passed to the `Chroma` class - * constructor. It can either contain a `url` for the Chroma database, the - * number of dimensions for the vectors (`numDimensions`), a - * `collectionName` for the collection to be used in the database, and a - * `filter` object; or it can contain an `index` which is an instance of - * `ChromaClientT`, along with the `numDimensions`, `collectionName`, and - * `filter`. - */ -export type ChromaLibArgs = - | { - url?: string; - numDimensions?: number; - collectionName?: string; - filter?: object; - collectionMetadata?: CollectionMetadata; - } - | { - index?: ChromaClientT; - numDimensions?: number; - collectionName?: string; - filter?: object; - collectionMetadata?: CollectionMetadata; - }; - -/** - * Defines the parameters for the `delete` method in the `Chroma` class. - * It can either contain an array of `ids` of the documents to be deleted - * or a `filter` object to specify the documents to be deleted. - */ -export interface ChromaDeleteParams { - ids?: string[]; - filter?: T; -} - -/** - * The main class that extends the `VectorStore` class. It provides - * methods for interacting with the Chroma database, such as adding - * documents, deleting documents, and searching for similar vectors. - */ -export class Chroma extends VectorStore { - declare FilterType: Where; - - index?: ChromaClientT; - - collection?: Collection; - - collectionName: string; - - collectionMetadata?: CollectionMetadata; - - numDimensions?: number; - - url: string; - - filter?: object; - - _vectorstoreType(): string { - return "chroma"; - } - - constructor(embeddings: Embeddings, args: ChromaLibArgs) { - super(embeddings, args); - this.numDimensions = args.numDimensions; - this.embeddings = embeddings; - this.collectionName = ensureCollectionName(args.collectionName); - this.collectionMetadata = args.collectionMetadata; - if ("index" in args) { - this.index = args.index; - } else if ("url" in args) { - this.url = args.url || "http://localhost:8000"; - } - - this.filter = args.filter; - } - - /** - * Adds documents to the Chroma database. The documents are first - * converted to vectors using the `embeddings` instance, and then added to - * the database. - * @param documents An array of `Document` instances to be added to the database. - * @param options Optional. An object containing an array of `ids` for the documents. - * @returns A promise that resolves when the documents have been added to the database. - */ - async addDocuments(documents: Document[], options?: { ids?: string[] }) { - const texts = documents.map(({ pageContent }) => pageContent); - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents, - options - ); - } - - /** - * Ensures that a collection exists in the Chroma database. If the - * collection does not exist, it is created. - * @returns A promise that resolves with the `Collection` instance. - */ - async ensureCollection(): Promise { - if (!this.collection) { - if (!this.index) { - const { ChromaClient } = await Chroma.imports(); - this.index = new ChromaClient({ path: this.url }); - } - try { - this.collection = await this.index.getOrCreateCollection({ - name: this.collectionName, - ...(this.collectionMetadata && { metadata: this.collectionMetadata }), - }); - } catch (err) { - throw new Error(`Chroma getOrCreateCollection error: ${err}`); - } - } - - return this.collection; - } - - /** - * Adds vectors to the Chroma database. The vectors are associated with - * the provided documents. - * @param vectors An array of vectors to be added to the database. - * @param documents An array of `Document` instances associated with the vectors. - * @param options Optional. An object containing an array of `ids` for the vectors. - * @returns A promise that resolves with an array of document IDs when the vectors have been added to the database. - */ - async addVectors( - vectors: number[][], - documents: Document[], - options?: { ids?: string[] } - ) { - if (vectors.length === 0) { - return []; - } - if (this.numDimensions === undefined) { - this.numDimensions = vectors[0].length; - } - if (vectors.length !== documents.length) { - throw new Error(`Vectors and metadatas must have the same length`); - } - if (vectors[0].length !== this.numDimensions) { - throw new Error( - `Vectors must have the same length as the number of dimensions (${this.numDimensions})` - ); - } - - const documentIds = - options?.ids ?? Array.from({ length: vectors.length }, () => uuid.v1()); - const collection = await this.ensureCollection(); - - const mappedMetadatas = documents.map(({ metadata }) => { - let locFrom; - let locTo; - - if (metadata?.loc) { - if (metadata.loc.lines?.from !== undefined) - locFrom = metadata.loc.lines.from; - if (metadata.loc.lines?.to !== undefined) locTo = metadata.loc.lines.to; - } - - const newMetadata: Document["metadata"] = { - ...metadata, - ...(locFrom !== undefined && { locFrom }), - ...(locTo !== undefined && { locTo }), - }; - - if (newMetadata.loc) delete newMetadata.loc; - - return newMetadata; - }); - - await collection.upsert({ - ids: documentIds, - embeddings: vectors, - metadatas: mappedMetadatas, - documents: documents.map(({ pageContent }) => pageContent), - }); - return documentIds; - } - - /** - * Deletes documents from the Chroma database. The documents to be deleted - * can be specified by providing an array of `ids` or a `filter` object. - * @param params An object containing either an array of `ids` of the documents to be deleted or a `filter` object to specify the documents to be deleted. - * @returns A promise that resolves when the specified documents have been deleted from the database. - */ - async delete(params: ChromaDeleteParams): Promise { - const collection = await this.ensureCollection(); - if (Array.isArray(params.ids)) { - await collection.delete({ ids: params.ids }); - } else if (params.filter) { - await collection.delete({ - where: { ...params.filter }, - }); - } else { - throw new Error(`You must provide one of "ids or "filter".`); - } - } - - /** - * Searches for vectors in the Chroma database that are similar to the - * provided query vector. The search can be filtered using the provided - * `filter` object or the `filter` property of the `Chroma` instance. - * @param query The query vector. - * @param k The number of similar vectors to return. - * @param filter Optional. A `filter` object to filter the search results. - * @returns A promise that resolves with an array of tuples, each containing a `Document` instance and a similarity score. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: this["FilterType"] - ) { - if (filter && this.filter) { - throw new Error("cannot provide both `filter` and `this.filter`"); - } - const _filter = filter ?? this.filter; - - const collection = await this.ensureCollection(); - - // similaritySearchVectorWithScore supports one query vector at a time - // chroma supports multiple query vectors at a time - const result = await collection.query({ - queryEmbeddings: query, - nResults: k, - where: { ..._filter }, - }); - - const { ids, distances, documents, metadatas } = result; - if (!ids || !distances || !documents || !metadatas) { - return []; - } - // get the result data from the first and only query vector - const [firstIds] = ids; - const [firstDistances] = distances; - const [firstDocuments] = documents; - const [firstMetadatas] = metadatas; - - const results: [Document, number][] = []; - for (let i = 0; i < firstIds.length; i += 1) { - let metadata: Document["metadata"] = firstMetadatas?.[i] ?? {}; - - if (metadata.locFrom && metadata.locTo) { - metadata = { - ...metadata, - loc: { - lines: { - from: metadata.locFrom, - to: metadata.locTo, - }, - }, - }; - - delete metadata.locFrom; - delete metadata.locTo; - } - - results.push([ - new Document({ - pageContent: firstDocuments?.[i] ?? "", - metadata, - }), - firstDistances[i], - ]); - } - return results; - } - - /** - * Creates a new `Chroma` instance from an array of text strings. The text - * strings are converted to `Document` instances and added to the Chroma - * database. - * @param texts An array of text strings. - * @param metadatas An array of metadata objects or a single metadata object. If an array is provided, it must have the same length as the `texts` array. - * @param embeddings An `Embeddings` instance used to generate embeddings for the documents. - * @param dbConfig A `ChromaLibArgs` object containing the configuration for the Chroma database. - * @returns A promise that resolves with a new `Chroma` instance. - */ - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - dbConfig: ChromaLibArgs - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - return this.fromDocuments(docs, embeddings, dbConfig); - } - - /** - * Creates a new `Chroma` instance from an array of `Document` instances. - * The documents are added to the Chroma database. - * @param docs An array of `Document` instances. - * @param embeddings An `Embeddings` instance used to generate embeddings for the documents. - * @param dbConfig A `ChromaLibArgs` object containing the configuration for the Chroma database. - * @returns A promise that resolves with a new `Chroma` instance. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig: ChromaLibArgs - ): Promise { - const instance = new this(embeddings, dbConfig); - await instance.addDocuments(docs); - return instance; - } - - /** - * Creates a new `Chroma` instance from an existing collection in the - * Chroma database. - * @param embeddings An `Embeddings` instance used to generate embeddings for the documents. - * @param dbConfig A `ChromaLibArgs` object containing the configuration for the Chroma database. - * @returns A promise that resolves with a new `Chroma` instance. - */ - static async fromExistingCollection( - embeddings: Embeddings, - dbConfig: ChromaLibArgs - ): Promise { - const instance = new this(embeddings, dbConfig); - await instance.ensureCollection(); - return instance; - } - - /** - * Imports the `ChromaClient` from the `chromadb` module. - * @returns A promise that resolves with an object containing the `ChromaClient` constructor. - */ - static async imports(): Promise<{ - ChromaClient: typeof ChromaClientT; - }> { - try { - const { ChromaClient } = await import("chromadb"); - return { ChromaClient }; - } catch (e) { - throw new Error( - "Please install chromadb as a dependency with, e.g. `npm install -S chromadb`" - ); - } - } -} - -/** - * Generates a unique collection name if none is provided. - */ -function ensureCollectionName(collectionName?: string) { - if (!collectionName) { - return `langchain-${uuid.v4()}`; - } - return collectionName; -} +export * from "@langchain/community/vectorstores/chroma"; diff --git a/langchain/src/vectorstores/clickhouse.ts b/langchain/src/vectorstores/clickhouse.ts index 003256e8d7a8..ead163d980af 100644 --- a/langchain/src/vectorstores/clickhouse.ts +++ b/langchain/src/vectorstores/clickhouse.ts @@ -1,338 +1 @@ -import * as uuid from "uuid"; -import { ClickHouseClient, createClient } from "@clickhouse/client"; -import { format } from "mysql2"; -import { Embeddings } from "../embeddings/base.js"; -import { VectorStore } from "./base.js"; -import { Document } from "../document.js"; - -/** - * Arguments for the ClickHouseStore class, which include the host, port, - * protocol, username, password, index type, index parameters, - * index query params, column map, database, table. - */ -export interface ClickHouseLibArgs { - host: string; - port: string | number; - protocol?: string; - username: string; - password: string; - indexType?: string; - indexParam?: Record; - indexQueryParams?: Record; - columnMap?: ColumnMap; - database?: string; - table?: string; -} - -/** - * Mapping of columns in the ClickHouse database. - */ -export interface ColumnMap { - id: string; - uuid: string; - document: string; - embedding: string; - metadata: string; -} - -/** - * Type for filtering search results in the ClickHouse database. - */ -export interface ClickHouseFilter { - whereStr: string; -} - -/** - * Class for interacting with the ClickHouse database. It extends the - * VectorStore class and provides methods for adding vectors and - * documents, searching for similar vectors, and creating instances from - * texts or documents. - */ -export class ClickHouseStore extends VectorStore { - declare FilterType: ClickHouseFilter; - - private client: ClickHouseClient; - - private indexType: string; - - private indexParam: Record; - - private indexQueryParams: Record; - - private columnMap: ColumnMap; - - private database: string; - - private table: string; - - private isInitialized = false; - - _vectorstoreType(): string { - return "clickhouse"; - } - - constructor(embeddings: Embeddings, args: ClickHouseLibArgs) { - super(embeddings, args); - - this.indexType = args.indexType || "annoy"; - this.indexParam = args.indexParam || { L2Distance: 100 }; - this.indexQueryParams = args.indexQueryParams || {}; - this.columnMap = args.columnMap || { - id: "id", - document: "document", - embedding: "embedding", - metadata: "metadata", - uuid: "uuid", - }; - this.database = args.database || "default"; - this.table = args.table || "vector_table"; - - this.client = createClient({ - host: `${args.protocol ?? "https://"}${args.host}:${args.port}`, - username: args.username, - password: args.password, - session_id: uuid.v4(), - }); - } - - /** - * Method to add vectors to the ClickHouse database. - * @param vectors The vectors to add. - * @param documents The documents associated with the vectors. - * @returns Promise that resolves when the vectors have been added. - */ - async addVectors(vectors: number[][], documents: Document[]): Promise { - if (vectors.length === 0) { - return; - } - - if (!this.isInitialized) { - await this.initialize(vectors[0].length); - } - - const queryStr = this.buildInsertQuery(vectors, documents); - await this.client.exec({ query: queryStr }); - } - - /** - * Method to add documents to the ClickHouse database. - * @param documents The documents to add. - * @returns Promise that resolves when the documents have been added. - */ - async addDocuments(documents: Document[]): Promise { - return this.addVectors( - await this.embeddings.embedDocuments(documents.map((d) => d.pageContent)), - documents - ); - } - - /** - * Method to search for vectors that are similar to a given query vector. - * @param query The query vector. - * @param k The number of similar vectors to return. - * @param filter Optional filter for the search results. - * @returns Promise that resolves with an array of tuples, each containing a Document and a score. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: this["FilterType"] - ): Promise<[Document, number][]> { - if (!this.isInitialized) { - await this.initialize(query.length); - } - const queryStr = this.buildSearchQuery(query, k, filter); - - const queryResultSet = await this.client.query({ query: queryStr }); - - const queryResult: { - data: { document: string; metadata: object; dist: number }[]; - } = await queryResultSet.json(); - - const result: [Document, number][] = queryResult.data.map((item) => [ - new Document({ pageContent: item.document, metadata: item.metadata }), - item.dist, - ]); - - return result; - } - - /** - * Static method to create an instance of ClickHouseStore from texts. - * @param texts The texts to use. - * @param metadatas The metadata associated with the texts. - * @param embeddings The embeddings to use. - * @param args The arguments for the ClickHouseStore. - * @returns Promise that resolves with a new instance of ClickHouseStore. - */ - static async fromTexts( - texts: string[], - metadatas: object | object[], - embeddings: Embeddings, - args: ClickHouseLibArgs - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - return ClickHouseStore.fromDocuments(docs, embeddings, args); - } - - /** - * Static method to create an instance of ClickHouseStore from documents. - * @param docs The documents to use. - * @param embeddings The embeddings to use. - * @param args The arguments for the ClickHouseStore. - * @returns Promise that resolves with a new instance of ClickHouseStore. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - args: ClickHouseLibArgs - ): Promise { - const instance = new this(embeddings, args); - await instance.addDocuments(docs); - return instance; - } - - /** - * Static method to create an instance of ClickHouseStore from an existing - * index. - * @param embeddings The embeddings to use. - * @param args The arguments for the ClickHouseStore. - * @returns Promise that resolves with a new instance of ClickHouseStore. - */ - static async fromExistingIndex( - embeddings: Embeddings, - args: ClickHouseLibArgs - ): Promise { - const instance = new this(embeddings, args); - - await instance.initialize(); - return instance; - } - - /** - * Method to initialize the ClickHouse database. - * @param dimension Optional dimension of the vectors. - * @returns Promise that resolves when the database has been initialized. - */ - private async initialize(dimension?: number): Promise { - const dim = dimension ?? (await this.embeddings.embedQuery("test")).length; - - const indexParamStr = this.indexParam - ? Object.entries(this.indexParam) - .map(([key, value]) => `'${key}', ${value}`) - .join(", ") - : ""; - - const query = ` - CREATE TABLE IF NOT EXISTS ${this.database}.${this.table}( - ${this.columnMap.id} Nullable(String), - ${this.columnMap.document} Nullable(String), - ${this.columnMap.embedding} Array(Float32), - ${this.columnMap.metadata} JSON, - ${this.columnMap.uuid} UUID DEFAULT generateUUIDv4(), - CONSTRAINT cons_vec_len CHECK length(${this.columnMap.embedding}) = ${dim}, - INDEX vec_idx ${this.columnMap.embedding} TYPE ${this.indexType}(${indexParamStr}) GRANULARITY 1000 - ) ENGINE = MergeTree ORDER BY ${this.columnMap.uuid} SETTINGS index_granularity = 8192;`; - - await this.client.exec({ - query, - clickhouse_settings: { - allow_experimental_object_type: 1, - allow_experimental_annoy_index: 1, - }, - }); - this.isInitialized = true; - } - - /** - * Method to build an SQL query for inserting vectors and documents into - * the ClickHouse database. - * @param vectors The vectors to insert. - * @param documents The documents to insert. - * @returns The SQL query string. - */ - private buildInsertQuery(vectors: number[][], documents: Document[]): string { - const columnsStr = Object.values( - Object.fromEntries( - Object.entries(this.columnMap).filter( - ([key]) => key !== this.columnMap.uuid - ) - ) - ).join(", "); - - const placeholders = vectors.map(() => "(?, ?, ?, ?)").join(", "); - const values = []; - - for (let i = 0; i < vectors.length; i += 1) { - const vector = vectors[i]; - const document = documents[i]; - values.push( - uuid.v4(), - this.escapeString(document.pageContent), - JSON.stringify(vector), - JSON.stringify(document.metadata) - ); - } - - const insertQueryStr = ` - INSERT INTO TABLE ${this.database}.${this.table}(${columnsStr}) - VALUES ${placeholders} - `; - - const insertQuery = format(insertQueryStr, values); - return insertQuery; - } - - private escapeString(str: string): string { - return str.replace(/\\/g, "\\\\").replace(/'/g, "\\'"); - } - - /** - * Method to build an SQL query for searching for similar vectors in the - * ClickHouse database. - * @param query The query vector. - * @param k The number of similar vectors to return. - * @param filter Optional filter for the search results. - * @returns The SQL query string. - */ - private buildSearchQuery( - query: number[], - k: number, - filter?: ClickHouseFilter - ): string { - const order = "ASC"; - const whereStr = filter ? `PREWHERE ${filter.whereStr}` : ""; - const placeholders = query.map(() => "?").join(", "); - - const settingStrings: string[] = []; - if (this.indexQueryParams) { - for (const [key, value] of Object.entries(this.indexQueryParams)) { - settingStrings.push(`SETTING ${key}=${value}`); - } - } - - const searchQueryStr = ` - SELECT ${this.columnMap.document} AS document, ${ - this.columnMap.metadata - } AS metadata, dist - FROM ${this.database}.${this.table} - ${whereStr} - ORDER BY L2Distance(${ - this.columnMap.embedding - }, [${placeholders}]) AS dist ${order} - LIMIT ${k} ${settingStrings.join(" ")} - `; - - // Format the query with actual values - const searchQuery = format(searchQueryStr, query); - return searchQuery; - } -} +export * from "@langchain/community/vectorstores/clickhouse"; diff --git a/langchain/src/vectorstores/closevector/node.ts b/langchain/src/vectorstores/closevector/node.ts index dd91f71a57b9..d3a2298f9577 100644 --- a/langchain/src/vectorstores/closevector/node.ts +++ b/langchain/src/vectorstores/closevector/node.ts @@ -1,182 +1 @@ -import { - CloseVectorHNSWNode, - HierarchicalNSWT, - CloseVectorHNSWLibArgs, - CloseVectorCredentials, -} from "closevector-node"; - -import { CloseVector } from "./common.js"; - -import { Embeddings } from "../../embeddings/base.js"; -import { Document } from "../../document.js"; - -/** - * package closevector-node is largely based on hnswlib.ts in the current folder with the following exceptions: - * 1. It uses a modified version of hnswlib-node to ensure the generated index can be loaded by closevector_web.ts. - * 2. It adds features to upload and download the index to/from the CDN provided by CloseVector. - * - * For more information, check out https://closevector-docs.getmegaportal.com/ - */ - -/** - * Arguments for creating a CloseVectorNode instance, extending CloseVectorHNSWLibArgs. - */ -export interface CloseVectorNodeArgs - extends CloseVectorHNSWLibArgs { - instance?: CloseVectorHNSWNode; -} - -/** - * Class that implements a vector store using Hierarchical Navigable Small - * World (HNSW) graphs. It extends the SaveableVectorStore class and - * provides methods for adding documents and vectors, performing - * similarity searches, and saving and loading the vector store. - */ -export class CloseVectorNode extends CloseVector { - declare FilterType: (doc: Document) => boolean; - - constructor( - embeddings: Embeddings, - args: CloseVectorNodeArgs, - credentials?: CloseVectorCredentials - ) { - super(embeddings, args, credentials); - if (args.instance) { - this.instance = args.instance; - } else { - this.instance = new CloseVectorHNSWNode(embeddings, args); - } - if (this.credentials?.key) { - this.instance.accessKey = this.credentials.key; - } - if (this.credentials?.secret) { - this.instance.secret = this.credentials.secret; - } - } - - /** - * Method to save the index to the CloseVector CDN. - * @param options - * @param options.description A description of the index. - * @param options.public Whether the index should be public or private. Defaults to false. - * @param options.uuid A UUID for the index. If not provided, a new index will be created. - * @param options.onProgress A callback function that will be called with the progress of the upload. - */ - async saveToCloud( - options: Parameters[0] - ) { - await this.instance.saveToCloud(options); - } - - /** - * Method to load the index from the CloseVector CDN. - * @param options - * @param options.uuid The UUID of the index to be downloaded. - * @param options.credentials The credentials to be used by the CloseVectorNode instance. - * @param options.embeddings The embeddings to be used by the CloseVectorNode instance. - * @param options.onProgress A callback function that will be called with the progress of the download. - */ - static async loadFromCloud( - options: Omit< - Parameters<(typeof CloseVectorHNSWNode)["loadFromCloud"]>[0] & { - embeddings: Embeddings; - credentials: CloseVectorCredentials; - }, - "accessKey" | "secret" - > - ) { - if (!options.credentials.key || !options.credentials.secret) { - throw new Error("key and secret must be provided"); - } - const instance = await CloseVectorHNSWNode.loadFromCloud({ - ...options, - accessKey: options.credentials.key, - secret: options.credentials.secret, - }); - const vectorstore = new this( - options.embeddings, - instance.args, - options.credentials - ); - return vectorstore; - } - - /** - * Static method to load a vector store from a directory. It reads the - * HNSW index, the arguments, and the document store from the directory, - * then creates a new HNSWLib instance with these values. - * @param directory The directory from which to load the vector store. - * @param embeddings The embeddings to be used by the CloseVectorNode instance. - * @returns A Promise that resolves to a new CloseVectorNode instance. - */ - static async load( - directory: string, - embeddings: Embeddings, - credentials?: CloseVectorCredentials - ) { - const instance = await CloseVectorHNSWNode.load(directory, embeddings); - const vectorstore = new this(embeddings, instance.args, credentials); - return vectorstore; - } - - /** - * Static method to create a new CloseVectorWeb instance from texts and metadata. - * It creates a new Document instance for each text and metadata, then - * calls the fromDocuments method to create the CloseVectorWeb instance. - * @param texts The texts to be used to create the documents. - * @param metadatas The metadata to be used to create the documents. - * @param embeddings The embeddings to be used by the CloseVectorWeb instance. - * @param args An optional configuration object for the CloseVectorWeb instance. - * @param credential An optional credential object for the CloseVector API. - * @returns A Promise that resolves to a new CloseVectorWeb instance. - */ - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - args?: Record, - credential?: CloseVectorCredentials - ): Promise { - const docs = CloseVector.textsToDocuments(texts, metadatas); - return await CloseVectorNode.fromDocuments( - docs, - embeddings, - args, - credential - ); - } - - /** - * Static method to create a new CloseVectorNode instance from documents. It - * creates a new CloseVectorNode instance, adds the documents to it, then returns - * the instance. - * @param docs The documents to be added to the HNSWLib instance. - * @param embeddings The embeddings to be used by the HNSWLib instance. - * @param args An optional configuration object for the HNSWLib instance. - * @param credentials An optional credential object for the CloseVector API. - * @returns A Promise that resolves to a new CloseVectorNode instance. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - args?: Record, - credentials?: CloseVectorCredentials - ): Promise { - const _args: Record = args || { - space: "cosine", - }; - const instance = new this( - embeddings, - _args as unknown as CloseVectorNodeArgs, - credentials - ); - await instance.addDocuments(docs); - return instance; - } - - static async imports(): Promise<{ - HierarchicalNSW: typeof HierarchicalNSWT; - }> { - return CloseVectorHNSWNode.imports(); - } -} +export * from "@langchain/community/vectorstores/closevector/node"; diff --git a/langchain/src/vectorstores/closevector/web.ts b/langchain/src/vectorstores/closevector/web.ts index 06882c2befa9..bc67272fa78e 100644 --- a/langchain/src/vectorstores/closevector/web.ts +++ b/langchain/src/vectorstores/closevector/web.ts @@ -1,179 +1 @@ -import { - CloseVectorHNSWWeb, - HierarchicalNSWT, - CloseVectorHNSWLibArgs, - CloseVectorCredentials, - HnswlibModule, -} from "closevector-web"; - -import { CloseVector } from "./common.js"; - -import { Embeddings } from "../../embeddings/base.js"; -import { Document } from "../../document.js"; - -/** - * package closevector-node is largely based on hnswlib.ts in the current folder with the following exceptions: - * 1. It uses a modified version of hnswlib-node to ensure the generated index can be loaded by closevector_web.ts. - * 2. It adds features to upload and download the index to/from the CDN provided by CloseVector. - * - * For more information, check out https://closevector-docs.getmegaportal.com/ - */ - -/** - * Arguments for creating a CloseVectorWeb instance, extending CloseVectorHNSWLibArgs. - */ -export interface CloseVectorWebArgs - extends CloseVectorHNSWLibArgs { - instance?: CloseVectorHNSWWeb; -} - -/** - * Class that implements a vector store using CloseVector, It extends the SaveableVectorStore class and - * provides methods for adding documents and vectors, performing - * similarity searches, and saving and loading the vector store. - */ -export class CloseVectorWeb extends CloseVector { - declare FilterType: (doc: Document) => boolean; - - constructor( - embeddings: Embeddings, - args: CloseVectorWebArgs, - credentials?: CloseVectorCredentials - ) { - super(embeddings, args, credentials); - if (args.instance) { - this.instance = args.instance; - } else { - this.instance = new CloseVectorHNSWWeb(embeddings, args); - } - } - - /** - * Method to save the index to the CloseVector CDN. - * @param options - * @param options.url the upload url generated by the CloseVector API: https://closevector-docs.getmegaportal.com/docs/api/http-api/file-url - * @param options.onProgress a callback function to track the upload progress - */ - async saveToCloud( - options: Parameters[0] & { - uuid?: string; - } - ) { - if (!this.instance.uuid && !options.uuid) { - throw new Error("No uuid provided"); - } - if (!this.instance.uuid) { - this.instance._uuid = options.uuid; - } - await this.save(this.instance.uuid); - await this.instance.saveToCloud(options); - } - - /** - * Method to load the index from the CloseVector CDN. - * @param options - * @param options.url the upload url generated by the CloseVector API: https://closevector-docs.getmegaportal.com/docs/api/http-api/file-url - * @param options.onProgress a callback function to track the upload progress - * @param options.uuid the uuid of the index to be downloaded - * @param options.embeddings the embeddings to be used by the CloseVectorWeb instance - */ - static async loadFromCloud( - options: Parameters[0] & { - embeddings: Embeddings; - credentials?: CloseVectorCredentials; - } - ) { - const instance = await CloseVectorHNSWWeb.loadFromCloud(options); - const vectorstore = new this( - options.embeddings, - instance.args, - options.credentials - ); - return vectorstore; - } - - /** - * Static method to load a vector store from a directory. It reads the - * HNSW index, the arguments, and the document store from the directory, - * then creates a new CloseVectorWeb instance with these values. - * @param directory The directory from which to load the vector store. - * @param embeddings The embeddings to be used by the CloseVectorWeb instance. - * @returns A Promise that resolves to a new CloseVectorWeb instance. - */ - static async load( - directory: string, - embeddings: Embeddings, - credentials?: CloseVectorCredentials - ) { - const instance = await CloseVectorHNSWWeb.load(directory, embeddings); - const vectorstore = new this(embeddings, instance.args, credentials); - return vectorstore; - } - - /** - * Static method to create a new CloseVectorWeb instance from texts and metadata. - * It creates a new Document instance for each text and metadata, then - * calls the fromDocuments method to create the CloseVectorWeb instance. - * @param texts The texts to be used to create the documents. - * @param metadatas The metadata to be used to create the documents. - * @param embeddings The embeddings to be used by the CloseVectorWeb instance. - * @param args An optional configuration object for the CloseVectorWeb instance. - * @param credential An optional credential object for the CloseVector API. - * @returns A Promise that resolves to a new CloseVectorWeb instance. - */ - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - args?: Record, - credential?: CloseVectorCredentials - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - return await CloseVectorWeb.fromDocuments( - docs, - embeddings, - args, - credential - ); - } - - /** - * Static method to create a new CloseVectorWeb instance from documents. It - * creates a new CloseVectorWeb instance, adds the documents to it, then returns - * the instance. - * @param docs The documents to be added to the CloseVectorWeb instance. - * @param embeddings The embeddings to be used by the CloseVectorWeb instance. - * @param args An optional configuration object for the CloseVectorWeb instance. - * @param credentials An optional credential object for the CloseVector API. - * @returns A Promise that resolves to a new CloseVectorWeb instance. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - args?: Record, - credentials?: CloseVectorCredentials - ): Promise { - const _args: Record = args || { - space: "cosine", - }; - const instance = new this( - embeddings, - _args as unknown as CloseVectorWebArgs, - credentials - ); - await instance.addDocuments(docs); - return instance; - } - - static async imports(): Promise { - return CloseVectorHNSWWeb.imports(); - } -} +export * from "@langchain/community/vectorstores/closevector/web"; diff --git a/langchain/src/vectorstores/cloudflare_vectorize.ts b/langchain/src/vectorstores/cloudflare_vectorize.ts index 8a5babf49b1f..5ea66a2f66ae 100644 --- a/langchain/src/vectorstores/cloudflare_vectorize.ts +++ b/langchain/src/vectorstores/cloudflare_vectorize.ts @@ -1,227 +1 @@ -import * as uuid from "uuid"; - -import { - VectorizeIndex, - VectorizeVectorMetadata, -} from "@cloudflare/workers-types"; -import { VectorStore } from "./base.js"; -import { Embeddings } from "../embeddings/base.js"; -import { Document } from "../document.js"; -import { chunkArray } from "../util/chunk.js"; -import { AsyncCaller, type AsyncCallerParams } from "../util/async_caller.js"; - -export interface VectorizeLibArgs extends AsyncCallerParams { - index: VectorizeIndex; - textKey?: string; -} - -/** - * Type that defines the parameters for the delete operation in the - * CloudflareVectorizeStore class. It includes ids, deleteAll flag, and namespace. - */ -export type VectorizeDeleteParams = { - ids: string[]; -}; - -/** - * Class that extends the VectorStore class and provides methods to - * interact with the Cloudflare Vectorize vector database. - */ -export class CloudflareVectorizeStore extends VectorStore { - textKey: string; - - namespace?: string; - - index: VectorizeIndex; - - caller: AsyncCaller; - - _vectorstoreType(): string { - return "cloudflare_vectorize"; - } - - constructor(embeddings: Embeddings, args: VectorizeLibArgs) { - super(embeddings, args); - - this.embeddings = embeddings; - const { index, textKey, ...asyncCallerArgs } = args; - if (!index) { - throw new Error( - "Must supply a Vectorize index binding, eg { index: env.VECTORIZE }" - ); - } - this.index = index; - this.textKey = textKey ?? "text"; - this.caller = new AsyncCaller({ - maxConcurrency: 6, - maxRetries: 0, - ...asyncCallerArgs, - }); - } - - /** - * Method that adds documents to the Vectorize database. - * @param documents Array of documents to add. - * @param options Optional ids for the documents. - * @returns Promise that resolves with the ids of the added documents. - */ - async addDocuments( - documents: Document[], - options?: { ids?: string[] } | string[] - ) { - const texts = documents.map(({ pageContent }) => pageContent); - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents, - options - ); - } - - /** - * Method that adds vectors to the Vectorize database. - * @param vectors Array of vectors to add. - * @param documents Array of documents associated with the vectors. - * @param options Optional ids for the vectors. - * @returns Promise that resolves with the ids of the added vectors. - */ - async addVectors( - vectors: number[][], - documents: Document[], - options?: { ids?: string[] } | string[] - ) { - const ids = Array.isArray(options) ? options : options?.ids; - const documentIds = ids == null ? documents.map(() => uuid.v4()) : ids; - const vectorizeVectors = vectors.map((values, idx) => { - const metadata: Record = { - ...documents[idx].metadata, - [this.textKey]: documents[idx].pageContent, - }; - return { - id: documentIds[idx], - metadata, - values, - }; - }); - - // Stick to a limit of 500 vectors per upsert request - const chunkSize = 500; - const chunkedVectors = chunkArray(vectorizeVectors, chunkSize); - const batchRequests = chunkedVectors.map((chunk) => - this.caller.call(async () => this.index.upsert(chunk)) - ); - - await Promise.all(batchRequests); - - return documentIds; - } - - /** - * Method that deletes vectors from the Vectorize database. - * @param params Parameters for the delete operation. - * @returns Promise that resolves when the delete operation is complete. - */ - async delete(params: VectorizeDeleteParams): Promise { - const batchSize = 1000; - const batchedIds = chunkArray(params.ids, batchSize); - const batchRequests = batchedIds.map((batchIds) => - this.caller.call(async () => this.index.deleteByIds(batchIds)) - ); - await Promise.all(batchRequests); - } - - /** - * Method that performs a similarity search in the Vectorize database and - * returns the results along with their scores. - * @param query Query vector for the similarity search. - * @param k Number of top results to return. - * @returns Promise that resolves with an array of documents and their scores. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number - ): Promise<[Document, number][]> { - const results = await this.index.query(query, { - returnVectors: true, - topK: k, - }); - - const result: [Document, number][] = []; - - if (results.matches) { - for (const res of results.matches) { - const { [this.textKey]: pageContent, ...metadata } = - res.vector?.metadata ?? {}; - result.push([ - new Document({ metadata, pageContent: pageContent as string }), - res.score, - ]); - } - } - - return result; - } - - /** - * Static method that creates a new instance of the CloudflareVectorizeStore class - * from texts. - * @param texts Array of texts to add to the Vectorize database. - * @param metadatas Metadata associated with the texts. - * @param embeddings Embeddings to use for the texts. - * @param dbConfig Configuration for the Vectorize database. - * @param options Optional ids for the vectors. - * @returns Promise that resolves with a new instance of the CloudflareVectorizeStore class. - */ - static async fromTexts( - texts: string[], - metadatas: - | Record[] - | Record, - embeddings: Embeddings, - dbConfig: VectorizeLibArgs - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - return CloudflareVectorizeStore.fromDocuments(docs, embeddings, dbConfig); - } - - /** - * Static method that creates a new instance of the CloudflareVectorizeStore class - * from documents. - * @param docs Array of documents to add to the Vectorize database. - * @param embeddings Embeddings to use for the documents. - * @param dbConfig Configuration for the Vectorize database. - * @param options Optional ids for the vectors. - * @returns Promise that resolves with a new instance of the CloudflareVectorizeStore class. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig: VectorizeLibArgs - ): Promise { - const instance = new this(embeddings, dbConfig); - await instance.addDocuments(docs); - return instance; - } - - /** - * Static method that creates a new instance of the CloudflareVectorizeStore class - * from an existing index. - * @param embeddings Embeddings to use for the documents. - * @param dbConfig Configuration for the Vectorize database. - * @returns Promise that resolves with a new instance of the CloudflareVectorizeStore class. - */ - static async fromExistingIndex( - embeddings: Embeddings, - dbConfig: VectorizeLibArgs - ): Promise { - const instance = new this(embeddings, dbConfig); - return instance; - } -} +export * from "@langchain/community/vectorstores/cloudflare_vectorize"; diff --git a/langchain/src/vectorstores/convex.ts b/langchain/src/vectorstores/convex.ts index e42264ee6060..65839c433e6b 100644 --- a/langchain/src/vectorstores/convex.ts +++ b/langchain/src/vectorstores/convex.ts @@ -1,376 +1 @@ -// eslint-disable-next-line import/no-extraneous-dependencies -import { - DocumentByInfo, - FieldPaths, - FilterExpression, - FunctionReference, - GenericActionCtx, - GenericDataModel, - GenericTableInfo, - NamedTableInfo, - NamedVectorIndex, - TableNamesInDataModel, - VectorFilterBuilder, - VectorIndexNames, - makeFunctionReference, -} from "convex/server"; -import { Document } from "../document.js"; -import { Embeddings } from "../embeddings/base.js"; -import { VectorStore } from "./base.js"; - -/** - * Type that defines the config required to initialize the - * ConvexVectorStore class. It includes the table name, - * index name, text field name, and embedding field name. - */ -export type ConvexVectorStoreConfig< - DataModel extends GenericDataModel, - TableName extends TableNamesInDataModel, - IndexName extends VectorIndexNames>, - TextFieldName extends FieldPaths>, - EmbeddingFieldName extends FieldPaths>, - MetadataFieldName extends FieldPaths>, - InsertMutation extends FunctionReference< - "mutation", - "internal", - { table: string; document: object } - >, - GetQuery extends FunctionReference< - "query", - "internal", - { id: string }, - object | null - > -> = { - readonly ctx: GenericActionCtx; - /** - * Defaults to "documents" - */ - readonly table?: TableName; - /** - * Defaults to "byEmbedding" - */ - readonly index?: IndexName; - /** - * Defaults to "text" - */ - readonly textField?: TextFieldName; - /** - * Defaults to "embedding" - */ - readonly embeddingField?: EmbeddingFieldName; - /** - * Defaults to "metadata" - */ - readonly metadataField?: MetadataFieldName; - /** - * Defaults to `internal.langchain.db.insert` - */ - readonly insert?: InsertMutation; - /** - * Defaults to `internal.langchain.db.get` - */ - readonly get?: GetQuery; -}; - -/** - * Class that is a wrapper around Convex storage and vector search. It is used - * to insert embeddings in Convex documents with a vector search index, - * and perform a vector search on them. - * - * ConvexVectorStore does NOT implement maxMarginalRelevanceSearch. - */ -export class ConvexVectorStore< - DataModel extends GenericDataModel, - TableName extends TableNamesInDataModel, - IndexName extends VectorIndexNames>, - TextFieldName extends FieldPaths>, - EmbeddingFieldName extends FieldPaths>, - MetadataFieldName extends FieldPaths>, - InsertMutation extends FunctionReference< - "mutation", - "internal", - { table: string; document: object } - >, - GetQuery extends FunctionReference< - "query", - "internal", - { id: string }, - object | null - > -> extends VectorStore { - /** - * Type that defines the filter used in the - * similaritySearchVectorWithScore and maxMarginalRelevanceSearch methods. - * It includes limit, filter and a flag to include embeddings. - */ - declare FilterType: { - filter?: ( - q: VectorFilterBuilder< - DocumentByInfo, - NamedVectorIndex, IndexName> - > - ) => FilterExpression; - includeEmbeddings?: boolean; - }; - - private readonly ctx: GenericActionCtx; - - private readonly table: TableName; - - private readonly index: IndexName; - - private readonly textField: TextFieldName; - - private readonly embeddingField: EmbeddingFieldName; - - private readonly metadataField: MetadataFieldName; - - private readonly insert: InsertMutation; - - private readonly get: GetQuery; - - _vectorstoreType(): string { - return "convex"; - } - - constructor( - embeddings: Embeddings, - config: ConvexVectorStoreConfig< - DataModel, - TableName, - IndexName, - TextFieldName, - EmbeddingFieldName, - MetadataFieldName, - InsertMutation, - GetQuery - > - ) { - super(embeddings, config); - this.ctx = config.ctx; - this.table = config.table ?? ("documents" as TableName); - this.index = config.index ?? ("byEmbedding" as IndexName); - this.textField = config.textField ?? ("text" as TextFieldName); - this.embeddingField = - config.embeddingField ?? ("embedding" as EmbeddingFieldName); - this.metadataField = - config.metadataField ?? ("metadata" as MetadataFieldName); - this.insert = - // eslint-disable-next-line @typescript-eslint/no-explicit-any - config.insert ?? (makeFunctionReference("langchain/db:insert") as any); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - this.get = config.get ?? (makeFunctionReference("langchain/db:get") as any); - } - - /** - * Add vectors and their corresponding documents to the Convex table. - * @param vectors Vectors to be added. - * @param documents Corresponding documents to be added. - * @returns Promise that resolves when the vectors and documents have been added. - */ - async addVectors(vectors: number[][], documents: Document[]): Promise { - const convexDocuments = vectors.map((embedding, idx) => ({ - [this.textField]: documents[idx].pageContent, - [this.embeddingField]: embedding, - [this.metadataField]: documents[idx].metadata, - })); - // TODO: Remove chunking when Convex handles the concurrent requests correctly - const PAGE_SIZE = 16; - for (let i = 0; i < convexDocuments.length; i += PAGE_SIZE) { - await Promise.all( - convexDocuments.slice(i, i + PAGE_SIZE).map((document) => - this.ctx.runMutation(this.insert, { - table: this.table, - document, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } as any) - ) - ); - } - } - - /** - * Add documents to the Convex table. It first converts - * the documents to vectors using the embeddings and then calls the - * addVectors method. - * @param documents Documents to be added. - * @returns Promise that resolves when the documents have been added. - */ - async addDocuments(documents: Document[]): Promise { - const texts = documents.map(({ pageContent }) => pageContent); - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents - ); - } - - /** - * Similarity search on the vectors stored in the - * Convex table. It returns a list of documents and their - * corresponding similarity scores. - * @param query Query vector for the similarity search. - * @param k Number of nearest neighbors to return. - * @param filter Optional filter to be applied. - * @returns Promise that resolves to a list of documents and their corresponding similarity scores. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: this["FilterType"] - ): Promise<[Document, number][]> { - const idsAndScores = await this.ctx.vectorSearch(this.table, this.index, { - vector: query, - limit: k, - filter: filter?.filter, - }); - - const documents = await Promise.all( - idsAndScores.map(({ _id }) => - // eslint-disable-next-line @typescript-eslint/no-explicit-any - this.ctx.runQuery(this.get, { id: _id } as any) - ) - ); - - return documents.map( - ( - { - [this.textField]: text, - [this.embeddingField]: embedding, - [this.metadataField]: metadata, - }, - idx - ) => [ - new Document({ - pageContent: text as string, - metadata: { - ...metadata, - ...(filter?.includeEmbeddings ? { embedding } : null), - }, - }), - idsAndScores[idx]._score, - ] - ); - } - - /** - * Static method to create an instance of ConvexVectorStore from a - * list of texts. It first converts the texts to vectors and then adds - * them to the Convex table. - * @param texts List of texts to be converted to vectors. - * @param metadatas Metadata for the texts. - * @param embeddings Embeddings to be used for conversion. - * @param dbConfig Database configuration for Convex. - * @returns Promise that resolves to a new instance of ConvexVectorStore. - */ - static async fromTexts< - DataModel extends GenericDataModel, - TableName extends TableNamesInDataModel, - IndexName extends VectorIndexNames>, - TextFieldName extends FieldPaths>, - EmbeddingFieldName extends FieldPaths>, - MetadataFieldName extends FieldPaths>, - InsertMutation extends FunctionReference< - "mutation", - "internal", - { table: string; document: object } - >, - GetQuery extends FunctionReference< - "query", - "internal", - { id: string }, - object | null - > - >( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - dbConfig: ConvexVectorStoreConfig< - DataModel, - TableName, - IndexName, - TextFieldName, - EmbeddingFieldName, - MetadataFieldName, - InsertMutation, - GetQuery - > - ): Promise< - ConvexVectorStore< - DataModel, - TableName, - IndexName, - TextFieldName, - EmbeddingFieldName, - MetadataFieldName, - InsertMutation, - GetQuery - > - > { - const docs = texts.map( - (text, i) => - new Document({ - pageContent: text, - metadata: Array.isArray(metadatas) ? metadatas[i] : metadatas, - }) - ); - return ConvexVectorStore.fromDocuments(docs, embeddings, dbConfig); - } - - /** - * Static method to create an instance of ConvexVectorStore from a - * list of documents. It first converts the documents to vectors and then - * adds them to the Convex table. - * @param docs List of documents to be converted to vectors. - * @param embeddings Embeddings to be used for conversion. - * @param dbConfig Database configuration for Convex. - * @returns Promise that resolves to a new instance of ConvexVectorStore. - */ - static async fromDocuments< - DataModel extends GenericDataModel, - TableName extends TableNamesInDataModel, - IndexName extends VectorIndexNames>, - TextFieldName extends FieldPaths>, - EmbeddingFieldName extends FieldPaths>, - MetadataFieldName extends FieldPaths>, - InsertMutation extends FunctionReference< - "mutation", - "internal", - { table: string; document: object } - >, - GetQuery extends FunctionReference< - "query", - "internal", - { id: string }, - object | null - > - >( - docs: Document[], - embeddings: Embeddings, - dbConfig: ConvexVectorStoreConfig< - DataModel, - TableName, - IndexName, - TextFieldName, - EmbeddingFieldName, - MetadataFieldName, - InsertMutation, - GetQuery - > - ): Promise< - ConvexVectorStore< - DataModel, - TableName, - IndexName, - TextFieldName, - EmbeddingFieldName, - MetadataFieldName, - InsertMutation, - GetQuery - > - > { - const instance = new this(embeddings, dbConfig); - await instance.addDocuments(docs); - return instance; - } -} +export * from "@langchain/community/vectorstores/convex"; diff --git a/langchain/src/vectorstores/elasticsearch.ts b/langchain/src/vectorstores/elasticsearch.ts index 3d0161bb8035..cc0d9f90ea4d 100644 --- a/langchain/src/vectorstores/elasticsearch.ts +++ b/langchain/src/vectorstores/elasticsearch.ts @@ -1,342 +1 @@ -import * as uuid from "uuid"; -import { Client, estypes } from "@elastic/elasticsearch"; -import { Embeddings } from "../embeddings/base.js"; -import { Document } from "../document.js"; -import { VectorStore } from "./base.js"; - -/** - * Type representing the k-nearest neighbors (k-NN) engine used in - * Elasticsearch. - */ -type ElasticKnnEngine = "hnsw"; -/** - * Type representing the similarity measure used in Elasticsearch. - */ -type ElasticSimilarity = "l2_norm" | "dot_product" | "cosine"; - -/** - * Interface defining the options for vector search in Elasticsearch. - */ -interface VectorSearchOptions { - readonly engine?: ElasticKnnEngine; - readonly similarity?: ElasticSimilarity; - readonly m?: number; - readonly efConstruction?: number; - readonly candidates?: number; -} - -/** - * Interface defining the arguments required to create an Elasticsearch - * client. - */ -export interface ElasticClientArgs { - readonly client: Client; - readonly indexName?: string; - readonly vectorSearchOptions?: VectorSearchOptions; -} - -/** - * Type representing a filter object in Elasticsearch. - */ -// eslint-disable-next-line @typescript-eslint/no-explicit-any -type ElasticFilter = object | { field: string; operator: string; value: any }[]; - -/** - * Class for interacting with an Elasticsearch database. It extends the - * VectorStore base class and provides methods for adding documents and - * vectors to the Elasticsearch database, performing similarity searches, - * deleting documents, and more. - */ -export class ElasticVectorSearch extends VectorStore { - declare FilterType: ElasticFilter; - - private readonly client: Client; - - private readonly indexName: string; - - private readonly engine: ElasticKnnEngine; - - private readonly similarity: ElasticSimilarity; - - private readonly efConstruction: number; - - private readonly m: number; - - private readonly candidates: number; - - _vectorstoreType(): string { - return "elasticsearch"; - } - - constructor(embeddings: Embeddings, args: ElasticClientArgs) { - super(embeddings, args); - - this.engine = args.vectorSearchOptions?.engine ?? "hnsw"; - this.similarity = args.vectorSearchOptions?.similarity ?? "l2_norm"; - this.m = args.vectorSearchOptions?.m ?? 16; - this.efConstruction = args.vectorSearchOptions?.efConstruction ?? 100; - this.candidates = args.vectorSearchOptions?.candidates ?? 200; - - this.client = args.client.child({ - headers: { "user-agent": "langchain-js-vs/0.0.1" }, - }); - this.indexName = args.indexName ?? "documents"; - } - - /** - * Method to add documents to the Elasticsearch database. It first - * converts the documents to vectors using the embeddings, then adds the - * vectors to the database. - * @param documents The documents to add to the database. - * @param options Optional parameter that can contain the IDs for the documents. - * @returns A promise that resolves with the IDs of the added documents. - */ - async addDocuments(documents: Document[], options?: { ids?: string[] }) { - const texts = documents.map(({ pageContent }) => pageContent); - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents, - options - ); - } - - /** - * Method to add vectors to the Elasticsearch database. It ensures the - * index exists, then adds the vectors and their corresponding documents - * to the database. - * @param vectors The vectors to add to the database. - * @param documents The documents corresponding to the vectors. - * @param options Optional parameter that can contain the IDs for the documents. - * @returns A promise that resolves with the IDs of the added documents. - */ - async addVectors( - vectors: number[][], - documents: Document[], - options?: { ids?: string[] } - ) { - await this.ensureIndexExists( - vectors[0].length, - this.engine, - this.similarity, - this.efConstruction, - this.m - ); - const documentIds = - options?.ids ?? Array.from({ length: vectors.length }, () => uuid.v4()); - const operations = vectors.flatMap((embedding, idx) => [ - { - index: { - _id: documentIds[idx], - _index: this.indexName, - }, - }, - { - embedding, - metadata: documents[idx].metadata, - text: documents[idx].pageContent, - }, - ]); - await this.client.bulk({ refresh: true, operations }); - return documentIds; - } - - /** - * Method to perform a similarity search in the Elasticsearch database - * using a vector. It returns the k most similar documents along with - * their similarity scores. - * @param query The query vector. - * @param k The number of most similar documents to return. - * @param filter Optional filter to apply to the search. - * @returns A promise that resolves with an array of tuples, where each tuple contains a Document and its similarity score. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: ElasticFilter - ): Promise<[Document, number][]> { - const result = await this.client.search({ - index: this.indexName, - size: k, - knn: { - field: "embedding", - query_vector: query, - filter: this.buildMetadataTerms(filter), - k, - num_candidates: this.candidates, - }, - }); - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - return result.hits.hits.map((hit: any) => [ - new Document({ - pageContent: hit._source.text, - metadata: hit._source.metadata, - }), - hit._score, - ]); - } - - /** - * Method to delete documents from the Elasticsearch database. - * @param params Object containing the IDs of the documents to delete. - * @returns A promise that resolves when the deletion is complete. - */ - async delete(params: { ids: string[] }): Promise { - const operations = params.ids.map((id) => ({ - delete: { - _id: id, - _index: this.indexName, - }, - })); - await this.client.bulk({ refresh: true, operations }); - } - - /** - * Static method to create an ElasticVectorSearch instance from texts. It - * creates Document instances from the texts and their corresponding - * metadata, then calls the fromDocuments method to create the - * ElasticVectorSearch instance. - * @param texts The texts to create the ElasticVectorSearch instance from. - * @param metadatas The metadata corresponding to the texts. - * @param embeddings The embeddings to use for the documents. - * @param args The arguments to create the Elasticsearch client. - * @returns A promise that resolves with the created ElasticVectorSearch instance. - */ - static fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - args: ElasticClientArgs - ): Promise { - const documents = texts.map((text, idx) => { - const metadata = Array.isArray(metadatas) ? metadatas[idx] : metadatas; - return new Document({ pageContent: text, metadata }); - }); - - return ElasticVectorSearch.fromDocuments(documents, embeddings, args); - } - - /** - * Static method to create an ElasticVectorSearch instance from Document - * instances. It adds the documents to the Elasticsearch database, then - * returns the ElasticVectorSearch instance. - * @param docs The Document instances to create the ElasticVectorSearch instance from. - * @param embeddings The embeddings to use for the documents. - * @param dbConfig The configuration for the Elasticsearch database. - * @returns A promise that resolves with the created ElasticVectorSearch instance. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig: ElasticClientArgs - ): Promise { - const store = new ElasticVectorSearch(embeddings, dbConfig); - await store.addDocuments(docs).then(() => store); - return store; - } - - /** - * Static method to create an ElasticVectorSearch instance from an - * existing index in the Elasticsearch database. It checks if the index - * exists, then returns the ElasticVectorSearch instance if it does. - * @param embeddings The embeddings to use for the documents. - * @param dbConfig The configuration for the Elasticsearch database. - * @returns A promise that resolves with the created ElasticVectorSearch instance if the index exists, otherwise it throws an error. - */ - static async fromExistingIndex( - embeddings: Embeddings, - dbConfig: ElasticClientArgs - ): Promise { - const store = new ElasticVectorSearch(embeddings, dbConfig); - const exists = await store.doesIndexExist(); - if (exists) { - return store; - } - throw new Error(`The index ${store.indexName} does not exist.`); - } - - private async ensureIndexExists( - dimension: number, - engine = "hnsw", - similarity = "l2_norm", - efConstruction = 100, - m = 16 - ): Promise { - const request: estypes.IndicesCreateRequest = { - index: this.indexName, - mappings: { - dynamic_templates: [ - { - // map all metadata properties to be keyword - "metadata.*": { - match_mapping_type: "*", - mapping: { type: "keyword" }, - }, - }, - ], - properties: { - text: { type: "text" }, - metadata: { type: "object" }, - embedding: { - type: "dense_vector", - dims: dimension, - index: true, - similarity, - index_options: { - type: engine, - m, - ef_construction: efConstruction, - }, - }, - }, - }, - }; - - const indexExists = await this.doesIndexExist(); - if (indexExists) return; - - await this.client.indices.create(request); - } - - private buildMetadataTerms( - filter?: ElasticFilter - // eslint-disable-next-line @typescript-eslint/no-explicit-any - ): { [operator: string]: { [field: string]: any } }[] { - if (filter == null) return []; - const result = []; - const filters = Array.isArray(filter) - ? filter - : Object.entries(filter).map(([key, value]) => ({ - operator: "term", - field: key, - value, - })); - for (const condition of filters) { - result.push({ - [condition.operator]: { - [`metadata.${condition.field}`]: condition.value, - }, - }); - } - return result; - } - - /** - * Method to check if an index exists in the Elasticsearch database. - * @returns A promise that resolves with a boolean indicating whether the index exists. - */ - async doesIndexExist(): Promise { - return await this.client.indices.exists({ index: this.indexName }); - } - - /** - * Method to delete an index from the Elasticsearch database if it exists. - * @returns A promise that resolves when the deletion is complete. - */ - async deleteIfExists(): Promise { - const indexExists = await this.doesIndexExist(); - if (!indexExists) return; - - await this.client.indices.delete({ index: this.indexName }); - } -} +export * from "@langchain/community/vectorstores/elasticsearch"; diff --git a/langchain/src/vectorstores/faiss.ts b/langchain/src/vectorstores/faiss.ts index 9ba2ffd3ac49..3696e021927b 100644 --- a/langchain/src/vectorstores/faiss.ts +++ b/langchain/src/vectorstores/faiss.ts @@ -1,461 +1 @@ -import type { IndexFlatL2 } from "faiss-node"; -import type { NameRegistry, Parser } from "pickleparser"; -import * as uuid from "uuid"; -import { Embeddings } from "../embeddings/base.js"; -import { SaveableVectorStore } from "./base.js"; -import { Document } from "../document.js"; -import { SynchronousInMemoryDocstore } from "../stores/doc/in_memory.js"; - -/** - * Interface for the arguments required to initialize a FaissStore - * instance. - */ -export interface FaissLibArgs { - docstore?: SynchronousInMemoryDocstore; - index?: IndexFlatL2; - mapping?: Record; -} - -/** - * A class that wraps the FAISS (Facebook AI Similarity Search) vector - * database for efficient similarity search and clustering of dense - * vectors. - */ -export class FaissStore extends SaveableVectorStore { - _index?: IndexFlatL2; - - _mapping: Record; - - docstore: SynchronousInMemoryDocstore; - - args: FaissLibArgs; - - _vectorstoreType(): string { - return "faiss"; - } - - getMapping(): Record { - return this._mapping; - } - - getDocstore(): SynchronousInMemoryDocstore { - return this.docstore; - } - - constructor(embeddings: Embeddings, args: FaissLibArgs) { - super(embeddings, args); - this.args = args; - this._index = args.index; - this._mapping = args.mapping ?? {}; - this.embeddings = embeddings; - this.docstore = args?.docstore ?? new SynchronousInMemoryDocstore(); - } - - /** - * Adds an array of Document objects to the store. - * @param documents An array of Document objects. - * @returns A Promise that resolves when the documents have been added. - */ - async addDocuments(documents: Document[], options?: { ids?: string[] }) { - const texts = documents.map(({ pageContent }) => pageContent); - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents, - options - ); - } - - public get index(): IndexFlatL2 { - if (!this._index) { - throw new Error( - "Vector store not initialised yet. Try calling `fromTexts`, `fromDocuments` or `fromIndex` first." - ); - } - return this._index; - } - - private set index(index: IndexFlatL2) { - this._index = index; - } - - /** - * Adds an array of vectors and their corresponding Document objects to - * the store. - * @param vectors An array of vectors. - * @param documents An array of Document objects corresponding to the vectors. - * @returns A Promise that resolves with an array of document IDs when the vectors and documents have been added. - */ - async addVectors( - vectors: number[][], - documents: Document[], - options?: { ids?: string[] } - ) { - if (vectors.length === 0) { - return []; - } - if (vectors.length !== documents.length) { - throw new Error(`Vectors and documents must have the same length`); - } - const dv = vectors[0].length; - if (!this._index) { - const { IndexFlatL2 } = await FaissStore.importFaiss(); - this._index = new IndexFlatL2(dv); - } - const d = this.index.getDimension(); - if (dv !== d) { - throw new Error( - `Vectors must have the same length as the number of dimensions (${d})` - ); - } - - const docstoreSize = this.index.ntotal(); - const documentIds = options?.ids ?? documents.map(() => uuid.v4()); - for (let i = 0; i < vectors.length; i += 1) { - const documentId = documentIds[i]; - const id = docstoreSize + i; - this.index.add(vectors[i]); - this._mapping[id] = documentId; - this.docstore.add({ [documentId]: documents[i] }); - } - return documentIds; - } - - /** - * Performs a similarity search in the vector store using a query vector - * and returns the top k results along with their scores. - * @param query A query vector. - * @param k The number of top results to return. - * @returns A Promise that resolves with an array of tuples, each containing a Document and its corresponding score. - */ - async similaritySearchVectorWithScore(query: number[], k: number) { - const d = this.index.getDimension(); - if (query.length !== d) { - throw new Error( - `Query vector must have the same length as the number of dimensions (${d})` - ); - } - if (k > this.index.ntotal()) { - const total = this.index.ntotal(); - console.warn( - `k (${k}) is greater than the number of elements in the index (${total}), setting k to ${total}` - ); - // eslint-disable-next-line no-param-reassign - k = total; - } - const result = this.index.search(query, k); - return result.labels.map((id, index) => { - const uuid = this._mapping[id]; - return [this.docstore.search(uuid), result.distances[index]] as [ - Document, - number - ]; - }); - } - - /** - * Saves the current state of the FaissStore to a specified directory. - * @param directory The directory to save the state to. - * @returns A Promise that resolves when the state has been saved. - */ - async save(directory: string) { - const fs = await import("node:fs/promises"); - const path = await import("node:path"); - await fs.mkdir(directory, { recursive: true }); - await Promise.all([ - this.index.write(path.join(directory, "faiss.index")), - await fs.writeFile( - path.join(directory, "docstore.json"), - JSON.stringify([ - Array.from(this.docstore._docs.entries()), - this._mapping, - ]) - ), - ]); - } - - /** - * Method to delete documents. - * @param params Object containing the IDs of the documents to delete. - * @returns A promise that resolves when the deletion is complete. - */ - async delete(params: { ids: string[] }) { - const documentIds = params.ids; - if (documentIds == null) { - throw new Error("No documentIds provided to delete."); - } - - const mappings = new Map( - Object.entries(this._mapping).map(([key, value]) => [ - parseInt(key, 10), - value, - ]) - ); - const reversedMappings = new Map( - Array.from(mappings, (entry) => [entry[1], entry[0]]) - ); - - const missingIds = new Set( - documentIds.filter((id) => !reversedMappings.has(id)) - ); - if (missingIds.size > 0) { - throw new Error( - `Some specified documentIds do not exist in the current store. DocumentIds not found: ${Array.from( - missingIds - ).join(", ")}` - ); - } - - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const indexIdToDelete = documentIds.map((id) => reversedMappings.get(id)!); - - // remove from index - this.index.removeIds(indexIdToDelete); - // remove from docstore - documentIds.forEach((id) => { - this.docstore._docs.delete(id); - }); - // remove from mappings - indexIdToDelete.forEach((id) => { - mappings.delete(id); - }); - - this._mapping = { ...Array.from(mappings.values()) }; - } - - /** - * Merges the current FaissStore with another FaissStore. - * @param targetIndex The FaissStore to merge with. - * @returns A Promise that resolves with an array of document IDs when the merge is complete. - */ - async mergeFrom(targetIndex: FaissStore) { - const targetIndexDimensions = targetIndex.index.getDimension(); - if (!this._index) { - const { IndexFlatL2 } = await FaissStore.importFaiss(); - this._index = new IndexFlatL2(targetIndexDimensions); - } - const d = this.index.getDimension(); - if (targetIndexDimensions !== d) { - throw new Error("Cannot merge indexes with different dimensions."); - } - const targetMapping = targetIndex.getMapping(); - const targetDocstore = targetIndex.getDocstore(); - const targetSize = targetIndex.index.ntotal(); - const documentIds = []; - const currentDocstoreSize = this.index.ntotal(); - for (let i = 0; i < targetSize; i += 1) { - const targetId = targetMapping[i]; - documentIds.push(targetId); - const targetDocument = targetDocstore.search(targetId); - const id = currentDocstoreSize + i; - this._mapping[id] = targetId; - this.docstore.add({ [targetId]: targetDocument }); - } - this.index.mergeFrom(targetIndex.index); - return documentIds; - } - - /** - * Loads a FaissStore from a specified directory. - * @param directory The directory to load the FaissStore from. - * @param embeddings An Embeddings object. - * @returns A Promise that resolves with a new FaissStore instance. - */ - static async load(directory: string, embeddings: Embeddings) { - const fs = await import("node:fs/promises"); - const path = await import("node:path"); - const readStore = (directory: string) => - fs - .readFile(path.join(directory, "docstore.json"), "utf8") - .then(JSON.parse) as Promise< - [Map, Record] - >; - const readIndex = async (directory: string) => { - const { IndexFlatL2 } = await this.importFaiss(); - return IndexFlatL2.read(path.join(directory, "faiss.index")); - }; - const [[docstoreFiles, mapping], index] = await Promise.all([ - readStore(directory), - readIndex(directory), - ]); - const docstore = new SynchronousInMemoryDocstore(new Map(docstoreFiles)); - return new this(embeddings, { docstore, index, mapping }); - } - - static async loadFromPython(directory: string, embeddings: Embeddings) { - const fs = await import("node:fs/promises"); - const path = await import("node:path"); - const { Parser, NameRegistry } = await this.importPickleparser(); - - class PyDocument extends Map { - toDocument(): Document { - return new Document({ - pageContent: this.get("page_content"), - metadata: this.get("metadata"), - }); - } - } - - class PyInMemoryDocstore { - _dict: Map; - - toInMemoryDocstore(): SynchronousInMemoryDocstore { - const s = new SynchronousInMemoryDocstore(); - for (const [key, value] of Object.entries(this._dict)) { - s._docs.set(key, value.toDocument()); - } - return s; - } - } - - const readStore = async (directory: string) => { - const pkl = await fs.readFile( - path.join(directory, "index.pkl"), - "binary" - ); - const buffer = Buffer.from(pkl, "binary"); - - const registry = new NameRegistry() - .register( - "langchain.docstore.in_memory", - "InMemoryDocstore", - PyInMemoryDocstore - ) - .register("langchain.schema", "Document", PyDocument) - .register("langchain.docstore.document", "Document", PyDocument) - .register("langchain.schema.document", "Document", PyDocument) - .register("pathlib", "WindowsPath", (...args) => args.join("\\")) - .register("pathlib", "PosixPath", (...args) => args.join("/")); - - const pickleparser = new Parser({ - nameResolver: registry, - }); - const [rawStore, mapping] = - pickleparser.parse<[PyInMemoryDocstore, Record]>( - buffer - ); - const store = rawStore.toInMemoryDocstore(); - return { store, mapping }; - }; - const readIndex = async (directory: string) => { - const { IndexFlatL2 } = await this.importFaiss(); - return IndexFlatL2.read(path.join(directory, "index.faiss")); - }; - const [store, index] = await Promise.all([ - readStore(directory), - readIndex(directory), - ]); - return new this(embeddings, { - docstore: store.store, - index, - mapping: store.mapping, - }); - } - - /** - * Creates a new FaissStore from an array of texts, their corresponding - * metadata, and an Embeddings object. - * @param texts An array of texts. - * @param metadatas An array of metadata corresponding to the texts, or a single metadata object to be used for all texts. - * @param embeddings An Embeddings object. - * @param dbConfig An optional configuration object for the document store. - * @returns A Promise that resolves with a new FaissStore instance. - */ - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - dbConfig?: { - docstore?: SynchronousInMemoryDocstore; - } - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - return this.fromDocuments(docs, embeddings, dbConfig); - } - - /** - * Creates a new FaissStore from an array of Document objects and an - * Embeddings object. - * @param docs An array of Document objects. - * @param embeddings An Embeddings object. - * @param dbConfig An optional configuration object for the document store. - * @returns A Promise that resolves with a new FaissStore instance. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig?: { - docstore?: SynchronousInMemoryDocstore; - } - ): Promise { - const args: FaissLibArgs = { - docstore: dbConfig?.docstore, - }; - const instance = new this(embeddings, args); - await instance.addDocuments(docs); - return instance; - } - - /** - * Creates a new FaissStore from an existing FaissStore and an Embeddings - * object. - * @param targetIndex An existing FaissStore. - * @param embeddings An Embeddings object. - * @param dbConfig An optional configuration object for the document store. - * @returns A Promise that resolves with a new FaissStore instance. - */ - static async fromIndex( - targetIndex: FaissStore, - embeddings: Embeddings, - dbConfig?: { - docstore?: SynchronousInMemoryDocstore; - } - ): Promise { - const args: FaissLibArgs = { - docstore: dbConfig?.docstore, - }; - const instance = new this(embeddings, args); - await instance.mergeFrom(targetIndex); - return instance; - } - - static async importFaiss(): Promise<{ IndexFlatL2: typeof IndexFlatL2 }> { - try { - const { - default: { IndexFlatL2 }, - } = await import("faiss-node"); - - return { IndexFlatL2 }; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } catch (err: any) { - throw new Error( - `Could not import faiss-node. Please install faiss-node as a dependency with, e.g. \`npm install -S faiss-node\`.\n\nError: ${err?.message}` - ); - } - } - - static async importPickleparser(): Promise<{ - Parser: typeof Parser; - NameRegistry: typeof NameRegistry; - }> { - try { - const { - default: { Parser, NameRegistry }, - } = await import("pickleparser"); - - return { Parser, NameRegistry }; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } catch (err: any) { - throw new Error( - `Could not import pickleparser. Please install pickleparser as a dependency with, e.g. \`npm install -S pickleparser\`.\n\nError: ${err?.message}` - ); - } - } -} +export * from "@langchain/community/vectorstores/faiss"; diff --git a/langchain/src/vectorstores/googlevertexai.ts b/langchain/src/vectorstores/googlevertexai.ts index 0deaf7cfa5a3..dd4573f3feb8 100644 --- a/langchain/src/vectorstores/googlevertexai.ts +++ b/langchain/src/vectorstores/googlevertexai.ts @@ -1,737 +1 @@ -import * as uuid from "uuid"; -import flatten from "flat"; -import { GoogleAuth, GoogleAuthOptions } from "google-auth-library"; -import { VectorStore } from "./base.js"; -import { Embeddings } from "../embeddings/base.js"; -import { Document, DocumentInput } from "../document.js"; -import { GoogleVertexAIConnection } from "../util/googlevertexai-connection.js"; -import { - AsyncCaller, - AsyncCallerCallOptions, - AsyncCallerParams, -} from "../util/async_caller.js"; -import { - GoogleVertexAIConnectionParams, - GoogleResponse, - GoogleAbstractedClientOpsMethod, -} from "../types/googlevertexai-types.js"; -import { Docstore } from "../schema/index.js"; - -/** - * Allows us to create IdDocument classes that contain the ID. - */ -export interface IdDocumentInput extends DocumentInput { - id?: string; -} - -/** - * A Document that optionally includes the ID of the document. - */ -export class IdDocument extends Document implements IdDocumentInput { - id?: string; - - constructor(fields: IdDocumentInput) { - super(fields); - this.id = fields.id; - } -} - -interface IndexEndpointConnectionParams - extends GoogleVertexAIConnectionParams { - indexEndpoint: string; -} - -interface DeployedIndex { - id: string; - index: string; - // There are other attributes, but we don't care about them right now -} - -interface IndexEndpointResponse extends GoogleResponse { - data: { - deployedIndexes: DeployedIndex[]; - publicEndpointDomainName: string; - // There are other attributes, but we don't care about them right now - }; -} - -class IndexEndpointConnection extends GoogleVertexAIConnection< - AsyncCallerCallOptions, - IndexEndpointResponse, - GoogleAuthOptions -> { - indexEndpoint: string; - - constructor(fields: IndexEndpointConnectionParams, caller: AsyncCaller) { - super(fields, caller, new GoogleAuth(fields.authOptions)); - - this.indexEndpoint = fields.indexEndpoint; - } - - async buildUrl(): Promise { - const projectId = await this.client.getProjectId(); - const url = `https://${this.endpoint}/${this.apiVersion}/projects/${projectId}/locations/${this.location}/indexEndpoints/${this.indexEndpoint}`; - return url; - } - - buildMethod(): GoogleAbstractedClientOpsMethod { - return "GET"; - } - - async request( - options: AsyncCallerCallOptions - ): Promise { - return this._request(undefined, options); - } -} - -/** - * Used to represent parameters that are necessary to delete documents - * from the matching engine. These must be a list of string IDs - */ -export interface MatchingEngineDeleteParams { - ids: string[]; -} - -interface RemoveDatapointParams - extends GoogleVertexAIConnectionParams { - index: string; -} - -interface RemoveDatapointRequest { - datapointIds: string[]; -} - -interface RemoveDatapointResponse extends GoogleResponse { - // Should be empty -} - -class RemoveDatapointConnection extends GoogleVertexAIConnection< - AsyncCallerCallOptions, - RemoveDatapointResponse, - GoogleAuthOptions -> { - index: string; - - constructor(fields: RemoveDatapointParams, caller: AsyncCaller) { - super(fields, caller, new GoogleAuth(fields.authOptions)); - - this.index = fields.index; - } - - async buildUrl(): Promise { - const projectId = await this.client.getProjectId(); - const url = `https://${this.endpoint}/${this.apiVersion}/projects/${projectId}/locations/${this.location}/indexes/${this.index}:removeDatapoints`; - return url; - } - - buildMethod(): GoogleAbstractedClientOpsMethod { - return "POST"; - } - - async request( - datapointIds: string[], - options: AsyncCallerCallOptions - ): Promise { - const data: RemoveDatapointRequest = { - datapointIds, - }; - return this._request(data, options); - } -} - -interface UpsertDatapointParams - extends GoogleVertexAIConnectionParams { - index: string; -} - -export interface Restriction { - namespace: string; - allowList?: string[]; - denyList?: string[]; -} - -interface CrowdingTag { - crowdingAttribute: string; -} - -interface IndexDatapoint { - datapointId: string; - featureVector: number[]; - restricts?: Restriction[]; - crowdingTag?: CrowdingTag; -} - -interface UpsertDatapointRequest { - datapoints: IndexDatapoint[]; -} - -interface UpsertDatapointResponse extends GoogleResponse { - // Should be empty -} - -class UpsertDatapointConnection extends GoogleVertexAIConnection< - AsyncCallerCallOptions, - UpsertDatapointResponse, - GoogleAuthOptions -> { - index: string; - - constructor(fields: UpsertDatapointParams, caller: AsyncCaller) { - super(fields, caller, new GoogleAuth(fields.authOptions)); - - this.index = fields.index; - } - - async buildUrl(): Promise { - const projectId = await this.client.getProjectId(); - const url = `https://${this.endpoint}/${this.apiVersion}/projects/${projectId}/locations/${this.location}/indexes/${this.index}:upsertDatapoints`; - return url; - } - - buildMethod(): GoogleAbstractedClientOpsMethod { - return "POST"; - } - - async request( - datapoints: IndexDatapoint[], - options: AsyncCallerCallOptions - ): Promise { - const data: UpsertDatapointRequest = { - datapoints, - }; - return this._request(data, options); - } -} - -interface FindNeighborsConnectionParams - extends GoogleVertexAIConnectionParams { - indexEndpoint: string; - - deployedIndexId: string; -} - -interface FindNeighborsRequestQuery { - datapoint: { - datapointId: string; - featureVector: number[]; - restricts?: Restriction[]; - }; - neighborCount: number; -} - -interface FindNeighborsRequest { - deployedIndexId: string; - queries: FindNeighborsRequestQuery[]; -} - -interface FindNeighborsResponseNeighbor { - datapoint: { - datapointId: string; - crowdingTag: { - crowdingTagAttribute: string; - }; - }; - distance: number; -} - -interface FindNeighborsResponseNearestNeighbor { - id: string; - neighbors: FindNeighborsResponseNeighbor[]; -} - -interface FindNeighborsResponse extends GoogleResponse { - data: { - nearestNeighbors: FindNeighborsResponseNearestNeighbor[]; - }; -} - -class FindNeighborsConnection - extends GoogleVertexAIConnection< - AsyncCallerCallOptions, - FindNeighborsResponse, - GoogleAuthOptions - > - implements FindNeighborsConnectionParams -{ - indexEndpoint: string; - - deployedIndexId: string; - - constructor(params: FindNeighborsConnectionParams, caller: AsyncCaller) { - super(params, caller, new GoogleAuth(params.authOptions)); - - this.indexEndpoint = params.indexEndpoint; - this.deployedIndexId = params.deployedIndexId; - } - - async buildUrl(): Promise { - const projectId = await this.client.getProjectId(); - const url = `https://${this.endpoint}/${this.apiVersion}/projects/${projectId}/locations/${this.location}/indexEndpoints/${this.indexEndpoint}:findNeighbors`; - return url; - } - - buildMethod(): GoogleAbstractedClientOpsMethod { - return "POST"; - } - - async request( - request: FindNeighborsRequest, - options: AsyncCallerCallOptions - ): Promise { - return this._request(request, options); - } -} - -/** - * Information about the Matching Engine public API endpoint. - * Primarily exported to allow for testing. - */ -export interface PublicAPIEndpointInfo { - apiEndpoint?: string; - - deployedIndexId?: string; -} - -/** - * Parameters necessary to configure the Matching Engine. - */ -export interface MatchingEngineArgs - extends GoogleVertexAIConnectionParams, - IndexEndpointConnectionParams, - UpsertDatapointParams { - docstore: Docstore; - - callerParams?: AsyncCallerParams; - - callerOptions?: AsyncCallerCallOptions; - - apiEndpoint?: string; - - deployedIndexId?: string; -} - -/** - * A class that represents a connection to a Google Vertex AI Matching Engine - * instance. - */ -export class MatchingEngine extends VectorStore implements MatchingEngineArgs { - declare FilterType: Restriction[]; - - /** - * Docstore that retains the document, stored by ID - */ - docstore: Docstore; - - /** - * The host to connect to for queries and upserts. - */ - apiEndpoint: string; - - apiVersion = "v1"; - - endpoint = "us-central1-aiplatform.googleapis.com"; - - location = "us-central1"; - - /** - * The id for the index endpoint - */ - indexEndpoint: string; - - /** - * The id for the index - */ - index: string; - - /** - * The id for the "deployed index", which is an identifier in the - * index endpoint that references the index (but is not the index id) - */ - deployedIndexId: string; - - callerParams: AsyncCallerParams; - - callerOptions: AsyncCallerCallOptions; - - caller: AsyncCaller; - - indexEndpointClient: IndexEndpointConnection; - - removeDatapointClient: RemoveDatapointConnection; - - upsertDatapointClient: UpsertDatapointConnection; - - constructor(embeddings: Embeddings, args: MatchingEngineArgs) { - super(embeddings, args); - - this.embeddings = embeddings; - this.docstore = args.docstore; - - this.apiEndpoint = args.apiEndpoint ?? this.apiEndpoint; - this.deployedIndexId = args.deployedIndexId ?? this.deployedIndexId; - - this.apiVersion = args.apiVersion ?? this.apiVersion; - this.endpoint = args.endpoint ?? this.endpoint; - this.location = args.location ?? this.location; - this.indexEndpoint = args.indexEndpoint ?? this.indexEndpoint; - this.index = args.index ?? this.index; - - this.callerParams = args.callerParams ?? this.callerParams; - this.callerOptions = args.callerOptions ?? this.callerOptions; - this.caller = new AsyncCaller(this.callerParams || {}); - - const indexClientParams: IndexEndpointConnectionParams = { - endpoint: this.endpoint, - location: this.location, - apiVersion: this.apiVersion, - indexEndpoint: this.indexEndpoint, - }; - this.indexEndpointClient = new IndexEndpointConnection( - indexClientParams, - this.caller - ); - - const removeClientParams: RemoveDatapointParams = { - endpoint: this.endpoint, - location: this.location, - apiVersion: this.apiVersion, - index: this.index, - }; - this.removeDatapointClient = new RemoveDatapointConnection( - removeClientParams, - this.caller - ); - - const upsertClientParams: UpsertDatapointParams = { - endpoint: this.endpoint, - location: this.location, - apiVersion: this.apiVersion, - index: this.index, - }; - this.upsertDatapointClient = new UpsertDatapointConnection( - upsertClientParams, - this.caller - ); - } - - _vectorstoreType(): string { - return "googlevertexai"; - } - - async addDocuments(documents: Document[]): Promise { - const texts: string[] = documents.map((doc) => doc.pageContent); - const vectors: number[][] = await this.embeddings.embedDocuments(texts); - return this.addVectors(vectors, documents); - } - - async addVectors(vectors: number[][], documents: Document[]): Promise { - if (vectors.length !== documents.length) { - throw new Error(`Vectors and metadata must have the same length`); - } - const datapoints: IndexDatapoint[] = vectors.map((vector, idx) => - this.buildDatapoint(vector, documents[idx]) - ); - const options = {}; - const response = await this.upsertDatapointClient.request( - datapoints, - options - ); - if (Object.keys(response?.data ?? {}).length === 0) { - // Nothing in the response in the body means we saved it ok - const idDoc = documents as IdDocument[]; - const docsToStore: Record = {}; - idDoc.forEach((doc) => { - if (doc.id) { - docsToStore[doc.id] = doc; - } - }); - await this.docstore.add(docsToStore); - } - } - - // TODO: Refactor this into a utility type and use with pinecone as well? - // eslint-disable-next-line @typescript-eslint/no-explicit-any - cleanMetadata(documentMetadata: Record): { - [key: string]: string | number | boolean | string[] | null; - } { - type metadataType = { - [key: string]: string | number | boolean | string[] | null; - }; - - function getStringArrays( - prefix: string, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - m: Record - ): Record { - let ret: Record = {}; - - Object.keys(m).forEach((key) => { - const newPrefix = prefix.length > 0 ? `${prefix}.${key}` : key; - const val = m[key]; - if (!val) { - // Ignore it - } else if (Array.isArray(val)) { - // Make sure everything in the array is a string - ret[newPrefix] = val.map((v) => `${v}`); - } else if (typeof val === "object") { - const subArrays = getStringArrays(newPrefix, val); - ret = { ...ret, ...subArrays }; - } - }); - - return ret; - } - - const stringArrays: Record = getStringArrays( - "", - documentMetadata - ); - - const flatMetadata: metadataType = flatten(documentMetadata); - Object.keys(flatMetadata).forEach((key) => { - Object.keys(stringArrays).forEach((arrayKey) => { - const matchKey = `${arrayKey}.`; - if (key.startsWith(matchKey)) { - delete flatMetadata[key]; - } - }); - }); - - const metadata: metadataType = { - ...flatMetadata, - ...stringArrays, - }; - return metadata; - } - - /** - * Given the metadata from a document, convert it to an array of Restriction - * objects that may be passed to the Matching Engine and stored. - * The default implementation flattens any metadata and includes it as - * an "allowList". Subclasses can choose to convert some of these to - * "denyList" items or to add additional restrictions (for example, to format - * dates into a different structure or to add additional restrictions - * based on the date). - * @param documentMetadata - The metadata from a document - * @returns a Restriction[] (or an array of a subclass, from the FilterType) - */ - metadataToRestrictions( - // eslint-disable-next-line @typescript-eslint/no-explicit-any - documentMetadata: Record - ): this["FilterType"] { - const metadata = this.cleanMetadata(documentMetadata); - - const restrictions: this["FilterType"] = []; - for (const key of Object.keys(metadata)) { - // Make sure the value is an array (or that we'll ignore it) - let valArray; - const val = metadata[key]; - if (val === null) { - valArray = null; - } else if (Array.isArray(val) && val.length > 0) { - valArray = val; - } else { - valArray = [`${val}`]; - } - - // Add to the restrictions if we do have a valid value - if (valArray) { - // Determine if this key is for the allowList or denyList - // TODO: get which ones should be on the deny list - const listType = "allowList"; - - // Create the restriction - const restriction: Restriction = { - namespace: key, - [listType]: valArray, - }; - - // Add it to the restriction list - restrictions.push(restriction); - } - } - return restrictions; - } - - /** - * Create an index datapoint for the vector and document id. - * If an id does not exist, create it and set the document to its value. - * @param vector - * @param document - */ - buildDatapoint(vector: number[], document: IdDocument): IndexDatapoint { - if (!document.id) { - // eslint-disable-next-line no-param-reassign - document.id = uuid.v4(); - } - const ret: IndexDatapoint = { - datapointId: document.id, - featureVector: vector, - }; - const restrictions = this.metadataToRestrictions(document.metadata); - if (restrictions?.length > 0) { - ret.restricts = restrictions; - } - return ret; - } - - async delete(params: MatchingEngineDeleteParams): Promise { - const options = {}; - await this.removeDatapointClient.request(params.ids, options); - } - - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: this["FilterType"] - ): Promise<[Document, number][]> { - // Format the query into the request - const deployedIndexId = await this.getDeployedIndexId(); - const requestQuery: FindNeighborsRequestQuery = { - neighborCount: k, - datapoint: { - datapointId: `0`, - featureVector: query, - }, - }; - if (filter) { - requestQuery.datapoint.restricts = filter; - } - const request: FindNeighborsRequest = { - deployedIndexId, - queries: [requestQuery], - }; - - // Build the connection. - // Has to be done here, since we defer getting the endpoint until - // we need it. - const apiEndpoint = await this.getPublicAPIEndpoint(); - const findNeighborsParams: FindNeighborsConnectionParams = { - endpoint: apiEndpoint, - indexEndpoint: this.indexEndpoint, - apiVersion: this.apiVersion, - location: this.location, - deployedIndexId, - }; - const connection = new FindNeighborsConnection( - findNeighborsParams, - this.caller - ); - - // Make the call - const options = {}; - const response = await connection.request(request, options); - - // Get the document for each datapoint id and return them - const nearestNeighbors = response?.data?.nearestNeighbors ?? []; - const nearestNeighbor = nearestNeighbors[0]; - const neighbors = nearestNeighbor?.neighbors ?? []; - const ret: [Document, number][] = await Promise.all( - neighbors.map(async (neighbor) => { - const id = neighbor?.datapoint?.datapointId; - const distance = neighbor?.distance; - let doc: IdDocument; - try { - doc = await this.docstore.search(id); - } catch (xx) { - // Documents that are in the index are returned, even if they - // are not in the document store, to allow for some way to get - // the id so they can be deleted. - console.error(xx); - console.warn( - [ - `Document with id "${id}" is missing from the backing docstore.`, - `This can occur if you clear the docstore without deleting from the corresponding Matching Engine index.`, - `To resolve this, you should call .delete() with this id as part of the "ids" parameter.`, - ].join("\n") - ); - doc = new Document({ pageContent: `Missing document ${id}` }); - } - doc.id ??= id; - return [doc, distance]; - }) - ); - - return ret; - } - - /** - * For this index endpoint, figure out what API Endpoint URL and deployed - * index ID should be used to do upserts and queries. - * Also sets the `apiEndpoint` and `deployedIndexId` property for future use. - * @return The URL - */ - async determinePublicAPIEndpoint(): Promise { - const response: IndexEndpointResponse = - await this.indexEndpointClient.request(this.callerOptions); - - // Get the endpoint - const publicEndpointDomainName = response?.data?.publicEndpointDomainName; - this.apiEndpoint = publicEndpointDomainName; - - // Determine which of the deployed indexes match the index id - // and get the deployed index id. The list of deployed index ids - // contain the "index name" or path, but not the index id by itself, - // so we need to extract it from the name - const indexPathPattern = /projects\/.+\/locations\/.+\/indexes\/(.+)$/; - const deployedIndexes = response?.data?.deployedIndexes ?? []; - const deployedIndex = deployedIndexes.find((index) => { - const deployedIndexPath = index.index; - const match = deployedIndexPath.match(indexPathPattern); - if (match) { - const [, potentialIndexId] = match; - if (potentialIndexId === this.index) { - return true; - } - } - return false; - }); - if (deployedIndex) { - this.deployedIndexId = deployedIndex.id; - } - - return { - apiEndpoint: this.apiEndpoint, - deployedIndexId: this.deployedIndexId, - }; - } - - async getPublicAPIEndpoint(): Promise { - return ( - this.apiEndpoint ?? (await this.determinePublicAPIEndpoint()).apiEndpoint - ); - } - - async getDeployedIndexId(): Promise { - return ( - this.deployedIndexId ?? - (await this.determinePublicAPIEndpoint()).deployedIndexId - ); - } - - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - dbConfig: MatchingEngineArgs - ): Promise { - const docs: Document[] = texts.map( - (text, index): Document => ({ - pageContent: text, - metadata: Array.isArray(metadatas) ? metadatas[index] : metadatas, - }) - ); - return this.fromDocuments(docs, embeddings, dbConfig); - } - - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig: MatchingEngineArgs - ): Promise { - const ret = new MatchingEngine(embeddings, dbConfig); - await ret.addDocuments(docs); - return ret; - } -} +export * from "@langchain/community/vectorstores/googlevertexai"; diff --git a/langchain/src/vectorstores/hnswlib.ts b/langchain/src/vectorstores/hnswlib.ts index 7f29eaaba44f..2d91e0d321c8 100644 --- a/langchain/src/vectorstores/hnswlib.ts +++ b/langchain/src/vectorstores/hnswlib.ts @@ -1,354 +1 @@ -import type { - HierarchicalNSW as HierarchicalNSWT, - SpaceName, -} from "hnswlib-node"; -import { Embeddings } from "../embeddings/base.js"; -import { SaveableVectorStore } from "./base.js"; -import { Document } from "../document.js"; -import { SynchronousInMemoryDocstore } from "../stores/doc/in_memory.js"; - -/** - * Interface for the base configuration of HNSWLib. It includes the space - * name and the number of dimensions. - */ -export interface HNSWLibBase { - space: SpaceName; - numDimensions?: number; -} - -/** - * Interface for the arguments that can be passed to the HNSWLib - * constructor. It extends HNSWLibBase and includes properties for the - * document store and HNSW index. - */ -export interface HNSWLibArgs extends HNSWLibBase { - docstore?: SynchronousInMemoryDocstore; - index?: HierarchicalNSWT; -} - -/** - * Class that implements a vector store using Hierarchical Navigable Small - * World (HNSW) graphs. It extends the SaveableVectorStore class and - * provides methods for adding documents and vectors, performing - * similarity searches, and saving and loading the vector store. - */ -export class HNSWLib extends SaveableVectorStore { - declare FilterType: (doc: Document) => boolean; - - _index?: HierarchicalNSWT; - - docstore: SynchronousInMemoryDocstore; - - args: HNSWLibBase; - - _vectorstoreType(): string { - return "hnswlib"; - } - - constructor(embeddings: Embeddings, args: HNSWLibArgs) { - super(embeddings, args); - this._index = args.index; - this.args = args; - this.embeddings = embeddings; - this.docstore = args?.docstore ?? new SynchronousInMemoryDocstore(); - } - - /** - * Method to add documents to the vector store. It first converts the - * documents to vectors using the embeddings, then adds the vectors to the - * vector store. - * @param documents The documents to be added to the vector store. - * @returns A Promise that resolves when the documents have been added. - */ - async addDocuments(documents: Document[]): Promise { - const texts = documents.map(({ pageContent }) => pageContent); - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents - ); - } - - private static async getHierarchicalNSW(args: HNSWLibBase) { - const { HierarchicalNSW } = await HNSWLib.imports(); - if (!args.space) { - throw new Error("hnswlib-node requires a space argument"); - } - if (args.numDimensions === undefined) { - throw new Error("hnswlib-node requires a numDimensions argument"); - } - return new HierarchicalNSW(args.space, args.numDimensions); - } - - private async initIndex(vectors: number[][]) { - if (!this._index) { - if (this.args.numDimensions === undefined) { - this.args.numDimensions = vectors[0].length; - } - this.index = await HNSWLib.getHierarchicalNSW(this.args); - } - if (!this.index.getCurrentCount()) { - this.index.initIndex(vectors.length); - } - } - - public get index(): HierarchicalNSWT { - if (!this._index) { - throw new Error( - "Vector store not initialised yet. Try calling `addTexts` first." - ); - } - return this._index; - } - - private set index(index: HierarchicalNSWT) { - this._index = index; - } - - /** - * Method to add vectors to the vector store. It first initializes the - * index if it hasn't been initialized yet, then adds the vectors to the - * index and the documents to the document store. - * @param vectors The vectors to be added to the vector store. - * @param documents The documents corresponding to the vectors. - * @returns A Promise that resolves when the vectors and documents have been added. - */ - async addVectors(vectors: number[][], documents: Document[]) { - if (vectors.length === 0) { - return; - } - await this.initIndex(vectors); - - // TODO here we could optionally normalise the vectors to unit length - // so that dot product is equivalent to cosine similarity, like this - // https://github.com/nmslib/hnswlib/issues/384#issuecomment-1155737730 - // While we only support OpenAI embeddings this isn't necessary - if (vectors.length !== documents.length) { - throw new Error(`Vectors and metadatas must have the same length`); - } - if (vectors[0].length !== this.args.numDimensions) { - throw new Error( - `Vectors must have the same length as the number of dimensions (${this.args.numDimensions})` - ); - } - const capacity = this.index.getMaxElements(); - const needed = this.index.getCurrentCount() + vectors.length; - if (needed > capacity) { - this.index.resizeIndex(needed); - } - const docstoreSize = this.index.getCurrentCount(); - const toSave: Record = {}; - for (let i = 0; i < vectors.length; i += 1) { - this.index.addPoint(vectors[i], docstoreSize + i); - toSave[docstoreSize + i] = documents[i]; - } - this.docstore.add(toSave); - } - - /** - * Method to perform a similarity search in the vector store using a query - * vector. It returns the k most similar documents along with their - * similarity scores. An optional filter function can be provided to - * filter the documents. - * @param query The query vector. - * @param k The number of most similar documents to return. - * @param filter An optional filter function to filter the documents. - * @returns A Promise that resolves to an array of tuples, where each tuple contains a document and its similarity score. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: this["FilterType"] - ) { - if (this.args.numDimensions && !this._index) { - await this.initIndex([[]]); - } - if (query.length !== this.args.numDimensions) { - throw new Error( - `Query vector must have the same length as the number of dimensions (${this.args.numDimensions})` - ); - } - if (k > this.index.getCurrentCount()) { - const total = this.index.getCurrentCount(); - console.warn( - `k (${k}) is greater than the number of elements in the index (${total}), setting k to ${total}` - ); - // eslint-disable-next-line no-param-reassign - k = total; - } - const filterFunction = (label: number): boolean => { - if (!filter) { - return true; - } - const document = this.docstore.search(String(label)); - // eslint-disable-next-line no-instanceof/no-instanceof - if (typeof document !== "string") { - return filter(document); - } - return false; - }; - const result = this.index.searchKnn( - query, - k, - filter ? filterFunction : undefined - ); - return result.neighbors.map( - (docIndex, resultIndex) => - [ - this.docstore.search(String(docIndex)), - result.distances[resultIndex], - ] as [Document, number] - ); - } - - /** - * Method to delete the vector store from a directory. It deletes the - * hnswlib.index file, the docstore.json file, and the args.json file from - * the directory. - * @param params An object with a directory property that specifies the directory from which to delete the vector store. - * @returns A Promise that resolves when the vector store has been deleted. - */ - async delete(params: { directory: string }) { - const fs = await import("node:fs/promises"); - const path = await import("node:path"); - try { - await fs.access(path.join(params.directory, "hnswlib.index")); - } catch (err) { - throw new Error( - `Directory ${params.directory} does not contain a hnswlib.index file.` - ); - } - - await Promise.all([ - await fs.rm(path.join(params.directory, "hnswlib.index"), { - force: true, - }), - await fs.rm(path.join(params.directory, "docstore.json"), { - force: true, - }), - await fs.rm(path.join(params.directory, "args.json"), { force: true }), - ]); - } - - /** - * Method to save the vector store to a directory. It saves the HNSW - * index, the arguments, and the document store to the directory. - * @param directory The directory to which to save the vector store. - * @returns A Promise that resolves when the vector store has been saved. - */ - async save(directory: string) { - const fs = await import("node:fs/promises"); - const path = await import("node:path"); - await fs.mkdir(directory, { recursive: true }); - await Promise.all([ - this.index.writeIndex(path.join(directory, "hnswlib.index")), - await fs.writeFile( - path.join(directory, "args.json"), - JSON.stringify(this.args) - ), - await fs.writeFile( - path.join(directory, "docstore.json"), - JSON.stringify(Array.from(this.docstore._docs.entries())) - ), - ]); - } - - /** - * Static method to load a vector store from a directory. It reads the - * HNSW index, the arguments, and the document store from the directory, - * then creates a new HNSWLib instance with these values. - * @param directory The directory from which to load the vector store. - * @param embeddings The embeddings to be used by the HNSWLib instance. - * @returns A Promise that resolves to a new HNSWLib instance. - */ - static async load(directory: string, embeddings: Embeddings) { - const fs = await import("node:fs/promises"); - const path = await import("node:path"); - const args = JSON.parse( - await fs.readFile(path.join(directory, "args.json"), "utf8") - ); - const index = await HNSWLib.getHierarchicalNSW(args); - const [docstoreFiles] = await Promise.all([ - fs - .readFile(path.join(directory, "docstore.json"), "utf8") - .then(JSON.parse), - index.readIndex(path.join(directory, "hnswlib.index")), - ]); - args.docstore = new SynchronousInMemoryDocstore(new Map(docstoreFiles)); - - args.index = index; - - return new HNSWLib(embeddings, args); - } - - /** - * Static method to create a new HNSWLib instance from texts and metadata. - * It creates a new Document instance for each text and metadata, then - * calls the fromDocuments method to create the HNSWLib instance. - * @param texts The texts to be used to create the documents. - * @param metadatas The metadata to be used to create the documents. - * @param embeddings The embeddings to be used by the HNSWLib instance. - * @param dbConfig An optional configuration object for the document store. - * @returns A Promise that resolves to a new HNSWLib instance. - */ - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - dbConfig?: { - docstore?: SynchronousInMemoryDocstore; - } - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - return HNSWLib.fromDocuments(docs, embeddings, dbConfig); - } - - /** - * Static method to create a new HNSWLib instance from documents. It - * creates a new HNSWLib instance, adds the documents to it, then returns - * the instance. - * @param docs The documents to be added to the HNSWLib instance. - * @param embeddings The embeddings to be used by the HNSWLib instance. - * @param dbConfig An optional configuration object for the document store. - * @returns A Promise that resolves to a new HNSWLib instance. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig?: { - docstore?: SynchronousInMemoryDocstore; - } - ): Promise { - const args: HNSWLibArgs = { - docstore: dbConfig?.docstore, - space: "cosine", - }; - const instance = new this(embeddings, args); - await instance.addDocuments(docs); - return instance; - } - - static async imports(): Promise<{ - HierarchicalNSW: typeof HierarchicalNSWT; - }> { - try { - const { - default: { HierarchicalNSW }, - } = await import("hnswlib-node"); - - return { HierarchicalNSW }; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } catch (err: any) { - throw new Error( - `Could not import hnswlib-node. Please install hnswlib-node as a dependency with, e.g. \`npm install -S hnswlib-node\`.\n\nError: ${err?.message}` - ); - } - } -} +export * from "@langchain/community/vectorstores/hnswlib"; diff --git a/langchain/src/vectorstores/lancedb.ts b/langchain/src/vectorstores/lancedb.ts index 398147ed0d3b..4b4b62d64a75 100644 --- a/langchain/src/vectorstores/lancedb.ts +++ b/langchain/src/vectorstores/lancedb.ts @@ -1,152 +1 @@ -import { Table } from "vectordb"; -import { VectorStore } from "./base.js"; -import { Embeddings } from "../embeddings/base.js"; -import { Document } from "../document.js"; - -/** - * Defines the arguments for the LanceDB class constructor. It includes a - * table and an optional textKey. - */ -export type LanceDBArgs = { - table: Table; - textKey?: string; -}; - -/** - * A wrapper for an open-source database for vector-search with persistent - * storage. It simplifies retrieval, filtering, and management of - * embeddings. - */ -export class LanceDB extends VectorStore { - private table: Table; - - private textKey: string; - - constructor(embeddings: Embeddings, args: LanceDBArgs) { - super(embeddings, args); - this.table = args.table; - this.embeddings = embeddings; - this.textKey = args.textKey || "text"; - } - - /** - * Adds documents to the database. - * @param documents The documents to be added. - * @returns A Promise that resolves when the documents have been added. - */ - async addDocuments(documents: Document[]): Promise { - const texts = documents.map(({ pageContent }) => pageContent); - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents - ); - } - - _vectorstoreType(): string { - return "lancedb"; - } - - /** - * Adds vectors and their corresponding documents to the database. - * @param vectors The vectors to be added. - * @param documents The corresponding documents to be added. - * @returns A Promise that resolves when the vectors and documents have been added. - */ - async addVectors(vectors: number[][], documents: Document[]): Promise { - if (vectors.length === 0) { - return; - } - if (vectors.length !== documents.length) { - throw new Error(`Vectors and documents must have the same length`); - } - - const data: Array> = []; - for (let i = 0; i < documents.length; i += 1) { - const record = { - vector: vectors[i], - [this.textKey]: documents[i].pageContent, - }; - Object.keys(documents[i].metadata).forEach((metaKey) => { - record[metaKey] = documents[i].metadata[metaKey]; - }); - data.push(record); - } - await this.table.add(data); - } - - /** - * Performs a similarity search on the vectors in the database and returns - * the documents and their scores. - * @param query The query vector. - * @param k The number of results to return. - * @returns A Promise that resolves with an array of tuples, each containing a Document and its score. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number - ): Promise<[Document, number][]> { - const results = await this.table.search(query).limit(k).execute(); - - const docsAndScore: [Document, number][] = []; - results.forEach((item) => { - const metadata: Record = {}; - Object.keys(item).forEach((key) => { - if (key !== "vector" && key !== "score" && key !== this.textKey) { - metadata[key] = item[key]; - } - }); - - docsAndScore.push([ - new Document({ - pageContent: item[this.textKey] as string, - metadata, - }), - item.score as number, - ]); - }); - return docsAndScore; - } - - /** - * Creates a new instance of LanceDB from texts. - * @param texts The texts to be converted into documents. - * @param metadatas The metadata for the texts. - * @param embeddings The embeddings to be managed. - * @param dbConfig The configuration for the LanceDB instance. - * @returns A Promise that resolves with a new instance of LanceDB. - */ - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - dbConfig: LanceDBArgs - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - return LanceDB.fromDocuments(docs, embeddings, dbConfig); - } - - /** - * Creates a new instance of LanceDB from documents. - * @param docs The documents to be added to the database. - * @param embeddings The embeddings to be managed. - * @param dbConfig The configuration for the LanceDB instance. - * @returns A Promise that resolves with a new instance of LanceDB. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig: LanceDBArgs - ): Promise { - const instance = new this(embeddings, dbConfig); - await instance.addDocuments(docs); - return instance; - } -} +export * from "@langchain/community/vectorstores/lancedb"; diff --git a/langchain/src/vectorstores/memory.ts b/langchain/src/vectorstores/memory.ts index c2e396980293..5cf932eecac8 100644 --- a/langchain/src/vectorstores/memory.ts +++ b/langchain/src/vectorstores/memory.ts @@ -1,192 +1 @@ -import { similarity as ml_distance_similarity } from "ml-distance"; -import { VectorStore } from "./base.js"; -import { Embeddings } from "../embeddings/base.js"; -import { Document } from "../document.js"; - -/** - * Interface representing a vector in memory. It includes the content - * (text), the corresponding embedding (vector), and any associated - * metadata. - */ -interface MemoryVector { - content: string; - embedding: number[]; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - metadata: Record; -} - -/** - * Interface for the arguments that can be passed to the - * `MemoryVectorStore` constructor. It includes an optional `similarity` - * function. - */ -export interface MemoryVectorStoreArgs { - similarity?: typeof ml_distance_similarity.cosine; -} - -/** - * Class that extends `VectorStore` to store vectors in memory. Provides - * methods for adding documents, performing similarity searches, and - * creating instances from texts, documents, or an existing index. - */ -export class MemoryVectorStore extends VectorStore { - declare FilterType: (doc: Document) => boolean; - - memoryVectors: MemoryVector[] = []; - - similarity: typeof ml_distance_similarity.cosine; - - _vectorstoreType(): string { - return "memory"; - } - - constructor( - embeddings: Embeddings, - { similarity, ...rest }: MemoryVectorStoreArgs = {} - ) { - super(embeddings, rest); - - this.similarity = similarity ?? ml_distance_similarity.cosine; - } - - /** - * Method to add documents to the memory vector store. It extracts the - * text from each document, generates embeddings for them, and adds the - * resulting vectors to the store. - * @param documents Array of `Document` instances to be added to the store. - * @returns Promise that resolves when all documents have been added. - */ - async addDocuments(documents: Document[]): Promise { - const texts = documents.map(({ pageContent }) => pageContent); - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents - ); - } - - /** - * Method to add vectors to the memory vector store. It creates - * `MemoryVector` instances for each vector and document pair and adds - * them to the store. - * @param vectors Array of vectors to be added to the store. - * @param documents Array of `Document` instances corresponding to the vectors. - * @returns Promise that resolves when all vectors have been added. - */ - async addVectors(vectors: number[][], documents: Document[]): Promise { - const memoryVectors = vectors.map((embedding, idx) => ({ - content: documents[idx].pageContent, - embedding, - metadata: documents[idx].metadata, - })); - - this.memoryVectors = this.memoryVectors.concat(memoryVectors); - } - - /** - * Method to perform a similarity search in the memory vector store. It - * calculates the similarity between the query vector and each vector in - * the store, sorts the results by similarity, and returns the top `k` - * results along with their scores. - * @param query Query vector to compare against the vectors in the store. - * @param k Number of top results to return. - * @param filter Optional filter function to apply to the vectors before performing the search. - * @returns Promise that resolves with an array of tuples, each containing a `Document` and its similarity score. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: this["FilterType"] - ): Promise<[Document, number][]> { - const filterFunction = (memoryVector: MemoryVector) => { - if (!filter) { - return true; - } - - const doc = new Document({ - metadata: memoryVector.metadata, - pageContent: memoryVector.content, - }); - return filter(doc); - }; - const filteredMemoryVectors = this.memoryVectors.filter(filterFunction); - const searches = filteredMemoryVectors - .map((vector, index) => ({ - similarity: this.similarity(query, vector.embedding), - index, - })) - .sort((a, b) => (a.similarity > b.similarity ? -1 : 0)) - .slice(0, k); - - const result: [Document, number][] = searches.map((search) => [ - new Document({ - metadata: filteredMemoryVectors[search.index].metadata, - pageContent: filteredMemoryVectors[search.index].content, - }), - search.similarity, - ]); - - return result; - } - - /** - * Static method to create a `MemoryVectorStore` instance from an array of - * texts. It creates a `Document` for each text and metadata pair, and - * adds them to the store. - * @param texts Array of texts to be added to the store. - * @param metadatas Array or single object of metadata corresponding to the texts. - * @param embeddings `Embeddings` instance used to generate embeddings for the texts. - * @param dbConfig Optional `MemoryVectorStoreArgs` to configure the `MemoryVectorStore` instance. - * @returns Promise that resolves with a new `MemoryVectorStore` instance. - */ - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - dbConfig?: MemoryVectorStoreArgs - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - return MemoryVectorStore.fromDocuments(docs, embeddings, dbConfig); - } - - /** - * Static method to create a `MemoryVectorStore` instance from an array of - * `Document` instances. It adds the documents to the store. - * @param docs Array of `Document` instances to be added to the store. - * @param embeddings `Embeddings` instance used to generate embeddings for the documents. - * @param dbConfig Optional `MemoryVectorStoreArgs` to configure the `MemoryVectorStore` instance. - * @returns Promise that resolves with a new `MemoryVectorStore` instance. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig?: MemoryVectorStoreArgs - ): Promise { - const instance = new this(embeddings, dbConfig); - await instance.addDocuments(docs); - return instance; - } - - /** - * Static method to create a `MemoryVectorStore` instance from an existing - * index. It creates a new `MemoryVectorStore` instance without adding any - * documents or vectors. - * @param embeddings `Embeddings` instance used to generate embeddings for the documents. - * @param dbConfig Optional `MemoryVectorStoreArgs` to configure the `MemoryVectorStore` instance. - * @returns Promise that resolves with a new `MemoryVectorStore` instance. - */ - static async fromExistingIndex( - embeddings: Embeddings, - dbConfig?: MemoryVectorStoreArgs - ): Promise { - const instance = new this(embeddings, dbConfig); - return instance; - } -} +export * from "@langchain/community/vectorstores/memory"; diff --git a/langchain/src/vectorstores/milvus.ts b/langchain/src/vectorstores/milvus.ts index 463da9e2e044..c99dfbc8ad45 100644 --- a/langchain/src/vectorstores/milvus.ts +++ b/langchain/src/vectorstores/milvus.ts @@ -1,674 +1 @@ -import * as uuid from "uuid"; -import { - MilvusClient, - DataType, - DataTypeMap, - ErrorCode, - FieldType, - ClientConfig, -} from "@zilliz/milvus2-sdk-node"; - -import { Embeddings } from "../embeddings/base.js"; -import { VectorStore } from "./base.js"; -import { Document } from "../document.js"; -import { getEnvironmentVariable } from "../util/env.js"; - -/** - * Interface for the arguments required by the Milvus class constructor. - */ -export interface MilvusLibArgs { - collectionName?: string; - primaryField?: string; - vectorField?: string; - textField?: string; - url?: string; // db address - ssl?: boolean; - username?: string; - password?: string; - textFieldMaxLength?: number; - clientConfig?: ClientConfig; - autoId?: boolean; -} - -/** - * Type representing the type of index used in the Milvus database. - */ -type IndexType = - | "IVF_FLAT" - | "IVF_SQ8" - | "IVF_PQ" - | "HNSW" - | "RHNSW_FLAT" - | "RHNSW_SQ" - | "RHNSW_PQ" - | "IVF_HNSW" - | "ANNOY"; - -/** - * Interface for the parameters required to create an index in the Milvus - * database. - */ -interface IndexParam { - params: { nprobe?: number; ef?: number; search_k?: number }; -} - -interface InsertRow { - [x: string]: string | number[]; -} - -const MILVUS_PRIMARY_FIELD_NAME = "langchain_primaryid"; -const MILVUS_VECTOR_FIELD_NAME = "langchain_vector"; -const MILVUS_TEXT_FIELD_NAME = "langchain_text"; -const MILVUS_COLLECTION_NAME_PREFIX = "langchain_col"; - -/** - * Class for interacting with a Milvus database. Extends the VectorStore - * class. - */ -export class Milvus extends VectorStore { - get lc_secrets(): { [key: string]: string } { - return { - ssl: "MILVUS_SSL", - username: "MILVUS_USERNAME", - password: "MILVUS_PASSWORD", - }; - } - - declare FilterType: string; - - collectionName: string; - - numDimensions?: number; - - autoId?: boolean; - - primaryField: string; - - vectorField: string; - - textField: string; - - textFieldMaxLength: number; - - fields: string[]; - - client: MilvusClient; - - indexParams: Record = { - IVF_FLAT: { params: { nprobe: 10 } }, - IVF_SQ8: { params: { nprobe: 10 } }, - IVF_PQ: { params: { nprobe: 10 } }, - HNSW: { params: { ef: 10 } }, - RHNSW_FLAT: { params: { ef: 10 } }, - RHNSW_SQ: { params: { ef: 10 } }, - RHNSW_PQ: { params: { ef: 10 } }, - IVF_HNSW: { params: { nprobe: 10, ef: 10 } }, - ANNOY: { params: { search_k: 10 } }, - }; - - indexCreateParams = { - index_type: "HNSW", - metric_type: "L2", - params: JSON.stringify({ M: 8, efConstruction: 64 }), - }; - - indexSearchParams = JSON.stringify({ ef: 64 }); - - _vectorstoreType(): string { - return "milvus"; - } - - constructor(embeddings: Embeddings, args: MilvusLibArgs) { - super(embeddings, args); - this.embeddings = embeddings; - this.collectionName = args.collectionName ?? genCollectionName(); - this.textField = args.textField ?? MILVUS_TEXT_FIELD_NAME; - - this.autoId = args.autoId ?? true; - this.primaryField = args.primaryField ?? MILVUS_PRIMARY_FIELD_NAME; - this.vectorField = args.vectorField ?? MILVUS_VECTOR_FIELD_NAME; - - this.textFieldMaxLength = args.textFieldMaxLength ?? 0; - - this.fields = []; - - const url = args.url ?? getEnvironmentVariable("MILVUS_URL"); - const { - address = "", - username = "", - password = "", - ssl, - } = args.clientConfig || {}; - - // combine args clientConfig and env variables - const clientConfig: ClientConfig = { - ...(args.clientConfig || {}), - address: url || address, - username: args.username || username, - password: args.password || password, - ssl: args.ssl || ssl, - }; - - if (!clientConfig.address) { - throw new Error("Milvus URL address is not provided."); - } - this.client = new MilvusClient(clientConfig); - } - - /** - * Adds documents to the Milvus database. - * @param documents Array of Document instances to be added to the database. - * @returns Promise resolving to void. - */ - async addDocuments(documents: Document[]): Promise { - const texts = documents.map(({ pageContent }) => pageContent); - await this.addVectors( - await this.embeddings.embedDocuments(texts), - documents - ); - } - - /** - * Adds vectors to the Milvus database. - * @param vectors Array of vectors to be added to the database. - * @param documents Array of Document instances associated with the vectors. - * @returns Promise resolving to void. - */ - async addVectors(vectors: number[][], documents: Document[]): Promise { - if (vectors.length === 0) { - return; - } - await this.ensureCollection(vectors, documents); - - const insertDatas: InsertRow[] = []; - // eslint-disable-next-line no-plusplus - for (let index = 0; index < vectors.length; index++) { - const vec = vectors[index]; - const doc = documents[index]; - const data: InsertRow = { - [this.textField]: doc.pageContent, - [this.vectorField]: vec, - }; - this.fields.forEach((field) => { - switch (field) { - case this.primaryField: - if (!this.autoId) { - if (doc.metadata[this.primaryField] === undefined) { - throw new Error( - `The Collection's primaryField is configured with autoId=false, thus its value must be provided through metadata.` - ); - } - data[field] = doc.metadata[this.primaryField]; - } - break; - case this.textField: - data[field] = doc.pageContent; - break; - case this.vectorField: - data[field] = vec; - break; - default: // metadata fields - if (doc.metadata[field] === undefined) { - throw new Error( - `The field "${field}" is not provided in documents[${index}].metadata.` - ); - } else if (typeof doc.metadata[field] === "object") { - data[field] = JSON.stringify(doc.metadata[field]); - } else { - data[field] = doc.metadata[field]; - } - break; - } - }); - - insertDatas.push(data); - } - - const insertResp = await this.client.insert({ - collection_name: this.collectionName, - fields_data: insertDatas, - }); - if (insertResp.status.error_code !== ErrorCode.SUCCESS) { - throw new Error(`Error inserting data: ${JSON.stringify(insertResp)}`); - } - await this.client.flushSync({ collection_names: [this.collectionName] }); - } - - /** - * Searches for vectors in the Milvus database that are similar to a given - * vector. - * @param query Vector to compare with the vectors in the database. - * @param k Number of similar vectors to return. - * @param filter Optional filter to apply to the search. - * @returns Promise resolving to an array of tuples, each containing a Document instance and a similarity score. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: string - ): Promise<[Document, number][]> { - const hasColResp = await this.client.hasCollection({ - collection_name: this.collectionName, - }); - if (hasColResp.status.error_code !== ErrorCode.SUCCESS) { - throw new Error(`Error checking collection: ${hasColResp}`); - } - if (hasColResp.value === false) { - throw new Error( - `Collection not found: ${this.collectionName}, please create collection before search.` - ); - } - - const filterStr = filter ?? ""; - - await this.grabCollectionFields(); - - const loadResp = await this.client.loadCollectionSync({ - collection_name: this.collectionName, - }); - if (loadResp.error_code !== ErrorCode.SUCCESS) { - throw new Error(`Error loading collection: ${loadResp}`); - } - - // clone this.field and remove vectorField - const outputFields = this.fields.filter( - (field) => field !== this.vectorField - ); - - const searchResp = await this.client.search({ - collection_name: this.collectionName, - search_params: { - anns_field: this.vectorField, - topk: k.toString(), - metric_type: this.indexCreateParams.metric_type, - params: this.indexSearchParams, - }, - output_fields: outputFields, - vector_type: DataType.FloatVector, - vectors: [query], - filter: filterStr, - }); - if (searchResp.status.error_code !== ErrorCode.SUCCESS) { - throw new Error(`Error searching data: ${JSON.stringify(searchResp)}`); - } - const results: [Document, number][] = []; - searchResp.results.forEach((result) => { - const fields = { - pageContent: "", - // eslint-disable-next-line @typescript-eslint/no-explicit-any - metadata: {} as Record, - }; - Object.keys(result).forEach((key) => { - if (key === this.textField) { - fields.pageContent = result[key]; - } else if (this.fields.includes(key) || key === this.primaryField) { - if (typeof result[key] === "string") { - const { isJson, obj } = checkJsonString(result[key]); - fields.metadata[key] = isJson ? obj : result[key]; - } else { - fields.metadata[key] = result[key]; - } - } - }); - results.push([new Document(fields), result.score]); - }); - // console.log("Search result: " + JSON.stringify(results, null, 2)); - return results; - } - - /** - * Ensures that a collection exists in the Milvus database. - * @param vectors Optional array of vectors to be used if a new collection needs to be created. - * @param documents Optional array of Document instances to be used if a new collection needs to be created. - * @returns Promise resolving to void. - */ - async ensureCollection(vectors?: number[][], documents?: Document[]) { - const hasColResp = await this.client.hasCollection({ - collection_name: this.collectionName, - }); - if (hasColResp.status.error_code !== ErrorCode.SUCCESS) { - throw new Error( - `Error checking collection: ${JSON.stringify(hasColResp, null, 2)}` - ); - } - - if (hasColResp.value === false) { - if (vectors === undefined || documents === undefined) { - throw new Error( - `Collection not found: ${this.collectionName}, please provide vectors and documents to create collection.` - ); - } - await this.createCollection(vectors, documents); - } else { - await this.grabCollectionFields(); - } - } - - /** - * Creates a collection in the Milvus database. - * @param vectors Array of vectors to be added to the new collection. - * @param documents Array of Document instances to be added to the new collection. - * @returns Promise resolving to void. - */ - async createCollection( - vectors: number[][], - documents: Document[] - ): Promise { - const fieldList: FieldType[] = []; - - fieldList.push(...createFieldTypeForMetadata(documents, this.primaryField)); - - fieldList.push( - { - name: this.primaryField, - description: "Primary key", - data_type: DataType.Int64, - is_primary_key: true, - autoID: this.autoId, - }, - { - name: this.textField, - description: "Text field", - data_type: DataType.VarChar, - type_params: { - max_length: - this.textFieldMaxLength > 0 - ? this.textFieldMaxLength.toString() - : getTextFieldMaxLength(documents).toString(), - }, - }, - { - name: this.vectorField, - description: "Vector field", - data_type: DataType.FloatVector, - type_params: { - dim: getVectorFieldDim(vectors).toString(), - }, - } - ); - - fieldList.forEach((field) => { - if (!field.autoID) { - this.fields.push(field.name); - } - }); - - const createRes = await this.client.createCollection({ - collection_name: this.collectionName, - fields: fieldList, - }); - - if (createRes.error_code !== ErrorCode.SUCCESS) { - console.log(createRes); - throw new Error(`Failed to create collection: ${createRes}`); - } - - await this.client.createIndex({ - collection_name: this.collectionName, - field_name: this.vectorField, - extra_params: this.indexCreateParams, - }); - } - - /** - * Retrieves the fields of a collection in the Milvus database. - * @returns Promise resolving to void. - */ - async grabCollectionFields(): Promise { - if (!this.collectionName) { - throw new Error("Need collection name to grab collection fields"); - } - if ( - this.primaryField && - this.vectorField && - this.textField && - this.fields.length > 0 - ) { - return; - } - const desc = await this.client.describeCollection({ - collection_name: this.collectionName, - }); - desc.schema.fields.forEach((field) => { - this.fields.push(field.name); - if (field.autoID) { - const index = this.fields.indexOf(field.name); - if (index !== -1) { - this.fields.splice(index, 1); - } - } - if (field.is_primary_key) { - this.primaryField = field.name; - } - const dtype = DataTypeMap[field.data_type]; - if (dtype === DataType.FloatVector || dtype === DataType.BinaryVector) { - this.vectorField = field.name; - } - - if (dtype === DataType.VarChar && field.name === MILVUS_TEXT_FIELD_NAME) { - this.textField = field.name; - } - }); - } - - /** - * Creates a Milvus instance from a set of texts and their associated - * metadata. - * @param texts Array of texts to be added to the database. - * @param metadatas Array of metadata objects associated with the texts. - * @param embeddings Embeddings instance used to generate vector embeddings for the texts. - * @param dbConfig Optional configuration for the Milvus database. - * @returns Promise resolving to a new Milvus instance. - */ - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - dbConfig?: MilvusLibArgs - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - return Milvus.fromDocuments(docs, embeddings, dbConfig); - } - - /** - * Creates a Milvus instance from a set of Document instances. - * @param docs Array of Document instances to be added to the database. - * @param embeddings Embeddings instance used to generate vector embeddings for the documents. - * @param dbConfig Optional configuration for the Milvus database. - * @returns Promise resolving to a new Milvus instance. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig?: MilvusLibArgs - ): Promise { - const args: MilvusLibArgs = { - collectionName: dbConfig?.collectionName || genCollectionName(), - url: dbConfig?.url, - ssl: dbConfig?.ssl, - username: dbConfig?.username, - password: dbConfig?.password, - textField: dbConfig?.textField, - primaryField: dbConfig?.primaryField, - vectorField: dbConfig?.vectorField, - clientConfig: dbConfig?.clientConfig, - autoId: dbConfig?.autoId, - }; - const instance = new this(embeddings, args); - await instance.addDocuments(docs); - return instance; - } - - /** - * Creates a Milvus instance from an existing collection in the Milvus - * database. - * @param embeddings Embeddings instance used to generate vector embeddings for the documents in the collection. - * @param dbConfig Configuration for the Milvus database. - * @returns Promise resolving to a new Milvus instance. - */ - static async fromExistingCollection( - embeddings: Embeddings, - dbConfig: MilvusLibArgs - ): Promise { - const instance = new this(embeddings, dbConfig); - await instance.ensureCollection(); - return instance; - } - - /** - * Deletes data from the Milvus database. - * @param params Object containing a filter to apply to the deletion. - * @returns Promise resolving to void. - */ - async delete(params: { filter: string }): Promise { - const hasColResp = await this.client.hasCollection({ - collection_name: this.collectionName, - }); - if (hasColResp.status.error_code !== ErrorCode.SUCCESS) { - throw new Error(`Error checking collection: ${hasColResp}`); - } - if (hasColResp.value === false) { - throw new Error( - `Collection not found: ${this.collectionName}, please create collection before search.` - ); - } - - const { filter } = params; - - const deleteResp = await this.client.deleteEntities({ - collection_name: this.collectionName, - expr: filter, - }); - - if (deleteResp.status.error_code !== ErrorCode.SUCCESS) { - throw new Error(`Error deleting data: ${JSON.stringify(deleteResp)}`); - } - } -} - -function createFieldTypeForMetadata( - documents: Document[], - primaryFieldName: string -): FieldType[] { - const sampleMetadata = documents[0].metadata; - let textFieldMaxLength = 0; - let jsonFieldMaxLength = 0; - documents.forEach(({ metadata }) => { - // check all keys name and count in metadata is same as sampleMetadata - Object.keys(metadata).forEach((key) => { - if ( - !(key in metadata) || - typeof metadata[key] !== typeof sampleMetadata[key] - ) { - throw new Error( - "All documents must have same metadata keys and datatype" - ); - } - - // find max length of string field and json field, cache json string value - if (typeof metadata[key] === "string") { - if (metadata[key].length > textFieldMaxLength) { - textFieldMaxLength = metadata[key].length; - } - } else if (typeof metadata[key] === "object") { - const json = JSON.stringify(metadata[key]); - if (json.length > jsonFieldMaxLength) { - jsonFieldMaxLength = json.length; - } - } - }); - }); - - const fields: FieldType[] = []; - for (const [key, value] of Object.entries(sampleMetadata)) { - const type = typeof value; - - if (key === primaryFieldName) { - /** - * skip primary field - * because we will create primary field in createCollection - * */ - } else if (type === "string") { - fields.push({ - name: key, - description: `Metadata String field`, - data_type: DataType.VarChar, - type_params: { - max_length: textFieldMaxLength.toString(), - }, - }); - } else if (type === "number") { - fields.push({ - name: key, - description: `Metadata Number field`, - data_type: DataType.Float, - }); - } else if (type === "boolean") { - fields.push({ - name: key, - description: `Metadata Boolean field`, - data_type: DataType.Bool, - }); - } else if (value === null) { - // skip - } else { - // use json for other types - try { - fields.push({ - name: key, - description: `Metadata JSON field`, - data_type: DataType.VarChar, - type_params: { - max_length: jsonFieldMaxLength.toString(), - }, - }); - } catch (e) { - throw new Error("Failed to parse metadata field as JSON"); - } - } - } - return fields; -} - -function genCollectionName(): string { - return `${MILVUS_COLLECTION_NAME_PREFIX}_${uuid.v4().replaceAll("-", "")}`; -} - -function getTextFieldMaxLength(documents: Document[]) { - let textMaxLength = 0; - const textEncoder = new TextEncoder(); - // eslint-disable-next-line no-plusplus - for (let i = 0; i < documents.length; i++) { - const text = documents[i].pageContent; - const textLengthInBytes = textEncoder.encode(text).length; - if (textLengthInBytes > textMaxLength) { - textMaxLength = textLengthInBytes; - } - } - return textMaxLength; -} - -function getVectorFieldDim(vectors: number[][]) { - if (vectors.length === 0) { - throw new Error("No vectors found"); - } - return vectors[0].length; -} - -// eslint-disable-next-line @typescript-eslint/no-explicit-any -function checkJsonString(value: string): { isJson: boolean; obj: any } { - try { - const result = JSON.parse(value); - return { isJson: true, obj: result }; - } catch (e) { - return { isJson: false, obj: null }; - } -} +export * from "@langchain/community/vectorstores/milvus"; diff --git a/langchain/src/vectorstores/momento_vector_index.ts b/langchain/src/vectorstores/momento_vector_index.ts index afc13a1b8248..ad282e0176ea 100644 --- a/langchain/src/vectorstores/momento_vector_index.ts +++ b/langchain/src/vectorstores/momento_vector_index.ts @@ -1,399 +1 @@ -/* eslint-disable no-instanceof/no-instanceof */ -/* eslint-disable @typescript-eslint/no-explicit-any */ -import { - ALL_VECTOR_METADATA, - IVectorIndexClient, - VectorIndexItem, - CreateVectorIndex, - VectorUpsertItemBatch, - VectorDeleteItemBatch, - VectorSearch, - VectorSearchAndFetchVectors, -} from "@gomomento/sdk-core"; -import * as uuid from "uuid"; -import { Document } from "../document.js"; -import { Embeddings } from "../embeddings/base.js"; -import { MaxMarginalRelevanceSearchOptions, VectorStore } from "./base.js"; -import { maximalMarginalRelevance } from "../util/math.js"; - -export interface DocumentProps { - ids: string[]; -} - -export interface MomentoVectorIndexLibArgs { - /** - * The Momento Vector Index client. - */ - client: IVectorIndexClient; - /** - * The name of the index to use to store the data. - * Defaults to "default". - */ - indexName?: string; - /** - * The name of the metadata field to use to store the text of the document. - * Defaults to "text". - */ - textField?: string; - /** - * Whether to create the index if it does not already exist. - * Defaults to true. - */ - ensureIndexExists?: boolean; -} - -export interface DeleteProps { - /** - * The ids of the documents to delete. - */ - ids: string[]; -} - -/** - * A vector store that uses the Momento Vector Index. - * - * @remarks - * To sign up for a free Momento account, visit https://console.gomomento.com. - */ -export class MomentoVectorIndex extends VectorStore { - private client: IVectorIndexClient; - - private indexName: string; - - private textField: string; - - private _ensureIndexExists: boolean; - - _vectorstoreType(): string { - return "momento"; - } - - /** - * Creates a new `MomentoVectorIndex` instance. - * @param embeddings The embeddings instance to use to generate embeddings from documents. - * @param args The arguments to use to configure the vector store. - */ - constructor(embeddings: Embeddings, args: MomentoVectorIndexLibArgs) { - super(embeddings, args); - - this.embeddings = embeddings; - this.client = args.client; - this.indexName = args.indexName ?? "default"; - this.textField = args.textField ?? "text"; - this._ensureIndexExists = args.ensureIndexExists ?? true; - } - - /** - * Returns the Momento Vector Index client. - * @returns The Momento Vector Index client. - */ - public getClient(): IVectorIndexClient { - return this.client; - } - - /** - * Creates the index if it does not already exist. - * @param numDimensions The number of dimensions of the vectors to be stored in the index. - * @returns Promise that resolves to true if the index was created, false if it already existed. - */ - private async ensureIndexExists(numDimensions: number): Promise { - const response = await this.client.createIndex( - this.indexName, - numDimensions - ); - if (response instanceof CreateVectorIndex.Success) { - return true; - } else if (response instanceof CreateVectorIndex.AlreadyExists) { - return false; - } else if (response instanceof CreateVectorIndex.Error) { - throw new Error(response.toString()); - } else { - throw new Error(`Unknown response type: ${response.toString()}`); - } - } - - /** - * Converts the documents to a format that can be stored in the index. - * - * This is necessary because the Momento Vector Index requires that the metadata - * be a map of strings to strings. - * @param vectors The vectors to convert. - * @param documents The documents to convert. - * @param ids The ids to convert. - * @returns The converted documents. - */ - private prepareItemBatch( - vectors: number[][], - documents: Document>[], - ids: string[] - ): VectorIndexItem[] { - return vectors.map((vector, idx) => ({ - id: ids[idx], - vector, - metadata: { - ...documents[idx].metadata, - [this.textField]: documents[idx].pageContent, - }, - })); - } - - /** - * Adds vectors to the index. - * - * @remarks If the index does not already exist, it will be created if `ensureIndexExists` is true. - * @param vectors The vectors to add to the index. - * @param documents The documents to add to the index. - * @param documentProps The properties of the documents to add to the index, specifically the ids. - * @returns Promise that resolves when the vectors have been added to the index. Also returns the ids of the - * documents that were added. - */ - public async addVectors( - vectors: number[][], - documents: Document>[], - documentProps?: DocumentProps - ): Promise { - if (vectors.length === 0) { - return; - } - - if (documents.length !== vectors.length) { - throw new Error( - `Number of vectors (${vectors.length}) does not equal number of documents (${documents.length})` - ); - } - - if (vectors.some((v) => v.length !== vectors[0].length)) { - throw new Error("All vectors must have the same length"); - } - - if ( - documentProps?.ids !== undefined && - documentProps.ids.length !== vectors.length - ) { - throw new Error( - `Number of ids (${ - documentProps?.ids?.length || "null" - }) does not equal number of vectors (${vectors.length})` - ); - } - - if (this._ensureIndexExists) { - await this.ensureIndexExists(vectors[0].length); - } - const documentIds = documentProps?.ids ?? documents.map(() => uuid.v4()); - - const batchSize = 128; - const numBatches = Math.ceil(vectors.length / batchSize); - - // Add each batch of vectors to the index - for (let i = 0; i < numBatches; i += 1) { - const [startIndex, endIndex] = [ - i * batchSize, - Math.min((i + 1) * batchSize, vectors.length), - ]; - - const batchVectors = vectors.slice(startIndex, endIndex); - const batchDocuments = documents.slice(startIndex, endIndex); - const batchDocumentIds = documentIds.slice(startIndex, endIndex); - - // Insert the items to the index - const response = await this.client.upsertItemBatch( - this.indexName, - this.prepareItemBatch(batchVectors, batchDocuments, batchDocumentIds) - ); - if (response instanceof VectorUpsertItemBatch.Success) { - // eslint-disable-next-line no-continue - continue; - } else if (response instanceof VectorUpsertItemBatch.Error) { - throw new Error(response.toString()); - } else { - throw new Error(`Unknown response type: ${response.toString()}`); - } - } - } - - /** - * Adds vectors to the index. Generates embeddings from the documents - * using the `Embeddings` instance passed to the constructor. - * @param documents Array of `Document` instances to be added to the index. - * @returns Promise that resolves when the documents have been added to the index. - */ - async addDocuments( - documents: Document[], - documentProps?: DocumentProps - ): Promise { - const texts = documents.map(({ pageContent }) => pageContent); - await this.addVectors( - await this.embeddings.embedDocuments(texts), - documents, - documentProps - ); - } - - /** - * Deletes vectors from the index by id. - * @param params The parameters to use to delete the vectors, specifically the ids. - */ - public async delete(params: DeleteProps): Promise { - const response = await this.client.deleteItemBatch( - this.indexName, - params.ids - ); - if (response instanceof VectorDeleteItemBatch.Success) { - // pass - } else if (response instanceof VectorDeleteItemBatch.Error) { - throw new Error(response.toString()); - } else { - throw new Error(`Unknown response type: ${response.toString()}`); - } - } - - /** - * Searches the index for the most similar vectors to the query vector. - * @param query The query vector. - * @param k The number of results to return. - * @returns Promise that resolves to the documents of the most similar vectors - * to the query vector. - */ - public async similaritySearchVectorWithScore( - query: number[], - k: number - ): Promise<[Document>, number][]> { - const response = await this.client.search(this.indexName, query, { - topK: k, - metadataFields: ALL_VECTOR_METADATA, - }); - if (response instanceof VectorSearch.Success) { - if (response.hits === undefined) { - return []; - } - - return response.hits().map((hit) => [ - new Document({ - pageContent: hit.metadata[this.textField]?.toString() ?? "", - metadata: Object.fromEntries( - Object.entries(hit.metadata).filter( - ([key]) => key !== this.textField - ) - ), - }), - hit.score, - ]); - } else if (response instanceof VectorSearch.Error) { - throw new Error(response.toString()); - } else { - throw new Error(`Unknown response type: ${response.toString()}`); - } - } - - /** - * Return documents selected using the maximal marginal relevance. - * Maximal marginal relevance optimizes for similarity to the query AND diversity - * among selected documents. - * - * @param {string} query - Text to look up documents similar to. - * @param {number} options.k - Number of documents to return. - * @param {number} options.fetchK - Number of documents to fetch before passing to the MMR algorithm. - * @param {number} options.lambda - Number between 0 and 1 that determines the degree of diversity among the results, - * where 0 corresponds to maximum diversity and 1 to minimum diversity. - * @param {this["FilterType"]} options.filter - Optional filter - * @param _callbacks - * - * @returns {Promise} - List of documents selected by maximal marginal relevance. - */ - async maxMarginalRelevanceSearch( - query: string, - options: MaxMarginalRelevanceSearchOptions - ): Promise { - const queryEmbedding = await this.embeddings.embedQuery(query); - const response = await this.client.searchAndFetchVectors( - this.indexName, - queryEmbedding, - { topK: options.fetchK ?? 20, metadataFields: ALL_VECTOR_METADATA } - ); - - if (response instanceof VectorSearchAndFetchVectors.Success) { - const hits = response.hits(); - - // Gather the embeddings of the search results - const embeddingList = hits.map((hit) => hit.vector); - - // Gather the ids of the most relevant results when applying MMR - const mmrIndexes = maximalMarginalRelevance( - queryEmbedding, - embeddingList, - options.lambda, - options.k - ); - - const finalResult = mmrIndexes.map((index) => { - const hit = hits[index]; - const { [this.textField]: pageContent, ...metadata } = hit.metadata; - return new Document({ metadata, pageContent: pageContent as string }); - }); - return finalResult; - } else if (response instanceof VectorSearchAndFetchVectors.Error) { - throw new Error(response.toString()); - } else { - throw new Error(`Unknown response type: ${response.toString()}`); - } - } - - /** - * Stores the documents in the index. - * - * Converts the documents to vectors using the `Embeddings` instance passed. - * @param texts The texts to store in the index. - * @param metadatas The metadata to store in the index. - * @param embeddings The embeddings instance to use to generate embeddings from the documents. - * @param dbConfig The configuration to use to instantiate the vector store. - * @param documentProps The properties of the documents to add to the index, specifically the ids. - * @returns Promise that resolves to the vector store. - */ - public static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - dbConfig: MomentoVectorIndexLibArgs, - documentProps?: DocumentProps - ): Promise { - if (Array.isArray(metadatas) && texts.length !== metadatas.length) { - throw new Error( - `Number of texts (${texts.length}) does not equal number of metadatas (${metadatas.length})` - ); - } - - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment - const metadata: object = Array.isArray(metadatas) - ? metadatas[i] - : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - return await this.fromDocuments(docs, embeddings, dbConfig, documentProps); - } - - /** - * Stores the documents in the index. - * @param docs The documents to store in the index. - * @param embeddings The embeddings instance to use to generate embeddings from the documents. - * @param dbConfig The configuration to use to instantiate the vector store. - * @param documentProps The properties of the documents to add to the index, specifically the ids. - * @returns Promise that resolves to the vector store. - */ - public static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig: MomentoVectorIndexLibArgs, - documentProps?: DocumentProps - ): Promise { - const vectorStore = new MomentoVectorIndex(embeddings, dbConfig); - await vectorStore.addDocuments(docs, documentProps); - return vectorStore; - } -} +export * from "@langchain/community/vectorstores/momento_vector_index"; diff --git a/langchain/src/vectorstores/mongodb_atlas.ts b/langchain/src/vectorstores/mongodb_atlas.ts index 4815330bddb4..3d211e633bd9 100755 --- a/langchain/src/vectorstores/mongodb_atlas.ts +++ b/langchain/src/vectorstores/mongodb_atlas.ts @@ -1,279 +1 @@ -import type { Collection, Document as MongoDBDocument } from "mongodb"; -import { MaxMarginalRelevanceSearchOptions, VectorStore } from "./base.js"; -import { Embeddings } from "../embeddings/base.js"; -import { Document } from "../document.js"; -import { maximalMarginalRelevance } from "../util/math.js"; - -/** - * Type that defines the arguments required to initialize the - * MongoDBAtlasVectorSearch class. It includes the MongoDB collection, - * index name, text key, and embedding key. - */ -export type MongoDBAtlasVectorSearchLibArgs = { - readonly collection: Collection; - readonly indexName?: string; - readonly textKey?: string; - readonly embeddingKey?: string; -}; - -/** - * Type that defines the filter used in the - * similaritySearchVectorWithScore and maxMarginalRelevanceSearch methods. - * It includes pre-filter, post-filter pipeline, and a flag to include - * embeddings. - */ -type MongoDBAtlasFilter = { - preFilter?: MongoDBDocument; - postFilterPipeline?: MongoDBDocument[]; - includeEmbeddings?: boolean; -} & MongoDBDocument; - -/** - * Class that is a wrapper around MongoDB Atlas Vector Search. It is used - * to store embeddings in MongoDB documents, create a vector search index, - * and perform K-Nearest Neighbors (KNN) search with an approximate - * nearest neighbor algorithm. - */ -export class MongoDBAtlasVectorSearch extends VectorStore { - declare FilterType: MongoDBAtlasFilter; - - private readonly collection: Collection; - - private readonly indexName: string; - - private readonly textKey: string; - - private readonly embeddingKey: string; - - _vectorstoreType(): string { - return "mongodb_atlas"; - } - - constructor(embeddings: Embeddings, args: MongoDBAtlasVectorSearchLibArgs) { - super(embeddings, args); - this.collection = args.collection; - this.indexName = args.indexName ?? "default"; - this.textKey = args.textKey ?? "text"; - this.embeddingKey = args.embeddingKey ?? "embedding"; - } - - /** - * Method to add vectors and their corresponding documents to the MongoDB - * collection. - * @param vectors Vectors to be added. - * @param documents Corresponding documents to be added. - * @returns Promise that resolves when the vectors and documents have been added. - */ - async addVectors(vectors: number[][], documents: Document[]): Promise { - const docs = vectors.map((embedding, idx) => ({ - [this.textKey]: documents[idx].pageContent, - [this.embeddingKey]: embedding, - ...documents[idx].metadata, - })); - await this.collection.insertMany(docs); - } - - /** - * Method to add documents to the MongoDB collection. It first converts - * the documents to vectors using the embeddings and then calls the - * addVectors method. - * @param documents Documents to be added. - * @returns Promise that resolves when the documents have been added. - */ - async addDocuments(documents: Document[]): Promise { - const texts = documents.map(({ pageContent }) => pageContent); - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents - ); - } - - /** - * Method that performs a similarity search on the vectors stored in the - * MongoDB collection. It returns a list of documents and their - * corresponding similarity scores. - * @param query Query vector for the similarity search. - * @param k Number of nearest neighbors to return. - * @param filter Optional filter to be applied. - * @returns Promise that resolves to a list of documents and their corresponding similarity scores. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: MongoDBAtlasFilter - ): Promise<[Document, number][]> { - const postFilterPipeline = filter?.postFilterPipeline ?? []; - const preFilter: MongoDBDocument | undefined = - filter?.preFilter || - filter?.postFilterPipeline || - filter?.includeEmbeddings - ? filter.preFilter - : filter; - const removeEmbeddingsPipeline = !filter?.includeEmbeddings - ? [ - { - $project: { - [this.embeddingKey]: 0, - }, - }, - ] - : []; - - const pipeline: MongoDBDocument[] = [ - { - $vectorSearch: { - queryVector: MongoDBAtlasVectorSearch.fixArrayPrecision(query), - index: this.indexName, - path: this.embeddingKey, - limit: k, - numCandidates: 10 * k, - ...(preFilter && { filter: preFilter }), - }, - }, - { - $set: { - score: { $meta: "vectorSearchScore" }, - }, - }, - ...removeEmbeddingsPipeline, - ...postFilterPipeline, - ]; - - const results = this.collection - .aggregate(pipeline) - .map<[Document, number]>((result) => { - const { score, [this.textKey]: text, ...metadata } = result; - return [new Document({ pageContent: text, metadata }), score]; - }); - - return results.toArray(); - } - - /** - * Return documents selected using the maximal marginal relevance. - * Maximal marginal relevance optimizes for similarity to the query AND diversity - * among selected documents. - * - * @param {string} query - Text to look up documents similar to. - * @param {number} options.k - Number of documents to return. - * @param {number} options.fetchK=20- Number of documents to fetch before passing to the MMR algorithm. - * @param {number} options.lambda=0.5 - Number between 0 and 1 that determines the degree of diversity among the results, - * where 0 corresponds to maximum diversity and 1 to minimum diversity. - * @param {MongoDBAtlasFilter} options.filter - Optional Atlas Search operator to pre-filter on document fields - * or post-filter following the knnBeta search. - * - * @returns {Promise} - List of documents selected by maximal marginal relevance. - */ - async maxMarginalRelevanceSearch( - query: string, - options: MaxMarginalRelevanceSearchOptions - ): Promise { - const { k, fetchK = 20, lambda = 0.5, filter } = options; - - const queryEmbedding = await this.embeddings.embedQuery(query); - - // preserve the original value of includeEmbeddings - const includeEmbeddingsFlag = options.filter?.includeEmbeddings || false; - - // update filter to include embeddings, as they will be used in MMR - const includeEmbeddingsFilter = { - ...filter, - includeEmbeddings: true, - }; - - const resultDocs = await this.similaritySearchVectorWithScore( - MongoDBAtlasVectorSearch.fixArrayPrecision(queryEmbedding), - fetchK, - includeEmbeddingsFilter - ); - - const embeddingList = resultDocs.map( - (doc) => doc[0].metadata[this.embeddingKey] - ); - - const mmrIndexes = maximalMarginalRelevance( - queryEmbedding, - embeddingList, - lambda, - k - ); - - return mmrIndexes.map((idx) => { - const doc = resultDocs[idx][0]; - - // remove embeddings if they were not requested originally - if (!includeEmbeddingsFlag) { - delete doc.metadata[this.embeddingKey]; - } - return doc; - }); - } - - /** - * Static method to create an instance of MongoDBAtlasVectorSearch from a - * list of texts. It first converts the texts to vectors and then adds - * them to the MongoDB collection. - * @param texts List of texts to be converted to vectors. - * @param metadatas Metadata for the texts. - * @param embeddings Embeddings to be used for conversion. - * @param dbConfig Database configuration for MongoDB Atlas. - * @returns Promise that resolves to a new instance of MongoDBAtlasVectorSearch. - */ - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - dbConfig: MongoDBAtlasVectorSearchLibArgs - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - return MongoDBAtlasVectorSearch.fromDocuments(docs, embeddings, dbConfig); - } - - /** - * Static method to create an instance of MongoDBAtlasVectorSearch from a - * list of documents. It first converts the documents to vectors and then - * adds them to the MongoDB collection. - * @param docs List of documents to be converted to vectors. - * @param embeddings Embeddings to be used for conversion. - * @param dbConfig Database configuration for MongoDB Atlas. - * @returns Promise that resolves to a new instance of MongoDBAtlasVectorSearch. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig: MongoDBAtlasVectorSearchLibArgs - ): Promise { - const instance = new this(embeddings, dbConfig); - await instance.addDocuments(docs); - return instance; - } - - /** - * Static method to fix the precision of the array that ensures that - * every number in this array is always float when casted to other types. - * This is needed since MongoDB Atlas Vector Search does not cast integer - * inside vector search to float automatically. - * This method shall introduce a hint of error but should be safe to use - * since introduced error is very small, only applies to integer numbers - * returned by embeddings, and most embeddings shall not have precision - * as high as 15 decimal places. - * @param array Array of number to be fixed. - * @returns - */ - static fixArrayPrecision(array: number[]) { - return array.map((value) => { - if (Number.isInteger(value)) { - return value + 0.000000000000001; - } - return value; - }); - } -} +export * from "@langchain/community/vectorstores/mongodb_atlas"; diff --git a/langchain/src/vectorstores/myscale.ts b/langchain/src/vectorstores/myscale.ts index 3fdd997c44a0..d1c1a8cda53e 100644 --- a/langchain/src/vectorstores/myscale.ts +++ b/langchain/src/vectorstores/myscale.ts @@ -1,314 +1 @@ -import * as uuid from "uuid"; -import { ClickHouseClient, createClient } from "@clickhouse/client"; - -import { Embeddings } from "../embeddings/base.js"; -import { VectorStore } from "./base.js"; -import { Document } from "../document.js"; - -/** - * Arguments for the MyScaleStore class, which include the host, port, - * protocol, username, password, index type, index parameters, column map, - * database, table, and metric. - */ -export interface MyScaleLibArgs { - host: string; - port: string | number; - protocol?: string; - username: string; - password: string; - indexType?: string; - indexParam?: Record; - columnMap?: ColumnMap; - database?: string; - table?: string; - metric?: metric; -} - -/** - * Mapping of columns in the MyScale database. - */ -export interface ColumnMap { - id: string; - text: string; - vector: string; - metadata: string; -} - -/** - * Type of metric used in the MyScale database. - */ -export type metric = "L2" | "Cosine" | "IP"; - -/** - * Type for filtering search results in the MyScale database. - */ -export interface MyScaleFilter { - whereStr: string; -} - -/** - * Class for interacting with the MyScale database. It extends the - * VectorStore class and provides methods for adding vectors and - * documents, searching for similar vectors, and creating instances from - * texts or documents. - */ -export class MyScaleStore extends VectorStore { - declare FilterType: MyScaleFilter; - - private client: ClickHouseClient; - - private indexType: string; - - private indexParam: Record; - - private columnMap: ColumnMap; - - private database: string; - - private table: string; - - private metric: metric; - - private isInitialized = false; - - _vectorstoreType(): string { - return "myscale"; - } - - constructor(embeddings: Embeddings, args: MyScaleLibArgs) { - super(embeddings, args); - - this.indexType = args.indexType || "MSTG"; - this.indexParam = args.indexParam || {}; - this.columnMap = args.columnMap || { - id: "id", - text: "text", - vector: "vector", - metadata: "metadata", - }; - this.database = args.database || "default"; - this.table = args.table || "vector_table"; - this.metric = args.metric || "Cosine"; - - this.client = createClient({ - host: `${args.protocol ?? "https://"}${args.host}:${args.port}`, - username: args.username, - password: args.password, - session_id: uuid.v4(), - }); - } - - /** - * Method to add vectors to the MyScale database. - * @param vectors The vectors to add. - * @param documents The documents associated with the vectors. - * @returns Promise that resolves when the vectors have been added. - */ - async addVectors(vectors: number[][], documents: Document[]): Promise { - if (vectors.length === 0) { - return; - } - - if (!this.isInitialized) { - await this.initialize(vectors[0].length); - } - - const queryStr = this.buildInsertQuery(vectors, documents); - await this.client.exec({ query: queryStr }); - } - - /** - * Method to add documents to the MyScale database. - * @param documents The documents to add. - * @returns Promise that resolves when the documents have been added. - */ - async addDocuments(documents: Document[]): Promise { - return this.addVectors( - await this.embeddings.embedDocuments(documents.map((d) => d.pageContent)), - documents - ); - } - - /** - * Method to search for vectors that are similar to a given query vector. - * @param query The query vector. - * @param k The number of similar vectors to return. - * @param filter Optional filter for the search results. - * @returns Promise that resolves with an array of tuples, each containing a Document and a score. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: this["FilterType"] - ): Promise<[Document, number][]> { - if (!this.isInitialized) { - await this.initialize(query.length); - } - const queryStr = this.buildSearchQuery(query, k, filter); - - const queryResultSet = await this.client.query({ query: queryStr }); - const queryResult: { - data: { text: string; metadata: object; dist: number }[]; - } = await queryResultSet.json(); - - const result: [Document, number][] = queryResult.data.map((item) => [ - new Document({ pageContent: item.text, metadata: item.metadata }), - item.dist, - ]); - - return result; - } - - /** - * Static method to create an instance of MyScaleStore from texts. - * @param texts The texts to use. - * @param metadatas The metadata associated with the texts. - * @param embeddings The embeddings to use. - * @param args The arguments for the MyScaleStore. - * @returns Promise that resolves with a new instance of MyScaleStore. - */ - static async fromTexts( - texts: string[], - metadatas: object | object[], - embeddings: Embeddings, - args: MyScaleLibArgs - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - return MyScaleStore.fromDocuments(docs, embeddings, args); - } - - /** - * Static method to create an instance of MyScaleStore from documents. - * @param docs The documents to use. - * @param embeddings The embeddings to use. - * @param args The arguments for the MyScaleStore. - * @returns Promise that resolves with a new instance of MyScaleStore. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - args: MyScaleLibArgs - ): Promise { - const instance = new this(embeddings, args); - await instance.addDocuments(docs); - return instance; - } - - /** - * Static method to create an instance of MyScaleStore from an existing - * index. - * @param embeddings The embeddings to use. - * @param args The arguments for the MyScaleStore. - * @returns Promise that resolves with a new instance of MyScaleStore. - */ - static async fromExistingIndex( - embeddings: Embeddings, - args: MyScaleLibArgs - ): Promise { - const instance = new this(embeddings, args); - - await instance.initialize(); - return instance; - } - - /** - * Method to initialize the MyScale database. - * @param dimension Optional dimension of the vectors. - * @returns Promise that resolves when the database has been initialized. - */ - private async initialize(dimension?: number): Promise { - const dim = dimension ?? (await this.embeddings.embedQuery("test")).length; - - let indexParamStr = ""; - for (const [key, value] of Object.entries(this.indexParam)) { - indexParamStr += `, '${key}=${value}'`; - } - - const query = ` - CREATE TABLE IF NOT EXISTS ${this.database}.${this.table}( - ${this.columnMap.id} String, - ${this.columnMap.text} String, - ${this.columnMap.vector} Array(Float32), - ${this.columnMap.metadata} JSON, - CONSTRAINT cons_vec_len CHECK length(${this.columnMap.vector}) = ${dim}, - VECTOR INDEX vidx ${this.columnMap.vector} TYPE ${this.indexType}('metric_type=${this.metric}'${indexParamStr}) - ) ENGINE = MergeTree ORDER BY ${this.columnMap.id} - `; - - await this.client.exec({ query: "SET allow_experimental_object_type=1" }); - await this.client.exec({ - query: "SET output_format_json_named_tuples_as_objects = 1", - }); - await this.client.exec({ query }); - this.isInitialized = true; - } - - /** - * Method to build an SQL query for inserting vectors and documents into - * the MyScale database. - * @param vectors The vectors to insert. - * @param documents The documents to insert. - * @returns The SQL query string. - */ - private buildInsertQuery(vectors: number[][], documents: Document[]): string { - const columnsStr = Object.values(this.columnMap).join(", "); - - const data: string[] = []; - for (let i = 0; i < vectors.length; i += 1) { - const vector = vectors[i]; - const document = documents[i]; - const item = [ - `'${uuid.v4()}'`, - `'${this.escapeString(document.pageContent)}'`, - `[${vector}]`, - `'${JSON.stringify(document.metadata)}'`, - ].join(", "); - data.push(`(${item})`); - } - const dataStr = data.join(", "); - - return ` - INSERT INTO TABLE - ${this.database}.${this.table}(${columnsStr}) - VALUES - ${dataStr} - `; - } - - private escapeString(str: string): string { - return str.replace(/\\/g, "\\\\").replace(/'/g, "\\'"); - } - - /** - * Method to build an SQL query for searching for similar vectors in the - * MyScale database. - * @param query The query vector. - * @param k The number of similar vectors to return. - * @param filter Optional filter for the search results. - * @returns The SQL query string. - */ - private buildSearchQuery( - query: number[], - k: number, - filter?: MyScaleFilter - ): string { - const order = this.metric === "IP" ? "DESC" : "ASC"; - - const whereStr = filter ? `PREWHERE ${filter.whereStr}` : ""; - return ` - SELECT ${this.columnMap.text} AS text, ${this.columnMap.metadata} AS metadata, dist - FROM ${this.database}.${this.table} - ${whereStr} - ORDER BY distance(${this.columnMap.vector}, [${query}]) AS dist ${order} - LIMIT ${k} - `; - } -} +export * from "@langchain/community/vectorstores/myscale"; diff --git a/langchain/src/vectorstores/neo4j_vector.ts b/langchain/src/vectorstores/neo4j_vector.ts index e7c496a6e9ed..0338e358a89b 100644 --- a/langchain/src/vectorstores/neo4j_vector.ts +++ b/langchain/src/vectorstores/neo4j_vector.ts @@ -1,731 +1 @@ -import neo4j from "neo4j-driver"; -import * as uuid from "uuid"; -import { Document } from "../document.js"; -import { Embeddings } from "../embeddings/base.js"; -import { VectorStore } from "./base.js"; - -export type SearchType = "vector" | "hybrid"; - -export type DistanceStrategy = "euclidean" | "cosine"; - -interface Neo4jVectorStoreArgs { - url: string; - username: string; - password: string; - database?: string; - preDeleteCollection?: boolean; - textNodeProperty?: string; - textNodeProperties?: string[]; - embeddingNodeProperty?: string; - keywordIndexName?: string; - indexName?: string; - searchType?: SearchType; - retrievalQuery?: string; - nodeLabel?: string; - createIdIndex?: boolean; -} - -const DEFAULT_SEARCH_TYPE = "vector"; -const DEFAULT_DISTANCE_STRATEGY = "cosine"; - -/** - * @security *Security note*: Make sure that the database connection uses credentials - * that are narrowly-scoped to only include necessary permissions. - * Failure to do so may result in data corruption or loss, since the calling - * code may attempt commands that would result in deletion, mutation - * of data if appropriately prompted or reading sensitive data if such - * data is present in the database. - * The best way to guard against such negative outcomes is to (as appropriate) - * limit the permissions granted to the credentials used with this tool. - * For example, creating read only users for the database is a good way to - * ensure that the calling code cannot mutate or delete data. - * - * @link See https://js.langchain.com/docs/security for more information. - */ -export class Neo4jVectorStore extends VectorStore { - private driver: neo4j.Driver; - - private database: string; - - private preDeleteCollection: boolean; - - private nodeLabel: string; - - private embeddingNodeProperty: string; - - private embeddingDimension: number; - - private textNodeProperty: string; - - private keywordIndexName: string; - - private indexName: string; - - private retrievalQuery: string; - - private searchType: SearchType; - - private distanceStrategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY; - - _vectorstoreType(): string { - return "neo4jvector"; - } - - constructor(embeddings: Embeddings, config: Neo4jVectorStoreArgs) { - super(embeddings, config); - } - - static async initialize( - embeddings: Embeddings, - config: Neo4jVectorStoreArgs - ) { - const store = new Neo4jVectorStore(embeddings, config); - await store._initializeDriver(config); - await store._verifyConnectivity(); - - const { - preDeleteCollection = false, - nodeLabel = "Chunk", - textNodeProperty = "text", - embeddingNodeProperty = "embedding", - keywordIndexName = "keyword", - indexName = "vector", - retrievalQuery = "", - searchType = DEFAULT_SEARCH_TYPE, - } = config; - - store.embeddingDimension = (await embeddings.embedQuery("foo")).length; - store.preDeleteCollection = preDeleteCollection; - store.nodeLabel = nodeLabel; - store.textNodeProperty = textNodeProperty; - store.embeddingNodeProperty = embeddingNodeProperty; - store.keywordIndexName = keywordIndexName; - store.indexName = indexName; - store.retrievalQuery = retrievalQuery; - store.searchType = searchType; - - if (store.preDeleteCollection) { - await store._dropIndex(); - } - - return store; - } - - async _initializeDriver({ - url, - username, - password, - database = "neo4j", - }: Neo4jVectorStoreArgs) { - try { - this.driver = neo4j.driver(url, neo4j.auth.basic(username, password)); - this.database = database; - } catch (error) { - throw new Error( - "Could not create a Neo4j driver instance. Please check the connection details." - ); - } - } - - async _verifyConnectivity() { - await this.driver.verifyAuthentication(); - } - - async close() { - await this.driver.close(); - } - - async _dropIndex() { - try { - await this.query(` - MATCH (n:\`${this.nodeLabel}\`) - CALL { - WITH n - DETACH DELETE n - } - IN TRANSACTIONS OF 10000 ROWS; - `); - await this.query(`DROP INDEX ${this.indexName}`); - } catch (error) { - console.error("An error occurred while dropping the index:", error); - } - } - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - async query(query: string, params: any = {}): Promise { - const session = this.driver.session({ database: this.database }); - const result = await session.run(query, params); - return toObjects(result.records); - } - - static async fromTexts( - texts: string[], - // eslint-disable-next-line @typescript-eslint/no-explicit-any - metadatas: any, - embeddings: Embeddings, - config: Neo4jVectorStoreArgs - ): Promise { - const docs = []; - - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - - return Neo4jVectorStore.fromDocuments(docs, embeddings, config); - } - - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - config: Neo4jVectorStoreArgs - ): Promise { - const { - searchType = DEFAULT_SEARCH_TYPE, - createIdIndex = true, - textNodeProperties = [], - } = config; - - const store = await this.initialize(embeddings, config); - - const embeddingDimension = await store.retrieveExistingIndex(); - - if (!embeddingDimension) { - await store.createNewIndex(); - } else if (store.embeddingDimension !== embeddingDimension) { - throw new Error( - `Index with name "${store.indexName}" already exists. The provided embedding function and vector index dimensions do not match. - Embedding function dimension: ${store.embeddingDimension} - Vector index dimension: ${embeddingDimension}` - ); - } - - if (searchType === "hybrid") { - const ftsNodeLabel = await store.retrieveExistingFtsIndex(); - - if (!ftsNodeLabel) { - await store.createNewKeywordIndex(textNodeProperties); - } else { - if (ftsNodeLabel !== store.nodeLabel) { - throw Error( - "Vector and keyword index don't index the same node label" - ); - } - } - } - - if (createIdIndex) { - await store.query( - `CREATE CONSTRAINT IF NOT EXISTS FOR (n:${store.nodeLabel}) REQUIRE n.id IS UNIQUE;` - ); - } - - await store.addDocuments(docs); - - return store; - } - - static async fromExistingIndex( - embeddings: Embeddings, - config: Neo4jVectorStoreArgs - ) { - const { searchType = DEFAULT_SEARCH_TYPE, keywordIndexName = "keyword" } = - config; - - if (searchType === "hybrid" && !keywordIndexName) { - throw Error( - "keyword_index name has to be specified when using hybrid search option" - ); - } - - const store = await this.initialize(embeddings, config); - const embeddingDimension = await store.retrieveExistingIndex(); - - if (!embeddingDimension) { - throw Error( - "The specified vector index name does not exist. Make sure to check if you spelled it correctly" - ); - } - - if (store.embeddingDimension !== embeddingDimension) { - throw new Error( - `The provided embedding function and vector index dimensions do not match. - Embedding function dimension: ${store.embeddingDimension} - Vector index dimension: ${embeddingDimension}` - ); - } - - if (searchType === "hybrid") { - const ftsNodeLabel = await store.retrieveExistingFtsIndex(); - - if (!ftsNodeLabel) { - throw Error( - "The specified keyword index name does not exist. Make sure to check if you spelled it correctly" - ); - } else { - if (ftsNodeLabel !== store.nodeLabel) { - throw Error( - "Vector and keyword index don't index the same node label" - ); - } - } - } - - return store; - } - - static async fromExistingGraph( - embeddings: Embeddings, - config: Neo4jVectorStoreArgs - ) { - const { - textNodeProperties = [], - embeddingNodeProperty, - searchType = DEFAULT_SEARCH_TYPE, - retrievalQuery = "", - nodeLabel, - } = config; - - let _retrievalQuery = retrievalQuery; - - if (textNodeProperties.length === 0) { - throw Error( - "Parameter `text_node_properties` must not be an empty array" - ); - } - - if (!retrievalQuery) { - _retrievalQuery = ` - RETURN reduce(str='', k IN ${JSON.stringify(textNodeProperties)} | - str + '\\n' + k + ': ' + coalesce(node[k], '')) AS text, - node {.*, \`${embeddingNodeProperty}\`: Null, id: Null, ${textNodeProperties - .map((prop) => `\`${prop}\`: Null`) - .join(", ")} } AS metadata, score - `; - } - - const store = await this.initialize(embeddings, { - ...config, - retrievalQuery: _retrievalQuery, - }); - - const embeddingDimension = await store.retrieveExistingIndex(); - - if (!embeddingDimension) { - await store.createNewIndex(); - } else if (store.embeddingDimension !== embeddingDimension) { - throw new Error( - `Index with name ${store.indexName} already exists. The provided embedding function and vector index dimensions do not match.\nEmbedding function dimension: ${store.embeddingDimension}\nVector index dimension: ${embeddingDimension}` - ); - } - - if (searchType === "hybrid") { - const ftsNodeLabel = await store.retrieveExistingFtsIndex( - textNodeProperties - ); - - if (!ftsNodeLabel) { - await store.createNewKeywordIndex(textNodeProperties); - } else { - if (ftsNodeLabel !== store.nodeLabel) { - throw Error( - "Vector and keyword index don't index the same node label" - ); - } - } - } - - // eslint-disable-next-line no-constant-condition - while (true) { - const fetchQuery = ` - MATCH (n:\`${nodeLabel}\`) - WHERE n.${embeddingNodeProperty} IS null - AND any(k in $props WHERE n[k] IS NOT null) - RETURN elementId(n) AS id, reduce(str='', k IN $props | - str + '\\n' + k + ':' + coalesce(n[k], '')) AS text - LIMIT 1000 - `; - - const data = await store.query(fetchQuery, { props: textNodeProperties }); - - if (!data) { - continue; - } - - const textEmbeddings = await embeddings.embedDocuments( - data.map((el) => el.text) - ); - - const params = { - data: data.map((el, index) => ({ - id: el.id, - embedding: textEmbeddings[index], - })), - }; - - await store.query( - ` - UNWIND $data AS row - MATCH (n:\`${nodeLabel}\`) - WHERE elementId(n) = row.id - CALL db.create.setVectorProperty(n, '${embeddingNodeProperty}', row.embedding) - YIELD node RETURN count(*) - `, - params - ); - - if (data.length < 1000) { - break; - } - } - - return store; - } - - async createNewIndex(): Promise { - const indexQuery = ` - CALL db.index.vector.createNodeIndex( - $index_name, - $node_label, - $embedding_node_property, - toInteger($embedding_dimension), - $similarity_metric - ) - `; - - const parameters = { - index_name: this.indexName, - node_label: this.nodeLabel, - embedding_node_property: this.embeddingNodeProperty, - embedding_dimension: this.embeddingDimension, - similarity_metric: this.distanceStrategy, - }; - - await this.query(indexQuery, parameters); - } - - async retrieveExistingIndex() { - let indexInformation = await this.query( - ` - SHOW INDEXES YIELD name, type, labelsOrTypes, properties, options - WHERE type = 'VECTOR' AND (name = $index_name - OR (labelsOrTypes[0] = $node_label AND - properties[0] = $embedding_node_property)) - RETURN name, labelsOrTypes, properties, options - `, - { - index_name: this.indexName, - node_label: this.nodeLabel, - embedding_node_property: this.embeddingNodeProperty, - } - ); - - if (indexInformation) { - indexInformation = this.sortByIndexName(indexInformation, this.indexName); - - try { - const [index] = indexInformation; - const [labelOrType] = index.labelsOrTypes; - const [property] = index.properties; - - this.indexName = index.name; - this.nodeLabel = labelOrType; - this.embeddingNodeProperty = property; - - const embeddingDimension = - index.options.indexConfig["vector.dimensions"]; - return Number(embeddingDimension); - } catch (error) { - return null; - } - } - - return null; - } - - async retrieveExistingFtsIndex( - textNodeProperties: string[] = [] - ): Promise { - const indexInformation = await this.query( - ` - SHOW INDEXES YIELD name, type, labelsOrTypes, properties, options - WHERE type = 'FULLTEXT' AND (name = $keyword_index_name - OR (labelsOrTypes = [$node_label] AND - properties = $text_node_property)) - RETURN name, labelsOrTypes, properties, options - `, - { - keyword_index_name: this.keywordIndexName, - node_label: this.nodeLabel, - text_node_property: - textNodeProperties.length > 0 - ? textNodeProperties - : [this.textNodeProperty], - } - ); - - if (indexInformation) { - // Sort the index information by index name - const sortedIndexInformation = this.sortByIndexName( - indexInformation, - this.indexName - ); - - try { - const [index] = sortedIndexInformation; - const [labelOrType] = index.labelsOrTypes; - const [property] = index.properties; - - this.keywordIndexName = index.name; - this.textNodeProperty = property; - this.nodeLabel = labelOrType; - - return labelOrType; - } catch (error) { - return null; - } - } - - return null; - } - - async createNewKeywordIndex( - textNodeProperties: string[] = [] - ): Promise { - const nodeProps = - textNodeProperties.length > 0 - ? textNodeProperties - : [this.textNodeProperty]; - - // Construct the Cypher query to create a new full text index - const ftsIndexQuery = ` - CREATE FULLTEXT INDEX ${this.keywordIndexName} - FOR (n:\`${this.nodeLabel}\`) ON EACH - [${nodeProps.map((prop) => `n.\`${prop}\``).join(", ")}] - `; - - await this.query(ftsIndexQuery); - } - - sortByIndexName( - // eslint-disable-next-line @typescript-eslint/no-explicit-any - values: Array<{ [key: string]: any }>, - indexName: string - // eslint-disable-next-line @typescript-eslint/no-explicit-any - ): Array<{ [key: string]: any }> { - return values.sort( - (a, b) => - (a.index_name === indexName ? -1 : 0) - - (b.index_name === indexName ? -1 : 0) - ); - } - - async addVectors( - vectors: number[][], - documents: Document[], - // eslint-disable-next-line @typescript-eslint/no-explicit-any - metadatas?: Record[], - ids?: string[] - ): Promise { - let _ids = ids; - let _metadatas = metadatas; - - if (!_ids) { - _ids = documents.map(() => uuid.v1()); - } - - if (!metadatas) { - _metadatas = documents.map(() => ({})); - } - - const importQuery = ` - UNWIND $data AS row - CALL { - WITH row - MERGE (c:\`${this.nodeLabel}\` {id: row.id}) - WITH c, row - CALL db.create.setVectorProperty(c, '${this.embeddingNodeProperty}', row.embedding) - YIELD node - SET c.\`${this.textNodeProperty}\` = row.text - SET c += row.metadata - } IN TRANSACTIONS OF 1000 ROWS - `; - - const parameters = { - data: documents.map(({ pageContent, metadata }, index) => ({ - text: pageContent, - metadata: _metadatas ? _metadatas[index] : metadata, - embedding: vectors[index], - id: _ids ? _ids[index] : null, - })), - }; - - await this.query(importQuery, parameters); - - return _ids; - } - - async addDocuments(documents: Document[]): Promise { - const texts = documents.map(({ pageContent }) => pageContent); - - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents - ); - } - - async similaritySearch(query: string, k = 4): Promise { - const embedding = await this.embeddings.embedQuery(query); - - const results = await this.similaritySearchVectorWithScore( - embedding, - k, - query - ); - - return results.map((result) => result[0]); - } - - async similaritySearchVectorWithScore( - vector: number[], - k: number, - query: string - ): Promise<[Document, number][]> { - const defaultRetrieval = ` - RETURN node.${this.textNodeProperty} AS text, score, - node {.*, ${this.textNodeProperty}: Null, - ${this.embeddingNodeProperty}: Null, id: Null } AS metadata - `; - - const retrievalQuery = this.retrievalQuery - ? this.retrievalQuery - : defaultRetrieval; - - const readQuery = `${getSearchIndexQuery( - this.searchType - )} ${retrievalQuery}`; - - const parameters = { - index: this.indexName, - k: Number(k), - embedding: vector, - keyword_index: this.keywordIndexName, - query, - }; - const results = await this.query(readQuery, parameters); - - if (results) { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const docs: [Document, number][] = results.map((result: any) => [ - new Document({ - pageContent: result.text, - metadata: Object.fromEntries( - Object.entries(result.metadata).filter(([_, v]) => v !== null) - ), - }), - result.score, - ]); - - return docs; - } - - return []; - } -} - -function toObjects(records: neo4j.Record[]) { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const recordValues: Record[] = records.map((record) => { - const rObj = record.toObject(); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const out: { [key: string]: any } = {}; - Object.keys(rObj).forEach((key) => { - out[key] = itemIntToString(rObj[key]); - }); - return out; - }); - return recordValues; -} - -// eslint-disable-next-line @typescript-eslint/no-explicit-any -function itemIntToString(item: any): any { - if (neo4j.isInt(item)) return item.toString(); - if (Array.isArray(item)) return item.map((ii) => itemIntToString(ii)); - if (["number", "string", "boolean"].indexOf(typeof item) !== -1) return item; - if (item === null) return item; - if (typeof item === "object") return objIntToString(item); -} - -// eslint-disable-next-line @typescript-eslint/no-explicit-any -function objIntToString(obj: any) { - const entry = extractFromNeoObjects(obj); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let newObj: any = null; - if (Array.isArray(entry)) { - newObj = entry.map((item) => itemIntToString(item)); - } else if (entry !== null && typeof entry === "object") { - newObj = {}; - Object.keys(entry).forEach((key) => { - newObj[key] = itemIntToString(entry[key]); - }); - } - return newObj; -} - -// eslint-disable-next-line @typescript-eslint/no-explicit-any -function extractFromNeoObjects(obj: any) { - if ( - // eslint-disable-next-line - obj instanceof (neo4j.types.Node as any) || - // eslint-disable-next-line - obj instanceof (neo4j.types.Relationship as any) - ) { - return obj.properties; - // eslint-disable-next-line - } else if (obj instanceof (neo4j.types.Path as any)) { - // eslint-disable-next-line - return [].concat.apply([], extractPathForRows(obj)); - } - return obj; -} - -function extractPathForRows(path: neo4j.Path) { - let { segments } = path; - // Zero length path. No relationship, end === start - if (!Array.isArray(path.segments) || path.segments.length < 1) { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - segments = [{ ...path, end: null } as any]; - } - // eslint-disable-next-line @typescript-eslint/no-explicit-any - return segments.map((segment: any) => - [ - objIntToString(segment.start), - objIntToString(segment.relationship), - objIntToString(segment.end), - ].filter((part) => part !== null) - ); -} - -function getSearchIndexQuery(searchType: SearchType): string { - const typeToQueryMap: { [key in SearchType]: string } = { - vector: - "CALL db.index.vector.queryNodes($index, $k, $embedding) YIELD node, score", - hybrid: ` - CALL { - CALL db.index.vector.queryNodes($index, $k, $embedding) YIELD node, score - RETURN node, score UNION - CALL db.index.fulltext.queryNodes($keyword_index, $query, {limit: $k}) YIELD node, score - WITH collect({node: node, score: score}) AS nodes, max(score) AS max - UNWIND nodes AS n - RETURN n.node AS node, (n.score / max) AS score - } - WITH node, max(score) AS score ORDER BY score DESC LIMIT toInteger($k) - `, - }; - - return typeToQueryMap[searchType]; -} +export * from "@langchain/community/vectorstores/neo4j_vector"; diff --git a/langchain/src/vectorstores/opensearch.ts b/langchain/src/vectorstores/opensearch.ts index cfd9fcd87bbe..1e8e2d6058fd 100644 --- a/langchain/src/vectorstores/opensearch.ts +++ b/langchain/src/vectorstores/opensearch.ts @@ -1,326 +1 @@ -import { Client, RequestParams, errors } from "@opensearch-project/opensearch"; -import * as uuid from "uuid"; -import { Embeddings } from "../embeddings/base.js"; -import { Document } from "../document.js"; -import { VectorStore } from "./base.js"; - -type OpenSearchEngine = "nmslib" | "hnsw"; -type OpenSearchSpaceType = "l2" | "cosinesimil" | "ip"; - -/** - * Interface defining the options for vector search in OpenSearch. It - * includes the engine type, space type, and parameters for the HNSW - * algorithm. - */ -interface VectorSearchOptions { - readonly engine?: OpenSearchEngine; - readonly spaceType?: OpenSearchSpaceType; - readonly m?: number; - readonly efConstruction?: number; - readonly efSearch?: number; -} - -/** - * Interface defining the arguments required to create an instance of the - * OpenSearchVectorStore class. It includes the OpenSearch client, index - * name, and vector search options. - */ -export interface OpenSearchClientArgs { - readonly client: Client; - readonly indexName?: string; - - readonly vectorSearchOptions?: VectorSearchOptions; -} - -/** - * Type alias for an object. It's used to define filters for OpenSearch - * queries. - */ -type OpenSearchFilter = object; - -/** - * Class that provides a wrapper around the OpenSearch service for vector - * search. It provides methods for adding documents and vectors to the - * OpenSearch index, searching for similar vectors, and managing the - * OpenSearch index. - */ -export class OpenSearchVectorStore extends VectorStore { - declare FilterType: OpenSearchFilter; - - private readonly client: Client; - - private readonly indexName: string; - - private readonly engine: OpenSearchEngine; - - private readonly spaceType: OpenSearchSpaceType; - - private readonly efConstruction: number; - - private readonly efSearch: number; - - private readonly m: number; - - _vectorstoreType(): string { - return "opensearch"; - } - - constructor(embeddings: Embeddings, args: OpenSearchClientArgs) { - super(embeddings, args); - - this.spaceType = args.vectorSearchOptions?.spaceType ?? "l2"; - this.engine = args.vectorSearchOptions?.engine ?? "nmslib"; - this.m = args.vectorSearchOptions?.m ?? 16; - this.efConstruction = args.vectorSearchOptions?.efConstruction ?? 512; - this.efSearch = args.vectorSearchOptions?.efSearch ?? 512; - - this.client = args.client; - this.indexName = args.indexName ?? "documents"; - } - - /** - * Method to add documents to the OpenSearch index. It first converts the - * documents to vectors using the embeddings, then adds the vectors to the - * index. - * @param documents The documents to be added to the OpenSearch index. - * @returns Promise resolving to void. - */ - async addDocuments(documents: Document[]): Promise { - const texts = documents.map(({ pageContent }) => pageContent); - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents - ); - } - - /** - * Method to add vectors to the OpenSearch index. It ensures the index - * exists, then adds the vectors and associated documents to the index. - * @param vectors The vectors to be added to the OpenSearch index. - * @param documents The documents associated with the vectors. - * @param options Optional parameter that can contain the IDs for the documents. - * @returns Promise resolving to void. - */ - async addVectors( - vectors: number[][], - documents: Document[], - options?: { ids?: string[] } - ): Promise { - await this.ensureIndexExists( - vectors[0].length, - this.engine, - this.spaceType, - this.efSearch, - this.efConstruction, - this.m - ); - const documentIds = - options?.ids ?? Array.from({ length: vectors.length }, () => uuid.v4()); - const operations = vectors.flatMap((embedding, idx) => [ - { - index: { - _index: this.indexName, - _id: documentIds[idx], - }, - }, - { - embedding, - metadata: documents[idx].metadata, - text: documents[idx].pageContent, - }, - ]); - await this.client.bulk({ body: operations }); - await this.client.indices.refresh({ index: this.indexName }); - } - - /** - * Method to perform a similarity search on the OpenSearch index using a - * query vector. It returns the k most similar documents and their scores. - * @param query The query vector. - * @param k The number of similar documents to return. - * @param filter Optional filter for the OpenSearch query. - * @returns Promise resolving to an array of tuples, each containing a Document and its score. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: OpenSearchFilter | undefined - ): Promise<[Document, number][]> { - const search: RequestParams.Search = { - index: this.indexName, - body: { - query: { - bool: { - filter: { bool: { must: this.buildMetadataTerms(filter) } }, - must: [ - { - knn: { - embedding: { vector: query, k }, - }, - }, - ], - }, - }, - size: k, - }, - }; - - const { body } = await this.client.search(search); - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - return body.hits.hits.map((hit: any) => [ - new Document({ - pageContent: hit._source.text, - metadata: hit._source.metadata, - }), - hit._score, - ]); - } - - /** - * Static method to create a new OpenSearchVectorStore from an array of - * texts, their metadata, embeddings, and OpenSearch client arguments. - * @param texts The texts to be converted into documents and added to the OpenSearch index. - * @param metadatas The metadata associated with the texts. Can be an array of objects or a single object. - * @param embeddings The embeddings used to convert the texts into vectors. - * @param args The OpenSearch client arguments. - * @returns Promise resolving to a new instance of OpenSearchVectorStore. - */ - static fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - args: OpenSearchClientArgs - ): Promise { - const documents = texts.map((text, idx) => { - const metadata = Array.isArray(metadatas) ? metadatas[idx] : metadatas; - return new Document({ pageContent: text, metadata }); - }); - - return OpenSearchVectorStore.fromDocuments(documents, embeddings, args); - } - - /** - * Static method to create a new OpenSearchVectorStore from an array of - * Documents, embeddings, and OpenSearch client arguments. - * @param docs The documents to be added to the OpenSearch index. - * @param embeddings The embeddings used to convert the documents into vectors. - * @param dbConfig The OpenSearch client arguments. - * @returns Promise resolving to a new instance of OpenSearchVectorStore. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig: OpenSearchClientArgs - ): Promise { - const store = new OpenSearchVectorStore(embeddings, dbConfig); - await store.addDocuments(docs).then(() => store); - return store; - } - - /** - * Static method to create a new OpenSearchVectorStore from an existing - * OpenSearch index, embeddings, and OpenSearch client arguments. - * @param embeddings The embeddings used to convert the documents into vectors. - * @param dbConfig The OpenSearch client arguments. - * @returns Promise resolving to a new instance of OpenSearchVectorStore. - */ - static async fromExistingIndex( - embeddings: Embeddings, - dbConfig: OpenSearchClientArgs - ): Promise { - const store = new OpenSearchVectorStore(embeddings, dbConfig); - await store.client.cat.indices({ index: store.indexName }); - return store; - } - - private async ensureIndexExists( - dimension: number, - engine = "nmslib", - spaceType = "l2", - efSearch = 512, - efConstruction = 512, - m = 16 - ): Promise { - const body = { - settings: { - index: { - number_of_shards: 5, - number_of_replicas: 1, - knn: true, - "knn.algo_param.ef_search": efSearch, - }, - }, - mappings: { - dynamic_templates: [ - { - // map all metadata properties to be keyword - "metadata.*": { - match_mapping_type: "*", - mapping: { type: "keyword" }, - }, - }, - ], - properties: { - text: { type: "text" }, - metadata: { type: "object" }, - embedding: { - type: "knn_vector", - dimension, - method: { - name: "hnsw", - engine, - space_type: spaceType, - parameters: { ef_construction: efConstruction, m }, - }, - }, - }, - }, - }; - - const indexExists = await this.doesIndexExist(); - if (indexExists) return; - - await this.client.indices.create({ index: this.indexName, body }); - } - - private buildMetadataTerms( - filter?: OpenSearchFilter - ): { [key: string]: Record }[] { - if (filter == null) return []; - const result = []; - for (const [key, value] of Object.entries(filter)) { - const aggregatorKey = Array.isArray(value) ? "terms" : "term"; - result.push({ [aggregatorKey]: { [`metadata.${key}`]: value } }); - } - return result; - } - - /** - * Method to check if the OpenSearch index exists. - * @returns Promise resolving to a boolean indicating whether the index exists. - */ - async doesIndexExist(): Promise { - try { - await this.client.cat.indices({ index: this.indexName }); - return true; - } catch (err: unknown) { - // eslint-disable-next-line no-instanceof/no-instanceof - if (err instanceof errors.ResponseError && err.statusCode === 404) { - return false; - } - throw err; - } - } - - /** - * Method to delete the OpenSearch index if it exists. - * @returns Promise resolving to void. - */ - async deleteIfExists(): Promise { - const indexExists = await this.doesIndexExist(); - if (!indexExists) return; - - await this.client.indices.delete({ index: this.indexName }); - } -} +export * from "@langchain/community/vectorstores/opensearch"; diff --git a/langchain/src/vectorstores/pgvector.ts b/langchain/src/vectorstores/pgvector.ts index 023d64966153..026732814617 100644 --- a/langchain/src/vectorstores/pgvector.ts +++ b/langchain/src/vectorstores/pgvector.ts @@ -1,440 +1 @@ -import pg, { type Pool, type PoolClient, type PoolConfig } from "pg"; -import { VectorStore } from "./base.js"; -import { Embeddings } from "../embeddings/base.js"; -import { Document } from "../document.js"; -import { getEnvironmentVariable } from "../util/env.js"; - -type Metadata = Record; - -/** - * Interface that defines the arguments required to create a - * `PGVectorStore` instance. It includes Postgres connection options, - * table name, filter, and verbosity level. - */ -export interface PGVectorStoreArgs { - postgresConnectionOptions: PoolConfig; - tableName: string; - collectionTableName?: string; - collectionName?: string; - collectionMetadata?: Metadata | null; - columns?: { - idColumnName?: string; - vectorColumnName?: string; - contentColumnName?: string; - metadataColumnName?: string; - }; - filter?: Metadata; - verbose?: boolean; - /** - * The amount of documents to chunk by when - * adding vectors. - * @default 500 - */ - chunkSize?: number; -} - -/** - * Class that provides an interface to a Postgres vector database. It - * extends the `VectorStore` base class and implements methods for adding - * documents and vectors, performing similarity searches, and ensuring the - * existence of a table in the database. - */ -export class PGVectorStore extends VectorStore { - declare FilterType: Metadata; - - tableName: string; - - collectionTableName?: string; - - collectionName = "langchain"; - - collectionMetadata: Metadata | null; - - idColumnName: string; - - vectorColumnName: string; - - contentColumnName: string; - - metadataColumnName: string; - - filter?: Metadata; - - _verbose?: boolean; - - pool: Pool; - - client?: PoolClient; - - chunkSize = 500; - - _vectorstoreType(): string { - return "pgvector"; - } - - private constructor(embeddings: Embeddings, config: PGVectorStoreArgs) { - super(embeddings, config); - this.tableName = config.tableName; - this.collectionTableName = config.collectionTableName; - this.collectionName = config.collectionName ?? "langchain"; - this.collectionMetadata = config.collectionMetadata ?? null; - this.filter = config.filter; - - this.vectorColumnName = config.columns?.vectorColumnName ?? "embedding"; - this.contentColumnName = config.columns?.contentColumnName ?? "text"; - this.idColumnName = config.columns?.idColumnName ?? "id"; - this.metadataColumnName = config.columns?.metadataColumnName ?? "metadata"; - - const pool = new pg.Pool(config.postgresConnectionOptions); - this.pool = pool; - this.chunkSize = config.chunkSize ?? 500; - - this._verbose = - getEnvironmentVariable("LANGCHAIN_VERBOSE") === "true" ?? - !!config.verbose; - } - - /** - * Static method to create a new `PGVectorStore` instance from a - * connection. It creates a table if one does not exist, and calls - * `connect` to return a new instance of `PGVectorStore`. - * - * @param embeddings - Embeddings instance. - * @param fields - `PGVectorStoreArgs` instance. - * @returns A new instance of `PGVectorStore`. - */ - static async initialize( - embeddings: Embeddings, - config: PGVectorStoreArgs - ): Promise { - const postgresqlVectorStore = new PGVectorStore(embeddings, config); - - await postgresqlVectorStore._initializeClient(); - await postgresqlVectorStore.ensureTableInDatabase(); - if (postgresqlVectorStore.collectionTableName) { - await postgresqlVectorStore.ensureCollectionTableInDatabase(); - } - - return postgresqlVectorStore; - } - - protected async _initializeClient() { - this.client = await this.pool.connect(); - } - - /** - * Method to add documents to the vector store. It converts the documents into - * vectors, and adds them to the store. - * - * @param documents - Array of `Document` instances. - * @returns Promise that resolves when the documents have been added. - */ - async addDocuments(documents: Document[]): Promise { - const texts = documents.map(({ pageContent }) => pageContent); - - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents - ); - } - - /** - * Inserts a row for the collectionName provided at initialization if it does not - * exist and returns the collectionId. - * - * @returns The collectionId for the given collectionName. - */ - async getOrCreateCollection(): Promise { - const queryString = ` - SELECT uuid from ${this.collectionTableName} - WHERE name = $1; - `; - const queryResult = await this.pool.query(queryString, [ - this.collectionName, - ]); - let collectionId = queryResult.rows[0]?.uuid; - - if (!collectionId) { - const insertString = ` - INSERT INTO ${this.collectionTableName}( - uuid, - name, - cmetadata - ) - VALUES ( - uuid_generate_v4(), - $1, - $2 - ) - RETURNING uuid; - `; - const insertResult = await this.pool.query(insertString, [ - this.collectionName, - this.collectionMetadata, - ]); - collectionId = insertResult.rows[0]?.uuid; - } - - return collectionId; - } - - /** - * Generates the SQL placeholders for a specific row at the provided index. - * - * @param index - The index of the row for which placeholders need to be generated. - * @param numOfColumns - The number of columns we are inserting data into. - * @returns The SQL placeholders for the row values. - */ - private generatePlaceholderForRowAt( - index: number, - numOfColumns: number - ): string { - const placeholders = []; - for (let i = 0; i < numOfColumns; i += 1) { - placeholders.push(`$${index * numOfColumns + i + 1}`); - } - return `(${placeholders.join(", ")})`; - } - - /** - * Constructs the SQL query for inserting rows into the specified table. - * - * @param rows - The rows of data to be inserted, consisting of values and records. - * @param chunkIndex - The starting index for generating query placeholders based on chunk positioning. - * @returns The complete SQL INSERT INTO query string. - */ - private async buildInsertQuery(rows: (string | Record)[][]) { - let collectionId; - if (this.collectionTableName) { - collectionId = await this.getOrCreateCollection(); - } - - const columns = [ - this.contentColumnName, - this.vectorColumnName, - this.metadataColumnName, - ]; - - if (collectionId) { - columns.push("collection_id"); - } - - const valuesPlaceholders = rows - .map((_, j) => this.generatePlaceholderForRowAt(j, columns.length)) - .join(", "); - - const text = ` - INSERT INTO ${this.tableName}( - ${columns} - ) - VALUES ${valuesPlaceholders} - `; - return text; - } - - /** - * Method to add vectors to the vector store. It converts the vectors into - * rows and inserts them into the database. - * - * @param vectors - Array of vectors. - * @param documents - Array of `Document` instances. - * @returns Promise that resolves when the vectors have been added. - */ - async addVectors(vectors: number[][], documents: Document[]): Promise { - const rows = []; - let collectionId; - if (this.collectionTableName) { - collectionId = await this.getOrCreateCollection(); - } - - for (let i = 0; i < vectors.length; i += 1) { - const values = []; - const embedding = vectors[i]; - const embeddingString = `[${embedding.join(",")}]`; - values.push( - documents[i].pageContent, - embeddingString, - documents[i].metadata - ); - if (collectionId) { - values.push(collectionId); - } - rows.push(values); - } - - for (let i = 0; i < rows.length; i += this.chunkSize) { - const chunk = rows.slice(i, i + this.chunkSize); - const insertQuery = await this.buildInsertQuery(chunk); - const flatValues = chunk.flat(); - try { - await this.pool.query(insertQuery, flatValues); - } catch (e) { - console.error(e); - throw new Error(`Error inserting: ${(e as Error).message}`); - } - } - } - - /** - * Method to perform a similarity search in the vector store. It returns - * the `k` most similar documents to the query vector, along with their - * similarity scores. - * - * @param query - Query vector. - * @param k - Number of most similar documents to return. - * @param filter - Optional filter to apply to the search. - * @returns Promise that resolves with an array of tuples, each containing a `Document` and its similarity score. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: this["FilterType"] - ): Promise<[Document, number][]> { - const embeddingString = `[${query.join(",")}]`; - const _filter = filter ?? "{}"; - let collectionId; - if (this.collectionTableName) { - collectionId = await this.getOrCreateCollection(); - } - - const parameters = [embeddingString, _filter, k]; - if (collectionId) { - parameters.push(collectionId); - } - - const queryString = ` - SELECT *, ${this.vectorColumnName} <=> $1 as "_distance" - FROM ${this.tableName} - WHERE ${this.metadataColumnName}::jsonb @> $2 - ${collectionId ? "AND collection_id = $4" : ""} - ORDER BY "_distance" ASC - LIMIT $3; - `; - - const documents = (await this.pool.query(queryString, parameters)).rows; - - const results = [] as [Document, number][]; - for (const doc of documents) { - if (doc._distance != null && doc[this.contentColumnName] != null) { - const document = new Document({ - pageContent: doc[this.contentColumnName], - metadata: doc[this.metadataColumnName], - }); - results.push([document, doc._distance]); - } - } - return results; - } - - /** - * Method to ensure the existence of the table in the database. It creates - * the table if it does not already exist. - * - * @returns Promise that resolves when the table has been ensured. - */ - async ensureTableInDatabase(): Promise { - await this.pool.query("CREATE EXTENSION IF NOT EXISTS vector;"); - await this.pool.query('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";'); - - await this.pool.query(` - CREATE TABLE IF NOT EXISTS ${this.tableName} ( - "${this.idColumnName}" uuid NOT NULL DEFAULT uuid_generate_v4() PRIMARY KEY, - "${this.contentColumnName}" text, - "${this.metadataColumnName}" jsonb, - "${this.vectorColumnName}" vector - ); - `); - } - - /** - * Method to ensure the existence of the collection table in the database. - * It creates the table if it does not already exist. - * - * @returns Promise that resolves when the collection table has been ensured. - */ - async ensureCollectionTableInDatabase(): Promise { - try { - await this.pool.query(` - CREATE TABLE IF NOT EXISTS ${this.collectionTableName} ( - uuid uuid NOT NULL DEFAULT uuid_generate_v4() PRIMARY KEY, - name character varying, - cmetadata jsonb - ); - - ALTER TABLE ${this.tableName} - ADD COLUMN collection_id uuid; - - ALTER TABLE ${this.tableName} - ADD CONSTRAINT ${this.tableName}_collection_id_fkey - FOREIGN KEY (collection_id) - REFERENCES ${this.collectionTableName}(uuid) - ON DELETE CASCADE; - `); - } catch (e) { - if (!(e as Error).message.includes("already exists")) { - console.error(e); - throw new Error(`Error adding column: ${(e as Error).message}`); - } - } - } - - /** - * Static method to create a new `PGVectorStore` instance from an - * array of texts and their metadata. It converts the texts into - * `Document` instances and adds them to the store. - * - * @param texts - Array of texts. - * @param metadatas - Array of metadata objects or a single metadata object. - * @param embeddings - Embeddings instance. - * @param dbConfig - `PGVectorStoreArgs` instance. - * @returns Promise that resolves with a new instance of `PGVectorStore`. - */ - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - dbConfig: PGVectorStoreArgs - ): Promise { - const docs = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - - return PGVectorStore.fromDocuments(docs, embeddings, dbConfig); - } - - /** - * Static method to create a new `PGVectorStore` instance from an - * array of `Document` instances. It adds the documents to the store. - * - * @param docs - Array of `Document` instances. - * @param embeddings - Embeddings instance. - * @param dbConfig - `PGVectorStoreArgs` instance. - * @returns Promise that resolves with a new instance of `PGVectorStore`. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig: PGVectorStoreArgs - ): Promise { - const instance = await PGVectorStore.initialize(embeddings, dbConfig); - await instance.addDocuments(docs); - - return instance; - } - - /** - * Closes all the clients in the pool and terminates the pool. - * - * @returns Promise that resolves when all clients are closed and the pool is terminated. - */ - async end(): Promise { - this.client?.release(); - return this.pool.end(); - } -} +export * from "@langchain/community/vectorstores/pgvector"; diff --git a/langchain/src/vectorstores/pinecone.ts b/langchain/src/vectorstores/pinecone.ts index d4978f7a9289..667b9e8a8860 100644 --- a/langchain/src/vectorstores/pinecone.ts +++ b/langchain/src/vectorstores/pinecone.ts @@ -1,360 +1 @@ -/* eslint-disable no-process-env */ -import * as uuid from "uuid"; -import flatten from "flat"; - -import { - RecordMetadata, - PineconeRecord, - Index as PineconeIndex, -} from "@pinecone-database/pinecone"; - -import { MaxMarginalRelevanceSearchOptions, VectorStore } from "./base.js"; -import { Embeddings } from "../embeddings/base.js"; -import { Document } from "../document.js"; -import { AsyncCaller, AsyncCallerParams } from "../util/async_caller.js"; -import { maximalMarginalRelevance } from "../util/math.js"; -import { chunkArray } from "../util/chunk.js"; - -// eslint-disable-next-line @typescript-eslint/ban-types, @typescript-eslint/no-explicit-any -type PineconeMetadata = Record; - -export interface PineconeLibArgs extends AsyncCallerParams { - pineconeIndex: PineconeIndex; - textKey?: string; - namespace?: string; - filter?: PineconeMetadata; -} - -/** - * Type that defines the parameters for the delete operation in the - * PineconeStore class. It includes ids, filter, deleteAll flag, and namespace. - */ -export type PineconeDeleteParams = { - ids?: string[]; - deleteAll?: boolean; - filter?: object; - namespace?: string; -}; - -/** - * Class that extends the VectorStore class and provides methods to - * interact with the Pinecone vector database. - */ -export class PineconeStore extends VectorStore { - declare FilterType: PineconeMetadata; - - textKey: string; - - namespace?: string; - - pineconeIndex: PineconeIndex; - - filter?: PineconeMetadata; - - caller: AsyncCaller; - - _vectorstoreType(): string { - return "pinecone"; - } - - constructor(embeddings: Embeddings, args: PineconeLibArgs) { - super(embeddings, args); - - this.embeddings = embeddings; - const { namespace, pineconeIndex, textKey, filter, ...asyncCallerArgs } = - args; - this.namespace = namespace; - this.pineconeIndex = pineconeIndex; - this.textKey = textKey ?? "text"; - this.filter = filter; - this.caller = new AsyncCaller(asyncCallerArgs); - } - - /** - * Method that adds documents to the Pinecone database. - * @param documents Array of documents to add to the Pinecone database. - * @param options Optional ids for the documents. - * @returns Promise that resolves with the ids of the added documents. - */ - async addDocuments( - documents: Document[], - options?: { ids?: string[] } | string[] - ) { - const texts = documents.map(({ pageContent }) => pageContent); - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents, - options - ); - } - - /** - * Method that adds vectors to the Pinecone database. - * @param vectors Array of vectors to add to the Pinecone database. - * @param documents Array of documents associated with the vectors. - * @param options Optional ids for the vectors. - * @returns Promise that resolves with the ids of the added vectors. - */ - async addVectors( - vectors: number[][], - documents: Document[], - options?: { ids?: string[] } | string[] - ) { - const ids = Array.isArray(options) ? options : options?.ids; - const documentIds = ids == null ? documents.map(() => uuid.v4()) : ids; - const pineconeVectors = vectors.map((values, idx) => { - // Pinecone doesn't support nested objects, so we flatten them - const documentMetadata = { ...documents[idx].metadata }; - // preserve string arrays which are allowed - const stringArrays: Record = {}; - for (const key of Object.keys(documentMetadata)) { - if ( - Array.isArray(documentMetadata[key]) && - // eslint-disable-next-line @typescript-eslint/ban-types, @typescript-eslint/no-explicit-any - documentMetadata[key].every((el: any) => typeof el === "string") - ) { - stringArrays[key] = documentMetadata[key]; - delete documentMetadata[key]; - } - } - const metadata: { - [key: string]: string | number | boolean | string[] | null; - } = { - ...flatten(documentMetadata), - ...stringArrays, - [this.textKey]: documents[idx].pageContent, - }; - // Pinecone doesn't support null values, so we remove them - for (const key of Object.keys(metadata)) { - if (metadata[key] == null) { - delete metadata[key]; - } else if ( - typeof metadata[key] === "object" && - Object.keys(metadata[key] as unknown as object).length === 0 - ) { - delete metadata[key]; - } - } - - return { - id: documentIds[idx], - metadata, - values, - } as PineconeRecord; - }); - - const namespace = this.pineconeIndex.namespace(this.namespace ?? ""); - // Pinecone recommends a limit of 100 vectors per upsert request - const chunkSize = 100; - const chunkedVectors = chunkArray(pineconeVectors, chunkSize); - const batchRequests = chunkedVectors.map((chunk) => - this.caller.call(async () => namespace.upsert(chunk)) - ); - - await Promise.all(batchRequests); - - return documentIds; - } - - /** - * Method that deletes vectors from the Pinecone database. - * @param params Parameters for the delete operation. - * @returns Promise that resolves when the delete operation is complete. - */ - async delete(params: PineconeDeleteParams): Promise { - const { deleteAll, ids, filter } = params; - const namespace = this.pineconeIndex.namespace(this.namespace ?? ""); - - if (deleteAll) { - await namespace.deleteAll(); - } else if (ids) { - const batchSize = 1000; - for (let i = 0; i < ids.length; i += batchSize) { - const batchIds = ids.slice(i, i + batchSize); - await namespace.deleteMany(batchIds); - } - } else if (filter) { - await namespace.deleteMany(filter); - } else { - throw new Error("Either ids or delete_all must be provided."); - } - } - - protected async _runPineconeQuery( - query: number[], - k: number, - filter?: PineconeMetadata, - options?: { includeValues: boolean } - ) { - if (filter && this.filter) { - throw new Error("cannot provide both `filter` and `this.filter`"); - } - const _filter = filter ?? this.filter; - const namespace = this.pineconeIndex.namespace(this.namespace ?? ""); - - const results = await namespace.query({ - includeMetadata: true, - topK: k, - vector: query, - filter: _filter, - ...options, - }); - - return results; - } - - /** - * Method that performs a similarity search in the Pinecone database and - * returns the results along with their scores. - * @param query Query vector for the similarity search. - * @param k Number of top results to return. - * @param filter Optional filter to apply to the search. - * @returns Promise that resolves with an array of documents and their scores. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: PineconeMetadata - ): Promise<[Document, number][]> { - const results = await this._runPineconeQuery(query, k, filter); - const result: [Document, number][] = []; - - if (results.matches) { - for (const res of results.matches) { - const { [this.textKey]: pageContent, ...metadata } = (res.metadata ?? - {}) as PineconeMetadata; - if (res.score) { - result.push([new Document({ metadata, pageContent }), res.score]); - } - } - } - - return result; - } - - /** - * Return documents selected using the maximal marginal relevance. - * Maximal marginal relevance optimizes for similarity to the query AND diversity - * among selected documents. - * - * @param {string} query - Text to look up documents similar to. - * @param {number} options.k - Number of documents to return. - * @param {number} options.fetchK=20 - Number of documents to fetch before passing to the MMR algorithm. - * @param {number} options.lambda=0.5 - Number between 0 and 1 that determines the degree of diversity among the results, - * where 0 corresponds to maximum diversity and 1 to minimum diversity. - * @param {PineconeMetadata} options.filter - Optional filter to apply to the search. - * - * @returns {Promise} - List of documents selected by maximal marginal relevance. - */ - async maxMarginalRelevanceSearch( - query: string, - options: MaxMarginalRelevanceSearchOptions - ): Promise { - const queryEmbedding = await this.embeddings.embedQuery(query); - - const results = await this._runPineconeQuery( - queryEmbedding, - options.fetchK ?? 20, - options.filter, - { includeValues: true } - ); - - const matches = results?.matches ?? []; - const embeddingList = matches.map((match) => match.values); - - const mmrIndexes = maximalMarginalRelevance( - queryEmbedding, - embeddingList, - options.lambda, - options.k - ); - - const topMmrMatches = mmrIndexes.map((idx) => matches[idx]); - - const finalResult: Document[] = []; - for (const res of topMmrMatches) { - const { [this.textKey]: pageContent, ...metadata } = (res.metadata ?? - {}) as PineconeMetadata; - if (res.score) { - finalResult.push(new Document({ metadata, pageContent })); - } - } - - return finalResult; - } - - /** - * Static method that creates a new instance of the PineconeStore class - * from texts. - * @param texts Array of texts to add to the Pinecone database. - * @param metadatas Metadata associated with the texts. - * @param embeddings Embeddings to use for the texts. - * @param dbConfig Configuration for the Pinecone database. - * @returns Promise that resolves with a new instance of the PineconeStore class. - */ - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - dbConfig: - | { - pineconeIndex: PineconeIndex; - textKey?: string; - namespace?: string | undefined; - } - | PineconeLibArgs - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - - const args: PineconeLibArgs = { - pineconeIndex: dbConfig.pineconeIndex, - textKey: dbConfig.textKey, - namespace: dbConfig.namespace, - }; - return PineconeStore.fromDocuments(docs, embeddings, args); - } - - /** - * Static method that creates a new instance of the PineconeStore class - * from documents. - * @param docs Array of documents to add to the Pinecone database. - * @param embeddings Embeddings to use for the documents. - * @param dbConfig Configuration for the Pinecone database. - * @returns Promise that resolves with a new instance of the PineconeStore class. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig: PineconeLibArgs - ): Promise { - const args = dbConfig; - args.textKey = dbConfig.textKey ?? "text"; - - const instance = new this(embeddings, args); - await instance.addDocuments(docs); - return instance; - } - - /** - * Static method that creates a new instance of the PineconeStore class - * from an existing index. - * @param embeddings Embeddings to use for the documents. - * @param dbConfig Configuration for the Pinecone database. - * @returns Promise that resolves with a new instance of the PineconeStore class. - */ - static async fromExistingIndex( - embeddings: Embeddings, - dbConfig: PineconeLibArgs - ): Promise { - const instance = new this(embeddings, dbConfig); - return instance; - } -} +export * from "@langchain/community/vectorstores/pinecone"; diff --git a/langchain/src/vectorstores/prisma.ts b/langchain/src/vectorstores/prisma.ts index 219f570df37d..9e31386edfe3 100644 --- a/langchain/src/vectorstores/prisma.ts +++ b/langchain/src/vectorstores/prisma.ts @@ -1,511 +1 @@ -import { VectorStore } from "./base.js"; -import { Document } from "../document.js"; -import { type Embeddings } from "../embeddings/base.js"; -import { Callbacks } from "../callbacks/manager.js"; - -const IdColumnSymbol = Symbol("id"); -const ContentColumnSymbol = Symbol("content"); - -type ColumnSymbol = typeof IdColumnSymbol | typeof ContentColumnSymbol; - -declare type Value = unknown; -declare type RawValue = Value | Sql; - -declare class Sql { - strings: string[]; - - constructor( - rawStrings: ReadonlyArray, - rawValues: ReadonlyArray - ); -} - -type PrismaNamespace = { - ModelName: Record; - Sql: typeof Sql; - raw: (sql: string) => Sql; - join: ( - values: RawValue[], - separator?: string, - prefix?: string, - suffix?: string - ) => Sql; - sql: (strings: ReadonlyArray, ...values: RawValue[]) => Sql; -}; - -type PrismaClient = { - $queryRaw( - query: TemplateStringsArray | Sql, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - ...values: any[] - ): Promise; - $executeRaw( - query: TemplateStringsArray | Sql, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - ...values: any[] - ): // eslint-disable-next-line @typescript-eslint/no-explicit-any - Promise; - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - $transaction

[]>(arg: [...P]): Promise; -}; - -type ObjectIntersect = { - [P in keyof A & keyof B]: A[P] | B[P]; -}; - -type ModelColumns> = { - [K in keyof TModel]?: true | ColumnSymbol; -}; - -export type PrismaSqlFilter> = { - [K in keyof TModel]?: { - equals?: TModel[K]; - in?: TModel[K][]; - isNull?: TModel[K]; - isNotNull?: TModel[K]; - like?: TModel[K]; - lt?: TModel[K]; - lte?: TModel[K]; - gt?: TModel[K]; - gte?: TModel[K]; - not?: TModel[K]; - }; -}; - -const OpMap = { - equals: "=", - in: "IN", - isNull: "IS NULL", - isNotNull: "IS NOT NULL", - like: "LIKE", - lt: "<", - lte: "<=", - gt: ">", - gte: ">=", - not: "<>", -}; - -type SimilarityModel< - TModel extends Record = Record, - TColumns extends ModelColumns = ModelColumns -> = Pick> & { - _distance: number | null; -}; - -type DefaultPrismaVectorStore = PrismaVectorStore< - Record, - string, - ModelColumns>, - PrismaSqlFilter> ->; - -/** - * A specific implementation of the VectorStore class that is designed to - * work with Prisma. It provides methods for adding models, documents, and - * vectors, as well as for performing similarity searches. - */ -export class PrismaVectorStore< - TModel extends Record, - TModelName extends string, - TSelectModel extends ModelColumns, - TFilterModel extends PrismaSqlFilter -> extends VectorStore { - protected tableName: string; - - protected vectorColumnName: string; - - protected selectColumns: string[]; - - filter?: TFilterModel; - - idColumn: keyof TModel & string; - - contentColumn: keyof TModel & string; - - static IdColumn: typeof IdColumnSymbol = IdColumnSymbol; - - static ContentColumn: typeof ContentColumnSymbol = ContentColumnSymbol; - - protected db: PrismaClient; - - protected Prisma: PrismaNamespace; - - _vectorstoreType(): string { - return "prisma"; - } - - constructor( - embeddings: Embeddings, - config: { - db: PrismaClient; - prisma: PrismaNamespace; - tableName: TModelName; - vectorColumnName: string; - columns: TSelectModel; - filter?: TFilterModel; - } - ) { - super(embeddings, {}); - - this.Prisma = config.prisma; - this.db = config.db; - - const entries = Object.entries(config.columns); - const idColumn = entries.find((i) => i[1] === IdColumnSymbol)?.[0]; - const contentColumn = entries.find( - (i) => i[1] === ContentColumnSymbol - )?.[0]; - - if (idColumn == null) throw new Error("Missing ID column"); - if (contentColumn == null) throw new Error("Missing content column"); - - this.idColumn = idColumn; - this.contentColumn = contentColumn; - - this.tableName = config.tableName; - this.vectorColumnName = config.vectorColumnName; - - this.selectColumns = entries - .map(([key, alias]) => (alias && key) || null) - .filter((x): x is string => !!x); - - if (config.filter) { - this.filter = config.filter; - } - } - - /** - * Creates a new PrismaVectorStore with the specified model. - * @param db The PrismaClient instance. - * @returns An object with create, fromTexts, and fromDocuments methods. - */ - static withModel>(db: PrismaClient) { - function create< - TPrisma extends PrismaNamespace, - TColumns extends ModelColumns, - TFilters extends PrismaSqlFilter - >( - embeddings: Embeddings, - config: { - prisma: TPrisma; - tableName: keyof TPrisma["ModelName"] & string; - vectorColumnName: string; - columns: TColumns; - filter?: TFilters; - } - ) { - type ModelName = keyof TPrisma["ModelName"] & string; - return new PrismaVectorStore( - embeddings, - { ...config, db } - ); - } - - async function fromTexts< - TPrisma extends PrismaNamespace, - TColumns extends ModelColumns - >( - texts: string[], - metadatas: TModel[], - embeddings: Embeddings, - dbConfig: { - prisma: TPrisma; - tableName: keyof TPrisma["ModelName"] & string; - vectorColumnName: string; - columns: TColumns; - } - ) { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - - return PrismaVectorStore.fromDocuments(docs, embeddings, { - ...dbConfig, - db, - }); - } - - async function fromDocuments< - TPrisma extends PrismaNamespace, - TColumns extends ModelColumns, - TFilters extends PrismaSqlFilter - >( - docs: Document[], - embeddings: Embeddings, - dbConfig: { - prisma: TPrisma; - tableName: keyof TPrisma["ModelName"] & string; - vectorColumnName: string; - columns: TColumns; - } - ) { - type ModelName = keyof TPrisma["ModelName"] & string; - const instance = new PrismaVectorStore< - TModel, - ModelName, - TColumns, - TFilters - >(embeddings, { ...dbConfig, db }); - await instance.addDocuments(docs); - return instance; - } - - return { create, fromTexts, fromDocuments }; - } - - /** - * Adds the specified models to the store. - * @param models The models to add. - * @returns A promise that resolves when the models have been added. - */ - async addModels(models: TModel[]) { - return this.addDocuments( - models.map((metadata) => { - const pageContent = metadata[this.contentColumn]; - if (typeof pageContent !== "string") - throw new Error("Content column must be a string"); - return new Document({ pageContent, metadata }); - }) - ); - } - - /** - * Adds the specified documents to the store. - * @param documents The documents to add. - * @returns A promise that resolves when the documents have been added. - */ - async addDocuments(documents: Document[]) { - const texts = documents.map(({ pageContent }) => pageContent); - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents - ); - } - - /** - * Adds the specified vectors to the store. - * @param vectors The vectors to add. - * @param documents The documents associated with the vectors. - * @returns A promise that resolves when the vectors have been added. - */ - async addVectors(vectors: number[][], documents: Document[]) { - // table name, column name cannot be parametrised - // these fields are thus not escaped by Prisma and can be dangerous if user input is used - const idColumnRaw = this.Prisma.raw(`"${this.idColumn}"`); - const tableNameRaw = this.Prisma.raw(`"${this.tableName}"`); - const vectorColumnRaw = this.Prisma.raw(`"${this.vectorColumnName}"`); - - await this.db.$transaction( - vectors.map( - (vector, idx) => this.db.$executeRaw` - UPDATE ${tableNameRaw} - SET ${vectorColumnRaw} = ${`[${vector.join(",")}]`}::vector - WHERE ${idColumnRaw} = ${documents[idx].metadata[this.idColumn]} - ` - ) - ); - } - - /** - * Performs a similarity search with the specified query. - * @param query The query to use for the similarity search. - * @param k The number of results to return. - * @param _filter The filter to apply to the results. - * @param _callbacks The callbacks to use during the search. - * @returns A promise that resolves with the search results. - */ - async similaritySearch( - query: string, - k = 4, - _filter: this["FilterType"] | undefined = undefined, // not used. here to make the interface compatible with the other stores - _callbacks: Callbacks | undefined = undefined // implement passing to embedQuery later - ): Promise>[]> { - const results = await this.similaritySearchVectorWithScore( - await this.embeddings.embedQuery(query), - k - ); - - return results.map((result) => result[0]); - } - - /** - * Performs a similarity search with the specified query and returns the - * results along with their scores. - * @param query The query to use for the similarity search. - * @param k The number of results to return. - * @param filter The filter to apply to the results. - * @param _callbacks The callbacks to use during the search. - * @returns A promise that resolves with the search results and their scores. - */ - async similaritySearchWithScore( - query: string, - k?: number, - filter?: TFilterModel, - _callbacks: Callbacks | undefined = undefined // implement passing to embedQuery later - ) { - return super.similaritySearchWithScore(query, k, filter); - } - - /** - * Performs a similarity search with the specified vector and returns the - * results along with their scores. - * @param query The vector to use for the similarity search. - * @param k The number of results to return. - * @param filter The filter to apply to the results. - * @returns A promise that resolves with the search results and their scores. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: TFilterModel - ): Promise<[Document>, number][]> { - // table name, column names cannot be parametrised - // these fields are thus not escaped by Prisma and can be dangerous if user input is used - const vectorColumnRaw = this.Prisma.raw(`"${this.vectorColumnName}"`); - const tableNameRaw = this.Prisma.raw(`"${this.tableName}"`); - const selectRaw = this.Prisma.raw( - this.selectColumns.map((x) => `"${x}"`).join(", ") - ); - - const vector = `[${query.join(",")}]`; - const articles = await this.db.$queryRaw< - Array> - >( - this.Prisma.join( - [ - this.Prisma.sql` - SELECT ${selectRaw}, ${vectorColumnRaw} <=> ${vector}::vector as "_distance" - FROM ${tableNameRaw} - `, - this.buildSqlFilterStr(filter ?? this.filter), - this.Prisma.sql` - ORDER BY "_distance" ASC - LIMIT ${k}; - `, - ].filter((x) => x != null), - "" - ) - ); - - const results: [Document>, number][] = - []; - for (const article of articles) { - if (article._distance != null && article[this.contentColumn] != null) { - results.push([ - new Document({ - pageContent: article[this.contentColumn] as string, - metadata: article, - }), - article._distance, - ]); - } - } - - return results; - } - - buildSqlFilterStr(filter?: TFilterModel) { - if (filter == null) return null; - return this.Prisma.join( - Object.entries(filter).flatMap(([key, ops]) => - Object.entries(ops).map(([opName, value]) => { - // column name, operators cannot be parametrised - // these fields are thus not escaped by Prisma and can be dangerous if user input is used - const opNameKey = opName as keyof typeof OpMap; - const colRaw = this.Prisma.raw(`"${key}"`); - const opRaw = this.Prisma.raw(OpMap[opNameKey]); - - switch (OpMap[opNameKey]) { - case OpMap.in: { - if ( - !Array.isArray(value) || - !value.every((v) => typeof v === "string") - ) { - throw new Error( - `Invalid filter: IN operator requires an array of strings. Received: ${JSON.stringify( - value, - null, - 2 - )}` - ); - } - return this.Prisma.sql`${colRaw} ${opRaw} (${this.Prisma.join( - value - )})`; - } - case OpMap.isNull: - case OpMap.isNotNull: - return this.Prisma.sql`${colRaw} ${opRaw}`; - default: - return this.Prisma.sql`${colRaw} ${opRaw} ${value}`; - } - }) - ), - " AND ", - " WHERE " - ); - } - - /** - * Creates a new PrismaVectorStore from the specified texts. - * @param texts The texts to use to create the store. - * @param metadatas The metadata for the texts. - * @param embeddings The embeddings to use. - * @param dbConfig The database configuration. - * @returns A promise that resolves with the new PrismaVectorStore. - */ - static async fromTexts( - texts: string[], - metadatas: object[], - embeddings: Embeddings, - dbConfig: { - db: PrismaClient; - prisma: PrismaNamespace; - tableName: string; - vectorColumnName: string; - columns: ModelColumns>; - } - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - - return PrismaVectorStore.fromDocuments(docs, embeddings, dbConfig); - } - - /** - * Creates a new PrismaVectorStore from the specified documents. - * @param docs The documents to use to create the store. - * @param embeddings The embeddings to use. - * @param dbConfig The database configuration. - * @returns A promise that resolves with the new PrismaVectorStore. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig: { - db: PrismaClient; - prisma: PrismaNamespace; - tableName: string; - vectorColumnName: string; - columns: ModelColumns>; - } - ): Promise { - const instance = new PrismaVectorStore(embeddings, dbConfig); - await instance.addDocuments(docs); - return instance; - } -} +export * from "@langchain/community/vectorstores/prisma"; diff --git a/langchain/src/vectorstores/qdrant.ts b/langchain/src/vectorstores/qdrant.ts index 6ec54233a6ec..a4a8c713d6f4 100644 --- a/langchain/src/vectorstores/qdrant.ts +++ b/langchain/src/vectorstores/qdrant.ts @@ -1,260 +1 @@ -import { QdrantClient } from "@qdrant/js-client-rest"; -import type { Schemas as QdrantSchemas } from "@qdrant/js-client-rest"; -import { v4 as uuid } from "uuid"; - -import { Embeddings } from "../embeddings/base.js"; -import { VectorStore } from "./base.js"; -import { Document } from "../document.js"; -import { getEnvironmentVariable } from "../util/env.js"; - -/** - * Interface for the arguments that can be passed to the - * `QdrantVectorStore` constructor. It includes options for specifying a - * `QdrantClient` instance, the URL and API key for a Qdrant database, and - * the name and configuration for a collection. - */ -export interface QdrantLibArgs { - client?: QdrantClient; - url?: string; - apiKey?: string; - collectionName?: string; - collectionConfig?: QdrantSchemas["CreateCollection"]; -} - -/** - * Type for the response returned by a search operation in the Qdrant - * database. It includes the score and payload (metadata and content) for - * each point (document) in the search results. - */ -type QdrantSearchResponse = QdrantSchemas["ScoredPoint"] & { - payload: { - metadata: object; - content: string; - }; -}; - -/** - * Class that extends the `VectorStore` base class to interact with a - * Qdrant database. It includes methods for adding documents and vectors - * to the Qdrant database, searching for similar vectors, and ensuring the - * existence of a collection in the database. - */ -export class QdrantVectorStore extends VectorStore { - get lc_secrets(): { [key: string]: string } { - return { - apiKey: "QDRANT_API_KEY", - url: "QDRANT_URL", - }; - } - - client: QdrantClient; - - collectionName: string; - - collectionConfig?: QdrantSchemas["CreateCollection"]; - - _vectorstoreType(): string { - return "qdrant"; - } - - constructor(embeddings: Embeddings, args: QdrantLibArgs) { - super(embeddings, args); - - const url = args.url ?? getEnvironmentVariable("QDRANT_URL"); - const apiKey = args.apiKey ?? getEnvironmentVariable("QDRANT_API_KEY"); - - if (!args.client && !url) { - throw new Error("Qdrant client or url address must be set."); - } - - this.client = - args.client || - new QdrantClient({ - url, - apiKey, - }); - - this.collectionName = args.collectionName ?? "documents"; - - this.collectionConfig = args.collectionConfig; - } - - /** - * Method to add documents to the Qdrant database. It generates vectors - * from the documents using the `Embeddings` instance and then adds the - * vectors to the database. - * @param documents Array of `Document` instances to be added to the Qdrant database. - * @returns Promise that resolves when the documents have been added to the database. - */ - async addDocuments(documents: Document[]): Promise { - const texts = documents.map(({ pageContent }) => pageContent); - await this.addVectors( - await this.embeddings.embedDocuments(texts), - documents - ); - } - - /** - * Method to add vectors to the Qdrant database. Each vector is associated - * with a document, which is stored as the payload for a point in the - * database. - * @param vectors Array of vectors to be added to the Qdrant database. - * @param documents Array of `Document` instances associated with the vectors. - * @returns Promise that resolves when the vectors have been added to the database. - */ - async addVectors(vectors: number[][], documents: Document[]): Promise { - if (vectors.length === 0) { - return; - } - - await this.ensureCollection(); - - const points = vectors.map((embedding, idx) => ({ - id: uuid(), - vector: embedding, - payload: { - content: documents[idx].pageContent, - metadata: documents[idx].metadata, - }, - })); - - try { - await this.client.upsert(this.collectionName, { - wait: true, - points, - }); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } catch (e: any) { - const error = new Error( - `${e?.status ?? "Undefined error code"} ${e?.message}: ${ - e?.data?.status?.error - }` - ); - throw error; - } - } - - /** - * Method to search for vectors in the Qdrant database that are similar to - * a given query vector. The search results include the score and payload - * (metadata and content) for each similar vector. - * @param query Query vector to search for similar vectors in the Qdrant database. - * @param k Optional number of similar vectors to return. If not specified, all similar vectors are returned. - * @param filter Optional filter to apply to the search results. - * @returns Promise that resolves with an array of tuples, where each tuple includes a `Document` instance and a score for a similar vector. - */ - async similaritySearchVectorWithScore( - query: number[], - k?: number, - filter?: QdrantSchemas["Filter"] - ): Promise<[Document, number][]> { - if (!query) { - return []; - } - - await this.ensureCollection(); - - const results = await this.client.search(this.collectionName, { - vector: query, - limit: k, - filter, - }); - - const result: [Document, number][] = ( - results as QdrantSearchResponse[] - ).map((res) => [ - new Document({ - metadata: res.payload.metadata, - pageContent: res.payload.content, - }), - res.score, - ]); - - return result; - } - - /** - * Method to ensure the existence of a collection in the Qdrant database. - * If the collection does not exist, it is created. - * @returns Promise that resolves when the existence of the collection has been ensured. - */ - async ensureCollection() { - const response = await this.client.getCollections(); - - const collectionNames = response.collections.map( - (collection) => collection.name - ); - - if (!collectionNames.includes(this.collectionName)) { - const collectionConfig = this.collectionConfig ?? { - vectors: { - size: (await this.embeddings.embedQuery("test")).length, - distance: "Cosine", - }, - }; - await this.client.createCollection(this.collectionName, collectionConfig); - } - } - - /** - * Static method to create a `QdrantVectorStore` instance from texts. Each - * text is associated with metadata and converted to a `Document` - * instance, which is then added to the Qdrant database. - * @param texts Array of texts to be converted to `Document` instances and added to the Qdrant database. - * @param metadatas Array or single object of metadata to be associated with the texts. - * @param embeddings `Embeddings` instance used to generate vectors from the texts. - * @param dbConfig `QdrantLibArgs` instance specifying the configuration for the Qdrant database. - * @returns Promise that resolves with a new `QdrantVectorStore` instance. - */ - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - dbConfig: QdrantLibArgs - ): Promise { - const docs = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - return QdrantVectorStore.fromDocuments(docs, embeddings, dbConfig); - } - - /** - * Static method to create a `QdrantVectorStore` instance from `Document` - * instances. The documents are added to the Qdrant database. - * @param docs Array of `Document` instances to be added to the Qdrant database. - * @param embeddings `Embeddings` instance used to generate vectors from the documents. - * @param dbConfig `QdrantLibArgs` instance specifying the configuration for the Qdrant database. - * @returns Promise that resolves with a new `QdrantVectorStore` instance. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig: QdrantLibArgs - ): Promise { - const instance = new this(embeddings, dbConfig); - await instance.addDocuments(docs); - return instance; - } - - /** - * Static method to create a `QdrantVectorStore` instance from an existing - * collection in the Qdrant database. - * @param embeddings `Embeddings` instance used to generate vectors from the documents in the collection. - * @param dbConfig `QdrantLibArgs` instance specifying the configuration for the Qdrant database. - * @returns Promise that resolves with a new `QdrantVectorStore` instance. - */ - static async fromExistingCollection( - embeddings: Embeddings, - dbConfig: QdrantLibArgs - ): Promise { - const instance = new this(embeddings, dbConfig); - await instance.ensureCollection(); - return instance; - } -} +export * from "@langchain/community/vectorstores/qdrant"; diff --git a/langchain/src/vectorstores/redis.ts b/langchain/src/vectorstores/redis.ts index bc33d68ae28d..fa3b6aba3c78 100644 --- a/langchain/src/vectorstores/redis.ts +++ b/langchain/src/vectorstores/redis.ts @@ -1,458 +1 @@ -import type { - createCluster, - createClient, - RediSearchSchema, - SearchOptions, -} from "redis"; -import { SchemaFieldTypes, VectorAlgorithms } from "redis"; -import { Embeddings } from "../embeddings/base.js"; -import { VectorStore } from "./base.js"; -import { Document } from "../document.js"; - -// Adapated from internal redis types which aren't exported -/** - * Type for creating a schema vector field. It includes the algorithm, - * distance metric, and initial capacity. - */ -export type CreateSchemaVectorField< - T extends VectorAlgorithms, - A extends Record -> = { - ALGORITHM: T; - DISTANCE_METRIC: "L2" | "IP" | "COSINE"; - INITIAL_CAP?: number; -} & A; -/** - * Type for creating a flat schema vector field. It extends - * CreateSchemaVectorField with a block size property. - */ -export type CreateSchemaFlatVectorField = CreateSchemaVectorField< - VectorAlgorithms.FLAT, - { - BLOCK_SIZE?: number; - } ->; -/** - * Type for creating a HNSW schema vector field. It extends - * CreateSchemaVectorField with M, EF_CONSTRUCTION, and EF_RUNTIME - * properties. - */ -export type CreateSchemaHNSWVectorField = CreateSchemaVectorField< - VectorAlgorithms.HNSW, - { - M?: number; - EF_CONSTRUCTION?: number; - EF_RUNTIME?: number; - } ->; - -type CreateIndexOptions = NonNullable< - Parameters["ft"]["create"]>[3] ->; - -export type RedisSearchLanguages = `${NonNullable< - CreateIndexOptions["LANGUAGE"] ->}`; - -export type RedisVectorStoreIndexOptions = Omit< - CreateIndexOptions, - "LANGUAGE" -> & { LANGUAGE?: RedisSearchLanguages }; - -/** - * Interface for the configuration of the RedisVectorStore. It includes - * the Redis client, index name, index options, key prefix, content key, - * metadata key, vector key, and filter. - */ -export interface RedisVectorStoreConfig { - redisClient: - | ReturnType - | ReturnType; - indexName: string; - indexOptions?: CreateSchemaFlatVectorField | CreateSchemaHNSWVectorField; - createIndexOptions?: Omit; // PREFIX must be set with keyPrefix - keyPrefix?: string; - contentKey?: string; - metadataKey?: string; - vectorKey?: string; - filter?: RedisVectorStoreFilterType; -} - -/** - * Interface for the options when adding documents to the - * RedisVectorStore. It includes keys and batch size. - */ -export interface RedisAddOptions { - keys?: string[]; - batchSize?: number; -} - -/** - * Type for the filter used in the RedisVectorStore. It is an array of - * strings. - */ -export type RedisVectorStoreFilterType = string[]; - -/** - * Class representing a RedisVectorStore. It extends the VectorStore class - * and includes methods for adding documents and vectors, performing - * similarity searches, managing the index, and more. - */ -export class RedisVectorStore extends VectorStore { - declare FilterType: RedisVectorStoreFilterType; - - private redisClient: - | ReturnType - | ReturnType; - - indexName: string; - - indexOptions: CreateSchemaFlatVectorField | CreateSchemaHNSWVectorField; - - createIndexOptions: CreateIndexOptions; - - keyPrefix: string; - - contentKey: string; - - metadataKey: string; - - vectorKey: string; - - filter?: RedisVectorStoreFilterType; - - _vectorstoreType(): string { - return "redis"; - } - - constructor(embeddings: Embeddings, _dbConfig: RedisVectorStoreConfig) { - super(embeddings, _dbConfig); - - this.redisClient = _dbConfig.redisClient; - this.indexName = _dbConfig.indexName; - this.indexOptions = _dbConfig.indexOptions ?? { - ALGORITHM: VectorAlgorithms.HNSW, - DISTANCE_METRIC: "COSINE", - }; - this.keyPrefix = _dbConfig.keyPrefix ?? `doc:${this.indexName}:`; - this.contentKey = _dbConfig.contentKey ?? "content"; - this.metadataKey = _dbConfig.metadataKey ?? "metadata"; - this.vectorKey = _dbConfig.vectorKey ?? "content_vector"; - this.filter = _dbConfig.filter; - this.createIndexOptions = { - ON: "HASH", - PREFIX: this.keyPrefix, - ...(_dbConfig.createIndexOptions as CreateIndexOptions), - }; - } - - /** - * Method for adding documents to the RedisVectorStore. It first converts - * the documents to texts and then adds them as vectors. - * @param documents The documents to add. - * @param options Optional parameters for adding the documents. - * @returns A promise that resolves when the documents have been added. - */ - async addDocuments(documents: Document[], options?: RedisAddOptions) { - const texts = documents.map(({ pageContent }) => pageContent); - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents, - options - ); - } - - /** - * Method for adding vectors to the RedisVectorStore. It checks if the - * index exists and creates it if it doesn't, then adds the vectors in - * batches. - * @param vectors The vectors to add. - * @param documents The documents associated with the vectors. - * @param keys Optional keys for the vectors. - * @param batchSize The size of the batches in which to add the vectors. Defaults to 1000. - * @returns A promise that resolves when the vectors have been added. - */ - async addVectors( - vectors: number[][], - documents: Document[], - { keys, batchSize = 1000 }: RedisAddOptions = {} - ) { - if (!vectors.length || !vectors[0].length) { - throw new Error("No vectors provided"); - } - // check if the index exists and create it if it doesn't - await this.createIndex(vectors[0].length); - - const info = await this.redisClient.ft.info(this.indexName); - const lastKeyCount = parseInt(info.numDocs, 10) || 0; - const multi = this.redisClient.multi(); - - vectors.map(async (vector, idx) => { - const key = - keys && keys.length - ? keys[idx] - : `${this.keyPrefix}${idx + lastKeyCount}`; - const metadata = - documents[idx] && documents[idx].metadata - ? documents[idx].metadata - : {}; - - multi.hSet(key, { - [this.vectorKey]: this.getFloat32Buffer(vector), - [this.contentKey]: documents[idx].pageContent, - [this.metadataKey]: this.escapeSpecialChars(JSON.stringify(metadata)), - }); - - // write batch - if (idx % batchSize === 0) { - await multi.exec(); - } - }); - - // insert final batch - await multi.exec(); - } - - /** - * Method for performing a similarity search in the RedisVectorStore. It - * returns the documents and their scores. - * @param query The query vector. - * @param k The number of nearest neighbors to return. - * @param filter Optional filter to apply to the search. - * @returns A promise that resolves to an array of documents and their scores. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: RedisVectorStoreFilterType - ): Promise<[Document, number][]> { - if (filter && this.filter) { - throw new Error("cannot provide both `filter` and `this.filter`"); - } - - const _filter = filter ?? this.filter; - const results = await this.redisClient.ft.search( - this.indexName, - ...this.buildQuery(query, k, _filter) - ); - const result: [Document, number][] = []; - - if (results.total) { - for (const res of results.documents) { - if (res.value) { - const document = res.value; - if (document.vector_score) { - result.push([ - new Document({ - pageContent: document[this.contentKey] as string, - metadata: JSON.parse( - this.unEscapeSpecialChars(document.metadata as string) - ), - }), - Number(document.vector_score), - ]); - } - } - } - } - - return result; - } - - /** - * Static method for creating a new instance of RedisVectorStore from - * texts. It creates documents from the texts and metadata, then adds them - * to the RedisVectorStore. - * @param texts The texts to add. - * @param metadatas The metadata associated with the texts. - * @param embeddings The embeddings to use. - * @param dbConfig The configuration for the RedisVectorStore. - * @returns A promise that resolves to a new instance of RedisVectorStore. - */ - static fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - dbConfig: RedisVectorStoreConfig - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - return RedisVectorStore.fromDocuments(docs, embeddings, dbConfig); - } - - /** - * Static method for creating a new instance of RedisVectorStore from - * documents. It adds the documents to the RedisVectorStore. - * @param docs The documents to add. - * @param embeddings The embeddings to use. - * @param dbConfig The configuration for the RedisVectorStore. - * @returns A promise that resolves to a new instance of RedisVectorStore. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig: RedisVectorStoreConfig - ): Promise { - const instance = new this(embeddings, dbConfig); - await instance.addDocuments(docs); - return instance; - } - - /** - * Method for checking if an index exists in the RedisVectorStore. - * @returns A promise that resolves to a boolean indicating whether the index exists. - */ - async checkIndexExists() { - try { - await this.redisClient.ft.info(this.indexName); - } catch (err) { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - if ((err as any)?.message.includes("unknown command")) { - throw new Error( - "Failed to run FT.INFO command. Please ensure that you are running a RediSearch-capable Redis instance: https://js.langchain.com/docs/modules/data_connection/vectorstores/integrations/redis#setup" - ); - } - // index doesn't exist - return false; - } - - return true; - } - - /** - * Method for creating an index in the RedisVectorStore. If the index - * already exists, it does nothing. - * @param dimensions The dimensions of the index - * @returns A promise that resolves when the index has been created. - */ - async createIndex(dimensions = 1536): Promise { - if (await this.checkIndexExists()) { - return; - } - - const schema: RediSearchSchema = { - [this.vectorKey]: { - type: SchemaFieldTypes.VECTOR, - TYPE: "FLOAT32", - DIM: dimensions, - ...this.indexOptions, - }, - [this.contentKey]: SchemaFieldTypes.TEXT, - [this.metadataKey]: SchemaFieldTypes.TEXT, - }; - - await this.redisClient.ft.create( - this.indexName, - schema, - this.createIndexOptions - ); - } - - /** - * Method for dropping an index from the RedisVectorStore. - * @param deleteDocuments Optional boolean indicating whether to drop the associated documents. - * @returns A promise that resolves to a boolean indicating whether the index was dropped. - */ - async dropIndex(deleteDocuments?: boolean): Promise { - try { - const options = deleteDocuments ? { DD: deleteDocuments } : undefined; - await this.redisClient.ft.dropIndex(this.indexName, options); - - return true; - } catch (err) { - return false; - } - } - - /** - * Deletes vectors from the vector store. - * @param params The parameters for deleting vectors. - * @returns A promise that resolves when the vectors have been deleted. - */ - async delete(params: { deleteAll: boolean }): Promise { - if (params.deleteAll) { - await this.dropIndex(true); - } else { - throw new Error(`Invalid parameters passed to "delete".`); - } - } - - private buildQuery( - query: number[], - k: number, - filter?: RedisVectorStoreFilterType - ): [string, SearchOptions] { - const vectorScoreField = "vector_score"; - - let hybridFields = "*"; - // if a filter is set, modify the hybrid query - if (filter && filter.length) { - // `filter` is a list of strings, then it's applied using the OR operator in the metadata key - // for example: filter = ['foo', 'bar'] => this will filter all metadata containing either 'foo' OR 'bar' - hybridFields = `@${this.metadataKey}:(${this.prepareFilter(filter)})`; - } - - const baseQuery = `${hybridFields} => [KNN ${k} @${this.vectorKey} $vector AS ${vectorScoreField}]`; - const returnFields = [this.metadataKey, this.contentKey, vectorScoreField]; - - const options: SearchOptions = { - PARAMS: { - vector: this.getFloat32Buffer(query), - }, - RETURN: returnFields, - SORTBY: vectorScoreField, - DIALECT: 2, - LIMIT: { - from: 0, - size: k, - }, - }; - - return [baseQuery, options]; - } - - private prepareFilter(filter: RedisVectorStoreFilterType) { - return filter.map(this.escapeSpecialChars).join("|"); - } - - /** - * Escapes all '-' characters. - * RediSearch considers '-' as a negative operator, hence we need - * to escape it - * @see https://redis.io/docs/stack/search/reference/query_syntax - * - * @param str - * @returns - */ - private escapeSpecialChars(str: string) { - return str.replaceAll("-", "\\-"); - } - - /** - * Unescapes all '-' characters, returning the original string - * - * @param str - * @returns - */ - private unEscapeSpecialChars(str: string) { - return str.replaceAll("\\-", "-"); - } - - /** - * Converts the vector to the buffer Redis needs to - * correctly store an embedding - * - * @param vector - * @returns Buffer - */ - private getFloat32Buffer(vector: number[]) { - return Buffer.from(new Float32Array(vector).buffer); - } -} +export * from "@langchain/community/vectorstores/redis"; diff --git a/langchain/src/vectorstores/rockset.ts b/langchain/src/vectorstores/rockset.ts index 38a4f21dc5e3..5ae76eefb55f 100644 --- a/langchain/src/vectorstores/rockset.ts +++ b/langchain/src/vectorstores/rockset.ts @@ -1,453 +1 @@ -import { MainApi } from "@rockset/client"; -import type { CreateCollectionRequest } from "@rockset/client/dist/codegen/api.d.ts"; -import { Collection } from "@rockset/client/dist/codegen/api.js"; - -import { VectorStore } from "./base.js"; -import { Embeddings } from "../embeddings/base.js"; -import { Document } from "../document.js"; - -/** - * Generic Rockset vector storage error - */ -export class RocksetStoreError extends Error { - /** - * Constructs a RocksetStoreError - * @param message The error message - */ - constructor(message: string) { - super(message); - this.name = this.constructor.name; - } -} - -/** - * Error that is thrown when a RocksetStore function is called - * after `destroy()` is called (meaning the collection would be - * deleted). - */ -export class RocksetStoreDestroyedError extends RocksetStoreError { - constructor() { - super("The Rockset store has been destroyed"); - this.name = this.constructor.name; - } -} - -/** - * Functions to measure vector distance/similarity by. - * See https://rockset.com/docs/vector-functions/#vector-distance-functions - * @enum SimilarityMetric - */ -export const SimilarityMetric = { - CosineSimilarity: "COSINE_SIM", - EuclideanDistance: "EUCLIDEAN_DIST", - DotProduct: "DOT_PRODUCT", -} as const; - -export type SimilarityMetric = - (typeof SimilarityMetric)[keyof typeof SimilarityMetric]; - -interface CollectionNotFoundError { - message_key: string; -} - -/** - * Vector store arguments - * @interface RocksetStore - */ -export interface RocksetLibArgs { - /** - * The rockset client object constructed with `rocksetConfigure` - * @type {MainAPI} - */ - client: MainApi; - /** - * The name of the Rockset collection to store vectors - * @type {string} - */ - collectionName: string; - /** - * The name of othe Rockset workspace that holds @member collectionName - * @type {string} - */ - workspaceName?: string; - /** - * The name of the collection column to contain page contnent of documents - * @type {string} - */ - textKey?: string; - /** - * The name of the collection column to contain vectors - * @type {string} - */ - embeddingKey?: string; - /** - * The SQL `WHERE` clause to filter by - * @type {string} - */ - filter?: string; - /** - * The metric used to measure vector relationship - * @type {SimilarityMetric} - */ - similarityMetric?: SimilarityMetric; -} - -/** - * Exposes Rockset's vector store/search functionality - */ -export class RocksetStore extends VectorStore { - declare FilterType: string; - - client: MainApi; - - collectionName: string; - - workspaceName: string; - - textKey: string; - - embeddingKey: string; - - filter?: string; - - private _similarityMetric: SimilarityMetric; - - private similarityOrder: "ASC" | "DESC"; - - private destroyed: boolean; - - /** - * Gets a string representation of the type of this VectorStore - * @returns {"rockset"} - */ - _vectorstoreType(): "rockset" { - return "rockset"; - } - - /** - * Constructs a new RocksetStore - * @param {Embeddings} embeddings Object used to embed queries and - * page content - * @param {RocksetLibArgs} args - */ - constructor(embeddings: Embeddings, args: RocksetLibArgs) { - super(embeddings, args); - - this.embeddings = embeddings; - this.client = args.client; - this.collectionName = args.collectionName; - this.workspaceName = args.workspaceName ?? "commons"; - this.textKey = args.textKey ?? "text"; - this.embeddingKey = args.embeddingKey ?? "embedding"; - this.filter = args.filter; - this.similarityMetric = - args.similarityMetric ?? SimilarityMetric.CosineSimilarity; - this.setSimilarityOrder(); - } - - /** - * Sets the object's similarity order based on what - * SimilarityMetric is being used - */ - private setSimilarityOrder() { - this.checkIfDestroyed(); - this.similarityOrder = - this.similarityMetric === SimilarityMetric.EuclideanDistance - ? "ASC" - : "DESC"; - } - - /** - * Embeds and adds Documents to the store. - * @param {Documents[]} documents The documents to store - * @returns {Promise} The _id's of the documents added - */ - async addDocuments(documents: Document[]): Promise { - const texts = documents.map(({ pageContent }) => pageContent); - return await this.addVectors( - await this.embeddings.embedDocuments(texts), - documents - ); - } - - /** - * Adds vectors to the store given their corresponding Documents - * @param {number[][]} vectors The vectors to store - * @param {Document[]} documents The Documents they represent - * @return {Promise} The _id's of the added documents - */ - async addVectors(vectors: number[][], documents: Document[]) { - this.checkIfDestroyed(); - const rocksetDocs = []; - for (let i = 0; i < documents.length; i += 1) { - const currDoc = documents[i]; - const currVector = vectors[i]; - rocksetDocs.push({ - [this.textKey]: currDoc.pageContent, - [this.embeddingKey]: currVector, - ...currDoc.metadata, - }); - } - - return ( - await this.client.documents.addDocuments( - this.workspaceName, - this.collectionName, - { - data: rocksetDocs, - } - ) - ).data?.map((docStatus) => docStatus._id || ""); - } - - /** - * Deletes Rockset documements given their _id's - * @param {string[]} ids The IDS to remove documents with - */ - async delete(ids: string[]): Promise { - this.checkIfDestroyed(); - await this.client.documents.deleteDocuments( - this.workspaceName, - this.collectionName, - { - data: ids.map((id) => ({ _id: id })), - } - ); - } - - /** - * Gets the most relevant documents to a query along - * with their similarity score. The returned documents - * are ordered by similarity (most similar at the first - * index) - * @param {number[]} query The embedded query to search - * the store by - * @param {number} k The number of documents to retreive - * @param {string?} filter The SQL `WHERE` clause to filter by - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: string - ): Promise<[Document, number][]> { - this.checkIfDestroyed(); - if (filter && this.filter) { - throw new RocksetStoreError( - "cannot provide both `filter` and `this.filter`" - ); - } - const similarityKey = "similarity"; - const _filter = filter ?? this.filter; - return ( - ( - await this.client.queries.query({ - sql: { - query: ` - SELECT - * EXCEPT("${this.embeddingKey}"), - "${this.textKey}", - ${this.similarityMetric}(:query, "${ - this.embeddingKey - }") AS "${similarityKey}" - FROM - "${this.workspaceName}"."${this.collectionName}" - ${_filter ? `WHERE ${_filter}` : ""} - ORDER BY - "${similarityKey}" ${this.similarityOrder} - LIMIT - ${k} - `, - parameters: [ - { - name: "query", - type: "", - value: `[${query.toString()}]`, - }, - ], - }, - }) - ).results?.map((rocksetDoc) => [ - new Document>({ - pageContent: rocksetDoc[this.textKey], - metadata: (({ - [this.textKey]: t, - [similarityKey]: s, - ...rocksetDoc - }) => rocksetDoc)(rocksetDoc), - }), - rocksetDoc[similarityKey] as number, - ]) ?? [] - ); - } - - /** - * Constructs and returns a RocksetStore object given texts to store. - * @param {string[]} texts The texts to store - * @param {object[] | object} metadatas The metadatas that correspond - * to @param texts - * @param {Embeddings} embeddings The object used to embed queries - * and page content - * @param {RocksetLibArgs} dbConfig The options to be passed into the - * RocksetStore constructor - * @returns {RocksetStore} - */ - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - dbConfig: RocksetLibArgs - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - - return RocksetStore.fromDocuments(docs, embeddings, dbConfig); - } - - /** - * Constructs, adds docs to, and returns a RocksetStore object - * @param {Document[]} docs The Documents to store - * @param {Embeddings} embeddings The object used to embed queries - * and page content - * @param {RocksetLibArgs} dbConfig The options to be passed into the - * RocksetStore constructor - * @returns {RocksetStore} - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig: RocksetLibArgs - ): Promise { - const args = { ...dbConfig, textKey: dbConfig.textKey ?? "text" }; - const instance = new this(embeddings, args); - await instance.addDocuments(docs); - return instance; - } - - /** - * Checks if a Rockset collection exists. - * @param {RocksetLibArgs} dbConfig The object containing the collection - * and workspace names - * @return {boolean} whether the collection exists - */ - private static async collectionExists(dbConfig: RocksetLibArgs) { - try { - await dbConfig.client.collections.getCollection( - dbConfig.workspaceName ?? "commons", - dbConfig.collectionName - ); - } catch (err) { - if ( - (err as CollectionNotFoundError).message_key === - "COLLECTION_DOES_NOT_EXIST" - ) { - return false; - } - throw err; - } - return true; - } - - /** - * Checks whether a Rockset collection is ready to be queried. - * @param {RocksetLibArgs} dbConfig The object containing the collection - * name and workspace - * @return {boolean} whether the collection is ready - */ - private static async collectionReady(dbConfig: RocksetLibArgs) { - return ( - ( - await dbConfig.client.collections.getCollection( - dbConfig.workspaceName ?? "commons", - dbConfig.collectionName - ) - ).data?.status === Collection.StatusEnum.READY - ); - } - - /** - * Deletes the collection this RocksetStore uses - * @param {boolean?} waitUntilDeletion Whether to sleep until the - * collection is ready to be - * queried - */ - async destroy(waitUntilDeletion?: boolean) { - await this.client.collections.deleteCollection( - this.workspaceName, - this.collectionName - ); - this.destroyed = true; - if (waitUntilDeletion) { - while ( - await RocksetStore.collectionExists({ - collectionName: this.collectionName, - client: this.client, - }) - ); - } - } - - /** - * Checks if this RocksetStore has been destroyed. - * @throws {RocksetStoreDestroyederror} if it has. - */ - private checkIfDestroyed() { - if (this.destroyed) { - throw new RocksetStoreDestroyedError(); - } - } - - /** - * Creates a new Rockset collection and returns a RocksetStore that - * uses it - * @param {Embeddings} embeddings Object used to embed queries and - * page content - * @param {RocksetLibArgs} dbConfig The options to be passed into the - * RocksetStore constructor - * @param {CreateCollectionRequest?} collectionOptions The arguments to sent with the - * HTTP request when creating the - * collection. Setting a field mapping - * that `VECTOR_ENFORCE`s is recommended - * when using this function. See - * https://rockset.com/docs/vector-functions/#vector_enforce - * @returns {RocsketStore} - */ - static async withNewCollection( - embeddings: Embeddings, - dbConfig: RocksetLibArgs, - collectionOptions?: CreateCollectionRequest - ): Promise { - if ( - collectionOptions?.name && - dbConfig.collectionName !== collectionOptions?.name - ) { - throw new RocksetStoreError( - "`dbConfig.name` and `collectionOptions.name` do not match" - ); - } - await dbConfig.client.collections.createCollection( - dbConfig.workspaceName ?? "commons", - collectionOptions || { name: dbConfig.collectionName } - ); - while ( - !(await this.collectionExists(dbConfig)) || - !(await this.collectionReady(dbConfig)) - ); - return new this(embeddings, dbConfig); - } - - public get similarityMetric() { - return this._similarityMetric; - } - - public set similarityMetric(metric: SimilarityMetric) { - this._similarityMetric = metric; - this.setSimilarityOrder(); - } -} +export * from "@langchain/community/vectorstores/rockset"; diff --git a/langchain/src/vectorstores/singlestore.ts b/langchain/src/vectorstores/singlestore.ts index e67d5e625f61..8d5df3a1dc1a 100644 --- a/langchain/src/vectorstores/singlestore.ts +++ b/langchain/src/vectorstores/singlestore.ts @@ -1,294 +1 @@ -import type { - Pool, - RowDataPacket, - OkPacket, - ResultSetHeader, - FieldPacket, - PoolOptions, -} from "mysql2/promise"; -import { format } from "mysql2"; -import { createPool } from "mysql2/promise"; -import { VectorStore } from "./base.js"; -import { Embeddings } from "../embeddings/base.js"; -import { Document } from "../document.js"; - -// eslint-disable-next-line @typescript-eslint/no-explicit-any -export type Metadata = Record; - -export type DistanceMetrics = "DOT_PRODUCT" | "EUCLIDEAN_DISTANCE"; - -const OrderingDirective: Record = { - DOT_PRODUCT: "DESC", - EUCLIDEAN_DISTANCE: "", -}; - -export interface ConnectionOptions extends PoolOptions {} - -type ConnectionWithUri = { - connectionOptions?: never; - connectionURI: string; -}; - -type ConnectionWithOptions = { - connectionURI?: never; - connectionOptions: ConnectionOptions; -}; - -type ConnectionConfig = ConnectionWithUri | ConnectionWithOptions; - -export type SingleStoreVectorStoreConfig = ConnectionConfig & { - tableName?: string; - contentColumnName?: string; - vectorColumnName?: string; - metadataColumnName?: string; - distanceMetric?: DistanceMetrics; -}; - -function withConnectAttributes( - config: SingleStoreVectorStoreConfig -): ConnectionOptions { - let newOptions: ConnectionOptions = {}; - if (config.connectionURI) { - newOptions = { - uri: config.connectionURI, - }; - } else if (config.connectionOptions) { - newOptions = { - ...config.connectionOptions, - }; - } - const result: ConnectionOptions = { - ...newOptions, - connectAttributes: { - ...newOptions.connectAttributes, - }, - }; - - if (!result.connectAttributes) { - result.connectAttributes = {}; - } - - result.connectAttributes = { - ...result.connectAttributes, - _connector_name: "langchain js sdk", - _connector_version: "1.0.0", - _driver_name: "Node-MySQL-2", - }; - - return result; -} - -/** - * Class for interacting with SingleStoreDB, a high-performance - * distributed SQL database. It provides vector storage and vector - * functions. - */ -export class SingleStoreVectorStore extends VectorStore { - connectionPool: Pool; - - tableName: string; - - contentColumnName: string; - - vectorColumnName: string; - - metadataColumnName: string; - - distanceMetric: DistanceMetrics; - - _vectorstoreType(): string { - return "singlestore"; - } - - constructor(embeddings: Embeddings, config: SingleStoreVectorStoreConfig) { - super(embeddings, config); - this.connectionPool = createPool(withConnectAttributes(config)); - this.tableName = config.tableName ?? "embeddings"; - this.contentColumnName = config.contentColumnName ?? "content"; - this.vectorColumnName = config.vectorColumnName ?? "vector"; - this.metadataColumnName = config.metadataColumnName ?? "metadata"; - this.distanceMetric = config.distanceMetric ?? "DOT_PRODUCT"; - } - - /** - * Creates a new table in the SingleStoreDB database if it does not - * already exist. - */ - async createTableIfNotExists(): Promise { - await this.connectionPool - .execute(`CREATE TABLE IF NOT EXISTS ${this.tableName} ( - ${this.contentColumnName} TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci, - ${this.vectorColumnName} BLOB, - ${this.metadataColumnName} JSON);`); - } - - /** - * Ends the connection to the SingleStoreDB database. - */ - async end(): Promise { - return this.connectionPool.end(); - } - - /** - * Adds new documents to the SingleStoreDB database. - * @param documents An array of Document objects. - */ - async addDocuments(documents: Document[]): Promise { - const texts = documents.map(({ pageContent }) => pageContent); - const vectors = await this.embeddings.embedDocuments(texts); - return this.addVectors(vectors, documents); - } - - /** - * Adds new vectors to the SingleStoreDB database. - * @param vectors An array of vectors. - * @param documents An array of Document objects. - */ - async addVectors(vectors: number[][], documents: Document[]): Promise { - await this.createTableIfNotExists(); - const { tableName } = this; - - await Promise.all( - vectors.map(async (vector, idx) => { - try { - await this.connectionPool.execute( - format( - `INSERT INTO ${tableName} VALUES (?, JSON_ARRAY_PACK('[?]'), ?);`, - [ - documents[idx].pageContent, - vector, - JSON.stringify(documents[idx].metadata), - ] - ) - ); - } catch (error) { - console.error(`Error adding vector at index ${idx}:`, error); - } - }) - ); - } - - /** - * Performs a similarity search on the vectors stored in the SingleStoreDB - * database. - * @param query An array of numbers representing the query vector. - * @param k The number of nearest neighbors to return. - * @param filter Optional metadata to filter the vectors by. - * @returns Top matching vectors with score - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: Metadata - ): Promise<[Document, number][]> { - // build the where clause from filter - const whereArgs: string[] = []; - const buildWhereClause = (record: Metadata, argList: string[]): string => { - const whereTokens: string[] = []; - for (const key in record) - if (record[key] !== undefined) { - if ( - typeof record[key] === "object" && - record[key] != null && - !Array.isArray(record[key]) - ) { - whereTokens.push( - buildWhereClause(record[key], argList.concat([key])) - ); - } else { - whereTokens.push( - `JSON_EXTRACT_JSON(${this.metadataColumnName}, `.concat( - Array.from({ length: argList.length + 1 }, () => "?").join( - ", " - ), - ") = ?" - ) - ); - whereArgs.push(...argList, key, JSON.stringify(record[key])); - } - } - return whereTokens.join(" AND "); - }; - const whereClause = filter - ? "WHERE ".concat(buildWhereClause(filter, [])) - : ""; - - const [rows]: [ - ( - | RowDataPacket[] - | RowDataPacket[][] - | OkPacket - | OkPacket[] - | ResultSetHeader - ), - FieldPacket[] - ] = await this.connectionPool.query( - format( - `SELECT ${this.contentColumnName}, - ${this.metadataColumnName}, - ${this.distanceMetric}(${ - this.vectorColumnName - }, JSON_ARRAY_PACK('[?]')) as __score FROM ${ - this.tableName - } ${whereClause} - ORDER BY __score ${OrderingDirective[this.distanceMetric]} LIMIT ?;`, - [query, ...whereArgs, k] - ) - ); - const result: [Document, number][] = []; - for (const row of rows as RowDataPacket[]) { - const rowData = row as unknown as Record; - result.push([ - new Document({ - pageContent: rowData[this.contentColumnName] as string, - metadata: rowData[this.metadataColumnName] as Record, - }), - Number(rowData.score), - ]); - } - return result; - } - - /** - * Creates a new instance of the SingleStoreVectorStore class from a list - * of texts. - * @param texts An array of strings. - * @param metadatas An array of metadata objects. - * @param embeddings An Embeddings object. - * @param dbConfig A SingleStoreVectorStoreConfig object. - * @returns A new SingleStoreVectorStore instance - */ - static async fromTexts( - texts: string[], - metadatas: object[], - embeddings: Embeddings, - dbConfig: SingleStoreVectorStoreConfig - ): Promise { - const docs = texts.map((text, idx) => { - const metadata = Array.isArray(metadatas) ? metadatas[idx] : metadatas; - return new Document({ - pageContent: text, - metadata, - }); - }); - return SingleStoreVectorStore.fromDocuments(docs, embeddings, dbConfig); - } - - /** - * Creates a new instance of the SingleStoreVectorStore class from a list - * of Document objects. - * @param docs An array of Document objects. - * @param embeddings An Embeddings object. - * @param dbConfig A SingleStoreVectorStoreConfig object. - * @returns A new SingleStoreVectorStore instance - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig: SingleStoreVectorStoreConfig - ): Promise { - const instance = new this(embeddings, dbConfig); - await instance.addDocuments(docs); - return instance; - } -} +export * from "@langchain/community/vectorstores/singlestore"; diff --git a/langchain/src/vectorstores/supabase.ts b/langchain/src/vectorstores/supabase.ts index 0659af76cf97..103b466c3ea5 100644 --- a/langchain/src/vectorstores/supabase.ts +++ b/langchain/src/vectorstores/supabase.ts @@ -1,310 +1 @@ -import type { SupabaseClient } from "@supabase/supabase-js"; -import type { PostgrestFilterBuilder } from "@supabase/postgrest-js"; -import { MaxMarginalRelevanceSearchOptions, VectorStore } from "./base.js"; -import { Embeddings } from "../embeddings/base.js"; -import { Document } from "../document.js"; -import { maximalMarginalRelevance } from "../util/math.js"; - -/** - * Interface for the parameters required for searching embeddings. - */ -interface SearchEmbeddingsParams { - query_embedding: number[]; - match_count: number; // int - filter?: SupabaseMetadata | SupabaseFilterRPCCall; -} - -// eslint-disable-next-line @typescript-eslint/ban-types, @typescript-eslint/no-explicit-any -export type SupabaseMetadata = Record; -// eslint-disable-next-line @typescript-eslint/ban-types, @typescript-eslint/no-explicit-any -export type SupabaseFilter = PostgrestFilterBuilder; -export type SupabaseFilterRPCCall = (rpcCall: SupabaseFilter) => SupabaseFilter; - -/** - * Interface for the response returned when searching embeddings. - */ -interface SearchEmbeddingsResponse { - id: number; - content: string; - metadata: object; - embedding: number[]; - similarity: number; -} - -/** - * Interface for the arguments required to initialize a Supabase library. - */ -export interface SupabaseLibArgs { - client: SupabaseClient; - tableName?: string; - queryName?: string; - filter?: SupabaseMetadata | SupabaseFilterRPCCall; - upsertBatchSize?: number; -} - -/** - * Class for interacting with a Supabase database to store and manage - * vectors. - */ -export class SupabaseVectorStore extends VectorStore { - declare FilterType: SupabaseMetadata | SupabaseFilterRPCCall; - - client: SupabaseClient; - - tableName: string; - - queryName: string; - - filter?: SupabaseMetadata | SupabaseFilterRPCCall; - - upsertBatchSize = 500; - - _vectorstoreType(): string { - return "supabase"; - } - - constructor(embeddings: Embeddings, args: SupabaseLibArgs) { - super(embeddings, args); - - this.client = args.client; - this.tableName = args.tableName || "documents"; - this.queryName = args.queryName || "match_documents"; - this.filter = args.filter; - this.upsertBatchSize = args.upsertBatchSize ?? this.upsertBatchSize; - } - - /** - * Adds documents to the vector store. - * @param documents The documents to add. - * @param options Optional parameters for adding the documents. - * @returns A promise that resolves when the documents have been added. - */ - async addDocuments( - documents: Document[], - options?: { ids?: string[] | number[] } - ) { - const texts = documents.map(({ pageContent }) => pageContent); - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents, - options - ); - } - - /** - * Adds vectors to the vector store. - * @param vectors The vectors to add. - * @param documents The documents associated with the vectors. - * @param options Optional parameters for adding the vectors. - * @returns A promise that resolves with the IDs of the added vectors when the vectors have been added. - */ - async addVectors( - vectors: number[][], - documents: Document[], - options?: { ids?: string[] | number[] } - ) { - const rows = vectors.map((embedding, idx) => ({ - content: documents[idx].pageContent, - embedding, - metadata: documents[idx].metadata, - })); - - // upsert returns 500/502/504 (yes really any of them) if given too many rows/characters - // ~2000 trips it, but my data is probably smaller than average pageContent and metadata - let returnedIds: string[] = []; - for (let i = 0; i < rows.length; i += this.upsertBatchSize) { - const chunk = rows.slice(i, i + this.upsertBatchSize).map((row, j) => { - if (options?.ids) { - return { id: options.ids[i + j], ...row }; - } - return row; - }); - - const res = await this.client.from(this.tableName).upsert(chunk).select(); - if (res.error) { - throw new Error( - `Error inserting: ${res.error.message} ${res.status} ${res.statusText}` - ); - } - if (res.data) { - returnedIds = returnedIds.concat(res.data.map((row) => row.id)); - } - } - return returnedIds; - } - - /** - * Deletes vectors from the vector store. - * @param params The parameters for deleting vectors. - * @returns A promise that resolves when the vectors have been deleted. - */ - async delete(params: { ids: string[] | number[] }): Promise { - const { ids } = params; - for (const id of ids) { - await this.client.from(this.tableName).delete().eq("id", id); - } - } - - protected async _searchSupabase( - query: number[], - k: number, - filter?: this["FilterType"] - ): Promise { - if (filter && this.filter) { - throw new Error("cannot provide both `filter` and `this.filter`"); - } - const _filter = filter ?? this.filter ?? {}; - const matchDocumentsParams: Partial = { - query_embedding: query, - }; - - let filterFunction: SupabaseFilterRPCCall; - - if (typeof _filter === "function") { - filterFunction = (rpcCall) => _filter(rpcCall).limit(k); - } else if (typeof _filter === "object") { - matchDocumentsParams.filter = _filter; - matchDocumentsParams.match_count = k; - filterFunction = (rpcCall) => rpcCall; - } else { - throw new Error("invalid filter type"); - } - - const rpcCall = this.client.rpc(this.queryName, matchDocumentsParams); - - const { data: searches, error } = await filterFunction(rpcCall); - - if (error) { - throw new Error( - `Error searching for documents: ${error.code} ${error.message} ${error.details}` - ); - } - - return searches; - } - - /** - * Performs a similarity search on the vector store. - * @param query The query vector. - * @param k The number of results to return. - * @param filter Optional filter to apply to the search. - * @returns A promise that resolves with the search results when the search is complete. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: this["FilterType"] - ): Promise<[Document, number][]> { - const searches = await this._searchSupabase(query, k, filter); - const result: [Document, number][] = searches.map((resp) => [ - new Document({ - metadata: resp.metadata, - pageContent: resp.content, - }), - resp.similarity, - ]); - - return result; - } - - /** - * Return documents selected using the maximal marginal relevance. - * Maximal marginal relevance optimizes for similarity to the query AND diversity - * among selected documents. - * - * @param {string} query - Text to look up documents similar to. - * @param {number} options.k - Number of documents to return. - * @param {number} options.fetchK=20- Number of documents to fetch before passing to the MMR algorithm. - * @param {number} options.lambda=0.5 - Number between 0 and 1 that determines the degree of diversity among the results, - * where 0 corresponds to maximum diversity and 1 to minimum diversity. - * @param {SupabaseLibArgs} options.filter - Optional filter to apply to the search. - * - * @returns {Promise} - List of documents selected by maximal marginal relevance. - */ - async maxMarginalRelevanceSearch( - query: string, - options: MaxMarginalRelevanceSearchOptions - ): Promise { - const queryEmbedding = await this.embeddings.embedQuery(query); - - const searches = await this._searchSupabase( - queryEmbedding, - options.fetchK ?? 20, - options.filter - ); - - const embeddingList = searches.map((searchResp) => searchResp.embedding); - - const mmrIndexes = maximalMarginalRelevance( - queryEmbedding, - embeddingList, - options.lambda, - options.k - ); - - return mmrIndexes.map( - (idx) => - new Document({ - metadata: searches[idx].metadata, - pageContent: searches[idx].content, - }) - ); - } - - /** - * Creates a new SupabaseVectorStore instance from an array of texts. - * @param texts The texts to create documents from. - * @param metadatas The metadata for the documents. - * @param embeddings The embeddings to use. - * @param dbConfig The configuration for the Supabase database. - * @returns A promise that resolves with a new SupabaseVectorStore instance when the instance has been created. - */ - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - dbConfig: SupabaseLibArgs - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - return SupabaseVectorStore.fromDocuments(docs, embeddings, dbConfig); - } - - /** - * Creates a new SupabaseVectorStore instance from an array of documents. - * @param docs The documents to create the instance from. - * @param embeddings The embeddings to use. - * @param dbConfig The configuration for the Supabase database. - * @returns A promise that resolves with a new SupabaseVectorStore instance when the instance has been created. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig: SupabaseLibArgs - ): Promise { - const instance = new this(embeddings, dbConfig); - await instance.addDocuments(docs); - return instance; - } - - /** - * Creates a new SupabaseVectorStore instance from an existing index. - * @param embeddings The embeddings to use. - * @param dbConfig The configuration for the Supabase database. - * @returns A promise that resolves with a new SupabaseVectorStore instance when the instance has been created. - */ - static async fromExistingIndex( - embeddings: Embeddings, - dbConfig: SupabaseLibArgs - ): Promise { - const instance = new this(embeddings, dbConfig); - return instance; - } -} +export * from "@langchain/community/vectorstores/supabase"; diff --git a/langchain/src/vectorstores/tests/convex/convex/langchain/db.ts b/langchain/src/vectorstores/tests/convex/convex/langchain/db.ts deleted file mode 100644 index e09d4ecfe02d..000000000000 --- a/langchain/src/vectorstores/tests/convex/convex/langchain/db.ts +++ /dev/null @@ -1 +0,0 @@ -export * from "../../../../../util/convex.js"; diff --git a/langchain/src/vectorstores/tigris.ts b/langchain/src/vectorstores/tigris.ts index d4cfd73828fe..d3702e081d3b 100644 --- a/langchain/src/vectorstores/tigris.ts +++ b/langchain/src/vectorstores/tigris.ts @@ -1,177 +1 @@ -import * as uuid from "uuid"; - -import { Embeddings } from "../embeddings/base.js"; -import { VectorStore } from "./base.js"; -import { Document } from "../document.js"; - -/** - * Type definition for the arguments required to initialize a - * TigrisVectorStore instance. - */ -export type TigrisLibArgs = { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - index: any; -}; - -/** - * Class for managing and operating vector search applications with - * Tigris, an open-source Serverless NoSQL Database and Search Platform. - */ -export class TigrisVectorStore extends VectorStore { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - index?: any; - - _vectorstoreType(): string { - return "tigris"; - } - - constructor(embeddings: Embeddings, args: TigrisLibArgs) { - super(embeddings, args); - - this.embeddings = embeddings; - this.index = args.index; - } - - /** - * Method to add an array of documents to the Tigris database. - * @param documents An array of Document instances to be added to the Tigris database. - * @param options Optional parameter that can either be an array of string IDs or an object with a property 'ids' that is an array of string IDs. - * @returns A Promise that resolves when the documents have been added to the Tigris database. - */ - async addDocuments( - documents: Document[], - options?: { ids?: string[] } | string[] - ): Promise { - const texts = documents.map(({ pageContent }) => pageContent); - await this.addVectors( - await this.embeddings.embedDocuments(texts), - documents, - options - ); - } - - /** - * Method to add vectors to the Tigris database. - * @param vectors An array of vectors to be added to the Tigris database. - * @param documents An array of Document instances corresponding to the vectors. - * @param options Optional parameter that can either be an array of string IDs or an object with a property 'ids' that is an array of string IDs. - * @returns A Promise that resolves when the vectors have been added to the Tigris database. - */ - async addVectors( - vectors: number[][], - documents: Document[], - options?: { ids?: string[] } | string[] - ) { - if (vectors.length === 0) { - return; - } - - if (vectors.length !== documents.length) { - throw new Error(`Vectors and metadatas must have the same length`); - } - - const ids = Array.isArray(options) ? options : options?.ids; - const documentIds = ids == null ? documents.map(() => uuid.v4()) : ids; - await this.index?.addDocumentsWithVectors({ - ids: documentIds, - embeddings: vectors, - documents: documents.map(({ metadata, pageContent }) => ({ - content: pageContent, - metadata, - })), - }); - } - - /** - * Method to perform a similarity search in the Tigris database and return - * the k most similar vectors along with their similarity scores. - * @param query The query vector. - * @param k The number of most similar vectors to return. - * @param filter Optional filter object to apply during the search. - * @returns A Promise that resolves to an array of tuples, each containing a Document and its similarity score. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: object - ) { - const result = await this.index?.similaritySearchVectorWithScore({ - query, - k, - filter, - }); - - if (!result) { - return []; - } - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - return result.map(([document, score]: [any, any]) => [ - new Document({ - pageContent: document.content, - metadata: document.metadata, - }), - score, - ]) as [Document, number][]; - } - - /** - * Static method to create a new instance of TigrisVectorStore from an - * array of texts. - * @param texts An array of texts to be converted into Document instances and added to the Tigris database. - * @param metadatas Either an array of metadata objects or a single metadata object to be associated with the texts. - * @param embeddings An instance of Embeddings to be used for embedding the texts. - * @param dbConfig An instance of TigrisLibArgs to be used for configuring the Tigris database. - * @returns A Promise that resolves to a new instance of TigrisVectorStore. - */ - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - dbConfig: TigrisLibArgs - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - return TigrisVectorStore.fromDocuments(docs, embeddings, dbConfig); - } - - /** - * Static method to create a new instance of TigrisVectorStore from an - * array of Document instances. - * @param docs An array of Document instances to be added to the Tigris database. - * @param embeddings An instance of Embeddings to be used for embedding the documents. - * @param dbConfig An instance of TigrisLibArgs to be used for configuring the Tigris database. - * @returns A Promise that resolves to a new instance of TigrisVectorStore. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig: TigrisLibArgs - ): Promise { - const instance = new this(embeddings, dbConfig); - await instance.addDocuments(docs); - return instance; - } - - /** - * Static method to create a new instance of TigrisVectorStore from an - * existing index. - * @param embeddings An instance of Embeddings to be used for embedding the documents. - * @param dbConfig An instance of TigrisLibArgs to be used for configuring the Tigris database. - * @returns A Promise that resolves to a new instance of TigrisVectorStore. - */ - static async fromExistingIndex( - embeddings: Embeddings, - dbConfig: TigrisLibArgs - ): Promise { - const instance = new this(embeddings, dbConfig); - return instance; - } -} +export * from "@langchain/community/vectorstores/tigris"; diff --git a/langchain/src/vectorstores/typeorm.ts b/langchain/src/vectorstores/typeorm.ts index d0130303ed68..234f53c8d663 100644 --- a/langchain/src/vectorstores/typeorm.ts +++ b/langchain/src/vectorstores/typeorm.ts @@ -1,298 +1 @@ -import { Metadata } from "@opensearch-project/opensearch/api/types.js"; -import { DataSource, DataSourceOptions, EntitySchema } from "typeorm"; -import { VectorStore } from "./base.js"; -import { Embeddings } from "../embeddings/base.js"; -import { Document } from "../document.js"; -import { getEnvironmentVariable } from "../util/env.js"; - -/** - * Interface that defines the arguments required to create a - * `TypeORMVectorStore` instance. It includes Postgres connection options, - * table name, filter, and verbosity level. - */ -export interface TypeORMVectorStoreArgs { - postgresConnectionOptions: DataSourceOptions; - tableName?: string; - filter?: Metadata; - verbose?: boolean; -} - -/** - * Class that extends the `Document` base class and adds an `embedding` - * property. It represents a document in the vector store. - */ -export class TypeORMVectorStoreDocument extends Document { - embedding: string; - - id?: string; -} - -const defaultDocumentTableName = "documents"; - -/** - * Class that provides an interface to a Postgres vector database. It - * extends the `VectorStore` base class and implements methods for adding - * documents and vectors, performing similarity searches, and ensuring the - * existence of a table in the database. - */ -export class TypeORMVectorStore extends VectorStore { - declare FilterType: Metadata; - - tableName: string; - - documentEntity: EntitySchema; - - filter?: Metadata; - - appDataSource: DataSource; - - _verbose?: boolean; - - _vectorstoreType(): string { - return "typeorm"; - } - - private constructor(embeddings: Embeddings, fields: TypeORMVectorStoreArgs) { - super(embeddings, fields); - this.tableName = fields.tableName || defaultDocumentTableName; - this.filter = fields.filter; - - const TypeORMDocumentEntity = new EntitySchema({ - name: fields.tableName ?? defaultDocumentTableName, - columns: { - id: { - generated: "uuid", - type: "uuid", - primary: true, - }, - pageContent: { - type: String, - }, - metadata: { - type: "jsonb", - }, - embedding: { - type: String, - }, - }, - }); - const appDataSource = new DataSource({ - entities: [TypeORMDocumentEntity], - ...fields.postgresConnectionOptions, - }); - this.appDataSource = appDataSource; - this.documentEntity = TypeORMDocumentEntity; - - this._verbose = - getEnvironmentVariable("LANGCHAIN_VERBOSE") === "true" ?? - fields.verbose ?? - false; - } - - /** - * Static method to create a new `TypeORMVectorStore` instance from a - * `DataSource`. It initializes the `DataSource` if it is not already - * initialized. - * @param embeddings Embeddings instance. - * @param fields `TypeORMVectorStoreArgs` instance. - * @returns A new instance of `TypeORMVectorStore`. - */ - static async fromDataSource( - embeddings: Embeddings, - fields: TypeORMVectorStoreArgs - ): Promise { - const postgresqlVectorStore = new TypeORMVectorStore(embeddings, fields); - - if (!postgresqlVectorStore.appDataSource.isInitialized) { - await postgresqlVectorStore.appDataSource.initialize(); - } - - return postgresqlVectorStore; - } - - /** - * Method to add documents to the vector store. It ensures the existence - * of the table in the database, converts the documents into vectors, and - * adds them to the store. - * @param documents Array of `Document` instances. - * @returns Promise that resolves when the documents have been added. - */ - async addDocuments(documents: Document[]): Promise { - const texts = documents.map(({ pageContent }) => pageContent); - // This will create the table if it does not exist. We can call it every time as it doesn't - // do anything if the table already exists, and it is not expensive in terms of performance - await this.ensureTableInDatabase(); - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents - ); - } - - /** - * Method to add vectors to the vector store. It converts the vectors into - * rows and inserts them into the database. - * @param vectors Array of vectors. - * @param documents Array of `Document` instances. - * @returns Promise that resolves when the vectors have been added. - */ - async addVectors(vectors: number[][], documents: Document[]): Promise { - const rows = vectors.map((embedding, idx) => { - const embeddingString = `[${embedding.join(",")}]`; - const documentRow = { - pageContent: documents[idx].pageContent, - embedding: embeddingString, - metadata: documents[idx].metadata, - }; - - return documentRow; - }); - - const documentRepository = this.appDataSource.getRepository( - this.documentEntity - ); - - const chunkSize = 500; - for (let i = 0; i < rows.length; i += chunkSize) { - const chunk = rows.slice(i, i + chunkSize); - - try { - await documentRepository.save(chunk); - } catch (e) { - console.error(e); - throw new Error(`Error inserting: ${chunk[0].pageContent}`); - } - } - } - - /** - * Method to perform a similarity search in the vector store. It returns - * the `k` most similar documents to the query vector, along with their - * similarity scores. - * @param query Query vector. - * @param k Number of most similar documents to return. - * @param filter Optional filter to apply to the search. - * @returns Promise that resolves with an array of tuples, each containing a `TypeORMVectorStoreDocument` and its similarity score. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: this["FilterType"] - ): Promise<[TypeORMVectorStoreDocument, number][]> { - const embeddingString = `[${query.join(",")}]`; - const _filter = filter ?? "{}"; - - const queryString = ` - SELECT *, embedding <=> $1 as "_distance" - FROM ${this.tableName} - WHERE metadata @> $2 - ORDER BY "_distance" ASC - LIMIT $3;`; - - const documents = await this.appDataSource.query(queryString, [ - embeddingString, - _filter, - k, - ]); - - const results = [] as [TypeORMVectorStoreDocument, number][]; - for (const doc of documents) { - if (doc._distance != null && doc.pageContent != null) { - const document = new Document(doc) as TypeORMVectorStoreDocument; - document.id = doc.id; - results.push([document, doc._distance]); - } - } - - return results; - } - - /** - * Method to ensure the existence of the table in the database. It creates - * the table if it does not already exist. - * @returns Promise that resolves when the table has been ensured. - */ - async ensureTableInDatabase(): Promise { - await this.appDataSource.query("CREATE EXTENSION IF NOT EXISTS vector;"); - await this.appDataSource.query( - 'CREATE EXTENSION IF NOT EXISTS "uuid-ossp";' - ); - - await this.appDataSource.query(` - CREATE TABLE IF NOT EXISTS ${this.tableName} ( - "id" uuid NOT NULL DEFAULT uuid_generate_v4() PRIMARY KEY, - "pageContent" text, - metadata jsonb, - embedding vector - ); - `); - } - - /** - * Static method to create a new `TypeORMVectorStore` instance from an - * array of texts and their metadata. It converts the texts into - * `Document` instances and adds them to the store. - * @param texts Array of texts. - * @param metadatas Array of metadata objects or a single metadata object. - * @param embeddings Embeddings instance. - * @param dbConfig `TypeORMVectorStoreArgs` instance. - * @returns Promise that resolves with a new instance of `TypeORMVectorStore`. - */ - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - dbConfig: TypeORMVectorStoreArgs - ): Promise { - const docs = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - - return TypeORMVectorStore.fromDocuments(docs, embeddings, dbConfig); - } - - /** - * Static method to create a new `TypeORMVectorStore` instance from an - * array of `Document` instances. It adds the documents to the store. - * @param docs Array of `Document` instances. - * @param embeddings Embeddings instance. - * @param dbConfig `TypeORMVectorStoreArgs` instance. - * @returns Promise that resolves with a new instance of `TypeORMVectorStore`. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig: TypeORMVectorStoreArgs - ): Promise { - const instance = await TypeORMVectorStore.fromDataSource( - embeddings, - dbConfig - ); - await instance.addDocuments(docs); - - return instance; - } - - /** - * Static method to create a new `TypeORMVectorStore` instance from an - * existing index. - * @param embeddings Embeddings instance. - * @param dbConfig `TypeORMVectorStoreArgs` instance. - * @returns Promise that resolves with a new instance of `TypeORMVectorStore`. - */ - static async fromExistingIndex( - embeddings: Embeddings, - dbConfig: TypeORMVectorStoreArgs - ): Promise { - const instance = await TypeORMVectorStore.fromDataSource( - embeddings, - dbConfig - ); - return instance; - } -} +export * from "@langchain/community/vectorstores/typeorm"; diff --git a/langchain/src/vectorstores/typesense.ts b/langchain/src/vectorstores/typesense.ts index e6f608312cc0..bf389cde3fcb 100644 --- a/langchain/src/vectorstores/typesense.ts +++ b/langchain/src/vectorstores/typesense.ts @@ -1,320 +1 @@ -import type { Client } from "typesense"; -import type { MultiSearchRequestSchema } from "typesense/lib/Typesense/MultiSearch.js"; -import type { - SearchResponseHit, - DocumentSchema, -} from "typesense/lib/Typesense/Documents.js"; -import type { Document } from "../document.js"; -import { Embeddings } from "../embeddings/base.js"; -import { VectorStore } from "./base.js"; -import { AsyncCaller, AsyncCallerParams } from "../util/async_caller.js"; - -/** - * Interface for the response hit from a vector search in Typesense. - */ -interface VectorSearchResponseHit - extends SearchResponseHit { - vector_distance?: number; -} - -/** - * Typesense vector store configuration. - */ -export interface TypesenseConfig extends AsyncCallerParams { - /** - * Typesense client. - */ - typesenseClient: Client; - /** - * Typesense schema name in which documents will be stored and searched. - */ - schemaName: string; - /** - * Typesense search parameters. - * @default { q: '*', per_page: 5, query_by: '' } - */ - searchParams?: MultiSearchRequestSchema; - /** - * Column names. - */ - columnNames?: { - /** - * Vector column name. - * @default 'vec' - */ - vector?: string; - /** - * Page content column name. - * @default 'text' - */ - pageContent?: string; - /** - * Metadata column names. - * @default [] - */ - metadataColumnNames?: string[]; - }; - /** - * Replace default import function. - * Default import function will update documents if there is a document with the same id. - * @param data - * @param collectionName - */ - import? = Record>( - data: T[], - collectionName: string - ): Promise; -} - -/** - * Typesense vector store. - */ -export class Typesense extends VectorStore { - declare FilterType: Partial; - - private client: Client; - - private schemaName: string; - - private searchParams: MultiSearchRequestSchema; - - private vectorColumnName: string; - - private pageContentColumnName: string; - - private metadataColumnNames: string[]; - - private caller: AsyncCaller; - - private import: ( - data: Record[], - collectionName: string - ) => Promise; - - _vectorstoreType(): string { - return "typesense"; - } - - constructor(embeddings: Embeddings, config: TypesenseConfig) { - super(embeddings, config); - - // Assign config values to class properties. - this.client = config.typesenseClient; - this.schemaName = config.schemaName; - this.searchParams = config.searchParams || { - q: "*", - per_page: 5, - query_by: "", - }; - this.vectorColumnName = config.columnNames?.vector || "vec"; - this.pageContentColumnName = config.columnNames?.pageContent || "text"; - this.metadataColumnNames = config.columnNames?.metadataColumnNames || []; - - // Assign import function. - this.import = config.import || this.importToTypesense.bind(this); - - this.caller = new AsyncCaller(config); - } - - /** - * Default function to import data to typesense - * @param data - * @param collectionName - */ - private async importToTypesense< - T extends Record = Record - >(data: T[], collectionName: string) { - const chunkSize = 2000; - for (let i = 0; i < data.length; i += chunkSize) { - const chunk = data.slice(i, i + chunkSize); - - await this.caller.call(async () => { - await this.client - .collections(collectionName) - .documents() - .import(chunk, { action: "emplace", dirty_values: "drop" }); - }); - } - } - - /** - * Transform documents to Typesense records. - * @param documents - * @returns Typesense records. - */ - _documentsToTypesenseRecords( - documents: Document[], - vectors: number[][] - ): Record[] { - const metadatas = documents.map((doc) => doc.metadata); - - const typesenseDocuments = documents.map((doc, index) => { - const metadata = metadatas[index]; - const objectWithMetadatas: Record = {}; - - this.metadataColumnNames.forEach((metadataColumnName) => { - objectWithMetadatas[metadataColumnName] = metadata[metadataColumnName]; - }); - - return { - [this.pageContentColumnName]: doc.pageContent, - [this.vectorColumnName]: vectors[index], - ...objectWithMetadatas, - }; - }); - - return typesenseDocuments; - } - - /** - * Transform the Typesense records to documents. - * @param typesenseRecords - * @returns documents - */ - _typesenseRecordsToDocuments( - typesenseRecords: - | { document?: Record; vector_distance: number }[] - | undefined - ): [Document, number][] { - const documents: [Document, number][] = - typesenseRecords?.map((hit) => { - const objectWithMetadatas: Record = {}; - const hitDoc = hit.document || {}; - this.metadataColumnNames.forEach((metadataColumnName) => { - objectWithMetadatas[metadataColumnName] = hitDoc[metadataColumnName]; - }); - - const document: Document = { - pageContent: (hitDoc[this.pageContentColumnName] as string) || "", - metadata: objectWithMetadatas, - }; - return [document, hit.vector_distance]; - }) || []; - - return documents; - } - - /** - * Add documents to the vector store. - * Will be updated if in the metadata there is a document with the same id if is using the default import function. - * Metadata will be added in the columns of the schema based on metadataColumnNames. - * @param documents Documents to add. - */ - async addDocuments(documents: Document[]) { - const typesenseDocuments = this._documentsToTypesenseRecords( - documents, - await this.embeddings.embedDocuments( - documents.map((doc) => doc.pageContent) - ) - ); - await this.import(typesenseDocuments, this.schemaName); - } - - /** - * Adds vectors to the vector store. - * @param vectors Vectors to add. - * @param documents Documents associated with the vectors. - */ - async addVectors(vectors: number[][], documents: Document[]) { - const typesenseDocuments = this._documentsToTypesenseRecords( - documents, - vectors - ); - await this.import(typesenseDocuments, this.schemaName); - } - - /** - * Search for similar documents with their similarity score. - * @param vectorPrompt vector to search for - * @param k amount of results to return - * @returns similar documents with their similarity score - */ - async similaritySearchVectorWithScore( - vectorPrompt: number[], - k?: number, - filter: this["FilterType"] = {} - ) { - const amount = k || this.searchParams.per_page || 5; - const vector_query = `${this.vectorColumnName}:([${vectorPrompt}], k:${amount})`; - const typesenseResponse = await this.client.multiSearch.perform( - { - searches: [ - { - ...this.searchParams, - ...filter, - per_page: amount, - vector_query, - collection: this.schemaName, - }, - ], - }, - {} - ); - const results = typesenseResponse.results[0].hits; - - const hits = results?.map((hit: VectorSearchResponseHit) => ({ - document: hit?.document || {}, - vector_distance: hit?.vector_distance || 2, - })) as - | { document: Record; vector_distance: number }[] - | undefined; - - return this._typesenseRecordsToDocuments(hits); - } - - /** - * Delete documents from the vector store. - * @param documentIds ids of the documents to delete - */ - async deleteDocuments(documentIds: string[]) { - await this.client - .collections(this.schemaName) - .documents() - .delete({ - filter_by: `id:=${documentIds.join(",")}`, - }); - } - - /** - * Create a vector store from documents. - * @param docs documents - * @param embeddings embeddings - * @param config Typesense configuration - * @returns Typesense vector store - * @warning You can omit this method, and only use the constructor and addDocuments. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - config: TypesenseConfig - ): Promise { - const instance = new Typesense(embeddings, config); - await instance.addDocuments(docs); - - return instance; - } - - /** - * Create a vector store from texts. - * @param texts - * @param metadatas - * @param embeddings - * @param config - * @returns Typesense vector store - */ - static async fromTexts( - texts: string[], - metadatas: object[], - embeddings: Embeddings, - config: TypesenseConfig - ) { - const instance = new Typesense(embeddings, config); - const documents: Document[] = texts.map((text, i) => ({ - pageContent: text, - metadata: metadatas[i] || {}, - })); - await instance.addDocuments(documents); - - return instance; - } -} +export * from "@langchain/community/vectorstores/typesense"; diff --git a/langchain/src/vectorstores/usearch.ts b/langchain/src/vectorstores/usearch.ts index 00faf0bfa25b..181f05100ade 100644 --- a/langchain/src/vectorstores/usearch.ts +++ b/langchain/src/vectorstores/usearch.ts @@ -1,223 +1 @@ -import usearch from "usearch"; -import * as uuid from "uuid"; -import { Embeddings } from "../embeddings/base.js"; -import { SaveableVectorStore } from "./base.js"; -import { Document } from "../document.js"; -import { SynchronousInMemoryDocstore } from "../stores/doc/in_memory.js"; - -/** - * Interface that defines the arguments that can be passed to the - * `USearch` constructor. It includes optional properties for a - * `docstore`, `index`, and `mapping`. - */ -export interface USearchArgs { - docstore?: SynchronousInMemoryDocstore; - index?: usearch.Index; - mapping?: Record; -} - -/** - * Class that extends `SaveableVectorStore` and provides methods for - * adding documents and vectors to a `usearch` index, performing - * similarity searches, and saving the index. - */ -export class USearch extends SaveableVectorStore { - _index?: usearch.Index; - - _mapping: Record; - - docstore: SynchronousInMemoryDocstore; - - args: USearchArgs; - - _vectorstoreType(): string { - return "usearch"; - } - - constructor(embeddings: Embeddings, args: USearchArgs) { - super(embeddings, args); - this.args = args; - this._index = args.index; - this._mapping = args.mapping ?? {}; - this.embeddings = embeddings; - this.docstore = args?.docstore ?? new SynchronousInMemoryDocstore(); - } - - /** - * Method that adds documents to the `usearch` index. It generates - * embeddings for the documents and adds them to the index. - * @param documents An array of `Document` instances to be added to the index. - * @returns A promise that resolves with an array of document IDs. - */ - async addDocuments(documents: Document[]) { - const texts = documents.map(({ pageContent }) => pageContent); - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents - ); - } - - public get index(): usearch.Index { - if (!this._index) { - throw new Error( - "Vector store not initialised yet. Try calling `fromTexts` or `fromDocuments` first." - ); - } - return this._index; - } - - private set index(index: usearch.Index) { - this._index = index; - } - - /** - * Method that adds vectors to the `usearch` index. It also updates the - * mapping between vector IDs and document IDs. - * @param vectors An array of vectors to be added to the index. - * @param documents An array of `Document` instances corresponding to the vectors. - * @returns A promise that resolves with an array of document IDs. - */ - async addVectors(vectors: number[][], documents: Document[]) { - if (vectors.length === 0) { - return []; - } - if (vectors.length !== documents.length) { - throw new Error(`Vectors and documents must have the same length`); - } - const dv = vectors[0].length; - if (!this._index) { - this._index = new usearch.Index({ - metric: "l2sq", - connectivity: BigInt(16), - dimensions: BigInt(dv), - }); - } - const d = this.index.dimensions(); - if (BigInt(dv) !== d) { - throw new Error( - `Vectors must have the same length as the number of dimensions (${d})` - ); - } - - const docstoreSize = this.index.size(); - const documentIds = []; - for (let i = 0; i < vectors.length; i += 1) { - const documentId = uuid.v4(); - documentIds.push(documentId); - const id = Number(docstoreSize) + i; - this.index.add(BigInt(id), new Float32Array(vectors[i])); - this._mapping[id] = documentId; - this.docstore.add({ [documentId]: documents[i] }); - } - return documentIds; - } - - /** - * Method that performs a similarity search in the `usearch` index. It - * returns the `k` most similar documents to a given query vector, along - * with their similarity scores. - * @param query The query vector. - * @param k The number of most similar documents to return. - * @returns A promise that resolves with an array of tuples, each containing a `Document` and its similarity score. - */ - async similaritySearchVectorWithScore(query: number[], k: number) { - const d = this.index.dimensions(); - if (BigInt(query.length) !== d) { - throw new Error( - `Query vector must have the same length as the number of dimensions (${d})` - ); - } - if (k > this.index.size()) { - const total = this.index.size(); - console.warn( - `k (${k}) is greater than the number of elements in the index (${total}), setting k to ${total}` - ); - // eslint-disable-next-line no-param-reassign - k = Number(total); - } - const result = this.index.search(new Float32Array(query), BigInt(k)); - - const return_list: [Document, number][] = []; - for (let i = 0; i < result.count; i += 1) { - const uuid = this._mapping[Number(result.keys[i])]; - return_list.push([this.docstore.search(uuid), result.distances[i]]); - } - - return return_list; - } - - /** - * Method that saves the `usearch` index and the document store to disk. - * @param directory The directory where the index and document store should be saved. - * @returns A promise that resolves when the save operation is complete. - */ - async save(directory: string) { - const fs = await import("node:fs/promises"); - const path = await import("node:path"); - await fs.mkdir(directory, { recursive: true }); - await Promise.all([ - this.index.save(path.join(directory, "usearch.index")), - await fs.writeFile( - path.join(directory, "docstore.json"), - JSON.stringify([ - Array.from(this.docstore._docs.entries()), - this._mapping, - ]) - ), - ]); - } - - /** - * Static method that creates a new `USearch` instance from a list of - * texts. It generates embeddings for the texts and adds them to the - * `usearch` index. - * @param texts An array of texts to be added to the index. - * @param metadatas Metadata associated with the texts. - * @param embeddings An instance of `Embeddings` used to generate embeddings for the texts. - * @param dbConfig Optional configuration for the document store. - * @returns A promise that resolves with a new `USearch` instance. - */ - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - dbConfig?: { - docstore?: SynchronousInMemoryDocstore; - } - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - return this.fromDocuments(docs, embeddings, dbConfig); - } - - /** - * Static method that creates a new `USearch` instance from a list of - * documents. It generates embeddings for the documents and adds them to - * the `usearch` index. - * @param docs An array of `Document` instances to be added to the index. - * @param embeddings An instance of `Embeddings` used to generate embeddings for the documents. - * @param dbConfig Optional configuration for the document store. - * @returns A promise that resolves with a new `USearch` instance. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig?: { - docstore?: SynchronousInMemoryDocstore; - } - ): Promise { - const args: USearchArgs = { - docstore: dbConfig?.docstore, - }; - const instance = new this(embeddings, args); - await instance.addDocuments(docs); - return instance; - } -} +export * from "@langchain/community/vectorstores/usearch"; diff --git a/langchain/src/vectorstores/vectara.ts b/langchain/src/vectorstores/vectara.ts index 1a35a0f7c4ba..7b8d7579e011 100644 --- a/langchain/src/vectorstores/vectara.ts +++ b/langchain/src/vectorstores/vectara.ts @@ -1,532 +1 @@ -import * as uuid from "uuid"; - -import { Document } from "../document.js"; -import { Embeddings } from "../embeddings/base.js"; -import { FakeEmbeddings } from "../embeddings/fake.js"; -import { getEnvironmentVariable } from "../util/env.js"; -import { VectorStore } from "./base.js"; - -/** - * Interface for the arguments required to initialize a VectaraStore - * instance. - */ -export interface VectaraLibArgs { - customerId: number; - corpusId: number | number[]; - apiKey: string; - verbose?: boolean; - source?: string; -} - -/** - * Interface for the headers required for Vectara API calls. - */ -interface VectaraCallHeader { - headers: { - "x-api-key": string; - "Content-Type": string; - "customer-id": string; - "X-Source": string; - }; -} - -/** - * Interface for the file objects to be uploaded to Vectara. - */ -export interface VectaraFile { - // The contents of the file to be uploaded. - blob: Blob; - // The name of the file to be uploaded. - fileName: string; -} - -/** - * Interface for the filter options used in Vectara API calls. - */ -export interface VectaraFilter { - // Example of a vectara filter string can be: "doc.rating > 3.0 and part.lang = 'deu'" - // See https://docs.vectara.com/docs/search-apis/sql/filter-overview for more details. - filter?: string; - // Improve retrieval accuracy by adjusting the balance (from 0 to 1), known as lambda, - // between neural search and keyword-based search factors. Values between 0.01 and 0.2 tend to work well. - // see https://docs.vectara.com/docs/api-reference/search-apis/lexical-matching for more details. - lambda?: number; - // The number of sentences before/after the matching segment to add to the context. - contextConfig?: VectaraContextConfig; -} - -/** - * Interface for the context configuration used in Vectara API calls. - */ -export interface VectaraContextConfig { - // The number of sentences before the matching segment to add. Default is 2. - sentencesBefore?: number; - // The number of sentences after the matching segment to add. Default is 2. - sentencesAfter?: number; -} - -/** - * Class for interacting with the Vectara API. Extends the VectorStore - * class. - */ -export class VectaraStore extends VectorStore { - get lc_secrets(): { [key: string]: string } { - return { - apiKey: "VECTARA_API_KEY", - corpusId: "VECTARA_CORPUS_ID", - customerId: "VECTARA_CUSTOMER_ID", - }; - } - - get lc_aliases(): { [key: string]: string } { - return { - apiKey: "vectara_api_key", - corpusId: "vectara_corpus_id", - customerId: "vectara_customer_id", - }; - } - - declare FilterType: VectaraFilter; - - private apiEndpoint = "api.vectara.io"; - - private apiKey: string; - - private corpusId: number[]; - - private customerId: number; - - private verbose: boolean; - - private source: string; - - private vectaraApiTimeoutSeconds = 60; - - _vectorstoreType(): string { - return "vectara"; - } - - constructor(args: VectaraLibArgs) { - // Vectara doesn't need embeddings, but we need to pass something to the parent constructor - // The embeddings are abstracted out from the user in Vectara. - super(new FakeEmbeddings(), args); - - const apiKey = args.apiKey ?? getEnvironmentVariable("VECTARA_API_KEY"); - if (!apiKey) { - throw new Error("Vectara api key is not provided."); - } - this.apiKey = apiKey; - this.source = args.source ?? "langchainjs"; - - const corpusId = - args.corpusId ?? - getEnvironmentVariable("VECTARA_CORPUS_ID") - ?.split(",") - .map((id) => { - const num = Number(id); - if (Number.isNaN(num)) - throw new Error("Vectara corpus id is not a number."); - return num; - }); - if (!corpusId) { - throw new Error("Vectara corpus id is not provided."); - } - - if (typeof corpusId === "number") { - this.corpusId = [corpusId]; - } else { - if (corpusId.length === 0) - throw new Error("Vectara corpus id is not provided."); - this.corpusId = corpusId; - } - - const customerId = - args.customerId ?? getEnvironmentVariable("VECTARA_CUSTOMER_ID"); - if (!customerId) { - throw new Error("Vectara customer id is not provided."); - } - this.customerId = customerId; - - this.verbose = args.verbose ?? false; - } - - /** - * Returns a header for Vectara API calls. - * @returns A Promise that resolves to a VectaraCallHeader object. - */ - async getJsonHeader(): Promise { - return { - headers: { - "x-api-key": this.apiKey, - "Content-Type": "application/json", - "customer-id": this.customerId.toString(), - "X-Source": this.source, - }, - }; - } - - /** - * Throws an error, as this method is not implemented. Use addDocuments - * instead. - * @param _vectors Not used. - * @param _documents Not used. - * @returns Does not return a value. - */ - async addVectors( - _vectors: number[][], - _documents: Document[] - ): Promise { - throw new Error( - "Method not implemented. Please call addDocuments instead." - ); - } - - /** - * Method to delete data from the Vectara corpus. - * @param params an array of document IDs to be deleted - * @returns Promise that resolves when the deletion is complete. - */ - async deleteDocuments(ids: string[]): Promise { - if (ids && ids.length > 0) { - const headers = await this.getJsonHeader(); - for (const id of ids) { - const data = { - customer_id: this.customerId, - corpus_id: this.corpusId[0], - document_id: id, - }; - - try { - const controller = new AbortController(); - const timeout = setTimeout( - () => controller.abort(), - this.vectaraApiTimeoutSeconds * 1000 - ); - const response = await fetch( - `https://${this.apiEndpoint}/v1/delete-doc`, - { - method: "POST", - headers: headers?.headers, - body: JSON.stringify(data), - signal: controller.signal, - } - ); - clearTimeout(timeout); - if (response.status !== 200) { - throw new Error( - `Vectara API returned status code ${response.status} when deleting document ${id}` - ); - } - } catch (e) { - const error = new Error(`Error ${(e as Error).message}`); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (error as any).code = 500; - throw error; - } - } - } else { - throw new Error(`no "ids" specified for deletion`); - } - } - - /** - * Adds documents to the Vectara store. - * @param documents An array of Document objects to add to the Vectara store. - * @returns A Promise that resolves to an array of document IDs indexed in Vectara. - */ - async addDocuments(documents: Document[]): Promise { - if (this.corpusId.length > 1) - throw new Error("addDocuments does not support multiple corpus ids"); - - const headers = await this.getJsonHeader(); - const doc_ids: string[] = []; - let countAdded = 0; - for (const document of documents) { - const doc_id: string = document.metadata?.document_id ?? uuid.v4(); - const data = { - customer_id: this.customerId, - corpus_id: this.corpusId[0], - document: { - document_id: doc_id, - title: document.metadata?.title ?? "", - metadata_json: JSON.stringify(document.metadata ?? {}), - section: [ - { - text: document.pageContent, - }, - ], - }, - }; - - try { - const controller = new AbortController(); - const timeout = setTimeout( - () => controller.abort(), - this.vectaraApiTimeoutSeconds * 1000 - ); - const response = await fetch(`https://${this.apiEndpoint}/v1/index`, { - method: "POST", - headers: headers?.headers, - body: JSON.stringify(data), - signal: controller.signal, - }); - clearTimeout(timeout); - const result = await response.json(); - if ( - result.status?.code !== "OK" && - result.status?.code !== "ALREADY_EXISTS" - ) { - const error = new Error( - `Vectara API returned status code ${ - result.status?.code - }: ${JSON.stringify(result.message)}` - ); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (error as any).code = 500; - throw error; - } else { - countAdded += 1; - doc_ids.push(doc_id); - } - } catch (e) { - const error = new Error( - `Error ${(e as Error).message} while adding document` - ); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (error as any).code = 500; - throw error; - } - } - if (this.verbose) { - console.log(`Added ${countAdded} documents to Vectara`); - } - - return doc_ids; - } - - /** - * Vectara provides a way to add documents directly via their API. This API handles - * pre-processing and chunking internally in an optimal manner. This method is a wrapper - * to utilize that API within LangChain. - * - * @param files An array of VectaraFile objects representing the files and their respective file names to be uploaded to Vectara. - * @param metadata Optional. An array of metadata objects corresponding to each file in the `filePaths` array. - * @returns A Promise that resolves to the number of successfully uploaded files. - */ - async addFiles( - files: VectaraFile[], - metadatas: Record | undefined = undefined - ) { - if (this.corpusId.length > 1) - throw new Error("addFiles does not support multiple corpus ids"); - - const doc_ids: string[] = []; - - for (const [index, file] of files.entries()) { - const md = metadatas ? metadatas[index] : {}; - - const data = new FormData(); - data.append("file", file.blob, file.fileName); - data.append("doc-metadata", JSON.stringify(md)); - - const response = await fetch( - `https://api.vectara.io/v1/upload?c=${this.customerId}&o=${this.corpusId[0]}&d=true`, - { - method: "POST", - headers: { - "x-api-key": this.apiKey, - "X-Source": this.source, - }, - body: data, - } - ); - - const { status } = response; - if (status === 409) { - throw new Error(`File at index ${index} already exists in Vectara`); - } else if (status !== 200) { - throw new Error(`Vectara API returned status code ${status}`); - } else { - const result = await response.json(); - const doc_id = result.document.documentId; - doc_ids.push(doc_id); - } - } - - if (this.verbose) { - console.log(`Uploaded ${files.length} files to Vectara`); - } - - return doc_ids; - } - - /** - * Performs a similarity search and returns documents along with their - * scores. - * @param query The query string for the similarity search. - * @param k Optional. The number of results to return. Default is 10. - * @param filter Optional. A VectaraFilter object to refine the search results. - * @returns A Promise that resolves to an array of tuples, each containing a Document and its score. - */ - async similaritySearchWithScore( - query: string, - k = 10, - filter: VectaraFilter | undefined = undefined - ): Promise<[Document, number][]> { - const headers = await this.getJsonHeader(); - - const corpusKeys = this.corpusId.map((corpusId) => ({ - customerId: this.customerId, - corpusId, - metadataFilter: filter?.filter ?? "", - lexicalInterpolationConfig: { lambda: filter?.lambda ?? 0.025 }, - })); - - const data = { - query: [ - { - query, - numResults: k, - contextConfig: { - sentencesAfter: filter?.contextConfig?.sentencesAfter ?? 2, - sentencesBefore: filter?.contextConfig?.sentencesBefore ?? 2, - }, - corpusKey: corpusKeys, - }, - ], - }; - - const controller = new AbortController(); - const timeout = setTimeout( - () => controller.abort(), - this.vectaraApiTimeoutSeconds * 1000 - ); - const response = await fetch(`https://${this.apiEndpoint}/v1/query`, { - method: "POST", - headers: headers?.headers, - body: JSON.stringify(data), - signal: controller.signal, - }); - clearTimeout(timeout); - if (response.status !== 200) { - throw new Error(`Vectara API returned status code ${response.status}`); - } - - const result = await response.json(); - const responses = result.responseSet[0].response; - const documents = result.responseSet[0].document; - - for (let i = 0; i < responses.length; i += 1) { - const responseMetadata = responses[i].metadata; - const documentMetadata = documents[responses[i].documentIndex].metadata; - const combinedMetadata: Record = {}; - - responseMetadata.forEach((item: { name: string; value: unknown }) => { - combinedMetadata[item.name] = item.value; - }); - - documentMetadata.forEach((item: { name: string; value: unknown }) => { - combinedMetadata[item.name] = item.value; - }); - - responses[i].metadata = combinedMetadata; - } - - const documentsAndScores = responses.map( - (response: { - text: string; - metadata: Record; - score: number; - }) => [ - new Document({ - pageContent: response.text, - metadata: response.metadata, - }), - response.score, - ] - ); - return documentsAndScores; - } - - /** - * Performs a similarity search and returns documents. - * @param query The query string for the similarity search. - * @param k Optional. The number of results to return. Default is 10. - * @param filter Optional. A VectaraFilter object to refine the search results. - * @returns A Promise that resolves to an array of Document objects. - */ - async similaritySearch( - query: string, - k = 10, - filter: VectaraFilter | undefined = undefined - ): Promise { - const resultWithScore = await this.similaritySearchWithScore( - query, - k, - filter - ); - return resultWithScore.map((result) => result[0]); - } - - /** - * Throws an error, as this method is not implemented. Use - * similaritySearch or similaritySearchWithScore instead. - * @param _query Not used. - * @param _k Not used. - * @param _filter Not used. - * @returns Does not return a value. - */ - async similaritySearchVectorWithScore( - _query: number[], - _k: number, - _filter?: VectaraFilter | undefined - ): Promise<[Document, number][]> { - throw new Error( - "Method not implemented. Please call similaritySearch or similaritySearchWithScore instead." - ); - } - - /** - * Creates a VectaraStore instance from texts. - * @param texts An array of text strings. - * @param metadatas Metadata for the texts. Can be a single object or an array of objects. - * @param _embeddings Not used. - * @param args A VectaraLibArgs object for initializing the VectaraStore instance. - * @returns A Promise that resolves to a VectaraStore instance. - */ - static fromTexts( - texts: string[], - metadatas: object | object[], - _embeddings: Embeddings, - args: VectaraLibArgs - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - - return VectaraStore.fromDocuments(docs, new FakeEmbeddings(), args); - } - - /** - * Creates a VectaraStore instance from documents. - * @param docs An array of Document objects. - * @param _embeddings Not used. - * @param args A VectaraLibArgs object for initializing the VectaraStore instance. - * @returns A Promise that resolves to a VectaraStore instance. - */ - static async fromDocuments( - docs: Document[], - _embeddings: Embeddings, - args: VectaraLibArgs - ): Promise { - const instance = new this(args); - await instance.addDocuments(docs); - return instance; - } -} +export * from "@langchain/community/vectorstores/vectara"; diff --git a/langchain/src/vectorstores/vercel_postgres.ts b/langchain/src/vectorstores/vercel_postgres.ts index 755d659faa3c..22ec83d4f4ad 100644 --- a/langchain/src/vectorstores/vercel_postgres.ts +++ b/langchain/src/vectorstores/vercel_postgres.ts @@ -1,393 +1 @@ -import { - type VercelPool, - type VercelPoolClient, - type VercelPostgresPoolConfig, - createPool, -} from "@vercel/postgres"; -import { VectorStore } from "./base.js"; -import { Embeddings } from "../embeddings/base.js"; -import { Document } from "../document.js"; -import { getEnvironmentVariable } from "../util/env.js"; - -type Metadata = Record>; - -/** - * Interface that defines the arguments required to create a - * `VercelPostgres` instance. It includes Postgres connection options, - * table name, filter, and verbosity level. - */ -export interface VercelPostgresFields { - pool: VercelPool; - client: VercelPoolClient; - tableName?: string; - columns?: { - idColumnName?: string; - vectorColumnName?: string; - contentColumnName?: string; - metadataColumnName?: string; - }; - filter?: Metadata; - verbose?: boolean; -} - -/** - * Class that provides an interface to a Vercel Postgres vector database. It - * extends the `VectorStore` base class and implements methods for adding - * documents and vectors and performing similarity searches. - */ -export class VercelPostgres extends VectorStore { - declare FilterType: Metadata; - - tableName: string; - - idColumnName: string; - - vectorColumnName: string; - - contentColumnName: string; - - metadataColumnName: string; - - filter?: Metadata; - - _verbose?: boolean; - - pool: VercelPool; - - client: VercelPoolClient; - - _vectorstoreType(): string { - return "vercel"; - } - - private constructor(embeddings: Embeddings, config: VercelPostgresFields) { - super(embeddings, config); - this.tableName = config.tableName ?? "langchain_vectors"; - this.filter = config.filter; - - this.vectorColumnName = config.columns?.vectorColumnName ?? "embedding"; - this.contentColumnName = config.columns?.contentColumnName ?? "text"; - this.idColumnName = config.columns?.idColumnName ?? "id"; - this.metadataColumnName = config.columns?.metadataColumnName ?? "metadata"; - - this.pool = config.pool; - this.client = config.client; - - this._verbose = - getEnvironmentVariable("LANGCHAIN_VERBOSE") === "true" ?? - !!config.verbose; - } - - /** - * Static method to create a new `VercelPostgres` instance from a - * connection. It creates a table if one does not exist, and calls - * `connect` to return a new instance of `VercelPostgres`. - * - * @param embeddings - Embeddings instance. - * @param fields - `VercelPostgres` configuration options. - * @returns A new instance of `VercelPostgres`. - */ - static async initialize( - embeddings: Embeddings, - config?: Partial & { - postgresConnectionOptions?: VercelPostgresPoolConfig; - } - ): Promise { - // Default maxUses to 1 for edge environments: - // https://github.com/vercel/storage/tree/main/packages/postgres#a-note-on-edge-environments - const pool = - config?.pool ?? - createPool({ maxUses: 1, ...config?.postgresConnectionOptions }); - const client = config?.client ?? (await pool.connect()); - const postgresqlVectorStore = new VercelPostgres(embeddings, { - ...config, - pool, - client, - }); - - await postgresqlVectorStore.ensureTableInDatabase(); - - return postgresqlVectorStore; - } - - /** - * Method to add documents to the vector store. It converts the documents into - * vectors, and adds them to the store. - * - * @param documents - Array of `Document` instances. - * @returns Promise that resolves when the documents have been added. - */ - async addDocuments( - documents: Document[], - options?: { ids?: string[] } - ): Promise { - const texts = documents.map(({ pageContent }) => pageContent); - - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents, - options - ); - } - - /** - * Generates the SQL placeholders for a specific row at the provided index. - * - * @param index - The index of the row for which placeholders need to be generated. - * @returns The SQL placeholders for the row values. - */ - protected generatePlaceholderForRowAt( - // eslint-disable-next-line @typescript-eslint/no-explicit-any - row: (string | Record)[], - index: number - ): string { - const base = index * row.length; - return `(${row.map((_, j) => `$${base + 1 + j}`)})`; - } - - /** - * Constructs the SQL query for inserting rows into the specified table. - * - * @param rows - The rows of data to be inserted, consisting of values and records. - * @param chunkIndex - The starting index for generating query placeholders based on chunk positioning. - * @returns The complete SQL INSERT INTO query string. - */ - protected async runInsertQuery( - // eslint-disable-next-line @typescript-eslint/no-explicit-any - rows: (string | Record)[][], - useIdColumn: boolean - ) { - const values = rows.map((row, j) => - this.generatePlaceholderForRowAt(row, j) - ); - const flatValues = rows.flat(); - return this.client.query( - ` - INSERT INTO ${this.tableName} ( - ${useIdColumn ? `${this.idColumnName},` : ""} - ${this.contentColumnName}, - ${this.vectorColumnName}, - ${this.metadataColumnName} - ) VALUES ${values.join(", ")} - ON CONFLICT (${this.idColumnName}) - DO UPDATE - SET - ${this.contentColumnName} = EXCLUDED.${this.contentColumnName}, - ${this.vectorColumnName} = EXCLUDED.${this.vectorColumnName}, - ${this.metadataColumnName} = EXCLUDED.${this.metadataColumnName} - RETURNING ${this.idColumnName}`, - flatValues - ); - } - - /** - * Method to add vectors to the vector store. It converts the vectors into - * rows and inserts them into the database. - * - * @param vectors - Array of vectors. - * @param documents - Array of `Document` instances. - * @returns Promise that resolves when the vectors have been added. - */ - async addVectors( - vectors: number[][], - documents: Document[], - options?: { ids?: string[] } - ): Promise { - if (options?.ids !== undefined && options?.ids.length !== vectors.length) { - throw new Error( - `If provided, the length of "ids" must be the same as the number of vectors.` - ); - } - const rows = vectors.map((embedding, idx) => { - const embeddingString = `[${embedding.join(",")}]`; - const row = [ - documents[idx].pageContent, - embeddingString, - documents[idx].metadata, - ]; - if (options?.ids) { - return [options.ids[idx], ...row]; - } - return row; - }); - - const chunkSize = 500; - - const ids = []; - - for (let i = 0; i < rows.length; i += chunkSize) { - const chunk = rows.slice(i, i + chunkSize); - try { - const result = await this.runInsertQuery( - chunk, - options?.ids !== undefined - ); - ids.push(...result.rows.map((row) => row[this.idColumnName])); - } catch (e) { - console.error(e); - throw new Error(`Error inserting: ${(e as Error).message}`); - } - } - return ids; - } - - /** - * Method to perform a similarity search in the vector store. It returns - * the `k` most similar documents to the query vector, along with their - * similarity scores. - * - * @param query - Query vector. - * @param k - Number of most similar documents to return. - * @param filter - Optional filter to apply to the search. - * @returns Promise that resolves with an array of tuples, each containing a `Document` and its similarity score. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: this["FilterType"] - ): Promise<[Document, number][]> { - const embeddingString = `[${query.join(",")}]`; - const _filter: this["FilterType"] = filter ?? {}; - const whereClauses = []; - const values = [embeddingString, k]; - let paramCount = values.length; - - for (const [key, value] of Object.entries(_filter)) { - if (typeof value === "object" && value !== null) { - const currentParamCount = paramCount; - const placeholders = value.in - .map((_, index) => `$${currentParamCount + index + 1}`) - .join(","); - whereClauses.push( - `${this.metadataColumnName}->>'${key}' IN (${placeholders})` - ); - values.push(...value.in); - paramCount += value.in.length; - } else { - paramCount += 1; - whereClauses.push( - `${this.metadataColumnName}->>'${key}' = $${paramCount}` - ); - values.push(value); - } - } - - const whereClause = whereClauses.length - ? `WHERE ${whereClauses.join(" AND ")}` - : ""; - - const queryString = ` - SELECT *, ${this.vectorColumnName} <=> $1 as "_distance" - FROM ${this.tableName} - ${whereClause} - ORDER BY "_distance" ASC - LIMIT $2;`; - - const documents = (await this.client.query(queryString, values)).rows; - const results = [] as [Document, number][]; - for (const doc of documents) { - if (doc._distance != null && doc[this.contentColumnName] != null) { - const document = new Document({ - pageContent: doc[this.contentColumnName], - metadata: doc[this.metadataColumnName], - }); - results.push([document, doc._distance]); - } - } - return results; - } - - async delete(params: { ids?: string[]; deleteAll?: boolean }): Promise { - if (params.ids !== undefined) { - await this.client.query( - `DELETE FROM ${this.tableName} WHERE ${ - this.idColumnName - } IN (${params.ids.map((_, idx) => `$${idx + 1}`)})`, - params.ids - ); - } else if (params.deleteAll) { - await this.client.query(`TRUNCATE TABLE ${this.tableName}`); - } - } - - /** - * Method to ensure the existence of the table in the database. It creates - * the table if it does not already exist. - * - * @returns Promise that resolves when the table has been ensured. - */ - async ensureTableInDatabase(): Promise { - await this.client.query(`CREATE EXTENSION IF NOT EXISTS vector;`); - await this.client.query(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`); - await this.client.query(`CREATE TABLE IF NOT EXISTS "${this.tableName}" ( - "${this.idColumnName}" uuid NOT NULL DEFAULT uuid_generate_v4() PRIMARY KEY, - "${this.contentColumnName}" text, - "${this.metadataColumnName}" jsonb, - "${this.vectorColumnName}" vector - );`); - } - - /** - * Static method to create a new `VercelPostgres` instance from an - * array of texts and their metadata. It converts the texts into - * `Document` instances and adds them to the store. - * - * @param texts - Array of texts. - * @param metadatas - Array of metadata objects or a single metadata object. - * @param embeddings - Embeddings instance. - * @param fields - `VercelPostgres` configuration options. - * @returns Promise that resolves with a new instance of `VercelPostgres`. - */ - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - dbConfig?: Partial & { - postgresConnectionOptions?: VercelPostgresPoolConfig; - } - ): Promise { - const docs = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - - return this.fromDocuments(docs, embeddings, dbConfig); - } - - /** - * Static method to create a new `VercelPostgres` instance from an - * array of `Document` instances. It adds the documents to the store. - * - * @param docs - Array of `Document` instances. - * @param embeddings - Embeddings instance. - * @param fields - `VercelPostgres` configuration options. - * @returns Promise that resolves with a new instance of `VercelPostgres`. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - dbConfig?: Partial & { - postgresConnectionOptions?: VercelPostgresPoolConfig; - } - ): Promise { - const instance = await this.initialize(embeddings, dbConfig); - await instance.addDocuments(docs); - - return instance; - } - - /** - * Closes all the clients in the pool and terminates the pool. - * - * @returns Promise that resolves when all clients are closed and the pool is terminated. - */ - async end(): Promise { - await this.client?.release(); - return this.pool.end(); - } -} +export * from "@langchain/community/vectorstores/vercel_postgres"; diff --git a/langchain/src/vectorstores/voy.ts b/langchain/src/vectorstores/voy.ts index c968b3dfb71f..09428f9e9734 100644 --- a/langchain/src/vectorstores/voy.ts +++ b/langchain/src/vectorstores/voy.ts @@ -1,191 +1 @@ -import type { Voy as VoyOriginClient, SearchResult } from "voy-search"; -import { Embeddings } from "../embeddings/base.js"; -import { VectorStore } from "./base.js"; -import { Document } from "../document.js"; - -export type VoyClient = Omit< - VoyOriginClient, - "remove" | "size" | "serialize" | "free" ->; - -/** - * Internal interface for storing documents mappings. - */ -interface InternalDoc { - embeddings: number[]; - document: Document; -} - -/** - * Class that extends `VectorStore`. It allows to perform similarity search using - * Voi similarity search engine. The class requires passing Voy Client as an input parameter. - */ -export class VoyVectorStore extends VectorStore { - client: VoyClient; - - numDimensions: number | null = null; - - docstore: InternalDoc[] = []; - - _vectorstoreType(): string { - return "voi"; - } - - constructor(client: VoyClient, embeddings: Embeddings) { - super(embeddings, {}); - this.client = client; - this.embeddings = embeddings; - } - - /** - * Adds documents to the Voy database. The documents are embedded using embeddings provided while instantiating the class. - * @param documents An array of `Document` instances associated with the vectors. - */ - async addDocuments(documents: Document[]): Promise { - const texts = documents.map(({ pageContent }) => pageContent); - if (documents.length === 0) { - return; - } - - const firstVector = ( - await this.embeddings.embedDocuments(texts.slice(0, 1)) - )[0]; - if (this.numDimensions === null) { - this.numDimensions = firstVector.length; - } else if (this.numDimensions !== firstVector.length) { - throw new Error( - `Vectors must have the same length as the number of dimensions (${this.numDimensions})` - ); - } - const restResults = await this.embeddings.embedDocuments(texts.slice(1)); - await this.addVectors([firstVector, ...restResults], documents); - } - - /** - * Adds vectors to the Voy database. The vectors are associated with - * the provided documents. - * @param vectors An array of vectors to be added to the database. - * @param documents An array of `Document` instances associated with the vectors. - */ - async addVectors(vectors: number[][], documents: Document[]): Promise { - if (vectors.length === 0) { - return; - } - if (this.numDimensions === null) { - this.numDimensions = vectors[0].length; - } - - if (vectors.length !== documents.length) { - throw new Error(`Vectors and metadata must have the same length`); - } - if (!vectors.every((v) => v.length === this.numDimensions)) { - throw new Error( - `Vectors must have the same length as the number of dimensions (${this.numDimensions})` - ); - } - - vectors.forEach((item, idx) => { - const doc = documents[idx]; - this.docstore.push({ embeddings: item, document: doc }); - }); - const embeddings = this.docstore.map((item, idx) => ({ - id: String(idx), - embeddings: item.embeddings, - title: "", - url: "", - })); - this.client.index({ embeddings }); - } - - /** - * Searches for vectors in the Voy database that are similar to the - * provided query vector. - * @param query The query vector. - * @param k The number of similar vectors to return. - * @returns A promise that resolves with an array of tuples, each containing a `Document` instance and a similarity score. - */ - async similaritySearchVectorWithScore(query: number[], k: number) { - if (this.numDimensions === null) { - throw new Error("There aren't any elements in the index yet."); - } - if (query.length !== this.numDimensions) { - throw new Error( - `Query vector must have the same length as the number of dimensions (${this.numDimensions})` - ); - } - const itemsToQuery = Math.min(this.docstore.length, k); - if (itemsToQuery > this.docstore.length) { - console.warn( - `k (${k}) is greater than the number of elements in the index (${this.docstore.length}), setting k to ${itemsToQuery}` - ); - } - const results: SearchResult = this.client.search( - new Float32Array(query), - itemsToQuery - ); - return results.neighbors.map( - ({ id }, idx) => - [this.docstore[parseInt(id, 10)].document, idx] as [Document, number] - ); - } - - /** - * Method to delete data from the Voy index. It can delete data based - * on specific IDs or a filter. - * @param params Object that includes either an array of IDs or a filter for the data to be deleted. - * @returns Promise that resolves when the deletion is complete. - */ - async delete(params: { deleteAll?: boolean }): Promise { - if (params.deleteAll === true) { - await this.client.clear(); - } else { - throw new Error(`You must provide a "deleteAll" parameter.`); - } - } - - /** - * Creates a new `VoyVectorStore` instance from an array of text strings. The text - * strings are converted to `Document` instances and added to the Voy - * database. - * @param texts An array of text strings. - * @param metadatas An array of metadata objects or a single metadata object. If an array is provided, it must have the same length as the `texts` array. - * @param embeddings An `Embeddings` instance used to generate embeddings for the documents. - * @param client An instance of Voy client to use in the underlying operations. - * @returns A promise that resolves with a new `VoyVectorStore` instance. - */ - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - client: VoyClient - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - return VoyVectorStore.fromDocuments(docs, embeddings, client); - } - - /** - * Creates a new `VoyVectorStore` instance from an array of `Document` instances. - * The documents are added to the Voy database. - * @param docs An array of `Document` instances. - * @param embeddings An `Embeddings` instance used to generate embeddings for the documents. - * @param client An instance of Voy client to use in the underlying operations. - * @returns A promise that resolves with a new `VoyVectorStore` instance. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - client: VoyClient - ): Promise { - const instance = new VoyVectorStore(client, embeddings); - await instance.addDocuments(docs); - return instance; - } -} +export * from "@langchain/community/vectorstores/voy"; diff --git a/langchain/src/vectorstores/weaviate.ts b/langchain/src/vectorstores/weaviate.ts index 1bd1fba5a77a..e7a76b406db7 100644 --- a/langchain/src/vectorstores/weaviate.ts +++ b/langchain/src/vectorstores/weaviate.ts @@ -1,432 +1 @@ -import * as uuid from "uuid"; -import type { - WeaviateClient, - WeaviateObject, - WhereFilter, -} from "weaviate-ts-client"; -import { MaxMarginalRelevanceSearchOptions, VectorStore } from "./base.js"; -import { Embeddings } from "../embeddings/base.js"; -import { Document } from "../document.js"; -import { maximalMarginalRelevance } from "../util/math.js"; - -// Note this function is not generic, it is designed specifically for Weaviate -// https://weaviate.io/developers/weaviate/config-refs/datatypes#introduction -export const flattenObjectForWeaviate = ( - // eslint-disable-next-line @typescript-eslint/no-explicit-any - obj: Record -) => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const flattenedObject: Record = {}; - - for (const key in obj) { - if (!Object.hasOwn(obj, key)) { - continue; - } - const value = obj[key]; - if (typeof obj[key] === "object" && !Array.isArray(value)) { - const recursiveResult = flattenObjectForWeaviate(value); - - for (const deepKey in recursiveResult) { - if (Object.hasOwn(obj, key)) { - flattenedObject[`${key}_${deepKey}`] = recursiveResult[deepKey]; - } - } - } else if (Array.isArray(value)) { - if ( - value.length > 0 && - typeof value[0] !== "object" && - // eslint-disable-next-line @typescript-eslint/no-explicit-any - value.every((el: any) => typeof el === typeof value[0]) - ) { - // Weaviate only supports arrays of primitive types, - // where all elements are of the same type - flattenedObject[key] = value; - } - } else { - flattenedObject[key] = value; - } - } - - return flattenedObject; -}; - -/** - * Interface that defines the arguments required to create a new instance - * of the `WeaviateStore` class. It includes the Weaviate client, the name - * of the class in Weaviate, and optional keys for text and metadata. - */ -export interface WeaviateLibArgs { - client: WeaviateClient; - /** - * The name of the class in Weaviate. Must start with a capital letter. - */ - indexName: string; - textKey?: string; - metadataKeys?: string[]; - tenant?: string; -} - -interface ResultRow { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - [key: string]: any; -} - -/** - * Interface that defines a filter for querying data from Weaviate. It - * includes a distance and a `WhereFilter`. - */ -export interface WeaviateFilter { - distance?: number; - where: WhereFilter; -} - -/** - * Class that extends the `VectorStore` base class. It provides methods to - * interact with a Weaviate index, including adding vectors and documents, - * deleting data, and performing similarity searches. - */ -export class WeaviateStore extends VectorStore { - declare FilterType: WeaviateFilter; - - private client: WeaviateClient; - - private indexName: string; - - private textKey: string; - - private queryAttrs: string[]; - - private tenant?: string; - - _vectorstoreType(): string { - return "weaviate"; - } - - constructor(public embeddings: Embeddings, args: WeaviateLibArgs) { - super(embeddings, args); - - this.client = args.client; - this.indexName = args.indexName; - this.textKey = args.textKey || "text"; - this.queryAttrs = [this.textKey]; - this.tenant = args.tenant; - - if (args.metadataKeys) { - this.queryAttrs = [ - ...new Set([ - ...this.queryAttrs, - ...args.metadataKeys.filter((k) => { - // https://spec.graphql.org/June2018/#sec-Names - // queryAttrs need to be valid GraphQL Names - const keyIsValid = /^[_A-Za-z][_0-9A-Za-z]*$/.test(k); - if (!keyIsValid) { - console.warn( - `Skipping metadata key ${k} as it is not a valid GraphQL Name` - ); - } - return keyIsValid; - }), - ]), - ]; - } - } - - /** - * Method to add vectors and corresponding documents to the Weaviate - * index. - * @param vectors Array of vectors to be added. - * @param documents Array of documents corresponding to the vectors. - * @param options Optional parameter that can include specific IDs for the documents. - * @returns An array of document IDs. - */ - async addVectors( - vectors: number[][], - documents: Document[], - options?: { ids?: string[] } - ) { - const documentIds = options?.ids ?? documents.map((_) => uuid.v4()); - const batch: WeaviateObject[] = documents.map((document, index) => { - if (Object.hasOwn(document.metadata, "id")) - throw new Error( - "Document inserted to Weaviate vectorstore should not have `id` in their metadata." - ); - - const flattenedMetadata = flattenObjectForWeaviate(document.metadata); - return { - ...(this.tenant ? { tenant: this.tenant } : {}), - class: this.indexName, - id: documentIds[index], - vector: vectors[index], - properties: { - [this.textKey]: document.pageContent, - ...flattenedMetadata, - }, - }; - }); - - try { - const responses = await this.client.batch - .objectsBatcher() - .withObjects(...batch) - .do(); - // if storing vectors fails, we need to know why - const errorMessages: string[] = []; - responses.forEach((response) => { - if (response?.result?.errors?.error) { - errorMessages.push( - ...response.result.errors.error.map( - (err) => - err.message ?? - "!! Unfortunately no error message was presented in the API response !!" - ) - ); - } - }); - if (errorMessages.length > 0) { - throw new Error(errorMessages.join("\n")); - } - } catch (e) { - throw Error(`Error adding vectors: ${e}`); - } - return documentIds; - } - - /** - * Method to add documents to the Weaviate index. It first generates - * vectors for the documents using the embeddings, then adds the vectors - * and documents to the index. - * @param documents Array of documents to be added. - * @param options Optional parameter that can include specific IDs for the documents. - * @returns An array of document IDs. - */ - async addDocuments(documents: Document[], options?: { ids?: string[] }) { - return this.addVectors( - await this.embeddings.embedDocuments(documents.map((d) => d.pageContent)), - documents, - options - ); - } - - /** - * Method to delete data from the Weaviate index. It can delete data based - * on specific IDs or a filter. - * @param params Object that includes either an array of IDs or a filter for the data to be deleted. - * @returns Promise that resolves when the deletion is complete. - */ - async delete(params: { - ids?: string[]; - filter?: WeaviateFilter; - }): Promise { - const { ids, filter } = params; - - if (ids && ids.length > 0) { - for (const id of ids) { - let deleter = this.client.data - .deleter() - .withClassName(this.indexName) - .withId(id); - - if (this.tenant) { - deleter = deleter.withTenant(this.tenant); - } - - await deleter.do(); - } - } else if (filter) { - let batchDeleter = this.client.batch - .objectsBatchDeleter() - .withClassName(this.indexName) - .withWhere(filter.where); - - if (this.tenant) { - batchDeleter = batchDeleter.withTenant(this.tenant); - } - - await batchDeleter.do(); - } else { - throw new Error( - `This method requires either "ids" or "filter" to be set in the input object` - ); - } - } - - /** - * Method to perform a similarity search on the stored vectors in the - * Weaviate index. It returns the top k most similar documents and their - * similarity scores. - * @param query The query vector. - * @param k The number of most similar documents to return. - * @param filter Optional filter to apply to the search. - * @returns An array of tuples, where each tuple contains a document and its similarity score. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: WeaviateFilter - ): Promise<[Document, number][]> { - const resultsWithEmbedding = - await this.similaritySearchVectorWithScoreAndEmbedding(query, k, filter); - return resultsWithEmbedding.map(([document, score, _embedding]) => [ - document, - score, - ]); - } - - /** - * Method to perform a similarity search on the stored vectors in the - * Weaviate index. It returns the top k most similar documents, their - * similarity scores and embedding vectors. - * @param query The query vector. - * @param k The number of most similar documents to return. - * @param filter Optional filter to apply to the search. - * @returns An array of tuples, where each tuple contains a document, its similarity score and its embedding vector. - */ - async similaritySearchVectorWithScoreAndEmbedding( - query: number[], - k: number, - filter?: WeaviateFilter - ): Promise<[Document, number, number[]][]> { - try { - let builder = this.client.graphql - .get() - .withClassName(this.indexName) - .withFields( - `${this.queryAttrs.join(" ")} _additional { distance vector }` - ) - .withNearVector({ - vector: query, - distance: filter?.distance, - }) - .withLimit(k); - - if (this.tenant) { - builder = builder.withTenant(this.tenant); - } - - if (filter?.where) { - builder = builder.withWhere(filter.where); - } - - const result = await builder.do(); - - const documents: [Document, number, number[]][] = []; - for (const data of result.data.Get[this.indexName]) { - const { [this.textKey]: text, _additional, ...rest }: ResultRow = data; - - documents.push([ - new Document({ - pageContent: text, - metadata: rest, - }), - _additional.distance, - _additional.vector, - ]); - } - return documents; - } catch (e) { - throw Error(`'Error in similaritySearch' ${e}`); - } - } - - /** - * Return documents selected using the maximal marginal relevance. - * Maximal marginal relevance optimizes for similarity to the query AND diversity - * among selected documents. - * - * @param {string} query - Text to look up documents similar to. - * @param {number} options.k - Number of documents to return. - * @param {number} options.fetchK - Number of documents to fetch before passing to the MMR algorithm. - * @param {number} options.lambda - Number between 0 and 1 that determines the degree of diversity among the results, - * where 0 corresponds to maximum diversity and 1 to minimum diversity. - * @param {this["FilterType"]} options.filter - Optional filter - * @param _callbacks - * - * @returns {Promise} - List of documents selected by maximal marginal relevance. - */ - override async maxMarginalRelevanceSearch( - query: string, - options: MaxMarginalRelevanceSearchOptions, - _callbacks?: undefined - ): Promise { - const { k, fetchK = 20, lambda = 0.5, filter } = options; - const queryEmbedding: number[] = await this.embeddings.embedQuery(query); - const allResults: [Document, number, number[]][] = - await this.similaritySearchVectorWithScoreAndEmbedding( - queryEmbedding, - fetchK, - filter - ); - const embeddingList = allResults.map( - ([_doc, _score, embedding]) => embedding - ); - const mmrIndexes = maximalMarginalRelevance( - queryEmbedding, - embeddingList, - lambda, - k - ); - return mmrIndexes - .filter((idx) => idx !== -1) - .map((idx) => allResults[idx][0]); - } - - /** - * Static method to create a new `WeaviateStore` instance from a list of - * texts. It first creates documents from the texts and metadata, then - * adds the documents to the Weaviate index. - * @param texts Array of texts. - * @param metadatas Metadata for the texts. Can be a single object or an array of objects. - * @param embeddings Embeddings to be used for the texts. - * @param args Arguments required to create a new `WeaviateStore` instance. - * @returns A new `WeaviateStore` instance. - */ - static fromTexts( - texts: string[], - metadatas: object | object[], - embeddings: Embeddings, - args: WeaviateLibArgs - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - return WeaviateStore.fromDocuments(docs, embeddings, args); - } - - /** - * Static method to create a new `WeaviateStore` instance from a list of - * documents. It adds the documents to the Weaviate index. - * @param docs Array of documents. - * @param embeddings Embeddings to be used for the documents. - * @param args Arguments required to create a new `WeaviateStore` instance. - * @returns A new `WeaviateStore` instance. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - args: WeaviateLibArgs - ): Promise { - const instance = new this(embeddings, args); - await instance.addDocuments(docs); - return instance; - } - - /** - * Static method to create a new `WeaviateStore` instance from an existing - * Weaviate index. - * @param embeddings Embeddings to be used for the Weaviate index. - * @param args Arguments required to create a new `WeaviateStore` instance. - * @returns A new `WeaviateStore` instance. - */ - static async fromExistingIndex( - embeddings: Embeddings, - args: WeaviateLibArgs - ): Promise { - return new this(embeddings, args); - } -} +export * from "@langchain/community/vectorstores/weaviate"; diff --git a/langchain/src/vectorstores/xata.ts b/langchain/src/vectorstores/xata.ts index ccd6089ea4e9..9ec25bef187b 100644 --- a/langchain/src/vectorstores/xata.ts +++ b/langchain/src/vectorstores/xata.ts @@ -1,149 +1 @@ -import { BaseClient } from "@xata.io/client"; -import { VectorStore } from "./base.js"; -import { Embeddings } from "../embeddings/base.js"; -import { Document } from "../document.js"; - -/** - * Interface for the arguments required to create a XataClient. Includes - * the client instance and the table name. - */ -export interface XataClientArgs { - readonly client: XataClient; - readonly table: string; -} - -/** - * Type for the filter object used in Xata database queries. - */ -type XataFilter = object; - -/** - * Class for interacting with a Xata database as a VectorStore. Provides - * methods to add documents and vectors to the database, delete entries, - * and perform similarity searches. - */ -export class XataVectorSearch< - XataClient extends BaseClient -> extends VectorStore { - declare FilterType: XataFilter; - - private readonly client: XataClient; - - private readonly table: string; - - _vectorstoreType(): string { - return "xata"; - } - - constructor(embeddings: Embeddings, args: XataClientArgs) { - super(embeddings, args); - - this.client = args.client; - this.table = args.table; - } - - /** - * Method to add documents to the Xata database. Maps the page content of - * each document, embeds the documents using the embeddings, and adds the - * vectors to the database. - * @param documents Array of documents to be added. - * @param options Optional object containing an array of ids. - * @returns Promise resolving to an array of ids of the added documents. - */ - async addDocuments(documents: Document[], options?: { ids?: string[] }) { - const texts = documents.map(({ pageContent }) => pageContent); - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents, - options - ); - } - - /** - * Method to add vectors to the Xata database. Maps each vector to a row - * with the document's content, embedding, and metadata. Creates or - * replaces these rows in the Xata database. - * @param vectors Array of vectors to be added. - * @param documents Array of documents corresponding to the vectors. - * @param options Optional object containing an array of ids. - * @returns Promise resolving to an array of ids of the added vectors. - */ - async addVectors( - vectors: number[][], - documents: Document[], - options?: { ids?: string[] } - ) { - const rows = vectors - .map((embedding, idx) => ({ - content: documents[idx].pageContent, - embedding, - ...documents[idx].metadata, - })) - .map((row, idx) => { - if (options?.ids) { - return { id: options.ids[idx], ...row }; - } - return row; - }); - - const res = await this.client.db[this.table].createOrReplace(rows); - // Since we have an untyped BaseClient, it doesn't know the - // actual return type of the overload. - const results = res as unknown as { id: string }[]; - const returnedIds = results.map((row) => row.id); - return returnedIds; - } - - /** - * Method to delete entries from the Xata database. Deletes the entries - * with the provided ids. - * @param params Object containing an array of ids of the entries to be deleted. - * @returns Promise resolving to void. - */ - async delete(params: { ids: string[] }): Promise { - const { ids } = params; - await this.client.db[this.table].delete(ids); - } - - /** - * Method to perform a similarity search in the Xata database. Returns the - * k most similar documents along with their scores. - * @param query Query vector for the similarity search. - * @param k Number of most similar documents to return. - * @param filter Optional filter for the search. - * @returns Promise resolving to an array of tuples, each containing a Document and its score. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: XataFilter | undefined - ): Promise<[Document, number][]> { - const { records } = await this.client.db[this.table].vectorSearch( - "embedding", - query, - { - size: k, - filter, - } - ); - - return ( - // eslint-disable-next-line @typescript-eslint/no-explicit-any - records?.map((record: any) => [ - new Document({ - pageContent: record.content, - metadata: Object.fromEntries( - Object.entries(record).filter( - ([key]) => - key !== "content" && - key !== "embedding" && - key !== "xata" && - key !== "id" - ) - ), - }), - record.xata.score, - ]) ?? [] - ); - } -} +export * from "@langchain/community/vectorstores/xata"; diff --git a/langchain/src/vectorstores/zep.ts b/langchain/src/vectorstores/zep.ts index 63919fa34970..ea9578092dbe 100644 --- a/langchain/src/vectorstores/zep.ts +++ b/langchain/src/vectorstores/zep.ts @@ -1,424 +1 @@ -import { - DocumentCollection, - IDocument, - NotFoundError, - ZepClient, -} from "@getzep/zep-js"; - -import { MaxMarginalRelevanceSearchOptions, VectorStore } from "./base.js"; -import { Embeddings } from "../embeddings/base.js"; -import { Document } from "../document.js"; -import { FakeEmbeddings } from "../embeddings/fake.js"; -import { Callbacks } from "../callbacks/index.js"; -import { maximalMarginalRelevance } from "../util/math.js"; - -/** - * Interface for the arguments required to initialize a ZepVectorStore - * instance. - */ -export interface IZepArgs { - collection: DocumentCollection; -} - -/** - * Interface for the configuration options for a ZepVectorStore instance. - */ -export interface IZepConfig { - apiUrl: string; - apiKey?: string; - collectionName: string; - description?: string; - metadata?: Record; - embeddingDimensions?: number; - isAutoEmbedded?: boolean; -} - -/** - * Interface for the parameters required to delete documents from a - * ZepVectorStore instance. - */ -export interface IZepDeleteParams { - uuids: string[]; -} - -/** - * ZepVectorStore is a VectorStore implementation that uses the Zep long-term memory store as a backend. - * - * If the collection does not exist, it will be created automatically. - * - * Requires `zep-js` to be installed: - * ```bash - * npm install @getzep/zep-js - * ``` - * - * @property {ZepClient} client - The ZepClient instance used to interact with Zep's API. - * @property {Promise} initPromise - A promise that resolves when the collection is initialized. - * @property {DocumentCollection} collection - The Zep document collection. - */ -export class ZepVectorStore extends VectorStore { - public client: ZepClient; - - public collection: DocumentCollection; - - private initPromise: Promise; - - private autoEmbed = false; - - constructor(embeddings: Embeddings, args: IZepConfig) { - super(embeddings, args); - - this.embeddings = embeddings; - - // eslint-disable-next-line no-instanceof/no-instanceof - if (this.embeddings instanceof FakeEmbeddings) { - this.autoEmbed = true; - } - - this.initPromise = this.initCollection(args).catch((err) => { - console.error("Error initializing collection:", err); - throw err; - }); - } - - /** - * Initializes the document collection. If the collection does not exist, it creates a new one. - * - * @param {IZepConfig} args - The configuration object for the Zep API. - */ - private async initCollection(args: IZepConfig) { - this.client = await ZepClient.init(args.apiUrl, args.apiKey); - try { - this.collection = await this.client.document.getCollection( - args.collectionName - ); - - // If the Embedding passed in is fake, but the collection is not auto embedded, throw an error - // eslint-disable-next-line no-instanceof/no-instanceof - if (!this.collection.is_auto_embedded && this.autoEmbed) { - throw new Error(`You can't pass in FakeEmbeddings when collection ${args.collectionName} - is not set to auto-embed.`); - } - } catch (err) { - // eslint-disable-next-line no-instanceof/no-instanceof - if (err instanceof Error) { - // eslint-disable-next-line no-instanceof/no-instanceof - if (err instanceof NotFoundError || err.name === "NotFoundError") { - await this.createCollection(args); - } else { - throw err; - } - } - } - } - - /** - * Creates a new document collection. - * - * @param {IZepConfig} args - The configuration object for the Zep API. - */ - private async createCollection(args: IZepConfig) { - if (!args.embeddingDimensions) { - throw new Error(`Collection ${args.collectionName} not found. - You can create a new Collection by providing embeddingDimensions.`); - } - - this.collection = await this.client.document.addCollection({ - name: args.collectionName, - description: args.description, - metadata: args.metadata, - embeddingDimensions: args.embeddingDimensions, - isAutoEmbedded: this.autoEmbed, - }); - - console.info("Created new collection:", args.collectionName); - } - - /** - * Adds vectors and corresponding documents to the collection. - * - * @param {number[][]} vectors - The vectors to add. - * @param {Document[]} documents - The corresponding documents to add. - * @returns {Promise} - A promise that resolves with the UUIDs of the added documents. - */ - async addVectors( - vectors: number[][], - documents: Document[] - ): Promise { - if (!this.autoEmbed && vectors.length === 0) { - throw new Error(`Vectors must be provided if autoEmbed is false`); - } - if (!this.autoEmbed && vectors.length !== documents.length) { - throw new Error(`Vectors and documents must have the same length`); - } - - const docs: Array = []; - for (let i = 0; i < documents.length; i += 1) { - const doc: IDocument = { - content: documents[i].pageContent, - metadata: documents[i].metadata, - embedding: vectors.length > 0 ? vectors[i] : undefined, - }; - docs.push(doc); - } - // Wait for collection to be initialized - await this.initPromise; - return await this.collection.addDocuments(docs); - } - - /** - * Adds documents to the collection. The documents are first embedded into vectors - * using the provided embedding model. - * - * @param {Document[]} documents - The documents to add. - * @returns {Promise} - A promise that resolves with the UUIDs of the added documents. - */ - async addDocuments(documents: Document[]): Promise { - const texts = documents.map(({ pageContent }) => pageContent); - let vectors: number[][] = []; - if (!this.autoEmbed) { - vectors = await this.embeddings.embedDocuments(texts); - } - return this.addVectors(vectors, documents); - } - - _vectorstoreType(): string { - return "zep"; - } - - /** - * Deletes documents from the collection. - * - * @param {IZepDeleteParams} params - The list of Zep document UUIDs to delete. - * @returns {Promise} - */ - async delete(params: IZepDeleteParams): Promise { - // Wait for collection to be initialized - await this.initPromise; - for (const uuid of params.uuids) { - await this.collection.deleteDocument(uuid); - } - } - - /** - * Performs a similarity search in the collection and returns the results with their scores. - * - * @param {number[]} query - The query vector. - * @param {number} k - The number of results to return. - * @param {Record} filter - The filter to apply to the search. Zep only supports Record as filter. - * @returns {Promise<[Document, number][]>} - A promise that resolves with the search results and their scores. - */ - async similaritySearchVectorWithScore( - query: number[], - k: number, - filter?: Record | undefined - ): Promise<[Document, number][]> { - await this.initPromise; - const results = await this.collection.search( - { - embedding: new Float32Array(query), - metadata: assignMetadata(filter), - }, - k - ); - return zepDocsToDocumentsAndScore(results); - } - - async _similaritySearchWithScore( - query: string, - k: number, - filter?: Record | undefined - ): Promise<[Document, number][]> { - await this.initPromise; - const results = await this.collection.search( - { - text: query, - metadata: assignMetadata(filter), - }, - k - ); - return zepDocsToDocumentsAndScore(results); - } - - async similaritySearchWithScore( - query: string, - k = 4, - filter: Record | undefined = undefined, - _callbacks = undefined // implement passing to embedQuery later - ): Promise<[Document, number][]> { - if (this.autoEmbed) { - return this._similaritySearchWithScore(query, k, filter); - } else { - return this.similaritySearchVectorWithScore( - await this.embeddings.embedQuery(query), - k, - filter - ); - } - } - - /** - * Performs a similarity search on the Zep collection. - * - * @param {string} query - The query string to search for. - * @param {number} [k=4] - The number of results to return. Defaults to 4. - * @param {this["FilterType"] | undefined} [filter=undefined] - An optional set of JSONPath filters to apply to the search. - * @param {Callbacks | undefined} [_callbacks=undefined] - Optional callbacks. Currently not implemented. - * @returns {Promise} - A promise that resolves to an array of Documents that are similar to the query. - * - * @async - */ - async similaritySearch( - query: string, - k = 4, - filter: this["FilterType"] | undefined = undefined, - _callbacks: Callbacks | undefined = undefined // implement passing to embedQuery later - ): Promise { - await this.initPromise; - - let results: [Document, number][]; - if (this.autoEmbed) { - const zepResults = await this.collection.search( - { text: query, metadata: assignMetadata(filter) }, - k - ); - results = zepDocsToDocumentsAndScore(zepResults); - } else { - results = await this.similaritySearchVectorWithScore( - await this.embeddings.embedQuery(query), - k, - assignMetadata(filter) - ); - } - - return results.map((result) => result[0]); - } - - /** - * Return documents selected using the maximal marginal relevance. - * Maximal marginal relevance optimizes for similarity to the query AND diversity - * among selected documents. - * - * @param {string} query - Text to look up documents similar to. - * @param options - * @param {number} options.k - Number of documents to return. - * @param {number} options.fetchK=20- Number of documents to fetch before passing to the MMR algorithm. - * @param {number} options.lambda=0.5 - Number between 0 and 1 that determines the degree of diversity among the results, - * where 0 corresponds to maximum diversity and 1 to minimum diversity. - * @param {Record} options.filter - Optional Zep JSONPath query to pre-filter on document metadata field - * - * @returns {Promise} - List of documents selected by maximal marginal relevance. - */ - async maxMarginalRelevanceSearch( - query: string, - options: MaxMarginalRelevanceSearchOptions - ): Promise { - const { k, fetchK = 20, lambda = 0.5, filter } = options; - - let queryEmbedding: number[]; - let zepResults: IDocument[]; - if (!this.autoEmbed) { - queryEmbedding = await this.embeddings.embedQuery(query); - zepResults = await this.collection.search( - { - embedding: new Float32Array(queryEmbedding), - metadata: assignMetadata(filter), - }, - fetchK - ); - } else { - let queryEmbeddingArray: Float32Array; - [zepResults, queryEmbeddingArray] = - await this.collection.searchReturnQueryVector( - { text: query, metadata: assignMetadata(filter) }, - fetchK - ); - queryEmbedding = Array.from(queryEmbeddingArray); - } - - const results = zepDocsToDocumentsAndScore(zepResults); - - const embeddingList = zepResults.map((doc) => - Array.from(doc.embedding ? doc.embedding : []) - ); - - const mmrIndexes = maximalMarginalRelevance( - queryEmbedding, - embeddingList, - lambda, - k - ); - - return mmrIndexes.filter((idx) => idx !== -1).map((idx) => results[idx][0]); - } - - /** - * Creates a new ZepVectorStore instance from an array of texts. Each text is converted into a Document and added to the collection. - * - * @param {string[]} texts - The texts to convert into Documents. - * @param {object[] | object} metadatas - The metadata to associate with each Document. If an array is provided, each element is associated with the corresponding Document. If an object is provided, it is associated with all Documents. - * @param {Embeddings} embeddings - The embeddings to use for vectorizing the texts. - * @param {IZepConfig} zepConfig - The configuration object for the Zep API. - * @returns {Promise} - A promise that resolves with the new ZepVectorStore instance. - */ - static async fromTexts( - texts: string[], - metadatas: object[] | object, - embeddings: Embeddings, - zepConfig: IZepConfig - ): Promise { - const docs: Document[] = []; - for (let i = 0; i < texts.length; i += 1) { - const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; - const newDoc = new Document({ - pageContent: texts[i], - metadata, - }); - docs.push(newDoc); - } - return ZepVectorStore.fromDocuments(docs, embeddings, zepConfig); - } - - /** - * Creates a new ZepVectorStore instance from an array of Documents. Each Document is added to a Zep collection. - * - * @param {Document[]} docs - The Documents to add. - * @param {Embeddings} embeddings - The embeddings to use for vectorizing the Document contents. - * @param {IZepConfig} zepConfig - The configuration object for the Zep API. - * @returns {Promise} - A promise that resolves with the new ZepVectorStore instance. - */ - static async fromDocuments( - docs: Document[], - embeddings: Embeddings, - zepConfig: IZepConfig - ): Promise { - const instance = new this(embeddings, zepConfig); - // Wait for collection to be initialized - await instance.initPromise; - await instance.addDocuments(docs); - return instance; - } -} - -function zepDocsToDocumentsAndScore( - results: IDocument[] -): [Document, number][] { - return results.map((d) => [ - new Document({ - pageContent: d.content, - metadata: d.metadata, - }), - d.score ? d.score : 0, - ]); -} - -function assignMetadata( - value: string | Record | object | undefined -): Record | undefined { - if (typeof value === "object" && value !== null) { - return value as Record; - } - if (value !== undefined) { - console.warn("Metadata filters must be an object, Record, or undefined."); - } - return undefined; -} +export * from "@langchain/community/vectorstores/zep"; diff --git a/libs/langchain-community/.eslintrc.cjs b/libs/langchain-community/.eslintrc.cjs new file mode 100644 index 000000000000..344f8a9d6cd9 --- /dev/null +++ b/libs/langchain-community/.eslintrc.cjs @@ -0,0 +1,66 @@ +module.exports = { + extends: [ + "airbnb-base", + "eslint:recommended", + "prettier", + "plugin:@typescript-eslint/recommended", + ], + parserOptions: { + ecmaVersion: 12, + parser: "@typescript-eslint/parser", + project: "./tsconfig.json", + sourceType: "module", + }, + plugins: ["@typescript-eslint", "no-instanceof"], + ignorePatterns: [ + ".eslintrc.cjs", + "scripts", + "node_modules", + "dist", + "dist-cjs", + "*.js", + "*.cjs", + "*.d.ts", + ], + rules: { + "no-process-env": 2, + "no-instanceof/no-instanceof": 2, + "@typescript-eslint/explicit-module-boundary-types": 0, + "@typescript-eslint/no-empty-function": 0, + "@typescript-eslint/no-shadow": 0, + "@typescript-eslint/no-empty-interface": 0, + "@typescript-eslint/no-use-before-define": ["error", "nofunc"], + "@typescript-eslint/no-unused-vars": ["warn", { args: "none" }], + "@typescript-eslint/no-floating-promises": "error", + "@typescript-eslint/no-misused-promises": "error", + camelcase: 0, + "class-methods-use-this": 0, + "import/extensions": [2, "ignorePackages"], + "import/no-extraneous-dependencies": [ + "error", + { devDependencies: ["**/*.test.ts"] }, + ], + "import/no-unresolved": 0, + "import/prefer-default-export": 0, + "keyword-spacing": "error", + "max-classes-per-file": 0, + "max-len": 0, + "no-await-in-loop": 0, + "no-bitwise": 0, + "no-console": 0, + "no-restricted-syntax": 0, + "no-shadow": 0, + "no-continue": 0, + "no-void": 0, + "no-underscore-dangle": 0, + "no-use-before-define": 0, + "no-useless-constructor": 0, + "no-return-await": 0, + "consistent-return": 0, + "no-else-return": 0, + "func-names": 0, + "no-lonely-if": 0, + "prefer-rest-params": 0, + "new-cap": ["error", { properties: false, capIsNew: false }], + }, +}; diff --git a/libs/langchain-community/.gitignore b/libs/langchain-community/.gitignore new file mode 100644 index 000000000000..9e044f8930e5 --- /dev/null +++ b/libs/langchain-community/.gitignore @@ -0,0 +1,417 @@ +load.cjs +load.js +load.d.ts +load/serializable.cjs +load/serializable.js +load/serializable.d.ts +tools/aiplugin.cjs +tools/aiplugin.js +tools/aiplugin.d.ts +tools/aws_sfn.cjs +tools/aws_sfn.js +tools/aws_sfn.d.ts +tools/bingserpapi.cjs +tools/bingserpapi.js +tools/bingserpapi.d.ts +tools/brave_search.cjs +tools/brave_search.js +tools/brave_search.d.ts +tools/connery.cjs +tools/connery.js +tools/connery.d.ts +tools/dadjokeapi.cjs +tools/dadjokeapi.js +tools/dadjokeapi.d.ts +tools/dataforseo_api_search.cjs +tools/dataforseo_api_search.js +tools/dataforseo_api_search.d.ts +tools/gmail.cjs +tools/gmail.js +tools/gmail.d.ts +tools/google_custom_search.cjs +tools/google_custom_search.js +tools/google_custom_search.d.ts +tools/google_places.cjs +tools/google_places.js +tools/google_places.d.ts +tools/ifttt.cjs +tools/ifttt.js +tools/ifttt.d.ts +tools/searchapi.cjs +tools/searchapi.js +tools/searchapi.d.ts +tools/searxng_search.cjs +tools/searxng_search.js +tools/searxng_search.d.ts +tools/serpapi.cjs +tools/serpapi.js +tools/serpapi.d.ts +tools/serper.cjs +tools/serper.js +tools/serper.d.ts +tools/wikipedia_query_run.cjs +tools/wikipedia_query_run.js +tools/wikipedia_query_run.d.ts +tools/wolframalpha.cjs +tools/wolframalpha.js +tools/wolframalpha.d.ts +embeddings/bedrock.cjs +embeddings/bedrock.js +embeddings/bedrock.d.ts +embeddings/cloudflare_workersai.cjs +embeddings/cloudflare_workersai.js +embeddings/cloudflare_workersai.d.ts +embeddings/cohere.cjs +embeddings/cohere.js +embeddings/cohere.d.ts +embeddings/googlepalm.cjs +embeddings/googlepalm.js +embeddings/googlepalm.d.ts +embeddings/googlevertexai.cjs +embeddings/googlevertexai.js +embeddings/googlevertexai.d.ts +embeddings/gradient_ai.cjs +embeddings/gradient_ai.js +embeddings/gradient_ai.d.ts +embeddings/hf.cjs +embeddings/hf.js +embeddings/hf.d.ts +embeddings/hf_transformers.cjs +embeddings/hf_transformers.js +embeddings/hf_transformers.d.ts +embeddings/llama_cpp.cjs +embeddings/llama_cpp.js +embeddings/llama_cpp.d.ts +embeddings/minimax.cjs +embeddings/minimax.js +embeddings/minimax.d.ts +embeddings/ollama.cjs +embeddings/ollama.js +embeddings/ollama.d.ts +embeddings/tensorflow.cjs +embeddings/tensorflow.js +embeddings/tensorflow.d.ts +embeddings/voyage.cjs +embeddings/voyage.js +embeddings/voyage.d.ts +llms/ai21.cjs +llms/ai21.js +llms/ai21.d.ts +llms/aleph_alpha.cjs +llms/aleph_alpha.js +llms/aleph_alpha.d.ts +llms/bedrock.cjs +llms/bedrock.js +llms/bedrock.d.ts +llms/bedrock/web.cjs +llms/bedrock/web.js +llms/bedrock/web.d.ts +llms/cloudflare_workersai.cjs +llms/cloudflare_workersai.js +llms/cloudflare_workersai.d.ts +llms/cohere.cjs +llms/cohere.js +llms/cohere.d.ts +llms/fireworks.cjs +llms/fireworks.js +llms/fireworks.d.ts +llms/googlepalm.cjs +llms/googlepalm.js +llms/googlepalm.d.ts +llms/googlevertexai.cjs +llms/googlevertexai.js +llms/googlevertexai.d.ts +llms/googlevertexai/web.cjs +llms/googlevertexai/web.js +llms/googlevertexai/web.d.ts +llms/gradient_ai.cjs +llms/gradient_ai.js +llms/gradient_ai.d.ts +llms/hf.cjs +llms/hf.js +llms/hf.d.ts +llms/llama_cpp.cjs +llms/llama_cpp.js +llms/llama_cpp.d.ts +llms/ollama.cjs +llms/ollama.js +llms/ollama.d.ts +llms/portkey.cjs +llms/portkey.js +llms/portkey.d.ts +llms/raycast.cjs +llms/raycast.js +llms/raycast.d.ts +llms/replicate.cjs +llms/replicate.js +llms/replicate.d.ts +llms/sagemaker_endpoint.cjs +llms/sagemaker_endpoint.js +llms/sagemaker_endpoint.d.ts +llms/watsonx_ai.cjs +llms/watsonx_ai.js +llms/watsonx_ai.d.ts +llms/writer.cjs +llms/writer.js +llms/writer.d.ts +llms/yandex.cjs +llms/yandex.js +llms/yandex.d.ts +vectorstores/analyticdb.cjs +vectorstores/analyticdb.js +vectorstores/analyticdb.d.ts +vectorstores/cassandra.cjs +vectorstores/cassandra.js +vectorstores/cassandra.d.ts +vectorstores/chroma.cjs +vectorstores/chroma.js +vectorstores/chroma.d.ts +vectorstores/clickhouse.cjs +vectorstores/clickhouse.js +vectorstores/clickhouse.d.ts +vectorstores/closevector/node.cjs +vectorstores/closevector/node.js +vectorstores/closevector/node.d.ts +vectorstores/closevector/web.cjs +vectorstores/closevector/web.js +vectorstores/closevector/web.d.ts +vectorstores/cloudflare_vectorize.cjs +vectorstores/cloudflare_vectorize.js +vectorstores/cloudflare_vectorize.d.ts +vectorstores/convex.cjs +vectorstores/convex.js +vectorstores/convex.d.ts +vectorstores/elasticsearch.cjs +vectorstores/elasticsearch.js +vectorstores/elasticsearch.d.ts +vectorstores/faiss.cjs +vectorstores/faiss.js +vectorstores/faiss.d.ts +vectorstores/googlevertexai.cjs +vectorstores/googlevertexai.js +vectorstores/googlevertexai.d.ts +vectorstores/hnswlib.cjs +vectorstores/hnswlib.js +vectorstores/hnswlib.d.ts +vectorstores/lancedb.cjs +vectorstores/lancedb.js +vectorstores/lancedb.d.ts +vectorstores/memory.cjs +vectorstores/memory.js +vectorstores/memory.d.ts +vectorstores/milvus.cjs +vectorstores/milvus.js +vectorstores/milvus.d.ts +vectorstores/momento_vector_index.cjs +vectorstores/momento_vector_index.js +vectorstores/momento_vector_index.d.ts +vectorstores/mongodb_atlas.cjs +vectorstores/mongodb_atlas.js +vectorstores/mongodb_atlas.d.ts +vectorstores/myscale.cjs +vectorstores/myscale.js +vectorstores/myscale.d.ts +vectorstores/neo4j_vector.cjs +vectorstores/neo4j_vector.js +vectorstores/neo4j_vector.d.ts +vectorstores/opensearch.cjs +vectorstores/opensearch.js +vectorstores/opensearch.d.ts +vectorstores/pgvector.cjs +vectorstores/pgvector.js +vectorstores/pgvector.d.ts +vectorstores/pinecone.cjs +vectorstores/pinecone.js +vectorstores/pinecone.d.ts +vectorstores/prisma.cjs +vectorstores/prisma.js +vectorstores/prisma.d.ts +vectorstores/qdrant.cjs +vectorstores/qdrant.js +vectorstores/qdrant.d.ts +vectorstores/redis.cjs +vectorstores/redis.js +vectorstores/redis.d.ts +vectorstores/rockset.cjs +vectorstores/rockset.js +vectorstores/rockset.d.ts +vectorstores/singlestore.cjs +vectorstores/singlestore.js +vectorstores/singlestore.d.ts +vectorstores/supabase.cjs +vectorstores/supabase.js +vectorstores/supabase.d.ts +vectorstores/tigris.cjs +vectorstores/tigris.js +vectorstores/tigris.d.ts +vectorstores/typeorm.cjs +vectorstores/typeorm.js +vectorstores/typeorm.d.ts +vectorstores/typesense.cjs +vectorstores/typesense.js +vectorstores/typesense.d.ts +vectorstores/usearch.cjs +vectorstores/usearch.js +vectorstores/usearch.d.ts +vectorstores/vectara.cjs +vectorstores/vectara.js +vectorstores/vectara.d.ts +vectorstores/vercel_postgres.cjs +vectorstores/vercel_postgres.js +vectorstores/vercel_postgres.d.ts +vectorstores/voy.cjs +vectorstores/voy.js +vectorstores/voy.d.ts +vectorstores/weaviate.cjs +vectorstores/weaviate.js +vectorstores/weaviate.d.ts +vectorstores/xata.cjs +vectorstores/xata.js +vectorstores/xata.d.ts +vectorstores/zep.cjs +vectorstores/zep.js +vectorstores/zep.d.ts +chat_models/baiduwenxin.cjs +chat_models/baiduwenxin.js +chat_models/baiduwenxin.d.ts +chat_models/bedrock.cjs +chat_models/bedrock.js +chat_models/bedrock.d.ts +chat_models/bedrock/web.cjs +chat_models/bedrock/web.js +chat_models/bedrock/web.d.ts +chat_models/cloudflare_workersai.cjs +chat_models/cloudflare_workersai.js +chat_models/cloudflare_workersai.d.ts +chat_models/fireworks.cjs +chat_models/fireworks.js +chat_models/fireworks.d.ts +chat_models/googlevertexai.cjs +chat_models/googlevertexai.js +chat_models/googlevertexai.d.ts +chat_models/googlevertexai/web.cjs +chat_models/googlevertexai/web.js +chat_models/googlevertexai/web.d.ts +chat_models/googlepalm.cjs +chat_models/googlepalm.js +chat_models/googlepalm.d.ts +chat_models/iflytek_xinghuo.cjs +chat_models/iflytek_xinghuo.js +chat_models/iflytek_xinghuo.d.ts +chat_models/iflytek_xinghuo/web.cjs +chat_models/iflytek_xinghuo/web.js +chat_models/iflytek_xinghuo/web.d.ts +chat_models/llama_cpp.cjs +chat_models/llama_cpp.js +chat_models/llama_cpp.d.ts +chat_models/minimax.cjs +chat_models/minimax.js +chat_models/minimax.d.ts +chat_models/ollama.cjs +chat_models/ollama.js +chat_models/ollama.d.ts +chat_models/portkey.cjs +chat_models/portkey.js +chat_models/portkey.d.ts +chat_models/yandex.cjs +chat_models/yandex.js +chat_models/yandex.d.ts +callbacks/handlers/llmonitor.cjs +callbacks/handlers/llmonitor.js +callbacks/handlers/llmonitor.d.ts +retrievers/amazon_kendra.cjs +retrievers/amazon_kendra.js +retrievers/amazon_kendra.d.ts +retrievers/chaindesk.cjs +retrievers/chaindesk.js +retrievers/chaindesk.d.ts +retrievers/databerry.cjs +retrievers/databerry.js +retrievers/databerry.d.ts +retrievers/metal.cjs +retrievers/metal.js +retrievers/metal.d.ts +retrievers/supabase.cjs +retrievers/supabase.js +retrievers/supabase.d.ts +retrievers/tavily_search_api.cjs +retrievers/tavily_search_api.js +retrievers/tavily_search_api.d.ts +retrievers/zep.cjs +retrievers/zep.js +retrievers/zep.d.ts +caches/cloudflare_kv.cjs +caches/cloudflare_kv.js +caches/cloudflare_kv.d.ts +caches/momento.cjs +caches/momento.js +caches/momento.d.ts +caches/upstash_redis.cjs +caches/upstash_redis.js +caches/upstash_redis.d.ts +graphs/neo4j_graph.cjs +graphs/neo4j_graph.js +graphs/neo4j_graph.d.ts +utils/event_source_parse.cjs +utils/event_source_parse.js +utils/event_source_parse.d.ts +document_transformers/html_to_text.cjs +document_transformers/html_to_text.js +document_transformers/html_to_text.d.ts +document_transformers/mozilla_readability.cjs +document_transformers/mozilla_readability.js +document_transformers/mozilla_readability.d.ts +storage/convex.cjs +storage/convex.js +storage/convex.d.ts +storage/ioredis.cjs +storage/ioredis.js +storage/ioredis.d.ts +storage/upstash_redis.cjs +storage/upstash_redis.js +storage/upstash_redis.d.ts +storage/vercel_kv.cjs +storage/vercel_kv.js +storage/vercel_kv.d.ts +stores/doc/base.cjs +stores/doc/base.js +stores/doc/base.d.ts +stores/doc/in_memory.cjs +stores/doc/in_memory.js +stores/doc/in_memory.d.ts +stores/message/cassandra.cjs +stores/message/cassandra.js +stores/message/cassandra.d.ts +stores/message/cloudflare_d1.cjs +stores/message/cloudflare_d1.js +stores/message/cloudflare_d1.d.ts +stores/message/convex.cjs +stores/message/convex.js +stores/message/convex.d.ts +stores/message/dynamodb.cjs +stores/message/dynamodb.js +stores/message/dynamodb.d.ts +stores/message/firestore.cjs +stores/message/firestore.js +stores/message/firestore.d.ts +stores/message/ioredis.cjs +stores/message/ioredis.js +stores/message/ioredis.d.ts +stores/message/momento.cjs +stores/message/momento.js +stores/message/momento.d.ts +stores/message/mongodb.cjs +stores/message/mongodb.js +stores/message/mongodb.d.ts +stores/message/planetscale.cjs +stores/message/planetscale.js +stores/message/planetscale.d.ts +stores/message/redis.cjs +stores/message/redis.js +stores/message/redis.d.ts +stores/message/upstash_redis.cjs +stores/message/upstash_redis.js +stores/message/upstash_redis.d.ts +stores/message/xata.cjs +stores/message/xata.js +stores/message/xata.d.ts diff --git a/libs/langchain-community/LICENSE b/libs/langchain-community/LICENSE new file mode 100644 index 000000000000..8cd8f501eb49 --- /dev/null +++ b/libs/langchain-community/LICENSE @@ -0,0 +1,21 @@ +The MIT License + +Copyright (c) 2023 LangChain + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/libs/langchain-community/jest.config.cjs b/libs/langchain-community/jest.config.cjs new file mode 100644 index 000000000000..5cc0b1ab72c6 --- /dev/null +++ b/libs/langchain-community/jest.config.cjs @@ -0,0 +1,19 @@ +/** @type {import('ts-jest').JestConfigWithTsJest} */ +module.exports = { + preset: "ts-jest/presets/default-esm", + testEnvironment: "./jest.env.cjs", + modulePathIgnorePatterns: ["dist/", "docs/"], + moduleNameMapper: { + "^(\\.{1,2}/.*)\\.js$": "$1", + }, + transform: { + "^.+\\.tsx?$": ["@swc/jest"], + }, + transformIgnorePatterns: [ + "/node_modules/", + "\\.pnp\\.[^\\/]+$", + "./scripts/jest-setup-after-env.js", + ], + setupFiles: ["dotenv/config"], + testTimeout: 20_000, +}; diff --git a/libs/langchain-community/jest.env.cjs b/libs/langchain-community/jest.env.cjs new file mode 100644 index 000000000000..2ccedccb8672 --- /dev/null +++ b/libs/langchain-community/jest.env.cjs @@ -0,0 +1,12 @@ +const { TestEnvironment } = require("jest-environment-node"); + +class AdjustedTestEnvironmentToSupportFloat32Array extends TestEnvironment { + constructor(config, context) { + // Make `instanceof Float32Array` return true in tests + // to avoid https://github.com/xenova/transformers.js/issues/57 and https://github.com/jestjs/jest/issues/2549 + super(config, context); + this.global.Float32Array = Float32Array; + } +} + +module.exports = AdjustedTestEnvironmentToSupportFloat32Array; diff --git a/libs/langchain-community/package.json b/libs/langchain-community/package.json new file mode 100644 index 000000000000..e159d7b5949e --- /dev/null +++ b/libs/langchain-community/package.json @@ -0,0 +1,1631 @@ +{ + "name": "@langchain/community", + "version": "0.0.0", + "description": "Sample integration for LangChain.js", + "type": "module", + "engines": { + "node": ">=18" + }, + "main": "./index.js", + "types": "./index.d.ts", + "repository": { + "type": "git", + "url": "git@github.com:langchain-ai/langchainjs.git" + }, + "scripts": { + "build": "yarn clean && yarn build:esm && yarn build:cjs && yarn build:scripts", + "build:esm": "NODE_OPTIONS=--max-old-space-size=4096 tsc --outDir dist/ && rm -rf dist/tests dist/**/tests", + "build:cjs": "NODE_OPTIONS=--max-old-space-size=4096 tsc --outDir dist-cjs/ -p tsconfig.cjs.json && node scripts/move-cjs-to-dist.js && rm -rf dist-cjs", + "build:watch": "node scripts/create-entrypoints.js && tsc --outDir dist/ --watch", + "build:scripts": "node scripts/create-entrypoints.js && node scripts/check-tree-shaking.js", + "lint": "NODE_OPTIONS=--max-old-space-size=4096 eslint src && dpdm --exit-code circular:1 --no-warning --no-tree src/*.ts src/**/*.ts", + "lint:fix": "yarn lint --fix", + "clean": "rm -rf dist/ && NODE_OPTIONS=--max-old-space-size=4096 node scripts/create-entrypoints.js pre", + "prepack": "yarn build", + "release": "release-it --only-version --config .release-it.json", + "test": "NODE_OPTIONS=--experimental-vm-modules jest --testPathIgnorePatterns=\\.int\\.test.ts --testTimeout 30000 --maxWorkers=50%", + "test:watch": "NODE_OPTIONS=--experimental-vm-modules jest --watch --testPathIgnorePatterns=\\.int\\.test.ts", + "test:single": "NODE_OPTIONS=--experimental-vm-modules yarn run jest --config jest.config.cjs --testTimeout 100000", + "test:int": "NODE_OPTIONS=--experimental-vm-modules jest --testPathPattern=\\.int\\.test.ts --testTimeout 100000 --maxWorkers=50%", + "format": "prettier --write \"src\"", + "format:check": "prettier --check \"src\"" + }, + "author": "LangChain", + "license": "MIT", + "dependencies": { + "@langchain/core": "~0.0.11-rc.1", + "@langchain/openai": "~0.0.2-rc.0", + "flat": "^5.0.2", + "langsmith": "~0.0.48", + "ml-distance": "^4.0.0", + "uuid": "^9.0.0", + "zod": "^3.22.3" + }, + "devDependencies": { + "@aws-crypto/sha256-js": "^5.0.0", + "@aws-sdk/client-bedrock-runtime": "^3.422.0", + "@aws-sdk/client-dynamodb": "^3.310.0", + "@aws-sdk/client-kendra": "^3.352.0", + "@aws-sdk/client-lambda": "^3.310.0", + "@aws-sdk/client-sagemaker-runtime": "^3.414.0", + "@aws-sdk/client-sfn": "^3.362.0", + "@aws-sdk/credential-provider-node": "^3.388.0", + "@aws-sdk/types": "^3.357.0", + "@clickhouse/client": "^0.2.5", + "@cloudflare/ai": "^1.0.12", + "@cloudflare/workers-types": "^4.20230922.0", + "@elastic/elasticsearch": "^8.4.0", + "@getmetal/metal-sdk": "^4.0.0", + "@getzep/zep-js": "^0.9.0", + "@gomomento/sdk": "^1.51.1", + "@gomomento/sdk-core": "^1.51.1", + "@google-ai/generativelanguage": "^0.2.1", + "@google-cloud/storage": "^6.10.1", + "@gradientai/nodejs-sdk": "^1.2.0", + "@huggingface/inference": "^2.6.4", + "@jest/globals": "^29.5.0", + "@mozilla/readability": "^0.4.4", + "@notionhq/client": "^2.2.10", + "@opensearch-project/opensearch": "^2.2.0", + "@pinecone-database/pinecone": "^1.1.0", + "@planetscale/database": "^1.8.0", + "@qdrant/js-client-rest": "^1.2.0", + "@raycast/api": "^1.55.2", + "@rockset/client": "^0.9.1", + "@smithy/eventstream-codec": "^2.0.5", + "@smithy/protocol-http": "^3.0.6", + "@smithy/signature-v4": "^2.0.10", + "@smithy/util-utf8": "^2.0.0", + "@supabase/postgrest-js": "^1.1.1", + "@supabase/supabase-js": "^2.10.0", + "@swc/core": "^1.3.90", + "@swc/jest": "^0.2.29", + "@tensorflow-models/universal-sentence-encoder": "^1.3.3", + "@tensorflow/tfjs-backend-cpu": "^3", + "@tensorflow/tfjs-converter": "^3.6.0", + "@tensorflow/tfjs-core": "^3.6.0", + "@tsconfig/recommended": "^1.0.2", + "@types/flat": "^5.0.2", + "@types/html-to-text": "^9", + "@types/jsdom": "^21.1.1", + "@types/lodash": "^4", + "@types/mozilla-readability": "^0.2.1", + "@types/pg": "^8", + "@types/pg-copy-streams": "^1.2.2", + "@types/uuid": "^9", + "@types/ws": "^8", + "@typescript-eslint/eslint-plugin": "^5.58.0", + "@typescript-eslint/parser": "^5.58.0", + "@upstash/redis": "^1.20.6", + "@vercel/kv": "^0.2.3", + "@vercel/postgres": "^0.5.0", + "@writerai/writer-sdk": "^0.40.2", + "@xata.io/client": "^0.28.0", + "@xenova/transformers": "^2.5.4", + "@zilliz/milvus2-sdk-node": ">=2.2.11", + "axios": "^0.26.0", + "cassandra-driver": "^4.7.2", + "chromadb": "^1.5.3", + "closevector-common": "0.1.0-alpha.1", + "closevector-node": "0.1.0-alpha.10", + "closevector-web": "0.1.0-alpha.15", + "cohere-ai": ">=6.0.0", + "convex": "^1.3.1", + "d3-dsv": "^2.0.0", + "dotenv": "^16.0.3", + "dpdm": "^3.12.0", + "eslint": "^8.33.0", + "eslint-config-airbnb-base": "^15.0.0", + "eslint-config-prettier": "^8.6.0", + "eslint-plugin-import": "^2.27.5", + "eslint-plugin-jest": "^27.6.0", + "eslint-plugin-no-instanceof": "^1.0.1", + "eslint-plugin-prettier": "^4.2.1", + "faiss-node": "^0.5.1", + "fast-xml-parser": "^4.2.7", + "firebase-admin": "^11.9.0", + "google-auth-library": "^8.9.0", + "googleapis": "^126.0.1", + "graphql": "^16.6.0", + "hnswlib-node": "^1.4.2", + "html-to-text": "^9.0.5", + "ignore": "^5.2.0", + "ioredis": "^5.3.2", + "jest": "^29.5.0", + "jest-environment-node": "^29.6.4", + "jsdom": "^22.1.0", + "llmonitor": "^0.5.9", + "lodash": "^4.17.21", + "mammoth": "^1.5.1", + "mongodb": "^5.2.0", + "mysql2": "^3.3.3", + "neo4j-driver": "^5.12.0", + "node-llama-cpp": "2.7.3", + "pg": "^8.11.0", + "pg-copy-streams": "^6.0.5", + "pickleparser": "^0.2.1", + "portkey-ai": "^0.1.11", + "prettier": "^2.8.3", + "pyodide": "^0.24.1", + "redis": "^4.6.6", + "release-it": "^15.10.1", + "replicate": "^0.18.0", + "rollup": "^3.19.1", + "sqlite3": "^5.1.4", + "ts-jest": "^29.1.0", + "typeorm": "^0.3.12", + "typescript": "~5.1.6", + "typesense": "^1.5.3", + "usearch": "^1.1.1", + "vectordb": "^0.1.4", + "voy-search": "0.6.2", + "weaviate-ts-client": "^1.4.0", + "web-auth-library": "^1.0.3" + }, + "peerDependencies": { + "@aws-crypto/sha256-js": "^5.0.0", + "@aws-sdk/client-bedrock-runtime": "^3.422.0", + "@aws-sdk/client-dynamodb": "^3.310.0", + "@aws-sdk/client-kendra": "^3.352.0", + "@aws-sdk/client-lambda": "^3.310.0", + "@aws-sdk/client-sagemaker-runtime": "^3.310.0", + "@aws-sdk/client-sfn": "^3.310.0", + "@aws-sdk/credential-provider-node": "^3.388.0", + "@clickhouse/client": "^0.2.5", + "@cloudflare/ai": "^1.0.12", + "@elastic/elasticsearch": "^8.4.0", + "@faker-js/faker": "^7.6.0", + "@getmetal/metal-sdk": "*", + "@getzep/zep-js": "^0.9.0", + "@gomomento/sdk": "^1.51.1", + "@gomomento/sdk-core": "^1.51.1", + "@gomomento/sdk-web": "^1.51.1", + "@google-ai/generativelanguage": "^0.2.1", + "@google-cloud/storage": "^6.10.1", + "@gradientai/nodejs-sdk": "^1.2.0", + "@huggingface/inference": "^2.6.4", + "@mozilla/readability": "*", + "@notionhq/client": "^2.2.10", + "@opensearch-project/opensearch": "*", + "@pinecone-database/pinecone": "^1.1.0", + "@planetscale/database": "^1.8.0", + "@qdrant/js-client-rest": "^1.2.0", + "@raycast/api": "^1.55.2", + "@rockset/client": "^0.9.1", + "@smithy/eventstream-codec": "^2.0.5", + "@smithy/protocol-http": "^3.0.6", + "@smithy/signature-v4": "^2.0.10", + "@smithy/util-utf8": "^2.0.0", + "@supabase/postgrest-js": "^1.1.1", + "@supabase/supabase-js": "^2.10.0", + "@tensorflow-models/universal-sentence-encoder": "*", + "@tensorflow/tfjs-converter": "*", + "@tensorflow/tfjs-core": "*", + "@upstash/redis": "^1.20.6", + "@vercel/kv": "^0.2.3", + "@vercel/postgres": "^0.5.0", + "@writerai/writer-sdk": "^0.40.2", + "@xata.io/client": "^0.28.0", + "@xenova/transformers": "^2.5.4", + "@zilliz/milvus2-sdk-node": ">=2.2.7", + "axios": "*", + "cassandra-driver": "^4.7.2", + "chromadb": "*", + "closevector-common": "0.1.0-alpha.1", + "closevector-node": "0.1.0-alpha.10", + "closevector-web": "0.1.0-alpha.16", + "cohere-ai": ">=6.0.0", + "convex": "^1.3.1", + "d3-dsv": "^2.0.0", + "faiss-node": "^0.5.1", + "fast-xml-parser": "^4.2.7", + "firebase-admin": "^11.9.0", + "google-auth-library": "^8.9.0", + "googleapis": "^126.0.1", + "hnswlib-node": "^1.4.2", + "html-to-text": "^9.0.5", + "ignore": "^5.2.0", + "ioredis": "^5.3.2", + "jsdom": "*", + "llmonitor": "^0.5.9", + "lodash": "^4.17.21", + "mammoth": "*", + "mongodb": "^5.2.0", + "mysql2": "^3.3.3", + "neo4j-driver": "*", + "node-llama-cpp": "*", + "pg": "^8.11.0", + "pg-copy-streams": "^6.0.5", + "pickleparser": "^0.2.1", + "portkey-ai": "^0.1.11", + "pyodide": "^0.24.1", + "redis": "^4.6.4", + "replicate": "^0.18.0", + "typeorm": "^0.3.12", + "typesense": "^1.5.3", + "usearch": "^1.1.1", + "vectordb": "^0.1.4", + "voy-search": "0.6.2", + "weaviate-ts-client": "^1.4.0", + "web-auth-library": "^1.0.3", + "ws": "^8.14.2" + }, + "peerDependenciesMeta": { + "@aws-crypto/sha256-js": { + "optional": true + }, + "@aws-sdk/client-bedrock-runtime": { + "optional": true + }, + "@aws-sdk/client-dynamodb": { + "optional": true + }, + "@aws-sdk/client-kendra": { + "optional": true + }, + "@aws-sdk/client-lambda": { + "optional": true + }, + "@aws-sdk/client-sagemaker-runtime": { + "optional": true + }, + "@aws-sdk/client-sfn": { + "optional": true + }, + "@aws-sdk/credential-provider-node": { + "optional": true + }, + "@clickhouse/client": { + "optional": true + }, + "@cloudflare/ai": { + "optional": true + }, + "@elastic/elasticsearch": { + "optional": true + }, + "@getmetal/metal-sdk": { + "optional": true + }, + "@getzep/zep-js": { + "optional": true + }, + "@gomomento/sdk": { + "optional": true + }, + "@gomomento/sdk-core": { + "optional": true + }, + "@gomomento/sdk-web": { + "optional": true + }, + "@google-ai/generativelanguage": { + "optional": true + }, + "@google-cloud/storage": { + "optional": true + }, + "@gradientai/nodejs-sdk": { + "optional": true + }, + "@huggingface/inference": { + "optional": true + }, + "@mozilla/readability": { + "optional": true + }, + "@notionhq/client": { + "optional": true + }, + "@opensearch-project/opensearch": { + "optional": true + }, + "@pinecone-database/pinecone": { + "optional": true + }, + "@planetscale/database": { + "optional": true + }, + "@qdrant/js-client-rest": { + "optional": true + }, + "@raycast/api": { + "optional": true + }, + "@rockset/client": { + "optional": true + }, + "@smithy/eventstream-codec": { + "optional": true + }, + "@smithy/protocol-http": { + "optional": true + }, + "@smithy/signature-v4": { + "optional": true + }, + "@smithy/util-utf8": { + "optional": true + }, + "@supabase/postgrest-js": { + "optional": true + }, + "@supabase/supabase-js": { + "optional": true + }, + "@tensorflow-models/universal-sentence-encoder": { + "optional": true + }, + "@tensorflow/tfjs-converter": { + "optional": true + }, + "@tensorflow/tfjs-core": { + "optional": true + }, + "@upstash/redis": { + "optional": true + }, + "@vercel/kv": { + "optional": true + }, + "@vercel/postgres": { + "optional": true + }, + "@writerai/writer-sdk": { + "optional": true + }, + "@xata.io/client": { + "optional": true + }, + "@xenova/transformers": { + "optional": true + }, + "@zilliz/milvus2-sdk-node": { + "optional": true + }, + "axios": { + "optional": true + }, + "cassandra-driver": { + "optional": true + }, + "chromadb": { + "optional": true + }, + "closevector-common": { + "optional": true + }, + "closevector-node": { + "optional": true + }, + "closevector-web": { + "optional": true + }, + "cohere-ai": { + "optional": true + }, + "convex": { + "optional": true + }, + "d3-dsv": { + "optional": true + }, + "faiss-node": { + "optional": true + }, + "fast-xml-parser": { + "optional": true + }, + "firebase-admin": { + "optional": true + }, + "google-auth-library": { + "optional": true + }, + "googleapis": { + "optional": true + }, + "hnswlib-node": { + "optional": true + }, + "html-to-text": { + "optional": true + }, + "ignore": { + "optional": true + }, + "ioredis": { + "optional": true + }, + "jsdom": { + "optional": true + }, + "llmonitor": { + "optional": true + }, + "lodash": { + "optional": true + }, + "mammoth": { + "optional": true + }, + "mongodb": { + "optional": true + }, + "mysql2": { + "optional": true + }, + "neo4j-driver": { + "optional": true + }, + "node-llama-cpp": { + "optional": true + }, + "pg": { + "optional": true + }, + "pg-copy-streams": { + "optional": true + }, + "pickleparser": { + "optional": true + }, + "portkey-ai": { + "optional": true + }, + "pyodide": { + "optional": true + }, + "redis": { + "optional": true + }, + "replicate": { + "optional": true + }, + "typeorm": { + "optional": true + }, + "typesense": { + "optional": true + }, + "usearch": { + "optional": true + }, + "vectordb": { + "optional": true + }, + "voy-search": { + "optional": true + }, + "weaviate-ts-client": { + "optional": true + }, + "web-auth-library": { + "optional": true + }, + "ws": { + "optional": true + } + }, + "publishConfig": { + "access": "public" + }, + "exports": { + "./load": { + "types": "./load.d.ts", + "import": "./load.js", + "require": "./load.cjs" + }, + "./load/serializable": { + "types": "./load/serializable.d.ts", + "import": "./load/serializable.js", + "require": "./load/serializable.cjs" + }, + "./tools/aiplugin": { + "types": "./tools/aiplugin.d.ts", + "import": "./tools/aiplugin.js", + "require": "./tools/aiplugin.cjs" + }, + "./tools/aws_sfn": { + "types": "./tools/aws_sfn.d.ts", + "import": "./tools/aws_sfn.js", + "require": "./tools/aws_sfn.cjs" + }, + "./tools/bingserpapi": { + "types": "./tools/bingserpapi.d.ts", + "import": "./tools/bingserpapi.js", + "require": "./tools/bingserpapi.cjs" + }, + "./tools/brave_search": { + "types": "./tools/brave_search.d.ts", + "import": "./tools/brave_search.js", + "require": "./tools/brave_search.cjs" + }, + "./tools/connery": { + "types": "./tools/connery.d.ts", + "import": "./tools/connery.js", + "require": "./tools/connery.cjs" + }, + "./tools/dadjokeapi": { + "types": "./tools/dadjokeapi.d.ts", + "import": "./tools/dadjokeapi.js", + "require": "./tools/dadjokeapi.cjs" + }, + "./tools/dataforseo_api_search": { + "types": "./tools/dataforseo_api_search.d.ts", + "import": "./tools/dataforseo_api_search.js", + "require": "./tools/dataforseo_api_search.cjs" + }, + "./tools/gmail": { + "types": "./tools/gmail.d.ts", + "import": "./tools/gmail.js", + "require": "./tools/gmail.cjs" + }, + "./tools/google_custom_search": { + "types": "./tools/google_custom_search.d.ts", + "import": "./tools/google_custom_search.js", + "require": "./tools/google_custom_search.cjs" + }, + "./tools/google_places": { + "types": "./tools/google_places.d.ts", + "import": "./tools/google_places.js", + "require": "./tools/google_places.cjs" + }, + "./tools/ifttt": { + "types": "./tools/ifttt.d.ts", + "import": "./tools/ifttt.js", + "require": "./tools/ifttt.cjs" + }, + "./tools/searchapi": { + "types": "./tools/searchapi.d.ts", + "import": "./tools/searchapi.js", + "require": "./tools/searchapi.cjs" + }, + "./tools/searxng_search": { + "types": "./tools/searxng_search.d.ts", + "import": "./tools/searxng_search.js", + "require": "./tools/searxng_search.cjs" + }, + "./tools/serpapi": { + "types": "./tools/serpapi.d.ts", + "import": "./tools/serpapi.js", + "require": "./tools/serpapi.cjs" + }, + "./tools/serper": { + "types": "./tools/serper.d.ts", + "import": "./tools/serper.js", + "require": "./tools/serper.cjs" + }, + "./tools/wikipedia_query_run": { + "types": "./tools/wikipedia_query_run.d.ts", + "import": "./tools/wikipedia_query_run.js", + "require": "./tools/wikipedia_query_run.cjs" + }, + "./tools/wolframalpha": { + "types": "./tools/wolframalpha.d.ts", + "import": "./tools/wolframalpha.js", + "require": "./tools/wolframalpha.cjs" + }, + "./embeddings/bedrock": { + "types": "./embeddings/bedrock.d.ts", + "import": "./embeddings/bedrock.js", + "require": "./embeddings/bedrock.cjs" + }, + "./embeddings/cloudflare_workersai": { + "types": "./embeddings/cloudflare_workersai.d.ts", + "import": "./embeddings/cloudflare_workersai.js", + "require": "./embeddings/cloudflare_workersai.cjs" + }, + "./embeddings/cohere": { + "types": "./embeddings/cohere.d.ts", + "import": "./embeddings/cohere.js", + "require": "./embeddings/cohere.cjs" + }, + "./embeddings/googlepalm": { + "types": "./embeddings/googlepalm.d.ts", + "import": "./embeddings/googlepalm.js", + "require": "./embeddings/googlepalm.cjs" + }, + "./embeddings/googlevertexai": { + "types": "./embeddings/googlevertexai.d.ts", + "import": "./embeddings/googlevertexai.js", + "require": "./embeddings/googlevertexai.cjs" + }, + "./embeddings/gradient_ai": { + "types": "./embeddings/gradient_ai.d.ts", + "import": "./embeddings/gradient_ai.js", + "require": "./embeddings/gradient_ai.cjs" + }, + "./embeddings/hf": { + "types": "./embeddings/hf.d.ts", + "import": "./embeddings/hf.js", + "require": "./embeddings/hf.cjs" + }, + "./embeddings/hf_transformers": { + "types": "./embeddings/hf_transformers.d.ts", + "import": "./embeddings/hf_transformers.js", + "require": "./embeddings/hf_transformers.cjs" + }, + "./embeddings/llama_cpp": { + "types": "./embeddings/llama_cpp.d.ts", + "import": "./embeddings/llama_cpp.js", + "require": "./embeddings/llama_cpp.cjs" + }, + "./embeddings/minimax": { + "types": "./embeddings/minimax.d.ts", + "import": "./embeddings/minimax.js", + "require": "./embeddings/minimax.cjs" + }, + "./embeddings/ollama": { + "types": "./embeddings/ollama.d.ts", + "import": "./embeddings/ollama.js", + "require": "./embeddings/ollama.cjs" + }, + "./embeddings/tensorflow": { + "types": "./embeddings/tensorflow.d.ts", + "import": "./embeddings/tensorflow.js", + "require": "./embeddings/tensorflow.cjs" + }, + "./embeddings/voyage": { + "types": "./embeddings/voyage.d.ts", + "import": "./embeddings/voyage.js", + "require": "./embeddings/voyage.cjs" + }, + "./llms/ai21": { + "types": "./llms/ai21.d.ts", + "import": "./llms/ai21.js", + "require": "./llms/ai21.cjs" + }, + "./llms/aleph_alpha": { + "types": "./llms/aleph_alpha.d.ts", + "import": "./llms/aleph_alpha.js", + "require": "./llms/aleph_alpha.cjs" + }, + "./llms/bedrock": { + "types": "./llms/bedrock.d.ts", + "import": "./llms/bedrock.js", + "require": "./llms/bedrock.cjs" + }, + "./llms/bedrock/web": { + "types": "./llms/bedrock/web.d.ts", + "import": "./llms/bedrock/web.js", + "require": "./llms/bedrock/web.cjs" + }, + "./llms/cloudflare_workersai": { + "types": "./llms/cloudflare_workersai.d.ts", + "import": "./llms/cloudflare_workersai.js", + "require": "./llms/cloudflare_workersai.cjs" + }, + "./llms/cohere": { + "types": "./llms/cohere.d.ts", + "import": "./llms/cohere.js", + "require": "./llms/cohere.cjs" + }, + "./llms/fireworks": { + "types": "./llms/fireworks.d.ts", + "import": "./llms/fireworks.js", + "require": "./llms/fireworks.cjs" + }, + "./llms/googlepalm": { + "types": "./llms/googlepalm.d.ts", + "import": "./llms/googlepalm.js", + "require": "./llms/googlepalm.cjs" + }, + "./llms/googlevertexai": { + "types": "./llms/googlevertexai.d.ts", + "import": "./llms/googlevertexai.js", + "require": "./llms/googlevertexai.cjs" + }, + "./llms/googlevertexai/web": { + "types": "./llms/googlevertexai/web.d.ts", + "import": "./llms/googlevertexai/web.js", + "require": "./llms/googlevertexai/web.cjs" + }, + "./llms/gradient_ai": { + "types": "./llms/gradient_ai.d.ts", + "import": "./llms/gradient_ai.js", + "require": "./llms/gradient_ai.cjs" + }, + "./llms/hf": { + "types": "./llms/hf.d.ts", + "import": "./llms/hf.js", + "require": "./llms/hf.cjs" + }, + "./llms/llama_cpp": { + "types": "./llms/llama_cpp.d.ts", + "import": "./llms/llama_cpp.js", + "require": "./llms/llama_cpp.cjs" + }, + "./llms/ollama": { + "types": "./llms/ollama.d.ts", + "import": "./llms/ollama.js", + "require": "./llms/ollama.cjs" + }, + "./llms/portkey": { + "types": "./llms/portkey.d.ts", + "import": "./llms/portkey.js", + "require": "./llms/portkey.cjs" + }, + "./llms/raycast": { + "types": "./llms/raycast.d.ts", + "import": "./llms/raycast.js", + "require": "./llms/raycast.cjs" + }, + "./llms/replicate": { + "types": "./llms/replicate.d.ts", + "import": "./llms/replicate.js", + "require": "./llms/replicate.cjs" + }, + "./llms/sagemaker_endpoint": { + "types": "./llms/sagemaker_endpoint.d.ts", + "import": "./llms/sagemaker_endpoint.js", + "require": "./llms/sagemaker_endpoint.cjs" + }, + "./llms/watsonx_ai": { + "types": "./llms/watsonx_ai.d.ts", + "import": "./llms/watsonx_ai.js", + "require": "./llms/watsonx_ai.cjs" + }, + "./llms/writer": { + "types": "./llms/writer.d.ts", + "import": "./llms/writer.js", + "require": "./llms/writer.cjs" + }, + "./llms/yandex": { + "types": "./llms/yandex.d.ts", + "import": "./llms/yandex.js", + "require": "./llms/yandex.cjs" + }, + "./vectorstores/analyticdb": { + "types": "./vectorstores/analyticdb.d.ts", + "import": "./vectorstores/analyticdb.js", + "require": "./vectorstores/analyticdb.cjs" + }, + "./vectorstores/cassandra": { + "types": "./vectorstores/cassandra.d.ts", + "import": "./vectorstores/cassandra.js", + "require": "./vectorstores/cassandra.cjs" + }, + "./vectorstores/chroma": { + "types": "./vectorstores/chroma.d.ts", + "import": "./vectorstores/chroma.js", + "require": "./vectorstores/chroma.cjs" + }, + "./vectorstores/clickhouse": { + "types": "./vectorstores/clickhouse.d.ts", + "import": "./vectorstores/clickhouse.js", + "require": "./vectorstores/clickhouse.cjs" + }, + "./vectorstores/closevector/node": { + "types": "./vectorstores/closevector/node.d.ts", + "import": "./vectorstores/closevector/node.js", + "require": "./vectorstores/closevector/node.cjs" + }, + "./vectorstores/closevector/web": { + "types": "./vectorstores/closevector/web.d.ts", + "import": "./vectorstores/closevector/web.js", + "require": "./vectorstores/closevector/web.cjs" + }, + "./vectorstores/cloudflare_vectorize": { + "types": "./vectorstores/cloudflare_vectorize.d.ts", + "import": "./vectorstores/cloudflare_vectorize.js", + "require": "./vectorstores/cloudflare_vectorize.cjs" + }, + "./vectorstores/convex": { + "types": "./vectorstores/convex.d.ts", + "import": "./vectorstores/convex.js", + "require": "./vectorstores/convex.cjs" + }, + "./vectorstores/elasticsearch": { + "types": "./vectorstores/elasticsearch.d.ts", + "import": "./vectorstores/elasticsearch.js", + "require": "./vectorstores/elasticsearch.cjs" + }, + "./vectorstores/faiss": { + "types": "./vectorstores/faiss.d.ts", + "import": "./vectorstores/faiss.js", + "require": "./vectorstores/faiss.cjs" + }, + "./vectorstores/googlevertexai": { + "types": "./vectorstores/googlevertexai.d.ts", + "import": "./vectorstores/googlevertexai.js", + "require": "./vectorstores/googlevertexai.cjs" + }, + "./vectorstores/hnswlib": { + "types": "./vectorstores/hnswlib.d.ts", + "import": "./vectorstores/hnswlib.js", + "require": "./vectorstores/hnswlib.cjs" + }, + "./vectorstores/lancedb": { + "types": "./vectorstores/lancedb.d.ts", + "import": "./vectorstores/lancedb.js", + "require": "./vectorstores/lancedb.cjs" + }, + "./vectorstores/memory": { + "types": "./vectorstores/memory.d.ts", + "import": "./vectorstores/memory.js", + "require": "./vectorstores/memory.cjs" + }, + "./vectorstores/milvus": { + "types": "./vectorstores/milvus.d.ts", + "import": "./vectorstores/milvus.js", + "require": "./vectorstores/milvus.cjs" + }, + "./vectorstores/momento_vector_index": { + "types": "./vectorstores/momento_vector_index.d.ts", + "import": "./vectorstores/momento_vector_index.js", + "require": "./vectorstores/momento_vector_index.cjs" + }, + "./vectorstores/mongodb_atlas": { + "types": "./vectorstores/mongodb_atlas.d.ts", + "import": "./vectorstores/mongodb_atlas.js", + "require": "./vectorstores/mongodb_atlas.cjs" + }, + "./vectorstores/myscale": { + "types": "./vectorstores/myscale.d.ts", + "import": "./vectorstores/myscale.js", + "require": "./vectorstores/myscale.cjs" + }, + "./vectorstores/neo4j_vector": { + "types": "./vectorstores/neo4j_vector.d.ts", + "import": "./vectorstores/neo4j_vector.js", + "require": "./vectorstores/neo4j_vector.cjs" + }, + "./vectorstores/opensearch": { + "types": "./vectorstores/opensearch.d.ts", + "import": "./vectorstores/opensearch.js", + "require": "./vectorstores/opensearch.cjs" + }, + "./vectorstores/pgvector": { + "types": "./vectorstores/pgvector.d.ts", + "import": "./vectorstores/pgvector.js", + "require": "./vectorstores/pgvector.cjs" + }, + "./vectorstores/pinecone": { + "types": "./vectorstores/pinecone.d.ts", + "import": "./vectorstores/pinecone.js", + "require": "./vectorstores/pinecone.cjs" + }, + "./vectorstores/prisma": { + "types": "./vectorstores/prisma.d.ts", + "import": "./vectorstores/prisma.js", + "require": "./vectorstores/prisma.cjs" + }, + "./vectorstores/qdrant": { + "types": "./vectorstores/qdrant.d.ts", + "import": "./vectorstores/qdrant.js", + "require": "./vectorstores/qdrant.cjs" + }, + "./vectorstores/redis": { + "types": "./vectorstores/redis.d.ts", + "import": "./vectorstores/redis.js", + "require": "./vectorstores/redis.cjs" + }, + "./vectorstores/rockset": { + "types": "./vectorstores/rockset.d.ts", + "import": "./vectorstores/rockset.js", + "require": "./vectorstores/rockset.cjs" + }, + "./vectorstores/singlestore": { + "types": "./vectorstores/singlestore.d.ts", + "import": "./vectorstores/singlestore.js", + "require": "./vectorstores/singlestore.cjs" + }, + "./vectorstores/supabase": { + "types": "./vectorstores/supabase.d.ts", + "import": "./vectorstores/supabase.js", + "require": "./vectorstores/supabase.cjs" + }, + "./vectorstores/tigris": { + "types": "./vectorstores/tigris.d.ts", + "import": "./vectorstores/tigris.js", + "require": "./vectorstores/tigris.cjs" + }, + "./vectorstores/typeorm": { + "types": "./vectorstores/typeorm.d.ts", + "import": "./vectorstores/typeorm.js", + "require": "./vectorstores/typeorm.cjs" + }, + "./vectorstores/typesense": { + "types": "./vectorstores/typesense.d.ts", + "import": "./vectorstores/typesense.js", + "require": "./vectorstores/typesense.cjs" + }, + "./vectorstores/usearch": { + "types": "./vectorstores/usearch.d.ts", + "import": "./vectorstores/usearch.js", + "require": "./vectorstores/usearch.cjs" + }, + "./vectorstores/vectara": { + "types": "./vectorstores/vectara.d.ts", + "import": "./vectorstores/vectara.js", + "require": "./vectorstores/vectara.cjs" + }, + "./vectorstores/vercel_postgres": { + "types": "./vectorstores/vercel_postgres.d.ts", + "import": "./vectorstores/vercel_postgres.js", + "require": "./vectorstores/vercel_postgres.cjs" + }, + "./vectorstores/voy": { + "types": "./vectorstores/voy.d.ts", + "import": "./vectorstores/voy.js", + "require": "./vectorstores/voy.cjs" + }, + "./vectorstores/weaviate": { + "types": "./vectorstores/weaviate.d.ts", + "import": "./vectorstores/weaviate.js", + "require": "./vectorstores/weaviate.cjs" + }, + "./vectorstores/xata": { + "types": "./vectorstores/xata.d.ts", + "import": "./vectorstores/xata.js", + "require": "./vectorstores/xata.cjs" + }, + "./vectorstores/zep": { + "types": "./vectorstores/zep.d.ts", + "import": "./vectorstores/zep.js", + "require": "./vectorstores/zep.cjs" + }, + "./chat_models/baiduwenxin": { + "types": "./chat_models/baiduwenxin.d.ts", + "import": "./chat_models/baiduwenxin.js", + "require": "./chat_models/baiduwenxin.cjs" + }, + "./chat_models/bedrock": { + "types": "./chat_models/bedrock.d.ts", + "import": "./chat_models/bedrock.js", + "require": "./chat_models/bedrock.cjs" + }, + "./chat_models/bedrock/web": { + "types": "./chat_models/bedrock/web.d.ts", + "import": "./chat_models/bedrock/web.js", + "require": "./chat_models/bedrock/web.cjs" + }, + "./chat_models/cloudflare_workersai": { + "types": "./chat_models/cloudflare_workersai.d.ts", + "import": "./chat_models/cloudflare_workersai.js", + "require": "./chat_models/cloudflare_workersai.cjs" + }, + "./chat_models/fireworks": { + "types": "./chat_models/fireworks.d.ts", + "import": "./chat_models/fireworks.js", + "require": "./chat_models/fireworks.cjs" + }, + "./chat_models/googlevertexai": { + "types": "./chat_models/googlevertexai.d.ts", + "import": "./chat_models/googlevertexai.js", + "require": "./chat_models/googlevertexai.cjs" + }, + "./chat_models/googlevertexai/web": { + "types": "./chat_models/googlevertexai/web.d.ts", + "import": "./chat_models/googlevertexai/web.js", + "require": "./chat_models/googlevertexai/web.cjs" + }, + "./chat_models/googlepalm": { + "types": "./chat_models/googlepalm.d.ts", + "import": "./chat_models/googlepalm.js", + "require": "./chat_models/googlepalm.cjs" + }, + "./chat_models/iflytek_xinghuo": { + "types": "./chat_models/iflytek_xinghuo.d.ts", + "import": "./chat_models/iflytek_xinghuo.js", + "require": "./chat_models/iflytek_xinghuo.cjs" + }, + "./chat_models/iflytek_xinghuo/web": { + "types": "./chat_models/iflytek_xinghuo/web.d.ts", + "import": "./chat_models/iflytek_xinghuo/web.js", + "require": "./chat_models/iflytek_xinghuo/web.cjs" + }, + "./chat_models/llama_cpp": { + "types": "./chat_models/llama_cpp.d.ts", + "import": "./chat_models/llama_cpp.js", + "require": "./chat_models/llama_cpp.cjs" + }, + "./chat_models/minimax": { + "types": "./chat_models/minimax.d.ts", + "import": "./chat_models/minimax.js", + "require": "./chat_models/minimax.cjs" + }, + "./chat_models/ollama": { + "types": "./chat_models/ollama.d.ts", + "import": "./chat_models/ollama.js", + "require": "./chat_models/ollama.cjs" + }, + "./chat_models/portkey": { + "types": "./chat_models/portkey.d.ts", + "import": "./chat_models/portkey.js", + "require": "./chat_models/portkey.cjs" + }, + "./chat_models/yandex": { + "types": "./chat_models/yandex.d.ts", + "import": "./chat_models/yandex.js", + "require": "./chat_models/yandex.cjs" + }, + "./callbacks/handlers/llmonitor": { + "types": "./callbacks/handlers/llmonitor.d.ts", + "import": "./callbacks/handlers/llmonitor.js", + "require": "./callbacks/handlers/llmonitor.cjs" + }, + "./retrievers/amazon_kendra": { + "types": "./retrievers/amazon_kendra.d.ts", + "import": "./retrievers/amazon_kendra.js", + "require": "./retrievers/amazon_kendra.cjs" + }, + "./retrievers/chaindesk": { + "types": "./retrievers/chaindesk.d.ts", + "import": "./retrievers/chaindesk.js", + "require": "./retrievers/chaindesk.cjs" + }, + "./retrievers/databerry": { + "types": "./retrievers/databerry.d.ts", + "import": "./retrievers/databerry.js", + "require": "./retrievers/databerry.cjs" + }, + "./retrievers/metal": { + "types": "./retrievers/metal.d.ts", + "import": "./retrievers/metal.js", + "require": "./retrievers/metal.cjs" + }, + "./retrievers/supabase": { + "types": "./retrievers/supabase.d.ts", + "import": "./retrievers/supabase.js", + "require": "./retrievers/supabase.cjs" + }, + "./retrievers/tavily_search_api": { + "types": "./retrievers/tavily_search_api.d.ts", + "import": "./retrievers/tavily_search_api.js", + "require": "./retrievers/tavily_search_api.cjs" + }, + "./retrievers/zep": { + "types": "./retrievers/zep.d.ts", + "import": "./retrievers/zep.js", + "require": "./retrievers/zep.cjs" + }, + "./caches/cloudflare_kv": { + "types": "./caches/cloudflare_kv.d.ts", + "import": "./caches/cloudflare_kv.js", + "require": "./caches/cloudflare_kv.cjs" + }, + "./caches/momento": { + "types": "./caches/momento.d.ts", + "import": "./caches/momento.js", + "require": "./caches/momento.cjs" + }, + "./caches/upstash_redis": { + "types": "./caches/upstash_redis.d.ts", + "import": "./caches/upstash_redis.js", + "require": "./caches/upstash_redis.cjs" + }, + "./graphs/neo4j_graph": { + "types": "./graphs/neo4j_graph.d.ts", + "import": "./graphs/neo4j_graph.js", + "require": "./graphs/neo4j_graph.cjs" + }, + "./utils/event_source_parse": { + "types": "./utils/event_source_parse.d.ts", + "import": "./utils/event_source_parse.js", + "require": "./utils/event_source_parse.cjs" + }, + "./document_transformers/html_to_text": { + "types": "./document_transformers/html_to_text.d.ts", + "import": "./document_transformers/html_to_text.js", + "require": "./document_transformers/html_to_text.cjs" + }, + "./document_transformers/mozilla_readability": { + "types": "./document_transformers/mozilla_readability.d.ts", + "import": "./document_transformers/mozilla_readability.js", + "require": "./document_transformers/mozilla_readability.cjs" + }, + "./storage/convex": { + "types": "./storage/convex.d.ts", + "import": "./storage/convex.js", + "require": "./storage/convex.cjs" + }, + "./storage/ioredis": { + "types": "./storage/ioredis.d.ts", + "import": "./storage/ioredis.js", + "require": "./storage/ioredis.cjs" + }, + "./storage/upstash_redis": { + "types": "./storage/upstash_redis.d.ts", + "import": "./storage/upstash_redis.js", + "require": "./storage/upstash_redis.cjs" + }, + "./storage/vercel_kv": { + "types": "./storage/vercel_kv.d.ts", + "import": "./storage/vercel_kv.js", + "require": "./storage/vercel_kv.cjs" + }, + "./stores/doc/base": { + "types": "./stores/doc/base.d.ts", + "import": "./stores/doc/base.js", + "require": "./stores/doc/base.cjs" + }, + "./stores/doc/in_memory": { + "types": "./stores/doc/in_memory.d.ts", + "import": "./stores/doc/in_memory.js", + "require": "./stores/doc/in_memory.cjs" + }, + "./stores/message/cassandra": { + "types": "./stores/message/cassandra.d.ts", + "import": "./stores/message/cassandra.js", + "require": "./stores/message/cassandra.cjs" + }, + "./stores/message/cloudflare_d1": { + "types": "./stores/message/cloudflare_d1.d.ts", + "import": "./stores/message/cloudflare_d1.js", + "require": "./stores/message/cloudflare_d1.cjs" + }, + "./stores/message/convex": { + "types": "./stores/message/convex.d.ts", + "import": "./stores/message/convex.js", + "require": "./stores/message/convex.cjs" + }, + "./stores/message/dynamodb": { + "types": "./stores/message/dynamodb.d.ts", + "import": "./stores/message/dynamodb.js", + "require": "./stores/message/dynamodb.cjs" + }, + "./stores/message/firestore": { + "types": "./stores/message/firestore.d.ts", + "import": "./stores/message/firestore.js", + "require": "./stores/message/firestore.cjs" + }, + "./stores/message/ioredis": { + "types": "./stores/message/ioredis.d.ts", + "import": "./stores/message/ioredis.js", + "require": "./stores/message/ioredis.cjs" + }, + "./stores/message/momento": { + "types": "./stores/message/momento.d.ts", + "import": "./stores/message/momento.js", + "require": "./stores/message/momento.cjs" + }, + "./stores/message/mongodb": { + "types": "./stores/message/mongodb.d.ts", + "import": "./stores/message/mongodb.js", + "require": "./stores/message/mongodb.cjs" + }, + "./stores/message/planetscale": { + "types": "./stores/message/planetscale.d.ts", + "import": "./stores/message/planetscale.js", + "require": "./stores/message/planetscale.cjs" + }, + "./stores/message/redis": { + "types": "./stores/message/redis.d.ts", + "import": "./stores/message/redis.js", + "require": "./stores/message/redis.cjs" + }, + "./stores/message/upstash_redis": { + "types": "./stores/message/upstash_redis.d.ts", + "import": "./stores/message/upstash_redis.js", + "require": "./stores/message/upstash_redis.cjs" + }, + "./stores/message/xata": { + "types": "./stores/message/xata.d.ts", + "import": "./stores/message/xata.js", + "require": "./stores/message/xata.cjs" + }, + "./package.json": "./package.json" + }, + "files": [ + "dist/", + "load.cjs", + "load.js", + "load.d.ts", + "load/serializable.cjs", + "load/serializable.js", + "load/serializable.d.ts", + "tools/aiplugin.cjs", + "tools/aiplugin.js", + "tools/aiplugin.d.ts", + "tools/aws_sfn.cjs", + "tools/aws_sfn.js", + "tools/aws_sfn.d.ts", + "tools/bingserpapi.cjs", + "tools/bingserpapi.js", + "tools/bingserpapi.d.ts", + "tools/brave_search.cjs", + "tools/brave_search.js", + "tools/brave_search.d.ts", + "tools/connery.cjs", + "tools/connery.js", + "tools/connery.d.ts", + "tools/dadjokeapi.cjs", + "tools/dadjokeapi.js", + "tools/dadjokeapi.d.ts", + "tools/dataforseo_api_search.cjs", + "tools/dataforseo_api_search.js", + "tools/dataforseo_api_search.d.ts", + "tools/gmail.cjs", + "tools/gmail.js", + "tools/gmail.d.ts", + "tools/google_custom_search.cjs", + "tools/google_custom_search.js", + "tools/google_custom_search.d.ts", + "tools/google_places.cjs", + "tools/google_places.js", + "tools/google_places.d.ts", + "tools/ifttt.cjs", + "tools/ifttt.js", + "tools/ifttt.d.ts", + "tools/searchapi.cjs", + "tools/searchapi.js", + "tools/searchapi.d.ts", + "tools/searxng_search.cjs", + "tools/searxng_search.js", + "tools/searxng_search.d.ts", + "tools/serpapi.cjs", + "tools/serpapi.js", + "tools/serpapi.d.ts", + "tools/serper.cjs", + "tools/serper.js", + "tools/serper.d.ts", + "tools/wikipedia_query_run.cjs", + "tools/wikipedia_query_run.js", + "tools/wikipedia_query_run.d.ts", + "tools/wolframalpha.cjs", + "tools/wolframalpha.js", + "tools/wolframalpha.d.ts", + "embeddings/bedrock.cjs", + "embeddings/bedrock.js", + "embeddings/bedrock.d.ts", + "embeddings/cloudflare_workersai.cjs", + "embeddings/cloudflare_workersai.js", + "embeddings/cloudflare_workersai.d.ts", + "embeddings/cohere.cjs", + "embeddings/cohere.js", + "embeddings/cohere.d.ts", + "embeddings/googlepalm.cjs", + "embeddings/googlepalm.js", + "embeddings/googlepalm.d.ts", + "embeddings/googlevertexai.cjs", + "embeddings/googlevertexai.js", + "embeddings/googlevertexai.d.ts", + "embeddings/gradient_ai.cjs", + "embeddings/gradient_ai.js", + "embeddings/gradient_ai.d.ts", + "embeddings/hf.cjs", + "embeddings/hf.js", + "embeddings/hf.d.ts", + "embeddings/hf_transformers.cjs", + "embeddings/hf_transformers.js", + "embeddings/hf_transformers.d.ts", + "embeddings/llama_cpp.cjs", + "embeddings/llama_cpp.js", + "embeddings/llama_cpp.d.ts", + "embeddings/minimax.cjs", + "embeddings/minimax.js", + "embeddings/minimax.d.ts", + "embeddings/ollama.cjs", + "embeddings/ollama.js", + "embeddings/ollama.d.ts", + "embeddings/tensorflow.cjs", + "embeddings/tensorflow.js", + "embeddings/tensorflow.d.ts", + "embeddings/voyage.cjs", + "embeddings/voyage.js", + "embeddings/voyage.d.ts", + "llms/ai21.cjs", + "llms/ai21.js", + "llms/ai21.d.ts", + "llms/aleph_alpha.cjs", + "llms/aleph_alpha.js", + "llms/aleph_alpha.d.ts", + "llms/bedrock.cjs", + "llms/bedrock.js", + "llms/bedrock.d.ts", + "llms/bedrock/web.cjs", + "llms/bedrock/web.js", + "llms/bedrock/web.d.ts", + "llms/cloudflare_workersai.cjs", + "llms/cloudflare_workersai.js", + "llms/cloudflare_workersai.d.ts", + "llms/cohere.cjs", + "llms/cohere.js", + "llms/cohere.d.ts", + "llms/fireworks.cjs", + "llms/fireworks.js", + "llms/fireworks.d.ts", + "llms/googlepalm.cjs", + "llms/googlepalm.js", + "llms/googlepalm.d.ts", + "llms/googlevertexai.cjs", + "llms/googlevertexai.js", + "llms/googlevertexai.d.ts", + "llms/googlevertexai/web.cjs", + "llms/googlevertexai/web.js", + "llms/googlevertexai/web.d.ts", + "llms/gradient_ai.cjs", + "llms/gradient_ai.js", + "llms/gradient_ai.d.ts", + "llms/hf.cjs", + "llms/hf.js", + "llms/hf.d.ts", + "llms/llama_cpp.cjs", + "llms/llama_cpp.js", + "llms/llama_cpp.d.ts", + "llms/ollama.cjs", + "llms/ollama.js", + "llms/ollama.d.ts", + "llms/portkey.cjs", + "llms/portkey.js", + "llms/portkey.d.ts", + "llms/raycast.cjs", + "llms/raycast.js", + "llms/raycast.d.ts", + "llms/replicate.cjs", + "llms/replicate.js", + "llms/replicate.d.ts", + "llms/sagemaker_endpoint.cjs", + "llms/sagemaker_endpoint.js", + "llms/sagemaker_endpoint.d.ts", + "llms/watsonx_ai.cjs", + "llms/watsonx_ai.js", + "llms/watsonx_ai.d.ts", + "llms/writer.cjs", + "llms/writer.js", + "llms/writer.d.ts", + "llms/yandex.cjs", + "llms/yandex.js", + "llms/yandex.d.ts", + "vectorstores/analyticdb.cjs", + "vectorstores/analyticdb.js", + "vectorstores/analyticdb.d.ts", + "vectorstores/cassandra.cjs", + "vectorstores/cassandra.js", + "vectorstores/cassandra.d.ts", + "vectorstores/chroma.cjs", + "vectorstores/chroma.js", + "vectorstores/chroma.d.ts", + "vectorstores/clickhouse.cjs", + "vectorstores/clickhouse.js", + "vectorstores/clickhouse.d.ts", + "vectorstores/closevector/node.cjs", + "vectorstores/closevector/node.js", + "vectorstores/closevector/node.d.ts", + "vectorstores/closevector/web.cjs", + "vectorstores/closevector/web.js", + "vectorstores/closevector/web.d.ts", + "vectorstores/cloudflare_vectorize.cjs", + "vectorstores/cloudflare_vectorize.js", + "vectorstores/cloudflare_vectorize.d.ts", + "vectorstores/convex.cjs", + "vectorstores/convex.js", + "vectorstores/convex.d.ts", + "vectorstores/elasticsearch.cjs", + "vectorstores/elasticsearch.js", + "vectorstores/elasticsearch.d.ts", + "vectorstores/faiss.cjs", + "vectorstores/faiss.js", + "vectorstores/faiss.d.ts", + "vectorstores/googlevertexai.cjs", + "vectorstores/googlevertexai.js", + "vectorstores/googlevertexai.d.ts", + "vectorstores/hnswlib.cjs", + "vectorstores/hnswlib.js", + "vectorstores/hnswlib.d.ts", + "vectorstores/lancedb.cjs", + "vectorstores/lancedb.js", + "vectorstores/lancedb.d.ts", + "vectorstores/memory.cjs", + "vectorstores/memory.js", + "vectorstores/memory.d.ts", + "vectorstores/milvus.cjs", + "vectorstores/milvus.js", + "vectorstores/milvus.d.ts", + "vectorstores/momento_vector_index.cjs", + "vectorstores/momento_vector_index.js", + "vectorstores/momento_vector_index.d.ts", + "vectorstores/mongodb_atlas.cjs", + "vectorstores/mongodb_atlas.js", + "vectorstores/mongodb_atlas.d.ts", + "vectorstores/myscale.cjs", + "vectorstores/myscale.js", + "vectorstores/myscale.d.ts", + "vectorstores/neo4j_vector.cjs", + "vectorstores/neo4j_vector.js", + "vectorstores/neo4j_vector.d.ts", + "vectorstores/opensearch.cjs", + "vectorstores/opensearch.js", + "vectorstores/opensearch.d.ts", + "vectorstores/pgvector.cjs", + "vectorstores/pgvector.js", + "vectorstores/pgvector.d.ts", + "vectorstores/pinecone.cjs", + "vectorstores/pinecone.js", + "vectorstores/pinecone.d.ts", + "vectorstores/prisma.cjs", + "vectorstores/prisma.js", + "vectorstores/prisma.d.ts", + "vectorstores/qdrant.cjs", + "vectorstores/qdrant.js", + "vectorstores/qdrant.d.ts", + "vectorstores/redis.cjs", + "vectorstores/redis.js", + "vectorstores/redis.d.ts", + "vectorstores/rockset.cjs", + "vectorstores/rockset.js", + "vectorstores/rockset.d.ts", + "vectorstores/singlestore.cjs", + "vectorstores/singlestore.js", + "vectorstores/singlestore.d.ts", + "vectorstores/supabase.cjs", + "vectorstores/supabase.js", + "vectorstores/supabase.d.ts", + "vectorstores/tigris.cjs", + "vectorstores/tigris.js", + "vectorstores/tigris.d.ts", + "vectorstores/typeorm.cjs", + "vectorstores/typeorm.js", + "vectorstores/typeorm.d.ts", + "vectorstores/typesense.cjs", + "vectorstores/typesense.js", + "vectorstores/typesense.d.ts", + "vectorstores/usearch.cjs", + "vectorstores/usearch.js", + "vectorstores/usearch.d.ts", + "vectorstores/vectara.cjs", + "vectorstores/vectara.js", + "vectorstores/vectara.d.ts", + "vectorstores/vercel_postgres.cjs", + "vectorstores/vercel_postgres.js", + "vectorstores/vercel_postgres.d.ts", + "vectorstores/voy.cjs", + "vectorstores/voy.js", + "vectorstores/voy.d.ts", + "vectorstores/weaviate.cjs", + "vectorstores/weaviate.js", + "vectorstores/weaviate.d.ts", + "vectorstores/xata.cjs", + "vectorstores/xata.js", + "vectorstores/xata.d.ts", + "vectorstores/zep.cjs", + "vectorstores/zep.js", + "vectorstores/zep.d.ts", + "chat_models/baiduwenxin.cjs", + "chat_models/baiduwenxin.js", + "chat_models/baiduwenxin.d.ts", + "chat_models/bedrock.cjs", + "chat_models/bedrock.js", + "chat_models/bedrock.d.ts", + "chat_models/bedrock/web.cjs", + "chat_models/bedrock/web.js", + "chat_models/bedrock/web.d.ts", + "chat_models/cloudflare_workersai.cjs", + "chat_models/cloudflare_workersai.js", + "chat_models/cloudflare_workersai.d.ts", + "chat_models/fireworks.cjs", + "chat_models/fireworks.js", + "chat_models/fireworks.d.ts", + "chat_models/googlevertexai.cjs", + "chat_models/googlevertexai.js", + "chat_models/googlevertexai.d.ts", + "chat_models/googlevertexai/web.cjs", + "chat_models/googlevertexai/web.js", + "chat_models/googlevertexai/web.d.ts", + "chat_models/googlepalm.cjs", + "chat_models/googlepalm.js", + "chat_models/googlepalm.d.ts", + "chat_models/iflytek_xinghuo.cjs", + "chat_models/iflytek_xinghuo.js", + "chat_models/iflytek_xinghuo.d.ts", + "chat_models/iflytek_xinghuo/web.cjs", + "chat_models/iflytek_xinghuo/web.js", + "chat_models/iflytek_xinghuo/web.d.ts", + "chat_models/llama_cpp.cjs", + "chat_models/llama_cpp.js", + "chat_models/llama_cpp.d.ts", + "chat_models/minimax.cjs", + "chat_models/minimax.js", + "chat_models/minimax.d.ts", + "chat_models/ollama.cjs", + "chat_models/ollama.js", + "chat_models/ollama.d.ts", + "chat_models/portkey.cjs", + "chat_models/portkey.js", + "chat_models/portkey.d.ts", + "chat_models/yandex.cjs", + "chat_models/yandex.js", + "chat_models/yandex.d.ts", + "callbacks/handlers/llmonitor.cjs", + "callbacks/handlers/llmonitor.js", + "callbacks/handlers/llmonitor.d.ts", + "retrievers/amazon_kendra.cjs", + "retrievers/amazon_kendra.js", + "retrievers/amazon_kendra.d.ts", + "retrievers/chaindesk.cjs", + "retrievers/chaindesk.js", + "retrievers/chaindesk.d.ts", + "retrievers/databerry.cjs", + "retrievers/databerry.js", + "retrievers/databerry.d.ts", + "retrievers/metal.cjs", + "retrievers/metal.js", + "retrievers/metal.d.ts", + "retrievers/supabase.cjs", + "retrievers/supabase.js", + "retrievers/supabase.d.ts", + "retrievers/tavily_search_api.cjs", + "retrievers/tavily_search_api.js", + "retrievers/tavily_search_api.d.ts", + "retrievers/zep.cjs", + "retrievers/zep.js", + "retrievers/zep.d.ts", + "caches/cloudflare_kv.cjs", + "caches/cloudflare_kv.js", + "caches/cloudflare_kv.d.ts", + "caches/momento.cjs", + "caches/momento.js", + "caches/momento.d.ts", + "caches/upstash_redis.cjs", + "caches/upstash_redis.js", + "caches/upstash_redis.d.ts", + "graphs/neo4j_graph.cjs", + "graphs/neo4j_graph.js", + "graphs/neo4j_graph.d.ts", + "utils/event_source_parse.cjs", + "utils/event_source_parse.js", + "utils/event_source_parse.d.ts", + "document_transformers/html_to_text.cjs", + "document_transformers/html_to_text.js", + "document_transformers/html_to_text.d.ts", + "document_transformers/mozilla_readability.cjs", + "document_transformers/mozilla_readability.js", + "document_transformers/mozilla_readability.d.ts", + "storage/convex.cjs", + "storage/convex.js", + "storage/convex.d.ts", + "storage/ioredis.cjs", + "storage/ioredis.js", + "storage/ioredis.d.ts", + "storage/upstash_redis.cjs", + "storage/upstash_redis.js", + "storage/upstash_redis.d.ts", + "storage/vercel_kv.cjs", + "storage/vercel_kv.js", + "storage/vercel_kv.d.ts", + "stores/doc/base.cjs", + "stores/doc/base.js", + "stores/doc/base.d.ts", + "stores/doc/in_memory.cjs", + "stores/doc/in_memory.js", + "stores/doc/in_memory.d.ts", + "stores/message/cassandra.cjs", + "stores/message/cassandra.js", + "stores/message/cassandra.d.ts", + "stores/message/cloudflare_d1.cjs", + "stores/message/cloudflare_d1.js", + "stores/message/cloudflare_d1.d.ts", + "stores/message/convex.cjs", + "stores/message/convex.js", + "stores/message/convex.d.ts", + "stores/message/dynamodb.cjs", + "stores/message/dynamodb.js", + "stores/message/dynamodb.d.ts", + "stores/message/firestore.cjs", + "stores/message/firestore.js", + "stores/message/firestore.d.ts", + "stores/message/ioredis.cjs", + "stores/message/ioredis.js", + "stores/message/ioredis.d.ts", + "stores/message/momento.cjs", + "stores/message/momento.js", + "stores/message/momento.d.ts", + "stores/message/mongodb.cjs", + "stores/message/mongodb.js", + "stores/message/mongodb.d.ts", + "stores/message/planetscale.cjs", + "stores/message/planetscale.js", + "stores/message/planetscale.d.ts", + "stores/message/redis.cjs", + "stores/message/redis.js", + "stores/message/redis.d.ts", + "stores/message/upstash_redis.cjs", + "stores/message/upstash_redis.js", + "stores/message/upstash_redis.d.ts", + "stores/message/xata.cjs", + "stores/message/xata.js", + "stores/message/xata.d.ts" + ] +} diff --git a/libs/langchain-community/scripts/check-tree-shaking.js b/libs/langchain-community/scripts/check-tree-shaking.js new file mode 100644 index 000000000000..62f289fd855e --- /dev/null +++ b/libs/langchain-community/scripts/check-tree-shaking.js @@ -0,0 +1,88 @@ +import fs from "fs/promises"; +import { rollup } from "rollup"; + +const packageJson = JSON.parse(await fs.readFile("package.json", "utf-8")); + +export function listEntrypoints() { + const exports = packageJson.exports; + const entrypoints = []; + + for (const [key, value] of Object.entries(exports)) { + if (key === "./package.json") { + continue; + } + if (typeof value === "string") { + entrypoints.push(value); + } else if (typeof value === "object" && value.import) { + entrypoints.push(value.import); + } + } + + return entrypoints; +} + +export function listExternals() { + return [ + ...Object.keys(packageJson.dependencies), + ...Object.keys(packageJson.peerDependencies ?? {}), + /node\:/, + /@langchain\/core\//, + "convex", + "convex/server", + "convex/values", + "@rockset/client/dist/codegen/api.js", + "mysql2/promise", + "web-auth-library/google", + "firebase-admin/app", + "firebase-admin/firestore", + ]; +} + +export async function checkTreeShaking() { + const externals = listExternals(); + const entrypoints = listEntrypoints(); + const consoleLog = console.log; + const reportMap = new Map(); + + for (const entrypoint of entrypoints) { + let sideEffects = ""; + + console.log = function (...args) { + const line = args.length ? args.join(" ") : ""; + if (line.trim().startsWith("First side effect in")) { + sideEffects += line + "\n"; + } + }; + + await rollup({ + external: externals, + input: entrypoint, + experimentalLogSideEffects: true, + }); + + reportMap.set(entrypoint, { + log: sideEffects, + hasSideEffects: sideEffects.length > 0, + }); + } + + console.log = consoleLog; + + let failed = false; + for (const [entrypoint, report] of reportMap) { + if (report.hasSideEffects) { + failed = true; + console.log("---------------------------------"); + console.log(`Tree shaking failed for ${entrypoint}`); + console.log(report.log); + } + } + + if (failed) { + process.exit(1); + } else { + console.log("Tree shaking checks passed!"); + } +} + +checkTreeShaking(); diff --git a/libs/langchain-community/scripts/create-entrypoints.js b/libs/langchain-community/scripts/create-entrypoints.js new file mode 100644 index 000000000000..6d371face720 --- /dev/null +++ b/libs/langchain-community/scripts/create-entrypoints.js @@ -0,0 +1,451 @@ +import * as fs from "fs"; +import * as path from "path"; +import { identifySecrets } from "./identify-secrets.js"; + +// This lists all the entrypoints for the library. Each key corresponds to an +// importable path, eg. `import { AgentExecutor } from "langchain/agents"`. +// The value is the path to the file in `src/` that exports the entrypoint. +// This is used to generate the `exports` field in package.json. +// Order is not important. +const entrypoints = { + load: "load/index", + "load/serializable": "load/serializable", + "tools/aiplugin": "tools/aiplugin", + "tools/aws_sfn": "tools/aws_sfn", + "tools/bingserpapi": "tools/bingserpapi", + "tools/brave_search": "tools/brave_search", + "tools/connery": "tools/connery", + "tools/dadjokeapi": "tools/dadjokeapi", + "tools/dataforseo_api_search": "tools/dataforseo_api_search", + "tools/gmail": "tools/gmail/index", + "tools/google_custom_search": "tools/google_custom_search", + "tools/google_places": "tools/google_places", + "tools/ifttt": "tools/ifttt", + "tools/searchapi": "tools/searchapi", + "tools/searxng_search": "tools/searxng_search", + "tools/serpapi": "tools/serpapi", + "tools/serper": "tools/serper", + "tools/wikipedia_query_run": "tools/wikipedia_query_run", + "tools/wolframalpha": "tools/wolframalpha", + // embeddings + "embeddings/bedrock": "embeddings/bedrock", + "embeddings/cloudflare_workersai": "embeddings/cloudflare_workersai", + "embeddings/cohere": "embeddings/cohere", + "embeddings/googlepalm": "embeddings/googlepalm", + "embeddings/googlevertexai": "embeddings/googlevertexai", + "embeddings/gradient_ai": "embeddings/gradient_ai", + "embeddings/hf": "embeddings/hf", + "embeddings/hf_transformers": "embeddings/hf_transformers", + "embeddings/llama_cpp": "embeddings/llama_cpp", + "embeddings/minimax": "embeddings/minimax", + "embeddings/ollama": "embeddings/ollama", + "embeddings/tensorflow": "embeddings/tensorflow", + "embeddings/voyage": "embeddings/voyage", + // llms + "llms/ai21": "llms/ai21", + "llms/aleph_alpha": "llms/aleph_alpha", + "llms/bedrock": "llms/bedrock/index", + "llms/bedrock/web": "llms/bedrock/web", + "llms/cloudflare_workersai": "llms/cloudflare_workersai", + "llms/cohere": "llms/cohere", + "llms/fireworks": "llms/fireworks", + "llms/googlepalm": "llms/googlepalm", + "llms/googlevertexai": "llms/googlevertexai/index", + "llms/googlevertexai/web": "llms/googlevertexai/web", + "llms/gradient_ai": "llms/gradient_ai", + "llms/hf": "llms/hf", + "llms/llama_cpp": "llms/llama_cpp", + "llms/ollama": "llms/ollama", + "llms/portkey": "llms/portkey", + "llms/raycast": "llms/raycast", + "llms/replicate": "llms/replicate", + "llms/sagemaker_endpoint": "llms/sagemaker_endpoint", + "llms/watsonx_ai": "llms/watsonx_ai", + "llms/writer": "llms/writer", + "llms/yandex": "llms/yandex", + // vectorstores + "vectorstores/analyticdb": "vectorstores/analyticdb", + "vectorstores/cassandra": "vectorstores/cassandra", + "vectorstores/chroma": "vectorstores/chroma", + "vectorstores/clickhouse": "vectorstores/clickhouse", + "vectorstores/closevector/node": "vectorstores/closevector/node", + "vectorstores/closevector/web": "vectorstores/closevector/web", + "vectorstores/cloudflare_vectorize": "vectorstores/cloudflare_vectorize", + "vectorstores/convex": "vectorstores/convex", + "vectorstores/elasticsearch": "vectorstores/elasticsearch", + "vectorstores/faiss": "vectorstores/faiss", + "vectorstores/googlevertexai": "vectorstores/googlevertexai", + "vectorstores/hnswlib": "vectorstores/hnswlib", + "vectorstores/lancedb": "vectorstores/lancedb", + "vectorstores/memory": "vectorstores/memory", + "vectorstores/milvus": "vectorstores/milvus", + "vectorstores/momento_vector_index": "vectorstores/momento_vector_index", + "vectorstores/mongodb_atlas": "vectorstores/mongodb_atlas", + "vectorstores/myscale": "vectorstores/myscale", + "vectorstores/neo4j_vector": "vectorstores/neo4j_vector", + "vectorstores/opensearch": "vectorstores/opensearch", + "vectorstores/pgvector": "vectorstores/pgvector", + "vectorstores/pinecone": "vectorstores/pinecone", + "vectorstores/prisma": "vectorstores/prisma", + "vectorstores/qdrant": "vectorstores/qdrant", + "vectorstores/redis": "vectorstores/redis", + "vectorstores/rockset": "vectorstores/rockset", + "vectorstores/singlestore": "vectorstores/singlestore", + "vectorstores/supabase": "vectorstores/supabase", + "vectorstores/tigris": "vectorstores/tigris", + "vectorstores/typeorm": "vectorstores/typeorm", + "vectorstores/typesense": "vectorstores/typesense", + "vectorstores/usearch": "vectorstores/usearch", + "vectorstores/vectara": "vectorstores/vectara", + "vectorstores/vercel_postgres": "vectorstores/vercel_postgres", + "vectorstores/voy": "vectorstores/voy", + "vectorstores/weaviate": "vectorstores/weaviate", + "vectorstores/xata": "vectorstores/xata", + "vectorstores/zep": "vectorstores/zep", + // chat_models + "chat_models/baiduwenxin": "chat_models/baiduwenxin", + "chat_models/bedrock": "chat_models/bedrock/index", + "chat_models/bedrock/web": "chat_models/bedrock/web", + "chat_models/cloudflare_workersai": "chat_models/cloudflare_workersai", + "chat_models/fireworks": "chat_models/fireworks", + "chat_models/googlevertexai": "chat_models/googlevertexai/index", + "chat_models/googlevertexai/web": "chat_models/googlevertexai/web", + "chat_models/googlepalm": "chat_models/googlepalm", + "chat_models/iflytek_xinghuo": "chat_models/iflytek_xinghuo/index", + "chat_models/iflytek_xinghuo/web": "chat_models/iflytek_xinghuo/web", + "chat_models/llama_cpp": "chat_models/llama_cpp", + "chat_models/minimax": "chat_models/minimax", + "chat_models/ollama": "chat_models/ollama", + "chat_models/portkey": "chat_models/portkey", + "chat_models/yandex": "chat_models/yandex", + // callbacks + "callbacks/handlers/llmonitor": "callbacks/handlers/llmonitor", + // retrievers + "retrievers/amazon_kendra": "retrievers/amazon_kendra", + "retrievers/chaindesk": "retrievers/chaindesk", + "retrievers/databerry": "retrievers/databerry", + "retrievers/metal": "retrievers/metal", + "retrievers/supabase": "retrievers/supabase", + "retrievers/tavily_search_api": "retrievers/tavily_search_api", + "retrievers/zep": "retrievers/zep", + // cache + "caches/cloudflare_kv": "caches/cloudflare_kv", + "caches/momento": "caches/momento", + "caches/upstash_redis": "caches/upstash_redis", + // graphs + "graphs/neo4j_graph": "graphs/neo4j_graph", + "utils/event_source_parse": "utils/event_source_parse", + // document transformers + "document_transformers/html_to_text": "document_transformers/html_to_text", + "document_transformers/mozilla_readability": + "document_transformers/mozilla_readability", + // storage + "storage/convex": "storage/convex", + "storage/ioredis": "storage/ioredis", + "storage/upstash_redis": "storage/upstash_redis", + "storage/vercel_kv": "storage/vercel_kv", + // stores + "stores/doc/base": "stores/doc/base", + "stores/doc/in_memory": "stores/doc/in_memory", + "stores/message/cassandra": "stores/message/cassandra", + "stores/message/cloudflare_d1": "stores/message/cloudflare_d1", + "stores/message/convex": "stores/message/convex", + "stores/message/dynamodb": "stores/message/dynamodb", + "stores/message/firestore": "stores/message/firestore", + "stores/message/ioredis": "stores/message/ioredis", + "stores/message/momento": "stores/message/momento", + "stores/message/mongodb": "stores/message/mongodb", + "stores/message/planetscale": "stores/message/planetscale", + "stores/message/redis": "stores/message/redis", + "stores/message/upstash_redis": "stores/message/upstash_redis", + "stores/message/xata": "stores/message/xata", +}; + +// Entrypoints in this list will +// 1. Be excluded from the documentation +// 2. Be only available in Node.js environments (for backwards compatibility) +const deprecatedNodeOnly = []; + +// Entrypoints in this list require an optional dependency to be installed. +// Therefore they are not tested in the generated test-exports-* packages. +const requiresOptionalDependency = [ + "tools/aws_sfn", + "tools/gmail", + "callbacks/handlers/llmonitor", + "embeddings/bedrock", + "embeddings/cloudflare_workersai", + "embeddings/cohere", + "embeddings/googlevertexai", + "embeddings/googlepalm", + "embeddings/tensorflow", + "embeddings/hf", + "embeddings/hf_transformers", + "embeddings/llama_cpp", + "embeddings/gradient_ai", + "llms/load", + "llms/cohere", + "llms/googlevertexai", + "llms/googlevertexai/web", + "llms/googlepalm", + "llms/gradient_ai", + "llms/hf", + "llms/raycast", + "llms/replicate", + "llms/sagemaker_endpoint", + "llms/watsonx_ai", + "llms/bedrock", + "llms/bedrock/web", + "llms/llama_cpp", + "llms/writer", + "llms/portkey", + "vectorstores/analyticdb", + "vectorstores/cassandra", + "vectorstores/chroma", + "vectorstores/clickhouse", + "vectorstores/closevector/node", + "vectorstores/closevector/web", + "vectorstores/cloudflare_vectorize", + "vectorstores/convex", + "vectorstores/elasticsearch", + "vectorstores/faiss", + "vectorstores/googlevertexai", + "vectorstores/hnswlib", + "vectorstores/lancedb", + "vectorstores/milvus", + "vectorstores/momento_vector_index", + "vectorstores/mongodb_atlas", + "vectorstores/myscale", + "vectorstores/neo4j_vector", + "vectorstores/opensearch", + "vectorstores/pgvector", + "vectorstores/pinecone", + "vectorstores/qdrant", + "vectorstores/redis", + "vectorstores/rockset", + "vectorstores/singlestore", + "vectorstores/supabase", + "vectorstores/tigris", + "vectorstores/typeorm", + "vectorstores/typesense", + "vectorstores/usearch", + "vectorstores/vercel_postgres", + "vectorstores/voy", + "vectorstores/weaviate", + "vectorstores/xata", + "vectorstores/zep", + "chat_models/bedrock", + "chat_models/bedrock/web", + "chat_models/googlevertexai", + "chat_models/googlevertexai/web", + "chat_models/googlepalm", + "chat_models/llama_cpp", + "chat_models/portkey", + "chat_models/iflytek_xinghuo", + "chat_models/iflytek_xinghuo/web", + "retrievers/amazon_kendra", + "retrievers/supabase", + "retrievers/zep", + "retrievers/metal", + "cache/cloudflare_kv", + "cache/momento", + "cache/upstash_redis", + "graphs/neo4j_graph", + // document_transformers + "document_transformers/html_to_text", + "document_transformers/mozilla_readability", + // storage + "storage/convex", + "storage/ioredis", + "storage/upstash_redis", + "storage/vercel_kv", + // stores + "stores/message/cassandra", + "stores/message/cloudflare_d1", + "stores/message/convex", + "stores/message/dynamodb", + "stores/message/firestore", + "stores/message/ioredis", + "stores/message/momento", + "stores/message/mongodb", + "stores/message/planetscale", + "stores/message/redis", + "stores/message/upstash_redis", + "stores/message/xata", +]; + +const updateJsonFile = (relativePath, updateFunction) => { + const contents = fs.readFileSync(relativePath).toString(); + const res = updateFunction(JSON.parse(contents)); + fs.writeFileSync(relativePath, JSON.stringify(res, null, 2) + "\n"); +}; + +const generateFiles = () => { + const files = [...Object.entries(entrypoints)].flatMap( + ([key, value]) => { + const nrOfDots = key.split("/").length - 1; + const relativePath = "../".repeat(nrOfDots) || "./"; + const compiledPath = `${relativePath}dist/${value}.js`; + return [ + [ + `${key}.cjs`, + `module.exports = require('${relativePath}dist/${value}.cjs');`, + ], + [`${key}.js`, `export * from '${compiledPath}'`], + [`${key}.d.ts`, `export * from '${compiledPath}'`], + ]; + } + ); + + return Object.fromEntries(files); +}; + +const updateConfig = () => { + const generatedFiles = generateFiles(); + const filenames = Object.keys(generatedFiles); + + // Update package.json `exports` and `files` fields + updateJsonFile("./package.json", (json) => ({ + ...json, + exports: Object.assign( + Object.fromEntries( + [...Object.keys(entrypoints)].map((key) => { + let entryPoint = { + types: `./${key}.d.ts`, + import: `./${key}.js`, + require: `./${key}.cjs`, + }; + + if (deprecatedNodeOnly.includes(key)) { + entryPoint = { + node: entryPoint, + }; + } + + return [`./${key}`, entryPoint]; + }) + ), + { "./package.json": "./package.json" } + ), + files: ["dist/", ...filenames], + })); + + // Write generated files + Object.entries(generatedFiles).forEach(([filename, content]) => { + fs.mkdirSync(path.dirname(filename), { recursive: true }); + fs.writeFileSync(filename, content); + }); + + // Update .gitignore + fs.writeFileSync("./.gitignore", filenames.join("\n") + "\n"); + + // Update test-exports-*/entrypoints.js + const entrypointsToTest = Object.keys(entrypoints) + .filter((key) => !deprecatedNodeOnly.includes(key)) + .filter((key) => !requiresOptionalDependency.includes(key)); +}; + +const cleanGenerated = () => { + const filenames = Object.keys(generateFiles()); + filenames.forEach((fname) => { + try { + fs.unlinkSync(fname); + } catch { + // ignore error + } + }); +}; + +// Tuple describing the auto-generated import map (used by langchain/load) +// [package name, import statement, import map path] +// This will not include entrypoints deprecated or requiring optional deps. +const importMap = [ + "langchain-community", + (k, p) => `export * as ${k.replace(/\//g, "__")} from "../${p}.js";`, + "src/load/import_map.ts", +]; + +const generateImportMap = () => { + // Generate import map + const entrypointsToInclude = Object.keys(entrypoints) + .filter((key) => key !== "load") + .filter((key) => !deprecatedNodeOnly.includes(key)) + .filter((key) => !requiresOptionalDependency.includes(key)); + const [pkg, importStatement, importMapPath] = importMap; + const contents = + entrypointsToInclude + .map((key) => importStatement(key, entrypoints[key])) + .join("\n") + "\n"; + fs.writeFileSync( + `../${pkg}/${importMapPath}`, + "// Auto-generated by `scripts/create-entrypoints.js`. Do not edit manually.\n\n" + + contents + ); +}; + +const importTypes = [ + "langchain-community", + (k, p) => + ` "@langchain/community/${k}"?: + | typeof import("../${p}.js") + | Promise;`, + "src/load/import_type.d.ts", +]; + +const generateImportTypes = () => { + // Generate import types + const [pkg, importStatement, importTypesPath] = importTypes; + fs.writeFileSync( + `../${pkg}/${importTypesPath}`, + `// Auto-generated by \`scripts/create-entrypoints.js\`. Do not edit manually. + +export interface OptionalImportMap { +${Object.keys(entrypoints) + .filter((key) => !deprecatedNodeOnly.includes(key)) + .filter((key) => requiresOptionalDependency.includes(key)) + .map((key) => importStatement(key, entrypoints[key])) + .join("\n")} +} + +export interface SecretMap { +${[...identifySecrets()] + .sort() + .map((secret) => ` ${secret}?: string;`) + .join("\n")} +} +` + ); +}; + +const importConstants = [ + "langchain-community", + (k) => ` "langchain_community/${k}"`, + "src/load/import_constants.ts", +]; + +const generateImportConstants = () => { + // Generate import constants + const entrypointsToInclude = Object.keys(entrypoints) + .filter((key) => !deprecatedNodeOnly.includes(key)) + .filter((key) => requiresOptionalDependency.includes(key)); + const [pkg, importStatement, importConstantsPath] = importConstants; + const contents = + entrypointsToInclude + .map((key) => importStatement(key, entrypoints[key])) + .join(",\n") + ",\n];\n"; + fs.writeFileSync( + `../${pkg}/${importConstantsPath}`, + "// Auto-generated by `scripts/create-entrypoints.js`. Do not edit manually.\n\nexport const optionalImportEntrypoints = [\n" + + contents + ); +}; + +const command = process.argv[2]; + +if (command === "pre") { + cleanGenerated(); + generateImportMap(); + generateImportTypes(); + generateImportConstants(); +} else { + updateConfig(); +} diff --git a/libs/langchain-community/scripts/identify-secrets.js b/libs/langchain-community/scripts/identify-secrets.js new file mode 100644 index 000000000000..c54bdd97c870 --- /dev/null +++ b/libs/langchain-community/scripts/identify-secrets.js @@ -0,0 +1,77 @@ +import ts from "typescript"; +import * as fs from "fs"; + +export function identifySecrets() { + const secrets = new Set(); + + const tsConfig = ts.parseJsonConfigFileContent( + ts.readJsonConfigFile("./tsconfig.json", (p) => + fs.readFileSync(p, "utf-8") + ), + ts.sys, + "./src/" + ); + + for (const fileName of tsConfig.fileNames.filter( + (fn) => !fn.endsWith("test.ts") + )) { + const sourceFile = ts.createSourceFile( + fileName, + fs.readFileSync(fileName, "utf-8"), + tsConfig.options.target, + true + ); + sourceFile.forEachChild((node) => { + switch (node.kind) { + case ts.SyntaxKind.ClassDeclaration: + case ts.SyntaxKind.ClassExpression: { + node.forEachChild((node) => { + // look for get lc_secrets() + switch (node.kind) { + case ts.SyntaxKind.GetAccessor: { + const property = node; + if (property.name.getText() === "lc_secrets") { + // look for return { ... } + property.body.statements.forEach((stmt) => { + if ( + stmt.kind === ts.SyntaxKind.ReturnStatement && + stmt.expression.kind === + ts.SyntaxKind.ObjectLiteralExpression + ) { + // collect secret identifier + stmt.expression.properties.forEach((element) => { + if ( + element.initializer.kind === + ts.SyntaxKind.StringLiteral + ) { + const secret = element.initializer.text; + + if (secret.toUpperCase() !== secret) { + throw new Error( + `Secret identifier must be uppercase: ${secret} at ${fileName}` + ); + } + if (/\s/.test(secret)) { + throw new Error( + `Secret identifier must not contain whitespace: ${secret} at ${fileName}` + ); + } + + secrets.add(secret); + } + }); + } + }); + } + break; + } + } + }); + break; + } + } + }); + } + + return secrets; +} diff --git a/libs/langchain-community/scripts/move-cjs-to-dist.js b/libs/langchain-community/scripts/move-cjs-to-dist.js new file mode 100644 index 000000000000..1e89ccca88e9 --- /dev/null +++ b/libs/langchain-community/scripts/move-cjs-to-dist.js @@ -0,0 +1,38 @@ +import { resolve, dirname, parse, format } from "node:path"; +import { readdir, readFile, writeFile } from "node:fs/promises"; +import { fileURLToPath } from "node:url"; + +function abs(relativePath) { + return resolve(dirname(fileURLToPath(import.meta.url)), relativePath); +} + +async function moveAndRename(source, dest) { + for (const file of await readdir(abs(source), { withFileTypes: true })) { + if (file.isDirectory()) { + await moveAndRename(`${source}/${file.name}`, `${dest}/${file.name}`); + } else if (file.isFile()) { + const parsed = parse(file.name); + + // Ignore anything that's not a .js file + if (parsed.ext !== ".js") { + continue; + } + + // Rewrite any require statements to use .cjs + const content = await readFile(abs(`${source}/${file.name}`), "utf8"); + const rewritten = content.replace(/require\("(\..+?).js"\)/g, (_, p1) => { + return `require("${p1}.cjs")`; + }); + + // Rename the file to .cjs + const renamed = format({ name: parsed.name, ext: ".cjs" }); + + await writeFile(abs(`${dest}/${renamed}`), rewritten, "utf8"); + } + } +} + +moveAndRename("../dist-cjs", "../dist").catch((err) => { + console.error(err); + process.exit(1); +}); diff --git a/libs/langchain-community/scripts/release-branch.sh b/libs/langchain-community/scripts/release-branch.sh new file mode 100644 index 000000000000..7504238c5561 --- /dev/null +++ b/libs/langchain-community/scripts/release-branch.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +if [[ $(git branch --show-current) == "main" ]]; then + git checkout -B release + git push -u origin release +fi diff --git a/libs/langchain-community/src/caches/cloudflare_kv.ts b/libs/langchain-community/src/caches/cloudflare_kv.ts new file mode 100644 index 000000000000..7e3e11eded77 --- /dev/null +++ b/libs/langchain-community/src/caches/cloudflare_kv.ts @@ -0,0 +1,78 @@ +import type { KVNamespace } from "@cloudflare/workers-types"; + +import { + BaseCache, + getCacheKey, + serializeGeneration, + deserializeStoredGeneration, +} from "@langchain/core/caches"; +import { Generation } from "@langchain/core/outputs"; + +/** + * Represents a specific implementation of a caching mechanism using Cloudflare KV + * as the underlying storage system. It extends the `BaseCache` class and + * overrides its methods to provide the Cloudflare KV-specific logic. + * @example + * ```typescript + * // Example of using OpenAI with Cloudflare KV as cache in a Cloudflare Worker + * const cache = new CloudflareKVCache(env.KV_NAMESPACE); + * const model = new ChatAnthropic({ + * cache, + * }); + * const response = await model.invoke("How are you today?"); + * return new Response(JSON.stringify(response), { + * headers: { "content-type": "application/json" }, + * }); + * + * ``` + */ +export class CloudflareKVCache extends BaseCache { + private binding: KVNamespace; + + constructor(binding: KVNamespace) { + super(); + this.binding = binding; + } + + /** + * Retrieves data from the cache. It constructs a cache key from the given + * `prompt` and `llmKey`, and retrieves the corresponding value from the + * Cloudflare KV namespace. + * @param prompt The prompt used to construct the cache key. + * @param llmKey The LLM key used to construct the cache key. + * @returns An array of Generations if found, null otherwise. + */ + public async lookup(prompt: string, llmKey: string) { + let idx = 0; + let key = getCacheKey(prompt, llmKey, String(idx)); + let value = await this.binding.get(key); + const generations: Generation[] = []; + + while (value) { + generations.push(deserializeStoredGeneration(JSON.parse(value))); + idx += 1; + key = getCacheKey(prompt, llmKey, String(idx)); + value = await this.binding.get(key); + } + + return generations.length > 0 ? generations : null; + } + + /** + * Updates the cache with new data. It constructs a cache key from the + * given `prompt` and `llmKey`, and stores the `value` in the Cloudflare KV + * namespace. + * @param prompt The prompt used to construct the cache key. + * @param llmKey The LLM key used to construct the cache key. + * @param value The value to be stored in the cache. + */ + public async update(prompt: string, llmKey: string, value: Generation[]) { + for (let i = 0; i < value.length; i += 1) { + const key = getCacheKey(prompt, llmKey, String(i)); + await this.binding.put( + key, + JSON.stringify(serializeGeneration(value[i])) + ); + } + } +} diff --git a/libs/langchain-community/src/caches/momento.ts b/libs/langchain-community/src/caches/momento.ts new file mode 100644 index 000000000000..565286037518 --- /dev/null +++ b/libs/langchain-community/src/caches/momento.ts @@ -0,0 +1,175 @@ +/* eslint-disable no-instanceof/no-instanceof */ +import { + ICacheClient, + CacheGet, + CacheSet, + InvalidArgumentError, +} from "@gomomento/sdk-core"; + +import { + BaseCache, + deserializeStoredGeneration, + getCacheKey, + serializeGeneration, +} from "@langchain/core/caches"; +import { Generation } from "@langchain/core/outputs"; + +import { ensureCacheExists } from "../utils/momento.js"; + +/** + * The settings to instantiate the Momento standard cache. + */ +export interface MomentoCacheProps { + /** + * The Momento cache client. + */ + client: ICacheClient; + /** + * The name of the cache to use to store the data. + */ + cacheName: string; + /** + * The time to live for the cache items. If not specified, + * the cache client default is used. + */ + ttlSeconds?: number; + /** + * If true, ensure that the cache exists before returning. + * If false, the cache is not checked for existence. + * Defaults to true. + */ + ensureCacheExists?: true; +} + +/** + * A cache that uses Momento as the backing store. + * See https://gomomento.com. + * @example + * ```typescript + * const cache = new MomentoCache({ + * client: new CacheClient({ + * configuration: Configurations.Laptop.v1(), + * credentialProvider: CredentialProvider.fromEnvironmentVariable({ + * environmentVariableName: "MOMENTO_API_KEY", + * }), + * defaultTtlSeconds: 60 * 60 * 24, // Cache TTL set to 24 hours. + * }), + * cacheName: "langchain", + * }); + * // Initialize the OpenAI model with Momento cache for caching responses + * const model = new ChatOpenAI({ + * cache, + * }); + * await model.invoke("How are you today?"); + * const cachedValues = await cache.lookup("How are you today?", "llmKey"); + * ``` + */ +export class MomentoCache extends BaseCache { + private client: ICacheClient; + + private readonly cacheName: string; + + private readonly ttlSeconds?: number; + + private constructor(props: MomentoCacheProps) { + super(); + this.client = props.client; + this.cacheName = props.cacheName; + + this.validateTtlSeconds(props.ttlSeconds); + this.ttlSeconds = props.ttlSeconds; + } + + /** + * Create a new standard cache backed by Momento. + * + * @param {MomentoCacheProps} props The settings to instantiate the cache. + * @param {ICacheClient} props.client The Momento cache client. + * @param {string} props.cacheName The name of the cache to use to store the data. + * @param {number} props.ttlSeconds The time to live for the cache items. If not specified, + * the cache client default is used. + * @param {boolean} props.ensureCacheExists If true, ensure that the cache exists before returning. + * If false, the cache is not checked for existence. Defaults to true. + * @throws {@link InvalidArgumentError} if {@link props.ttlSeconds} is not strictly positive. + * @returns The Momento-backed cache. + */ + public static async fromProps( + props: MomentoCacheProps + ): Promise { + const instance = new MomentoCache(props); + if (props.ensureCacheExists || props.ensureCacheExists === undefined) { + await ensureCacheExists(props.client, props.cacheName); + } + return instance; + } + + /** + * Validate the user-specified TTL, if provided, is strictly positive. + * @param ttlSeconds The TTL to validate. + */ + private validateTtlSeconds(ttlSeconds?: number): void { + if (ttlSeconds !== undefined && ttlSeconds <= 0) { + throw new InvalidArgumentError("ttlSeconds must be positive."); + } + } + + /** + * Lookup LLM generations in cache by prompt and associated LLM key. + * @param prompt The prompt to lookup. + * @param llmKey The LLM key to lookup. + * @returns The generations associated with the prompt and LLM key, or null if not found. + */ + public async lookup( + prompt: string, + llmKey: string + ): Promise { + const key = getCacheKey(prompt, llmKey); + const getResponse = await this.client.get(this.cacheName, key); + + if (getResponse instanceof CacheGet.Hit) { + const value = getResponse.valueString(); + const parsedValue = JSON.parse(value); + if (!Array.isArray(parsedValue)) { + return null; + } + return JSON.parse(value).map(deserializeStoredGeneration); + } else if (getResponse instanceof CacheGet.Miss) { + return null; + } else if (getResponse instanceof CacheGet.Error) { + throw getResponse.innerException(); + } else { + throw new Error(`Unknown response type: ${getResponse.toString()}`); + } + } + + /** + * Update the cache with the given generations. + * + * Note this overwrites any existing generations for the given prompt and LLM key. + * + * @param prompt The prompt to update. + * @param llmKey The LLM key to update. + * @param value The generations to store. + */ + public async update( + prompt: string, + llmKey: string, + value: Generation[] + ): Promise { + const key = getCacheKey(prompt, llmKey); + const setResponse = await this.client.set( + this.cacheName, + key, + JSON.stringify(value.map(serializeGeneration)), + { ttl: this.ttlSeconds } + ); + + if (setResponse instanceof CacheSet.Success) { + // pass + } else if (setResponse instanceof CacheSet.Error) { + throw setResponse.innerException(); + } else { + throw new Error(`Unknown response type: ${setResponse.toString()}`); + } + } +} diff --git a/libs/langchain-community/src/caches/tests/momento.test.ts b/libs/langchain-community/src/caches/tests/momento.test.ts new file mode 100644 index 000000000000..9ba02464cc8b --- /dev/null +++ b/libs/langchain-community/src/caches/tests/momento.test.ts @@ -0,0 +1,329 @@ +import { expect } from "@jest/globals"; + +import { + ICacheClient, + IMomentoCache, + CacheDelete, + CacheGet, + CacheIncrement, + CacheKeyExists, + CacheKeysExist, + CacheSet, + CacheSetIfNotExists, + CacheSetFetch, + CacheSetAddElements, + CacheSetAddElement, + CacheSetRemoveElements, + CacheSetRemoveElement, + CacheListFetch, + CacheListLength, + CacheListPushFront, + CacheListPushBack, + CacheListConcatenateBack, + CacheListConcatenateFront, + CacheListPopBack, + CacheListPopFront, + CacheListRemoveValue, + CacheListRetain, + CacheDictionarySetField, + CacheDictionarySetFields, + CacheDictionaryGetField, + CacheDictionaryGetFields, + CacheDictionaryFetch, + CacheDictionaryLength, + CacheDictionaryIncrement, + CacheDictionaryRemoveField, + CacheDictionaryRemoveFields, + CacheSortedSetFetch, + CacheSortedSetPutElement, + CacheSortedSetPutElements, + CacheSortedSetGetRank, + CacheSortedSetGetScore, + CacheSortedSetGetScores, + CacheSortedSetLength, + CacheSortedSetLengthByScore, + CacheSortedSetIncrementScore, + CacheSortedSetRemoveElement, + CacheItemGetType, + CacheItemGetTtl, + CreateCache, + ListCaches, + DeleteCache, + CacheFlush, + CacheUpdateTtl, + CacheIncreaseTtl, + CacheDecreaseTtl, +} from "@gomomento/sdk-core"; +import { Generation } from "@langchain/core/outputs"; + +import { MomentoCache } from "../momento.js"; + +class MockClient implements ICacheClient { + private _cache: Map; + + constructor() { + this._cache = new Map(); + } + + cache(): IMomentoCache { + throw new Error("Method not implemented."); + } + + public async get(_: string, key: string): Promise { + if (this._cache.has(key)) { + return new CacheGet.Hit(new TextEncoder().encode(this._cache.get(key))); + } else { + return new CacheGet.Miss(); + } + } + + public async set( + _: string, + key: string, + value: string + ): Promise { + this._cache.set(key, value); + return new CacheSet.Success(); + } + + public async createCache(): Promise { + return new CreateCache.Success(); + } + + deleteCache(): Promise { + throw new Error("Method not implemented."); + } + + listCaches(): Promise { + throw new Error("Method not implemented."); + } + + flushCache(): Promise { + throw new Error("Method not implemented."); + } + + ping(): Promise { + throw new Error("Method not implemented."); + } + + delete(): Promise { + throw new Error("Method not implemented."); + } + + increment(): Promise { + throw new Error("Method not implemented."); + } + + keyExists(): Promise { + throw new Error("Method not implemented."); + } + + keysExist(): Promise { + throw new Error("Method not implemented."); + } + + setIfNotExists(): Promise { + throw new Error("Method not implemented."); + } + + setFetch(): Promise { + throw new Error("Method not implemented."); + } + + setAddElement(): Promise { + throw new Error("Method not implemented."); + } + + setAddElements(): Promise { + throw new Error("Method not implemented."); + } + + setRemoveElement(): Promise { + throw new Error("Method not implemented."); + } + + setRemoveElements(): Promise { + throw new Error("Method not implemented."); + } + + listFetch(): Promise { + throw new Error("Method not implemented."); + } + + listLength(): Promise { + throw new Error("Method not implemented."); + } + + listPushFront(): Promise { + throw new Error("Method not implemented."); + } + + listPushBack(): Promise { + throw new Error("Method not implemented."); + } + + listConcatenateBack(): Promise { + throw new Error("Method not implemented."); + } + + listConcatenateFront(): Promise { + throw new Error("Method not implemented."); + } + + listPopBack(): Promise { + throw new Error("Method not implemented."); + } + + listPopFront(): Promise { + throw new Error("Method not implemented."); + } + + listRemoveValue(): Promise { + throw new Error("Method not implemented."); + } + + listRetain(): Promise { + throw new Error("Method not implemented."); + } + + dictionarySetField(): Promise { + throw new Error("Method not implemented."); + } + + dictionarySetFields(): Promise { + throw new Error("Method not implemented."); + } + + dictionaryGetField(): Promise { + throw new Error("Method not implemented."); + } + + dictionaryGetFields(): Promise { + throw new Error("Method not implemented."); + } + + dictionaryFetch(): Promise { + throw new Error("Method not implemented."); + } + + dictionaryIncrement(): Promise { + throw new Error("Method not implemented."); + } + + dictionaryLength(): Promise { + throw new Error("Method not implemented."); + } + + dictionaryRemoveField(): Promise { + throw new Error("Method not implemented."); + } + + dictionaryRemoveFields(): Promise { + throw new Error("Method not implemented."); + } + + sortedSetFetchByRank(): Promise { + throw new Error("Method not implemented."); + } + + sortedSetFetchByScore(): Promise { + throw new Error("Method not implemented."); + } + + sortedSetPutElement(): Promise { + throw new Error("Method not implemented."); + } + + sortedSetPutElements(): Promise { + throw new Error("Method not implemented."); + } + + sortedSetGetRank(): Promise { + throw new Error("Method not implemented."); + } + + sortedSetGetScore(): Promise { + throw new Error("Method not implemented."); + } + + sortedSetGetScores(): Promise { + throw new Error("Method not implemented."); + } + + sortedSetIncrementScore(): Promise { + throw new Error("Method not implemented."); + } + + sortedSetLength(): Promise { + throw new Error("Method not implemented."); + } + + sortedSetLengthByScore(): Promise { + throw new Error("Method not implemented."); + } + + sortedSetRemoveElement(): Promise { + throw new Error("Method not implemented."); + } + + sortedSetRemoveElements(): Promise { + throw new Error("Method not implemented."); + } + + itemGetType(): Promise { + throw new Error("Method not implemented."); + } + + itemGetTtl(): Promise { + throw new Error("Method not implemented."); + } + + updateTtl(): Promise { + throw new Error("Method not implemented."); + } + + increaseTtl(): Promise { + throw new Error("Method not implemented."); + } + + decreaseTtl(): Promise { + throw new Error("Method not implemented."); + } +} + +describe("MomentoCache", () => { + it("should return null on a cache miss", async () => { + const client = new MockClient(); + const cache = await MomentoCache.fromProps({ + client, + cacheName: "test-cache", + }); + expect(await cache.lookup("prompt", "llm-key")).toBeNull(); + }); + + it("should get a stored value", async () => { + const client = new MockClient(); + const cache = await MomentoCache.fromProps({ + client, + cacheName: "test-cache", + }); + const generations: Generation[] = [{ text: "foo" }]; + await cache.update("prompt", "llm-key", generations); + expect(await cache.lookup("prompt", "llm-key")).toStrictEqual(generations); + }); + + it("should work with multiple generations", async () => { + const client = new MockClient(); + const cache = await MomentoCache.fromProps({ + client, + cacheName: "test-cache", + }); + const generations: Generation[] = [ + { text: "foo" }, + { text: "bar" }, + { text: "baz" }, + ]; + await cache.update("prompt", "llm-key", generations); + expect(await cache.lookup("prompt", "llm-key")).toStrictEqual(generations); + }); +}); diff --git a/libs/langchain-community/src/caches/tests/upstash_redis.int.test.ts b/libs/langchain-community/src/caches/tests/upstash_redis.int.test.ts new file mode 100644 index 000000000000..b6ba8628cc21 --- /dev/null +++ b/libs/langchain-community/src/caches/tests/upstash_redis.int.test.ts @@ -0,0 +1,38 @@ +/* eslint-disable no-process-env */ +import { ChatOpenAI } from "@langchain/openai"; +import { UpstashRedisCache } from "../upstash_redis.js"; + +/** + * This test is a result of the `lookup` method trying to parse an + * incorrectly typed value Before it was being typed as a string, + * whereas in reality it was a JSON object. + */ +test.skip("UpstashRedisCache does not parse non string cached values", async () => { + if ( + !process.env.UPSTASH_REDIS_REST_URL || + !process.env.UPSTASH_REDIS_REST_TOKEN || + !process.env.OPENAI_API_KEY + ) { + throw new Error( + "Missing Upstash Redis REST URL // REST TOKEN or OpenAI API key" + ); + } + const upstashRedisCache = new UpstashRedisCache({ + config: { + url: process.env.UPSTASH_REDIS_REST_URL, + token: process.env.UPSTASH_REDIS_REST_TOKEN, + }, + }); + + const chat = new ChatOpenAI({ + temperature: 0, + cache: upstashRedisCache, + maxTokens: 10, + }); + + const prompt = "is the sky blue"; + const result1 = await chat.predict(prompt); + const result2 = await chat.predict(prompt); + + expect(result1).toEqual(result2); +}); diff --git a/libs/langchain-community/src/caches/tests/upstash_redis.test.ts b/libs/langchain-community/src/caches/tests/upstash_redis.test.ts new file mode 100644 index 000000000000..fc8cc5cc0f92 --- /dev/null +++ b/libs/langchain-community/src/caches/tests/upstash_redis.test.ts @@ -0,0 +1,21 @@ +import { test, expect, jest } from "@jest/globals"; +import { insecureHash } from "@langchain/core/utils/hash"; +import { StoredGeneration } from "@langchain/core/messages"; + +import { UpstashRedisCache } from "../upstash_redis.js"; + +const sha1 = (str: string) => insecureHash(str); + +test("UpstashRedisCache", async () => { + const redis = { + get: jest.fn(async (key: string): Promise => { + if (key === sha1("foo_bar_0")) { + return { text: "baz" }; + } + return null; + }), + }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const cache = new UpstashRedisCache({ client: redis as any }); + expect(await cache.lookup("foo", "bar")).toEqual([{ text: "baz" }]); +}); diff --git a/libs/langchain-community/src/caches/upstash_redis.ts b/libs/langchain-community/src/caches/upstash_redis.ts new file mode 100644 index 000000000000..1cf89e82c826 --- /dev/null +++ b/libs/langchain-community/src/caches/upstash_redis.ts @@ -0,0 +1,93 @@ +import { Redis, type RedisConfigNodejs } from "@upstash/redis"; + +import { Generation } from "@langchain/core/outputs"; +import { + BaseCache, + deserializeStoredGeneration, + getCacheKey, + serializeGeneration, +} from "@langchain/core/caches"; +import { StoredGeneration } from "@langchain/core/messages"; + +export type UpstashRedisCacheProps = { + /** + * The config to use to instantiate an Upstash Redis client. + */ + config?: RedisConfigNodejs; + /** + * An existing Upstash Redis client. + */ + client?: Redis; +}; + +/** + * A cache that uses Upstash as the backing store. + * See https://docs.upstash.com/redis. + * @example + * ```typescript + * const cache = new UpstashRedisCache({ + * config: { + * url: "UPSTASH_REDIS_REST_URL", + * token: "UPSTASH_REDIS_REST_TOKEN", + * }, + * }); + * // Initialize the OpenAI model with Upstash Redis cache for caching responses + * const model = new ChatOpenAI({ + * cache, + * }); + * await model.invoke("How are you today?"); + * const cachedValues = await cache.lookup("How are you today?", "llmKey"); + * ``` + */ +export class UpstashRedisCache extends BaseCache { + private redisClient: Redis; + + constructor(props: UpstashRedisCacheProps) { + super(); + const { config, client } = props; + + if (client) { + this.redisClient = client; + } else if (config) { + this.redisClient = new Redis(config); + } else { + throw new Error( + `Upstash Redis caches require either a config object or a pre-configured client.` + ); + } + } + + /** + * Lookup LLM generations in cache by prompt and associated LLM key. + */ + public async lookup(prompt: string, llmKey: string) { + let idx = 0; + let key = getCacheKey(prompt, llmKey, String(idx)); + let value = await this.redisClient.get(key); + const generations: Generation[] = []; + + while (value) { + generations.push(deserializeStoredGeneration(value)); + idx += 1; + key = getCacheKey(prompt, llmKey, String(idx)); + value = await this.redisClient.get(key); + } + + return generations.length > 0 ? generations : null; + } + + /** + * Update the cache with the given generations. + * + * Note this overwrites any existing generations for the given prompt and LLM key. + */ + public async update(prompt: string, llmKey: string, value: Generation[]) { + for (let i = 0; i < value.length; i += 1) { + const key = getCacheKey(prompt, llmKey, String(i)); + await this.redisClient.set( + key, + JSON.stringify(serializeGeneration(value[i])) + ); + } + } +} diff --git a/libs/langchain-community/src/callbacks/handlers/llmonitor.ts b/libs/langchain-community/src/callbacks/handlers/llmonitor.ts new file mode 100644 index 000000000000..8359704322e6 --- /dev/null +++ b/libs/langchain-community/src/callbacks/handlers/llmonitor.ts @@ -0,0 +1,338 @@ +import monitor from "llmonitor"; +import { LLMonitorOptions, ChatMessage, cJSON } from "llmonitor/types"; +import { BaseRun, RunUpdate as BaseRunUpdate, KVMap } from "langsmith/schemas"; + +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { BaseMessage } from "@langchain/core/messages"; +import { ChainValues } from "@langchain/core/utils/types"; +import { LLMResult, Generation } from "@langchain/core/outputs"; +import { + BaseCallbackHandler, + BaseCallbackHandlerInput, +} from "@langchain/core/callbacks/base"; + +import { Serialized } from "../../load/serializable.js"; + +type Role = "user" | "ai" | "system" | "function" | "tool"; + +// Langchain Helpers +// Input can be either a single message, an array of message, or an array of array of messages (batch requests) + +const parseRole = (id: string[]): Role => { + const roleHint = id[id.length - 1]; + + if (roleHint.includes("Human")) return "user"; + if (roleHint.includes("System")) return "system"; + if (roleHint.includes("AI")) return "ai"; + if (roleHint.includes("Function")) return "function"; + if (roleHint.includes("Tool")) return "tool"; + + return "ai"; +}; + +type Message = BaseMessage | Generation | string; + +type OutputMessage = ChatMessage | string; + +const PARAMS_TO_CAPTURE = [ + "stop", + "stop_sequences", + "function_call", + "functions", + "tools", + "tool_choice", + "response_format", +]; + +export const convertToLLMonitorMessages = ( + input: Message | Message[] | Message[][] +): OutputMessage | OutputMessage[] | OutputMessage[][] => { + const parseMessage = (raw: Message): OutputMessage => { + if (typeof raw === "string") return raw; + // sometimes the message is nested in a "message" property + if ("message" in raw) return parseMessage(raw.message as Message); + + // Serialize + const message = JSON.parse(JSON.stringify(raw)); + + try { + // "id" contains an array describing the constructor, with last item actual schema type + const role = parseRole(message.id); + + const obj = message.kwargs; + const text = message.text ?? obj.content; + + return { + role, + text, + ...(obj.additional_kwargs ?? {}), + }; + } catch (e) { + // if parsing fails, return the original message + return message.text ?? message; + } + }; + + if (Array.isArray(input)) { + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore Confuses the compiler + return input.length === 1 + ? convertToLLMonitorMessages(input[0]) + : input.map(convertToLLMonitorMessages); + } + return parseMessage(input); +}; + +const parseInput = (rawInput: Record) => { + if (!rawInput) return null; + + const { input, inputs, question } = rawInput; + + if (input) return input; + if (inputs) return inputs; + if (question) return question; + + return rawInput; +}; + +const parseOutput = (rawOutput: Record) => { + if (!rawOutput) return null; + + const { text, output, answer, result } = rawOutput; + + if (text) return text; + if (answer) return answer; + if (output) return output; + if (result) return result; + + return rawOutput; +}; + +const parseExtraAndName = ( + llm: Serialized, + extraParams?: KVMap, + metadata?: KVMap +) => { + const params = { + ...(extraParams?.invocation_params ?? {}), + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore this is a valid property + ...(llm?.kwargs ?? {}), + ...(metadata || {}), + }; + + const { model, model_name, modelName, model_id, userId, userProps, ...rest } = + params; + + const name = model || modelName || model_name || model_id || llm.id.at(-1); + + // Filter rest to only include params we want to capture + const extra = Object.fromEntries( + Object.entries(rest).filter( + ([key]) => + PARAMS_TO_CAPTURE.includes(key) || + ["string", "number", "boolean"].includes(typeof rest[key]) + ) + ) as cJSON; + + return { name, extra, userId, userProps }; +}; + +export interface Run extends BaseRun { + id: string; + child_runs: this[]; + child_execution_order: number; +} + +export interface RunUpdate extends BaseRunUpdate { + events: BaseRun["events"]; +} + +export interface LLMonitorHandlerFields + extends BaseCallbackHandlerInput, + LLMonitorOptions {} + +export class LLMonitorHandler + extends BaseCallbackHandler + implements LLMonitorHandlerFields +{ + name = "llmonitor_handler"; + + monitor: typeof monitor; + + constructor(fields: LLMonitorHandlerFields = {}) { + super(fields); + + this.monitor = monitor; + + if (fields) { + const { appId, apiUrl, verbose } = fields; + + this.monitor.init({ + verbose, + appId: appId ?? getEnvironmentVariable("LLMONITOR_APP_ID"), + apiUrl: apiUrl ?? getEnvironmentVariable("LLMONITOR_API_URL"), + }); + } + } + + async handleLLMStart( + llm: Serialized, + prompts: string[], + runId: string, + parentRunId?: string, + extraParams?: KVMap, + tags?: string[], + metadata?: KVMap + ): Promise { + const { name, extra, userId, userProps } = parseExtraAndName( + llm, + extraParams, + metadata + ); + + await this.monitor.trackEvent("llm", "start", { + runId, + parentRunId, + name, + input: convertToLLMonitorMessages(prompts), + extra, + userId, + userProps, + tags, + runtime: "langchain-js", + }); + } + + async handleChatModelStart( + llm: Serialized, + messages: BaseMessage[][], + runId: string, + parentRunId?: string, + extraParams?: KVMap, + tags?: string[], + metadata?: KVMap + ): Promise { + const { name, extra, userId, userProps } = parseExtraAndName( + llm, + extraParams, + metadata + ); + + await this.monitor.trackEvent("llm", "start", { + runId, + parentRunId, + name, + input: convertToLLMonitorMessages(messages), + extra, + userId, + userProps, + tags, + runtime: "langchain-js", + }); + } + + async handleLLMEnd(output: LLMResult, runId: string): Promise { + const { generations, llmOutput } = output; + + await this.monitor.trackEvent("llm", "end", { + runId, + output: convertToLLMonitorMessages(generations), + tokensUsage: { + completion: llmOutput?.tokenUsage?.completionTokens, + prompt: llmOutput?.tokenUsage?.promptTokens, + }, + }); + } + + async handleLLMError(error: Error, runId: string): Promise { + await this.monitor.trackEvent("llm", "error", { + runId, + error, + }); + } + + async handleChainStart( + chain: Serialized, + inputs: ChainValues, + runId: string, + parentRunId?: string, + tags?: string[], + metadata?: KVMap + ): Promise { + const { agentName, userId, userProps, ...rest } = metadata || {}; + + // allow the user to specify an agent name + const name = agentName || chain.id.at(-1); + + // Attempt to automatically detect if this is an agent or chain + const runType = + agentName || ["AgentExecutor", "PlanAndExecute"].includes(name) + ? "agent" + : "chain"; + + await this.monitor.trackEvent(runType, "start", { + runId, + parentRunId, + name, + userId, + userProps, + input: parseInput(inputs) as cJSON, + extra: rest, + tags, + runtime: "langchain-js", + }); + } + + async handleChainEnd(outputs: ChainValues, runId: string): Promise { + await this.monitor.trackEvent("chain", "end", { + runId, + output: parseOutput(outputs) as cJSON, + }); + } + + async handleChainError(error: Error, runId: string): Promise { + await this.monitor.trackEvent("chain", "error", { + runId, + error, + }); + } + + async handleToolStart( + tool: Serialized, + input: string, + runId: string, + parentRunId?: string, + tags?: string[], + metadata?: KVMap + ): Promise { + const { toolName, userId, userProps, ...rest } = metadata || {}; + const name = toolName || tool.id.at(-1); + + await this.monitor.trackEvent("tool", "start", { + runId, + parentRunId, + name, + userId, + userProps, + input, + extra: rest, + tags, + runtime: "langchain-js", + }); + } + + async handleToolEnd(output: string, runId: string): Promise { + await this.monitor.trackEvent("tool", "end", { + runId, + output, + }); + } + + async handleToolError(error: Error, runId: string): Promise { + await this.monitor.trackEvent("tool", "error", { + runId, + error, + }); + } +} diff --git a/libs/langchain-community/src/chat_models/baiduwenxin.ts b/libs/langchain-community/src/chat_models/baiduwenxin.ts new file mode 100644 index 000000000000..f79973bed372 --- /dev/null +++ b/libs/langchain-community/src/chat_models/baiduwenxin.ts @@ -0,0 +1,540 @@ +import { + BaseChatModel, + type BaseChatModelParams, +} from "@langchain/core/language_models/chat_models"; +import { AIMessage, BaseMessage, ChatMessage } from "@langchain/core/messages"; +import { ChatGeneration, ChatResult } from "@langchain/core/outputs"; +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +/** + * Type representing the role of a message in the Wenxin chat model. + */ +export type WenxinMessageRole = "assistant" | "user"; + +/** + * Interface representing a message in the Wenxin chat model. + */ +interface WenxinMessage { + role: WenxinMessageRole; + content: string; +} + +/** + * Interface representing the usage of tokens in a chat completion. + */ +interface TokenUsage { + completionTokens?: number; + promptTokens?: number; + totalTokens?: number; +} + +/** + * Interface representing a request for a chat completion. + */ +interface ChatCompletionRequest { + messages: WenxinMessage[]; + stream?: boolean; + user_id?: string; + temperature?: number; + top_p?: number; + penalty_score?: number; + system?: string; +} + +/** + * Interface representing a response from a chat completion. + */ +interface ChatCompletionResponse { + id: string; + object: string; + created: number; + result: string; + need_clear_history: boolean; + usage: TokenUsage; +} + +/** + * Interface defining the input to the ChatBaiduWenxin class. + */ +declare interface BaiduWenxinChatInput { + /** Model name to use. Available options are: ERNIE-Bot, ERNIE-Bot-turbo, ERNIE-Bot-4 + * @default "ERNIE-Bot-turbo" + */ + modelName: string; + + /** Whether to stream the results or not. Defaults to false. */ + streaming?: boolean; + + /** Messages to pass as a prefix to the prompt */ + prefixMessages?: WenxinMessage[]; + + /** + * ID of the end-user who made requests. + */ + userId?: string; + + /** + * API key to use when making requests. Defaults to the value of + * `BAIDU_API_KEY` environment variable. + */ + baiduApiKey?: string; + + /** + * Secret key to use when making requests. Defaults to the value of + * `BAIDU_SECRET_KEY` environment variable. + */ + baiduSecretKey?: string; + + /** Amount of randomness injected into the response. Ranges + * from 0 to 1 (0 is not included). Use temp closer to 0 for analytical / + * multiple choice, and temp closer to 1 for creative + * and generative tasks. Defaults to 0.95. + */ + temperature?: number; + + /** Total probability mass of tokens to consider at each step. Range + * from 0 to 1.0. Defaults to 0.8. + */ + topP?: number; + + /** Penalizes repeated tokens according to frequency. Range + * from 1.0 to 2.0. Defaults to 1.0. + */ + penaltyScore?: number; +} + +/** + * Function that extracts the custom role of a generic chat message. + * @param message Chat message from which to extract the custom role. + * @returns The custom role of the chat message. + */ +function extractGenericMessageCustomRole(message: ChatMessage) { + if (message.role !== "assistant" && message.role !== "user") { + console.warn(`Unknown message role: ${message.role}`); + } + + return message.role as WenxinMessageRole; +} + +/** + * Function that converts a base message to a Wenxin message role. + * @param message Base message to convert. + * @returns The Wenxin message role. + */ +function messageToWenxinRole(message: BaseMessage): WenxinMessageRole { + const type = message._getType(); + switch (type) { + case "ai": + return "assistant"; + case "human": + return "user"; + case "system": + throw new Error("System messages should not be here"); + case "function": + throw new Error("Function messages not supported"); + case "generic": { + if (!ChatMessage.isInstance(message)) + throw new Error("Invalid generic chat message"); + return extractGenericMessageCustomRole(message); + } + default: + throw new Error(`Unknown message type: ${type}`); + } +} + +/** + * Wrapper around Baidu ERNIE large language models that use the Chat endpoint. + * + * To use you should have the `BAIDU_API_KEY` and `BAIDU_SECRET_KEY` + * environment variable set. + * + * @augments BaseLLM + * @augments BaiduERNIEInput + * @example + * ```typescript + * const ernieTurbo = new ChatBaiduWenxin({ + * baiduApiKey: "YOUR-API-KEY", + * baiduSecretKey: "YOUR-SECRET-KEY", + * }); + * + * const ernie = new ChatBaiduWenxin({ + * modelName: "ERNIE-Bot", + * temperature: 1, + * baiduApiKey: "YOUR-API-KEY", + * baiduSecretKey: "YOUR-SECRET-KEY", + * }); + * + * const messages = [new HumanMessage("Hello")]; + * + * let res = await ernieTurbo.call(messages); + * + * res = await ernie.call(messages); + * ``` + */ +export class ChatBaiduWenxin + extends BaseChatModel + implements BaiduWenxinChatInput +{ + static lc_name() { + return "ChatBaiduWenxin"; + } + + get callKeys(): string[] { + return ["stop", "signal", "options"]; + } + + get lc_secrets(): { [key: string]: string } | undefined { + return { + baiduApiKey: "BAIDU_API_KEY", + baiduSecretKey: "BAIDU_SECRET_KEY", + }; + } + + get lc_aliases(): { [key: string]: string } | undefined { + return undefined; + } + + lc_serializable = true; + + baiduApiKey?: string; + + baiduSecretKey?: string; + + accessToken: string; + + streaming = false; + + prefixMessages?: WenxinMessage[]; + + userId?: string; + + modelName = "ERNIE-Bot-turbo"; + + apiUrl: string; + + temperature?: number | undefined; + + topP?: number | undefined; + + penaltyScore?: number | undefined; + + constructor(fields?: Partial & BaseChatModelParams) { + super(fields ?? {}); + + this.baiduApiKey = + fields?.baiduApiKey ?? getEnvironmentVariable("BAIDU_API_KEY"); + if (!this.baiduApiKey) { + throw new Error("Baidu API key not found"); + } + + this.baiduSecretKey = + fields?.baiduSecretKey ?? getEnvironmentVariable("BAIDU_SECRET_KEY"); + if (!this.baiduSecretKey) { + throw new Error("Baidu Secret key not found"); + } + + this.streaming = fields?.streaming ?? this.streaming; + this.prefixMessages = fields?.prefixMessages ?? this.prefixMessages; + this.userId = fields?.userId ?? this.userId; + this.temperature = fields?.temperature ?? this.temperature; + this.topP = fields?.topP ?? this.topP; + this.penaltyScore = fields?.penaltyScore ?? this.penaltyScore; + + this.modelName = fields?.modelName ?? this.modelName; + + if (this.modelName === "ERNIE-Bot") { + this.apiUrl = + "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"; + } else if (this.modelName === "ERNIE-Bot-turbo") { + this.apiUrl = + "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"; + } else if (this.modelName === "ERNIE-Bot-4") { + this.apiUrl = + "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"; + } else { + throw new Error(`Invalid model name: ${this.modelName}`); + } + } + + /** + * Method that retrieves the access token for making requests to the Baidu + * API. + * @param options Optional parsed call options. + * @returns The access token for making requests to the Baidu API. + */ + async getAccessToken(options?: this["ParsedCallOptions"]) { + const url = `https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=${this.baiduApiKey}&client_secret=${this.baiduSecretKey}`; + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json", + }, + signal: options?.signal, + }); + if (!response.ok) { + const text = await response.text(); + const error = new Error( + `Baidu get access token failed with status code ${response.status}, response: ${text}` + ); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (error as any).response = response; + throw error; + } + const json = await response.json(); + return json.access_token; + } + + /** + * Get the parameters used to invoke the model + */ + invocationParams(): Omit { + return { + stream: this.streaming, + user_id: this.userId, + temperature: this.temperature, + top_p: this.topP, + penalty_score: this.penaltyScore, + }; + } + + /** + * Get the identifying parameters for the model + */ + identifyingParams() { + return { + model_name: this.modelName, + ...this.invocationParams(), + }; + } + + /** @ignore */ + async _generate( + messages: BaseMessage[], + options?: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): Promise { + const tokenUsage: TokenUsage = {}; + + const params = this.invocationParams(); + + // Wenxin requires the system message to be put in the params, not messages array + const systemMessage = messages.find( + (message) => message._getType() === "system" + ); + if (systemMessage) { + // eslint-disable-next-line no-param-reassign + messages = messages.filter((message) => message !== systemMessage); + params.system = systemMessage.text; + } + const messagesMapped: WenxinMessage[] = messages.map((message) => ({ + role: messageToWenxinRole(message), + content: message.text, + })); + + const data = params.stream + ? await new Promise((resolve, reject) => { + let response: ChatCompletionResponse; + let rejected = false; + let resolved = false; + this.completionWithRetry( + { + ...params, + messages: messagesMapped, + }, + true, + options?.signal, + (event) => { + const data = JSON.parse(event.data); + + if (data?.error_code) { + if (rejected) { + return; + } + rejected = true; + reject(new Error(data?.error_msg)); + return; + } + + const message = data as { + id: string; + object: string; + created: number; + sentence_id?: number; + is_end: boolean; + result: string; + need_clear_history: boolean; + usage: TokenUsage; + }; + + // on the first message set the response properties + if (!response) { + response = { + id: message.id, + object: message.object, + created: message.created, + result: message.result, + need_clear_history: message.need_clear_history, + usage: message.usage, + }; + } else { + response.result += message.result; + response.created = message.created; + response.need_clear_history = message.need_clear_history; + response.usage = message.usage; + } + + // TODO this should pass part.index to the callback + // when that's supported there + // eslint-disable-next-line no-void + void runManager?.handleLLMNewToken(message.result ?? ""); + + if (message.is_end) { + if (resolved || rejected) { + return; + } + resolved = true; + resolve(response); + } + } + ).catch((error) => { + if (!rejected) { + rejected = true; + reject(error); + } + }); + }) + : await this.completionWithRetry( + { + ...params, + messages: messagesMapped, + }, + false, + options?.signal + ).then((data) => { + if (data?.error_code) { + throw new Error(data?.error_msg); + } + return data; + }); + + const { + completion_tokens: completionTokens, + prompt_tokens: promptTokens, + total_tokens: totalTokens, + } = data.usage ?? {}; + + if (completionTokens) { + tokenUsage.completionTokens = + (tokenUsage.completionTokens ?? 0) + completionTokens; + } + + if (promptTokens) { + tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + promptTokens; + } + + if (totalTokens) { + tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens; + } + + const generations: ChatGeneration[] = []; + const text = data.result ?? ""; + generations.push({ + text, + message: new AIMessage(text), + }); + return { + generations, + llmOutput: { tokenUsage }, + }; + } + + /** @ignore */ + async completionWithRetry( + request: ChatCompletionRequest, + stream: boolean, + signal?: AbortSignal, + onmessage?: (event: MessageEvent) => void + ) { + // The first run will get the accessToken + if (!this.accessToken) { + this.accessToken = await this.getAccessToken(); + } + + const makeCompletionRequest = async () => { + const url = `${this.apiUrl}?access_token=${this.accessToken}`; + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(request), + signal, + }); + + if (!stream) { + return response.json(); + } else { + if (response.body) { + // response will not be a stream if an error occurred + if ( + !response.headers + .get("content-type") + ?.startsWith("text/event-stream") + ) { + onmessage?.( + new MessageEvent("message", { + data: await response.text(), + }) + ); + return; + } + + const reader = response.body.getReader(); + + const decoder = new TextDecoder("utf-8"); + let data = ""; + + let continueReading = true; + while (continueReading) { + const { done, value } = await reader.read(); + if (done) { + continueReading = false; + break; + } + data += decoder.decode(value); + + let continueProcessing = true; + while (continueProcessing) { + const newlineIndex = data.indexOf("\n"); + if (newlineIndex === -1) { + continueProcessing = false; + break; + } + const line = data.slice(0, newlineIndex); + data = data.slice(newlineIndex + 1); + + if (line.startsWith("data:")) { + const event = new MessageEvent("message", { + data: line.slice("data:".length).trim(), + }); + onmessage?.(event); + } + } + } + } + } + }; + return this.caller.call(makeCompletionRequest); + } + + _llmType() { + return "baiduwenxin"; + } + + /** @ignore */ + _combineLLMOutput() { + return []; + } +} diff --git a/libs/langchain-community/src/chat_models/bedrock/index.ts b/libs/langchain-community/src/chat_models/bedrock/index.ts new file mode 100644 index 000000000000..a7c2cd0d116e --- /dev/null +++ b/libs/langchain-community/src/chat_models/bedrock/index.ts @@ -0,0 +1,40 @@ +import { defaultProvider } from "@aws-sdk/credential-provider-node"; + +import type { BaseChatModelParams } from "@langchain/core/language_models/chat_models"; + +import { BaseBedrockInput } from "../../utils/bedrock.js"; +import { BedrockChat as BaseBedrockChat } from "./web.js"; + +/** + * @example + * ```typescript + * const model = new BedrockChat({ + * model: "anthropic.claude-v2", + * region: "us-east-1", + * }); + * const res = await model.invoke([{ content: "Tell me a joke" }]); + * console.log(res); + * ``` + */ +export class BedrockChat extends BaseBedrockChat { + static lc_name() { + return "BedrockChat"; + } + + constructor(fields?: Partial & BaseChatModelParams) { + super({ + ...fields, + credentials: fields?.credentials ?? defaultProvider(), + }); + } +} + +export { + convertMessagesToPromptAnthropic, + convertMessagesToPrompt, +} from "./web.js"; + +/** + * @deprecated Use `BedrockChat` instead. + */ +export const ChatBedrock = BedrockChat; diff --git a/libs/langchain-community/src/chat_models/bedrock/web.ts b/libs/langchain-community/src/chat_models/bedrock/web.ts new file mode 100644 index 000000000000..e0ac54c39f81 --- /dev/null +++ b/libs/langchain-community/src/chat_models/bedrock/web.ts @@ -0,0 +1,435 @@ +import { SignatureV4 } from "@smithy/signature-v4"; +import { HttpRequest } from "@smithy/protocol-http"; +import { EventStreamCodec } from "@smithy/eventstream-codec"; +import { fromUtf8, toUtf8 } from "@smithy/util-utf8"; +import { Sha256 } from "@aws-crypto/sha256-js"; + +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { + SimpleChatModel, + type BaseChatModelParams, +} from "@langchain/core/language_models/chat_models"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { + AIMessageChunk, + BaseMessage, + AIMessage, + ChatMessage, +} from "@langchain/core/messages"; +import { ChatGenerationChunk } from "@langchain/core/outputs"; + +import { + BaseBedrockInput, + BedrockLLMInputOutputAdapter, + type CredentialType, +} from "../../utils/bedrock.js"; +import type { SerializedFields } from "../../load/map_keys.js"; + +function convertOneMessageToText( + message: BaseMessage, + humanPrompt: string, + aiPrompt: string +): string { + if (message._getType() === "human") { + return `${humanPrompt} ${message.content}`; + } else if (message._getType() === "ai") { + return `${aiPrompt} ${message.content}`; + } else if (message._getType() === "system") { + return `${humanPrompt} ${message.content}`; + } else if (ChatMessage.isInstance(message)) { + return `\n\n${ + message.role[0].toUpperCase() + message.role.slice(1) + }: {message.content}`; + } + throw new Error(`Unknown role: ${message._getType()}`); +} + +export function convertMessagesToPromptAnthropic( + messages: BaseMessage[], + humanPrompt = "\n\nHuman:", + aiPrompt = "\n\nAssistant:" +): string { + const messagesCopy = [...messages]; + + if ( + messagesCopy.length === 0 || + messagesCopy[messagesCopy.length - 1]._getType() !== "ai" + ) { + messagesCopy.push(new AIMessage({ content: "" })); + } + + return messagesCopy + .map((message) => convertOneMessageToText(message, humanPrompt, aiPrompt)) + .join(""); +} + +/** + * Function that converts an array of messages into a single string prompt + * that can be used as input for a chat model. It delegates the conversion + * logic to the appropriate provider-specific function. + * @param messages Array of messages to be converted. + * @param options Options to be used during the conversion. + * @returns A string prompt that can be used as input for a chat model. + */ +export function convertMessagesToPrompt( + messages: BaseMessage[], + provider: string +): string { + if (provider === "anthropic") { + return convertMessagesToPromptAnthropic(messages); + } + throw new Error(`Provider ${provider} does not support chat.`); +} + +/** + * A type of Large Language Model (LLM) that interacts with the Bedrock + * service. It extends the base `LLM` class and implements the + * `BaseBedrockInput` interface. The class is designed to authenticate and + * interact with the Bedrock service, which is a part of Amazon Web + * Services (AWS). It uses AWS credentials for authentication and can be + * configured with various parameters such as the model to use, the AWS + * region, and the maximum number of tokens to generate. + * @example + * ```typescript + * const model = new BedrockChat({ + * model: "anthropic.claude-v2", + * region: "us-east-1", + * }); + * const res = await model.invoke([{ content: "Tell me a joke" }]); + * console.log(res); + * ``` + */ +export class BedrockChat extends SimpleChatModel implements BaseBedrockInput { + model = "amazon.titan-tg1-large"; + + region: string; + + credentials: CredentialType; + + temperature?: number | undefined = undefined; + + maxTokens?: number | undefined = undefined; + + fetchFn: typeof fetch; + + endpointHost?: string; + + /** @deprecated */ + stopSequences?: string[]; + + modelKwargs?: Record; + + codec: EventStreamCodec = new EventStreamCodec(toUtf8, fromUtf8); + + streaming = false; + + lc_serializable = true; + + get lc_aliases(): Record { + return { + model: "model_id", + region: "region_name", + }; + } + + get lc_secrets(): { [key: string]: string } | undefined { + return { + "credentials.accessKeyId": "BEDROCK_AWS_ACCESS_KEY_ID", + "credentials.secretAccessKey": "BEDROCK_AWS_SECRET_ACCESS_KEY", + }; + } + + get lc_attributes(): SerializedFields | undefined { + return { region: this.region }; + } + + _llmType() { + return "bedrock"; + } + + static lc_name() { + return "BedrockChat"; + } + + constructor(fields?: Partial & BaseChatModelParams) { + super(fields ?? {}); + + this.model = fields?.model ?? this.model; + const allowedModels = ["ai21", "anthropic", "amazon", "cohere", "meta"]; + if (!allowedModels.includes(this.model.split(".")[0])) { + throw new Error( + `Unknown model: '${this.model}', only these are supported: ${allowedModels}` + ); + } + const region = + fields?.region ?? getEnvironmentVariable("AWS_DEFAULT_REGION"); + if (!region) { + throw new Error( + "Please set the AWS_DEFAULT_REGION environment variable or pass it to the constructor as the region field." + ); + } + this.region = region; + + const credentials = fields?.credentials; + if (!credentials) { + throw new Error( + "Please set the AWS credentials in the 'credentials' field." + ); + } + this.credentials = credentials; + + this.temperature = fields?.temperature ?? this.temperature; + this.maxTokens = fields?.maxTokens ?? this.maxTokens; + this.fetchFn = fields?.fetchFn ?? fetch.bind(globalThis); + this.endpointHost = fields?.endpointHost ?? fields?.endpointUrl; + this.stopSequences = fields?.stopSequences; + this.modelKwargs = fields?.modelKwargs; + this.streaming = fields?.streaming ?? this.streaming; + } + + /** Call out to Bedrock service model. + Arguments: + prompt: The prompt to pass into the model. + + Returns: + The string generated by the model. + + Example: + response = model.call("Tell me a joke.") + */ + async _call( + messages: BaseMessage[], + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): Promise { + const service = "bedrock-runtime"; + const endpointHost = + this.endpointHost ?? `${service}.${this.region}.amazonaws.com`; + const provider = this.model.split(".")[0]; + if (this.streaming) { + const stream = this._streamResponseChunks(messages, options, runManager); + let finalResult: ChatGenerationChunk | undefined; + for await (const chunk of stream) { + if (finalResult === undefined) { + finalResult = chunk; + } else { + finalResult = finalResult.concat(chunk); + } + } + const messageContent = finalResult?.message.content; + if (messageContent && typeof messageContent !== "string") { + throw new Error( + "Non-string output for ChatBedrock is currently not supported." + ); + } + return messageContent ?? ""; + } + + const response = await this._signedFetch(messages, options, { + bedrockMethod: "invoke", + endpointHost, + provider, + }); + const json = await response.json(); + if (!response.ok) { + throw new Error( + `Error ${response.status}: ${json.message ?? JSON.stringify(json)}` + ); + } + const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, json); + return text; + } + + async _signedFetch( + messages: BaseMessage[], + options: this["ParsedCallOptions"], + fields: { + bedrockMethod: "invoke" | "invoke-with-response-stream"; + endpointHost: string; + provider: string; + } + ) { + const { bedrockMethod, endpointHost, provider } = fields; + const inputBody = BedrockLLMInputOutputAdapter.prepareInput( + provider, + convertMessagesToPromptAnthropic(messages), + this.maxTokens, + this.temperature, + options.stop ?? this.stopSequences, + this.modelKwargs, + fields.bedrockMethod + ); + + const url = new URL( + `https://${endpointHost}/model/${this.model}/${bedrockMethod}` + ); + + const request = new HttpRequest({ + hostname: url.hostname, + path: url.pathname, + protocol: url.protocol, + method: "POST", // method must be uppercase + body: JSON.stringify(inputBody), + query: Object.fromEntries(url.searchParams.entries()), + headers: { + // host is required by AWS Signature V4: https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html + host: url.host, + accept: "application/json", + "content-type": "application/json", + }, + }); + + const signer = new SignatureV4({ + credentials: this.credentials, + service: "bedrock", + region: this.region, + sha256: Sha256, + }); + + const signedRequest = await signer.sign(request); + + // Send request to AWS using the low-level fetch API + const response = await this.caller.callWithOptions( + { signal: options.signal }, + async () => + this.fetchFn(url, { + headers: signedRequest.headers, + body: signedRequest.body, + method: signedRequest.method, + }) + ); + return response; + } + + async *_streamResponseChunks( + messages: BaseMessage[], + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): AsyncGenerator { + const provider = this.model.split(".")[0]; + const service = "bedrock-runtime"; + + const endpointHost = + this.endpointHost ?? `${service}.${this.region}.amazonaws.com`; + + const bedrockMethod = + provider === "anthropic" || provider === "cohere" || provider === "meta" + ? "invoke-with-response-stream" + : "invoke"; + + const response = await this._signedFetch(messages, options, { + bedrockMethod, + endpointHost, + provider, + }); + + if (response.status < 200 || response.status >= 300) { + throw Error( + `Failed to access underlying url '${endpointHost}': got ${ + response.status + } ${response.statusText}: ${await response.text()}` + ); + } + + if ( + provider === "anthropic" || + provider === "cohere" || + provider === "meta" + ) { + const reader = response.body?.getReader(); + const decoder = new TextDecoder(); + for await (const chunk of this._readChunks(reader)) { + const event = this.codec.decode(chunk); + if ( + (event.headers[":event-type"] !== undefined && + event.headers[":event-type"].value !== "chunk") || + event.headers[":content-type"].value !== "application/json" + ) { + throw Error(`Failed to get event chunk: got ${chunk}`); + } + const body = JSON.parse(decoder.decode(event.body)); + if (body.message) { + throw new Error(body.message); + } + if (body.bytes !== undefined) { + const chunkResult = JSON.parse( + decoder.decode( + Uint8Array.from(atob(body.bytes), (m) => m.codePointAt(0) ?? 0) + ) + ); + const text = BedrockLLMInputOutputAdapter.prepareOutput( + provider, + chunkResult + ); + yield new ChatGenerationChunk({ + text, + message: new AIMessageChunk({ content: text }), + }); + // eslint-disable-next-line no-void + void runManager?.handleLLMNewToken(text); + } + } + } else { + const json = await response.json(); + const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, json); + yield new ChatGenerationChunk({ + text, + message: new AIMessageChunk({ content: text }), + }); + // eslint-disable-next-line no-void + void runManager?.handleLLMNewToken(text); + } + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + _readChunks(reader: any) { + function _concatChunks(a: Uint8Array, b: Uint8Array) { + const newBuffer = new Uint8Array(a.length + b.length); + newBuffer.set(a); + newBuffer.set(b, a.length); + return newBuffer; + } + + function getMessageLength(buffer: Uint8Array) { + if (buffer.byteLength === 0) return 0; + const view = new DataView( + buffer.buffer, + buffer.byteOffset, + buffer.byteLength + ); + + return view.getUint32(0, false); + } + + return { + async *[Symbol.asyncIterator]() { + let readResult = await reader.read(); + + let buffer: Uint8Array = new Uint8Array(0); + while (!readResult.done) { + const chunk: Uint8Array = readResult.value; + + buffer = _concatChunks(buffer, chunk); + let messageLength = getMessageLength(buffer); + + while (buffer.byteLength > 0 && buffer.byteLength >= messageLength) { + yield buffer.slice(0, messageLength); + buffer = buffer.slice(messageLength); + messageLength = getMessageLength(buffer); + } + + readResult = await reader.read(); + } + }, + }; + } + + _combineLLMOutput() { + return {}; + } +} + +/** + * @deprecated Use `BedrockChat` instead. + */ +export const ChatBedrock = BedrockChat; diff --git a/libs/langchain-community/src/chat_models/cloudflare_workersai.ts b/libs/langchain-community/src/chat_models/cloudflare_workersai.ts new file mode 100644 index 000000000000..c50de33f765a --- /dev/null +++ b/libs/langchain-community/src/chat_models/cloudflare_workersai.ts @@ -0,0 +1,251 @@ +import { + SimpleChatModel, + type BaseChatModelParams, +} from "@langchain/core/language_models/chat_models"; +import type { BaseLanguageModelCallOptions } from "@langchain/core/language_models/base"; +import { + AIMessageChunk, + BaseMessage, + ChatMessage, +} from "@langchain/core/messages"; +import { ChatGenerationChunk } from "@langchain/core/outputs"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; + +import type { CloudflareWorkersAIInput } from "../llms/cloudflare_workersai.js"; +import { convertEventStreamToIterableReadableDataStream } from "../utils/event_source_parse.js"; + +/** + * An interface defining the options for a Cloudflare Workers AI call. It extends + * the BaseLanguageModelCallOptions interface. + */ +export interface ChatCloudflareWorkersAICallOptions + extends BaseLanguageModelCallOptions {} + +/** + * A class that enables calls to the Cloudflare Workers AI API to access large language + * models in a chat-like fashion. It extends the SimpleChatModel class and + * implements the CloudflareWorkersAIInput interface. + * @example + * ```typescript + * const model = new ChatCloudflareWorkersAI({ + * model: "@cf/meta/llama-2-7b-chat-int8", + * cloudflareAccountId: process.env.CLOUDFLARE_ACCOUNT_ID, + * cloudflareApiToken: process.env.CLOUDFLARE_API_TOKEN + * }); + * + * const response = await model.invoke([ + * ["system", "You are a helpful assistant that translates English to German."], + * ["human", `Translate "I love programming".`] + * ]); + * + * console.log(response); + * ``` + */ +export class ChatCloudflareWorkersAI + extends SimpleChatModel + implements CloudflareWorkersAIInput +{ + static lc_name() { + return "ChatCloudflareWorkersAI"; + } + + lc_serializable = true; + + model = "@cf/meta/llama-2-7b-chat-int8"; + + cloudflareAccountId?: string; + + cloudflareApiToken?: string; + + baseUrl: string; + + streaming = false; + + constructor(fields?: CloudflareWorkersAIInput & BaseChatModelParams) { + super(fields ?? {}); + + this.model = fields?.model ?? this.model; + this.streaming = fields?.streaming ?? this.streaming; + this.cloudflareAccountId = + fields?.cloudflareAccountId ?? + getEnvironmentVariable("CLOUDFLARE_ACCOUNT_ID"); + this.cloudflareApiToken = + fields?.cloudflareApiToken ?? + getEnvironmentVariable("CLOUDFLARE_API_TOKEN"); + this.baseUrl = + fields?.baseUrl ?? + `https://api.cloudflare.com/client/v4/accounts/${this.cloudflareAccountId}/ai/run`; + if (this.baseUrl.endsWith("/")) { + this.baseUrl = this.baseUrl.slice(0, -1); + } + } + + get lc_secrets(): { [key: string]: string } | undefined { + return { + cloudflareApiToken: "CLOUDFLARE_API_TOKEN", + }; + } + + _llmType() { + return "cloudflare"; + } + + /** Get the identifying parameters for this LLM. */ + get identifyingParams() { + return { model: this.model }; + } + + /** + * Get the parameters used to invoke the model + */ + invocationParams(_options?: this["ParsedCallOptions"]) { + return { + model: this.model, + }; + } + + _combineLLMOutput() { + return {}; + } + + /** + * Method to validate the environment. + */ + validateEnvironment() { + if (!this.cloudflareAccountId) { + throw new Error( + `No Cloudflare account ID found. Please provide it when instantiating the CloudflareWorkersAI class, or set it as "CLOUDFLARE_ACCOUNT_ID" in your environment variables.` + ); + } + if (!this.cloudflareApiToken) { + throw new Error( + `No Cloudflare API key found. Please provide it when instantiating the CloudflareWorkersAI class, or set it as "CLOUDFLARE_API_KEY" in your environment variables.` + ); + } + } + + async _request( + messages: BaseMessage[], + options: this["ParsedCallOptions"], + stream?: boolean + ) { + this.validateEnvironment(); + const url = `${this.baseUrl}/${this.model}`; + const headers = { + Authorization: `Bearer ${this.cloudflareApiToken}`, + "Content-Type": "application/json", + }; + + const formattedMessages = this._formatMessages(messages); + + const data = { messages: formattedMessages, stream }; + return this.caller.call(async () => { + const response = await fetch(url, { + method: "POST", + headers, + body: JSON.stringify(data), + signal: options.signal, + }); + if (!response.ok) { + const error = new Error( + `Cloudflare LLM call failed with status code ${response.status}` + ); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (error as any).response = response; + throw error; + } + return response; + }); + } + + async *_streamResponseChunks( + messages: BaseMessage[], + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): AsyncGenerator { + const response = await this._request(messages, options, true); + if (!response.body) { + throw new Error("Empty response from Cloudflare. Please try again."); + } + const stream = convertEventStreamToIterableReadableDataStream( + response.body + ); + for await (const chunk of stream) { + if (chunk !== "[DONE]") { + const parsedChunk = JSON.parse(chunk); + const generationChunk = new ChatGenerationChunk({ + message: new AIMessageChunk({ content: parsedChunk.response }), + text: parsedChunk.response, + }); + yield generationChunk; + // eslint-disable-next-line no-void + void runManager?.handleLLMNewToken(generationChunk.text ?? ""); + } + } + } + + protected _formatMessages( + messages: BaseMessage[] + ): { role: string; content: string }[] { + const formattedMessages = messages.map((message) => { + let role; + if (message._getType() === "human") { + role = "user"; + } else if (message._getType() === "ai") { + role = "assistant"; + } else if (message._getType() === "system") { + role = "system"; + } else if (ChatMessage.isInstance(message)) { + role = message.role; + } else { + console.warn( + `Unsupported message type passed to Cloudflare: "${message._getType()}"` + ); + role = "user"; + } + if (typeof message.content !== "string") { + throw new Error( + "ChatCloudflareWorkersAI currently does not support non-string message content." + ); + } + return { + role, + content: message.content, + }; + }); + return formattedMessages; + } + + /** @ignore */ + async _call( + messages: BaseMessage[], + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): Promise { + if (!this.streaming) { + const response = await this._request(messages, options); + + const responseData = await response.json(); + + return responseData.result.response; + } else { + const stream = this._streamResponseChunks(messages, options, runManager); + let finalResult: ChatGenerationChunk | undefined; + for await (const chunk of stream) { + if (finalResult === undefined) { + finalResult = chunk; + } else { + finalResult = finalResult.concat(chunk); + } + } + const messageContent = finalResult?.message.content; + if (messageContent && typeof messageContent !== "string") { + throw new Error( + "Non-string output for ChatCloudflareWorkersAI is currently not supported." + ); + } + return messageContent ?? ""; + } + } +} diff --git a/libs/langchain-community/src/chat_models/fireworks.ts b/libs/langchain-community/src/chat_models/fireworks.ts new file mode 100644 index 000000000000..b9de098bf2b2 --- /dev/null +++ b/libs/langchain-community/src/chat_models/fireworks.ts @@ -0,0 +1,141 @@ +import type { BaseChatModelParams } from "@langchain/core/language_models/chat_models"; +import { + type OpenAIClient, + type ChatOpenAICallOptions, + type OpenAIChatInput, + type OpenAICoreRequestOptions, + ChatOpenAI, +} from "@langchain/openai"; + +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +type FireworksUnsupportedArgs = + | "frequencyPenalty" + | "presencePenalty" + | "logitBias" + | "functions"; + +type FireworksUnsupportedCallOptions = "functions" | "function_call" | "tools"; + +export type ChatFireworksCallOptions = Partial< + Omit +>; + +/** + * Wrapper around Fireworks API for large language models fine-tuned for chat + * + * Fireworks API is compatible to the OpenAI API with some limitations described in + * https://readme.fireworks.ai/docs/openai-compatibility. + * + * To use, you should have the `openai` package installed and + * the `FIREWORKS_API_KEY` environment variable set. + * @example + * ```typescript + * const model = new ChatFireworks({ + * temperature: 0.9, + * fireworksApiKey: "YOUR-API-KEY", + * }); + * + * const response = await model.invoke("Hello, how are you?"); + * console.log(response); + * ``` + */ +export class ChatFireworks extends ChatOpenAI { + static lc_name() { + return "ChatFireworks"; + } + + _llmType() { + return "fireworks"; + } + + get lc_secrets(): { [key: string]: string } | undefined { + return { + fireworksApiKey: "FIREWORKS_API_KEY", + }; + } + + lc_serializable = true; + + fireworksApiKey?: string; + + constructor( + fields?: Partial< + Omit + > & + BaseChatModelParams & { fireworksApiKey?: string } + ) { + const fireworksApiKey = + fields?.fireworksApiKey || getEnvironmentVariable("FIREWORKS_API_KEY"); + + if (!fireworksApiKey) { + throw new Error( + `Fireworks API key not found. Please set the FIREWORKS_API_KEY environment variable or provide the key into "fireworksApiKey"` + ); + } + + super({ + ...fields, + modelName: + fields?.modelName || "accounts/fireworks/models/llama-v2-13b-chat", + openAIApiKey: fireworksApiKey, + configuration: { + baseURL: "https://api.fireworks.ai/inference/v1", + }, + }); + + this.fireworksApiKey = fireworksApiKey; + } + + toJSON() { + const result = super.toJSON(); + + if ( + "kwargs" in result && + typeof result.kwargs === "object" && + result.kwargs != null + ) { + delete result.kwargs.openai_api_key; + delete result.kwargs.configuration; + } + + return result; + } + + async completionWithRetry( + request: OpenAIClient.Chat.ChatCompletionCreateParamsStreaming, + options?: OpenAICoreRequestOptions + ): Promise>; + + async completionWithRetry( + request: OpenAIClient.Chat.ChatCompletionCreateParamsNonStreaming, + options?: OpenAICoreRequestOptions + ): Promise; + + /** + * Calls the Fireworks API with retry logic in case of failures. + * @param request The request to send to the Fireworks API. + * @param options Optional configuration for the API call. + * @returns The response from the Fireworks API. + */ + async completionWithRetry( + request: + | OpenAIClient.Chat.ChatCompletionCreateParamsStreaming + | OpenAIClient.Chat.ChatCompletionCreateParamsNonStreaming, + options?: OpenAICoreRequestOptions + ): Promise< + | AsyncIterable + | OpenAIClient.Chat.Completions.ChatCompletion + > { + delete request.frequency_penalty; + delete request.presence_penalty; + delete request.logit_bias; + delete request.functions; + + if (request.stream === true) { + return super.completionWithRetry(request, options); + } + + return super.completionWithRetry(request, options); + } +} diff --git a/libs/langchain-community/src/chat_models/googlepalm.ts b/libs/langchain-community/src/chat_models/googlepalm.ts new file mode 100644 index 000000000000..aec212d7d04d --- /dev/null +++ b/libs/langchain-community/src/chat_models/googlepalm.ts @@ -0,0 +1,343 @@ +import { DiscussServiceClient } from "@google-ai/generativelanguage"; +import type { protos } from "@google-ai/generativelanguage"; +import { GoogleAuth } from "google-auth-library"; +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { + AIMessage, + BaseMessage, + ChatMessage, + isBaseMessage, +} from "@langchain/core/messages"; +import { ChatResult } from "@langchain/core/outputs"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { + BaseChatModel, + type BaseChatModelParams, +} from "@langchain/core/language_models/chat_models"; + +export type BaseMessageExamplePair = { + input: BaseMessage; + output: BaseMessage; +}; + +/** + * An interface defining the input to the ChatGooglePaLM class. + */ +export interface GooglePaLMChatInput extends BaseChatModelParams { + /** + * Model Name to use + * + * Note: The format must follow the pattern - `models/{model}` + */ + modelName?: string; + + /** + * Controls the randomness of the output. + * + * Values can range from [0.0,1.0], inclusive. A value closer to 1.0 + * will produce responses that are more varied and creative, while + * a value closer to 0.0 will typically result in less surprising + * responses from the model. + * + * Note: The default value varies by model + */ + temperature?: number; + + /** + * Top-p changes how the model selects tokens for output. + * + * Tokens are selected from most probable to least until the sum + * of their probabilities equals the top-p value. + * + * For example, if tokens A, B, and C have a probability of + * .3, .2, and .1 and the top-p value is .5, then the model will + * select either A or B as the next token (using temperature). + * + * Note: The default value varies by model + */ + topP?: number; + + /** + * Top-k changes how the model selects tokens for output. + * + * A top-k of 1 means the selected token is the most probable among + * all tokens in the model’s vocabulary (also called greedy decoding), + * while a top-k of 3 means that the next token is selected from + * among the 3 most probable tokens (using temperature). + * + * Note: The default value varies by model + */ + topK?: number; + + examples?: + | protos.google.ai.generativelanguage.v1beta2.IExample[] + | BaseMessageExamplePair[]; + + /** + * Google Palm API key to use + */ + apiKey?: string; +} + +function getMessageAuthor(message: BaseMessage) { + const type = message._getType(); + if (ChatMessage.isInstance(message)) { + return message.role; + } + return message.name ?? type; +} + +/** + * A class that wraps the Google Palm chat model. + * @example + * ```typescript + * const model = new ChatGooglePaLM({ + * apiKey: "", + * temperature: 0.7, + * modelName: "models/chat-bison-001", + * topK: 40, + * topP: 1, + * examples: [ + * { + * input: new HumanMessage("What is your favorite sock color?"), + * output: new AIMessage("My favorite sock color be arrrr-ange!"), + * }, + * ], + * }); + * const questions = [ + * new SystemMessage( + * "You are a funny assistant that answers in pirate language." + * ), + * new HumanMessage("What is your favorite food?"), + * ]; + * const res = await model.call(questions); + * console.log({ res }); + * ``` + */ +export class ChatGooglePaLM + extends BaseChatModel + implements GooglePaLMChatInput +{ + static lc_name() { + return "ChatGooglePaLM"; + } + + lc_serializable = true; + + get lc_secrets(): { [key: string]: string } | undefined { + return { + apiKey: "GOOGLE_PALM_API_KEY", + }; + } + + modelName = "models/chat-bison-001"; + + temperature?: number; // default value chosen based on model + + topP?: number; // default value chosen based on model + + topK?: number; // default value chosen based on model + + examples: protos.google.ai.generativelanguage.v1beta2.IExample[] = []; + + apiKey?: string; + + private client: DiscussServiceClient; + + constructor(fields?: GooglePaLMChatInput) { + super(fields ?? {}); + + this.modelName = fields?.modelName ?? this.modelName; + + this.temperature = fields?.temperature ?? this.temperature; + if (this.temperature && (this.temperature < 0 || this.temperature > 1)) { + throw new Error("`temperature` must be in the range of [0.0,1.0]"); + } + + this.topP = fields?.topP ?? this.topP; + if (this.topP && this.topP < 0) { + throw new Error("`topP` must be a positive integer"); + } + + this.topK = fields?.topK ?? this.topK; + if (this.topK && this.topK < 0) { + throw new Error("`topK` must be a positive integer"); + } + + this.examples = + fields?.examples?.map((example) => { + if ( + (isBaseMessage(example.input) && + typeof example.input.content !== "string") || + (isBaseMessage(example.output) && + typeof example.output.content !== "string") + ) { + throw new Error( + "GooglePaLM example messages may only have string content." + ); + } + return { + input: { + ...example.input, + content: example.input?.content as string, + }, + output: { + ...example.output, + content: example.output?.content as string, + }, + }; + }) ?? this.examples; + + this.apiKey = + fields?.apiKey ?? getEnvironmentVariable("GOOGLE_PALM_API_KEY"); + if (!this.apiKey) { + throw new Error( + "Please set an API key for Google Palm 2 in the environment variable GOOGLE_PALM_API_KEY or in the `apiKey` field of the GooglePalm constructor" + ); + } + + this.client = new DiscussServiceClient({ + authClient: new GoogleAuth().fromAPIKey(this.apiKey), + }); + } + + _combineLLMOutput() { + return []; + } + + _llmType() { + return "googlepalm"; + } + + async _generate( + messages: BaseMessage[], + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): Promise { + const palmMessages = await this.caller.callWithOptions( + { signal: options.signal }, + this._generateMessage.bind(this), + this._mapBaseMessagesToPalmMessages(messages), + this._getPalmContextInstruction(messages), + this.examples + ); + const chatResult = this._mapPalmMessagesToChatResult(palmMessages); + + // Google Palm doesn't provide streaming as of now. But to support streaming handlers + // we call the handler with entire response text + void runManager?.handleLLMNewToken( + chatResult.generations.length > 0 ? chatResult.generations[0].text : "" + ); + + return chatResult; + } + + protected async _generateMessage( + messages: protos.google.ai.generativelanguage.v1beta2.IMessage[], + context?: string, + examples?: protos.google.ai.generativelanguage.v1beta2.IExample[] + ): Promise { + const [palmMessages] = await this.client.generateMessage({ + candidateCount: 1, + model: this.modelName, + temperature: this.temperature, + topK: this.topK, + topP: this.topP, + prompt: { + context, + examples, + messages, + }, + }); + return palmMessages; + } + + protected _getPalmContextInstruction( + messages: BaseMessage[] + ): string | undefined { + // get the first message and checks if it's a system 'system' messages + const systemMessage = + messages.length > 0 && getMessageAuthor(messages[0]) === "system" + ? messages[0] + : undefined; + if ( + systemMessage?.content !== undefined && + typeof systemMessage.content !== "string" + ) { + throw new Error("Non-string system message content is not supported."); + } + return systemMessage?.content; + } + + protected _mapBaseMessagesToPalmMessages( + messages: BaseMessage[] + ): protos.google.ai.generativelanguage.v1beta2.IMessage[] { + // remove all 'system' messages + const nonSystemMessages = messages.filter( + (m) => getMessageAuthor(m) !== "system" + ); + + // requires alternate human & ai messages. Throw error if two messages are consecutive + nonSystemMessages.forEach((msg, index) => { + if (index < 1) return; + if ( + getMessageAuthor(msg) === getMessageAuthor(nonSystemMessages[index - 1]) + ) { + throw new Error( + `Google PaLM requires alternate messages between authors` + ); + } + }); + + return nonSystemMessages.map((m) => { + if (typeof m.content !== "string") { + throw new Error( + "ChatGooglePaLM does not support non-string message content." + ); + } + return { + author: getMessageAuthor(m), + content: m.content, + citationMetadata: { + citationSources: m.additional_kwargs.citationSources as + | protos.google.ai.generativelanguage.v1beta2.ICitationSource[] + | undefined, + }, + }; + }); + } + + protected _mapPalmMessagesToChatResult( + msgRes: protos.google.ai.generativelanguage.v1beta2.IGenerateMessageResponse + ): ChatResult { + if ( + msgRes.candidates && + msgRes.candidates.length > 0 && + msgRes.candidates[0] + ) { + const message = msgRes.candidates[0]; + return { + generations: [ + { + text: message.content ?? "", + message: new AIMessage({ + content: message.content ?? "", + name: message.author === null ? undefined : message.author, + additional_kwargs: { + citationSources: message.citationMetadata?.citationSources, + filters: msgRes.filters, // content filters applied + }, + }), + }, + ], + }; + } + // if rejected or error, return empty generations with reason in filters + return { + generations: [], + llmOutput: { + filters: msgRes.filters, + }, + }; + } +} diff --git a/langchain/src/chat_models/googlevertexai/common.ts b/libs/langchain-community/src/chat_models/googlevertexai/common.ts similarity index 96% rename from langchain/src/chat_models/googlevertexai/common.ts rename to libs/langchain-community/src/chat_models/googlevertexai/common.ts index 4ff8b170a271..d1208de1c6d9 100644 --- a/langchain/src/chat_models/googlevertexai/common.ts +++ b/libs/langchain-community/src/chat_models/googlevertexai/common.ts @@ -1,26 +1,29 @@ -import { BaseChatModel } from "../base.js"; +import type { BaseLanguageModelCallOptions } from "@langchain/core/language_models/base"; +import { BaseChatModel } from "@langchain/core/language_models/chat_models"; +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; import { AIMessage, AIMessageChunk, BaseMessage, + ChatMessage, +} from "@langchain/core/messages"; +import { ChatGeneration, ChatGenerationChunk, - ChatMessage, ChatResult, LLMResult, -} from "../../schema/index.js"; +} from "@langchain/core/outputs"; + import { GoogleVertexAILLMConnection, GoogleVertexAIStream, -} from "../../util/googlevertexai-connection.js"; +} from "../../utils/googlevertexai-connection.js"; import { GoogleVertexAIBaseLLMInput, GoogleVertexAIBasePrediction, GoogleVertexAILLMPredictions, GoogleVertexAIModelParams, } from "../../types/googlevertexai-types.js"; -import { BaseLanguageModelCallOptions } from "../../base_language/index.js"; -import { CallbackManagerForLLMRun } from "../../callbacks/index.js"; /** * Represents a single "example" exchange that can be provided to diff --git a/libs/langchain-community/src/chat_models/googlevertexai/index.ts b/libs/langchain-community/src/chat_models/googlevertexai/index.ts new file mode 100644 index 000000000000..d93693e8fcc0 --- /dev/null +++ b/libs/langchain-community/src/chat_models/googlevertexai/index.ts @@ -0,0 +1,64 @@ +import { GoogleAuthOptions } from "google-auth-library"; +import { BaseChatGoogleVertexAI, GoogleVertexAIChatInput } from "./common.js"; +import { GoogleVertexAILLMConnection } from "../../utils/googlevertexai-connection.js"; +import { GAuthClient } from "../../utils/googlevertexai-gauth.js"; + +/** + * Enables calls to the Google Cloud's Vertex AI API to access + * Large Language Models in a chat-like fashion. + * + * To use, you will need to have one of the following authentication + * methods in place: + * - You are logged into an account permitted to the Google Cloud project + * using Vertex AI. + * - You are running this on a machine using a service account permitted to + * the Google Cloud project using Vertex AI. + * - The `GOOGLE_APPLICATION_CREDENTIALS` environment variable is set to the + * path of a credentials file for a service account permitted to the + * Google Cloud project using Vertex AI. + * @example + * ```typescript + * const model = new ChatGoogleVertexAI({ + * temperature: 0.7, + * }); + * const result = await model.invoke("What is the capital of France?"); + * ``` + */ +export class ChatGoogleVertexAI extends BaseChatGoogleVertexAI { + static lc_name() { + return "ChatVertexAI"; + } + + constructor(fields?: GoogleVertexAIChatInput) { + super(fields); + + const client = new GAuthClient({ + scopes: "https://www.googleapis.com/auth/cloud-platform", + ...fields?.authOptions, + }); + + this.connection = new GoogleVertexAILLMConnection( + { ...fields, ...this }, + this.caller, + client, + false + ); + + this.streamedConnection = new GoogleVertexAILLMConnection( + { ...fields, ...this }, + this.caller, + client, + true + ); + } +} + +export type { + ChatExample, + GoogleVertexAIChatAuthor, + GoogleVertexAIChatInput, + GoogleVertexAIChatInstance, + GoogleVertexAIChatMessage, + GoogleVertexAIChatMessageFields, + GoogleVertexAIChatPrediction, +} from "./common.js"; diff --git a/libs/langchain-community/src/chat_models/googlevertexai/web.ts b/libs/langchain-community/src/chat_models/googlevertexai/web.ts new file mode 100644 index 000000000000..503058fc0e21 --- /dev/null +++ b/libs/langchain-community/src/chat_models/googlevertexai/web.ts @@ -0,0 +1,66 @@ +import { GoogleVertexAILLMConnection } from "../../utils/googlevertexai-connection.js"; +import { + WebGoogleAuthOptions, + WebGoogleAuth, +} from "../../utils/googlevertexai-webauth.js"; +import { BaseChatGoogleVertexAI, GoogleVertexAIChatInput } from "./common.js"; + +/** + * Enables calls to the Google Cloud's Vertex AI API to access + * Large Language Models in a chat-like fashion. + * + * This entrypoint and class are intended to be used in web environments like Edge + * functions where you do not have access to the file system. It supports passing + * service account credentials directly as a "GOOGLE_VERTEX_AI_WEB_CREDENTIALS" + * environment variable or directly as "authOptions.credentials". + * @example + * ```typescript + * const model = new ChatGoogleVertexAI({ + * temperature: 0.7, + * }); + * const result = await model.invoke( + * "How do I implement a binary search algorithm in Python?", + * ); + * ``` + */ +export class ChatGoogleVertexAI extends BaseChatGoogleVertexAI { + static lc_name() { + return "ChatVertexAI"; + } + + get lc_secrets(): { [key: string]: string } { + return { + "authOptions.credentials": "GOOGLE_VERTEX_AI_WEB_CREDENTIALS", + }; + } + + constructor(fields?: GoogleVertexAIChatInput) { + super(fields); + + const client = new WebGoogleAuth(fields?.authOptions); + + this.connection = new GoogleVertexAILLMConnection( + { ...fields, ...this }, + this.caller, + client, + false + ); + + this.streamedConnection = new GoogleVertexAILLMConnection( + { ...fields, ...this }, + this.caller, + client, + true + ); + } +} + +export type { + ChatExample, + GoogleVertexAIChatAuthor, + GoogleVertexAIChatInput, + GoogleVertexAIChatInstance, + GoogleVertexAIChatMessage, + GoogleVertexAIChatMessageFields, + GoogleVertexAIChatPrediction, +} from "./common.js"; diff --git a/langchain/src/chat_models/iflytek_xinghuo/common.ts b/libs/langchain-community/src/chat_models/iflytek_xinghuo/common.ts similarity index 96% rename from langchain/src/chat_models/iflytek_xinghuo/common.ts rename to libs/langchain-community/src/chat_models/iflytek_xinghuo/common.ts index 6854dc22fdd9..af04bd8ec09f 100644 --- a/langchain/src/chat_models/iflytek_xinghuo/common.ts +++ b/libs/langchain-community/src/chat_models/iflytek_xinghuo/common.ts @@ -1,18 +1,16 @@ -import { CallbackManagerForLLMRun } from "../../callbacks/manager.js"; +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { AIMessage, BaseMessage, ChatMessage } from "@langchain/core/messages"; +import { ChatGeneration, ChatResult } from "@langchain/core/outputs"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { IterableReadableStream } from "@langchain/core/utils/stream"; import { - AIMessage, - BaseMessage, - ChatGeneration, - ChatMessage, - ChatResult, -} from "../../schema/index.js"; -import { getEnvironmentVariable } from "../../util/env.js"; -import { IterableReadableStream } from "../../util/stream.js"; -import { BaseChatModel, BaseChatModelParams } from "../base.js"; + BaseChatModel, + type BaseChatModelParams, +} from "@langchain/core/language_models/chat_models"; import { BaseWebSocketStream, WebSocketStreamOptions, -} from "../../util/iflytek_websocket_stream.js"; +} from "../../utils/iflytek_websocket_stream.js"; /** * Type representing the role of a message in the Xinghuo chat model. diff --git a/libs/langchain-community/src/chat_models/iflytek_xinghuo/index.ts b/libs/langchain-community/src/chat_models/iflytek_xinghuo/index.ts new file mode 100644 index 000000000000..681d6bdbc299 --- /dev/null +++ b/libs/langchain-community/src/chat_models/iflytek_xinghuo/index.ts @@ -0,0 +1,43 @@ +import WebSocket from "ws"; +import { BaseChatIflytekXinghuo } from "./common.js"; +import { + BaseWebSocketStream, + WebSocketStreamOptions, +} from "../../utils/iflytek_websocket_stream.js"; + +class WebSocketStream extends BaseWebSocketStream { + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore + openWebSocket(url: string, options: WebSocketStreamOptions): WebSocket { + return new WebSocket(url, options.protocols ?? []); + } +} + +/** + * @example + * ```typescript + * const model = new ChatIflytekXinghuo(); + * const response = await model.call([new HumanMessage("Nice to meet you!")]); + * console.log(response); + * ``` + */ +export class ChatIflytekXinghuo extends BaseChatIflytekXinghuo { + async openWebSocketStream( + options: WebSocketStreamOptions + ): Promise { + const host = "spark-api.xf-yun.com"; + const date = new Date().toUTCString(); + const url = `GET /${this.version}/chat HTTP/1.1`; + const { createHmac } = await import("node:crypto"); + const hash = createHmac("sha256", this.iflytekApiSecret) + .update(`host: ${host}\ndate: ${date}\n${url}`) + .digest("base64"); + const authorization_origin = `api_key="${this.iflytekApiKey}", algorithm="hmac-sha256", headers="host date request-line", signature="${hash}"`; + const authorization = Buffer.from(authorization_origin).toString("base64"); + let authWebSocketUrl = this.apiUrl; + authWebSocketUrl += `?authorization=${authorization}`; + authWebSocketUrl += `&host=${encodeURIComponent(host)}`; + authWebSocketUrl += `&date=${encodeURIComponent(date)}`; + return new WebSocketStream(authWebSocketUrl, options) as WebSocketStream; + } +} diff --git a/libs/langchain-community/src/chat_models/iflytek_xinghuo/web.ts b/libs/langchain-community/src/chat_models/iflytek_xinghuo/web.ts new file mode 100644 index 000000000000..df0db076085a --- /dev/null +++ b/libs/langchain-community/src/chat_models/iflytek_xinghuo/web.ts @@ -0,0 +1,49 @@ +import { BaseChatIflytekXinghuo } from "./common.js"; +import { + WebSocketStreamOptions, + BaseWebSocketStream, +} from "../../utils/iflytek_websocket_stream.js"; + +class WebSocketStream extends BaseWebSocketStream { + openWebSocket(url: string, options: WebSocketStreamOptions): WebSocket { + return new WebSocket(url, options.protocols ?? []); + } +} + +/** + * @example + * ```typescript + * const model = new ChatIflytekXinghuo(); + * const response = await model.call([new HumanMessage("Nice to meet you!")]); + * console.log(response); + * ``` + */ +export class ChatIflytekXinghuo extends BaseChatIflytekXinghuo { + async openWebSocketStream( + options: WebSocketStreamOptions + ): Promise { + const host = "spark-api.xf-yun.com"; + const date = new Date().toUTCString(); + const url = `GET /${this.version}/chat HTTP/1.1`; + const keyBuffer = new TextEncoder().encode(this.iflytekApiSecret); + const dataBuffer = new TextEncoder().encode( + `host: ${host}\ndate: ${date}\n${url}` + ); + const cryptoKey = await crypto.subtle.importKey( + "raw", + keyBuffer, + { name: "HMAC", hash: "SHA-256" }, + false, + ["sign"] + ); + const signature = await crypto.subtle.sign("HMAC", cryptoKey, dataBuffer); + const hash = window.btoa(String.fromCharCode(...new Uint8Array(signature))); + const authorization_origin = `api_key="${this.iflytekApiKey}", algorithm="hmac-sha256", headers="host date request-line", signature="${hash}"`; + const authorization = window.btoa(authorization_origin); + let authWebSocketUrl = this.apiUrl; + authWebSocketUrl += `?authorization=${authorization}`; + authWebSocketUrl += `&host=${encodeURIComponent(host)}`; + authWebSocketUrl += `&date=${encodeURIComponent(date)}`; + return new WebSocketStream(authWebSocketUrl, options) as WebSocketStream; + } +} diff --git a/libs/langchain-community/src/chat_models/llama_cpp.ts b/libs/langchain-community/src/chat_models/llama_cpp.ts new file mode 100644 index 000000000000..5f2fc95468be --- /dev/null +++ b/libs/langchain-community/src/chat_models/llama_cpp.ts @@ -0,0 +1,328 @@ +import { + LlamaModel, + LlamaContext, + LlamaChatSession, + type ConversationInteraction, +} from "node-llama-cpp"; + +import { + SimpleChatModel, + type BaseChatModelParams, +} from "@langchain/core/language_models/chat_models"; +import type { BaseLanguageModelCallOptions } from "@langchain/core/language_models/base"; +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { + BaseMessage, + AIMessageChunk, + ChatMessage, +} from "@langchain/core/messages"; +import { ChatGenerationChunk } from "@langchain/core/outputs"; +import { + LlamaBaseCppInputs, + createLlamaModel, + createLlamaContext, +} from "../utils/llama_cpp.js"; + +/** + * Note that the modelPath is the only required parameter. For testing you + * can set this in the environment variable `LLAMA_PATH`. + */ +export interface LlamaCppInputs + extends LlamaBaseCppInputs, + BaseChatModelParams {} + +export interface LlamaCppCallOptions extends BaseLanguageModelCallOptions { + /** The maximum number of tokens the response should contain. */ + maxTokens?: number; + /** A function called when matching the provided token array */ + onToken?: (tokens: number[]) => void; +} + +/** + * To use this model you need to have the `node-llama-cpp` module installed. + * This can be installed using `npm install -S node-llama-cpp` and the minimum + * version supported in version 2.0.0. + * This also requires that have a locally built version of Llama2 installed. + * @example + * ```typescript + * // Initialize the ChatLlamaCpp model with the path to the model binary file. + * const model = new ChatLlamaCpp({ + * modelPath: "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin", + * temperature: 0.5, + * }); + * + * // Call the model with a message and await the response. + * const response = await model.call([ + * new HumanMessage({ content: "My name is John." }), + * ]); + * + * // Log the response to the console. + * console.log({ response }); + * + * ``` + */ +export class ChatLlamaCpp extends SimpleChatModel { + declare CallOptions: LlamaCppCallOptions; + + static inputs: LlamaCppInputs; + + maxTokens?: number; + + temperature?: number; + + topK?: number; + + topP?: number; + + trimWhitespaceSuffix?: boolean; + + _model: LlamaModel; + + _context: LlamaContext; + + _session: LlamaChatSession | null; + + lc_serializable = true; + + static lc_name() { + return "ChatLlamaCpp"; + } + + constructor(inputs: LlamaCppInputs) { + super(inputs); + this.maxTokens = inputs?.maxTokens; + this.temperature = inputs?.temperature; + this.topK = inputs?.topK; + this.topP = inputs?.topP; + this.trimWhitespaceSuffix = inputs?.trimWhitespaceSuffix; + this._model = createLlamaModel(inputs); + this._context = createLlamaContext(this._model, inputs); + this._session = null; + } + + _llmType() { + return "llama2_cpp"; + } + + /** @ignore */ + _combineLLMOutput() { + return {}; + } + + invocationParams() { + return { + maxTokens: this.maxTokens, + temperature: this.temperature, + topK: this.topK, + topP: this.topP, + trimWhitespaceSuffix: this.trimWhitespaceSuffix, + }; + } + + /** @ignore */ + async _call( + messages: BaseMessage[], + options: this["ParsedCallOptions"] + ): Promise { + let prompt = ""; + + if (messages.length > 1) { + // We need to build a new _session + prompt = this._buildSession(messages); + } else if (!this._session) { + prompt = this._buildSession(messages); + } else { + if (typeof messages[0].content !== "string") { + throw new Error( + "ChatLlamaCpp does not support non-string message content in sessions." + ); + } + // If we already have a session then we should just have a single prompt + prompt = messages[0].content; + } + + try { + const promptOptions = { + onToken: options.onToken, + maxTokens: this?.maxTokens, + temperature: this?.temperature, + topK: this?.topK, + topP: this?.topP, + trimWhitespaceSuffix: this?.trimWhitespaceSuffix, + }; + // @ts-expect-error - TS2531: Object is possibly 'null'. + const completion = await this._session.prompt(prompt, promptOptions); + return completion; + } catch (e) { + throw new Error("Error getting prompt completion."); + } + } + + async *_streamResponseChunks( + input: BaseMessage[], + _options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): AsyncGenerator { + const promptOptions = { + temperature: this?.temperature, + topK: this?.topK, + topP: this?.topP, + }; + + const prompt = this._buildPrompt(input); + + const stream = await this.caller.call(async () => + this._context.evaluate(this._context.encode(prompt), promptOptions) + ); + + for await (const chunk of stream) { + yield new ChatGenerationChunk({ + text: this._context.decode([chunk]), + message: new AIMessageChunk({ + content: this._context.decode([chunk]), + }), + generationInfo: {}, + }); + await runManager?.handleLLMNewToken(this._context.decode([chunk]) ?? ""); + } + } + + // This constructs a new session if we need to adding in any sys messages or previous chats + protected _buildSession(messages: BaseMessage[]): string { + let prompt = ""; + let sysMessage = ""; + let noSystemMessages: BaseMessage[] = []; + let interactions: ConversationInteraction[] = []; + + // Let's see if we have a system message + if (messages.findIndex((msg) => msg._getType() === "system") !== -1) { + const sysMessages = messages.filter( + (message) => message._getType() === "system" + ); + + const systemMessageContent = sysMessages[sysMessages.length - 1].content; + + if (typeof systemMessageContent !== "string") { + throw new Error( + "ChatLlamaCpp does not support non-string message content in sessions." + ); + } + // Only use the last provided system message + sysMessage = systemMessageContent; + + // Now filter out the system messages + noSystemMessages = messages.filter( + (message) => message._getType() !== "system" + ); + } else { + noSystemMessages = messages; + } + + // Lets see if we just have a prompt left or are their previous interactions? + if (noSystemMessages.length > 1) { + // Is the last message a prompt? + if ( + noSystemMessages[noSystemMessages.length - 1]._getType() === "human" + ) { + const finalMessageContent = + noSystemMessages[noSystemMessages.length - 1].content; + if (typeof finalMessageContent !== "string") { + throw new Error( + "ChatLlamaCpp does not support non-string message content in sessions." + ); + } + prompt = finalMessageContent; + interactions = this._convertMessagesToInteractions( + noSystemMessages.slice(0, noSystemMessages.length - 1) + ); + } else { + interactions = this._convertMessagesToInteractions(noSystemMessages); + } + } else { + if (typeof noSystemMessages[0].content !== "string") { + throw new Error( + "ChatLlamaCpp does not support non-string message content in sessions." + ); + } + // If there was only a single message we assume it's a prompt + prompt = noSystemMessages[0].content; + } + + // Now lets construct a session according to what we got + if (sysMessage !== "" && interactions.length > 0) { + this._session = new LlamaChatSession({ + context: this._context, + conversationHistory: interactions, + systemPrompt: sysMessage, + }); + } else if (sysMessage !== "" && interactions.length === 0) { + this._session = new LlamaChatSession({ + context: this._context, + systemPrompt: sysMessage, + }); + } else if (sysMessage === "" && interactions.length > 0) { + this._session = new LlamaChatSession({ + context: this._context, + conversationHistory: interactions, + }); + } else { + this._session = new LlamaChatSession({ + context: this._context, + }); + } + + return prompt; + } + + // This builds a an array of interactions + protected _convertMessagesToInteractions( + messages: BaseMessage[] + ): ConversationInteraction[] { + const result: ConversationInteraction[] = []; + + for (let i = 0; i < messages.length; i += 2) { + if (i + 1 < messages.length) { + const prompt = messages[i].content; + const response = messages[i + 1].content; + if (typeof prompt !== "string" || typeof response !== "string") { + throw new Error( + "ChatLlamaCpp does not support non-string message content." + ); + } + result.push({ + prompt, + response, + }); + } + } + + return result; + } + + protected _buildPrompt(input: BaseMessage[]): string { + const prompt = input + .map((message) => { + let messageText; + if (message._getType() === "human") { + messageText = `[INST] ${message.content} [/INST]`; + } else if (message._getType() === "ai") { + messageText = message.content; + } else if (message._getType() === "system") { + messageText = `<> ${message.content} <>`; + } else if (ChatMessage.isInstance(message)) { + messageText = `\n\n${message.role[0].toUpperCase()}${message.role.slice( + 1 + )}: ${message.content}`; + } else { + console.warn( + `Unsupported message type passed to llama_cpp: "${message._getType()}"` + ); + messageText = ""; + } + return messageText; + }) + .join("\n"); + + return prompt; + } +} diff --git a/libs/langchain-community/src/chat_models/minimax.ts b/libs/langchain-community/src/chat_models/minimax.ts new file mode 100644 index 000000000000..bb450bf92d38 --- /dev/null +++ b/libs/langchain-community/src/chat_models/minimax.ts @@ -0,0 +1,882 @@ +import type { OpenAIClient } from "@langchain/openai"; + +import { + BaseChatModel, + type BaseChatModelParams, +} from "@langchain/core/language_models/chat_models"; +import { + AIMessage, + BaseMessage, + ChatMessage, + HumanMessage, +} from "@langchain/core/messages"; +import { ChatResult, ChatGeneration } from "@langchain/core/outputs"; +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { StructuredTool } from "@langchain/core/tools"; +import { BaseFunctionCallOptions } from "@langchain/core/language_models/base"; +import { formatToOpenAIFunction } from "@langchain/openai"; + +/** + * Type representing the sender_type of a message in the Minimax chat model. + */ +export type MinimaxMessageRole = "BOT" | "USER" | "FUNCTION"; + +/** + * Interface representing a message in the Minimax chat model. + */ +interface MinimaxChatCompletionRequestMessage { + sender_type: MinimaxMessageRole; + sender_name?: string; + text: string; +} + +/** + * Interface representing a request for a chat completion. + */ +interface MinimaxChatCompletionRequest { + model: string; + messages: MinimaxChatCompletionRequestMessage[]; + stream?: boolean; + prompt?: string; + temperature?: number; + top_p?: number; + tokens_to_generate?: number; + skip_info_mask?: boolean; + mask_sensitive_info?: boolean; + beam_width?: number; + use_standard_sse?: boolean; + role_meta?: RoleMeta; + bot_setting?: BotSetting[]; + reply_constraints?: ReplyConstraints; + sample_messages?: MinimaxChatCompletionRequestMessage[]; + /** + * A list of functions the model may generate JSON inputs for. + * @type {Array} + */ + functions?: OpenAIClient.Chat.ChatCompletionCreateParams.Function[]; + plugins?: string[]; +} + +interface RoleMeta { + role_meta: string; + bot_name: string; +} + +interface RawGlyph { + type: "raw"; + raw_glyph: string; +} + +interface JsonGlyph { + type: "json_value"; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + json_properties: any; +} + +type ReplyConstraintsGlyph = RawGlyph | JsonGlyph; + +interface ReplyConstraints { + sender_type: string; + sender_name: string; + glyph?: ReplyConstraintsGlyph; +} + +interface BotSetting { + content: string; + bot_name: string; +} + +export declare interface ConfigurationParameters { + basePath?: string; + headers?: Record; +} + +/** + * Interface defining the input to the ChatMinimax class. + */ +declare interface MinimaxChatInputBase { + /** Model name to use + * @default "abab5.5-chat" + */ + modelName: string; + + /** Whether to stream the results or not. Defaults to false. */ + streaming?: boolean; + + prefixMessages?: MinimaxChatCompletionRequestMessage[]; + + /** + * API key to use when making requests. Defaults to the value of + * `MINIMAX_GROUP_ID` environment variable. + */ + minimaxGroupId?: string; + + /** + * Secret key to use when making requests. Defaults to the value of + * `MINIMAX_API_KEY` environment variable. + */ + minimaxApiKey?: string; + + /** Amount of randomness injected into the response. Ranges + * from 0 to 1 (0 is not included). Use temp closer to 0 for analytical / + * multiple choice, and temp closer to 1 for creative + * and generative tasks. Defaults to 0.95. + */ + temperature?: number; + + /** + * The smaller the sampling method, the more determinate the result; + * the larger the number, the more random the result. + */ + topP?: number; + + /** + * Enable Chatcompletion pro + */ + proVersion?: boolean; + + /** + * Pay attention to the maximum number of tokens generated, + * this parameter does not affect the generation effect of the model itself, + * but only realizes the function by truncating the tokens exceeding the limit. + * It is necessary to ensure that the number of tokens of the input context plus this value is less than 6144 or 16384, + * otherwise the request will fail. + */ + tokensToGenerate?: number; +} + +declare interface MinimaxChatInputNormal { + /** + * Dialogue setting, characters, or functionality setting. + */ + prompt?: string; + /** + * Sensitize text information in the output that may involve privacy issues, + * currently including but not limited to emails, domain names, + * links, ID numbers, home addresses, etc. Default false, ie. enable sensitization. + */ + skipInfoMask?: boolean; + + /** + * Whether to use the standard SSE format, when set to true, + * the streaming results will be separated by two line breaks. + * This parameter only takes effect when stream is set to true. + */ + useStandardSse?: boolean; + + /** + * If it is true, this indicates that the current request is set to continuation mode, + * and the response is a continuation of the last sentence in the incoming messages; + * at this time, the last sender is not limited to USER, it can also be BOT. + * Assuming the last sentence of incoming messages is {"sender_type": " U S E R", "text": "天生我材"}, + * the completion of the reply may be "It must be useful." + */ + continueLastMessage?: boolean; + + /** + * How many results to generate; the default is 1 and the maximum is not more than 4. + * Because beamWidth generates multiple results, it will consume more tokens. + */ + beamWidth?: number; + + /** + * Dialogue Metadata + */ + roleMeta?: RoleMeta; +} + +declare interface MinimaxChatInputPro extends MinimaxChatInputBase { + /** + * For the text information in the output that may involve privacy issues, + * code masking is currently included but not limited to emails, domains, links, ID numbers, home addresses, etc., + * with the default being true, that is, code masking is enabled. + */ + maskSensitiveInfo?: boolean; + + /** + * Default bot name + */ + defaultBotName?: string; + + /** + * Default user name + */ + defaultUserName?: string; + + /** + * Setting for each robot, only available for pro version. + */ + botSetting?: BotSetting[]; + + replyConstraints?: ReplyConstraints; +} + +type MinimaxChatInput = MinimaxChatInputNormal & MinimaxChatInputPro; + +/** + * Function that extracts the custom sender_type of a generic chat message. + * @param message Chat message from which to extract the custom sender_type. + * @returns The custom sender_type of the chat message. + */ +function extractGenericMessageCustomRole(message: ChatMessage) { + if (message.role !== "ai" && message.role !== "user") { + console.warn(`Unknown message role: ${message.role}`); + } + if (message.role === "ai") { + return "BOT" as MinimaxMessageRole; + } + if (message.role === "user") { + return "USER" as MinimaxMessageRole; + } + return message.role as MinimaxMessageRole; +} + +/** + * Function that converts a base message to a Minimax message sender_type. + * @param message Base message to convert. + * @returns The Minimax message sender_type. + */ +function messageToMinimaxRole(message: BaseMessage): MinimaxMessageRole { + const type = message._getType(); + switch (type) { + case "ai": + return "BOT"; + case "human": + return "USER"; + case "system": + throw new Error("System messages not supported"); + case "function": + return "FUNCTION"; + case "generic": { + if (!ChatMessage.isInstance(message)) + throw new Error("Invalid generic chat message"); + return extractGenericMessageCustomRole(message); + } + default: + throw new Error(`Unknown message type: ${type}`); + } +} + +export interface ChatMinimaxCallOptions extends BaseFunctionCallOptions { + tools?: StructuredTool[]; + defaultUserName?: string; + defaultBotName?: string; + plugins?: string[]; + botSetting?: BotSetting[]; + replyConstraints?: ReplyConstraints; + sampleMessages?: BaseMessage[]; +} + +/** + * Wrapper around Minimax large language models that use the Chat endpoint. + * + * To use you should have the `MINIMAX_GROUP_ID` and `MINIMAX_API_KEY` + * environment variable set. + * @example + * ```typescript + * // Define a chat prompt with a system message setting the context for translation + * const chatPrompt = ChatPromptTemplate.fromMessages([ + * SystemMessagePromptTemplate.fromTemplate( + * "You are a helpful assistant that translates {input_language} to {output_language}.", + * ), + * HumanMessagePromptTemplate.fromTemplate("{text}"), + * ]); + * + * // Create a new LLMChain with the chat model and the defined prompt + * const chainB = new LLMChain({ + * prompt: chatPrompt, + * llm: new ChatMinimax({ temperature: 0.01 }), + * }); + * + * // Call the chain with the input language, output language, and the text to translate + * const resB = await chainB.call({ + * input_language: "English", + * output_language: "Chinese", + * text: "I love programming.", + * }); + * + * // Log the result + * console.log({ resB }); + * + * ``` + */ +export class ChatMinimax + extends BaseChatModel + implements MinimaxChatInput +{ + static lc_name() { + return "ChatMinimax"; + } + + get callKeys(): (keyof ChatMinimaxCallOptions)[] { + return [ + ...(super.callKeys as (keyof ChatMinimaxCallOptions)[]), + "functions", + "tools", + "defaultBotName", + "defaultUserName", + "plugins", + "replyConstraints", + "botSetting", + "sampleMessages", + ]; + } + + get lc_secrets(): { [key: string]: string } | undefined { + return { + minimaxApiKey: "MINIMAX_API_KEY", + minimaxGroupId: "MINIMAX_GROUP_ID", + }; + } + + lc_serializable = true; + + minimaxGroupId?: string; + + minimaxApiKey?: string; + + streaming = false; + + prompt?: string; + + modelName = "abab5.5-chat"; + + defaultBotName?: string = "Assistant"; + + defaultUserName?: string = "I"; + + prefixMessages?: MinimaxChatCompletionRequestMessage[]; + + apiUrl: string; + + basePath?: string = "https://api.minimax.chat/v1"; + + headers?: Record; + + temperature?: number = 0.9; + + topP?: number = 0.8; + + tokensToGenerate?: number; + + skipInfoMask?: boolean; + + proVersion?: boolean = true; + + beamWidth?: number; + + botSetting?: BotSetting[]; + + continueLastMessage?: boolean; + + maskSensitiveInfo?: boolean; + + roleMeta?: RoleMeta; + + useStandardSse?: boolean; + + replyConstraints?: ReplyConstraints; + + constructor( + fields?: Partial & + BaseChatModelParams & { + configuration?: ConfigurationParameters; + } + ) { + super(fields ?? {}); + + this.minimaxGroupId = + fields?.minimaxGroupId ?? getEnvironmentVariable("MINIMAX_GROUP_ID"); + if (!this.minimaxGroupId) { + throw new Error("Minimax GroupID not found"); + } + + this.minimaxApiKey = + fields?.minimaxApiKey ?? getEnvironmentVariable("MINIMAX_API_KEY"); + + if (!this.minimaxApiKey) { + throw new Error("Minimax ApiKey not found"); + } + + this.streaming = fields?.streaming ?? this.streaming; + this.prompt = fields?.prompt ?? this.prompt; + this.temperature = fields?.temperature ?? this.temperature; + this.topP = fields?.topP ?? this.topP; + this.skipInfoMask = fields?.skipInfoMask ?? this.skipInfoMask; + this.prefixMessages = fields?.prefixMessages ?? this.prefixMessages; + this.maskSensitiveInfo = + fields?.maskSensitiveInfo ?? this.maskSensitiveInfo; + this.beamWidth = fields?.beamWidth ?? this.beamWidth; + this.continueLastMessage = + fields?.continueLastMessage ?? this.continueLastMessage; + this.tokensToGenerate = fields?.tokensToGenerate ?? this.tokensToGenerate; + this.roleMeta = fields?.roleMeta ?? this.roleMeta; + this.botSetting = fields?.botSetting ?? this.botSetting; + this.useStandardSse = fields?.useStandardSse ?? this.useStandardSse; + this.replyConstraints = fields?.replyConstraints ?? this.replyConstraints; + this.defaultBotName = fields?.defaultBotName ?? this.defaultBotName; + + this.modelName = fields?.modelName ?? this.modelName; + this.basePath = fields?.configuration?.basePath ?? this.basePath; + this.headers = fields?.configuration?.headers ?? this.headers; + this.proVersion = fields?.proVersion ?? this.proVersion; + + const modelCompletion = this.proVersion + ? "chatcompletion_pro" + : "chatcompletion"; + this.apiUrl = `${this.basePath}/text/${modelCompletion}`; + } + + fallbackBotName(options?: this["ParsedCallOptions"]) { + let botName = options?.defaultBotName ?? this.defaultBotName ?? "Assistant"; + if (this.botSetting) { + botName = this.botSetting[0].bot_name; + } + return botName; + } + + defaultReplyConstraints(options?: this["ParsedCallOptions"]) { + const constraints = options?.replyConstraints ?? this.replyConstraints; + if (!constraints) { + let botName = + options?.defaultBotName ?? this.defaultBotName ?? "Assistant"; + if (this.botSetting) { + botName = this.botSetting[0].bot_name; + } + + return { + sender_type: "BOT", + sender_name: botName, + }; + } + return constraints; + } + + /** + * Get the parameters used to invoke the model + */ + invocationParams( + options?: this["ParsedCallOptions"] + ): Omit { + return { + model: this.modelName, + stream: this.streaming, + prompt: this.prompt, + temperature: this.temperature, + top_p: this.topP, + tokens_to_generate: this.tokensToGenerate, + skip_info_mask: this.skipInfoMask, + mask_sensitive_info: this.maskSensitiveInfo, + beam_width: this.beamWidth, + use_standard_sse: this.useStandardSse, + role_meta: this.roleMeta, + bot_setting: options?.botSetting ?? this.botSetting, + reply_constraints: this.defaultReplyConstraints(options), + sample_messages: this.messageToMinimaxMessage( + options?.sampleMessages, + options + ), + functions: + options?.functions ?? + (options?.tools + ? options?.tools.map(formatToOpenAIFunction) + : undefined), + plugins: options?.plugins, + }; + } + + /** + * Get the identifying parameters for the model + */ + identifyingParams() { + return { + ...this.invocationParams(), + }; + } + + /** + * Convert a list of messages to the format expected by the model. + * @param messages + * @param options + */ + messageToMinimaxMessage( + messages?: BaseMessage[], + options?: this["ParsedCallOptions"] + ): MinimaxChatCompletionRequestMessage[] | undefined { + return messages + ?.filter((message) => { + if (ChatMessage.isInstance(message)) { + return message.role !== "system"; + } + return message._getType() !== "system"; + }) + ?.map((message) => { + const sender_type = messageToMinimaxRole(message); + if (typeof message.content !== "string") { + throw new Error( + "ChatMinimax does not support non-string message content." + ); + } + return { + sender_type, + text: message.content, + sender_name: + message.name ?? + (sender_type === "BOT" + ? this.fallbackBotName() + : options?.defaultUserName ?? this.defaultUserName), + }; + }); + } + + /** @ignore */ + async _generate( + messages: BaseMessage[], + options?: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): Promise { + const tokenUsage = { totalTokens: 0 }; + this.botSettingFallback(options, messages); + + const params = this.invocationParams(options); + const messagesMapped: MinimaxChatCompletionRequestMessage[] = [ + ...(this.messageToMinimaxMessage(messages, options) ?? []), + ...(this.prefixMessages ?? []), + ]; + + const data = params.stream + ? await new Promise((resolve, reject) => { + let response: ChatCompletionResponse; + let rejected = false; + let resolved = false; + this.completionWithRetry( + { + ...params, + messages: messagesMapped, + }, + true, + options?.signal, + (event) => { + const data = JSON.parse(event.data); + + if (data?.error_code) { + if (rejected) { + return; + } + rejected = true; + reject(data); + return; + } + + const message = data as ChatCompletionResponse; + // on the first message set the response properties + + if (!message.choices[0].finish_reason) { + // the last stream message + let streamText; + if (this.proVersion) { + const messages = message.choices[0].messages ?? []; + streamText = messages[0].text; + } else { + streamText = message.choices[0].delta; + } + + // TODO this should pass part.index to the callback + // when that's supported there + // eslint-disable-next-line no-void + void runManager?.handleLLMNewToken(streamText ?? ""); + return; + } + + response = message; + if (!this.proVersion) { + response.choices[0].text = message.reply; + } + + if (resolved || rejected) { + return; + } + resolved = true; + resolve(response); + } + ).catch((error) => { + if (!rejected) { + rejected = true; + reject(error); + } + }); + }) + : await this.completionWithRetry( + { + ...params, + messages: messagesMapped, + }, + false, + options?.signal + ); + + const { total_tokens: totalTokens } = data.usage ?? {}; + + if (totalTokens) { + tokenUsage.totalTokens = totalTokens; + } + + if (data.base_resp?.status_code !== 0) { + throw new Error(`Minimax API error: ${data.base_resp?.status_msg}`); + } + const generations: ChatGeneration[] = []; + + if (this.proVersion) { + for (const choice of data.choices) { + const messages = choice.messages ?? []; + // 取最后一条消息 + if (messages) { + const message = messages[messages.length - 1]; + const text = message?.text ?? ""; + generations.push({ + text, + message: minimaxResponseToChatMessage(message), + }); + } + } + } else { + for (const choice of data.choices) { + const text = choice?.text ?? ""; + generations.push({ + text, + message: minimaxResponseToChatMessage({ + sender_type: "BOT", + sender_name: + options?.defaultBotName ?? this.defaultBotName ?? "Assistant", + text, + }), + }); + } + } + return { + generations, + llmOutput: { tokenUsage }, + }; + } + + /** @ignore */ + async completionWithRetry( + request: MinimaxChatCompletionRequest, + stream: boolean, + signal?: AbortSignal, + onmessage?: (event: MessageEvent) => void + ) { + // The first run will get the accessToken + const makeCompletionRequest = async () => { + const url = `${this.apiUrl}?GroupId=${this.minimaxGroupId}`; + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${this.minimaxApiKey}`, + ...this.headers, + }, + body: JSON.stringify(request), + signal, + }); + + if (!stream) { + const json = await response.json(); + return json as ChatCompletionResponse; + } else { + if (response.body) { + const reader = response.body.getReader(); + + const decoder = new TextDecoder("utf-8"); + let data = ""; + + let continueReading = true; + while (continueReading) { + const { done, value } = await reader.read(); + if (done) { + continueReading = false; + break; + } + data += decoder.decode(value); + + let continueProcessing = true; + while (continueProcessing) { + const newlineIndex = data.indexOf("\n"); + if (newlineIndex === -1) { + continueProcessing = false; + break; + } + const line = data.slice(0, newlineIndex); + data = data.slice(newlineIndex + 1); + + if (line.startsWith("data:")) { + const event = new MessageEvent("message", { + data: line.slice("data:".length).trim(), + }); + onmessage?.(event); + } + } + } + return {} as ChatCompletionResponse; + } + return {} as ChatCompletionResponse; + } + }; + return this.caller.call(makeCompletionRequest); + } + + _llmType() { + return "minimax"; + } + + /** @ignore */ + _combineLLMOutput() { + return []; + } + + private botSettingFallback( + options?: this["ParsedCallOptions"], + messages?: BaseMessage[] + ) { + const botSettings = options?.botSetting ?? this.botSetting; + if (!botSettings) { + const systemMessages = messages?.filter((message) => { + if (ChatMessage.isInstance(message)) { + return message.role === "system"; + } + return message._getType() === "system"; + }); + + // get the last system message + if (!systemMessages?.length) { + return; + } + const lastSystemMessage = systemMessages[systemMessages.length - 1]; + + if (typeof lastSystemMessage.content !== "string") { + throw new Error( + "ChatMinimax does not support non-string message content." + ); + } + + // setting the default botSetting. + this.botSetting = [ + { + content: lastSystemMessage.content, + bot_name: + options?.defaultBotName ?? this.defaultBotName ?? "Assistant", + }, + ]; + } + } +} + +function minimaxResponseToChatMessage( + message: ChatCompletionResponseMessage +): BaseMessage { + switch (message.sender_type) { + case "USER": + return new HumanMessage(message.text || ""); + case "BOT": + return new AIMessage(message.text || "", { + function_call: message.function_call, + }); + case "FUNCTION": + return new AIMessage(message.text || ""); + default: + return new ChatMessage( + message.text || "", + message.sender_type ?? "unknown" + ); + } +} + +/** ---Response Model---* */ +/** + * Interface representing a message responsed in the Minimax chat model. + */ +interface ChatCompletionResponseMessage { + sender_type: MinimaxMessageRole; + sender_name?: string; + text: string; + function_call?: ChatCompletionResponseMessageFunctionCall; +} + +/** + * Interface representing the usage of tokens in a chat completion. + */ +interface TokenUsage { + total_tokens?: number; +} + +interface BaseResp { + status_code?: number; + status_msg?: string; +} + +/** + * The name and arguments of a function that should be called, as generated by the model. + * @export + * @interface ChatCompletionResponseMessageFunctionCall + */ +export interface ChatCompletionResponseMessageFunctionCall { + /** + * The name of the function to call. + * @type {string} + * @memberof ChatCompletionResponseMessageFunctionCall + */ + name?: string; + /** + * The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. + * @type {string} + * @memberof ChatCompletionResponseMessageFunctionCall + */ + arguments?: string; +} + +/** + * + * @export + * @interface ChatCompletionResponseChoices + */ +export interface ChatCompletionResponseChoicesPro { + /** + * + * @type {string} + * @memberof ChatCompletionResponseChoices + */ + messages?: ChatCompletionResponseMessage[]; + + /** + * + * @type {string} + * @memberof ChatCompletionResponseChoices + */ + finish_reason?: string; +} + +interface ChatCompletionResponseChoices { + delta?: string; + text?: string; + index?: number; + finish_reason?: string; +} + +/** + * Interface representing a response from a chat completion. + */ +interface ChatCompletionResponse { + model: string; + created: number; + reply: string; + input_sensitive?: boolean; + input_sensitive_type?: number; + output_sensitive?: boolean; + output_sensitive_type?: number; + usage?: TokenUsage; + base_resp?: BaseResp; + choices: Array< + ChatCompletionResponseChoicesPro & ChatCompletionResponseChoices + >; +} diff --git a/libs/langchain-community/src/chat_models/ollama.ts b/libs/langchain-community/src/chat_models/ollama.ts new file mode 100644 index 000000000000..aa0d413e8ac0 --- /dev/null +++ b/libs/langchain-community/src/chat_models/ollama.ts @@ -0,0 +1,302 @@ +import { + SimpleChatModel, + type BaseChatModelParams, +} from "@langchain/core/language_models/chat_models"; +import type { BaseLanguageModelCallOptions } from "@langchain/core/language_models/base"; +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { + AIMessageChunk, + BaseMessage, + ChatMessage, +} from "@langchain/core/messages"; +import { ChatGenerationChunk } from "@langchain/core/outputs"; +import type { StringWithAutocomplete } from "@langchain/core/utils/types"; + +import { createOllamaStream, OllamaInput } from "../utils/ollama.js"; + +/** + * An interface defining the options for an Ollama API call. It extends + * the BaseLanguageModelCallOptions interface. + */ +export interface OllamaCallOptions extends BaseLanguageModelCallOptions {} + +/** + * A class that enables calls to the Ollama API to access large language + * models in a chat-like fashion. It extends the SimpleChatModel class and + * implements the OllamaInput interface. + * @example + * ```typescript + * const prompt = ChatPromptTemplate.fromMessages([ + * [ + * "system", + * `You are an expert translator. Format all responses as JSON objects with two keys: "original" and "translated".`, + * ], + * ["human", `Translate "{input}" into {language}.`], + * ]); + * + * const model = new ChatOllama({ + * baseUrl: "http://api.example.com", + * model: "llama2", + * format: "json", + * }); + * + * const chain = prompt.pipe(model); + * + * const result = await chain.invoke({ + * input: "I love programming", + * language: "German", + * }); + * + * ``` + */ +export class ChatOllama + extends SimpleChatModel + implements OllamaInput +{ + static lc_name() { + return "ChatOllama"; + } + + lc_serializable = true; + + model = "llama2"; + + baseUrl = "http://localhost:11434"; + + embeddingOnly?: boolean; + + f16KV?: boolean; + + frequencyPenalty?: number; + + logitsAll?: boolean; + + lowVram?: boolean; + + mainGpu?: number; + + mirostat?: number; + + mirostatEta?: number; + + mirostatTau?: number; + + numBatch?: number; + + numCtx?: number; + + numGpu?: number; + + numGqa?: number; + + numKeep?: number; + + numThread?: number; + + penalizeNewline?: boolean; + + presencePenalty?: number; + + repeatLastN?: number; + + repeatPenalty?: number; + + ropeFrequencyBase?: number; + + ropeFrequencyScale?: number; + + temperature?: number; + + stop?: string[]; + + tfsZ?: number; + + topK?: number; + + topP?: number; + + typicalP?: number; + + useMLock?: boolean; + + useMMap?: boolean; + + vocabOnly?: boolean; + + format?: StringWithAutocomplete<"json">; + + constructor(fields: OllamaInput & BaseChatModelParams) { + super(fields); + this.model = fields.model ?? this.model; + this.baseUrl = fields.baseUrl?.endsWith("/") + ? fields.baseUrl.slice(0, -1) + : fields.baseUrl ?? this.baseUrl; + this.embeddingOnly = fields.embeddingOnly; + this.f16KV = fields.f16KV; + this.frequencyPenalty = fields.frequencyPenalty; + this.logitsAll = fields.logitsAll; + this.lowVram = fields.lowVram; + this.mainGpu = fields.mainGpu; + this.mirostat = fields.mirostat; + this.mirostatEta = fields.mirostatEta; + this.mirostatTau = fields.mirostatTau; + this.numBatch = fields.numBatch; + this.numCtx = fields.numCtx; + this.numGpu = fields.numGpu; + this.numGqa = fields.numGqa; + this.numKeep = fields.numKeep; + this.numThread = fields.numThread; + this.penalizeNewline = fields.penalizeNewline; + this.presencePenalty = fields.presencePenalty; + this.repeatLastN = fields.repeatLastN; + this.repeatPenalty = fields.repeatPenalty; + this.ropeFrequencyBase = fields.ropeFrequencyBase; + this.ropeFrequencyScale = fields.ropeFrequencyScale; + this.temperature = fields.temperature; + this.stop = fields.stop; + this.tfsZ = fields.tfsZ; + this.topK = fields.topK; + this.topP = fields.topP; + this.typicalP = fields.typicalP; + this.useMLock = fields.useMLock; + this.useMMap = fields.useMMap; + this.vocabOnly = fields.vocabOnly; + this.format = fields.format; + } + + _llmType() { + return "ollama"; + } + + /** + * A method that returns the parameters for an Ollama API call. It + * includes model and options parameters. + * @param options Optional parsed call options. + * @returns An object containing the parameters for an Ollama API call. + */ + invocationParams(options?: this["ParsedCallOptions"]) { + return { + model: this.model, + format: this.format, + options: { + embedding_only: this.embeddingOnly, + f16_kv: this.f16KV, + frequency_penalty: this.frequencyPenalty, + logits_all: this.logitsAll, + low_vram: this.lowVram, + main_gpu: this.mainGpu, + mirostat: this.mirostat, + mirostat_eta: this.mirostatEta, + mirostat_tau: this.mirostatTau, + num_batch: this.numBatch, + num_ctx: this.numCtx, + num_gpu: this.numGpu, + num_gqa: this.numGqa, + num_keep: this.numKeep, + num_thread: this.numThread, + penalize_newline: this.penalizeNewline, + presence_penalty: this.presencePenalty, + repeat_last_n: this.repeatLastN, + repeat_penalty: this.repeatPenalty, + rope_frequency_base: this.ropeFrequencyBase, + rope_frequency_scale: this.ropeFrequencyScale, + temperature: this.temperature, + stop: options?.stop ?? this.stop, + tfs_z: this.tfsZ, + top_k: this.topK, + top_p: this.topP, + typical_p: this.typicalP, + use_mlock: this.useMLock, + use_mmap: this.useMMap, + vocab_only: this.vocabOnly, + }, + }; + } + + _combineLLMOutput() { + return {}; + } + + async *_streamResponseChunks( + input: BaseMessage[], + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): AsyncGenerator { + const stream = await this.caller.call(async () => + createOllamaStream( + this.baseUrl, + { + ...this.invocationParams(options), + prompt: this._formatMessagesAsPrompt(input), + }, + options + ) + ); + for await (const chunk of stream) { + if (!chunk.done) { + yield new ChatGenerationChunk({ + text: chunk.response, + message: new AIMessageChunk({ content: chunk.response }), + }); + await runManager?.handleLLMNewToken(chunk.response ?? ""); + } else { + yield new ChatGenerationChunk({ + text: "", + message: new AIMessageChunk({ content: "" }), + generationInfo: { + model: chunk.model, + total_duration: chunk.total_duration, + load_duration: chunk.load_duration, + prompt_eval_count: chunk.prompt_eval_count, + prompt_eval_duration: chunk.prompt_eval_duration, + eval_count: chunk.eval_count, + eval_duration: chunk.eval_duration, + }, + }); + } + } + } + + protected _formatMessagesAsPrompt(messages: BaseMessage[]): string { + const formattedMessages = messages + .map((message) => { + let messageText; + if (message._getType() === "human") { + messageText = `[INST] ${message.content} [/INST]`; + } else if (message._getType() === "ai") { + messageText = message.content; + } else if (message._getType() === "system") { + messageText = `<> ${message.content} <>`; + } else if (ChatMessage.isInstance(message)) { + messageText = `\n\n${message.role[0].toUpperCase()}${message.role.slice( + 1 + )}: ${message.content}`; + } else { + console.warn( + `Unsupported message type passed to Ollama: "${message._getType()}"` + ); + messageText = ""; + } + return messageText; + }) + .join("\n"); + return formattedMessages; + } + + /** @ignore */ + async _call( + messages: BaseMessage[], + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): Promise { + const chunks = []; + for await (const chunk of this._streamResponseChunks( + messages, + options, + runManager + )) { + chunks.push(chunk.message.content); + } + return chunks.join(""); + } +} diff --git a/libs/langchain-community/src/chat_models/portkey.ts b/libs/langchain-community/src/chat_models/portkey.ts new file mode 100644 index 000000000000..9941e58c05ff --- /dev/null +++ b/libs/langchain-community/src/chat_models/portkey.ts @@ -0,0 +1,185 @@ +import { LLMOptions } from "portkey-ai"; +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { + AIMessage, + AIMessageChunk, + BaseMessage, + ChatMessage, + ChatMessageChunk, + FunctionMessageChunk, + HumanMessage, + HumanMessageChunk, + SystemMessage, + SystemMessageChunk, +} from "@langchain/core/messages"; +import { + ChatResult, + ChatGeneration, + ChatGenerationChunk, +} from "@langchain/core/outputs"; +import { BaseChatModel } from "@langchain/core/language_models/chat_models"; + +import { PortkeySession, getPortkeySession } from "../llms/portkey.js"; + +interface Message { + role?: string; + content?: string; +} + +function portkeyResponseToChatMessage(message: Message): BaseMessage { + switch (message.role) { + case "user": + return new HumanMessage(message.content || ""); + case "assistant": + return new AIMessage(message.content || ""); + case "system": + return new SystemMessage(message.content || ""); + default: + return new ChatMessage(message.content || "", message.role ?? "unknown"); + } +} + +function _convertDeltaToMessageChunk( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + delta: Record +) { + const { role } = delta; + const content = delta.content ?? ""; + let additional_kwargs; + if (delta.function_call) { + additional_kwargs = { + function_call: delta.function_call, + }; + } else { + additional_kwargs = {}; + } + if (role === "user") { + return new HumanMessageChunk({ content }); + } else if (role === "assistant") { + return new AIMessageChunk({ content, additional_kwargs }); + } else if (role === "system") { + return new SystemMessageChunk({ content }); + } else if (role === "function") { + return new FunctionMessageChunk({ + content, + additional_kwargs, + name: delta.name, + }); + } else { + return new ChatMessageChunk({ content, role }); + } +} + +export class PortkeyChat extends BaseChatModel { + apiKey?: string = undefined; + + baseURL?: string = undefined; + + mode?: string = undefined; + + llms?: [LLMOptions] | null = undefined; + + session: PortkeySession; + + constructor(init?: Partial) { + super(init ?? {}); + this.apiKey = init?.apiKey; + this.baseURL = init?.baseURL; + this.mode = init?.mode; + this.llms = init?.llms; + this.session = getPortkeySession({ + apiKey: this.apiKey, + baseURL: this.baseURL, + llms: this.llms, + mode: this.mode, + }); + } + + _llmType() { + return "portkey"; + } + + async _generate( + messages: BaseMessage[], + options: this["ParsedCallOptions"], + _?: CallbackManagerForLLMRun + ): Promise { + const messagesList = messages.map((message) => { + if (typeof message.content !== "string") { + throw new Error( + "PortkeyChat does not support non-string message content." + ); + } + return { + role: message._getType() as string, + content: message.content, + }; + }); + const response = await this.session.portkey.chatCompletions.create({ + messages: messagesList, + ...options, + stream: false, + }); + const generations: ChatGeneration[] = []; + for (const data of response.choices ?? []) { + const text = data.message?.content ?? ""; + const generation: ChatGeneration = { + text, + message: portkeyResponseToChatMessage(data.message ?? {}), + }; + if (data.finish_reason) { + generation.generationInfo = { finish_reason: data.finish_reason }; + } + generations.push(generation); + } + + return { + generations, + }; + } + + async *_streamResponseChunks( + messages: BaseMessage[], + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): AsyncGenerator { + const messagesList = messages.map((message) => { + if (typeof message.content !== "string") { + throw new Error( + "PortkeyChat does not support non-string message content." + ); + } + return { + role: message._getType() as string, + content: message.content, + }; + }); + const response = await this.session.portkey.chatCompletions.create({ + messages: messagesList, + ...options, + stream: true, + }); + for await (const data of response) { + const choice = data?.choices[0]; + if (!choice) { + continue; + } + const chunk = new ChatGenerationChunk({ + message: _convertDeltaToMessageChunk(choice.delta ?? {}), + text: choice.message?.content ?? "", + generationInfo: { + finishReason: choice.finish_reason, + }, + }); + yield chunk; + void runManager?.handleLLMNewToken(chunk.text ?? ""); + } + if (options.signal?.aborted) { + throw new Error("AbortError"); + } + } + + _combineLLMOutput() { + return {}; + } +} diff --git a/libs/langchain-community/src/chat_models/yandex.ts b/libs/langchain-community/src/chat_models/yandex.ts new file mode 100644 index 000000000000..dd982510609e --- /dev/null +++ b/libs/langchain-community/src/chat_models/yandex.ts @@ -0,0 +1,139 @@ +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { AIMessage, BaseMessage } from "@langchain/core/messages"; +import { ChatResult, ChatGeneration } from "@langchain/core/outputs"; +import { BaseChatModel } from "@langchain/core/language_models/chat_models"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +import { YandexGPTInputs } from "../llms/yandex.js"; + +const apiUrl = "https://llm.api.cloud.yandex.net/llm/v1alpha/chat"; + +interface ParsedMessage { + role: string; + text: string; +} + +function _parseChatHistory(history: BaseMessage[]): [ParsedMessage[], string] { + const chatHistory: ParsedMessage[] = []; + let instruction = ""; + + for (const message of history) { + if (typeof message.content !== "string") { + throw new Error( + "ChatYandexGPT does not support non-string message content." + ); + } + if ("content" in message) { + if (message._getType() === "human") { + chatHistory.push({ role: "user", text: message.content }); + } else if (message._getType() === "ai") { + chatHistory.push({ role: "assistant", text: message.content }); + } else if (message._getType() === "system") { + instruction = message.content; + } + } + } + + return [chatHistory, instruction]; +} + +/** + * @example + * ```typescript + * const chat = new ChatYandexGPT({}); + * // The assistant is set to translate English to French. + * const res = await chat.call([ + * new SystemMessage( + * "You are a helpful assistant that translates English to French." + * ), + * new HumanMessage("I love programming."), + * ]); + * console.log(res); + * ``` + */ +export class ChatYandexGPT extends BaseChatModel { + apiKey?: string; + + iamToken?: string; + + temperature = 0.6; + + maxTokens = 1700; + + model = "general"; + + constructor(fields?: YandexGPTInputs) { + super(fields ?? {}); + + const apiKey = fields?.apiKey ?? getEnvironmentVariable("YC_API_KEY"); + + const iamToken = fields?.iamToken ?? getEnvironmentVariable("YC_IAM_TOKEN"); + + if (apiKey === undefined && iamToken === undefined) { + throw new Error( + "Please set the YC_API_KEY or YC_IAM_TOKEN environment variable or pass it to the constructor as the apiKey or iamToken field." + ); + } + + this.apiKey = apiKey; + this.iamToken = iamToken; + this.maxTokens = fields?.maxTokens ?? this.maxTokens; + this.temperature = fields?.temperature ?? this.temperature; + this.model = fields?.model ?? this.model; + } + + _llmType() { + return "yandexgpt"; + } + + _combineLLMOutput?() { + return {}; + } + + /** @ignore */ + async _generate( + messages: BaseMessage[], + options: this["ParsedCallOptions"], + _?: CallbackManagerForLLMRun | undefined + ): Promise { + const [messageHistory, instruction] = _parseChatHistory(messages); + const headers = { "Content-Type": "application/json", Authorization: "" }; + if (this.apiKey !== undefined) { + headers.Authorization = `Api-Key ${this.apiKey}`; + } else { + headers.Authorization = `Bearer ${this.iamToken}`; + } + const bodyData = { + model: this.model, + generationOptions: { + temperature: this.temperature, + maxTokens: this.maxTokens, + }, + messages: messageHistory, + instructionText: instruction, + }; + const response = await fetch(apiUrl, { + method: "POST", + headers, + body: JSON.stringify(bodyData), + signal: options?.signal, + }); + if (!response.ok) { + throw new Error( + `Failed to fetch ${apiUrl} from YandexGPT: ${response.status}` + ); + } + const responseData = await response.json(); + const { result } = responseData; + const { text } = result.message; + const totalTokens = result.num_tokens; + const generations: ChatGeneration[] = [ + { text, message: new AIMessage(text) }, + ]; + + return { + generations, + llmOutput: { totalTokens }, + }; + } +} diff --git a/libs/langchain-community/src/document_transformers/html_to_text.ts b/libs/langchain-community/src/document_transformers/html_to_text.ts new file mode 100644 index 000000000000..8676517b65a9 --- /dev/null +++ b/libs/langchain-community/src/document_transformers/html_to_text.ts @@ -0,0 +1,44 @@ +import { htmlToText, type HtmlToTextOptions } from "html-to-text"; +import { + MappingDocumentTransformer, + Document, +} from "@langchain/core/documents"; + +/** + * A transformer that converts HTML content to plain text. + * @example + * ```typescript + * const loader = new CheerioWebBaseLoader("https://example.com/some-page"); + * const docs = await loader.load(); + * + * const splitter = new RecursiveCharacterTextSplitter({ + * maxCharacterCount: 1000, + * }); + * const transformer = new HtmlToTextTransformer(); + * + * // The sequence of text splitting followed by HTML to text transformation + * const sequence = splitter.pipe(transformer); + * + * // Processing the loaded documents through the sequence + * const newDocuments = await sequence.invoke(docs); + * + * console.log(newDocuments); + * ``` + */ +export class HtmlToTextTransformer extends MappingDocumentTransformer { + static lc_name() { + return "HtmlToTextTransformer"; + } + + constructor(protected options: HtmlToTextOptions = {}) { + super(options); + } + + async _transformDocument(document: Document): Promise { + const extractedContent = htmlToText(document.pageContent, this.options); + return new Document({ + pageContent: extractedContent, + metadata: { ...document.metadata }, + }); + } +} diff --git a/libs/langchain-community/src/document_transformers/mozilla_readability.ts b/libs/langchain-community/src/document_transformers/mozilla_readability.ts new file mode 100644 index 000000000000..a26b42a6d6c7 --- /dev/null +++ b/libs/langchain-community/src/document_transformers/mozilla_readability.ts @@ -0,0 +1,54 @@ +import { Readability } from "@mozilla/readability"; +import { JSDOM } from "jsdom"; +import type { Options } from "mozilla-readability"; +import { + MappingDocumentTransformer, + Document, +} from "@langchain/core/documents"; + +/** + * A transformer that uses the Mozilla Readability library to extract the + * main content from a web page. + * @example + * ```typescript + * const loader = new CheerioWebBaseLoader("https://example.com/article"); + * const docs = await loader.load(); + * + * const splitter = new RecursiveCharacterTextSplitter({ + * maxCharacterCount: 5000, + * }); + * const transformer = new MozillaReadabilityTransformer(); + * + * // The sequence processes the loaded documents through the splitter and then the transformer. + * const sequence = splitter.pipe(transformer); + * + * // Invoke the sequence to transform the documents into a more readable format. + * const newDocuments = await sequence.invoke(docs); + * + * console.log(newDocuments); + * ``` + */ +export class MozillaReadabilityTransformer extends MappingDocumentTransformer { + static lc_name() { + return "MozillaReadabilityTransformer"; + } + + constructor(protected options: Options = {}) { + super(options); + } + + async _transformDocument(document: Document): Promise { + const doc = new JSDOM(document.pageContent); + + const readability = new Readability(doc.window.document, this.options); + + const result = readability.parse(); + + return new Document({ + pageContent: result?.textContent ?? "", + metadata: { + ...document.metadata, + }, + }); + } +} diff --git a/libs/langchain-community/src/embeddings/bedrock.ts b/libs/langchain-community/src/embeddings/bedrock.ts new file mode 100644 index 000000000000..305387007cd0 --- /dev/null +++ b/libs/langchain-community/src/embeddings/bedrock.ts @@ -0,0 +1,142 @@ +import { + BedrockRuntimeClient, + InvokeModelCommand, +} from "@aws-sdk/client-bedrock-runtime"; +import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings"; +import type { CredentialType } from "../utils/bedrock.js"; + +/** + * Interface that extends EmbeddingsParams and defines additional + * parameters specific to the BedrockEmbeddings class. + */ +export interface BedrockEmbeddingsParams extends EmbeddingsParams { + /** + * Model Name to use. Defaults to `amazon.titan-embed-text-v1` if not provided + * + */ + model?: string; + + /** + * A client provided by the user that allows them to customze any + * SDK configuration options. + */ + client?: BedrockRuntimeClient; + + region?: string; + + credentials?: CredentialType; +} + +/** + * Class that extends the Embeddings class and provides methods for + * generating embeddings using the Bedrock API. + * @example + * ```typescript + * const embeddings = new BedrockEmbeddings({ + * region: "your-aws-region", + * credentials: { + * accessKeyId: "your-access-key-id", + * secretAccessKey: "your-secret-access-key", + * }, + * model: "amazon.titan-embed-text-v1", + * }); + * + * // Embed a query and log the result + * const res = await embeddings.embedQuery( + * "What would be a good company name for a company that makes colorful socks?" + * ); + * console.log({ res }); + * ``` + */ +export class BedrockEmbeddings + extends Embeddings + implements BedrockEmbeddingsParams +{ + model: string; + + client: BedrockRuntimeClient; + + batchSize = 512; + + constructor(fields?: BedrockEmbeddingsParams) { + super(fields ?? {}); + + this.model = fields?.model ?? "amazon.titan-embed-text-v1"; + + this.client = + fields?.client ?? + new BedrockRuntimeClient({ + region: fields?.region, + credentials: fields?.credentials, + }); + } + + /** + * Protected method to make a request to the Bedrock API to generate + * embeddings. Handles the retry logic and returns the response from the + * API. + * @param request Request to send to the Bedrock API. + * @returns Promise that resolves to the response from the API. + */ + protected async _embedText(text: string): Promise { + return this.caller.call(async () => { + try { + // replace newlines, which can negatively affect performance. + const cleanedText = text.replace(/\n/g, " "); + + const res = await this.client.send( + new InvokeModelCommand({ + modelId: this.model, + body: JSON.stringify({ + inputText: cleanedText, + }), + contentType: "application/json", + accept: "application/json", + }) + ); + + const body = new TextDecoder().decode(res.body); + return JSON.parse(body).embedding; + } catch (e) { + console.error({ + error: e, + }); + // eslint-disable-next-line no-instanceof/no-instanceof + if (e instanceof Error) { + throw new Error( + `An error occurred while embedding documents with Bedrock: ${e.message}` + ); + } + + throw new Error( + "An error occurred while embedding documents with Bedrock" + ); + } + }); + } + + /** + * Method that takes a document as input and returns a promise that + * resolves to an embedding for the document. It calls the _embedText + * method with the document as the input. + * @param document Document for which to generate an embedding. + * @returns Promise that resolves to an embedding for the input document. + */ + embedQuery(document: string): Promise { + return this.caller.callWithOptions( + {}, + this._embedText.bind(this), + document + ); + } + + /** + * Method to generate embeddings for an array of texts. Calls _embedText + * method which batches and handles retry logic when calling the AWS Bedrock API. + * @param documents Array of texts for which to generate embeddings. + * @returns Promise that resolves to a 2D array of embeddings for each input document. + */ + async embedDocuments(documents: string[]): Promise { + return Promise.all(documents.map((document) => this._embedText(document))); + } +} diff --git a/libs/langchain-community/src/embeddings/cloudflare_workersai.ts b/libs/langchain-community/src/embeddings/cloudflare_workersai.ts new file mode 100644 index 000000000000..20ff93e070c7 --- /dev/null +++ b/libs/langchain-community/src/embeddings/cloudflare_workersai.ts @@ -0,0 +1,94 @@ +import { Ai } from "@cloudflare/ai"; +import { Fetcher } from "@cloudflare/workers-types"; +import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings"; +import { chunkArray } from "../utils/chunk.js"; + +type AiTextEmbeddingsInput = { + text: string | string[]; +}; + +type AiTextEmbeddingsOutput = { + shape: number[]; + data: number[][]; +}; + +export interface CloudflareWorkersAIEmbeddingsParams extends EmbeddingsParams { + /** Binding */ + binding: Fetcher; + + /** Model name to use */ + modelName?: string; + + /** + * The maximum number of documents to embed in a single request. + */ + batchSize?: number; + + /** + * Whether to strip new lines from the input text. This is recommended by + * OpenAI, but may not be suitable for all use cases. + */ + stripNewLines?: boolean; +} + +export class CloudflareWorkersAIEmbeddings extends Embeddings { + modelName = "@cf/baai/bge-base-en-v1.5"; + + batchSize = 50; + + stripNewLines = true; + + ai: Ai; + + constructor(fields: CloudflareWorkersAIEmbeddingsParams) { + super(fields); + + if (!fields.binding) { + throw new Error( + "Must supply a Workers AI binding, eg { binding: env.AI }" + ); + } + this.ai = new Ai(fields.binding); + this.modelName = fields.modelName ?? this.modelName; + this.stripNewLines = fields.stripNewLines ?? this.stripNewLines; + } + + async embedDocuments(texts: string[]): Promise { + const batches = chunkArray( + this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts, + this.batchSize + ); + + const batchRequests = batches.map((batch) => this.runEmbedding(batch)); + const batchResponses = await Promise.all(batchRequests); + const embeddings: number[][] = []; + + for (let i = 0; i < batchResponses.length; i += 1) { + const batchResponse = batchResponses[i]; + for (let j = 0; j < batchResponse.length; j += 1) { + embeddings.push(batchResponse[j]); + } + } + + return embeddings; + } + + async embedQuery(text: string): Promise { + const data = await this.runEmbedding([ + this.stripNewLines ? text.replace(/\n/g, " ") : text, + ]); + return data[0]; + } + + private async runEmbedding(texts: string[]) { + return this.caller.call(async () => { + const response: AiTextEmbeddingsOutput = await this.ai.run( + this.modelName, + { + text: texts, + } as AiTextEmbeddingsInput + ); + return response.data; + }); + } +} diff --git a/libs/langchain-community/src/embeddings/cohere.ts b/libs/langchain-community/src/embeddings/cohere.ts new file mode 100644 index 000000000000..86e151ab7653 --- /dev/null +++ b/libs/langchain-community/src/embeddings/cohere.ts @@ -0,0 +1,155 @@ +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings"; +import { chunkArray } from "../utils/chunk.js"; + +/** + * Interface that extends EmbeddingsParams and defines additional + * parameters specific to the CohereEmbeddings class. + */ +export interface CohereEmbeddingsParams extends EmbeddingsParams { + modelName: string; + + /** + * The maximum number of documents to embed in a single request. This is + * limited by the Cohere API to a maximum of 96. + */ + batchSize?: number; +} + +/** + * A class for generating embeddings using the Cohere API. + * @example + * ```typescript + * // Embed a query using the CohereEmbeddings class + * const model = new ChatOpenAI(); + * const res = await model.embedQuery( + * "What would be a good company name for a company that makes colorful socks?", + * ); + * console.log({ res }); + * + * ``` + */ +export class CohereEmbeddings + extends Embeddings + implements CohereEmbeddingsParams +{ + modelName = "small"; + + batchSize = 48; + + private apiKey: string; + + private client: typeof import("cohere-ai"); + + /** + * Constructor for the CohereEmbeddings class. + * @param fields - An optional object with properties to configure the instance. + */ + constructor( + fields?: Partial & { + verbose?: boolean; + apiKey?: string; + } + ) { + const fieldsWithDefaults = { maxConcurrency: 2, ...fields }; + + super(fieldsWithDefaults); + + const apiKey = + fieldsWithDefaults?.apiKey || getEnvironmentVariable("COHERE_API_KEY"); + + if (!apiKey) { + throw new Error("Cohere API key not found"); + } + + this.modelName = fieldsWithDefaults?.modelName ?? this.modelName; + this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize; + this.apiKey = apiKey; + } + + /** + * Generates embeddings for an array of texts. + * @param texts - An array of strings to generate embeddings for. + * @returns A Promise that resolves to an array of embeddings. + */ + async embedDocuments(texts: string[]): Promise { + await this.maybeInitClient(); + + const batches = chunkArray(texts, this.batchSize); + + const batchRequests = batches.map((batch) => + this.embeddingWithRetry({ + model: this.modelName, + texts: batch, + }) + ); + + const batchResponses = await Promise.all(batchRequests); + + const embeddings: number[][] = []; + + for (let i = 0; i < batchResponses.length; i += 1) { + const batch = batches[i]; + const { body: batchResponse } = batchResponses[i]; + for (let j = 0; j < batch.length; j += 1) { + embeddings.push(batchResponse.embeddings[j]); + } + } + + return embeddings; + } + + /** + * Generates an embedding for a single text. + * @param text - A string to generate an embedding for. + * @returns A Promise that resolves to an array of numbers representing the embedding. + */ + async embedQuery(text: string): Promise { + await this.maybeInitClient(); + + const { body } = await this.embeddingWithRetry({ + model: this.modelName, + texts: [text], + }); + return body.embeddings[0]; + } + + /** + * Generates embeddings with retry capabilities. + * @param request - An object containing the request parameters for generating embeddings. + * @returns A Promise that resolves to the API response. + */ + private async embeddingWithRetry( + request: Parameters[0] + ) { + await this.maybeInitClient(); + + return this.caller.call(this.client.embed.bind(this.client), request); + } + + /** + * Initializes the Cohere client if it hasn't been initialized already. + */ + private async maybeInitClient() { + if (!this.client) { + const { cohere } = await CohereEmbeddings.imports(); + + this.client = cohere; + this.client.init(this.apiKey); + } + } + + /** @ignore */ + static async imports(): Promise<{ + cohere: typeof import("cohere-ai"); + }> { + try { + const { default: cohere } = await import("cohere-ai"); + return { cohere }; + } catch (e) { + throw new Error( + "Please install cohere-ai as a dependency with, e.g. `yarn add cohere-ai`" + ); + } + } +} diff --git a/libs/langchain-community/src/embeddings/googlepalm.ts b/libs/langchain-community/src/embeddings/googlepalm.ts new file mode 100644 index 000000000000..93ed0743e601 --- /dev/null +++ b/libs/langchain-community/src/embeddings/googlepalm.ts @@ -0,0 +1,107 @@ +import { TextServiceClient } from "@google-ai/generativelanguage"; +import { GoogleAuth } from "google-auth-library"; +import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +/** + * Interface that extends EmbeddingsParams and defines additional + * parameters specific to the GooglePaLMEmbeddings class. + */ +export interface GooglePaLMEmbeddingsParams extends EmbeddingsParams { + /** + * Model Name to use + * + * Note: The format must follow the pattern - `models/{model}` + */ + modelName?: string; + /** + * Google Palm API key to use + */ + apiKey?: string; +} + +/** + * Class that extends the Embeddings class and provides methods for + * generating embeddings using the Google Palm API. + * @example + * ```typescript + * const model = new GooglePaLMEmbeddings({ + * apiKey: "", + * modelName: "models/embedding-gecko-001", + * }); + * + * // Embed a single query + * const res = await model.embedQuery( + * "What would be a good company name for a company that makes colorful socks?" + * ); + * console.log({ res }); + * + * // Embed multiple documents + * const documentRes = await model.embedDocuments(["Hello world", "Bye bye"]); + * console.log({ documentRes }); + * ``` + */ +export class GooglePaLMEmbeddings + extends Embeddings + implements GooglePaLMEmbeddingsParams +{ + apiKey?: string; + + modelName = "models/embedding-gecko-001"; + + private client: TextServiceClient; + + constructor(fields?: GooglePaLMEmbeddingsParams) { + super(fields ?? {}); + + this.modelName = fields?.modelName ?? this.modelName; + + this.apiKey = + fields?.apiKey ?? getEnvironmentVariable("GOOGLE_PALM_API_KEY"); + if (!this.apiKey) { + throw new Error( + "Please set an API key for Google Palm 2 in the environment variable GOOGLE_PALM_API_KEY or in the `apiKey` field of the GooglePalm constructor" + ); + } + + this.client = new TextServiceClient({ + authClient: new GoogleAuth().fromAPIKey(this.apiKey), + }); + } + + protected async _embedText(text: string): Promise { + // replace newlines, which can negatively affect performance. + const cleanedText = text.replace(/\n/g, " "); + const res = await this.client.embedText({ + model: this.modelName, + text: cleanedText, + }); + return res[0].embedding?.value ?? []; + } + + /** + * Method that takes a document as input and returns a promise that + * resolves to an embedding for the document. It calls the _embedText + * method with the document as the input. + * @param document Document for which to generate an embedding. + * @returns Promise that resolves to an embedding for the input document. + */ + embedQuery(document: string): Promise { + return this.caller.callWithOptions( + {}, + this._embedText.bind(this), + document + ); + } + + /** + * Method that takes an array of documents as input and returns a promise + * that resolves to a 2D array of embeddings for each document. It calls + * the _embedText method for each document in the array. + * @param documents Array of documents for which to generate embeddings. + * @returns Promise that resolves to a 2D array of embeddings for each input document. + */ + embedDocuments(documents: string[]): Promise { + return Promise.all(documents.map((document) => this._embedText(document))); + } +} diff --git a/libs/langchain-community/src/embeddings/googlevertexai.ts b/libs/langchain-community/src/embeddings/googlevertexai.ts new file mode 100644 index 000000000000..54b9e3d4ac89 --- /dev/null +++ b/libs/langchain-community/src/embeddings/googlevertexai.ts @@ -0,0 +1,145 @@ +import { GoogleAuth, GoogleAuthOptions } from "google-auth-library"; +import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings"; +import { AsyncCallerCallOptions } from "@langchain/core/utils/async_caller"; +import { + GoogleVertexAIBasePrediction, + GoogleVertexAIBaseLLMInput, + GoogleVertexAILLMPredictions, +} from "../types/googlevertexai-types.js"; +import { GoogleVertexAILLMConnection } from "../utils/googlevertexai-connection.js"; +import { chunkArray } from "../utils/chunk.js"; + +/** + * Defines the parameters required to initialize a + * GoogleVertexAIEmbeddings instance. It extends EmbeddingsParams and + * GoogleVertexAIConnectionParams. + */ +export interface GoogleVertexAIEmbeddingsParams + extends EmbeddingsParams, + GoogleVertexAIBaseLLMInput {} + +/** + * Defines additional options specific to the + * GoogleVertexAILLMEmbeddingsInstance. It extends AsyncCallerCallOptions. + */ +interface GoogleVertexAILLMEmbeddingsOptions extends AsyncCallerCallOptions {} + +/** + * Represents an instance for generating embeddings using the Google + * Vertex AI API. It contains the content to be embedded. + */ +interface GoogleVertexAILLMEmbeddingsInstance { + content: string; +} + +/** + * Defines the structure of the embeddings results returned by the Google + * Vertex AI API. It extends GoogleVertexAIBasePrediction and contains the + * embeddings and their statistics. + */ +interface GoogleVertexEmbeddingsResults extends GoogleVertexAIBasePrediction { + embeddings: { + statistics: { + token_count: number; + truncated: boolean; + }; + values: number[]; + }; +} + +/** + * Enables calls to the Google Cloud's Vertex AI API to access + * the embeddings generated by Large Language Models. + * + * To use, you will need to have one of the following authentication + * methods in place: + * - You are logged into an account permitted to the Google Cloud project + * using Vertex AI. + * - You are running this on a machine using a service account permitted to + * the Google Cloud project using Vertex AI. + * - The `GOOGLE_APPLICATION_CREDENTIALS` environment variable is set to the + * path of a credentials file for a service account permitted to the + * Google Cloud project using Vertex AI. + * @example + * ```typescript + * const model = new GoogleVertexAIEmbeddings(); + * const res = await model.embedQuery( + * "What would be a good company name for a company that makes colorful socks?" + * ); + * console.log({ res }); + * ``` + */ +export class GoogleVertexAIEmbeddings + extends Embeddings + implements GoogleVertexAIEmbeddingsParams +{ + model = "textembedding-gecko"; + + private connection: GoogleVertexAILLMConnection< + GoogleVertexAILLMEmbeddingsOptions, + GoogleVertexAILLMEmbeddingsInstance, + GoogleVertexEmbeddingsResults, + GoogleAuthOptions + >; + + constructor(fields?: GoogleVertexAIEmbeddingsParams) { + super(fields ?? {}); + + this.model = fields?.model ?? this.model; + + this.connection = new GoogleVertexAILLMConnection( + { ...fields, ...this }, + this.caller, + new GoogleAuth({ + scopes: "https://www.googleapis.com/auth/cloud-platform", + ...fields?.authOptions, + }) + ); + } + + /** + * Takes an array of documents as input and returns a promise that + * resolves to a 2D array of embeddings for each document. It splits the + * documents into chunks and makes requests to the Google Vertex AI API to + * generate embeddings. + * @param documents An array of documents to be embedded. + * @returns A promise that resolves to a 2D array of embeddings for each document. + */ + async embedDocuments(documents: string[]): Promise { + const instanceChunks: GoogleVertexAILLMEmbeddingsInstance[][] = chunkArray( + documents.map((document) => ({ + content: document, + })), + 5 + ); // Vertex AI accepts max 5 instances per prediction + const parameters = {}; + const options = {}; + const responses = await Promise.all( + instanceChunks.map((instances) => + this.connection.request(instances, parameters, options) + ) + ); + const result: number[][] = + responses + ?.map( + (response) => + ( + response?.data as GoogleVertexAILLMPredictions + )?.predictions?.map((result) => result.embeddings.values) ?? [] + ) + .flat() ?? []; + return result; + } + + /** + * Takes a document as input and returns a promise that resolves to an + * embedding for the document. It calls the embedDocuments method with the + * document as the input. + * @param document A document to be embedded. + * @returns A promise that resolves to an embedding for the document. + */ + async embedQuery(document: string): Promise { + const data = await this.embedDocuments([document]); + return data[0]; + } +} diff --git a/libs/langchain-community/src/embeddings/gradient_ai.ts b/libs/langchain-community/src/embeddings/gradient_ai.ts new file mode 100644 index 000000000000..f64f38483475 --- /dev/null +++ b/libs/langchain-community/src/embeddings/gradient_ai.ts @@ -0,0 +1,118 @@ +import { Gradient } from "@gradientai/nodejs-sdk"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings"; +import { chunkArray } from "../utils/chunk.js"; + +/** + * Interface for GradientEmbeddings parameters. Extends EmbeddingsParams and + * defines additional parameters specific to the GradientEmbeddings class. + */ +export interface GradientEmbeddingsParams extends EmbeddingsParams { + /** + * Gradient AI Access Token. + * Provide Access Token if you do not wish to automatically pull from env. + */ + gradientAccessKey?: string; + /** + * Gradient Workspace Id. + * Provide workspace id if you do not wish to automatically pull from env. + */ + workspaceId?: string; +} + +/** + * Class for generating embeddings using the Gradient AI's API. Extends the + * Embeddings class and implements GradientEmbeddingsParams and + */ +export class GradientEmbeddings + extends Embeddings + implements GradientEmbeddingsParams +{ + gradientAccessKey?: string; + + workspaceId?: string; + + batchSize = 128; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + model: any; + + constructor(fields: GradientEmbeddingsParams) { + super(fields); + + this.gradientAccessKey = + fields?.gradientAccessKey ?? + getEnvironmentVariable("GRADIENT_ACCESS_TOKEN"); + this.workspaceId = + fields?.workspaceId ?? getEnvironmentVariable("GRADIENT_WORKSPACE_ID"); + + if (!this.gradientAccessKey) { + throw new Error("Missing Gradient AI Access Token"); + } + + if (!this.workspaceId) { + throw new Error("Missing Gradient AI Workspace ID"); + } + } + + /** + * Method to generate embeddings for an array of documents. Splits the + * documents into batches and makes requests to the Gradient API to generate + * embeddings. + * @param texts Array of documents to generate embeddings for. + * @returns Promise that resolves to a 2D array of embeddings for each document. + */ + async embedDocuments(texts: string[]): Promise { + await this.setModel(); + + const mappedTexts = texts.map((text) => ({ input: text })); + + const batches = chunkArray(mappedTexts, this.batchSize); + + const batchRequests = batches.map((batch) => + this.caller.call(async () => + this.model.generateEmbeddings({ + inputs: batch, + }) + ) + ); + const batchResponses = await Promise.all(batchRequests); + + const embeddings: number[][] = []; + for (let i = 0; i < batchResponses.length; i += 1) { + const batch = batches[i]; + const { embeddings: batchResponse } = batchResponses[i]; + for (let j = 0; j < batch.length; j += 1) { + embeddings.push(batchResponse[j].embedding); + } + } + return embeddings; + } + + /** + * Method to generate an embedding for a single document. Calls the + * embedDocuments method with the document as the input. + * @param text Document to generate an embedding for. + * @returns Promise that resolves to an embedding for the document. + */ + async embedQuery(text: string): Promise { + const data = await this.embedDocuments([text]); + return data[0]; + } + + /** + * Method to set the model to use for generating embeddings. + * @sets the class' `model` value to that of the retrieved Embeddings Model. + */ + async setModel() { + if (this.model) return; + + const gradient = new Gradient({ + accessToken: this.gradientAccessKey, + workspaceId: this.workspaceId, + }); + this.model = await gradient.getEmbeddingsModel({ + slug: "bge-large", + }); + } +} diff --git a/libs/langchain-community/src/embeddings/hf.ts b/libs/langchain-community/src/embeddings/hf.ts new file mode 100644 index 000000000000..cbe66a5b3a46 --- /dev/null +++ b/libs/langchain-community/src/embeddings/hf.ts @@ -0,0 +1,77 @@ +import { HfInference, HfInferenceEndpoint } from "@huggingface/inference"; +import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +/** + * Interface that extends EmbeddingsParams and defines additional + * parameters specific to the HuggingFaceInferenceEmbeddings class. + */ +export interface HuggingFaceInferenceEmbeddingsParams extends EmbeddingsParams { + apiKey?: string; + model?: string; + endpointUrl?: string; +} + +/** + * Class that extends the Embeddings class and provides methods for + * generating embeddings using Hugging Face models through the + * HuggingFaceInference API. + */ +export class HuggingFaceInferenceEmbeddings + extends Embeddings + implements HuggingFaceInferenceEmbeddingsParams +{ + apiKey?: string; + + model: string; + + endpointUrl?: string; + + client: HfInference | HfInferenceEndpoint; + + constructor(fields?: HuggingFaceInferenceEmbeddingsParams) { + super(fields ?? {}); + + this.model = fields?.model ?? "BAAI/bge-base-en-v1.5"; + this.apiKey = + fields?.apiKey ?? getEnvironmentVariable("HUGGINGFACEHUB_API_KEY"); + this.endpointUrl = fields?.endpointUrl; + this.client = this.endpointUrl + ? new HfInference(this.apiKey).endpoint(this.endpointUrl) + : new HfInference(this.apiKey); + } + + async _embed(texts: string[]): Promise { + // replace newlines, which can negatively affect performance. + const clean = texts.map((text) => text.replace(/\n/g, " ")); + return this.caller.call(() => + this.client.featureExtraction({ + model: this.model, + inputs: clean, + }) + ) as Promise; + } + + /** + * Method that takes a document as input and returns a promise that + * resolves to an embedding for the document. It calls the _embed method + * with the document as the input and returns the first embedding in the + * resulting array. + * @param document Document to generate an embedding for. + * @returns Promise that resolves to an embedding for the document. + */ + embedQuery(document: string): Promise { + return this._embed([document]).then((embeddings) => embeddings[0]); + } + + /** + * Method that takes an array of documents as input and returns a promise + * that resolves to a 2D array of embeddings for each document. It calls + * the _embed method with the documents as the input. + * @param documents Array of documents to generate embeddings for. + * @returns Promise that resolves to a 2D array of embeddings for each document. + */ + embedDocuments(documents: string[]): Promise { + return this._embed(documents); + } +} diff --git a/libs/langchain-community/src/embeddings/hf_transformers.ts b/libs/langchain-community/src/embeddings/hf_transformers.ts new file mode 100644 index 000000000000..b70b9fe4c646 --- /dev/null +++ b/libs/langchain-community/src/embeddings/hf_transformers.ts @@ -0,0 +1,105 @@ +import { Pipeline, pipeline } from "@xenova/transformers"; +import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings"; +import { chunkArray } from "../utils/chunk.js"; + +export interface HuggingFaceTransformersEmbeddingsParams + extends EmbeddingsParams { + /** Model name to use */ + modelName: string; + + /** + * Timeout to use when making requests to OpenAI. + */ + timeout?: number; + + /** + * The maximum number of documents to embed in a single request. + */ + batchSize?: number; + + /** + * Whether to strip new lines from the input text. This is recommended by + * OpenAI, but may not be suitable for all use cases. + */ + stripNewLines?: boolean; +} + +/** + * @example + * ```typescript + * const model = new HuggingFaceTransformersEmbeddings({ + * modelName: "Xenova/all-MiniLM-L6-v2", + * }); + * + * // Embed a single query + * const res = await model.embedQuery( + * "What would be a good company name for a company that makes colorful socks?" + * ); + * console.log({ res }); + * + * // Embed multiple documents + * const documentRes = await model.embedDocuments(["Hello world", "Bye bye"]); + * console.log({ documentRes }); + * ``` + */ +export class HuggingFaceTransformersEmbeddings + extends Embeddings + implements HuggingFaceTransformersEmbeddingsParams +{ + modelName = "Xenova/all-MiniLM-L6-v2"; + + batchSize = 512; + + stripNewLines = true; + + timeout?: number; + + private pipelinePromise: Promise; + + constructor(fields?: Partial) { + super(fields ?? {}); + + this.modelName = fields?.modelName ?? this.modelName; + this.stripNewLines = fields?.stripNewLines ?? this.stripNewLines; + this.timeout = fields?.timeout; + } + + async embedDocuments(texts: string[]): Promise { + const batches = chunkArray( + this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts, + this.batchSize + ); + + const batchRequests = batches.map((batch) => this.runEmbedding(batch)); + const batchResponses = await Promise.all(batchRequests); + const embeddings: number[][] = []; + + for (let i = 0; i < batchResponses.length; i += 1) { + const batchResponse = batchResponses[i]; + for (let j = 0; j < batchResponse.length; j += 1) { + embeddings.push(batchResponse[j]); + } + } + + return embeddings; + } + + async embedQuery(text: string): Promise { + const data = await this.runEmbedding([ + this.stripNewLines ? text.replace(/\n/g, " ") : text, + ]); + return data[0]; + } + + private async runEmbedding(texts: string[]) { + const pipe = await (this.pipelinePromise ??= pipeline( + "feature-extraction", + this.modelName + )); + + return this.caller.call(async () => { + const output = await pipe(texts, { pooling: "mean", normalize: true }); + return output.tolist(); + }); + } +} diff --git a/libs/langchain-community/src/embeddings/llama_cpp.ts b/libs/langchain-community/src/embeddings/llama_cpp.ts new file mode 100644 index 000000000000..aad4163ac499 --- /dev/null +++ b/libs/langchain-community/src/embeddings/llama_cpp.ts @@ -0,0 +1,103 @@ +import { LlamaModel, LlamaContext } from "node-llama-cpp"; +import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings"; +import { + LlamaBaseCppInputs, + createLlamaModel, + createLlamaContext, +} from "../utils/llama_cpp.js"; + +/** + * Note that the modelPath is the only required parameter. For testing you + * can set this in the environment variable `LLAMA_PATH`. + */ +export interface LlamaCppEmbeddingsParams + extends LlamaBaseCppInputs, + EmbeddingsParams {} + +/** + * @example + * ```typescript + * // Initialize LlamaCppEmbeddings with the path to the model file + * const embeddings = new LlamaCppEmbeddings({ + * modelPath: "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin", + * }); + * + * // Embed a query string using the Llama embeddings + * const res = embeddings.embedQuery("Hello Llama!"); + * + * // Output the resulting embeddings + * console.log(res); + * + * ``` + */ +export class LlamaCppEmbeddings extends Embeddings { + _model: LlamaModel; + + _context: LlamaContext; + + constructor(inputs: LlamaCppEmbeddingsParams) { + super(inputs); + const _inputs = inputs; + _inputs.embedding = true; + + this._model = createLlamaModel(_inputs); + this._context = createLlamaContext(this._model, _inputs); + } + + /** + * Generates embeddings for an array of texts. + * @param texts - An array of strings to generate embeddings for. + * @returns A Promise that resolves to an array of embeddings. + */ + async embedDocuments(texts: string[]): Promise { + const tokensArray = []; + + for (const text of texts) { + const encodings = await this.caller.call( + () => + new Promise((resolve) => { + resolve(this._context.encode(text)); + }) + ); + tokensArray.push(encodings); + } + + const embeddings: number[][] = []; + + for (const tokens of tokensArray) { + const embedArray: number[] = []; + + for (let i = 0; i < tokens.length; i += 1) { + const nToken: number = +tokens[i]; + embedArray.push(nToken); + } + + embeddings.push(embedArray); + } + + return embeddings; + } + + /** + * Generates an embedding for a single text. + * @param text - A string to generate an embedding for. + * @returns A Promise that resolves to an array of numbers representing the embedding. + */ + async embedQuery(text: string): Promise { + const tokens: number[] = []; + + const encodings = await this.caller.call( + () => + new Promise((resolve) => { + resolve(this._context.encode(text)); + }) + ); + + for (let i = 0; i < encodings.length; i += 1) { + const token: number = +encodings[i]; + tokens.push(token); + } + + return tokens; + } +} diff --git a/libs/langchain-community/src/embeddings/minimax.ts b/libs/langchain-community/src/embeddings/minimax.ts new file mode 100644 index 000000000000..bf4594f9b5d3 --- /dev/null +++ b/libs/langchain-community/src/embeddings/minimax.ts @@ -0,0 +1,222 @@ +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings"; +import { chunkArray } from "../utils/chunk.js"; +import { ConfigurationParameters } from "../chat_models/minimax.js"; + +/** + * Interface for MinimaxEmbeddings parameters. Extends EmbeddingsParams and + * defines additional parameters specific to the MinimaxEmbeddings class. + */ +export interface MinimaxEmbeddingsParams extends EmbeddingsParams { + /** Model name to use */ + modelName: string; + + /** + * API key to use when making requests. Defaults to the value of + * `MINIMAX_GROUP_ID` environment variable. + */ + minimaxGroupId?: string; + + /** + * Secret key to use when making requests. Defaults to the value of + * `MINIMAX_API_KEY` environment variable. + */ + minimaxApiKey?: string; + + /** + * The maximum number of documents to embed in a single request. This is + * limited by the Minimax API to a maximum of 4096. + */ + batchSize?: number; + + /** + * Whether to strip new lines from the input text. This is recommended by + * Minimax, but may not be suitable for all use cases. + */ + stripNewLines?: boolean; + + /** + * The target use-case after generating the vector. + * When using embeddings, the vector of the target content is first generated through the db and stored in the vector database, + * and then the vector of the retrieval text is generated through the query. + * Note: For the parameters of the partial algorithm, we adopted a separate algorithm plan for query and db. + * Therefore, for a paragraph of text, if it is to be used as a retrieval text, it should use the db, + * and if it is used as a retrieval text, it should use the query. + */ + type?: "db" | "query"; +} + +export interface CreateMinimaxEmbeddingRequest { + /** + * @type {string} + * @memberof CreateMinimaxEmbeddingRequest + */ + model: string; + + /** + * Text to generate vector expectation + * @type {CreateEmbeddingRequestInput} + * @memberof CreateMinimaxEmbeddingRequest + */ + texts: string[]; + + /** + * The target use-case after generating the vector. When using embeddings, + * first generate the vector of the target content through the db and store it in the vector database, + * and then generate the vector of the retrieval text through the query. + * Note: For the parameter of the algorithm, we use the algorithm scheme of query and db separation, + * so a text, if it is to be retrieved as a text, should use the db, + * if it is used as a retrieval text, should use the query. + * @type {string} + * @memberof CreateMinimaxEmbeddingRequest + */ + type: "db" | "query"; +} + +/** + * Class for generating embeddings using the Minimax API. Extends the + * Embeddings class and implements MinimaxEmbeddingsParams + * @example + * ```typescript + * const embeddings = new MinimaxEmbeddings(); + * + * // Embed a single query + * const queryEmbedding = await embeddings.embedQuery("Hello world"); + * console.log(queryEmbedding); + * + * // Embed multiple documents + * const documentsEmbedding = await embeddings.embedDocuments([ + * "Hello world", + * "Bye bye", + * ]); + * console.log(documentsEmbedding); + * ``` + */ +export class MinimaxEmbeddings + extends Embeddings + implements MinimaxEmbeddingsParams +{ + modelName = "embo-01"; + + batchSize = 512; + + stripNewLines = true; + + minimaxGroupId?: string; + + minimaxApiKey?: string; + + type: "db" | "query" = "db"; + + apiUrl: string; + + basePath?: string = "https://api.minimax.chat/v1"; + + headers?: Record; + + constructor( + fields?: Partial & { + configuration?: ConfigurationParameters; + } + ) { + const fieldsWithDefaults = { maxConcurrency: 2, ...fields }; + super(fieldsWithDefaults); + + this.minimaxGroupId = + fields?.minimaxGroupId ?? getEnvironmentVariable("MINIMAX_GROUP_ID"); + if (!this.minimaxGroupId) { + throw new Error("Minimax GroupID not found"); + } + + this.minimaxApiKey = + fields?.minimaxApiKey ?? getEnvironmentVariable("MINIMAX_API_KEY"); + + if (!this.minimaxApiKey) { + throw new Error("Minimax ApiKey not found"); + } + + this.modelName = fieldsWithDefaults?.modelName ?? this.modelName; + this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize; + this.type = fieldsWithDefaults?.type ?? this.type; + this.stripNewLines = + fieldsWithDefaults?.stripNewLines ?? this.stripNewLines; + this.basePath = fields?.configuration?.basePath ?? this.basePath; + this.apiUrl = `${this.basePath}/embeddings`; + this.headers = fields?.configuration?.headers ?? this.headers; + } + + /** + * Method to generate embeddings for an array of documents. Splits the + * documents into batches and makes requests to the Minimax API to generate + * embeddings. + * @param texts Array of documents to generate embeddings for. + * @returns Promise that resolves to a 2D array of embeddings for each document. + */ + async embedDocuments(texts: string[]): Promise { + const batches = chunkArray( + this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts, + this.batchSize + ); + + const batchRequests = batches.map((batch) => + this.embeddingWithRetry({ + model: this.modelName, + texts: batch, + type: this.type, + }) + ); + const batchResponses = await Promise.all(batchRequests); + + const embeddings: number[][] = []; + for (let i = 0; i < batchResponses.length; i += 1) { + const batch = batches[i]; + const { vectors: batchResponse } = batchResponses[i]; + for (let j = 0; j < batch.length; j += 1) { + embeddings.push(batchResponse[j]); + } + } + return embeddings; + } + + /** + * Method to generate an embedding for a single document. Calls the + * embeddingWithRetry method with the document as the input. + * @param text Document to generate an embedding for. + * @returns Promise that resolves to an embedding for the document. + */ + async embedQuery(text: string): Promise { + const { vectors } = await this.embeddingWithRetry({ + model: this.modelName, + texts: [this.stripNewLines ? text.replace(/\n/g, " ") : text], + type: this.type, + }); + return vectors[0]; + } + + /** + * Private method to make a request to the Minimax API to generate + * embeddings. Handles the retry logic and returns the response from the + * API. + * @param request Request to send to the Minimax API. + * @returns Promise that resolves to the response from the API. + */ + private async embeddingWithRetry(request: CreateMinimaxEmbeddingRequest) { + const makeCompletionRequest = async () => { + const url = `${this.apiUrl}?GroupId=${this.minimaxGroupId}`; + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${this.minimaxApiKey}`, + ...this.headers, + }, + body: JSON.stringify(request), + }); + + const json = await response.json(); + return json; + }; + + return this.caller.call(makeCompletionRequest); + } +} diff --git a/libs/langchain-community/src/embeddings/ollama.ts b/libs/langchain-community/src/embeddings/ollama.ts new file mode 100644 index 000000000000..b1e63b8d7005 --- /dev/null +++ b/libs/langchain-community/src/embeddings/ollama.ts @@ -0,0 +1,148 @@ +import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings"; +import { OllamaInput, OllamaRequestParams } from "../utils/ollama.js"; + +type CamelCasedRequestOptions = Omit< + OllamaInput, + "baseUrl" | "model" | "format" +>; + +/** + * Interface for OllamaEmbeddings parameters. Extends EmbeddingsParams and + * defines additional parameters specific to the OllamaEmbeddings class. + */ +interface OllamaEmbeddingsParams extends EmbeddingsParams { + /** The Ollama model to use, e.g: "llama2:13b" */ + model?: string; + + /** Base URL of the Ollama server, defaults to "http://localhost:11434" */ + baseUrl?: string; + + /** Advanced Ollama API request parameters in camelCase, see + * https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values + * for details of the available parameters. + */ + requestOptions?: CamelCasedRequestOptions; +} + +export class OllamaEmbeddings extends Embeddings { + model = "llama2"; + + baseUrl = "http://localhost:11434"; + + requestOptions?: OllamaRequestParams["options"]; + + constructor(params?: OllamaEmbeddingsParams) { + super(params || {}); + + if (params?.model) { + this.model = params.model; + } + + if (params?.baseUrl) { + this.baseUrl = params.baseUrl; + } + + if (params?.requestOptions) { + this.requestOptions = this._convertOptions(params.requestOptions); + } + } + + /** convert camelCased Ollama request options like "useMMap" to + * the snake_cased equivalent which the ollama API actually uses. + * Used only for consistency with the llms/Ollama and chatModels/Ollama classes + */ + _convertOptions(requestOptions: CamelCasedRequestOptions) { + const snakeCasedOptions: Record = {}; + const mapping: Record = { + embeddingOnly: "embedding_only", + f16KV: "f16_kv", + frequencyPenalty: "frequency_penalty", + logitsAll: "logits_all", + lowVram: "low_vram", + mainGpu: "main_gpu", + mirostat: "mirostat", + mirostatEta: "mirostat_eta", + mirostatTau: "mirostat_tau", + numBatch: "num_batch", + numCtx: "num_ctx", + numGpu: "num_gpu", + numGqa: "num_gqa", + numKeep: "num_keep", + numThread: "num_thread", + penalizeNewline: "penalize_newline", + presencePenalty: "presence_penalty", + repeatLastN: "repeat_last_n", + repeatPenalty: "repeat_penalty", + ropeFrequencyBase: "rope_frequency_base", + ropeFrequencyScale: "rope_frequency_scale", + temperature: "temperature", + stop: "stop", + tfsZ: "tfs_z", + topK: "top_k", + topP: "top_p", + typicalP: "typical_p", + useMLock: "use_mlock", + useMMap: "use_mmap", + vocabOnly: "vocab_only", + }; + + for (const [key, value] of Object.entries(requestOptions)) { + const snakeCasedOption = mapping[key as keyof CamelCasedRequestOptions]; + if (snakeCasedOption) { + snakeCasedOptions[snakeCasedOption] = value; + } + } + return snakeCasedOptions; + } + + async _request(prompt: string): Promise { + const { model, baseUrl, requestOptions } = this; + + let formattedBaseUrl = baseUrl; + if (formattedBaseUrl.startsWith("http://localhost:")) { + // Node 18 has issues with resolving "localhost" + // See https://github.com/node-fetch/node-fetch/issues/1624 + formattedBaseUrl = formattedBaseUrl.replace( + "http://localhost:", + "http://127.0.0.1:" + ); + } + + const response = await fetch(`${formattedBaseUrl}/api/embeddings`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + prompt, + model, + options: requestOptions, + }), + }); + if (!response.ok) { + throw new Error( + `Request to Ollama server failed: ${response.status} ${response.statusText}` + ); + } + + const json = await response.json(); + return json.embedding; + } + + async _embed(strings: string[]): Promise { + const embeddings: number[][] = []; + + for await (const prompt of strings) { + const embedding = await this.caller.call(() => this._request(prompt)); + embeddings.push(embedding); + } + + return embeddings; + } + + async embedDocuments(documents: string[]) { + return this._embed(documents); + } + + async embedQuery(document: string) { + return (await this.embedDocuments([document]))[0]; + } +} diff --git a/libs/langchain-community/src/embeddings/tensorflow.ts b/libs/langchain-community/src/embeddings/tensorflow.ts new file mode 100644 index 000000000000..ae96f8cc9b30 --- /dev/null +++ b/libs/langchain-community/src/embeddings/tensorflow.ts @@ -0,0 +1,91 @@ +import { load } from "@tensorflow-models/universal-sentence-encoder"; +import * as tf from "@tensorflow/tfjs-core"; + +import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings"; + +/** + * Interface that extends EmbeddingsParams and defines additional + * parameters specific to the TensorFlowEmbeddings class. + */ +export interface TensorFlowEmbeddingsParams extends EmbeddingsParams {} + +/** + * Class that extends the Embeddings class and provides methods for + * generating embeddings using the Universal Sentence Encoder model from + * TensorFlow.js. + * @example + * ```typescript + * const embeddings = new TensorFlowEmbeddings(); + * const store = new MemoryVectorStore(embeddings); + * + * const documents = [ + * "A document", + * "Some other piece of text", + * "One more", + * "And another", + * ]; + * + * await store.addDocuments( + * documents.map((pageContent) => new Document({ pageContent })) + * ); + * ``` + */ +export class TensorFlowEmbeddings extends Embeddings { + constructor(fields?: TensorFlowEmbeddingsParams) { + super(fields ?? {}); + + try { + tf.backend(); + } catch (e) { + throw new Error("No TensorFlow backend found, see instructions at ..."); + } + } + + _cached: ReturnType; + + /** + * Private method that loads the Universal Sentence Encoder model if it + * hasn't been loaded already. It returns a promise that resolves to the + * loaded model. + * @returns Promise that resolves to the loaded Universal Sentence Encoder model. + */ + private async load() { + if (this._cached === undefined) { + this._cached = load(); + } + return this._cached; + } + + private _embed(texts: string[]) { + return this.caller.call(async () => { + const model = await this.load(); + return model.embed(texts); + }); + } + + /** + * Method that takes a document as input and returns a promise that + * resolves to an embedding for the document. It calls the _embed method + * with the document as the input and processes the result to return a + * single embedding. + * @param document Document to generate an embedding for. + * @returns Promise that resolves to an embedding for the input document. + */ + embedQuery(document: string): Promise { + return this._embed([document]) + .then((embeddings) => embeddings.array()) + .then((embeddings) => embeddings[0]); + } + + /** + * Method that takes an array of documents as input and returns a promise + * that resolves to a 2D array of embeddings for each document. It calls + * the _embed method with the documents as the input and processes the + * result to return the embeddings. + * @param documents Array of documents to generate embeddings for. + * @returns Promise that resolves to a 2D array of embeddings for each input document. + */ + embedDocuments(documents: string[]): Promise { + return this._embed(documents).then((embeddings) => embeddings.array()); + } +} diff --git a/langchain/src/embeddings/tests/bedrock.int.test.ts b/libs/langchain-community/src/embeddings/tests/bedrock.int.test.ts similarity index 100% rename from langchain/src/embeddings/tests/bedrock.int.test.ts rename to libs/langchain-community/src/embeddings/tests/bedrock.int.test.ts diff --git a/langchain/src/embeddings/tests/cohere.int.test.ts b/libs/langchain-community/src/embeddings/tests/cohere.int.test.ts similarity index 100% rename from langchain/src/embeddings/tests/cohere.int.test.ts rename to libs/langchain-community/src/embeddings/tests/cohere.int.test.ts diff --git a/langchain/src/embeddings/tests/googlepalm.int.test.ts b/libs/langchain-community/src/embeddings/tests/googlepalm.int.test.ts similarity index 100% rename from langchain/src/embeddings/tests/googlepalm.int.test.ts rename to libs/langchain-community/src/embeddings/tests/googlepalm.int.test.ts diff --git a/langchain/src/embeddings/tests/googlevertexai.int.test.ts b/libs/langchain-community/src/embeddings/tests/googlevertexai.int.test.ts similarity index 100% rename from langchain/src/embeddings/tests/googlevertexai.int.test.ts rename to libs/langchain-community/src/embeddings/tests/googlevertexai.int.test.ts diff --git a/langchain/src/embeddings/tests/hf.int.test.ts b/libs/langchain-community/src/embeddings/tests/hf.int.test.ts similarity index 95% rename from langchain/src/embeddings/tests/hf.int.test.ts rename to libs/langchain-community/src/embeddings/tests/hf.int.test.ts index 7de4c63dd40d..24abd5784abe 100644 --- a/langchain/src/embeddings/tests/hf.int.test.ts +++ b/libs/langchain-community/src/embeddings/tests/hf.int.test.ts @@ -1,7 +1,7 @@ import { test, expect } from "@jest/globals"; +import { Document } from "@langchain/core/documents"; import { HuggingFaceInferenceEmbeddings } from "../hf.js"; import { MemoryVectorStore } from "../../vectorstores/memory.js"; -import { Document } from "../../document.js"; test("HuggingFaceInferenceEmbeddings", async () => { const embeddings = new HuggingFaceInferenceEmbeddings(); diff --git a/langchain/src/embeddings/tests/hf_transformers.int.test.ts b/libs/langchain-community/src/embeddings/tests/hf_transformers.int.test.ts similarity index 95% rename from langchain/src/embeddings/tests/hf_transformers.int.test.ts rename to libs/langchain-community/src/embeddings/tests/hf_transformers.int.test.ts index 0a15a8e1d130..5948d2f0bf90 100644 --- a/langchain/src/embeddings/tests/hf_transformers.int.test.ts +++ b/libs/langchain-community/src/embeddings/tests/hf_transformers.int.test.ts @@ -1,7 +1,7 @@ import { test, expect } from "@jest/globals"; +import { Document } from "@langchain/core/documents"; import { HuggingFaceTransformersEmbeddings } from "../hf_transformers.js"; import { MemoryVectorStore } from "../../vectorstores/memory.js"; -import { Document } from "../../document.js"; test("HuggingFaceTransformersEmbeddings", async () => { const embeddings = new HuggingFaceTransformersEmbeddings(); diff --git a/langchain/src/embeddings/tests/llama_cpp.int.test.ts b/libs/langchain-community/src/embeddings/tests/llama_cpp.int.test.ts similarity index 94% rename from langchain/src/embeddings/tests/llama_cpp.int.test.ts rename to libs/langchain-community/src/embeddings/tests/llama_cpp.int.test.ts index 5ec6e33d4d6d..b1819f943a21 100644 --- a/langchain/src/embeddings/tests/llama_cpp.int.test.ts +++ b/libs/langchain-community/src/embeddings/tests/llama_cpp.int.test.ts @@ -1,7 +1,7 @@ /* eslint-disable @typescript-eslint/no-non-null-assertion */ import { test, expect } from "@jest/globals"; -import { getEnvironmentVariable } from "../../util/env.js"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; import { LlamaCppEmbeddings } from "../llama_cpp.js"; const llamaPath = getEnvironmentVariable("LLAMA_PATH")!; diff --git a/langchain/src/embeddings/tests/minimax.int.test.ts b/libs/langchain-community/src/embeddings/tests/minimax.int.test.ts similarity index 100% rename from langchain/src/embeddings/tests/minimax.int.test.ts rename to libs/langchain-community/src/embeddings/tests/minimax.int.test.ts diff --git a/langchain/src/embeddings/tests/ollama.int.test.ts b/libs/langchain-community/src/embeddings/tests/ollama.int.test.ts similarity index 100% rename from langchain/src/embeddings/tests/ollama.int.test.ts rename to libs/langchain-community/src/embeddings/tests/ollama.int.test.ts diff --git a/langchain/src/embeddings/tests/tensorflow.int.test.ts b/libs/langchain-community/src/embeddings/tests/tensorflow.int.test.ts similarity index 95% rename from langchain/src/embeddings/tests/tensorflow.int.test.ts rename to libs/langchain-community/src/embeddings/tests/tensorflow.int.test.ts index 4c51ef04a3a1..a471147572d4 100644 --- a/langchain/src/embeddings/tests/tensorflow.int.test.ts +++ b/libs/langchain-community/src/embeddings/tests/tensorflow.int.test.ts @@ -1,8 +1,8 @@ import { test, expect } from "@jest/globals"; import "@tensorflow/tfjs-backend-cpu"; +import { Document } from "@langchain/core/documents"; import { TensorFlowEmbeddings } from "../tensorflow.js"; import { MemoryVectorStore } from "../../vectorstores/memory.js"; -import { Document } from "../../document.js"; test("TensorflowEmbeddings", async () => { const embeddings = new TensorFlowEmbeddings(); diff --git a/langchain/src/embeddings/tests/voyage.int.test.ts b/libs/langchain-community/src/embeddings/tests/voyage.int.test.ts similarity index 100% rename from langchain/src/embeddings/tests/voyage.int.test.ts rename to libs/langchain-community/src/embeddings/tests/voyage.int.test.ts diff --git a/libs/langchain-community/src/embeddings/voyage.ts b/libs/langchain-community/src/embeddings/voyage.ts new file mode 100644 index 000000000000..8a023d5d7e28 --- /dev/null +++ b/libs/langchain-community/src/embeddings/voyage.ts @@ -0,0 +1,152 @@ +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings"; +import { chunkArray } from "../utils/chunk.js"; + +/** + * Interface that extends EmbeddingsParams and defines additional + * parameters specific to the VoyageEmbeddings class. + */ +export interface VoyageEmbeddingsParams extends EmbeddingsParams { + modelName: string; + + /** + * The maximum number of documents to embed in a single request. This is + * limited by the Voyage AI API to a maximum of 8. + */ + batchSize?: number; +} + +/** + * Interface for the request body to generate embeddings. + */ +export interface CreateVoyageEmbeddingRequest { + /** + * @type {string} + * @memberof CreateVoyageEmbeddingRequest + */ + model: string; + + /** + * Text to generate vector expectation + * @type {CreateEmbeddingRequestInput} + * @memberof CreateVoyageEmbeddingRequest + */ + input: string | string[]; +} + +/** + * A class for generating embeddings using the Voyage AI API. + */ +export class VoyageEmbeddings + extends Embeddings + implements VoyageEmbeddingsParams +{ + modelName = "voyage-01"; + + batchSize = 8; + + private apiKey: string; + + basePath?: string = "https://api.voyageai.com/v1"; + + apiUrl: string; + + headers?: Record; + + /** + * Constructor for the VoyageEmbeddings class. + * @param fields - An optional object with properties to configure the instance. + */ + constructor( + fields?: Partial & { + verbose?: boolean; + apiKey?: string; + } + ) { + const fieldsWithDefaults = { ...fields }; + + super(fieldsWithDefaults); + + const apiKey = + fieldsWithDefaults?.apiKey || getEnvironmentVariable("VOYAGEAI_API_KEY"); + + if (!apiKey) { + throw new Error("Voyage AI API key not found"); + } + + this.modelName = fieldsWithDefaults?.modelName ?? this.modelName; + this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize; + this.apiKey = apiKey; + this.apiUrl = `${this.basePath}/embeddings`; + } + + /** + * Generates embeddings for an array of texts. + * @param texts - An array of strings to generate embeddings for. + * @returns A Promise that resolves to an array of embeddings. + */ + async embedDocuments(texts: string[]): Promise { + const batches = chunkArray(texts, this.batchSize); + + const batchRequests = batches.map((batch) => + this.embeddingWithRetry({ + model: this.modelName, + input: batch, + }) + ); + + const batchResponses = await Promise.all(batchRequests); + + const embeddings: number[][] = []; + + for (let i = 0; i < batchResponses.length; i += 1) { + const batch = batches[i]; + const { data: batchResponse } = batchResponses[i]; + for (let j = 0; j < batch.length; j += 1) { + embeddings.push(batchResponse[j].embedding); + } + } + + return embeddings; + } + + /** + * Generates an embedding for a single text. + * @param text - A string to generate an embedding for. + * @returns A Promise that resolves to an array of numbers representing the embedding. + */ + async embedQuery(text: string): Promise { + const { data } = await this.embeddingWithRetry({ + model: this.modelName, + input: text, + }); + + return data[0].embedding; + } + + /** + * Makes a request to the Voyage AI API to generate embeddings for an array of texts. + * @param request - An object with properties to configure the request. + * @returns A Promise that resolves to the response from the Voyage AI API. + */ + + private async embeddingWithRetry(request: CreateVoyageEmbeddingRequest) { + const makeCompletionRequest = async () => { + const url = `${this.apiUrl}`; + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${this.apiKey}`, + ...this.headers, + }, + body: JSON.stringify(request), + }); + + const json = await response.json(); + return json; + }; + + return this.caller.call(makeCompletionRequest); + } +} diff --git a/libs/langchain-community/src/graphs/neo4j_graph.ts b/libs/langchain-community/src/graphs/neo4j_graph.ts new file mode 100644 index 000000000000..c404e7e3b2ad --- /dev/null +++ b/libs/langchain-community/src/graphs/neo4j_graph.ts @@ -0,0 +1,286 @@ +import neo4j, { Neo4jError } from "neo4j-driver"; + +interface Neo4jGraphConfig { + url: string; + username: string; + password: string; + database?: string; +} + +interface StructuredSchema { + nodeProps: { [key: NodeType["labels"]]: NodeType["properties"] }; + relProps: { [key: RelType["type"]]: RelType["properties"] }; + relationships: PathType[]; +} + +type NodeType = { + labels: string; + properties: { property: string; type: string }[]; +}; +type RelType = { + type: string; + properties: { property: string; type: string }[]; +}; +type PathType = { start: string; type: string; end: string }; + +/** + * @security *Security note*: Make sure that the database connection uses credentials + * that are narrowly-scoped to only include necessary permissions. + * Failure to do so may result in data corruption or loss, since the calling + * code may attempt commands that would result in deletion, mutation + * of data if appropriately prompted or reading sensitive data if such + * data is present in the database. + * The best way to guard against such negative outcomes is to (as appropriate) + * limit the permissions granted to the credentials used with this tool. + * For example, creating read only users for the database is a good way to + * ensure that the calling code cannot mutate or delete data. + * + * @link See https://js.langchain.com/docs/security for more information. + */ +export class Neo4jGraph { + private driver: neo4j.Driver; + + private database: string; + + private schema = ""; + + private structuredSchema: StructuredSchema = { + nodeProps: {}, + relProps: {}, + relationships: [], + }; + + constructor({ + url, + username, + password, + database = "neo4j", + }: Neo4jGraphConfig) { + try { + this.driver = neo4j.driver(url, neo4j.auth.basic(username, password)); + this.database = database; + } catch (error) { + throw new Error( + "Could not create a Neo4j driver instance. Please check the connection details." + ); + } + } + + static async initialize(config: Neo4jGraphConfig): Promise { + const graph = new Neo4jGraph(config); + + try { + await graph.verifyConnectivity(); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (error: any) { + console.log("Failed to verify connection."); + } + + try { + await graph.refreshSchema(); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (error: any) { + const message = [ + "Could not use APOC procedures.", + "Please ensure the APOC plugin is installed in Neo4j and that", + "'apoc.meta.data()' is allowed in Neo4j configuration", + ].join("\n"); + + throw new Error(message); + } finally { + console.log("Schema refreshed successfully."); + } + + return graph; + } + + getSchema(): string { + return this.schema; + } + + getStructuredSchema() { + return this.structuredSchema; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + async query(query: string, params: any = {}): Promise { + try { + const result = await this.driver.executeQuery(query, params, { + database: this.database, + }); + return toObjects(result.records); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (error: any) { + if ( + // eslint-disable-next-line + error instanceof Neo4jError && + error.code === "Neo.ClientError.Procedure.ProcedureNotFound" + ) { + throw new Error("Procedure not found in Neo4j."); + } + } + return undefined; + } + + async verifyConnectivity() { + await this.driver.verifyAuthentication(); + } + + async refreshSchema() { + const nodePropertiesQuery = ` + CALL apoc.meta.data() + YIELD label, other, elementType, type, property + WHERE NOT type = "RELATIONSHIP" AND elementType = "node" + WITH label AS nodeLabels, collect({property:property, type:type}) AS properties + RETURN {labels: nodeLabels, properties: properties} AS output + `; + + const relPropertiesQuery = ` + CALL apoc.meta.data() + YIELD label, other, elementType, type, property + WHERE NOT type = "RELATIONSHIP" AND elementType = "relationship" + WITH label AS nodeLabels, collect({property:property, type:type}) AS properties + RETURN {type: nodeLabels, properties: properties} AS output + `; + + const relQuery = ` + CALL apoc.meta.data() + YIELD label, other, elementType, type, property + WHERE type = "RELATIONSHIP" AND elementType = "node" + UNWIND other AS other_node + RETURN {start: label, type: property, end: toString(other_node)} AS output + `; + + // Assuming query method is defined and returns a Promise + const nodeProperties: NodeType[] | undefined = ( + await this.query(nodePropertiesQuery) + )?.map((el: { output: NodeType }) => el.output); + + const relationshipsProperties: RelType[] | undefined = ( + await this.query(relPropertiesQuery) + )?.map((el: { output: RelType }) => el.output); + + const relationships: PathType[] | undefined = ( + await this.query(relQuery) + )?.map((el: { output: PathType }) => el.output); + + // Structured schema similar to Python's dictionary comprehension + this.structuredSchema = { + nodeProps: Object.fromEntries( + nodeProperties?.map((el) => [el.labels, el.properties]) || [] + ), + relProps: Object.fromEntries( + relationshipsProperties?.map((el) => [el.type, el.properties]) || [] + ), + relationships: relationships || [], + }; + + // Format node properties + const formattedNodeProps = nodeProperties?.map((el) => { + const propsStr = el.properties + .map((prop) => `${prop.property}: ${prop.type}`) + .join(", "); + return `${el.labels} {${propsStr}}`; + }); + + // Format relationship properties + const formattedRelProps = relationshipsProperties?.map((el) => { + const propsStr = el.properties + .map((prop) => `${prop.property}: ${prop.type}`) + .join(", "); + return `${el.type} {${propsStr}}`; + }); + + // Format relationships + const formattedRels = relationships?.map( + (el) => `(:${el.start})-[:${el.type}]->(:${el.end})` + ); + + // Combine all formatted elements into a single string + this.schema = [ + "Node properties are the following:", + formattedNodeProps?.join(", "), + "Relationship properties are the following:", + formattedRelProps?.join(", "), + "The relationships are the following:", + formattedRels?.join(", "), + ].join("\n"); + } + + async close() { + await this.driver.close(); + } +} + +function toObjects(records: neo4j.Record[]) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const recordValues: Record[] = records.map((record) => { + const rObj = record.toObject(); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const out: { [key: string]: any } = {}; + Object.keys(rObj).forEach((key) => { + out[key] = itemIntToString(rObj[key]); + }); + return out; + }); + return recordValues; +} + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +function itemIntToString(item: any): any { + if (neo4j.isInt(item)) return item.toString(); + if (Array.isArray(item)) return item.map((ii) => itemIntToString(ii)); + if (["number", "string", "boolean"].indexOf(typeof item) !== -1) return item; + if (item === null) return item; + if (typeof item === "object") return objIntToString(item); +} + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +function objIntToString(obj: any) { + const entry = extractFromNeoObjects(obj); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let newObj: any = null; + if (Array.isArray(entry)) { + newObj = entry.map((item) => itemIntToString(item)); + } else if (entry !== null && typeof entry === "object") { + newObj = {}; + Object.keys(entry).forEach((key) => { + newObj[key] = itemIntToString(entry[key]); + }); + } + return newObj; +} + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +function extractFromNeoObjects(obj: any) { + if ( + // eslint-disable-next-line + obj instanceof (neo4j.types.Node as any) || + // eslint-disable-next-line + obj instanceof (neo4j.types.Relationship as any) + ) { + return obj.properties; + // eslint-disable-next-line + } else if (obj instanceof (neo4j.types.Path as any)) { + // eslint-disable-next-line + return [].concat.apply([], extractPathForRows(obj)); + } + return obj; +} + +const extractPathForRows = (path: neo4j.Path) => { + let { segments } = path; + // Zero length path. No relationship, end === start + if (!Array.isArray(path.segments) || path.segments.length < 1) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + segments = [{ ...path, end: null } as any]; + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return segments.map((segment: any) => + [ + objIntToString(segment.start), + objIntToString(segment.relationship), + objIntToString(segment.end), + ].filter((part) => part !== null) + ); +}; diff --git a/libs/langchain-community/src/index.ts b/libs/langchain-community/src/index.ts new file mode 100644 index 000000000000..d15abba59766 --- /dev/null +++ b/libs/langchain-community/src/index.ts @@ -0,0 +1 @@ +// Empty diff --git a/libs/langchain-community/src/llms/ai21.ts b/libs/langchain-community/src/llms/ai21.ts new file mode 100644 index 000000000000..2def8056f650 --- /dev/null +++ b/libs/langchain-community/src/llms/ai21.ts @@ -0,0 +1,201 @@ +import { LLM, type BaseLLMParams } from "@langchain/core/language_models/llms"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +/** + * Type definition for AI21 penalty data. + */ +export type AI21PenaltyData = { + scale: number; + applyToWhitespaces: boolean; + applyToPunctuations: boolean; + applyToNumbers: boolean; + applyToStopwords: boolean; + applyToEmojis: boolean; +}; + +/** + * Interface for AI21 input parameters. + */ +export interface AI21Input extends BaseLLMParams { + ai21ApiKey?: string; + model?: string; + temperature?: number; + minTokens?: number; + maxTokens?: number; + topP?: number; + presencePenalty?: AI21PenaltyData; + countPenalty?: AI21PenaltyData; + frequencyPenalty?: AI21PenaltyData; + numResults?: number; + logitBias?: Record; + stop?: string[]; + baseUrl?: string; +} + +/** + * Class representing the AI21 language model. It extends the LLM (Large + * Language Model) class, providing a standard interface for interacting + * with the AI21 language model. + */ +export class AI21 extends LLM implements AI21Input { + lc_serializable = true; + + model = "j2-jumbo-instruct"; + + temperature = 0.7; + + maxTokens = 1024; + + minTokens = 0; + + topP = 1; + + presencePenalty = AI21.getDefaultAI21PenaltyData(); + + countPenalty = AI21.getDefaultAI21PenaltyData(); + + frequencyPenalty = AI21.getDefaultAI21PenaltyData(); + + numResults = 1; + + logitBias?: Record; + + ai21ApiKey?: string; + + stop?: string[]; + + baseUrl?: string; + + constructor(fields?: AI21Input) { + super(fields ?? {}); + + this.model = fields?.model ?? this.model; + this.temperature = fields?.temperature ?? this.temperature; + this.maxTokens = fields?.maxTokens ?? this.maxTokens; + this.minTokens = fields?.minTokens ?? this.minTokens; + this.topP = fields?.topP ?? this.topP; + this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty; + this.countPenalty = fields?.countPenalty ?? this.countPenalty; + this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty; + this.numResults = fields?.numResults ?? this.numResults; + this.logitBias = fields?.logitBias; + this.ai21ApiKey = + fields?.ai21ApiKey ?? getEnvironmentVariable("AI21_API_KEY"); + this.stop = fields?.stop; + this.baseUrl = fields?.baseUrl; + } + + /** + * Method to validate the environment. It checks if the AI21 API key is + * set. If not, it throws an error. + */ + validateEnvironment() { + if (!this.ai21ApiKey) { + throw new Error( + `No AI21 API key found. Please set it as "AI21_API_KEY" in your environment variables.` + ); + } + } + + /** + * Static method to get the default penalty data for AI21. + * @returns AI21PenaltyData + */ + static getDefaultAI21PenaltyData(): AI21PenaltyData { + return { + scale: 0, + applyToWhitespaces: true, + applyToPunctuations: true, + applyToNumbers: true, + applyToStopwords: true, + applyToEmojis: true, + }; + } + + /** Get the type of LLM. */ + _llmType() { + return "ai21"; + } + + /** Get the default parameters for calling AI21 API. */ + get defaultParams() { + return { + temperature: this.temperature, + maxTokens: this.maxTokens, + minTokens: this.minTokens, + topP: this.topP, + presencePenalty: this.presencePenalty, + countPenalty: this.countPenalty, + frequencyPenalty: this.frequencyPenalty, + numResults: this.numResults, + logitBias: this.logitBias, + }; + } + + /** Get the identifying parameters for this LLM. */ + get identifyingParams() { + return { ...this.defaultParams, model: this.model }; + } + + /** Call out to AI21's complete endpoint. + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + let response = ai21._call("Tell me a joke."); + */ + async _call( + prompt: string, + options: this["ParsedCallOptions"] + ): Promise { + let stop = options?.stop; + this.validateEnvironment(); + if (this.stop && stop && this.stop.length > 0 && stop.length > 0) { + throw new Error("`stop` found in both the input and default params."); + } + stop = this.stop ?? stop ?? []; + + const baseUrl = + this.baseUrl ?? this.model === "j1-grande-instruct" + ? "https://api.ai21.com/studio/v1/experimental" + : "https://api.ai21.com/studio/v1"; + + const url = `${baseUrl}/${this.model}/complete`; + const headers = { + Authorization: `Bearer ${this.ai21ApiKey}`, + "Content-Type": "application/json", + }; + const data = { prompt, stopSequences: stop, ...this.defaultParams }; + const responseData = await this.caller.callWithOptions({}, async () => { + const response = await fetch(url, { + method: "POST", + headers, + body: JSON.stringify(data), + signal: options.signal, + }); + if (!response.ok) { + const error = new Error( + `AI21 call failed with status code ${response.status}` + ); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (error as any).response = response; + throw error; + } + return response.json(); + }); + + if ( + !responseData.completions || + responseData.completions.length === 0 || + !responseData.completions[0].data + ) { + throw new Error("No completions found in response"); + } + + return responseData.completions[0].data.text ?? ""; + } +} diff --git a/libs/langchain-community/src/llms/aleph_alpha.ts b/libs/langchain-community/src/llms/aleph_alpha.ts new file mode 100644 index 000000000000..412604ce5227 --- /dev/null +++ b/libs/langchain-community/src/llms/aleph_alpha.ts @@ -0,0 +1,300 @@ +import { LLM, type BaseLLMParams } from "@langchain/core/language_models/llms"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +/** + * Interface for the input parameters specific to the Aleph Alpha LLM. + */ +export interface AlephAlphaInput extends BaseLLMParams { + model: string; + maximum_tokens: number; + minimum_tokens?: number; + echo?: boolean; + temperature?: number; + top_k?: number; + top_p?: number; + presence_penalty?: number; + frequency_penalty?: number; + sequence_penalty?: number; + sequence_penalty_min_length?: number; + repetition_penalties_include_prompt?: boolean; + repetition_penalties_include_completion?: boolean; + use_multiplicative_presence_penalty?: boolean; + use_multiplicative_frequency_penalty?: boolean; + use_multiplicative_sequence_penalty?: boolean; + penalty_bias?: string; + penalty_exceptions?: string[]; + penalty_exceptions_include_stop_sequences?: boolean; + best_of?: number; + n?: number; + logit_bias?: object; + log_probs?: number; + tokens?: boolean; + raw_completion: boolean; + disable_optimizations?: boolean; + completion_bias_inclusion?: string[]; + completion_bias_inclusion_first_token_only: boolean; + completion_bias_exclusion?: string[]; + completion_bias_exclusion_first_token_only: boolean; + contextual_control_threshold?: number; + control_log_additive: boolean; + stop?: string[]; + aleph_alpha_api_key?: string; + base_url: string; +} + +/** + * Specific implementation of a Large Language Model (LLM) designed to + * interact with the Aleph Alpha API. It extends the base LLM class and + * includes a variety of parameters for customizing the behavior of the + * Aleph Alpha model. + */ +export class AlephAlpha extends LLM implements AlephAlphaInput { + lc_serializable = true; + + model = "luminous-base"; + + maximum_tokens = 64; + + minimum_tokens = 0; + + echo: boolean; + + temperature = 0.0; + + top_k: number; + + top_p = 0.0; + + presence_penalty?: number; + + frequency_penalty?: number; + + sequence_penalty?: number; + + sequence_penalty_min_length?: number; + + repetition_penalties_include_prompt?: boolean; + + repetition_penalties_include_completion?: boolean; + + use_multiplicative_presence_penalty?: boolean; + + use_multiplicative_frequency_penalty?: boolean; + + use_multiplicative_sequence_penalty?: boolean; + + penalty_bias?: string; + + penalty_exceptions?: string[]; + + penalty_exceptions_include_stop_sequences?: boolean; + + best_of?: number; + + n?: number; + + logit_bias?: object; + + log_probs?: number; + + tokens?: boolean; + + raw_completion: boolean; + + disable_optimizations?: boolean; + + completion_bias_inclusion?: string[]; + + completion_bias_inclusion_first_token_only: boolean; + + completion_bias_exclusion?: string[]; + + completion_bias_exclusion_first_token_only: boolean; + + contextual_control_threshold?: number; + + control_log_additive: boolean; + + aleph_alpha_api_key? = getEnvironmentVariable("ALEPH_ALPHA_API_KEY"); + + stop?: string[]; + + base_url = "https://api.aleph-alpha.com/complete"; + + constructor(fields: Partial) { + super(fields ?? {}); + this.model = fields?.model ?? this.model; + this.temperature = fields?.temperature ?? this.temperature; + this.maximum_tokens = fields?.maximum_tokens ?? this.maximum_tokens; + this.minimum_tokens = fields?.minimum_tokens ?? this.minimum_tokens; + this.top_k = fields?.top_k ?? this.top_k; + this.top_p = fields?.top_p ?? this.top_p; + this.presence_penalty = fields?.presence_penalty ?? this.presence_penalty; + this.frequency_penalty = + fields?.frequency_penalty ?? this.frequency_penalty; + this.sequence_penalty = fields?.sequence_penalty ?? this.sequence_penalty; + this.sequence_penalty_min_length = + fields?.sequence_penalty_min_length ?? this.sequence_penalty_min_length; + this.repetition_penalties_include_prompt = + fields?.repetition_penalties_include_prompt ?? + this.repetition_penalties_include_prompt; + this.repetition_penalties_include_completion = + fields?.repetition_penalties_include_completion ?? + this.repetition_penalties_include_completion; + this.use_multiplicative_presence_penalty = + fields?.use_multiplicative_presence_penalty ?? + this.use_multiplicative_presence_penalty; + this.use_multiplicative_frequency_penalty = + fields?.use_multiplicative_frequency_penalty ?? + this.use_multiplicative_frequency_penalty; + this.use_multiplicative_sequence_penalty = + fields?.use_multiplicative_sequence_penalty ?? + this.use_multiplicative_sequence_penalty; + this.penalty_bias = fields?.penalty_bias ?? this.penalty_bias; + this.penalty_exceptions = + fields?.penalty_exceptions ?? this.penalty_exceptions; + this.penalty_exceptions_include_stop_sequences = + fields?.penalty_exceptions_include_stop_sequences ?? + this.penalty_exceptions_include_stop_sequences; + this.best_of = fields?.best_of ?? this.best_of; + this.n = fields?.n ?? this.n; + this.logit_bias = fields?.logit_bias ?? this.logit_bias; + this.log_probs = fields?.log_probs ?? this.log_probs; + this.tokens = fields?.tokens ?? this.tokens; + this.raw_completion = fields?.raw_completion ?? this.raw_completion; + this.disable_optimizations = + fields?.disable_optimizations ?? this.disable_optimizations; + this.completion_bias_inclusion = + fields?.completion_bias_inclusion ?? this.completion_bias_inclusion; + this.completion_bias_inclusion_first_token_only = + fields?.completion_bias_inclusion_first_token_only ?? + this.completion_bias_inclusion_first_token_only; + this.completion_bias_exclusion = + fields?.completion_bias_exclusion ?? this.completion_bias_exclusion; + this.completion_bias_exclusion_first_token_only = + fields?.completion_bias_exclusion_first_token_only ?? + this.completion_bias_exclusion_first_token_only; + this.contextual_control_threshold = + fields?.contextual_control_threshold ?? this.contextual_control_threshold; + this.control_log_additive = + fields?.control_log_additive ?? this.control_log_additive; + this.aleph_alpha_api_key = + fields?.aleph_alpha_api_key ?? this.aleph_alpha_api_key; + this.stop = fields?.stop ?? this.stop; + } + + /** + * Validates the environment by ensuring the necessary Aleph Alpha API key + * is available. Throws an error if the API key is missing. + */ + validateEnvironment() { + if (!this.aleph_alpha_api_key) { + throw new Error( + "Aleph Alpha API Key is missing in environment variables." + ); + } + } + + /** Get the default parameters for calling Aleph Alpha API. */ + get defaultParams() { + return { + model: this.model, + temperature: this.temperature, + maximum_tokens: this.maximum_tokens, + minimum_tokens: this.minimum_tokens, + top_k: this.top_k, + top_p: this.top_p, + presence_penalty: this.presence_penalty, + frequency_penalty: this.frequency_penalty, + sequence_penalty: this.sequence_penalty, + sequence_penalty_min_length: this.sequence_penalty_min_length, + repetition_penalties_include_prompt: + this.repetition_penalties_include_prompt, + repetition_penalties_include_completion: + this.repetition_penalties_include_completion, + use_multiplicative_presence_penalty: + this.use_multiplicative_presence_penalty, + use_multiplicative_frequency_penalty: + this.use_multiplicative_frequency_penalty, + use_multiplicative_sequence_penalty: + this.use_multiplicative_sequence_penalty, + penalty_bias: this.penalty_bias, + penalty_exceptions: this.penalty_exceptions, + penalty_exceptions_include_stop_sequences: + this.penalty_exceptions_include_stop_sequences, + best_of: this.best_of, + n: this.n, + logit_bias: this.logit_bias, + log_probs: this.log_probs, + tokens: this.tokens, + raw_completion: this.raw_completion, + disable_optimizations: this.disable_optimizations, + completion_bias_inclusion: this.completion_bias_inclusion, + completion_bias_inclusion_first_token_only: + this.completion_bias_inclusion_first_token_only, + completion_bias_exclusion: this.completion_bias_exclusion, + completion_bias_exclusion_first_token_only: + this.completion_bias_exclusion_first_token_only, + contextual_control_threshold: this.contextual_control_threshold, + control_log_additive: this.control_log_additive, + }; + } + + /** Get the identifying parameters for this LLM. */ + get identifyingParams() { + return { ...this.defaultParams }; + } + + /** Get the type of LLM. */ + _llmType(): string { + return "aleph_alpha"; + } + + async _call( + prompt: string, + options: this["ParsedCallOptions"] + ): Promise { + let stop = options?.stop; + this.validateEnvironment(); + if (this.stop && stop && this.stop.length > 0 && stop.length > 0) { + throw new Error("`stop` found in both the input and default params."); + } + stop = this.stop ?? stop ?? []; + const headers = { + Authorization: `Bearer ${this.aleph_alpha_api_key}`, + "Content-Type": "application/json", + Accept: "application/json", + }; + const data = { prompt, stop_sequences: stop, ...this.defaultParams }; + const responseData = await this.caller.call(async () => { + const response = await fetch(this.base_url, { + method: "POST", + headers, + body: JSON.stringify(data), + signal: options.signal, + }); + if (!response.ok) { + // consume the response body to release the connection + // https://undici.nodejs.org/#/?id=garbage-collection + const text = await response.text(); + const error = new Error( + `Aleph Alpha call failed with status ${response.status} and body ${text}` + ); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (error as any).response = response; + throw error; + } + return response.json(); + }); + + if ( + !responseData.completions || + responseData.completions.length === 0 || + !responseData.completions[0].completion + ) { + throw new Error("No completions found in response"); + } + + return responseData.completions[0].completion ?? ""; + } +} diff --git a/libs/langchain-community/src/llms/bedrock/index.ts b/libs/langchain-community/src/llms/bedrock/index.ts new file mode 100644 index 000000000000..64f39a279671 --- /dev/null +++ b/libs/langchain-community/src/llms/bedrock/index.ts @@ -0,0 +1,17 @@ +import { defaultProvider } from "@aws-sdk/credential-provider-node"; +import type { BaseLLMParams } from "@langchain/core/language_models/llms"; +import { BaseBedrockInput } from "../../utils/bedrock.js"; +import { Bedrock as BaseBedrock } from "./web.js"; + +export class Bedrock extends BaseBedrock { + static lc_name() { + return "Bedrock"; + } + + constructor(fields?: Partial & BaseLLMParams) { + super({ + ...fields, + credentials: fields?.credentials ?? defaultProvider(), + }); + } +} diff --git a/libs/langchain-community/src/llms/bedrock/web.ts b/libs/langchain-community/src/llms/bedrock/web.ts new file mode 100644 index 000000000000..2c9afa62bfb0 --- /dev/null +++ b/libs/langchain-community/src/llms/bedrock/web.ts @@ -0,0 +1,357 @@ +import { SignatureV4 } from "@smithy/signature-v4"; + +import { HttpRequest } from "@smithy/protocol-http"; +import { EventStreamCodec } from "@smithy/eventstream-codec"; +import { fromUtf8, toUtf8 } from "@smithy/util-utf8"; +import { Sha256 } from "@aws-crypto/sha256-js"; + +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { GenerationChunk } from "@langchain/core/outputs"; +import { LLM, type BaseLLMParams } from "@langchain/core/language_models/llms"; + +import { + BaseBedrockInput, + BedrockLLMInputOutputAdapter, + type CredentialType, +} from "../../utils/bedrock.js"; +import type { SerializedFields } from "../../load/map_keys.js"; + +/** + * A type of Large Language Model (LLM) that interacts with the Bedrock + * service. It extends the base `LLM` class and implements the + * `BaseBedrockInput` interface. The class is designed to authenticate and + * interact with the Bedrock service, which is a part of Amazon Web + * Services (AWS). It uses AWS credentials for authentication and can be + * configured with various parameters such as the model to use, the AWS + * region, and the maximum number of tokens to generate. + */ +export class Bedrock extends LLM implements BaseBedrockInput { + model = "amazon.titan-tg1-large"; + + region: string; + + credentials: CredentialType; + + temperature?: number | undefined = undefined; + + maxTokens?: number | undefined = undefined; + + fetchFn: typeof fetch; + + endpointHost?: string; + + /** @deprecated */ + stopSequences?: string[]; + + modelKwargs?: Record; + + codec: EventStreamCodec = new EventStreamCodec(toUtf8, fromUtf8); + + streaming = false; + + lc_serializable = true; + + get lc_aliases(): Record { + return { + model: "model_id", + region: "region_name", + }; + } + + get lc_secrets(): { [key: string]: string } | undefined { + return { + "credentials.accessKeyId": "BEDROCK_AWS_ACCESS_KEY_ID", + "credentials.secretAccessKey": "BEDROCK_AWS_SECRET_ACCESS_KEY", + }; + } + + get lc_attributes(): SerializedFields | undefined { + return { region: this.region }; + } + + _llmType() { + return "bedrock"; + } + + static lc_name() { + return "Bedrock"; + } + + constructor(fields?: Partial & BaseLLMParams) { + super(fields ?? {}); + + this.model = fields?.model ?? this.model; + const allowedModels = ["ai21", "anthropic", "amazon", "cohere", "meta"]; + if (!allowedModels.includes(this.model.split(".")[0])) { + throw new Error( + `Unknown model: '${this.model}', only these are supported: ${allowedModels}` + ); + } + const region = + fields?.region ?? getEnvironmentVariable("AWS_DEFAULT_REGION"); + if (!region) { + throw new Error( + "Please set the AWS_DEFAULT_REGION environment variable or pass it to the constructor as the region field." + ); + } + this.region = region; + + const credentials = fields?.credentials; + if (!credentials) { + throw new Error( + "Please set the AWS credentials in the 'credentials' field." + ); + } + this.credentials = credentials; + + this.temperature = fields?.temperature ?? this.temperature; + this.maxTokens = fields?.maxTokens ?? this.maxTokens; + this.fetchFn = fields?.fetchFn ?? fetch.bind(globalThis); + this.endpointHost = fields?.endpointHost ?? fields?.endpointUrl; + this.stopSequences = fields?.stopSequences; + this.modelKwargs = fields?.modelKwargs; + this.streaming = fields?.streaming ?? this.streaming; + } + + /** Call out to Bedrock service model. + Arguments: + prompt: The prompt to pass into the model. + + Returns: + The string generated by the model. + + Example: + response = model.call("Tell me a joke.") + */ + async _call( + prompt: string, + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): Promise { + const service = "bedrock-runtime"; + const endpointHost = + this.endpointHost ?? `${service}.${this.region}.amazonaws.com`; + const provider = this.model.split(".")[0]; + if (this.streaming) { + const stream = this._streamResponseChunks(prompt, options, runManager); + let finalResult: GenerationChunk | undefined; + for await (const chunk of stream) { + if (finalResult === undefined) { + finalResult = chunk; + } else { + finalResult = finalResult.concat(chunk); + } + } + return finalResult?.text ?? ""; + } + const response = await this._signedFetch(prompt, options, { + bedrockMethod: "invoke", + endpointHost, + provider, + }); + const json = await response.json(); + if (!response.ok) { + throw new Error( + `Error ${response.status}: ${json.message ?? JSON.stringify(json)}` + ); + } + const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, json); + return text; + } + + async _signedFetch( + prompt: string, + options: this["ParsedCallOptions"], + fields: { + bedrockMethod: "invoke" | "invoke-with-response-stream"; + endpointHost: string; + provider: string; + } + ) { + const { bedrockMethod, endpointHost, provider } = fields; + const inputBody = BedrockLLMInputOutputAdapter.prepareInput( + provider, + prompt, + this.maxTokens, + this.temperature, + options.stop ?? this.stopSequences, + this.modelKwargs, + fields.bedrockMethod + ); + + const url = new URL( + `https://${endpointHost}/model/${this.model}/${bedrockMethod}` + ); + + const request = new HttpRequest({ + hostname: url.hostname, + path: url.pathname, + protocol: url.protocol, + method: "POST", // method must be uppercase + body: JSON.stringify(inputBody), + query: Object.fromEntries(url.searchParams.entries()), + headers: { + // host is required by AWS Signature V4: https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html + host: url.host, + accept: "application/json", + "content-type": "application/json", + }, + }); + + const signer = new SignatureV4({ + credentials: this.credentials, + service: "bedrock", + region: this.region, + sha256: Sha256, + }); + + const signedRequest = await signer.sign(request); + + // Send request to AWS using the low-level fetch API + const response = await this.caller.callWithOptions( + { signal: options.signal }, + async () => + this.fetchFn(url, { + headers: signedRequest.headers, + body: signedRequest.body, + method: signedRequest.method, + }) + ); + return response; + } + + invocationParams(options?: this["ParsedCallOptions"]) { + return { + model: this.model, + region: this.region, + temperature: this.temperature, + maxTokens: this.maxTokens, + stop: options?.stop ?? this.stopSequences, + modelKwargs: this.modelKwargs, + }; + } + + async *_streamResponseChunks( + prompt: string, + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): AsyncGenerator { + const provider = this.model.split(".")[0]; + const bedrockMethod = + provider === "anthropic" || provider === "cohere" || provider === "meta" + ? "invoke-with-response-stream" + : "invoke"; + + const service = "bedrock-runtime"; + const endpointHost = + this.endpointHost ?? `${service}.${this.region}.amazonaws.com`; + + // Send request to AWS using the low-level fetch API + const response = await this._signedFetch(prompt, options, { + bedrockMethod, + endpointHost, + provider, + }); + + if (response.status < 200 || response.status >= 300) { + throw Error( + `Failed to access underlying url '${endpointHost}': got ${ + response.status + } ${response.statusText}: ${await response.text()}` + ); + } + + if ( + provider === "anthropic" || + provider === "cohere" || + provider === "meta" + ) { + const reader = response.body?.getReader(); + const decoder = new TextDecoder(); + for await (const chunk of this._readChunks(reader)) { + const event = this.codec.decode(chunk); + if ( + (event.headers[":event-type"] !== undefined && + event.headers[":event-type"].value !== "chunk") || + event.headers[":content-type"].value !== "application/json" + ) { + throw Error(`Failed to get event chunk: got ${chunk}`); + } + const body = JSON.parse(decoder.decode(event.body)); + if (body.message) { + throw new Error(body.message); + } + if (body.bytes !== undefined) { + const chunkResult = JSON.parse( + decoder.decode( + Uint8Array.from(atob(body.bytes), (m) => m.codePointAt(0) ?? 0) + ) + ); + const text = BedrockLLMInputOutputAdapter.prepareOutput( + provider, + chunkResult + ); + yield new GenerationChunk({ + text, + generationInfo: {}, + }); + // eslint-disable-next-line no-void + void runManager?.handleLLMNewToken(text); + } + } + } else { + const json = await response.json(); + const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, json); + yield new GenerationChunk({ + text, + generationInfo: {}, + }); + // eslint-disable-next-line no-void + void runManager?.handleLLMNewToken(text); + } + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + _readChunks(reader: any) { + function _concatChunks(a: Uint8Array, b: Uint8Array) { + const newBuffer = new Uint8Array(a.length + b.length); + newBuffer.set(a); + newBuffer.set(b, a.length); + return newBuffer; + } + + function getMessageLength(buffer: Uint8Array) { + if (buffer.byteLength === 0) return 0; + const view = new DataView( + buffer.buffer, + buffer.byteOffset, + buffer.byteLength + ); + + return view.getUint32(0, false); + } + + return { + async *[Symbol.asyncIterator]() { + let readResult = await reader.read(); + + let buffer: Uint8Array = new Uint8Array(0); + while (!readResult.done) { + const chunk: Uint8Array = readResult.value; + + buffer = _concatChunks(buffer, chunk); + let messageLength = getMessageLength(buffer); + + while (buffer.byteLength > 0 && buffer.byteLength >= messageLength) { + yield buffer.slice(0, messageLength); + buffer = buffer.slice(messageLength); + messageLength = getMessageLength(buffer); + } + + readResult = await reader.read(); + } + }, + }; + } +} diff --git a/libs/langchain-community/src/llms/cloudflare_workersai.ts b/libs/langchain-community/src/llms/cloudflare_workersai.ts new file mode 100644 index 000000000000..b7b05ed05f53 --- /dev/null +++ b/libs/langchain-community/src/llms/cloudflare_workersai.ts @@ -0,0 +1,190 @@ +import { LLM, type BaseLLMParams } from "@langchain/core/language_models/llms"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { GenerationChunk } from "@langchain/core/outputs"; + +import { convertEventStreamToIterableReadableDataStream } from "../utils/event_source_parse.js"; + +/** + * Interface for CloudflareWorkersAI input parameters. + */ +export interface CloudflareWorkersAIInput { + cloudflareAccountId?: string; + cloudflareApiToken?: string; + model?: string; + baseUrl?: string; + streaming?: boolean; +} + +/** + * Class representing the CloudflareWorkersAI language model. It extends the LLM (Large + * Language Model) class, providing a standard interface for interacting + * with the CloudflareWorkersAI language model. + */ +export class CloudflareWorkersAI + extends LLM + implements CloudflareWorkersAIInput +{ + model = "@cf/meta/llama-2-7b-chat-int8"; + + cloudflareAccountId?: string; + + cloudflareApiToken?: string; + + baseUrl: string; + + streaming = false; + + static lc_name() { + return "CloudflareWorkersAI"; + } + + lc_serializable = true; + + constructor(fields?: CloudflareWorkersAIInput & BaseLLMParams) { + super(fields ?? {}); + + this.model = fields?.model ?? this.model; + this.streaming = fields?.streaming ?? this.streaming; + this.cloudflareAccountId = + fields?.cloudflareAccountId ?? + getEnvironmentVariable("CLOUDFLARE_ACCOUNT_ID"); + this.cloudflareApiToken = + fields?.cloudflareApiToken ?? + getEnvironmentVariable("CLOUDFLARE_API_TOKEN"); + this.baseUrl = + fields?.baseUrl ?? + `https://api.cloudflare.com/client/v4/accounts/${this.cloudflareAccountId}/ai/run`; + if (this.baseUrl.endsWith("/")) { + this.baseUrl = this.baseUrl.slice(0, -1); + } + } + + /** + * Method to validate the environment. + */ + validateEnvironment() { + if (this.baseUrl === undefined) { + if (!this.cloudflareAccountId) { + throw new Error( + `No Cloudflare account ID found. Please provide it when instantiating the CloudflareWorkersAI class, or set it as "CLOUDFLARE_ACCOUNT_ID" in your environment variables.` + ); + } + if (!this.cloudflareApiToken) { + throw new Error( + `No Cloudflare API key found. Please provide it when instantiating the CloudflareWorkersAI class, or set it as "CLOUDFLARE_API_KEY" in your environment variables.` + ); + } + } + } + + /** Get the identifying parameters for this LLM. */ + get identifyingParams() { + return { model: this.model }; + } + + /** + * Get the parameters used to invoke the model + */ + invocationParams() { + return { + model: this.model, + }; + } + + /** Get the type of LLM. */ + _llmType() { + return "cloudflare"; + } + + async _request( + prompt: string, + options: this["ParsedCallOptions"], + stream?: boolean + ) { + this.validateEnvironment(); + + const url = `${this.baseUrl}/${this.model}`; + const headers = { + Authorization: `Bearer ${this.cloudflareApiToken}`, + "Content-Type": "application/json", + }; + + const data = { prompt, stream }; + return this.caller.call(async () => { + const response = await fetch(url, { + method: "POST", + headers, + body: JSON.stringify(data), + signal: options.signal, + }); + if (!response.ok) { + const error = new Error( + `Cloudflare LLM call failed with status code ${response.status}` + ); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (error as any).response = response; + throw error; + } + return response; + }); + } + + async *_streamResponseChunks( + prompt: string, + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): AsyncGenerator { + const response = await this._request(prompt, options, true); + if (!response.body) { + throw new Error("Empty response from Cloudflare. Please try again."); + } + const stream = convertEventStreamToIterableReadableDataStream( + response.body + ); + for await (const chunk of stream) { + if (chunk !== "[DONE]") { + const parsedChunk = JSON.parse(chunk); + const generationChunk = new GenerationChunk({ + text: parsedChunk.response, + }); + yield generationChunk; + // eslint-disable-next-line no-void + void runManager?.handleLLMNewToken(generationChunk.text ?? ""); + } + } + } + + /** Call out to CloudflareWorkersAI's complete endpoint. + Args: + prompt: The prompt to pass into the model. + Returns: + The string generated by the model. + Example: + let response = CloudflareWorkersAI.call("Tell me a joke."); + */ + async _call( + prompt: string, + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): Promise { + if (!this.streaming) { + const response = await this._request(prompt, options); + + const responseData = await response.json(); + + return responseData.result.response; + } else { + const stream = this._streamResponseChunks(prompt, options, runManager); + let finalResult: GenerationChunk | undefined; + for await (const chunk of stream) { + if (finalResult === undefined) { + finalResult = chunk; + } else { + finalResult = finalResult.concat(chunk); + } + } + return finalResult?.text ?? ""; + } + } +} diff --git a/libs/langchain-community/src/llms/cohere.ts b/libs/langchain-community/src/llms/cohere.ts new file mode 100644 index 000000000000..6d73e684cb91 --- /dev/null +++ b/libs/langchain-community/src/llms/cohere.ts @@ -0,0 +1,129 @@ +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { LLM, type BaseLLMParams } from "@langchain/core/language_models/llms"; + +/** + * Interface for the input parameters specific to the Cohere model. + */ +export interface CohereInput extends BaseLLMParams { + /** Sampling temperature to use */ + temperature?: number; + + /** + * Maximum number of tokens to generate in the completion. + */ + maxTokens?: number; + + /** Model to use */ + model?: string; + + apiKey?: string; +} + +/** + * Class representing a Cohere Large Language Model (LLM). It interacts + * with the Cohere API to generate text completions. + * @example + * ```typescript + * const model = new Cohere({ + * temperature: 0.7, + * maxTokens: 20, + * maxRetries: 5, + * }); + * + * const res = await model.call( + * "Question: What would be a good company name for a company that makes colorful socks?\nAnswer:" + * ); + * console.log({ res }); + * ``` + */ +export class Cohere extends LLM implements CohereInput { + static lc_name() { + return "Cohere"; + } + + get lc_secrets(): { [key: string]: string } | undefined { + return { + apiKey: "COHERE_API_KEY", + }; + } + + get lc_aliases(): { [key: string]: string } | undefined { + return { + apiKey: "cohere_api_key", + }; + } + + lc_serializable = true; + + temperature = 0; + + maxTokens = 250; + + model: string; + + apiKey: string; + + constructor(fields?: CohereInput) { + super(fields ?? {}); + + const apiKey = fields?.apiKey ?? getEnvironmentVariable("COHERE_API_KEY"); + + if (!apiKey) { + throw new Error( + "Please set the COHERE_API_KEY environment variable or pass it to the constructor as the apiKey field." + ); + } + + this.apiKey = apiKey; + this.maxTokens = fields?.maxTokens ?? this.maxTokens; + this.temperature = fields?.temperature ?? this.temperature; + this.model = fields?.model ?? this.model; + } + + _llmType() { + return "cohere"; + } + + /** @ignore */ + async _call( + prompt: string, + options: this["ParsedCallOptions"] + ): Promise { + const { cohere } = await Cohere.imports(); + + cohere.init(this.apiKey); + + // Hit the `generate` endpoint on the `large` model + const generateResponse = await this.caller.callWithOptions( + { signal: options.signal }, + cohere.generate.bind(cohere), + { + prompt, + model: this.model, + max_tokens: this.maxTokens, + temperature: this.temperature, + end_sequences: options.stop, + } + ); + try { + return generateResponse.body.generations[0].text; + } catch { + console.log(generateResponse); + throw new Error("Could not parse response."); + } + } + + /** @ignore */ + static async imports(): Promise<{ + cohere: typeof import("cohere-ai"); + }> { + try { + const { default: cohere } = await import("cohere-ai"); + return { cohere }; + } catch (e) { + throw new Error( + "Please install cohere-ai as a dependency with, e.g. `yarn add cohere-ai`" + ); + } + } +} diff --git a/libs/langchain-community/src/llms/fireworks.ts b/libs/langchain-community/src/llms/fireworks.ts new file mode 100644 index 000000000000..5b28b8008af4 --- /dev/null +++ b/libs/langchain-community/src/llms/fireworks.ts @@ -0,0 +1,142 @@ +import { + type OpenAIClient, + type OpenAICallOptions, + type OpenAIInput, + type OpenAICoreRequestOptions, + OpenAI, +} from "@langchain/openai"; +import type { BaseLLMParams } from "@langchain/core/language_models/llms"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +type FireworksUnsupportedArgs = + | "frequencyPenalty" + | "presencePenalty" + | "bestOf" + | "logitBias"; + +type FireworksUnsupportedCallOptions = "functions" | "function_call" | "tools"; + +export type FireworksCallOptions = Partial< + Omit +>; + +/** + * Wrapper around Fireworks API for large language models + * + * Fireworks API is compatible to the OpenAI API with some limitations described in + * https://readme.fireworks.ai/docs/openai-compatibility. + * + * To use, you should have the `openai` package installed and + * the `FIREWORKS_API_KEY` environment variable set. + */ +export class Fireworks extends OpenAI { + static lc_name() { + return "Fireworks"; + } + + _llmType() { + return "fireworks"; + } + + get lc_secrets(): { [key: string]: string } | undefined { + return { + fireworksApiKey: "FIREWORKS_API_KEY", + }; + } + + lc_serializable = true; + + fireworksApiKey?: string; + + constructor( + fields?: Partial< + Omit + > & + BaseLLMParams & { fireworksApiKey?: string } + ) { + const fireworksApiKey = + fields?.fireworksApiKey || getEnvironmentVariable("FIREWORKS_API_KEY"); + + if (!fireworksApiKey) { + throw new Error( + `Fireworks API key not found. Please set the FIREWORKS_API_KEY environment variable or provide the key into "fireworksApiKey"` + ); + } + + super({ + ...fields, + openAIApiKey: fireworksApiKey, + modelName: fields?.modelName || "accounts/fireworks/models/llama-v2-13b", + configuration: { + baseURL: "https://api.fireworks.ai/inference/v1", + }, + }); + + this.fireworksApiKey = fireworksApiKey; + } + + toJSON() { + const result = super.toJSON(); + + if ( + "kwargs" in result && + typeof result.kwargs === "object" && + result.kwargs != null + ) { + delete result.kwargs.openai_api_key; + delete result.kwargs.configuration; + } + + return result; + } + + async completionWithRetry( + request: OpenAIClient.CompletionCreateParamsStreaming, + options?: OpenAICoreRequestOptions + ): Promise>; + + async completionWithRetry( + request: OpenAIClient.CompletionCreateParamsNonStreaming, + options?: OpenAICoreRequestOptions + ): Promise; + + /** + * Calls the Fireworks API with retry logic in case of failures. + * @param request The request to send to the Fireworks API. + * @param options Optional configuration for the API call. + * @returns The response from the Fireworks API. + */ + async completionWithRetry( + request: + | OpenAIClient.CompletionCreateParamsStreaming + | OpenAIClient.CompletionCreateParamsNonStreaming, + options?: OpenAICoreRequestOptions + ): Promise< + AsyncIterable | OpenAIClient.Completions.Completion + > { + // https://readme.fireworks.ai/docs/openai-compatibility#api-compatibility + if (Array.isArray(request.prompt)) { + if (request.prompt.length > 1) { + throw new Error("Multiple prompts are not supported by Fireworks"); + } + + const prompt = request.prompt[0]; + if (typeof prompt !== "string") { + throw new Error("Only string prompts are supported by Fireworks"); + } + + request.prompt = prompt; + } + + delete request.frequency_penalty; + delete request.presence_penalty; + delete request.best_of; + delete request.logit_bias; + + if (request.stream === true) { + return super.completionWithRetry(request, options); + } + + return super.completionWithRetry(request, options); + } +} diff --git a/libs/langchain-community/src/llms/googlepalm.ts b/libs/langchain-community/src/llms/googlepalm.ts new file mode 100644 index 000000000000..0839b2376aac --- /dev/null +++ b/libs/langchain-community/src/llms/googlepalm.ts @@ -0,0 +1,205 @@ +import { TextServiceClient, protos } from "@google-ai/generativelanguage"; +import { GoogleAuth } from "google-auth-library"; +import { type BaseLLMParams, LLM } from "@langchain/core/language_models/llms"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +/** + * Input for Text generation for Google Palm + */ +export interface GooglePaLMTextInput extends BaseLLMParams { + /** + * Model Name to use + * + * Note: The format must follow the pattern - `models/{model}` + */ + modelName?: string; + + /** + * Controls the randomness of the output. + * + * Values can range from [0.0,1.0], inclusive. A value closer to 1.0 + * will produce responses that are more varied and creative, while + * a value closer to 0.0 will typically result in more straightforward + * responses from the model. + * + * Note: The default value varies by model + */ + temperature?: number; + + /** + * Maximum number of tokens to generate in the completion. + */ + maxOutputTokens?: number; + + /** + * Top-p changes how the model selects tokens for output. + * + * Tokens are selected from most probable to least until the sum + * of their probabilities equals the top-p value. + * + * For example, if tokens A, B, and C have a probability of + * .3, .2, and .1 and the top-p value is .5, then the model will + * select either A or B as the next token (using temperature). + * + * Note: The default value varies by model + */ + topP?: number; + + /** + * Top-k changes how the model selects tokens for output. + * + * A top-k of 1 means the selected token is the most probable among + * all tokens in the model’s vocabulary (also called greedy decoding), + * while a top-k of 3 means that the next token is selected from + * among the 3 most probable tokens (using temperature). + * + * Note: The default value varies by model + */ + topK?: number; + + /** + * The set of character sequences (up to 5) that will stop output generation. + * If specified, the API will stop at the first appearance of a stop + * sequence. + * + * Note: The stop sequence will not be included as part of the response. + */ + stopSequences?: string[]; + + /** + * A list of unique `SafetySetting` instances for blocking unsafe content. The API will block + * any prompts and responses that fail to meet the thresholds set by these settings. If there + * is no `SafetySetting` for a given `SafetyCategory` provided in the list, the API will use + * the default safety setting for that category. + */ + safetySettings?: protos.google.ai.generativelanguage.v1beta2.ISafetySetting[]; + + /** + * Google Palm API key to use + */ + apiKey?: string; +} + +/** + * Google Palm 2 Language Model Wrapper to generate texts + */ +export class GooglePaLM extends LLM implements GooglePaLMTextInput { + lc_serializable = true; + + get lc_secrets(): { [key: string]: string } | undefined { + return { + apiKey: "GOOGLE_PALM_API_KEY", + }; + } + + modelName = "models/text-bison-001"; + + temperature?: number; // default value chosen based on model + + maxOutputTokens?: number; // defaults to 64 + + topP?: number; // default value chosen based on model + + topK?: number; // default value chosen based on model + + stopSequences: string[] = []; + + safetySettings?: protos.google.ai.generativelanguage.v1beta2.ISafetySetting[]; // default safety setting for that category + + apiKey?: string; + + private client: TextServiceClient; + + constructor(fields?: GooglePaLMTextInput) { + super(fields ?? {}); + + this.modelName = fields?.modelName ?? this.modelName; + + this.temperature = fields?.temperature ?? this.temperature; + if (this.temperature && (this.temperature < 0 || this.temperature > 1)) { + throw new Error("`temperature` must be in the range of [0.0,1.0]"); + } + + this.maxOutputTokens = fields?.maxOutputTokens ?? this.maxOutputTokens; + if (this.maxOutputTokens && this.maxOutputTokens < 0) { + throw new Error("`maxOutputTokens` must be a positive integer"); + } + + this.topP = fields?.topP ?? this.topP; + if (this.topP && this.topP < 0) { + throw new Error("`topP` must be a positive integer"); + } + + if (this.topP && this.topP > 1) { + throw new Error("Google PaLM `topP` must in the range of [0,1]"); + } + + this.topK = fields?.topK ?? this.topK; + if (this.topK && this.topK < 0) { + throw new Error("`topK` must be a positive integer"); + } + + this.stopSequences = fields?.stopSequences ?? this.stopSequences; + + this.safetySettings = fields?.safetySettings ?? this.safetySettings; + if (this.safetySettings && this.safetySettings.length > 0) { + const safetySettingsSet = new Set( + this.safetySettings.map((s) => s.category) + ); + if (safetySettingsSet.size !== this.safetySettings.length) { + throw new Error( + "The categories in `safetySettings` array must be unique" + ); + } + } + + this.apiKey = + fields?.apiKey ?? getEnvironmentVariable("GOOGLE_PALM_API_KEY"); + if (!this.apiKey) { + throw new Error( + "Please set an API key for Google Palm 2 in the environment variable GOOGLE_PALM_API_KEY or in the `apiKey` field of the GooglePalm constructor" + ); + } + + this.client = new TextServiceClient({ + authClient: new GoogleAuth().fromAPIKey(this.apiKey), + }); + } + + _llmType(): string { + return "googlepalm"; + } + + async _call( + prompt: string, + options: this["ParsedCallOptions"] + ): Promise { + const res = await this.caller.callWithOptions( + { signal: options.signal }, + this._generateText.bind(this), + prompt + ); + return res ?? ""; + } + + protected async _generateText( + prompt: string + ): Promise { + const res = await this.client.generateText({ + model: this.modelName, + temperature: this.temperature, + candidateCount: 1, + topK: this.topK, + topP: this.topP, + maxOutputTokens: this.maxOutputTokens, + stopSequences: this.stopSequences, + safetySettings: this.safetySettings, + prompt: { + text: prompt, + }, + }); + return res[0].candidates && res[0].candidates.length > 0 + ? res[0].candidates[0].output + : undefined; + } +} diff --git a/langchain/src/llms/googlevertexai/common.ts b/libs/langchain-community/src/llms/googlevertexai/common.ts similarity index 94% rename from langchain/src/llms/googlevertexai/common.ts rename to libs/langchain-community/src/llms/googlevertexai/common.ts index ff1663d7db77..add11058e56a 100644 --- a/langchain/src/llms/googlevertexai/common.ts +++ b/libs/langchain-community/src/llms/googlevertexai/common.ts @@ -1,18 +1,23 @@ -import { BaseLLM } from "../base.js"; -import { Generation, GenerationChunk, LLMResult } from "../../schema/index.js"; +import { BaseLLM } from "@langchain/core/language_models/llms"; +import { + Generation, + GenerationChunk, + LLMResult, +} from "@langchain/core/outputs"; +import type { BaseLanguageModelCallOptions } from "@langchain/core/language_models/base"; +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; + import { GoogleVertexAILLMConnection, GoogleVertexAIStream, GoogleVertexAILLMResponse, -} from "../../util/googlevertexai-connection.js"; +} from "../../utils/googlevertexai-connection.js"; import { GoogleVertexAIBaseLLMInput, GoogleVertexAIBasePrediction, GoogleVertexAILLMPredictions, GoogleVertexAIModelParams, } from "../../types/googlevertexai-types.js"; -import { BaseLanguageModelCallOptions } from "../../base_language/index.js"; -import { CallbackManagerForLLMRun } from "../../callbacks/index.js"; /** * Interface representing the instance of text input to the Google Vertex diff --git a/libs/langchain-community/src/llms/googlevertexai/index.ts b/libs/langchain-community/src/llms/googlevertexai/index.ts new file mode 100644 index 000000000000..9406c3e01013 --- /dev/null +++ b/libs/langchain-community/src/llms/googlevertexai/index.ts @@ -0,0 +1,66 @@ +import { GoogleAuthOptions } from "google-auth-library"; +import { GoogleVertexAILLMConnection } from "../../utils/googlevertexai-connection.js"; +import { GoogleVertexAIBaseLLMInput } from "../../types/googlevertexai-types.js"; +import { BaseGoogleVertexAI } from "./common.js"; +import { GAuthClient } from "../../utils/googlevertexai-gauth.js"; + +/** + * Interface representing the input to the Google Vertex AI model. + */ +export interface GoogleVertexAITextInput + extends GoogleVertexAIBaseLLMInput {} + +/** + * Enables calls to the Google Cloud's Vertex AI API to access + * Large Language Models. + * + * To use, you will need to have one of the following authentication + * methods in place: + * - You are logged into an account permitted to the Google Cloud project + * using Vertex AI. + * - You are running this on a machine using a service account permitted to + * the Google Cloud project using Vertex AI. + * - The `GOOGLE_APPLICATION_CREDENTIALS` environment variable is set to the + * path of a credentials file for a service account permitted to the + * Google Cloud project using Vertex AI. + * @example + * ```typescript + * const model = new GoogleVertexAI({ + * temperature: 0.7, + * }); + * const stream = await model.stream( + * "What would be a good company name for a company that makes colorful socks?", + * ); + * for await (const chunk of stream) { + * console.log(chunk); + * } + * ``` + */ +export class GoogleVertexAI extends BaseGoogleVertexAI { + static lc_name() { + return "VertexAI"; + } + + constructor(fields?: GoogleVertexAITextInput) { + super(fields); + + const client = new GAuthClient({ + scopes: "https://www.googleapis.com/auth/cloud-platform", + ...fields?.authOptions, + }); + + this.connection = new GoogleVertexAILLMConnection( + { ...fields, ...this }, + this.caller, + client, + false + ); + + this.streamedConnection = new GoogleVertexAILLMConnection( + { ...fields, ...this }, + this.caller, + client, + true + ); + } +} diff --git a/libs/langchain-community/src/llms/googlevertexai/web.ts b/libs/langchain-community/src/llms/googlevertexai/web.ts new file mode 100644 index 000000000000..9ceb3cbf2285 --- /dev/null +++ b/libs/langchain-community/src/llms/googlevertexai/web.ts @@ -0,0 +1,66 @@ +import { + WebGoogleAuth, + WebGoogleAuthOptions, +} from "../../utils/googlevertexai-webauth.js"; +import { GoogleVertexAILLMConnection } from "../../utils/googlevertexai-connection.js"; +import { GoogleVertexAIBaseLLMInput } from "../../types/googlevertexai-types.js"; +import { BaseGoogleVertexAI } from "./common.js"; + +/** + * Interface representing the input to the Google Vertex AI model. + */ +export interface GoogleVertexAITextInput + extends GoogleVertexAIBaseLLMInput {} + +/** + * Enables calls to the Google Cloud's Vertex AI API to access + * Large Language Models. + * + * This entrypoint and class are intended to be used in web environments like Edge + * functions where you do not have access to the file system. It supports passing + * service account credentials directly as a "GOOGLE_VERTEX_AI_WEB_CREDENTIALS" + * environment variable or directly as "authOptions.credentials". + * @example + * ```typescript + * const model = new GoogleVertexAI({ + * temperature: 0.7, + * }); + * const stream = await model.stream( + * "What would be a good company name for a company that makes colorful socks?", + * ); + * for await (const chunk of stream) { + * console.log(chunk); + * } + * ``` + */ +export class GoogleVertexAI extends BaseGoogleVertexAI { + static lc_name() { + return "VertexAI"; + } + + get lc_secrets(): { [key: string]: string } { + return { + "authOptions.credentials": "GOOGLE_VERTEX_AI_WEB_CREDENTIALS", + }; + } + + constructor(fields?: GoogleVertexAITextInput) { + super(fields); + + const client = new WebGoogleAuth(fields?.authOptions); + + this.connection = new GoogleVertexAILLMConnection( + { ...fields, ...this }, + this.caller, + client, + false + ); + + this.streamedConnection = new GoogleVertexAILLMConnection( + { ...fields, ...this }, + this.caller, + client, + true + ); + } +} diff --git a/libs/langchain-community/src/llms/gradient_ai.ts b/libs/langchain-community/src/llms/gradient_ai.ts new file mode 100644 index 000000000000..ad834e104d7f --- /dev/null +++ b/libs/langchain-community/src/llms/gradient_ai.ts @@ -0,0 +1,142 @@ +import { Gradient } from "@gradientai/nodejs-sdk"; +import { + type BaseLLMCallOptions, + type BaseLLMParams, + LLM, +} from "@langchain/core/language_models/llms"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +/** + * The GradientLLMParams interface defines the input parameters for + * the GradientLLM class. + */ +export interface GradientLLMParams extends BaseLLMParams { + /** + * Gradient AI Access Token. + * Provide Access Token if you do not wish to automatically pull from env. + */ + gradientAccessKey?: string; + /** + * Gradient Workspace Id. + * Provide workspace id if you do not wish to automatically pull from env. + */ + workspaceId?: string; + /** + * Parameters accepted by the Gradient npm package. + */ + inferenceParameters?: Record; + /** + * Gradient AI Model Slug. + */ + modelSlug?: string; + /** + * Gradient Adapter ID for custom fine tuned models. + */ + adapterId?: string; +} + +/** + * The GradientLLM class is used to interact with Gradient AI inference Endpoint models. + * This requires your Gradient AI Access Token which is autoloaded if not specified. + */ +export class GradientLLM extends LLM { + static lc_name() { + return "GradientLLM"; + } + + get lc_secrets(): { [key: string]: string } | undefined { + return { + gradientAccessKey: "GRADIENT_ACCESS_TOKEN", + workspaceId: "GRADIENT_WORKSPACE_ID", + }; + } + + modelSlug = "llama2-7b-chat"; + + adapterId?: string; + + gradientAccessKey?: string; + + workspaceId?: string; + + inferenceParameters?: Record; + + lc_serializable = true; + + // Gradient AI does not export the BaseModel type. Once it does, we can use it here. + // eslint-disable-next-line @typescript-eslint/no-explicit-any + model: any; + + constructor(fields: GradientLLMParams) { + super(fields); + + this.modelSlug = fields?.modelSlug ?? this.modelSlug; + this.adapterId = fields?.adapterId; + this.gradientAccessKey = + fields?.gradientAccessKey ?? + getEnvironmentVariable("GRADIENT_ACCESS_TOKEN"); + this.workspaceId = + fields?.workspaceId ?? getEnvironmentVariable("GRADIENT_WORKSPACE_ID"); + + this.inferenceParameters = fields.inferenceParameters; + + if (!this.gradientAccessKey) { + throw new Error("Missing Gradient AI Access Token"); + } + + if (!this.workspaceId) { + throw new Error("Missing Gradient AI Workspace ID"); + } + } + + _llmType() { + return "gradient_ai"; + } + + /** + * Calls the Gradient AI endpoint and retrieves the result. + * @param {string} prompt The input prompt. + * @returns {Promise} A promise that resolves to the generated string. + */ + /** @ignore */ + async _call( + prompt: string, + _options: this["ParsedCallOptions"] + ): Promise { + await this.setModel(); + + // GradientLLM does not export the CompleteResponse type. Once it does, we can use it here. + interface CompleteResponse { + finishReason: string; + generatedOutput: string; + } + + const response = (await this.caller.call(async () => + this.model.complete({ + query: prompt, + ...this.inferenceParameters, + }) + )) as CompleteResponse; + + return response.generatedOutput; + } + + async setModel() { + if (this.model) return; + + const gradient = new Gradient({ + accessToken: this.gradientAccessKey, + workspaceId: this.workspaceId, + }); + + if (this.adapterId) { + this.model = await gradient.getModelAdapter({ + modelAdapterId: this.adapterId, + }); + } else { + this.model = await gradient.getBaseModel({ + baseModelSlug: this.modelSlug, + }); + } + } +} diff --git a/libs/langchain-community/src/llms/hf.ts b/libs/langchain-community/src/llms/hf.ts new file mode 100644 index 000000000000..f1ceb58a23f9 --- /dev/null +++ b/libs/langchain-community/src/llms/hf.ts @@ -0,0 +1,157 @@ +import { LLM, type BaseLLMParams } from "@langchain/core/language_models/llms"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +/** + * Interface defining the parameters for configuring the Hugging Face + * model for text generation. + */ +export interface HFInput { + /** Model to use */ + model: string; + + /** Custom inference endpoint URL to use */ + endpointUrl?: string; + + /** Sampling temperature to use */ + temperature?: number; + + /** + * Maximum number of tokens to generate in the completion. + */ + maxTokens?: number; + + /** Total probability mass of tokens to consider at each step */ + topP?: number; + + /** Integer to define the top tokens considered within the sample operation to create new text. */ + topK?: number; + + /** Penalizes repeated tokens according to frequency */ + frequencyPenalty?: number; + + /** API key to use. */ + apiKey?: string; + + /** + * Credentials to use for the request. If this is a string, it will be passed straight on. If it's a boolean, true will be "include" and false will not send credentials at all. + */ + includeCredentials?: string | boolean; +} + +/** + * Class implementing the Large Language Model (LLM) interface using the + * Hugging Face Inference API for text generation. + * @example + * ```typescript + * const model = new HuggingFaceInference({ + * model: "gpt2", + * temperature: 0.7, + * maxTokens: 50, + * }); + * + * const res = await model.call( + * "Question: What would be a good company name for a company that makes colorful socks?\nAnswer:" + * ); + * console.log({ res }); + * ``` + */ +export class HuggingFaceInference extends LLM implements HFInput { + lc_serializable = true; + + get lc_secrets(): { [key: string]: string } | undefined { + return { + apiKey: "HUGGINGFACEHUB_API_KEY", + }; + } + + model = "gpt2"; + + temperature: number | undefined = undefined; + + maxTokens: number | undefined = undefined; + + topP: number | undefined = undefined; + + topK: number | undefined = undefined; + + frequencyPenalty: number | undefined = undefined; + + apiKey: string | undefined = undefined; + + endpointUrl: string | undefined = undefined; + + includeCredentials: string | boolean | undefined = undefined; + + constructor(fields?: Partial & BaseLLMParams) { + super(fields ?? {}); + + this.model = fields?.model ?? this.model; + this.temperature = fields?.temperature ?? this.temperature; + this.maxTokens = fields?.maxTokens ?? this.maxTokens; + this.topP = fields?.topP ?? this.topP; + this.topK = fields?.topK ?? this.topK; + this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty; + this.apiKey = + fields?.apiKey ?? getEnvironmentVariable("HUGGINGFACEHUB_API_KEY"); + this.endpointUrl = fields?.endpointUrl; + this.includeCredentials = fields?.includeCredentials; + + if (!this.apiKey) { + throw new Error( + "Please set an API key for HuggingFace Hub in the environment variable HUGGINGFACEHUB_API_KEY or in the apiKey field of the HuggingFaceInference constructor." + ); + } + } + + _llmType() { + return "hf"; + } + + /** @ignore */ + async _call( + prompt: string, + options: this["ParsedCallOptions"] + ): Promise { + const { HfInference } = await HuggingFaceInference.imports(); + const hf = this.endpointUrl + ? new HfInference(this.apiKey, { + includeCredentials: this.includeCredentials, + }).endpoint(this.endpointUrl) + : new HfInference(this.apiKey, { + includeCredentials: this.includeCredentials, + }); + + const res = await this.caller.callWithOptions( + { signal: options.signal }, + hf.textGeneration.bind(hf), + { + model: this.model, + parameters: { + // make it behave similar to openai, returning only the generated text + return_full_text: false, + temperature: this.temperature, + max_new_tokens: this.maxTokens, + top_p: this.topP, + top_k: this.topK, + repetition_penalty: this.frequencyPenalty, + }, + inputs: prompt, + } + ); + return res.generated_text; + } + + /** @ignore */ + static async imports(): Promise<{ + HfInference: typeof import("@huggingface/inference").HfInference; + }> { + try { + const { HfInference } = await import("@huggingface/inference"); + return { HfInference }; + } catch (e) { + throw new Error( + "Please install huggingface as a dependency with, e.g. `yarn add @huggingface/inference`" + ); + } + } +} diff --git a/libs/langchain-community/src/llms/llama_cpp.ts b/libs/langchain-community/src/llms/llama_cpp.ts new file mode 100644 index 000000000000..db03ffbf31ed --- /dev/null +++ b/libs/langchain-community/src/llms/llama_cpp.ts @@ -0,0 +1,123 @@ +import { LlamaModel, LlamaContext, LlamaChatSession } from "node-llama-cpp"; +import { + LLM, + type BaseLLMCallOptions, + type BaseLLMParams, +} from "@langchain/core/language_models/llms"; +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { GenerationChunk } from "@langchain/core/outputs"; + +import { + LlamaBaseCppInputs, + createLlamaModel, + createLlamaContext, + createLlamaSession, +} from "../utils/llama_cpp.js"; + +/** + * Note that the modelPath is the only required parameter. For testing you + * can set this in the environment variable `LLAMA_PATH`. + */ +export interface LlamaCppInputs extends LlamaBaseCppInputs, BaseLLMParams {} + +export interface LlamaCppCallOptions extends BaseLLMCallOptions { + /** The maximum number of tokens the response should contain. */ + maxTokens?: number; + /** A function called when matching the provided token array */ + onToken?: (tokens: number[]) => void; +} + +/** + * To use this model you need to have the `node-llama-cpp` module installed. + * This can be installed using `npm install -S node-llama-cpp` and the minimum + * version supported in version 2.0.0. + * This also requires that have a locally built version of Llama2 installed. + */ +export class LlamaCpp extends LLM { + lc_serializable = true; + + declare CallOptions: LlamaCppCallOptions; + + static inputs: LlamaCppInputs; + + maxTokens?: number; + + temperature?: number; + + topK?: number; + + topP?: number; + + trimWhitespaceSuffix?: boolean; + + _model: LlamaModel; + + _context: LlamaContext; + + _session: LlamaChatSession; + + static lc_name() { + return "LlamaCpp"; + } + + constructor(inputs: LlamaCppInputs) { + super(inputs); + this.maxTokens = inputs?.maxTokens; + this.temperature = inputs?.temperature; + this.topK = inputs?.topK; + this.topP = inputs?.topP; + this.trimWhitespaceSuffix = inputs?.trimWhitespaceSuffix; + this._model = createLlamaModel(inputs); + this._context = createLlamaContext(this._model, inputs); + this._session = createLlamaSession(this._context); + } + + _llmType() { + return "llama2_cpp"; + } + + /** @ignore */ + async _call( + prompt: string, + options?: this["ParsedCallOptions"] + ): Promise { + try { + const promptOptions = { + onToken: options?.onToken, + maxTokens: this?.maxTokens, + temperature: this?.temperature, + topK: this?.topK, + topP: this?.topP, + trimWhitespaceSuffix: this?.trimWhitespaceSuffix, + }; + const completion = await this._session.prompt(prompt, promptOptions); + return completion; + } catch (e) { + throw new Error("Error getting prompt completion."); + } + } + + async *_streamResponseChunks( + prompt: string, + _options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): AsyncGenerator { + const promptOptions = { + temperature: this?.temperature, + topK: this?.topK, + topP: this?.topP, + }; + + const stream = await this.caller.call(async () => + this._context.evaluate(this._context.encode(prompt), promptOptions) + ); + + for await (const chunk of stream) { + yield new GenerationChunk({ + text: this._context.decode([chunk]), + generationInfo: {}, + }); + await runManager?.handleLLMNewToken(this._context.decode([chunk]) ?? ""); + } + } +} diff --git a/libs/langchain-community/src/llms/ollama.ts b/libs/langchain-community/src/llms/ollama.ts new file mode 100644 index 000000000000..f97fef1f31ca --- /dev/null +++ b/libs/langchain-community/src/llms/ollama.ts @@ -0,0 +1,246 @@ +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { GenerationChunk } from "@langchain/core/outputs"; +import type { StringWithAutocomplete } from "@langchain/core/utils/types"; +import { LLM, type BaseLLMParams } from "@langchain/core/language_models/llms"; + +import { + createOllamaStream, + OllamaInput, + OllamaCallOptions, +} from "../utils/ollama.js"; + +/** + * Class that represents the Ollama language model. It extends the base + * LLM class and implements the OllamaInput interface. + * @example + * ```typescript + * const ollama = new Ollama({ + * baseUrl: "http://api.example.com", + * model: "llama2", + * }); + * + * // Streaming translation from English to German + * const stream = await ollama.stream( + * `Translate "I love programming" into German.` + * ); + * + * const chunks = []; + * for await (const chunk of stream) { + * chunks.push(chunk); + * } + * + * console.log(chunks.join("")); + * ``` + */ +export class Ollama extends LLM implements OllamaInput { + static lc_name() { + return "Ollama"; + } + + lc_serializable = true; + + model = "llama2"; + + baseUrl = "http://localhost:11434"; + + embeddingOnly?: boolean; + + f16KV?: boolean; + + frequencyPenalty?: number; + + logitsAll?: boolean; + + lowVram?: boolean; + + mainGpu?: number; + + mirostat?: number; + + mirostatEta?: number; + + mirostatTau?: number; + + numBatch?: number; + + numCtx?: number; + + numGpu?: number; + + numGqa?: number; + + numKeep?: number; + + numThread?: number; + + penalizeNewline?: boolean; + + presencePenalty?: number; + + repeatLastN?: number; + + repeatPenalty?: number; + + ropeFrequencyBase?: number; + + ropeFrequencyScale?: number; + + temperature?: number; + + stop?: string[]; + + tfsZ?: number; + + topK?: number; + + topP?: number; + + typicalP?: number; + + useMLock?: boolean; + + useMMap?: boolean; + + vocabOnly?: boolean; + + format?: StringWithAutocomplete<"json">; + + constructor(fields: OllamaInput & BaseLLMParams) { + super(fields); + this.model = fields.model ?? this.model; + this.baseUrl = fields.baseUrl?.endsWith("/") + ? fields.baseUrl.slice(0, -1) + : fields.baseUrl ?? this.baseUrl; + + this.embeddingOnly = fields.embeddingOnly; + this.f16KV = fields.f16KV; + this.frequencyPenalty = fields.frequencyPenalty; + this.logitsAll = fields.logitsAll; + this.lowVram = fields.lowVram; + this.mainGpu = fields.mainGpu; + this.mirostat = fields.mirostat; + this.mirostatEta = fields.mirostatEta; + this.mirostatTau = fields.mirostatTau; + this.numBatch = fields.numBatch; + this.numCtx = fields.numCtx; + this.numGpu = fields.numGpu; + this.numGqa = fields.numGqa; + this.numKeep = fields.numKeep; + this.numThread = fields.numThread; + this.penalizeNewline = fields.penalizeNewline; + this.presencePenalty = fields.presencePenalty; + this.repeatLastN = fields.repeatLastN; + this.repeatPenalty = fields.repeatPenalty; + this.ropeFrequencyBase = fields.ropeFrequencyBase; + this.ropeFrequencyScale = fields.ropeFrequencyScale; + this.temperature = fields.temperature; + this.stop = fields.stop; + this.tfsZ = fields.tfsZ; + this.topK = fields.topK; + this.topP = fields.topP; + this.typicalP = fields.typicalP; + this.useMLock = fields.useMLock; + this.useMMap = fields.useMMap; + this.vocabOnly = fields.vocabOnly; + this.format = fields.format; + } + + _llmType() { + return "ollama"; + } + + invocationParams(options?: this["ParsedCallOptions"]) { + return { + model: this.model, + format: this.format, + options: { + embedding_only: this.embeddingOnly, + f16_kv: this.f16KV, + frequency_penalty: this.frequencyPenalty, + logits_all: this.logitsAll, + low_vram: this.lowVram, + main_gpu: this.mainGpu, + mirostat: this.mirostat, + mirostat_eta: this.mirostatEta, + mirostat_tau: this.mirostatTau, + num_batch: this.numBatch, + num_ctx: this.numCtx, + num_gpu: this.numGpu, + num_gqa: this.numGqa, + num_keep: this.numKeep, + num_thread: this.numThread, + penalize_newline: this.penalizeNewline, + presence_penalty: this.presencePenalty, + repeat_last_n: this.repeatLastN, + repeat_penalty: this.repeatPenalty, + rope_frequency_base: this.ropeFrequencyBase, + rope_frequency_scale: this.ropeFrequencyScale, + temperature: this.temperature, + stop: options?.stop ?? this.stop, + tfs_z: this.tfsZ, + top_k: this.topK, + top_p: this.topP, + typical_p: this.typicalP, + use_mlock: this.useMLock, + use_mmap: this.useMMap, + vocab_only: this.vocabOnly, + }, + }; + } + + async *_streamResponseChunks( + prompt: string, + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): AsyncGenerator { + const stream = await this.caller.call(async () => + createOllamaStream( + this.baseUrl, + { ...this.invocationParams(options), prompt }, + options + ) + ); + for await (const chunk of stream) { + if (!chunk.done) { + yield new GenerationChunk({ + text: chunk.response, + generationInfo: { + ...chunk, + response: undefined, + }, + }); + await runManager?.handleLLMNewToken(chunk.response ?? ""); + } else { + yield new GenerationChunk({ + text: "", + generationInfo: { + model: chunk.model, + total_duration: chunk.total_duration, + load_duration: chunk.load_duration, + prompt_eval_count: chunk.prompt_eval_count, + prompt_eval_duration: chunk.prompt_eval_duration, + eval_count: chunk.eval_count, + eval_duration: chunk.eval_duration, + }, + }); + } + } + } + + /** @ignore */ + async _call( + prompt: string, + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): Promise { + const chunks = []; + for await (const chunk of this._streamResponseChunks( + prompt, + options, + runManager + )) { + chunks.push(chunk.text); + } + return chunks.join(""); + } +} diff --git a/libs/langchain-community/src/llms/portkey.ts b/libs/langchain-community/src/llms/portkey.ts new file mode 100644 index 000000000000..71ae1c0725bf --- /dev/null +++ b/libs/langchain-community/src/llms/portkey.ts @@ -0,0 +1,179 @@ +import _ from "lodash"; +import { LLMOptions, Portkey as _Portkey } from "portkey-ai"; +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { GenerationChunk, LLMResult } from "@langchain/core/outputs"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { BaseLLM } from "@langchain/core/language_models/llms"; + +interface PortkeyOptions { + apiKey?: string; + baseURL?: string; + mode?: string; + llms?: [LLMOptions] | null; +} + +const readEnv = (env: string, default_val?: string): string | undefined => + getEnvironmentVariable(env) ?? default_val; + +export class PortkeySession { + portkey: _Portkey; + + constructor(options: PortkeyOptions = {}) { + if (!options.apiKey) { + /* eslint-disable no-param-reassign */ + options.apiKey = readEnv("PORTKEY_API_KEY"); + } + + if (!options.baseURL) { + /* eslint-disable no-param-reassign */ + options.baseURL = readEnv("PORTKEY_BASE_URL", "https://api.portkey.ai"); + } + + this.portkey = new _Portkey({}); + this.portkey.llms = [{}]; + if (!options.apiKey) { + throw new Error("Set Portkey ApiKey in PORTKEY_API_KEY env variable"); + } + + this.portkey = new _Portkey(options); + } +} + +const defaultPortkeySession: { + session: PortkeySession; + options: PortkeyOptions; +}[] = []; + +/** + * Get a session for the Portkey API. If one already exists with the same options, + * it will be returned. Otherwise, a new session will be created. + * @param options + * @returns + */ +export function getPortkeySession(options: PortkeyOptions = {}) { + let session = defaultPortkeySession.find((session) => + _.isEqual(session.options, options) + )?.session; + + if (!session) { + session = new PortkeySession(options); + defaultPortkeySession.push({ session, options }); + } + return session; +} + +/** + * @example + * ```typescript + * const model = new Portkey({ + * mode: "single", + * llms: [ + * { + * provider: "openai", + * virtual_key: "open-ai-key-1234", + * model: "text-davinci-003", + * max_tokens: 2000, + * }, + * ], + * }); + * + * // Stream the output of the model and process it + * const res = await model.stream( + * "Question: Write a story about a king\nAnswer:" + * ); + * for await (const i of res) { + * process.stdout.write(i); + * } + * ``` + */ +export class Portkey extends BaseLLM { + apiKey?: string = undefined; + + baseURL?: string = undefined; + + mode?: string = undefined; + + llms?: [LLMOptions] | null = undefined; + + session: PortkeySession; + + constructor(init?: Partial) { + super(init ?? {}); + this.apiKey = init?.apiKey; + + this.baseURL = init?.baseURL; + + this.mode = init?.mode; + + this.llms = init?.llms; + + this.session = getPortkeySession({ + apiKey: this.apiKey, + baseURL: this.baseURL, + llms: this.llms, + mode: this.mode, + }); + } + + _llmType() { + return "portkey"; + } + + async _generate( + prompts: string[], + options: this["ParsedCallOptions"], + _?: CallbackManagerForLLMRun + ): Promise { + const choices = []; + for (let i = 0; i < prompts.length; i += 1) { + const response = await this.session.portkey.completions.create({ + prompt: prompts[i], + ...options, + stream: false, + }); + choices.push(response.choices); + } + const generations = choices.map((promptChoices) => + promptChoices.map((choice) => ({ + text: choice.text ?? "", + generationInfo: { + finishReason: choice.finish_reason, + logprobs: choice.logprobs, + }, + })) + ); + + return { + generations, + }; + } + + async *_streamResponseChunks( + input: string, + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): AsyncGenerator { + const response = await this.session.portkey.completions.create({ + prompt: input, + ...options, + stream: true, + }); + for await (const data of response) { + const choice = data?.choices[0]; + if (!choice) { + continue; + } + const chunk = new GenerationChunk({ + text: choice.text ?? "", + generationInfo: { + finishReason: choice.finish_reason, + }, + }); + yield chunk; + void runManager?.handleLLMNewToken(chunk.text ?? ""); + } + if (options.signal?.aborted) { + throw new Error("AbortError"); + } + } +} diff --git a/libs/langchain-community/src/llms/raycast.ts b/libs/langchain-community/src/llms/raycast.ts new file mode 100644 index 000000000000..233a02f00776 --- /dev/null +++ b/libs/langchain-community/src/llms/raycast.ts @@ -0,0 +1,101 @@ +import { AI, environment } from "@raycast/api"; +import { LLM, type BaseLLMParams } from "@langchain/core/language_models/llms"; + +/** + * The input parameters for the RaycastAI class, which extends the BaseLLMParams interface. + */ +export interface RaycastAIInput extends BaseLLMParams { + model?: AI.Model; + creativity?: number; + rateLimitPerMinute?: number; +} + +const wait = (ms: number) => + new Promise((resolve) => { + setTimeout(resolve, ms); + }); + +/** + * The RaycastAI class, which extends the LLM class and implements the RaycastAIInput interface. + */ +export class RaycastAI extends LLM implements RaycastAIInput { + lc_serializable = true; + + /** + * The model to use for generating text. + */ + model: AI.Model; + + /** + * The creativity parameter, also known as the "temperature". + */ + creativity: number; + + /** + * The rate limit for API calls, in requests per minute. + */ + rateLimitPerMinute: number; + + /** + * The timestamp of the last API call, used to enforce the rate limit. + */ + private lastCallTimestamp = 0; + + /** + * Creates a new instance of the RaycastAI class. + * @param {RaycastAIInput} fields The input parameters for the RaycastAI class. + * @throws {Error} If the Raycast AI environment is not accessible. + */ + constructor(fields: RaycastAIInput) { + super(fields ?? {}); + + if (!environment.canAccess(AI)) { + throw new Error("Raycast AI environment is not accessible."); + } + + this.model = fields.model ?? "text-davinci-003"; + this.creativity = fields.creativity ?? 0.5; + this.rateLimitPerMinute = fields.rateLimitPerMinute ?? 10; + } + + /** + * Returns the type of the LLM, which is "raycast_ai". + * @return {string} The type of the LLM. + * @ignore + */ + _llmType() { + return "raycast_ai"; + } + + /** + * Calls AI.ask with the given prompt and returns the generated text. + * @param {string} prompt The prompt to generate text from. + * @return {Promise} A Promise that resolves to the generated text. + * @ignore + */ + async _call( + prompt: string, + options: this["ParsedCallOptions"] + ): Promise { + const response = await this.caller.call(async () => { + // Rate limit calls to Raycast AI + const now = Date.now(); + const timeSinceLastCall = now - this.lastCallTimestamp; + const timeToWait = + (60 / this.rateLimitPerMinute) * 1000 - timeSinceLastCall; + + if (timeToWait > 0) { + await wait(timeToWait); + } + + return await AI.ask(prompt, { + model: this.model, + creativity: this.creativity, + signal: options.signal, + }); + }); + + // Since Raycast AI returns the response directly, no need for output transformation + return response; + } +} diff --git a/libs/langchain-community/src/llms/replicate.ts b/libs/langchain-community/src/llms/replicate.ts new file mode 100644 index 000000000000..fd433d412abd --- /dev/null +++ b/libs/langchain-community/src/llms/replicate.ts @@ -0,0 +1,158 @@ +import { LLM, type BaseLLMParams } from "@langchain/core/language_models/llms"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +/** + * Interface defining the structure of the input data for the Replicate + * class. It includes details about the model to be used, any additional + * input parameters, and the API key for the Replicate service. + */ +export interface ReplicateInput { + // owner/model_name:version + model: `${string}/${string}:${string}`; + + input?: { + // different models accept different inputs + [key: string]: string | number | boolean; + }; + + apiKey?: string; + + /** The key used to pass prompts to the model. */ + promptKey?: string; +} + +/** + * Class responsible for managing the interaction with the Replicate API. + * It handles the API key and model details, makes the actual API calls, + * and converts the API response into a format usable by the rest of the + * LangChain framework. + * @example + * ```typescript + * const model = new Replicate({ + * model: "replicate/flan-t5-xl:3ae0799123a1fe11f8c89fd99632f843fc5f7a761630160521c4253149754523", + * }); + * + * const res = await model.call( + * "Question: What would be a good company name for a company that makes colorful socks?\nAnswer:" + * ); + * console.log({ res }); + * ``` + */ +export class Replicate extends LLM implements ReplicateInput { + static lc_name() { + return "Replicate"; + } + + get lc_secrets(): { [key: string]: string } | undefined { + return { + apiKey: "REPLICATE_API_TOKEN", + }; + } + + lc_serializable = true; + + model: ReplicateInput["model"]; + + input: ReplicateInput["input"]; + + apiKey: string; + + promptKey?: string; + + constructor(fields: ReplicateInput & BaseLLMParams) { + super(fields); + + const apiKey = + fields?.apiKey ?? + getEnvironmentVariable("REPLICATE_API_KEY") ?? // previous environment variable for backwards compatibility + getEnvironmentVariable("REPLICATE_API_TOKEN"); // current environment variable, matching the Python library + + if (!apiKey) { + throw new Error( + "Please set the REPLICATE_API_TOKEN environment variable" + ); + } + + this.apiKey = apiKey; + this.model = fields.model; + this.input = fields.input ?? {}; + this.promptKey = fields.promptKey; + } + + _llmType() { + return "replicate"; + } + + /** @ignore */ + async _call( + prompt: string, + options: this["ParsedCallOptions"] + ): Promise { + const imports = await Replicate.imports(); + + const replicate = new imports.Replicate({ + userAgent: "langchain", + auth: this.apiKey, + }); + + if (this.promptKey === undefined) { + const [modelString, versionString] = this.model.split(":"); + const version = await replicate.models.versions.get( + modelString.split("/")[0], + modelString.split("/")[1], + versionString + ); + const openapiSchema = version.openapi_schema; + const inputProperties: { "x-order": number | undefined }[] = + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (openapiSchema as any)?.components?.schemas?.Input?.properties; + if (inputProperties === undefined) { + this.promptKey = "prompt"; + } else { + const sortedInputProperties = Object.entries(inputProperties).sort( + ([_keyA, valueA], [_keyB, valueB]) => { + const orderA = valueA["x-order"] || 0; + const orderB = valueB["x-order"] || 0; + return orderA - orderB; + } + ); + this.promptKey = sortedInputProperties[0][0] ?? "prompt"; + } + } + const output = await this.caller.callWithOptions( + { signal: options.signal }, + () => + replicate.run(this.model, { + input: { + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + [this.promptKey!]: prompt, + ...this.input, + }, + }) + ); + + if (typeof output === "string") { + return output; + } else if (Array.isArray(output)) { + return output.join(""); + } else { + // Note this is a little odd, but the output format is not consistent + // across models, so it makes some amount of sense. + return String(output); + } + } + + /** @ignore */ + static async imports(): Promise<{ + Replicate: typeof import("replicate").default; + }> { + try { + const { default: Replicate } = await import("replicate"); + return { Replicate }; + } catch (e) { + throw new Error( + "Please install replicate as a dependency with, e.g. `yarn add replicate`" + ); + } + } +} diff --git a/libs/langchain-community/src/llms/sagemaker_endpoint.ts b/libs/langchain-community/src/llms/sagemaker_endpoint.ts new file mode 100644 index 000000000000..43f99320cfd9 --- /dev/null +++ b/libs/langchain-community/src/llms/sagemaker_endpoint.ts @@ -0,0 +1,289 @@ +import { + InvokeEndpointCommand, + InvokeEndpointWithResponseStreamCommand, + SageMakerRuntimeClient, + SageMakerRuntimeClientConfig, +} from "@aws-sdk/client-sagemaker-runtime"; +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { GenerationChunk } from "@langchain/core/outputs"; +import { + type BaseLLMCallOptions, + type BaseLLMParams, + LLM, +} from "@langchain/core/language_models/llms"; + +/** + * A handler class to transform input from LLM to a format that SageMaker + * endpoint expects. Similarily, the class also handles transforming output from + * the SageMaker endpoint to a format that LLM class expects. + * + * Example: + * ``` + * class ContentHandler implements ContentHandlerBase { + * contentType = "application/json" + * accepts = "application/json" + * + * transformInput(prompt: string, modelKwargs: Record) { + * const inputString = JSON.stringify({ + * prompt, + * ...modelKwargs + * }) + * return Buffer.from(inputString) + * } + * + * transformOutput(output: Uint8Array) { + * const responseJson = JSON.parse(Buffer.from(output).toString("utf-8")) + * return responseJson[0].generated_text + * } + * + * } + * ``` + */ +export abstract class BaseSageMakerContentHandler { + contentType = "text/plain"; + + accepts = "text/plain"; + + /** + * Transforms the prompt and model arguments into a specific format for sending to SageMaker. + * @param {InputType} prompt The prompt to be transformed. + * @param {Record} modelKwargs Additional arguments. + * @returns {Promise} A promise that resolves to the formatted data for sending. + */ + abstract transformInput( + prompt: InputType, + modelKwargs: Record + ): Promise; + + /** + * Transforms SageMaker output into a desired format. + * @param {Uint8Array} output The raw output from SageMaker. + * @returns {Promise} A promise that resolves to the transformed data. + */ + abstract transformOutput(output: Uint8Array): Promise; +} + +export type SageMakerLLMContentHandler = BaseSageMakerContentHandler< + string, + string +>; + +/** + * The SageMakerEndpointInput interface defines the input parameters for + * the SageMakerEndpoint class, which includes the endpoint name, client + * options for the SageMaker client, the content handler, and optional + * keyword arguments for the model and the endpoint. + */ +export interface SageMakerEndpointInput extends BaseLLMParams { + /** + * The name of the endpoint from the deployed SageMaker model. Must be unique + * within an AWS Region. + */ + endpointName: string; + /** + * Options passed to the SageMaker client. + */ + clientOptions: SageMakerRuntimeClientConfig; + /** + * Key word arguments to pass to the model. + */ + modelKwargs?: Record; + /** + * Optional attributes passed to the InvokeEndpointCommand + */ + endpointKwargs?: Record; + /** + * The content handler class that provides an input and output transform + * functions to handle formats between LLM and the endpoint. + */ + contentHandler: SageMakerLLMContentHandler; + streaming?: boolean; +} + +/** + * The SageMakerEndpoint class is used to interact with SageMaker + * Inference Endpoint models. It uses the AWS client for authentication, + * which automatically loads credentials. + * If a specific credential profile is to be used, the name of the profile + * from the ~/.aws/credentials file must be passed. The credentials or + * roles used should have the required policies to access the SageMaker + * endpoint. + */ +export class SageMakerEndpoint extends LLM { + lc_serializable = true; + + static lc_name() { + return "SageMakerEndpoint"; + } + + get lc_secrets(): { [key: string]: string } | undefined { + return { + "clientOptions.credentials.accessKeyId": "AWS_ACCESS_KEY_ID", + "clientOptions.credentials.secretAccessKey": "AWS_SECRET_ACCESS_KEY", + "clientOptions.credentials.sessionToken": "AWS_SESSION_TOKEN", + }; + } + + endpointName: string; + + modelKwargs?: Record; + + endpointKwargs?: Record; + + client: SageMakerRuntimeClient; + + contentHandler: SageMakerLLMContentHandler; + + streaming: boolean; + + constructor(fields: SageMakerEndpointInput) { + super(fields); + + if (!fields.clientOptions.region) { + throw new Error( + `Please pass a "clientOptions" object with a "region" field to the constructor` + ); + } + + const endpointName = fields?.endpointName; + if (!endpointName) { + throw new Error(`Please pass an "endpointName" field to the constructor`); + } + + const contentHandler = fields?.contentHandler; + if (!contentHandler) { + throw new Error( + `Please pass a "contentHandler" field to the constructor` + ); + } + + this.endpointName = fields.endpointName; + this.contentHandler = fields.contentHandler; + this.endpointKwargs = fields.endpointKwargs; + this.modelKwargs = fields.modelKwargs; + this.streaming = fields.streaming ?? false; + this.client = new SageMakerRuntimeClient(fields.clientOptions); + } + + _llmType() { + return "sagemaker_endpoint"; + } + + /** + * Calls the SageMaker endpoint and retrieves the result. + * @param {string} prompt The input prompt. + * @param {this["ParsedCallOptions"]} options Parsed call options. + * @param {CallbackManagerForLLMRun} runManager Optional run manager. + * @returns {Promise} A promise that resolves to the generated string. + */ + /** @ignore */ + async _call( + prompt: string, + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): Promise { + return this.streaming + ? await this.streamingCall(prompt, options, runManager) + : await this.noStreamingCall(prompt, options); + } + + private async streamingCall( + prompt: string, + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): Promise { + const chunks = []; + for await (const chunk of this._streamResponseChunks( + prompt, + options, + runManager + )) { + chunks.push(chunk.text); + } + return chunks.join(""); + } + + private async noStreamingCall( + prompt: string, + options: this["ParsedCallOptions"] + ): Promise { + const body = await this.contentHandler.transformInput( + prompt, + this.modelKwargs ?? {} + ); + const { contentType, accepts } = this.contentHandler; + + const response = await this.caller.call(() => + this.client.send( + new InvokeEndpointCommand({ + EndpointName: this.endpointName, + Body: body, + ContentType: contentType, + Accept: accepts, + ...this.endpointKwargs, + }), + { abortSignal: options.signal } + ) + ); + + if (response.Body === undefined) { + throw new Error("Inference result missing Body"); + } + return this.contentHandler.transformOutput(response.Body); + } + + /** + * Streams response chunks from the SageMaker endpoint. + * @param {string} prompt The input prompt. + * @param {this["ParsedCallOptions"]} options Parsed call options. + * @returns {AsyncGenerator} An asynchronous generator yielding generation chunks. + */ + async *_streamResponseChunks( + prompt: string, + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): AsyncGenerator { + const body = await this.contentHandler.transformInput( + prompt, + this.modelKwargs ?? {} + ); + const { contentType, accepts } = this.contentHandler; + + const stream = await this.caller.call(() => + this.client.send( + new InvokeEndpointWithResponseStreamCommand({ + EndpointName: this.endpointName, + Body: body, + ContentType: contentType, + Accept: accepts, + ...this.endpointKwargs, + }), + { abortSignal: options.signal } + ) + ); + + if (!stream.Body) { + throw new Error("Inference result missing Body"); + } + + for await (const chunk of stream.Body) { + if (chunk.PayloadPart && chunk.PayloadPart.Bytes) { + const text = await this.contentHandler.transformOutput( + chunk.PayloadPart.Bytes + ); + yield new GenerationChunk({ + text, + generationInfo: { + ...chunk, + response: undefined, + }, + }); + await runManager?.handleLLMNewToken(text); + } else if (chunk.InternalStreamFailure) { + throw new Error(chunk.InternalStreamFailure.message); + } else if (chunk.ModelStreamError) { + throw new Error(chunk.ModelStreamError.message); + } + } + } +} diff --git a/libs/langchain-community/src/llms/watsonx_ai.ts b/libs/langchain-community/src/llms/watsonx_ai.ts new file mode 100644 index 000000000000..2da98bb86bd7 --- /dev/null +++ b/libs/langchain-community/src/llms/watsonx_ai.ts @@ -0,0 +1,200 @@ +import { + type BaseLLMCallOptions, + type BaseLLMParams, + LLM, +} from "@langchain/core/language_models/llms"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +/** + * The WatsonxAIParams interface defines the input parameters for + * the WatsonxAI class. + */ +export interface WatsonxAIParams extends BaseLLMParams { + /** + * WatsonX AI Complete Endpoint. + * Can be used if you want a fully custom endpoint. + */ + endpoint?: string; + /** + * IBM Cloud Compute Region. + * eg. us-south, us-east, etc. + */ + region?: string; + /** + * WatsonX AI Version. + * Date representing the WatsonX AI Version. + * eg. 2023-05-29 + */ + version?: string; + /** + * WatsonX AI Key. + * Provide API Key if you do not wish to automatically pull from env. + */ + ibmCloudApiKey?: string; + /** + * WatsonX AI Key. + * Provide API Key if you do not wish to automatically pull from env. + */ + projectId?: string; + /** + * Parameters accepted by the WatsonX AI Endpoint. + */ + modelParameters?: Record; + /** + * WatsonX AI Model ID. + */ + modelId?: string; +} + +const endpointConstructor = (region: string, version: string) => + `https://${region}.ml.cloud.ibm.com/ml/v1-beta/generation/text?version=${version}`; + +/** + * The WatsonxAI class is used to interact with Watsonx AI + * Inference Endpoint models. It uses IBM Cloud for authentication. + * This requires your IBM Cloud API Key which is autoloaded if not specified. + */ + +export class WatsonxAI extends LLM { + lc_serializable = true; + + static lc_name() { + return "WatsonxAI"; + } + + get lc_secrets(): { [key: string]: string } | undefined { + return { + ibmCloudApiKey: "IBM_CLOUD_API_KEY", + projectId: "WATSONX_PROJECT_ID", + }; + } + + endpoint: string; + + region = "us-south"; + + version = "2023-05-29"; + + modelId = "meta-llama/llama-2-70b-chat"; + + modelKwargs?: Record; + + ibmCloudApiKey?: string; + + ibmCloudToken?: string; + + ibmCloudTokenExpiresAt?: number; + + projectId?: string; + + modelParameters?: Record; + + constructor(fields: WatsonxAIParams) { + super(fields); + + this.region = fields?.region ?? this.region; + this.version = fields?.version ?? this.version; + this.modelId = fields?.modelId ?? this.modelId; + this.ibmCloudApiKey = + fields?.ibmCloudApiKey ?? getEnvironmentVariable("IBM_CLOUD_API_KEY"); + this.projectId = + fields?.projectId ?? getEnvironmentVariable("WATSONX_PROJECT_ID"); + + this.endpoint = + fields?.endpoint ?? endpointConstructor(this.region, this.version); + this.modelParameters = fields.modelParameters; + + if (!this.ibmCloudApiKey) { + throw new Error("Missing IBM Cloud API Key"); + } + + if (!this.projectId) { + throw new Error("Missing WatsonX AI Project ID"); + } + } + + _llmType() { + return "watsonx_ai"; + } + + /** + * Calls the WatsonX AI endpoint and retrieves the result. + * @param {string} prompt The input prompt. + * @returns {Promise} A promise that resolves to the generated string. + */ + /** @ignore */ + async _call( + prompt: string, + _options: this["ParsedCallOptions"] + ): Promise { + interface WatsonxAIResponse { + results: { + generated_text: string; + generated_token_count: number; + input_token_count: number; + }[]; + errors: { + code: string; + message: string; + }[]; + } + const response = (await this.caller.call(async () => + fetch(this.endpoint, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json", + Authorization: `Bearer ${await this.generateToken()}`, + }, + body: JSON.stringify({ + project_id: this.projectId, + model_id: this.modelId, + input: prompt, + parameters: this.modelParameters, + }), + }).then((res) => res.json()) + )) as WatsonxAIResponse; + + /** + * Handle Errors for invalid requests. + */ + if (response.errors) { + throw new Error(response.errors[0].message); + } + + return response.results[0].generated_text; + } + + async generateToken(): Promise { + if (this.ibmCloudToken && this.ibmCloudTokenExpiresAt) { + if (this.ibmCloudTokenExpiresAt > Date.now()) { + return this.ibmCloudToken; + } + } + + interface TokenResponse { + access_token: string; + expiration: number; + } + + const urlTokenParams = new URLSearchParams(); + urlTokenParams.append( + "grant_type", + "urn:ibm:params:oauth:grant-type:apikey" + ); + urlTokenParams.append("apikey", this.ibmCloudApiKey as string); + + const data = (await fetch("https://iam.cloud.ibm.com/identity/token", { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: urlTokenParams, + }).then((res) => res.json())) as TokenResponse; + + this.ibmCloudTokenExpiresAt = data.expiration * 1000; + this.ibmCloudToken = data.access_token; + + return this.ibmCloudToken; + } +} diff --git a/libs/langchain-community/src/llms/writer.ts b/libs/langchain-community/src/llms/writer.ts new file mode 100644 index 000000000000..1f72d3273e76 --- /dev/null +++ b/libs/langchain-community/src/llms/writer.ts @@ -0,0 +1,172 @@ +import { Writer as WriterClient } from "@writerai/writer-sdk"; + +import { type BaseLLMParams, LLM } from "@langchain/core/language_models/llms"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +/** + * Interface for the input parameters specific to the Writer model. + */ +export interface WriterInput extends BaseLLMParams { + /** Writer API key */ + apiKey?: string; + + /** Writer organization ID */ + orgId?: string | number; + + /** Model to use */ + model?: string; + + /** Sampling temperature to use */ + temperature?: number; + + /** Minimum number of tokens to generate. */ + minTokens?: number; + + /** Maximum number of tokens to generate in the completion. */ + maxTokens?: number; + + /** Generates this many completions server-side and returns the "best"." */ + bestOf?: number; + + /** Penalizes repeated tokens according to frequency. */ + frequencyPenalty?: number; + + /** Whether to return log probabilities. */ + logprobs?: number; + + /** Number of completions to generate. */ + n?: number; + + /** Penalizes repeated tokens regardless of frequency. */ + presencePenalty?: number; + + /** Total probability mass of tokens to consider at each step. */ + topP?: number; +} + +/** + * Class representing a Writer Large Language Model (LLM). It interacts + * with the Writer API to generate text completions. + */ +export class Writer extends LLM implements WriterInput { + static lc_name() { + return "Writer"; + } + + get lc_secrets(): { [key: string]: string } | undefined { + return { + apiKey: "WRITER_API_KEY", + orgId: "WRITER_ORG_ID", + }; + } + + get lc_aliases(): { [key: string]: string } | undefined { + return { + apiKey: "writer_api_key", + orgId: "writer_org_id", + }; + } + + lc_serializable = true; + + apiKey: string; + + orgId: number; + + model = "palmyra-instruct"; + + temperature?: number; + + minTokens?: number; + + maxTokens?: number; + + bestOf?: number; + + frequencyPenalty?: number; + + logprobs?: number; + + n?: number; + + presencePenalty?: number; + + topP?: number; + + constructor(fields?: WriterInput) { + super(fields ?? {}); + + const apiKey = fields?.apiKey ?? getEnvironmentVariable("WRITER_API_KEY"); + const orgId = fields?.orgId ?? getEnvironmentVariable("WRITER_ORG_ID"); + + if (!apiKey) { + throw new Error( + "Please set the WRITER_API_KEY environment variable or pass it to the constructor as the apiKey field." + ); + } + + if (!orgId) { + throw new Error( + "Please set the WRITER_ORG_ID environment variable or pass it to the constructor as the orgId field." + ); + } + + this.apiKey = apiKey; + this.orgId = typeof orgId === "string" ? parseInt(orgId, 10) : orgId; + this.model = fields?.model ?? this.model; + this.temperature = fields?.temperature ?? this.temperature; + this.minTokens = fields?.minTokens ?? this.minTokens; + this.maxTokens = fields?.maxTokens ?? this.maxTokens; + this.bestOf = fields?.bestOf ?? this.bestOf; + this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty; + this.logprobs = fields?.logprobs ?? this.logprobs; + this.n = fields?.n ?? this.n; + this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty; + this.topP = fields?.topP ?? this.topP; + } + + _llmType() { + return "writer"; + } + + /** @ignore */ + async _call( + prompt: string, + options: this["ParsedCallOptions"] + ): Promise { + const sdk = new WriterClient({ + security: { + apiKey: this.apiKey, + }, + organizationId: this.orgId, + }); + + return this.caller.callWithOptions({ signal: options.signal }, async () => { + try { + const res = await sdk.completions.create({ + completionRequest: { + prompt, + stop: options.stop, + temperature: this.temperature, + minTokens: this.minTokens, + maxTokens: this.maxTokens, + bestOf: this.bestOf, + n: this.n, + frequencyPenalty: this.frequencyPenalty, + logprobs: this.logprobs, + presencePenalty: this.presencePenalty, + topP: this.topP, + }, + modelId: this.model, + }); + return ( + res.completionResponse?.choices?.[0].text ?? "No completion found." + ); + } catch (e) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (e as any).response = (e as any).rawResponse; + throw e; + } + }); + } +} diff --git a/libs/langchain-community/src/llms/yandex.ts b/libs/langchain-community/src/llms/yandex.ts new file mode 100644 index 000000000000..58b1f31111d3 --- /dev/null +++ b/libs/langchain-community/src/llms/yandex.ts @@ -0,0 +1,125 @@ +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { LLM, type BaseLLMParams } from "@langchain/core/language_models/llms"; + +const apiUrl = "https://llm.api.cloud.yandex.net/llm/v1alpha/instruct"; + +export interface YandexGPTInputs extends BaseLLMParams { + /** + * What sampling temperature to use. + * Should be a double number between 0 (inclusive) and 1 (inclusive). + */ + temperature?: number; + + /** + * Maximum limit on the total number of tokens + * used for both the input prompt and the generated response. + */ + maxTokens?: number; + + /** Model name to use. */ + model?: string; + + /** + * Yandex Cloud Api Key for service account + * with the `ai.languageModels.user` role. + */ + apiKey?: string; + + /** + * Yandex Cloud IAM token for service account + * with the `ai.languageModels.user` role. + */ + iamToken?: string; +} + +export class YandexGPT extends LLM implements YandexGPTInputs { + lc_serializable = true; + + static lc_name() { + return "Yandex GPT"; + } + + get lc_secrets(): { [key: string]: string } | undefined { + return { + apiKey: "YC_API_KEY", + iamToken: "YC_IAM_TOKEN", + }; + } + + temperature = 0.6; + + maxTokens = 1700; + + model = "general"; + + apiKey?: string; + + iamToken?: string; + + constructor(fields?: YandexGPTInputs) { + super(fields ?? {}); + + const apiKey = fields?.apiKey ?? getEnvironmentVariable("YC_API_KEY"); + + const iamToken = fields?.iamToken ?? getEnvironmentVariable("YC_IAM_TOKEN"); + + if (apiKey === undefined && iamToken === undefined) { + throw new Error( + "Please set the YC_API_KEY or YC_IAM_TOKEN environment variable or pass it to the constructor as the apiKey or iamToken field." + ); + } + + this.apiKey = apiKey; + this.iamToken = iamToken; + this.maxTokens = fields?.maxTokens ?? this.maxTokens; + this.temperature = fields?.temperature ?? this.temperature; + this.model = fields?.model ?? this.model; + } + + _llmType() { + return "yandexgpt"; + } + + /** @ignore */ + async _call( + prompt: string, + options: this["ParsedCallOptions"] + ): Promise { + // Hit the `generate` endpoint on the `large` model + return this.caller.callWithOptions({ signal: options.signal }, async () => { + const headers = { "Content-Type": "application/json", Authorization: "" }; + if (this.apiKey !== undefined) { + headers.Authorization = `Api-Key ${this.apiKey}`; + } else { + headers.Authorization = `Bearer ${this.iamToken}`; + } + const bodyData = { + model: this.model, + generationOptions: { + temperature: this.temperature, + maxTokens: this.maxTokens, + }, + + requestText: prompt, + }; + + try { + const response = await fetch(apiUrl, { + method: "POST", + headers, + body: JSON.stringify(bodyData), + }); + if (!response.ok) { + throw new Error( + `Failed to fetch ${apiUrl} from YandexGPT: ${response.status}` + ); + } + + const responseData = await response.json(); + return responseData.result.alternatives[0].text; + } catch (error) { + throw new Error(`Failed to fetch ${apiUrl} from YandexGPT ${error}`); + } + }); + } +} diff --git a/libs/langchain-community/src/load/import_constants.ts b/libs/langchain-community/src/load/import_constants.ts new file mode 100644 index 000000000000..9be7c339c8ef --- /dev/null +++ b/libs/langchain-community/src/load/import_constants.ts @@ -0,0 +1,99 @@ +// Auto-generated by `scripts/create-entrypoints.js`. Do not edit manually. + +export const optionalImportEntrypoints = [ + "langchain_community/tools/aws_sfn", + "langchain_community/tools/gmail", + "langchain_community/embeddings/bedrock", + "langchain_community/embeddings/cloudflare_workersai", + "langchain_community/embeddings/cohere", + "langchain_community/embeddings/googlepalm", + "langchain_community/embeddings/googlevertexai", + "langchain_community/embeddings/gradient_ai", + "langchain_community/embeddings/hf", + "langchain_community/embeddings/hf_transformers", + "langchain_community/embeddings/llama_cpp", + "langchain_community/embeddings/tensorflow", + "langchain_community/llms/bedrock", + "langchain_community/llms/bedrock/web", + "langchain_community/llms/cohere", + "langchain_community/llms/googlepalm", + "langchain_community/llms/googlevertexai", + "langchain_community/llms/googlevertexai/web", + "langchain_community/llms/gradient_ai", + "langchain_community/llms/hf", + "langchain_community/llms/llama_cpp", + "langchain_community/llms/portkey", + "langchain_community/llms/raycast", + "langchain_community/llms/replicate", + "langchain_community/llms/sagemaker_endpoint", + "langchain_community/llms/watsonx_ai", + "langchain_community/llms/writer", + "langchain_community/vectorstores/analyticdb", + "langchain_community/vectorstores/cassandra", + "langchain_community/vectorstores/chroma", + "langchain_community/vectorstores/clickhouse", + "langchain_community/vectorstores/closevector/node", + "langchain_community/vectorstores/closevector/web", + "langchain_community/vectorstores/cloudflare_vectorize", + "langchain_community/vectorstores/convex", + "langchain_community/vectorstores/elasticsearch", + "langchain_community/vectorstores/faiss", + "langchain_community/vectorstores/googlevertexai", + "langchain_community/vectorstores/hnswlib", + "langchain_community/vectorstores/lancedb", + "langchain_community/vectorstores/milvus", + "langchain_community/vectorstores/momento_vector_index", + "langchain_community/vectorstores/mongodb_atlas", + "langchain_community/vectorstores/myscale", + "langchain_community/vectorstores/neo4j_vector", + "langchain_community/vectorstores/opensearch", + "langchain_community/vectorstores/pgvector", + "langchain_community/vectorstores/pinecone", + "langchain_community/vectorstores/qdrant", + "langchain_community/vectorstores/redis", + "langchain_community/vectorstores/rockset", + "langchain_community/vectorstores/singlestore", + "langchain_community/vectorstores/supabase", + "langchain_community/vectorstores/tigris", + "langchain_community/vectorstores/typeorm", + "langchain_community/vectorstores/typesense", + "langchain_community/vectorstores/usearch", + "langchain_community/vectorstores/vercel_postgres", + "langchain_community/vectorstores/voy", + "langchain_community/vectorstores/weaviate", + "langchain_community/vectorstores/xata", + "langchain_community/vectorstores/zep", + "langchain_community/chat_models/bedrock", + "langchain_community/chat_models/bedrock/web", + "langchain_community/chat_models/googlevertexai", + "langchain_community/chat_models/googlevertexai/web", + "langchain_community/chat_models/googlepalm", + "langchain_community/chat_models/iflytek_xinghuo", + "langchain_community/chat_models/iflytek_xinghuo/web", + "langchain_community/chat_models/llama_cpp", + "langchain_community/chat_models/portkey", + "langchain_community/callbacks/handlers/llmonitor", + "langchain_community/retrievers/amazon_kendra", + "langchain_community/retrievers/metal", + "langchain_community/retrievers/supabase", + "langchain_community/retrievers/zep", + "langchain_community/graphs/neo4j_graph", + "langchain_community/document_transformers/html_to_text", + "langchain_community/document_transformers/mozilla_readability", + "langchain_community/storage/convex", + "langchain_community/storage/ioredis", + "langchain_community/storage/upstash_redis", + "langchain_community/storage/vercel_kv", + "langchain_community/stores/message/cassandra", + "langchain_community/stores/message/cloudflare_d1", + "langchain_community/stores/message/convex", + "langchain_community/stores/message/dynamodb", + "langchain_community/stores/message/firestore", + "langchain_community/stores/message/ioredis", + "langchain_community/stores/message/momento", + "langchain_community/stores/message/mongodb", + "langchain_community/stores/message/planetscale", + "langchain_community/stores/message/redis", + "langchain_community/stores/message/upstash_redis", + "langchain_community/stores/message/xata", +]; diff --git a/libs/langchain-community/src/load/import_map.ts b/libs/langchain-community/src/load/import_map.ts new file mode 100644 index 000000000000..79e48309c105 --- /dev/null +++ b/libs/langchain-community/src/load/import_map.ts @@ -0,0 +1,45 @@ +// Auto-generated by `scripts/create-entrypoints.js`. Do not edit manually. + +export * as load__serializable from "../load/serializable.js"; +export * as tools__aiplugin from "../tools/aiplugin.js"; +export * as tools__bingserpapi from "../tools/bingserpapi.js"; +export * as tools__brave_search from "../tools/brave_search.js"; +export * as tools__connery from "../tools/connery.js"; +export * as tools__dadjokeapi from "../tools/dadjokeapi.js"; +export * as tools__dataforseo_api_search from "../tools/dataforseo_api_search.js"; +export * as tools__google_custom_search from "../tools/google_custom_search.js"; +export * as tools__google_places from "../tools/google_places.js"; +export * as tools__ifttt from "../tools/ifttt.js"; +export * as tools__searchapi from "../tools/searchapi.js"; +export * as tools__searxng_search from "../tools/searxng_search.js"; +export * as tools__serpapi from "../tools/serpapi.js"; +export * as tools__serper from "../tools/serper.js"; +export * as tools__wikipedia_query_run from "../tools/wikipedia_query_run.js"; +export * as tools__wolframalpha from "../tools/wolframalpha.js"; +export * as embeddings__minimax from "../embeddings/minimax.js"; +export * as embeddings__ollama from "../embeddings/ollama.js"; +export * as embeddings__voyage from "../embeddings/voyage.js"; +export * as llms__ai21 from "../llms/ai21.js"; +export * as llms__aleph_alpha from "../llms/aleph_alpha.js"; +export * as llms__cloudflare_workersai from "../llms/cloudflare_workersai.js"; +export * as llms__fireworks from "../llms/fireworks.js"; +export * as llms__ollama from "../llms/ollama.js"; +export * as llms__yandex from "../llms/yandex.js"; +export * as vectorstores__memory from "../vectorstores/memory.js"; +export * as vectorstores__prisma from "../vectorstores/prisma.js"; +export * as vectorstores__vectara from "../vectorstores/vectara.js"; +export * as chat_models__baiduwenxin from "../chat_models/baiduwenxin.js"; +export * as chat_models__cloudflare_workersai from "../chat_models/cloudflare_workersai.js"; +export * as chat_models__fireworks from "../chat_models/fireworks.js"; +export * as chat_models__minimax from "../chat_models/minimax.js"; +export * as chat_models__ollama from "../chat_models/ollama.js"; +export * as chat_models__yandex from "../chat_models/yandex.js"; +export * as retrievers__chaindesk from "../retrievers/chaindesk.js"; +export * as retrievers__databerry from "../retrievers/databerry.js"; +export * as retrievers__tavily_search_api from "../retrievers/tavily_search_api.js"; +export * as caches__cloudflare_kv from "../caches/cloudflare_kv.js"; +export * as caches__momento from "../caches/momento.js"; +export * as caches__upstash_redis from "../caches/upstash_redis.js"; +export * as utils__event_source_parse from "../utils/event_source_parse.js"; +export * as stores__doc__base from "../stores/doc/base.js"; +export * as stores__doc__in_memory from "../stores/doc/in_memory.js"; diff --git a/libs/langchain-community/src/load/import_type.d.ts b/libs/langchain-community/src/load/import_type.d.ts new file mode 100644 index 000000000000..01637a94f0b2 --- /dev/null +++ b/libs/langchain-community/src/load/import_type.d.ts @@ -0,0 +1,342 @@ +// Auto-generated by `scripts/create-entrypoints.js`. Do not edit manually. + +export interface OptionalImportMap { + "@langchain/community/tools/aws_sfn"?: + | typeof import("../tools/aws_sfn.js") + | Promise; + "@langchain/community/tools/gmail"?: + | typeof import("../tools/gmail/index.js") + | Promise; + "@langchain/community/embeddings/bedrock"?: + | typeof import("../embeddings/bedrock.js") + | Promise; + "@langchain/community/embeddings/cloudflare_workersai"?: + | typeof import("../embeddings/cloudflare_workersai.js") + | Promise; + "@langchain/community/embeddings/cohere"?: + | typeof import("../embeddings/cohere.js") + | Promise; + "@langchain/community/embeddings/googlepalm"?: + | typeof import("../embeddings/googlepalm.js") + | Promise; + "@langchain/community/embeddings/googlevertexai"?: + | typeof import("../embeddings/googlevertexai.js") + | Promise; + "@langchain/community/embeddings/gradient_ai"?: + | typeof import("../embeddings/gradient_ai.js") + | Promise; + "@langchain/community/embeddings/hf"?: + | typeof import("../embeddings/hf.js") + | Promise; + "@langchain/community/embeddings/hf_transformers"?: + | typeof import("../embeddings/hf_transformers.js") + | Promise; + "@langchain/community/embeddings/llama_cpp"?: + | typeof import("../embeddings/llama_cpp.js") + | Promise; + "@langchain/community/embeddings/tensorflow"?: + | typeof import("../embeddings/tensorflow.js") + | Promise; + "@langchain/community/llms/bedrock"?: + | typeof import("../llms/bedrock/index.js") + | Promise; + "@langchain/community/llms/bedrock/web"?: + | typeof import("../llms/bedrock/web.js") + | Promise; + "@langchain/community/llms/cohere"?: + | typeof import("../llms/cohere.js") + | Promise; + "@langchain/community/llms/googlepalm"?: + | typeof import("../llms/googlepalm.js") + | Promise; + "@langchain/community/llms/googlevertexai"?: + | typeof import("../llms/googlevertexai/index.js") + | Promise; + "@langchain/community/llms/googlevertexai/web"?: + | typeof import("../llms/googlevertexai/web.js") + | Promise; + "@langchain/community/llms/gradient_ai"?: + | typeof import("../llms/gradient_ai.js") + | Promise; + "@langchain/community/llms/hf"?: + | typeof import("../llms/hf.js") + | Promise; + "@langchain/community/llms/llama_cpp"?: + | typeof import("../llms/llama_cpp.js") + | Promise; + "@langchain/community/llms/portkey"?: + | typeof import("../llms/portkey.js") + | Promise; + "@langchain/community/llms/raycast"?: + | typeof import("../llms/raycast.js") + | Promise; + "@langchain/community/llms/replicate"?: + | typeof import("../llms/replicate.js") + | Promise; + "@langchain/community/llms/sagemaker_endpoint"?: + | typeof import("../llms/sagemaker_endpoint.js") + | Promise; + "@langchain/community/llms/watsonx_ai"?: + | typeof import("../llms/watsonx_ai.js") + | Promise; + "@langchain/community/llms/writer"?: + | typeof import("../llms/writer.js") + | Promise; + "@langchain/community/vectorstores/analyticdb"?: + | typeof import("../vectorstores/analyticdb.js") + | Promise; + "@langchain/community/vectorstores/cassandra"?: + | typeof import("../vectorstores/cassandra.js") + | Promise; + "@langchain/community/vectorstores/chroma"?: + | typeof import("../vectorstores/chroma.js") + | Promise; + "@langchain/community/vectorstores/clickhouse"?: + | typeof import("../vectorstores/clickhouse.js") + | Promise; + "@langchain/community/vectorstores/closevector/node"?: + | typeof import("../vectorstores/closevector/node.js") + | Promise; + "@langchain/community/vectorstores/closevector/web"?: + | typeof import("../vectorstores/closevector/web.js") + | Promise; + "@langchain/community/vectorstores/cloudflare_vectorize"?: + | typeof import("../vectorstores/cloudflare_vectorize.js") + | Promise; + "@langchain/community/vectorstores/convex"?: + | typeof import("../vectorstores/convex.js") + | Promise; + "@langchain/community/vectorstores/elasticsearch"?: + | typeof import("../vectorstores/elasticsearch.js") + | Promise; + "@langchain/community/vectorstores/faiss"?: + | typeof import("../vectorstores/faiss.js") + | Promise; + "@langchain/community/vectorstores/googlevertexai"?: + | typeof import("../vectorstores/googlevertexai.js") + | Promise; + "@langchain/community/vectorstores/hnswlib"?: + | typeof import("../vectorstores/hnswlib.js") + | Promise; + "@langchain/community/vectorstores/lancedb"?: + | typeof import("../vectorstores/lancedb.js") + | Promise; + "@langchain/community/vectorstores/milvus"?: + | typeof import("../vectorstores/milvus.js") + | Promise; + "@langchain/community/vectorstores/momento_vector_index"?: + | typeof import("../vectorstores/momento_vector_index.js") + | Promise; + "@langchain/community/vectorstores/mongodb_atlas"?: + | typeof import("../vectorstores/mongodb_atlas.js") + | Promise; + "@langchain/community/vectorstores/myscale"?: + | typeof import("../vectorstores/myscale.js") + | Promise; + "@langchain/community/vectorstores/neo4j_vector"?: + | typeof import("../vectorstores/neo4j_vector.js") + | Promise; + "@langchain/community/vectorstores/opensearch"?: + | typeof import("../vectorstores/opensearch.js") + | Promise; + "@langchain/community/vectorstores/pgvector"?: + | typeof import("../vectorstores/pgvector.js") + | Promise; + "@langchain/community/vectorstores/pinecone"?: + | typeof import("../vectorstores/pinecone.js") + | Promise; + "@langchain/community/vectorstores/qdrant"?: + | typeof import("../vectorstores/qdrant.js") + | Promise; + "@langchain/community/vectorstores/redis"?: + | typeof import("../vectorstores/redis.js") + | Promise; + "@langchain/community/vectorstores/rockset"?: + | typeof import("../vectorstores/rockset.js") + | Promise; + "@langchain/community/vectorstores/singlestore"?: + | typeof import("../vectorstores/singlestore.js") + | Promise; + "@langchain/community/vectorstores/supabase"?: + | typeof import("../vectorstores/supabase.js") + | Promise; + "@langchain/community/vectorstores/tigris"?: + | typeof import("../vectorstores/tigris.js") + | Promise; + "@langchain/community/vectorstores/typeorm"?: + | typeof import("../vectorstores/typeorm.js") + | Promise; + "@langchain/community/vectorstores/typesense"?: + | typeof import("../vectorstores/typesense.js") + | Promise; + "@langchain/community/vectorstores/usearch"?: + | typeof import("../vectorstores/usearch.js") + | Promise; + "@langchain/community/vectorstores/vercel_postgres"?: + | typeof import("../vectorstores/vercel_postgres.js") + | Promise; + "@langchain/community/vectorstores/voy"?: + | typeof import("../vectorstores/voy.js") + | Promise; + "@langchain/community/vectorstores/weaviate"?: + | typeof import("../vectorstores/weaviate.js") + | Promise; + "@langchain/community/vectorstores/xata"?: + | typeof import("../vectorstores/xata.js") + | Promise; + "@langchain/community/vectorstores/zep"?: + | typeof import("../vectorstores/zep.js") + | Promise; + "@langchain/community/chat_models/bedrock"?: + | typeof import("../chat_models/bedrock/index.js") + | Promise; + "@langchain/community/chat_models/bedrock/web"?: + | typeof import("../chat_models/bedrock/web.js") + | Promise; + "@langchain/community/chat_models/googlevertexai"?: + | typeof import("../chat_models/googlevertexai/index.js") + | Promise; + "@langchain/community/chat_models/googlevertexai/web"?: + | typeof import("../chat_models/googlevertexai/web.js") + | Promise; + "@langchain/community/chat_models/googlepalm"?: + | typeof import("../chat_models/googlepalm.js") + | Promise; + "@langchain/community/chat_models/iflytek_xinghuo"?: + | typeof import("../chat_models/iflytek_xinghuo/index.js") + | Promise; + "@langchain/community/chat_models/iflytek_xinghuo/web"?: + | typeof import("../chat_models/iflytek_xinghuo/web.js") + | Promise; + "@langchain/community/chat_models/llama_cpp"?: + | typeof import("../chat_models/llama_cpp.js") + | Promise; + "@langchain/community/chat_models/portkey"?: + | typeof import("../chat_models/portkey.js") + | Promise; + "@langchain/community/callbacks/handlers/llmonitor"?: + | typeof import("../callbacks/handlers/llmonitor.js") + | Promise; + "@langchain/community/retrievers/amazon_kendra"?: + | typeof import("../retrievers/amazon_kendra.js") + | Promise; + "@langchain/community/retrievers/metal"?: + | typeof import("../retrievers/metal.js") + | Promise; + "@langchain/community/retrievers/supabase"?: + | typeof import("../retrievers/supabase.js") + | Promise; + "@langchain/community/retrievers/zep"?: + | typeof import("../retrievers/zep.js") + | Promise; + "@langchain/community/graphs/neo4j_graph"?: + | typeof import("../graphs/neo4j_graph.js") + | Promise; + "@langchain/community/document_transformers/html_to_text"?: + | typeof import("../document_transformers/html_to_text.js") + | Promise; + "@langchain/community/document_transformers/mozilla_readability"?: + | typeof import("../document_transformers/mozilla_readability.js") + | Promise; + "@langchain/community/storage/convex"?: + | typeof import("../storage/convex.js") + | Promise; + "@langchain/community/storage/ioredis"?: + | typeof import("../storage/ioredis.js") + | Promise; + "@langchain/community/storage/upstash_redis"?: + | typeof import("../storage/upstash_redis.js") + | Promise; + "@langchain/community/storage/vercel_kv"?: + | typeof import("../storage/vercel_kv.js") + | Promise; + "@langchain/community/stores/message/cassandra"?: + | typeof import("../stores/message/cassandra.js") + | Promise; + "@langchain/community/stores/message/cloudflare_d1"?: + | typeof import("../stores/message/cloudflare_d1.js") + | Promise; + "@langchain/community/stores/message/convex"?: + | typeof import("../stores/message/convex.js") + | Promise; + "@langchain/community/stores/message/dynamodb"?: + | typeof import("../stores/message/dynamodb.js") + | Promise; + "@langchain/community/stores/message/firestore"?: + | typeof import("../stores/message/firestore.js") + | Promise; + "@langchain/community/stores/message/ioredis"?: + | typeof import("../stores/message/ioredis.js") + | Promise; + "@langchain/community/stores/message/momento"?: + | typeof import("../stores/message/momento.js") + | Promise; + "@langchain/community/stores/message/mongodb"?: + | typeof import("../stores/message/mongodb.js") + | Promise; + "@langchain/community/stores/message/planetscale"?: + | typeof import("../stores/message/planetscale.js") + | Promise; + "@langchain/community/stores/message/redis"?: + | typeof import("../stores/message/redis.js") + | Promise; + "@langchain/community/stores/message/upstash_redis"?: + | typeof import("../stores/message/upstash_redis.js") + | Promise; + "@langchain/community/stores/message/xata"?: + | typeof import("../stores/message/xata.js") + | Promise; +} + +export interface SecretMap { + AWS_ACCESS_KEY_ID?: string; + AWS_SECRETE_ACCESS_KEY?: string; + AWS_SECRET_ACCESS_KEY?: string; + AWS_SESSION_TOKEN?: string; + BAIDU_API_KEY?: string; + BAIDU_SECRET_KEY?: string; + BEDROCK_AWS_ACCESS_KEY_ID?: string; + BEDROCK_AWS_SECRET_ACCESS_KEY?: string; + CLOUDFLARE_API_TOKEN?: string; + COHERE_API_KEY?: string; + DATABERRY_API_KEY?: string; + FIREWORKS_API_KEY?: string; + GOOGLE_API_KEY?: string; + GOOGLE_PALM_API_KEY?: string; + GOOGLE_PLACES_API_KEY?: string; + GOOGLE_VERTEX_AI_WEB_CREDENTIALS?: string; + GRADIENT_ACCESS_TOKEN?: string; + GRADIENT_WORKSPACE_ID?: string; + HUGGINGFACEHUB_API_KEY?: string; + IBM_CLOUD_API_KEY?: string; + IFLYTEK_API_KEY?: string; + IFLYTEK_API_SECRET?: string; + MILVUS_PASSWORD?: string; + MILVUS_SSL?: string; + MILVUS_USERNAME?: string; + MINIMAX_API_KEY?: string; + MINIMAX_GROUP_ID?: string; + PLANETSCALE_DATABASE_URL?: string; + PLANETSCALE_HOST?: string; + PLANETSCALE_PASSWORD?: string; + PLANETSCALE_USERNAME?: string; + QDRANT_API_KEY?: string; + QDRANT_URL?: string; + REDIS_PASSWORD?: string; + REDIS_URL?: string; + REDIS_USERNAME?: string; + REPLICATE_API_TOKEN?: string; + SEARXNG_API_BASE?: string; + UPSTASH_REDIS_REST_TOKEN?: string; + UPSTASH_REDIS_REST_URL?: string; + VECTARA_API_KEY?: string; + VECTARA_CORPUS_ID?: string; + VECTARA_CUSTOMER_ID?: string; + WATSONX_PROJECT_ID?: string; + WRITER_API_KEY?: string; + WRITER_ORG_ID?: string; + YC_API_KEY?: string; + YC_IAM_TOKEN?: string; + ZEP_API_KEY?: string; + ZEP_API_URL?: string; +} diff --git a/libs/langchain-community/src/load/index.ts b/libs/langchain-community/src/load/index.ts new file mode 100644 index 000000000000..eaf66e52d699 --- /dev/null +++ b/libs/langchain-community/src/load/index.ts @@ -0,0 +1,3 @@ +export { type OptionalImportMap, type SecretMap } from "./import_type.js"; +export * as importMap from "./import_map.js"; +export { optionalImportEntrypoints } from "./import_constants.js"; diff --git a/libs/langchain-community/src/load/map_keys.ts b/libs/langchain-community/src/load/map_keys.ts new file mode 100644 index 000000000000..93a0ea6e4fa7 --- /dev/null +++ b/libs/langchain-community/src/load/map_keys.ts @@ -0,0 +1,4 @@ +export interface SerializedFields { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + [key: string]: any; +} diff --git a/libs/langchain-community/src/load/serializable.ts b/libs/langchain-community/src/load/serializable.ts new file mode 100644 index 000000000000..b196ae7faac5 --- /dev/null +++ b/libs/langchain-community/src/load/serializable.ts @@ -0,0 +1 @@ +export * from "@langchain/core/load/serializable"; diff --git a/libs/langchain-community/src/retrievers/amazon_kendra.ts b/libs/langchain-community/src/retrievers/amazon_kendra.ts new file mode 100644 index 000000000000..d61600fbe76a --- /dev/null +++ b/libs/langchain-community/src/retrievers/amazon_kendra.ts @@ -0,0 +1,317 @@ +import { + AttributeFilter, + DocumentAttribute, + DocumentAttributeValue, + KendraClient, + KendraClientConfig, + QueryCommand, + QueryCommandOutput, + QueryResultItem, + RetrieveCommand, + RetrieveCommandOutput, + RetrieveResultItem, +} from "@aws-sdk/client-kendra"; + +import { BaseRetriever } from "@langchain/core/retrievers"; +import { Document } from "@langchain/core/documents"; + +/** + * Interface for the arguments required to initialize an + * AmazonKendraRetriever instance. + */ +export interface AmazonKendraRetrieverArgs { + indexId: string; + topK: number; + region: string; + attributeFilter?: AttributeFilter; + clientOptions?: KendraClientConfig; +} + +/** + * Class for interacting with Amazon Kendra, an intelligent search service + * provided by AWS. Extends the BaseRetriever class. + * @example + * ```typescript + * const retriever = new AmazonKendraRetriever({ + * topK: 10, + * indexId: "YOUR_INDEX_ID", + * region: "us-east-2", + * clientOptions: { + * credentials: { + * accessKeyId: "YOUR_ACCESS_KEY_ID", + * secretAccessKey: "YOUR_SECRET_ACCESS_KEY", + * }, + * }, + * }); + * + * const docs = await retriever.getRelevantDocuments("How are clouds formed?"); + * ``` + */ +export class AmazonKendraRetriever extends BaseRetriever { + static lc_name() { + return "AmazonKendraRetriever"; + } + + lc_namespace = ["langchain", "retrievers", "amazon_kendra"]; + + indexId: string; + + topK: number; + + kendraClient: KendraClient; + + attributeFilter?: AttributeFilter; + + constructor({ + indexId, + topK = 10, + clientOptions, + attributeFilter, + region, + }: AmazonKendraRetrieverArgs) { + super(); + + if (!region) { + throw new Error("Please pass regionName field to the constructor!"); + } + + if (!indexId) { + throw new Error("Please pass Kendra Index Id to the constructor"); + } + + this.topK = topK; + this.kendraClient = new KendraClient({ + region, + ...clientOptions, + }); + this.attributeFilter = attributeFilter; + this.indexId = indexId; + } + + // A method to combine title and excerpt into a single string. + /** + * Combines title and excerpt into a single string. + * @param title The title of the document. + * @param excerpt An excerpt from the document. + * @returns A single string combining the title and excerpt. + */ + combineText(title?: string, excerpt?: string): string { + let text = ""; + if (title) { + text += `Document Title: ${title}\n`; + } + if (excerpt) { + text += `Document Excerpt: \n${excerpt}\n`; + } + return text; + } + + // A method to clean the result text by replacing sequences of whitespace with a single space and removing ellipses. + /** + * Cleans the result text by replacing sequences of whitespace with a + * single space and removing ellipses. + * @param resText The result text to clean. + * @returns The cleaned result text. + */ + cleanResult(resText: string) { + const res = resText.replace(/\s+/g, " ").replace(/\.\.\./g, ""); + return res; + } + + // A method to extract the attribute value from a DocumentAttributeValue object. + /** + * Extracts the attribute value from a DocumentAttributeValue object. + * @param value The DocumentAttributeValue object to extract the value from. + * @returns The extracted attribute value. + */ + getDocAttributeValue(value: DocumentAttributeValue) { + if (value.DateValue) { + return value.DateValue; + } + if (value.LongValue) { + return value.LongValue; + } + if (value.StringListValue) { + return value.StringListValue; + } + if (value.StringValue) { + return value.StringValue; + } + return ""; + } + + // A method to extract the attribute key-value pairs from an array of DocumentAttribute objects. + /** + * Extracts the attribute key-value pairs from an array of + * DocumentAttribute objects. + * @param documentAttributes The array of DocumentAttribute objects to extract the key-value pairs from. + * @returns An object containing the extracted attribute key-value pairs. + */ + getDocAttributes(documentAttributes?: DocumentAttribute[]): { + [key: string]: unknown; + } { + const attributes: { [key: string]: unknown } = {}; + if (documentAttributes) { + for (const attr of documentAttributes) { + if (attr.Key && attr.Value) { + attributes[attr.Key] = this.getDocAttributeValue(attr.Value); + } + } + } + return attributes; + } + + // A method to convert a RetrieveResultItem object into a Document object. + /** + * Converts a RetrieveResultItem object into a Document object. + * @param item The RetrieveResultItem object to convert. + * @returns A Document object. + */ + convertRetrieverItem(item: RetrieveResultItem) { + const title = item.DocumentTitle || ""; + const excerpt = item.Content ? this.cleanResult(item.Content) : ""; + const pageContent = this.combineText(title, excerpt); + const source = item.DocumentURI; + const attributes = this.getDocAttributes(item.DocumentAttributes); + const metadata = { + source, + title, + excerpt, + document_attributes: attributes, + }; + + return new Document({ pageContent, metadata }); + } + + // A method to extract the top-k documents from a RetrieveCommandOutput object. + /** + * Extracts the top-k documents from a RetrieveCommandOutput object. + * @param response The RetrieveCommandOutput object to extract the documents from. + * @param pageSize The number of documents to extract. + * @returns An array of Document objects. + */ + getRetrieverDocs( + response: RetrieveCommandOutput, + pageSize: number + ): Document[] { + if (!response.ResultItems) return []; + const { length } = response.ResultItems; + const count = length < pageSize ? length : pageSize; + + return response.ResultItems.slice(0, count).map((item) => + this.convertRetrieverItem(item) + ); + } + + // A method to extract the excerpt text from a QueryResultItem object. + /** + * Extracts the excerpt text from a QueryResultItem object. + * @param item The QueryResultItem object to extract the excerpt text from. + * @returns The extracted excerpt text. + */ + getQueryItemExcerpt(item: QueryResultItem) { + if ( + item.AdditionalAttributes && + item.AdditionalAttributes.length && + item.AdditionalAttributes[0].Key === "AnswerText" + ) { + if (!item.AdditionalAttributes) { + return ""; + } + if (!item.AdditionalAttributes[0]) { + return ""; + } + + return this.cleanResult( + item.AdditionalAttributes[0].Value?.TextWithHighlightsValue?.Text || "" + ); + } else if (item.DocumentExcerpt) { + return this.cleanResult(item.DocumentExcerpt.Text || ""); + } else { + return ""; + } + } + + // A method to convert a QueryResultItem object into a Document object. + /** + * Converts a QueryResultItem object into a Document object. + * @param item The QueryResultItem object to convert. + * @returns A Document object. + */ + convertQueryItem(item: QueryResultItem) { + const title = item.DocumentTitle?.Text || ""; + const excerpt = this.getQueryItemExcerpt(item); + const pageContent = this.combineText(title, excerpt); + const source = item.DocumentURI; + const attributes = this.getDocAttributes(item.DocumentAttributes); + const metadata = { + source, + title, + excerpt, + document_attributes: attributes, + }; + + return new Document({ pageContent, metadata }); + } + + // A method to extract the top-k documents from a QueryCommandOutput object. + /** + * Extracts the top-k documents from a QueryCommandOutput object. + * @param response The QueryCommandOutput object to extract the documents from. + * @param pageSize The number of documents to extract. + * @returns An array of Document objects. + */ + getQueryDocs(response: QueryCommandOutput, pageSize: number) { + if (!response.ResultItems) return []; + const { length } = response.ResultItems; + const count = length < pageSize ? length : pageSize; + return response.ResultItems.slice(0, count).map((item) => + this.convertQueryItem(item) + ); + } + + // A method to send a retrieve or query request to Kendra and return the top-k documents. + /** + * Sends a retrieve or query request to Kendra and returns the top-k + * documents. + * @param query The query to send to Kendra. + * @param topK The number of top documents to return. + * @param attributeFilter Optional filter to apply when retrieving documents. + * @returns A Promise that resolves to an array of Document objects. + */ + async queryKendra( + query: string, + topK: number, + attributeFilter?: AttributeFilter + ) { + const retrieveCommand = new RetrieveCommand({ + IndexId: this.indexId, + QueryText: query, + PageSize: topK, + AttributeFilter: attributeFilter, + }); + + const retrieveResponse = await this.kendraClient.send(retrieveCommand); + const retriveLength = retrieveResponse.ResultItems?.length; + + if (retriveLength === 0) { + // Retrieve API returned 0 results, call query API + const queryCommand = new QueryCommand({ + IndexId: this.indexId, + QueryText: query, + PageSize: topK, + AttributeFilter: attributeFilter, + }); + + const queryResponse = await this.kendraClient.send(queryCommand); + return this.getQueryDocs(queryResponse, this.topK); + } else { + return this.getRetrieverDocs(retrieveResponse, this.topK); + } + } + + async _getRelevantDocuments(query: string): Promise { + const docs = await this.queryKendra(query, this.topK, this.attributeFilter); + return docs; + } +} diff --git a/libs/langchain-community/src/retrievers/chaindesk.ts b/libs/langchain-community/src/retrievers/chaindesk.ts new file mode 100644 index 000000000000..08d6cd946fa1 --- /dev/null +++ b/libs/langchain-community/src/retrievers/chaindesk.ts @@ -0,0 +1,103 @@ +import { + BaseRetriever, + type BaseRetrieverInput, +} from "@langchain/core/retrievers"; +import { Document } from "@langchain/core/documents"; +import { + AsyncCaller, + type AsyncCallerParams, +} from "@langchain/core/utils/async_caller"; + +export interface ChaindeskRetrieverArgs + extends AsyncCallerParams, + BaseRetrieverInput { + datastoreId: string; + topK?: number; + filter?: Record; + apiKey?: string; +} + +interface Berry { + text: string; + score: number; + source?: string; + [key: string]: unknown; +} + +/** + * @example + * ```typescript + * const retriever = new ChaindeskRetriever({ + * datastoreId: "DATASTORE_ID", + * apiKey: "CHAINDESK_API_KEY", + * topK: 8, + * }); + * const docs = await retriever.getRelevantDocuments("hello"); + * ``` + */ +export class ChaindeskRetriever extends BaseRetriever { + static lc_name() { + return "ChaindeskRetriever"; + } + + lc_namespace = ["langchain", "retrievers", "chaindesk"]; + + caller: AsyncCaller; + + datastoreId: string; + + topK?: number; + + filter?: Record; + + apiKey?: string; + + constructor({ + datastoreId, + apiKey, + topK, + filter, + ...rest + }: ChaindeskRetrieverArgs) { + super(); + + this.caller = new AsyncCaller(rest); + this.datastoreId = datastoreId; + this.apiKey = apiKey; + this.topK = topK; + this.filter = filter; + } + + async getRelevantDocuments(query: string): Promise { + const r = await this.caller.call( + fetch, + `https://app.chaindesk.ai/api/datastores/${this.datastoreId}/query`, + { + method: "POST", + body: JSON.stringify({ + query, + ...(this.topK ? { topK: this.topK } : {}), + ...(this.filter ? { filters: this.filter } : {}), + }), + headers: { + "Content-Type": "application/json", + ...(this.apiKey ? { Authorization: `Bearer ${this.apiKey}` } : {}), + }, + } + ); + + const { results } = (await r.json()) as { results: Berry[] }; + + return results.map( + ({ text, score, source, ...rest }) => + new Document({ + pageContent: text, + metadata: { + score, + source, + ...rest, + }, + }) + ); + } +} diff --git a/libs/langchain-community/src/retrievers/databerry.ts b/libs/langchain-community/src/retrievers/databerry.ts new file mode 100644 index 000000000000..6b34541121cd --- /dev/null +++ b/libs/langchain-community/src/retrievers/databerry.ts @@ -0,0 +1,100 @@ +import { + BaseRetriever, + type BaseRetrieverInput, +} from "@langchain/core/retrievers"; +import { Document } from "@langchain/core/documents"; +import { + AsyncCaller, + AsyncCallerParams, +} from "@langchain/core/utils/async_caller"; + +/** + * Interface for the arguments required to create a new instance of + * DataberryRetriever. + */ +export interface DataberryRetrieverArgs + extends AsyncCallerParams, + BaseRetrieverInput { + datastoreUrl: string; + topK?: number; + apiKey?: string; +} + +/** + * Interface for the structure of a Berry object returned by the Databerry + * API. + */ +interface Berry { + text: string; + score: number; + source?: string; + [key: string]: unknown; +} + +/** + * A specific implementation of a document retriever for the Databerry + * API. It extends the BaseRetriever class, which is an abstract base + * class for a document retrieval system in LangChain. + */ +/** @deprecated Use "langchain/retrievers/chaindesk" instead */ +export class DataberryRetriever extends BaseRetriever { + static lc_name() { + return "DataberryRetriever"; + } + + lc_namespace = ["langchain", "retrievers", "databerry"]; + + get lc_secrets() { + return { apiKey: "DATABERRY_API_KEY" }; + } + + get lc_aliases() { + return { apiKey: "api_key" }; + } + + caller: AsyncCaller; + + datastoreUrl: string; + + topK?: number; + + apiKey?: string; + + constructor(fields: DataberryRetrieverArgs) { + super(fields); + const { datastoreUrl, apiKey, topK, ...rest } = fields; + + this.caller = new AsyncCaller(rest); + this.datastoreUrl = datastoreUrl; + this.apiKey = apiKey; + this.topK = topK; + } + + async _getRelevantDocuments(query: string): Promise { + const r = await this.caller.call(fetch, this.datastoreUrl, { + method: "POST", + body: JSON.stringify({ + query, + ...(this.topK ? { topK: this.topK } : {}), + }), + headers: { + "Content-Type": "application/json", + ...(this.apiKey ? { Authorization: `Bearer ${this.apiKey}` } : {}), + }, + }); + + const { results } = (await r.json()) as { results: Berry[] }; + + return results.map( + ({ text, score, source, ...rest }) => + new Document({ + pageContent: text, + metadata: { + score, + source, + ...rest, + }, + }) + ); + } +} diff --git a/libs/langchain-community/src/retrievers/metal.ts b/libs/langchain-community/src/retrievers/metal.ts new file mode 100644 index 000000000000..da0bb4749912 --- /dev/null +++ b/libs/langchain-community/src/retrievers/metal.ts @@ -0,0 +1,70 @@ +import Metal from "@getmetal/metal-sdk"; + +import { BaseRetriever, BaseRetrieverInput } from "@langchain/core/retrievers"; +import { Document } from "@langchain/core/documents"; + +/** + * Interface for the fields required during the initialization of a + * `MetalRetriever` instance. It extends the `BaseRetrieverInput` + * interface and adds a `client` field of type `Metal`. + */ +export interface MetalRetrieverFields extends BaseRetrieverInput { + client: Metal; +} + +/** + * Interface to represent a response item from the Metal service. It + * contains a `text` field and an index signature to allow for additional + * unknown properties. + */ +interface ResponseItem { + text: string; + [key: string]: unknown; +} + +/** + * Class used to interact with the Metal service, a managed retrieval & + * memory platform. It allows you to index your data into Metal and run + * semantic search and retrieval on it. It extends the `BaseRetriever` + * class and requires a `Metal` instance and a dictionary of parameters to + * pass to the Metal API during its initialization. + * @example + * ```typescript + * const retriever = new MetalRetriever({ + * client: new Metal( + * process.env.METAL_API_KEY, + * process.env.METAL_CLIENT_ID, + * process.env.METAL_INDEX_ID, + * ), + * }); + * const docs = await retriever.getRelevantDocuments("hello"); + * ``` + */ +export class MetalRetriever extends BaseRetriever { + static lc_name() { + return "MetalRetriever"; + } + + lc_namespace = ["langchain", "retrievers", "metal"]; + + private client: Metal; + + constructor(fields: MetalRetrieverFields) { + super(fields); + + this.client = fields.client; + } + + async _getRelevantDocuments(query: string): Promise { + const res = await this.client.search({ text: query }); + + const items = ("data" in res ? res.data : res) as ResponseItem[]; + return items.map( + ({ text, metadata }) => + new Document({ + pageContent: text, + metadata: metadata as Record, + }) + ); + } +} diff --git a/libs/langchain-community/src/retrievers/supabase.ts b/libs/langchain-community/src/retrievers/supabase.ts new file mode 100644 index 000000000000..022bfc339570 --- /dev/null +++ b/libs/langchain-community/src/retrievers/supabase.ts @@ -0,0 +1,241 @@ +import type { SupabaseClient } from "@supabase/supabase-js"; +import { Embeddings } from "@langchain/core/embeddings"; +import { Document } from "@langchain/core/documents"; +import { + BaseRetriever, + type BaseRetrieverInput, +} from "@langchain/core/retrievers"; +import { + CallbackManagerForRetrieverRun, + Callbacks, +} from "@langchain/core/callbacks/manager"; + +interface SearchEmbeddingsParams { + query_embedding: number[]; + match_count: number; // int + filter?: Record; // jsonb +} + +interface SearchKeywordParams { + query_text: string; + match_count: number; // int +} + +interface SearchResponseRow { + id: number; + content: string; + metadata: object; + similarity: number; +} + +type SearchResult = [Document, number, number]; + +export interface SupabaseLibArgs extends BaseRetrieverInput { + client: SupabaseClient; + /** + * The table name on Supabase. Defaults to "documents". + */ + tableName?: string; + /** + * The name of the Similarity search function on Supabase. Defaults to "match_documents". + */ + similarityQueryName?: string; + /** + * The name of the Keyword search function on Supabase. Defaults to "kw_match_documents". + */ + keywordQueryName?: string; + /** + * The number of documents to return from the similarity search. Defaults to 2. + */ + similarityK?: number; + /** + * The number of documents to return from the keyword search. Defaults to 2. + */ + keywordK?: number; +} + +export interface SupabaseHybridSearchParams { + query: string; + similarityK: number; + keywordK: number; +} + +/** + * Class for performing hybrid search operations on a Supabase database. + * It extends the `BaseRetriever` class and implements methods for + * similarity search, keyword search, and hybrid search. + */ +export class SupabaseHybridSearch extends BaseRetriever { + static lc_name() { + return "SupabaseHybridSearch"; + } + + lc_namespace = ["langchain", "retrievers", "supabase"]; + + similarityK: number; + + query: string; + + keywordK: number; + + similarityQueryName: string; + + client: SupabaseClient; + + tableName: string; + + keywordQueryName: string; + + embeddings: Embeddings; + + constructor(embeddings: Embeddings, args: SupabaseLibArgs) { + super(args); + this.embeddings = embeddings; + this.client = args.client; + this.tableName = args.tableName || "documents"; + this.similarityQueryName = args.similarityQueryName || "match_documents"; + this.keywordQueryName = args.keywordQueryName || "kw_match_documents"; + this.similarityK = args.similarityK || 2; + this.keywordK = args.keywordK || 2; + } + + /** + * Performs a similarity search on the Supabase database using the + * provided query and returns the top 'k' similar documents. + * @param query The query to use for the similarity search. + * @param k The number of top similar documents to return. + * @param _callbacks Optional callbacks to pass to the embedQuery method. + * @returns A promise that resolves to an array of search results. Each result is a tuple containing a Document, its similarity score, and its ID. + */ + protected async similaritySearch( + query: string, + k: number, + _callbacks?: Callbacks // implement passing to embedQuery later + ): Promise { + const embeddedQuery = await this.embeddings.embedQuery(query); + + const matchDocumentsParams: SearchEmbeddingsParams = { + query_embedding: embeddedQuery, + match_count: k, + }; + + if (Object.keys(this.metadata ?? {}).length > 0) { + matchDocumentsParams.filter = this.metadata; + } + + const { data: searches, error } = await this.client.rpc( + this.similarityQueryName, + matchDocumentsParams + ); + + if (error) { + throw new Error( + `Error searching for documents: ${error.code} ${error.message} ${error.details}` + ); + } + + return (searches as SearchResponseRow[]).map((resp) => [ + new Document({ + metadata: resp.metadata, + pageContent: resp.content, + }), + resp.similarity, + resp.id, + ]); + } + + /** + * Performs a keyword search on the Supabase database using the provided + * query and returns the top 'k' documents that match the keywords. + * @param query The query to use for the keyword search. + * @param k The number of top documents to return that match the keywords. + * @returns A promise that resolves to an array of search results. Each result is a tuple containing a Document, its similarity score multiplied by 10, and its ID. + */ + protected async keywordSearch( + query: string, + k: number + ): Promise { + const kwMatchDocumentsParams: SearchKeywordParams = { + query_text: query, + match_count: k, + }; + + const { data: searches, error } = await this.client.rpc( + this.keywordQueryName, + kwMatchDocumentsParams + ); + + if (error) { + throw new Error( + `Error searching for documents: ${error.code} ${error.message} ${error.details}` + ); + } + + return (searches as SearchResponseRow[]).map((resp) => [ + new Document({ + metadata: resp.metadata, + pageContent: resp.content, + }), + resp.similarity * 10, + resp.id, + ]); + } + + /** + * Combines the results of the `similaritySearch` and `keywordSearch` + * methods and returns the top 'k' documents based on a combination of + * similarity and keyword matching. + * @param query The query to use for the hybrid search. + * @param similarityK The number of top similar documents to return. + * @param keywordK The number of top documents to return that match the keywords. + * @param callbacks Optional callbacks to pass to the similaritySearch method. + * @returns A promise that resolves to an array of search results. Each result is a tuple containing a Document, its combined score, and its ID. + */ + protected async hybridSearch( + query: string, + similarityK: number, + keywordK: number, + callbacks?: Callbacks + ): Promise { + const similarity_search = this.similaritySearch( + query, + similarityK, + callbacks + ); + + const keyword_search = this.keywordSearch(query, keywordK); + + return Promise.all([similarity_search, keyword_search]) + .then((results) => results.flat()) + .then((results) => { + const picks = new Map(); + + results.forEach((result) => { + const id = result[2]; + const nextScore = result[1]; + const prevScore = picks.get(id)?.[1]; + + if (prevScore === undefined || nextScore > prevScore) { + picks.set(id, result); + } + }); + + return Array.from(picks.values()); + }) + .then((results) => results.sort((a, b) => b[1] - a[1])); + } + + async _getRelevantDocuments( + query: string, + runManager?: CallbackManagerForRetrieverRun + ): Promise { + const searchResults = await this.hybridSearch( + query, + this.similarityK, + this.keywordK, + runManager?.getChild("hybrid_search") + ); + + return searchResults.map(([doc]) => doc); + } +} diff --git a/libs/langchain-community/src/retrievers/tavily_search_api.ts b/libs/langchain-community/src/retrievers/tavily_search_api.ts new file mode 100644 index 000000000000..86d22ce1d048 --- /dev/null +++ b/libs/langchain-community/src/retrievers/tavily_search_api.ts @@ -0,0 +1,143 @@ +import { Document } from "@langchain/core/documents"; +import { CallbackManagerForRetrieverRun } from "@langchain/core/callbacks/manager"; +import { + BaseRetriever, + type BaseRetrieverInput, +} from "@langchain/core/retrievers"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +/** + * Options for the HydeRetriever class, which includes a BaseLanguageModel + * instance, a VectorStore instance, and an optional promptTemplate which + * can either be a BasePromptTemplate instance or a PromptKey. + */ +export type TavilySearchAPIRetrieverFields = BaseRetrieverInput & { + k?: number; + includeGeneratedAnswer?: boolean; + includeRawContent?: boolean; + includeImages?: boolean; + searchDepth?: "basic" | "advanced"; + includeDomains?: string[]; + excludeDomains?: string[]; + kwargs?: Record; + apiKey?: string; +}; + +/** + * A class for retrieving documents related to a given search term + * using the Tavily Search API. + */ +export class TavilySearchAPIRetriever extends BaseRetriever { + static lc_name() { + return "TavilySearchAPIRetriever"; + } + + get lc_namespace(): string[] { + return ["langchain", "retrievers", "tavily_search_api"]; + } + + k = 10; + + includeGeneratedAnswer = false; + + includeRawContent = false; + + includeImages = false; + + searchDepth = "basic"; + + includeDomains?: string[]; + + excludeDomains?: string[]; + + kwargs: Record = {}; + + apiKey?: string; + + constructor(fields?: TavilySearchAPIRetrieverFields) { + super(fields); + this.k = fields?.k ?? this.k; + this.includeGeneratedAnswer = + fields?.includeGeneratedAnswer ?? this.includeGeneratedAnswer; + this.includeRawContent = + fields?.includeRawContent ?? this.includeRawContent; + this.includeImages = fields?.includeImages ?? this.includeImages; + this.searchDepth = fields?.searchDepth ?? this.searchDepth; + this.includeDomains = fields?.includeDomains ?? this.includeDomains; + this.excludeDomains = fields?.excludeDomains ?? this.excludeDomains; + this.kwargs = fields?.kwargs ?? this.kwargs; + this.apiKey = fields?.apiKey ?? getEnvironmentVariable("TAVILY_API_KEY"); + if (this.apiKey === undefined) { + throw new Error( + `No Tavily API key found. Either set an environment variable named "TAVILY_API_KEY" or pass an API key as "apiKey".` + ); + } + } + + async _getRelevantDocuments( + query: string, + _runManager?: CallbackManagerForRetrieverRun + ): Promise { + const body: Record = { + query, + include_answer: this.includeGeneratedAnswer, + include_raw_content: this.includeRawContent, + include_images: this.includeImages, + max_results: this.k, + search_depth: this.searchDepth, + api_key: this.apiKey, + }; + if (this.includeDomains) { + body.include_domains = this.includeDomains; + } + if (this.excludeDomains) { + body.exclude_domains = this.excludeDomains; + } + + const response = await fetch("https://api.tavily.com/search", { + method: "POST", + headers: { + "content-type": "application/json", + }, + body: JSON.stringify({ ...body, ...this.kwargs }), + }); + const json = await response.json(); + if (!response.ok) { + throw new Error( + `Request failed with status code ${response.status}: ${json.error}` + ); + } + if (!Array.isArray(json.results)) { + throw new Error(`Could not parse Tavily results. Please try again.`); + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const docs: Document[] = json.results.map((result: any) => { + const pageContent = this.includeRawContent + ? result.raw_content + : result.content; + const metadata = { + title: result.title, + source: result.url, + ...Object.fromEntries( + Object.entries(result).filter( + ([k]) => !["content", "title", "url", "raw_content"].includes(k) + ) + ), + images: json.images, + }; + return new Document({ pageContent, metadata }); + }); + if (this.includeGeneratedAnswer) { + docs.push( + new Document({ + pageContent: json.answer, + metadata: { + title: "Suggested Answer", + source: "https://tavily.com/", + }, + }) + ); + } + return docs; + } +} diff --git a/libs/langchain-community/src/retrievers/zep.ts b/libs/langchain-community/src/retrievers/zep.ts new file mode 100644 index 000000000000..5ea21ab6c83f --- /dev/null +++ b/libs/langchain-community/src/retrievers/zep.ts @@ -0,0 +1,169 @@ +import { + MemorySearchPayload, + MemorySearchResult, + NotFoundError, + ZepClient, +} from "@getzep/zep-js"; +import { BaseRetriever, BaseRetrieverInput } from "@langchain/core/retrievers"; +import { Document } from "@langchain/core/documents"; + +/** + * Configuration interface for the ZepRetriever class. Extends the + * BaseRetrieverInput interface. + * + * @argument {string} sessionId - The ID of the Zep session. + * @argument {string} url - The URL of the Zep API. + * @argument {number} [topK] - The number of results to return. + * @argument {string} [apiKey] - The API key for the Zep API. + * @argument [searchScope] [searchScope] - The scope of the search: "messages" or "summary". + * @argument [searchType] [searchType] - The type of search to perform: "similarity" or "mmr". + * @argument {number} [mmrLambda] - The lambda value for the MMR search. + * @argument {Record} [filter] - The metadata filter to apply to the search. + */ +export interface ZepRetrieverConfig extends BaseRetrieverInput { + sessionId: string; + url: string; + topK?: number; + apiKey?: string; + searchScope?: "messages" | "summary"; + searchType?: "similarity" | "mmr"; + mmrLambda?: number; + filter?: Record; +} + +/** + * Class for retrieving information from a Zep long-term memory store. + * Extends the BaseRetriever class. + * @example + * ```typescript + * const retriever = new ZepRetriever({ + * url: "http: + * sessionId: "session_exampleUUID", + * topK: 3, + * }); + * const query = "Can I drive red cars in France?"; + * const docs = await retriever.getRelevantDocuments(query); + * ``` + */ +export class ZepRetriever extends BaseRetriever { + static lc_name() { + return "ZepRetriever"; + } + + lc_namespace = ["langchain", "retrievers", "zep"]; + + get lc_secrets(): { [key: string]: string } | undefined { + return { + apiKey: "ZEP_API_KEY", + url: "ZEP_API_URL", + }; + } + + get lc_aliases(): { [key: string]: string } | undefined { + return { apiKey: "api_key" }; + } + + zepClientPromise: Promise; + + private sessionId: string; + + private topK?: number; + + private searchScope?: "messages" | "summary"; + + private searchType?: "similarity" | "mmr"; + + private mmrLambda?: number; + + private filter?: Record; + + constructor(config: ZepRetrieverConfig) { + super(config); + this.sessionId = config.sessionId; + this.topK = config.topK; + this.searchScope = config.searchScope; + this.searchType = config.searchType; + this.mmrLambda = config.mmrLambda; + this.filter = config.filter; + this.zepClientPromise = ZepClient.init(config.url, config.apiKey); + } + + /** + * Converts an array of message search results to an array of Document objects. + * @param {MemorySearchResult[]} results - The array of search results. + * @returns {Document[]} An array of Document objects representing the search results. + */ + private searchMessageResultToDoc(results: MemorySearchResult[]): Document[] { + return results + .filter((r) => r.message) + .map( + ({ + message: { content, metadata: messageMetadata } = {}, + dist, + ...rest + }) => + new Document({ + pageContent: content ?? "", + metadata: { score: dist, ...messageMetadata, ...rest }, + }) + ); + } + + /** + * Converts an array of summary search results to an array of Document objects. + * @param {MemorySearchResult[]} results - The array of search results. + * @returns {Document[]} An array of Document objects representing the search results. + */ + private searchSummaryResultToDoc(results: MemorySearchResult[]): Document[] { + return results + .filter((r) => r.summary) + .map( + ({ + summary: { content, metadata: summaryMetadata } = {}, + dist, + ...rest + }) => + new Document({ + pageContent: content ?? "", + metadata: { score: dist, ...summaryMetadata, ...rest }, + }) + ); + } + + /** + * Retrieves the relevant documents based on the given query. + * @param {string} query - The query string. + * @returns {Promise} A promise that resolves to an array of relevant Document objects. + */ + async _getRelevantDocuments(query: string): Promise { + const payload: MemorySearchPayload = { + text: query, + metadata: this.filter, + search_scope: this.searchScope, + search_type: this.searchType, + mmr_lambda: this.mmrLambda, + }; + // Wait for ZepClient to be initialized + const zepClient = await this.zepClientPromise; + if (!zepClient) { + throw new Error("ZepClient is not initialized"); + } + try { + const results: MemorySearchResult[] = await zepClient.memory.searchMemory( + this.sessionId, + payload, + this.topK + ); + return this.searchScope === "summary" + ? this.searchSummaryResultToDoc(results) + : this.searchMessageResultToDoc(results); + } catch (error) { + // eslint-disable-next-line no-instanceof/no-instanceof + if (error instanceof NotFoundError) { + return Promise.resolve([]); // Return an empty Document array + } + // If it's not a NotFoundError, throw the error again + throw error; + } + } +} diff --git a/libs/langchain-community/src/storage/convex.ts b/libs/langchain-community/src/storage/convex.ts new file mode 100644 index 000000000000..adbddf26c33c --- /dev/null +++ b/libs/langchain-community/src/storage/convex.ts @@ -0,0 +1,224 @@ +// eslint-disable-next-line import/no-extraneous-dependencies +import { + FieldPaths, + FunctionReference, + GenericActionCtx, + GenericDataModel, + NamedTableInfo, + TableNamesInDataModel, + VectorIndexNames, + makeFunctionReference, +} from "convex/server"; +// eslint-disable-next-line import/no-extraneous-dependencies +import { Value } from "convex/values"; +import { BaseStore } from "@langchain/core/stores"; + +/** + * Type that defines the config required to initialize the + * ConvexKVStore class. It includes the table name, + * index name, field name. + */ +export type ConvexKVStoreConfig< + DataModel extends GenericDataModel, + TableName extends TableNamesInDataModel, + IndexName extends VectorIndexNames>, + KeyFieldName extends FieldPaths>, + ValueFieldName extends FieldPaths>, + UpsertMutation extends FunctionReference< + "mutation", + "internal", + { table: string; document: object } + >, + LookupQuery extends FunctionReference< + "query", + "internal", + { table: string; index: string; keyField: string; key: string }, + object[] + >, + DeleteManyMutation extends FunctionReference< + "mutation", + "internal", + { table: string; index: string; keyField: string; key: string } + > +> = { + readonly ctx: GenericActionCtx; + /** + * Defaults to "cache" + */ + readonly table?: TableName; + /** + * Defaults to "byKey" + */ + readonly index?: IndexName; + /** + * Defaults to "key" + */ + readonly keyField?: KeyFieldName; + /** + * Defaults to "value" + */ + readonly valueField?: ValueFieldName; + /** + * Defaults to `internal.langchain.db.upsert` + */ + readonly upsert?: UpsertMutation; + /** + * Defaults to `internal.langchain.db.lookup` + */ + readonly lookup?: LookupQuery; + /** + * Defaults to `internal.langchain.db.deleteMany` + */ + readonly deleteMany?: DeleteManyMutation; +}; + +/** + * Class that extends the BaseStore class to interact with a Convex + * database. It provides methods for getting, setting, and deleting key value pairs, + * as well as yielding keys from the database. + */ +export class ConvexKVStore< + T extends Value, + DataModel extends GenericDataModel, + TableName extends TableNamesInDataModel, + IndexName extends VectorIndexNames>, + KeyFieldName extends FieldPaths>, + ValueFieldName extends FieldPaths>, + UpsertMutation extends FunctionReference< + "mutation", + "internal", + { table: string; document: object } + >, + LookupQuery extends FunctionReference< + "query", + "internal", + { table: string; index: string; keyField: string; key: string }, + object[] + >, + DeleteManyMutation extends FunctionReference< + "mutation", + "internal", + { table: string; index: string; keyField: string; key: string } + > +> extends BaseStore { + lc_namespace = ["langchain", "storage", "convex"]; + + private readonly ctx: GenericActionCtx; + + private readonly table: TableName; + + private readonly index: IndexName; + + private readonly keyField: KeyFieldName; + + private readonly valueField: ValueFieldName; + + private readonly upsert: UpsertMutation; + + private readonly lookup: LookupQuery; + + private readonly deleteMany: DeleteManyMutation; + + constructor( + config: ConvexKVStoreConfig< + DataModel, + TableName, + IndexName, + KeyFieldName, + ValueFieldName, + UpsertMutation, + LookupQuery, + DeleteManyMutation + > + ) { + super(config); + this.ctx = config.ctx; + this.table = config.table ?? ("cache" as TableName); + this.index = config.index ?? ("byKey" as IndexName); + this.keyField = config.keyField ?? ("key" as KeyFieldName); + this.valueField = config.valueField ?? ("value" as ValueFieldName); + this.upsert = + // eslint-disable-next-line @typescript-eslint/no-explicit-any + config.upsert ?? (makeFunctionReference("langchain/db:upsert") as any); + this.lookup = + // eslint-disable-next-line @typescript-eslint/no-explicit-any + config.lookup ?? (makeFunctionReference("langchain/db:lookup") as any); + this.deleteMany = + config.deleteMany ?? + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (makeFunctionReference("langchain/db:deleteMany") as any); + } + + /** + * Gets multiple keys from the Convex database. + * @param keys Array of keys to be retrieved. + * @returns An array of retrieved values. + */ + async mget(keys: string[]) { + return (await Promise.all( + keys.map(async (key) => { + const found = (await this.ctx.runQuery(this.lookup, { + table: this.table, + index: this.index, + keyField: this.keyField, + key, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any)) as any; + return found.length > 0 ? found[0][this.valueField] : undefined; + }) + )) as (T | undefined)[]; + } + + /** + * Sets multiple keys in the Convex database. + * @param keyValuePairs Array of key-value pairs to be set. + * @returns Promise that resolves when all keys have been set. + */ + async mset(keyValuePairs: [string, T][]): Promise { + // TODO: Remove chunking when Convex handles the concurrent requests correctly + const PAGE_SIZE = 16; + for (let i = 0; i < keyValuePairs.length; i += PAGE_SIZE) { + await Promise.all( + keyValuePairs.slice(i, i + PAGE_SIZE).map(([key, value]) => + this.ctx.runMutation(this.upsert, { + table: this.table, + index: this.index, + keyField: this.keyField, + key, + document: { [this.keyField]: key, [this.valueField]: value }, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any) + ) + ); + } + } + + /** + * Deletes multiple keys from the Convex database. + * @param keys Array of keys to be deleted. + * @returns Promise that resolves when all keys have been deleted. + */ + async mdelete(keys: string[]): Promise { + await Promise.all( + keys.map((key) => + this.ctx.runMutation(this.deleteMany, { + table: this.table, + index: this.index, + keyField: this.keyField, + key, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any) + ) + ); + } + + /** + * Yields keys from the Convex database. + * @param prefix Optional prefix to filter the keys. + * @returns An AsyncGenerator that yields keys from the Convex database. + */ + // eslint-disable-next-line require-yield + async *yieldKeys(_prefix?: string): AsyncGenerator { + throw new Error("yieldKeys not implemented yet for ConvexKVStore"); + } +} diff --git a/libs/langchain-community/src/storage/ioredis.ts b/libs/langchain-community/src/storage/ioredis.ts new file mode 100644 index 000000000000..89deb527d571 --- /dev/null +++ b/libs/langchain-community/src/storage/ioredis.ts @@ -0,0 +1,159 @@ +import type { Redis } from "ioredis"; + +import { BaseStore } from "@langchain/core/stores"; + +/** + * Class that extends the BaseStore class to interact with a Redis + * database. It provides methods for getting, setting, and deleting data, + * as well as yielding keys from the database. + * @example + * ```typescript + * const store = new RedisByteStore({ client: new Redis({}) }); + * await store.mset([ + * [ + * "message:id:0", + * new TextEncoder().encode(JSON.stringify(new AIMessage("ai stuff..."))), + * ], + * [ + * "message:id:1", + * new TextEncoder().encode( + * JSON.stringify(new HumanMessage("human stuff...")), + * ), + * ], + * ]); + * const retrievedMessages = await store.mget(["message:id:0", "message:id:1"]); + * console.log(retrievedMessages.map((v) => new TextDecoder().decode(v))); + * const yieldedKeys = []; + * for await (const key of store.yieldKeys("message:id:")) { + * yieldedKeys.push(key); + * } + * console.log(yieldedKeys); + * await store.mdelete(yieldedKeys); + * ``` + */ +export class RedisByteStore extends BaseStore { + lc_namespace = ["langchain", "storage"]; + + protected client: Redis; + + protected ttl?: number; + + protected namespace?: string; + + protected yieldKeysScanBatchSize = 1000; + + constructor(fields: { + client: Redis; + ttl?: number; + namespace?: string; + yieldKeysScanBatchSize?: number; + }) { + super(fields); + this.client = fields.client; + this.ttl = fields.ttl; + this.namespace = fields.namespace; + this.yieldKeysScanBatchSize = + fields.yieldKeysScanBatchSize ?? this.yieldKeysScanBatchSize; + } + + _getPrefixedKey(key: string) { + if (this.namespace) { + const delimiter = "/"; + return `${this.namespace}${delimiter}${key}`; + } + return key; + } + + _getDeprefixedKey(key: string) { + if (this.namespace) { + const delimiter = "/"; + return key.slice(this.namespace.length + delimiter.length); + } + return key; + } + + /** + * Gets multiple keys from the Redis database. + * @param keys Array of keys to be retrieved. + * @returns An array of retrieved values. + */ + async mget(keys: string[]) { + const prefixedKeys = keys.map(this._getPrefixedKey.bind(this)); + const retrievedValues = await this.client.mgetBuffer(prefixedKeys); + return retrievedValues.map((value) => { + if (!value) { + return undefined; + } else { + return value; + } + }); + } + + /** + * Sets multiple keys in the Redis database. + * @param keyValuePairs Array of key-value pairs to be set. + * @returns Promise that resolves when all keys have been set. + */ + async mset(keyValuePairs: [string, Uint8Array][]): Promise { + const decoder = new TextDecoder(); + const encodedKeyValuePairs = keyValuePairs.map(([key, value]) => [ + this._getPrefixedKey(key), + decoder.decode(value), + ]); + const pipeline = this.client.pipeline(); + for (const [key, value] of encodedKeyValuePairs) { + if (this.ttl) { + pipeline.set(key, value, "EX", this.ttl); + } else { + pipeline.set(key, value); + } + } + await pipeline.exec(); + } + + /** + * Deletes multiple keys from the Redis database. + * @param keys Array of keys to be deleted. + * @returns Promise that resolves when all keys have been deleted. + */ + async mdelete(keys: string[]): Promise { + await this.client.del(...keys.map(this._getPrefixedKey.bind(this))); + } + + /** + * Yields keys from the Redis database. + * @param prefix Optional prefix to filter the keys. + * @returns An AsyncGenerator that yields keys from the Redis database. + */ + async *yieldKeys(prefix?: string): AsyncGenerator { + let pattern; + if (prefix) { + const wildcardPrefix = prefix.endsWith("*") ? prefix : `${prefix}*`; + pattern = this._getPrefixedKey(wildcardPrefix); + } else { + pattern = this._getPrefixedKey("*"); + } + let [cursor, batch] = await this.client.scan( + 0, + "MATCH", + pattern, + "COUNT", + this.yieldKeysScanBatchSize + ); + for (const key of batch) { + yield this._getDeprefixedKey(key); + } + while (cursor !== "0") { + [cursor, batch] = await this.client.scan( + cursor, + "MATCH", + pattern, + "COUNT", + this.yieldKeysScanBatchSize + ); + for (const key of batch) { + yield this._getDeprefixedKey(key); + } + } + } +} diff --git a/libs/langchain-community/src/storage/upstash_redis.ts b/libs/langchain-community/src/storage/upstash_redis.ts new file mode 100644 index 000000000000..489ec4fd0261 --- /dev/null +++ b/libs/langchain-community/src/storage/upstash_redis.ts @@ -0,0 +1,176 @@ +import { Redis as UpstashRedis, type RedisConfigNodejs } from "@upstash/redis"; + +import { BaseStore } from "@langchain/core/stores"; + +/** + * Type definition for the input parameters required to initialize an + * instance of the UpstashStoreInput class. + */ +export interface UpstashRedisStoreInput { + sessionTTL?: number; + config?: RedisConfigNodejs; + client?: UpstashRedis; + /** + * The amount of keys to retrieve per batch when yielding keys. + * @default 1000 + */ + yieldKeysScanBatchSize?: number; + /** + * The namespace to use for the keys in the database. + */ + namespace?: string; +} + +/** + * Class that extends the BaseStore class to interact with an Upstash Redis + * database. It provides methods for getting, setting, and deleting data, + * as well as yielding keys from the database. + * @example + * ```typescript + * const store = new UpstashRedisStore({ + * client: new Redis({ + * url: "your-upstash-redis-url", + * token: "your-upstash-redis-token", + * }), + * }); + * await store.mset([ + * ["message:id:0", "encoded-ai-message"], + * ["message:id:1", "encoded-human-message"], + * ]); + * const retrievedMessages = await store.mget(["message:id:0", "message:id:1"]); + * const yieldedKeys = []; + * for await (const key of store.yieldKeys("message:id")) { + * yieldedKeys.push(key); + * } + * await store.mdelete(yieldedKeys); + * ``` + */ +export class UpstashRedisStore extends BaseStore { + lc_namespace = ["langchain", "storage"]; + + protected client: UpstashRedis; + + protected namespace?: string; + + protected yieldKeysScanBatchSize = 1000; + + private sessionTTL?: number; + + constructor(fields: UpstashRedisStoreInput) { + super(fields); + if (fields.client) { + this.client = fields.client; + } else if (fields.config) { + this.client = new UpstashRedis(fields.config); + } else { + throw new Error( + `Upstash Redis store requires either a config object or a pre-configured client.` + ); + } + this.sessionTTL = fields.sessionTTL; + this.yieldKeysScanBatchSize = + fields.yieldKeysScanBatchSize ?? this.yieldKeysScanBatchSize; + this.namespace = fields.namespace; + } + + _getPrefixedKey(key: string) { + if (this.namespace) { + const delimiter = "/"; + return `${this.namespace}${delimiter}${key}`; + } + return key; + } + + _getDeprefixedKey(key: string) { + if (this.namespace) { + const delimiter = "/"; + return key.slice(this.namespace.length + delimiter.length); + } + return key; + } + + /** + * Gets multiple keys from the Upstash Redis database. + * @param keys Array of keys to be retrieved. + * @returns An array of retrieved values. + */ + async mget(keys: string[]) { + const encoder = new TextEncoder(); + + const prefixedKeys = keys.map(this._getPrefixedKey.bind(this)); + const retrievedValues = await this.client.mget( + ...prefixedKeys + ); + return retrievedValues.map((value) => { + if (!value) { + return undefined; + } else if (typeof value === "object") { + return encoder.encode(JSON.stringify(value)); + } else { + return encoder.encode(value); + } + }); + } + + /** + * Sets multiple keys in the Upstash Redis database. + * @param keyValuePairs Array of key-value pairs to be set. + * @returns Promise that resolves when all keys have been set. + */ + async mset(keyValuePairs: [string, Uint8Array][]): Promise { + const decoder = new TextDecoder(); + const encodedKeyValuePairs = keyValuePairs.map(([key, value]) => [ + this._getPrefixedKey(key), + decoder.decode(value), + ]); + const pipeline = this.client.pipeline(); + for (const [key, value] of encodedKeyValuePairs) { + if (this.sessionTTL) { + pipeline.setex(key, this.sessionTTL, value); + } else { + pipeline.set(key, value); + } + } + await pipeline.exec(); + } + + /** + * Deletes multiple keys from the Upstash Redis database. + * @param keys Array of keys to be deleted. + * @returns Promise that resolves when all keys have been deleted. + */ + async mdelete(keys: string[]): Promise { + await this.client.del(...keys.map(this._getPrefixedKey.bind(this))); + } + + /** + * Yields keys from the Upstash Redis database. + * @param prefix Optional prefix to filter the keys. A wildcard (*) is always appended to the end. + * @returns An AsyncGenerator that yields keys from the Upstash Redis database. + */ + async *yieldKeys(prefix?: string): AsyncGenerator { + let pattern; + if (prefix) { + const wildcardPrefix = prefix.endsWith("*") ? prefix : `${prefix}*`; + pattern = `${this._getPrefixedKey(wildcardPrefix)}*`; + } else { + pattern = this._getPrefixedKey("*"); + } + let [cursor, batch] = await this.client.scan(0, { + match: pattern, + count: this.yieldKeysScanBatchSize, + }); + for (const key of batch) { + yield this._getDeprefixedKey(key); + } + while (cursor !== 0) { + [cursor, batch] = await this.client.scan(cursor, { + match: pattern, + count: this.yieldKeysScanBatchSize, + }); + for (const key of batch) { + yield this._getDeprefixedKey(key); + } + } + } +} diff --git a/libs/langchain-community/src/storage/vercel_kv.ts b/libs/langchain-community/src/storage/vercel_kv.ts new file mode 100644 index 000000000000..8f520a0d8543 --- /dev/null +++ b/libs/langchain-community/src/storage/vercel_kv.ts @@ -0,0 +1,150 @@ +import { kv, type VercelKV } from "@vercel/kv"; + +import { BaseStore } from "@langchain/core/stores"; + +/** + * Class that extends the BaseStore class to interact with a Vercel KV + * database. It provides methods for getting, setting, and deleting data, + * as well as yielding keys from the database. + * @example + * ```typescript + * const store = new VercelKVStore({ + * client: getClient(), + * }); + * await store.mset([ + * { key: "message:id:0", value: "encoded message 0" }, + * { key: "message:id:1", value: "encoded message 1" }, + * ]); + * const retrievedMessages = await store.mget(["message:id:0", "message:id:1"]); + * const yieldedKeys = []; + * for await (const key of store.yieldKeys("message:id:")) { + * yieldedKeys.push(key); + * } + * await store.mdelete(yieldedKeys); + * ``` + */ +export class VercelKVStore extends BaseStore { + lc_namespace = ["langchain", "storage"]; + + protected client: VercelKV; + + protected ttl?: number; + + protected namespace?: string; + + protected yieldKeysScanBatchSize = 1000; + + constructor(fields?: { + client?: VercelKV; + ttl?: number; + namespace?: string; + yieldKeysScanBatchSize?: number; + }) { + super(fields); + this.client = fields?.client ?? kv; + this.ttl = fields?.ttl; + this.namespace = fields?.namespace; + this.yieldKeysScanBatchSize = + fields?.yieldKeysScanBatchSize ?? this.yieldKeysScanBatchSize; + } + + _getPrefixedKey(key: string) { + if (this.namespace) { + const delimiter = "/"; + return `${this.namespace}${delimiter}${key}`; + } + return key; + } + + _getDeprefixedKey(key: string) { + if (this.namespace) { + const delimiter = "/"; + return key.slice(this.namespace.length + delimiter.length); + } + return key; + } + + /** + * Gets multiple keys from the Redis database. + * @param keys Array of keys to be retrieved. + * @returns An array of retrieved values. + */ + async mget(keys: string[]) { + const prefixedKeys = keys.map(this._getPrefixedKey.bind(this)); + const retrievedValues = await this.client.mget<(string | undefined)[]>( + ...prefixedKeys + ); + const encoder = new TextEncoder(); + return retrievedValues.map((value) => { + if (value === undefined || value === null) { + return undefined; + } else if (typeof value === "object") { + return encoder.encode(JSON.stringify(value)); + } else { + return encoder.encode(value); + } + }); + } + + /** + * Sets multiple keys in the Redis database. + * @param keyValuePairs Array of key-value pairs to be set. + * @returns Promise that resolves when all keys have been set. + */ + async mset(keyValuePairs: [string, Uint8Array][]): Promise { + const decoder = new TextDecoder(); + const decodedKeyValuePairs = keyValuePairs.map(([key, value]) => [ + this._getPrefixedKey(key), + decoder.decode(value), + ]); + const pipeline = this.client.pipeline(); + for (const [key, value] of decodedKeyValuePairs) { + if (this.ttl) { + pipeline.setex(key, this.ttl, value); + } else { + pipeline.set(key, value); + } + } + await pipeline.exec(); + } + + /** + * Deletes multiple keys from the Redis database. + * @param keys Array of keys to be deleted. + * @returns Promise that resolves when all keys have been deleted. + */ + async mdelete(keys: string[]): Promise { + await this.client.del(...keys.map(this._getPrefixedKey.bind(this))); + } + + /** + * Yields keys from the Redis database. + * @param prefix Optional prefix to filter the keys. + * @returns An AsyncGenerator that yields keys from the Redis database. + */ + async *yieldKeys(prefix?: string): AsyncGenerator { + let pattern; + if (prefix) { + const wildcardPrefix = prefix.endsWith("*") ? prefix : `${prefix}*`; + pattern = this._getPrefixedKey(wildcardPrefix); + } else { + pattern = this._getPrefixedKey("*"); + } + let [cursor, batch] = await this.client.scan(0, { + match: pattern, + count: this.yieldKeysScanBatchSize, + }); + for (const key of batch) { + yield this._getDeprefixedKey(key); + } + while (cursor !== 0) { + [cursor, batch] = await this.client.scan(cursor, { + match: pattern, + count: this.yieldKeysScanBatchSize, + }); + for (const key of batch) { + yield this._getDeprefixedKey(key); + } + } + } +} diff --git a/libs/langchain-community/src/stores/doc/base.ts b/libs/langchain-community/src/stores/doc/base.ts new file mode 100644 index 000000000000..17b0c354c0f0 --- /dev/null +++ b/libs/langchain-community/src/stores/doc/base.ts @@ -0,0 +1,11 @@ +import { Document } from "@langchain/core/documents"; + +/** + * Abstract class for a document store. All document stores should extend + * this class. + */ +export abstract class Docstore { + abstract search(search: string): Promise; + + abstract add(texts: Record): Promise; +} diff --git a/libs/langchain-community/src/stores/doc/in_memory.ts b/libs/langchain-community/src/stores/doc/in_memory.ts new file mode 100644 index 000000000000..3a5be2949ec4 --- /dev/null +++ b/libs/langchain-community/src/stores/doc/in_memory.ts @@ -0,0 +1,113 @@ +import { Document } from "@langchain/core/documents"; +import { BaseStoreInterface } from "@langchain/core/stores"; +import { Docstore } from "./base.js"; + +/** + * Class for storing and retrieving documents in memory asynchronously. + * Extends the Docstore class. + */ +export class InMemoryDocstore + extends Docstore + implements BaseStoreInterface +{ + _docs: Map; + + constructor(docs?: Map) { + super(); + this._docs = docs ?? new Map(); + } + + /** + * Searches for a document in the store based on its ID. + * @param search The ID of the document to search for. + * @returns The document with the given ID. + */ + async search(search: string): Promise { + const result = this._docs.get(search); + if (!result) { + throw new Error(`ID ${search} not found.`); + } else { + return result; + } + } + + /** + * Adds new documents to the store. + * @param texts An object where the keys are document IDs and the values are the documents themselves. + * @returns Void + */ + async add(texts: Record): Promise { + const keys = [...this._docs.keys()]; + const overlapping = Object.keys(texts).filter((x) => keys.includes(x)); + + if (overlapping.length > 0) { + throw new Error(`Tried to add ids that already exist: ${overlapping}`); + } + + for (const [key, value] of Object.entries(texts)) { + this._docs.set(key, value); + } + } + + async mget(keys: string[]): Promise { + return Promise.all(keys.map((key) => this.search(key))); + } + + async mset(keyValuePairs: [string, Document][]): Promise { + await Promise.all( + keyValuePairs.map(([key, value]) => this.add({ [key]: value })) + ); + } + + async mdelete(_keys: string[]): Promise { + throw new Error("Not implemented."); + } + + // eslint-disable-next-line require-yield + async *yieldKeys(_prefix?: string): AsyncGenerator { + throw new Error("Not implemented"); + } +} + +/** + * Class for storing and retrieving documents in memory synchronously. + */ +export class SynchronousInMemoryDocstore { + _docs: Map; + + constructor(docs?: Map) { + this._docs = docs ?? new Map(); + } + + /** + * Searches for a document in the store based on its ID. + * @param search The ID of the document to search for. + * @returns The document with the given ID. + */ + search(search: string): Document { + const result = this._docs.get(search); + if (!result) { + throw new Error(`ID ${search} not found.`); + } else { + return result; + } + } + + /** + * Adds new documents to the store. + * @param texts An object where the keys are document IDs and the values are the documents themselves. + * @returns Void + */ + add(texts: Record): void { + const keys = [...this._docs.keys()]; + const overlapping = Object.keys(texts).filter((x) => keys.includes(x)); + + if (overlapping.length > 0) { + throw new Error(`Tried to add ids that already exist: ${overlapping}`); + } + + for (const [key, value] of Object.entries(texts)) { + this._docs.set(key, value); + } + } +} diff --git a/libs/langchain-community/src/stores/message/cassandra.ts b/libs/langchain-community/src/stores/message/cassandra.ts new file mode 100644 index 000000000000..c648d43dd42a --- /dev/null +++ b/libs/langchain-community/src/stores/message/cassandra.ts @@ -0,0 +1,152 @@ +import { Client, DseClientOptions } from "cassandra-driver"; +import { BaseListChatMessageHistory } from "@langchain/core/chat_history"; +import { + BaseMessage, + StoredMessage, + mapChatMessagesToStoredMessages, + mapStoredMessagesToChatMessages, +} from "@langchain/core/messages"; + +export interface CassandraChatMessageHistoryOptions extends DseClientOptions { + keyspace: string; + table: string; + sessionId: string; +} + +/** + * Class for storing chat message history within Cassandra. It extends the + * BaseListChatMessageHistory class and provides methods to get, add, and + * clear messages. + * @example + * ```typescript + * const chatHistory = new CassandraChatMessageHistory({ + * cloud: { + * secureConnectBundle: "", + * }, + * credentials: { + * username: "token", + * password: "", + * }, + * keyspace: "langchain", + * table: "message_history", + * sessionId: "", + * }); + * + * const chain = new ConversationChain({ + * llm: new ChatOpenAI(), + * memory: chatHistory, + * }); + * + * const response = await chain.invoke({ + * input: "What did I just say my name was?", + * }); + * console.log({ response }); + * ``` + */ +export class CassandraChatMessageHistory extends BaseListChatMessageHistory { + lc_namespace = ["langchain", "stores", "message", "cassandra"]; + + private keyspace: string; + + private table: string; + + private client: Client; + + private sessionId: string; + + private tableExists: boolean; + + private options: CassandraChatMessageHistoryOptions; + + private queries: { insert: string; select: string; delete: string }; + + constructor(options: CassandraChatMessageHistoryOptions) { + super(); + this.client = new Client(options); + this.keyspace = options.keyspace; + this.table = options.table; + this.sessionId = options.sessionId; + this.tableExists = false; + this.options = options; + } + + /** + * Method to get all the messages stored in the Cassandra database. + * @returns Array of stored BaseMessage instances. + */ + public async getMessages(): Promise { + await this.ensureTable(); + const resultSet = await this.client.execute( + this.queries.select, + [this.sessionId], + { prepare: true } + ); + const storedMessages: StoredMessage[] = resultSet.rows.map((row) => ({ + type: row.message_type, + data: JSON.parse(row.data), + })); + + const baseMessages = mapStoredMessagesToChatMessages(storedMessages); + return baseMessages; + } + + /** + * Method to add a new message to the Cassandra database. + * @param message The BaseMessage instance to add. + * @returns A promise that resolves when the message has been added. + */ + public async addMessage(message: BaseMessage): Promise { + await this.ensureTable(); + const messages = mapChatMessagesToStoredMessages([message]); + const { type, data } = messages[0]; + return this.client + .execute( + this.queries.insert, + [this.sessionId, type, JSON.stringify(data)], + { prepare: true, ...this.options } + ) + .then(() => {}); + } + + /** + * Method to clear all the messages from the Cassandra database. + * @returns A promise that resolves when all messages have been cleared. + */ + public async clear(): Promise { + await this.ensureTable(); + return this.client + .execute(this.queries.delete, [this.sessionId], { + prepare: true, + ...this.options, + }) + .then(() => {}); + } + + /** + * Method to initialize the Cassandra database. + * @returns Promise that resolves when the database has been initialized. + */ + private async ensureTable(): Promise { + if (this.tableExists) { + return; + } + + await this.client.execute(` + CREATE TABLE IF NOT EXISTS ${this.keyspace}.${this.table} ( + session_id text, + message_ts timestamp, + message_type text, + data text, + PRIMARY KEY ((session_id), message_ts) + ); + `); + + this.queries = { + insert: `INSERT INTO ${this.keyspace}.${this.table} (session_id, message_ts, message_type, data) VALUES (?, toTimestamp(now()), ?, ?);`, + select: `SELECT message_type, data FROM ${this.keyspace}.${this.table} WHERE session_id = ?;`, + delete: `DELETE FROM ${this.keyspace}.${this.table} WHERE session_id = ?;`, + }; + + this.tableExists = true; + } +} diff --git a/libs/langchain-community/src/stores/message/cloudflare_d1.ts b/libs/langchain-community/src/stores/message/cloudflare_d1.ts new file mode 100644 index 000000000000..4b8a99ff1904 --- /dev/null +++ b/libs/langchain-community/src/stores/message/cloudflare_d1.ts @@ -0,0 +1,193 @@ +import { v4 } from "uuid"; +import type { D1Database } from "@cloudflare/workers-types"; +import { BaseListChatMessageHistory } from "@langchain/core/chat_history"; +import { + BaseMessage, + StoredMessage, + StoredMessageData, + mapChatMessagesToStoredMessages, + mapStoredMessagesToChatMessages, +} from "@langchain/core/messages"; +/** + * Type definition for the input parameters required when instantiating a + * CloudflareD1MessageHistory object. + */ +export type CloudflareD1MessageHistoryInput = { + tableName?: string; + sessionId: string; + database?: D1Database; +}; + +/** + * Interface for the data transfer object used when selecting stored + * messages from the Cloudflare D1 database. + */ +interface selectStoredMessagesDTO { + id: string; + session_id: string; + type: string; + content: string; + role: string | null; + name: string | null; + additional_kwargs: string; +} + +/** + * Class for storing and retrieving chat message history from a + * Cloudflare D1 database. Extends the BaseListChatMessageHistory class. + * @example + * ```typescript + * const memory = new BufferMemory({ + * returnMessages: true, + * chatHistory: new CloudflareD1MessageHistory({ + * tableName: "stored_message", + * sessionId: "example", + * database: env.DB, + * }), + * }); + * + * const chainInput = { input }; + * + * const res = await memory.chatHistory.invoke(chainInput); + * await memory.saveContext(chainInput, { + * output: res, + * }); + * ``` + */ +export class CloudflareD1MessageHistory extends BaseListChatMessageHistory { + lc_namespace = ["langchain", "stores", "message", "cloudflare_d1"]; + + public database: D1Database; + + private tableName: string; + + private sessionId: string; + + private tableInitialized: boolean; + + constructor(fields: CloudflareD1MessageHistoryInput) { + super(fields); + + const { sessionId, database, tableName } = fields; + + if (database) { + this.database = database; + } else { + throw new Error( + "Either a client or config must be provided to CloudflareD1MessageHistory" + ); + } + + this.tableName = tableName || "langchain_chat_histories"; + this.tableInitialized = false; + this.sessionId = sessionId; + } + + /** + * Private method to ensure that the necessary table exists in the + * Cloudflare D1 database before performing any operations. If the table + * does not exist, it is created. + * @returns Promise that resolves to void. + */ + private async ensureTable(): Promise { + if (this.tableInitialized) { + return; + } + + const query = `CREATE TABLE IF NOT EXISTS ${this.tableName} (id TEXT PRIMARY KEY, session_id TEXT, type TEXT, content TEXT, role TEXT, name TEXT, additional_kwargs TEXT);`; + await this.database.prepare(query).bind().all(); + + const idIndexQuery = `CREATE INDEX IF NOT EXISTS id_index ON ${this.tableName} (id);`; + await this.database.prepare(idIndexQuery).bind().all(); + + const sessionIdIndexQuery = `CREATE INDEX IF NOT EXISTS session_id_index ON ${this.tableName} (session_id);`; + await this.database.prepare(sessionIdIndexQuery).bind().all(); + + this.tableInitialized = true; + } + + /** + * Method to retrieve all messages from the Cloudflare D1 database for the + * current session. + * @returns Promise that resolves to an array of BaseMessage objects. + */ + async getMessages(): Promise { + await this.ensureTable(); + + const query = `SELECT * FROM ${this.tableName} WHERE session_id = ?`; + const rawStoredMessages = await this.database + .prepare(query) + .bind(this.sessionId) + .all(); + const storedMessagesObject = + rawStoredMessages.results as unknown as selectStoredMessagesDTO[]; + + const orderedMessages: StoredMessage[] = storedMessagesObject.map( + (message) => { + const data = { + content: message.content, + additional_kwargs: JSON.parse(message.additional_kwargs), + } as StoredMessageData; + + if (message.role) { + data.role = message.role; + } + + if (message.name) { + data.name = message.name; + } + + return { + type: message.type, + data, + }; + } + ); + + return mapStoredMessagesToChatMessages(orderedMessages); + } + + /** + * Method to add a new message to the Cloudflare D1 database for the current + * session. + * @param message The BaseMessage object to be added to the database. + * @returns Promise that resolves to void. + */ + async addMessage(message: BaseMessage): Promise { + await this.ensureTable(); + + const messageToAdd = mapChatMessagesToStoredMessages([message]); + + const query = `INSERT INTO ${this.tableName} (id, session_id, type, content, role, name, additional_kwargs) VALUES(?, ?, ?, ?, ?, ?, ?)`; + + const id = v4(); + + await this.database + .prepare(query) + .bind( + id, + this.sessionId, + messageToAdd[0].type || null, + messageToAdd[0].data.content || null, + messageToAdd[0].data.role || null, + messageToAdd[0].data.name || null, + JSON.stringify(messageToAdd[0].data.additional_kwargs) + ) + .all(); + } + + /** + * Method to delete all messages from the Cloudflare D1 database for the + * current session. + * @returns Promise that resolves to void. + */ + async clear(): Promise { + await this.ensureTable(); + + const query = `DELETE FROM ? WHERE session_id = ? `; + await this.database + .prepare(query) + .bind(this.tableName, this.sessionId) + .all(); + } +} diff --git a/libs/langchain-community/src/stores/message/convex.ts b/libs/langchain-community/src/stores/message/convex.ts new file mode 100644 index 000000000000..9a7f5d120e1c --- /dev/null +++ b/libs/langchain-community/src/stores/message/convex.ts @@ -0,0 +1,210 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ + +// eslint-disable-next-line import/no-extraneous-dependencies +import { + DocumentByInfo, + DocumentByName, + FieldPaths, + FunctionReference, + GenericActionCtx, + GenericDataModel, + NamedTableInfo, + TableNamesInDataModel, + IndexNames, + makeFunctionReference, +} from "convex/server"; +import { BaseListChatMessageHistory } from "@langchain/core/chat_history"; +import { + BaseMessage, + mapChatMessagesToStoredMessages, + mapStoredMessagesToChatMessages, +} from "@langchain/core/messages"; + +/** + * Type that defines the config required to initialize the + * ConvexChatMessageHistory class. At minimum it needs a sessionId + * and an ActionCtx. + */ +export type ConvexChatMessageHistoryInput< + DataModel extends GenericDataModel, + TableName extends TableNamesInDataModel = "messages", + IndexName extends IndexNames< + NamedTableInfo + > = "bySessionId", + SessionIdFieldName extends FieldPaths< + NamedTableInfo + > = "sessionId", + MessageTextFieldName extends FieldPaths< + NamedTableInfo + > = "message", + InsertMutation extends FunctionReference< + "mutation", + "internal", + { table: string; document: object } + > = any, + LookupQuery extends FunctionReference< + "query", + "internal", + { table: string; index: string; keyField: string; key: string }, + object[] + > = any, + DeleteManyMutation extends FunctionReference< + "mutation", + "internal", + { table: string; index: string; keyField: string; key: string } + > = any +> = { + readonly ctx: GenericActionCtx; + readonly sessionId: DocumentByName[SessionIdFieldName]; + /** + * Defaults to "messages" + */ + readonly table?: TableName; + /** + * Defaults to "bySessionId" + */ + readonly index?: IndexName; + /** + * Defaults to "sessionId" + */ + readonly sessionIdField?: SessionIdFieldName; + /** + * Defaults to "message" + */ + readonly messageTextFieldName?: MessageTextFieldName; + /** + * Defaults to `internal.langchain.db.insert` + */ + readonly insert?: InsertMutation; + /** + * Defaults to `internal.langchain.db.lookup` + */ + readonly lookup?: LookupQuery; + /** + * Defaults to `internal.langchain.db.deleteMany` + */ + readonly deleteMany?: DeleteManyMutation; +}; + +export class ConvexChatMessageHistory< + DataModel extends GenericDataModel, + SessionIdFieldName extends FieldPaths< + NamedTableInfo + > = "sessionId", + TableName extends TableNamesInDataModel = "messages", + IndexName extends IndexNames< + NamedTableInfo + > = "bySessionId", + MessageTextFieldName extends FieldPaths< + NamedTableInfo + > = "message", + InsertMutation extends FunctionReference< + "mutation", + "internal", + { table: string; document: object } + > = any, + LookupQuery extends FunctionReference< + "query", + "internal", + { table: string; index: string; keyField: string; key: string }, + object[] + > = any, + DeleteManyMutation extends FunctionReference< + "mutation", + "internal", + { table: string; index: string; keyField: string; key: string } + > = any +> extends BaseListChatMessageHistory { + lc_namespace = ["langchain", "stores", "message", "convex"]; + + private readonly ctx: GenericActionCtx; + + private readonly sessionId: DocumentByInfo< + NamedTableInfo + >[SessionIdFieldName]; + + private readonly table: TableName; + + private readonly index: IndexName; + + private readonly sessionIdField: SessionIdFieldName; + + private readonly messageTextFieldName: MessageTextFieldName; + + private readonly insert: InsertMutation; + + private readonly lookup: LookupQuery; + + private readonly deleteMany: DeleteManyMutation; + + constructor( + config: ConvexChatMessageHistoryInput< + DataModel, + TableName, + IndexName, + SessionIdFieldName, + MessageTextFieldName, + InsertMutation, + LookupQuery, + DeleteManyMutation + > + ) { + super(); + this.ctx = config.ctx; + this.sessionId = config.sessionId; + this.table = config.table ?? ("messages" as TableName); + this.index = config.index ?? ("bySessionId" as IndexName); + this.sessionIdField = + config.sessionIdField ?? ("sessionId" as SessionIdFieldName); + this.messageTextFieldName = + config.messageTextFieldName ?? ("message" as MessageTextFieldName); + this.insert = + config.insert ?? (makeFunctionReference("langchain/db:insert") as any); + this.lookup = + config.lookup ?? (makeFunctionReference("langchain/db:lookup") as any); + this.deleteMany = + config.deleteMany ?? + (makeFunctionReference("langchain/db:deleteMany") as any); + } + + async getMessages(): Promise { + const convexDocuments: any[] = await this.ctx.runQuery(this.lookup, { + table: this.table, + index: this.index, + keyField: this.sessionIdField, + key: this.sessionId, + } as any); + + return mapStoredMessagesToChatMessages( + convexDocuments.map((doc) => doc[this.messageTextFieldName]) + ); + } + + async addMessage(message: BaseMessage): Promise { + const messages = mapChatMessagesToStoredMessages([message]); + // TODO: Remove chunking when Convex handles the concurrent requests correctly + const PAGE_SIZE = 16; + for (let i = 0; i < messages.length; i += PAGE_SIZE) { + await Promise.all( + messages.slice(i, i + PAGE_SIZE).map((message) => + this.ctx.runMutation(this.insert, { + table: this.table, + document: { + [this.sessionIdField]: this.sessionId, + [this.messageTextFieldName]: message, + }, + } as any) + ) + ); + } + } + + async clear(): Promise { + await this.ctx.runMutation(this.deleteMany, { + table: this.table, + index: this.index, + keyField: this.sessionIdField, + key: this.sessionId, + } as any); + } +} diff --git a/libs/langchain-community/src/stores/message/dynamodb.ts b/libs/langchain-community/src/stores/message/dynamodb.ts new file mode 100644 index 000000000000..6bef93aeb88f --- /dev/null +++ b/libs/langchain-community/src/stores/message/dynamodb.ts @@ -0,0 +1,196 @@ +import { + DynamoDBClient, + DynamoDBClientConfig, + GetItemCommand, + GetItemCommandInput, + UpdateItemCommand, + UpdateItemCommandInput, + DeleteItemCommand, + DeleteItemCommandInput, + AttributeValue, +} from "@aws-sdk/client-dynamodb"; + +import { BaseListChatMessageHistory } from "@langchain/core/chat_history"; +import { + BaseMessage, + StoredMessage, + mapChatMessagesToStoredMessages, + mapStoredMessagesToChatMessages, +} from "@langchain/core/messages"; + +/** + * Interface defining the fields required to create an instance of + * `DynamoDBChatMessageHistory`. It includes the DynamoDB table name, + * session ID, partition key, sort key, message attribute name, and + * DynamoDB client configuration. + */ +export interface DynamoDBChatMessageHistoryFields { + tableName: string; + sessionId: string; + partitionKey?: string; + sortKey?: string; + messageAttributeName?: string; + config?: DynamoDBClientConfig; + key?: Record; +} + +/** + * Interface defining the structure of a chat message as it is stored in + * DynamoDB. + */ +interface DynamoDBSerializedChatMessage { + M: { + type: { + S: string; + }; + text: { + S: string; + }; + role?: { + S: string; + }; + }; +} + +/** + * Class providing methods to interact with a DynamoDB table to store and + * retrieve chat messages. It extends the `BaseListChatMessageHistory` + * class. + */ +export class DynamoDBChatMessageHistory extends BaseListChatMessageHistory { + lc_namespace = ["langchain", "stores", "message", "dynamodb"]; + + get lc_secrets(): { [key: string]: string } | undefined { + return { + "config.credentials.accessKeyId": "AWS_ACCESS_KEY_ID", + "config.credentials.secretAccessKey": "AWS_SECRETE_ACCESS_KEY", + "config.credentials.sessionToken": "AWS_SESSION_TOKEN", + }; + } + + private tableName: string; + + private sessionId: string; + + private client: DynamoDBClient; + + private partitionKey = "id"; + + private sortKey?: string; + + private messageAttributeName = "messages"; + + private dynamoKey: Record = {}; + + constructor({ + tableName, + sessionId, + partitionKey, + sortKey, + messageAttributeName, + config, + key = {}, + }: DynamoDBChatMessageHistoryFields) { + super(); + + this.tableName = tableName; + this.sessionId = sessionId; + this.client = new DynamoDBClient(config ?? {}); + this.partitionKey = partitionKey ?? this.partitionKey; + this.sortKey = sortKey; + this.messageAttributeName = + messageAttributeName ?? this.messageAttributeName; + this.dynamoKey = key; + + // override dynamoKey with partition key and sort key when key not specified + if (Object.keys(this.dynamoKey).length === 0) { + this.dynamoKey[this.partitionKey] = { S: this.sessionId }; + if (this.sortKey) { + this.dynamoKey[this.sortKey] = { S: this.sortKey }; + } + } + } + + /** + * Retrieves all messages from the DynamoDB table and returns them as an + * array of `BaseMessage` instances. + * @returns Array of stored messages + */ + async getMessages(): Promise { + const params: GetItemCommandInput = { + TableName: this.tableName, + Key: this.dynamoKey, + }; + + const response = await this.client.send(new GetItemCommand(params)); + const items = response.Item + ? response.Item[this.messageAttributeName]?.L ?? [] + : []; + const messages = items + .map((item) => ({ + type: item.M?.type.S, + data: { + role: item.M?.role?.S, + content: item.M?.text.S, + }, + })) + .filter( + (x): x is StoredMessage => + x.type !== undefined && x.data.content !== undefined + ); + return mapStoredMessagesToChatMessages(messages); + } + + /** + * Deletes all messages from the DynamoDB table. + */ + async clear(): Promise { + const params: DeleteItemCommandInput = { + TableName: this.tableName, + Key: this.dynamoKey, + }; + await this.client.send(new DeleteItemCommand(params)); + } + + /** + * Adds a new message to the DynamoDB table. + * @param message The message to be added to the DynamoDB table. + */ + async addMessage(message: BaseMessage) { + const messages = mapChatMessagesToStoredMessages([message]); + + const params: UpdateItemCommandInput = { + TableName: this.tableName, + Key: this.dynamoKey, + ExpressionAttributeNames: { + "#m": this.messageAttributeName, + }, + ExpressionAttributeValues: { + ":empty_list": { + L: [], + }, + ":m": { + L: messages.map((message) => { + const dynamoSerializedMessage: DynamoDBSerializedChatMessage = { + M: { + type: { + S: message.type, + }, + text: { + S: message.data.content, + }, + }, + }; + if (message.data.role) { + dynamoSerializedMessage.M.role = { S: message.data.role }; + } + return dynamoSerializedMessage; + }), + }, + }, + UpdateExpression: + "SET #m = list_append(if_not_exists(#m, :empty_list), :m)", + }; + await this.client.send(new UpdateItemCommand(params)); + } +} diff --git a/libs/langchain-community/src/stores/message/firestore.ts b/libs/langchain-community/src/stores/message/firestore.ts new file mode 100644 index 000000000000..24d4af56c99c --- /dev/null +++ b/libs/langchain-community/src/stores/message/firestore.ts @@ -0,0 +1,193 @@ +import type { AppOptions } from "firebase-admin"; +import { getApps, initializeApp } from "firebase-admin/app"; +import { + getFirestore, + DocumentData, + Firestore, + DocumentReference, + FieldValue, +} from "firebase-admin/firestore"; + +import { BaseListChatMessageHistory } from "@langchain/core/chat_history"; +import { + BaseMessage, + StoredMessage, + mapChatMessagesToStoredMessages, + mapStoredMessagesToChatMessages, +} from "@langchain/core/messages"; + +/** + * Interface for FirestoreDBChatMessageHistory. It includes the collection + * name, session ID, user ID, and optionally, the app index and + * configuration for the Firebase app. + */ +export interface FirestoreDBChatMessageHistory { + collectionName: string; + sessionId: string; + userId: string; + appIdx?: number; + config?: AppOptions; +} +/** + * Class for managing chat message history using Google's Firestore as a + * storage backend. Extends the BaseListChatMessageHistory class. + * @example + * ```typescript + * const chatHistory = new FirestoreChatMessageHistory({ + * collectionName: "langchain", + * sessionId: "lc-example", + * userId: "a@example.com", + * config: { projectId: "your-project-id" }, + * }); + * + * const chain = new ConversationChain({ + * llm: new ChatOpenAI(), + * memory: new BufferMemory({ chatHistory }), + * }); + * + * const response = await chain.invoke({ + * input: "What did I just say my name was?", + * }); + * console.log({ response }); + * ``` + */ +export class FirestoreChatMessageHistory extends BaseListChatMessageHistory { + lc_namespace = ["langchain", "stores", "message", "firestore"]; + + private collectionName: string; + + private sessionId: string; + + private userId: string; + + private appIdx: number; + + private config: AppOptions; + + private firestoreClient: Firestore; + + private document: DocumentReference | null; + + constructor({ + collectionName, + sessionId, + userId, + appIdx = 0, + config, + }: FirestoreDBChatMessageHistory) { + super(); + this.collectionName = collectionName; + this.sessionId = sessionId; + this.userId = userId; + this.document = null; + this.appIdx = appIdx; + if (config) this.config = config; + + try { + this.ensureFirestore(); + } catch (error) { + throw new Error(`Unknown response type`); + } + } + + private ensureFirestore(): void { + let app; + // Check if the app is already initialized else get appIdx + if (!getApps().length) app = initializeApp(this.config); + else app = getApps()[this.appIdx]; + + this.firestoreClient = getFirestore(app); + + this.document = this.firestoreClient + .collection(this.collectionName) + .doc(this.sessionId); + } + + /** + * Method to retrieve all messages from the Firestore collection + * associated with the current session. Returns an array of BaseMessage + * objects. + * @returns Array of stored messages + */ + async getMessages(): Promise { + if (!this.document) { + throw new Error("Document not initialized"); + } + + const querySnapshot = await this.document + .collection("messages") + .orderBy("createdAt", "asc") + .get() + .catch((err) => { + throw new Error(`Unknown response type: ${err.toString()}`); + }); + + const response: StoredMessage[] = []; + querySnapshot.forEach((doc) => { + const { type, data } = doc.data(); + response.push({ type, data }); + }); + + return mapStoredMessagesToChatMessages(response); + } + + /** + * Method to add a new message to the Firestore collection. The message is + * passed as a BaseMessage object. + * @param message The message to be added as a BaseMessage object. + */ + public async addMessage(message: BaseMessage) { + const messages = mapChatMessagesToStoredMessages([message]); + await this.upsertMessage(messages[0]); + } + + private async upsertMessage(message: StoredMessage): Promise { + if (!this.document) { + throw new Error("Document not initialized"); + } + await this.document.set( + { + id: this.sessionId, + user_id: this.userId, + }, + { merge: true } + ); + await this.document + .collection("messages") + .add({ + type: message.type, + data: message.data, + createdBy: this.userId, + createdAt: FieldValue.serverTimestamp(), + }) + .catch((err) => { + throw new Error(`Unknown response type: ${err.toString()}`); + }); + } + + /** + * Method to delete all messages from the Firestore collection associated + * with the current session. + */ + public async clear(): Promise { + if (!this.document) { + throw new Error("Document not initialized"); + } + await this.document + .collection("messages") + .get() + .then((querySnapshot) => { + querySnapshot.docs.forEach((snapshot) => { + snapshot.ref.delete().catch((err) => { + throw new Error(`Unknown response type: ${err.toString()}`); + }); + }); + }) + .catch((err) => { + throw new Error(`Unknown response type: ${err.toString()}`); + }); + await this.document.delete().catch((err) => { + throw new Error(`Unknown response type: ${err.toString()}`); + }); + } +} diff --git a/libs/langchain-community/src/stores/message/ioredis.ts b/libs/langchain-community/src/stores/message/ioredis.ts new file mode 100644 index 000000000000..3772250ee2ac --- /dev/null +++ b/libs/langchain-community/src/stores/message/ioredis.ts @@ -0,0 +1,103 @@ +import { Redis, RedisOptions } from "ioredis"; +import { BaseListChatMessageHistory } from "@langchain/core/chat_history"; +import { + BaseMessage, + mapChatMessagesToStoredMessages, + mapStoredMessagesToChatMessages, +} from "@langchain/core/messages"; + +/** + * Type for the input parameter of the RedisChatMessageHistory + * constructor. It includes fields for the session ID, session TTL, Redis + * URL, Redis configuration, and Redis client. + */ +export type RedisChatMessageHistoryInput = { + sessionId: string; + sessionTTL?: number; + url?: string; + config?: RedisOptions; + client?: Redis; +}; + +/** + * Class used to store chat message history in Redis. It provides methods + * to add, retrieve, and clear messages from the chat history. + * @example + * ```typescript + * const chatHistory = new RedisChatMessageHistory({ + * sessionId: new Date().toISOString(), + * sessionTTL: 300, + * url: "redis: + * }); + * + * const chain = new ConversationChain({ + * llm: new ChatOpenAI({ temperature: 0 }), + * memory: { chatHistory }, + * }); + * + * const response = await chain.invoke({ + * input: "What did I just say my name was?", + * }); + * console.log({ response }); + * ``` + */ +export class RedisChatMessageHistory extends BaseListChatMessageHistory { + lc_namespace = ["langchain", "stores", "message", "ioredis"]; + + get lc_secrets() { + return { + url: "REDIS_URL", + "config.username": "REDIS_USERNAME", + "config.password": "REDIS_PASSWORD", + }; + } + + public client: Redis; + + private sessionId: string; + + private sessionTTL?: number; + + constructor(fields: RedisChatMessageHistoryInput) { + super(fields); + + const { sessionId, sessionTTL, url, config, client } = fields; + this.client = (client ?? + (url ? new Redis(url) : new Redis(config ?? {}))) as Redis; + this.sessionId = sessionId; + this.sessionTTL = sessionTTL; + } + + /** + * Retrieves all messages from the chat history. + * @returns Promise that resolves with an array of BaseMessage instances. + */ + async getMessages(): Promise { + const rawStoredMessages = await this.client.lrange(this.sessionId, 0, -1); + const orderedMessages = rawStoredMessages + .reverse() + .map((message) => JSON.parse(message)); + return mapStoredMessagesToChatMessages(orderedMessages); + } + + /** + * Adds a message to the chat history. + * @param message The message to add to the chat history. + * @returns Promise that resolves when the message has been added. + */ + async addMessage(message: BaseMessage): Promise { + const messageToAdd = mapChatMessagesToStoredMessages([message]); + await this.client.lpush(this.sessionId, JSON.stringify(messageToAdd[0])); + if (this.sessionTTL) { + await this.client.expire(this.sessionId, this.sessionTTL); + } + } + + /** + * Clears all messages from the chat history. + * @returns Promise that resolves when the chat history has been cleared. + */ + async clear(): Promise { + await this.client.del(this.sessionId); + } +} diff --git a/libs/langchain-community/src/stores/message/momento.ts b/libs/langchain-community/src/stores/message/momento.ts new file mode 100644 index 000000000000..a5c9fe323b3f --- /dev/null +++ b/libs/langchain-community/src/stores/message/momento.ts @@ -0,0 +1,196 @@ +/* eslint-disable no-instanceof/no-instanceof */ +import { + CacheDelete, + CacheListFetch, + CacheListPushBack, + ICacheClient, + InvalidArgumentError, + CollectionTtl, +} from "@gomomento/sdk-core"; +import { BaseListChatMessageHistory } from "@langchain/core/chat_history"; +import { + BaseMessage, + StoredMessage, + mapChatMessagesToStoredMessages, + mapStoredMessagesToChatMessages, +} from "@langchain/core/messages"; +import { ensureCacheExists } from "../../utils/momento.js"; + +/** + * The settings to instantiate the Momento chat message history. + */ +export interface MomentoChatMessageHistoryProps { + /** + * The session ID to use to store the data. + */ + sessionId: string; + /** + * The Momento cache client. + */ + client: ICacheClient; + /** + * The name of the cache to use to store the data. + */ + cacheName: string; + /** + * The time to live for the cache items in seconds. + * If not specified, the cache client default is used. + */ + sessionTtl?: number; + /** + * If true, ensure that the cache exists before returning. + * If false, the cache is not checked for existence. + * Defaults to true. + */ + ensureCacheExists?: true; +} + +/** + * A class that stores chat message history using Momento Cache. It + * interacts with a Momento cache client to perform operations like + * fetching, adding, and deleting messages. + * @example + * ```typescript + * const chatHistory = await MomentoChatMessageHistory.fromProps({ + * client: new CacheClient({ + * configuration: Configurations.Laptop.v1(), + * credentialProvider: CredentialProvider.fromEnvironmentVariable({ + * environmentVariableName: "MOMENTO_API_KEY", + * }), + * defaultTtlSeconds: 60 * 60 * 24, + * }), + * cacheName: "langchain", + * sessionId: new Date().toISOString(), + * sessionTtl: 300, + * }); + * + * const messages = await chatHistory.getMessages(); + * console.log({ messages }); + * ``` + */ +export class MomentoChatMessageHistory extends BaseListChatMessageHistory { + lc_namespace = ["langchain", "stores", "message", "momento"]; + + private readonly sessionId: string; + + private readonly client: ICacheClient; + + private readonly cacheName: string; + + private readonly sessionTtl: CollectionTtl; + + private constructor(props: MomentoChatMessageHistoryProps) { + super(); + this.sessionId = props.sessionId; + this.client = props.client; + this.cacheName = props.cacheName; + + this.validateTtlSeconds(props.sessionTtl); + this.sessionTtl = + props.sessionTtl !== undefined + ? CollectionTtl.of(props.sessionTtl) + : CollectionTtl.fromCacheTtl(); + } + + /** + * Create a new chat message history backed by Momento. + * + * @param {MomentoCacheProps} props The settings to instantiate the Momento chat message history. + * @param {string} props.sessionId The session ID to use to store the data. + * @param {ICacheClient} props.client The Momento cache client. + * @param {string} props.cacheName The name of the cache to use to store the data. + * @param {number} props.sessionTtl The time to live for the cache items in seconds. + * If not specified, the cache client default is used. + * @param {boolean} props.ensureCacheExists If true, ensure that the cache exists before returning. + * If false, the cache is not checked for existence. + * @throws {InvalidArgumentError} If {@link props.sessionTtl} is not strictly positive. + * @returns A new chat message history backed by Momento. + */ + public static async fromProps( + props: MomentoChatMessageHistoryProps + ): Promise { + const instance = new MomentoChatMessageHistory(props); + if (props.ensureCacheExists || props.ensureCacheExists === undefined) { + await ensureCacheExists(props.client, props.cacheName); + } + return instance; + } + + /** + * Validate the user-specified TTL, if provided, is strictly positive. + * @param ttlSeconds The TTL to validate. + */ + private validateTtlSeconds(ttlSeconds?: number): void { + if (ttlSeconds !== undefined && ttlSeconds <= 0) { + throw new InvalidArgumentError("ttlSeconds must be positive."); + } + } + + /** + * Fetches messages from the cache. + * @returns A Promise that resolves to an array of BaseMessage instances. + */ + public async getMessages(): Promise { + const fetchResponse = await this.client.listFetch( + this.cacheName, + this.sessionId + ); + + let messages: StoredMessage[] = []; + if (fetchResponse instanceof CacheListFetch.Hit) { + messages = fetchResponse + .valueList() + .map((serializedStoredMessage) => JSON.parse(serializedStoredMessage)); + } else if (fetchResponse instanceof CacheListFetch.Miss) { + // pass + } else if (fetchResponse instanceof CacheListFetch.Error) { + throw fetchResponse.innerException(); + } else { + throw new Error(`Unknown response type: ${fetchResponse.toString()}`); + } + return mapStoredMessagesToChatMessages(messages); + } + + /** + * Adds a message to the cache. + * @param message The BaseMessage instance to add to the cache. + * @returns A Promise that resolves when the message has been added. + */ + public async addMessage(message: BaseMessage): Promise { + const messageToAdd = JSON.stringify( + mapChatMessagesToStoredMessages([message])[0] + ); + + const pushResponse = await this.client.listPushBack( + this.cacheName, + this.sessionId, + messageToAdd, + { ttl: this.sessionTtl } + ); + if (pushResponse instanceof CacheListPushBack.Success) { + // pass + } else if (pushResponse instanceof CacheListPushBack.Error) { + throw pushResponse.innerException(); + } else { + throw new Error(`Unknown response type: ${pushResponse.toString()}`); + } + } + + /** + * Deletes all messages from the cache. + * @returns A Promise that resolves when all messages have been deleted. + */ + public async clear(): Promise { + const deleteResponse = await this.client.delete( + this.cacheName, + this.sessionId + ); + if (deleteResponse instanceof CacheDelete.Success) { + // pass + } else if (deleteResponse instanceof CacheDelete.Error) { + throw deleteResponse.innerException(); + } else { + throw new Error(`Unknown response type: ${deleteResponse.toString()}`); + } + } +} diff --git a/libs/langchain-community/src/stores/message/mongodb.ts b/libs/langchain-community/src/stores/message/mongodb.ts new file mode 100644 index 000000000000..e68ff65f4121 --- /dev/null +++ b/libs/langchain-community/src/stores/message/mongodb.ts @@ -0,0 +1,60 @@ +import { Collection, Document as MongoDBDocument, ObjectId } from "mongodb"; +import { BaseListChatMessageHistory } from "@langchain/core/chat_history"; +import { + BaseMessage, + mapChatMessagesToStoredMessages, + mapStoredMessagesToChatMessages, +} from "@langchain/core/messages"; + +export interface MongoDBChatMessageHistoryInput { + collection: Collection; + sessionId: string; +} + +/** + * @example + * ```typescript + * const chatHistory = new MongoDBChatMessageHistory({ + * collection: myCollection, + * sessionId: 'unique-session-id', + * }); + * const messages = await chatHistory.getMessages(); + * await chatHistory.clear(); + * ``` + */ +export class MongoDBChatMessageHistory extends BaseListChatMessageHistory { + lc_namespace = ["langchain", "stores", "message", "mongodb"]; + + private collection: Collection; + + private sessionId: string; + + constructor({ collection, sessionId }: MongoDBChatMessageHistoryInput) { + super(); + this.collection = collection; + this.sessionId = sessionId; + } + + async getMessages(): Promise { + const document = await this.collection.findOne({ + _id: new ObjectId(this.sessionId), + }); + const messages = document?.messages || []; + return mapStoredMessagesToChatMessages(messages); + } + + async addMessage(message: BaseMessage): Promise { + const messages = mapChatMessagesToStoredMessages([message]); + await this.collection.updateOne( + { _id: new ObjectId(this.sessionId) }, + { + $push: { messages: { $each: messages } }, + }, + { upsert: true } + ); + } + + async clear(): Promise { + await this.collection.deleteOne({ _id: new ObjectId(this.sessionId) }); + } +} diff --git a/libs/langchain-community/src/stores/message/planetscale.ts b/libs/langchain-community/src/stores/message/planetscale.ts new file mode 100644 index 000000000000..ef72166ab6d7 --- /dev/null +++ b/libs/langchain-community/src/stores/message/planetscale.ts @@ -0,0 +1,208 @@ +import { + Client as PlanetScaleClient, + Config as PlanetScaleConfig, + Connection as PlanetScaleConnection, +} from "@planetscale/database"; +import { BaseListChatMessageHistory } from "@langchain/core/chat_history"; +import { + BaseMessage, + StoredMessage, + StoredMessageData, + mapChatMessagesToStoredMessages, + mapStoredMessagesToChatMessages, +} from "@langchain/core/messages"; + +/** + * Type definition for the input parameters required when instantiating a + * PlanetScaleChatMessageHistory object. + */ +export type PlanetScaleChatMessageHistoryInput = { + tableName?: string; + sessionId: string; + config?: PlanetScaleConfig; + client?: PlanetScaleClient; +}; + +/** + * Interface for the data transfer object used when selecting stored + * messages from the PlanetScale database. + */ +interface selectStoredMessagesDTO { + id: string; + session_id: string; + type: string; + content: string; + role: string | null; + name: string | null; + additional_kwargs: string; +} + +/** + * Class for storing and retrieving chat message history from a + * PlanetScale database. Extends the BaseListChatMessageHistory class. + * @example + * ```typescript + * const chatHistory = new PlanetScaleChatMessageHistory({ + * tableName: "stored_message", + * sessionId: "lc-example", + * config: { + * url: "ADD_YOURS_HERE", + * }, + * }); + * const chain = new ConversationChain({ + * llm: new ChatOpenAI(), + * memory: chatHistory, + * }); + * const response = await chain.invoke({ + * input: "What did I just say my name was?", + * }); + * console.log({ response }); + * ``` + */ +export class PlanetScaleChatMessageHistory extends BaseListChatMessageHistory { + lc_namespace = ["langchain", "stores", "message", "planetscale"]; + + get lc_secrets() { + return { + "config.host": "PLANETSCALE_HOST", + "config.username": "PLANETSCALE_USERNAME", + "config.password": "PLANETSCALE_PASSWORD", + "config.url": "PLANETSCALE_DATABASE_URL", + }; + } + + public client: PlanetScaleClient; + + private connection: PlanetScaleConnection; + + private tableName: string; + + private sessionId: string; + + private tableInitialized: boolean; + + constructor(fields: PlanetScaleChatMessageHistoryInput) { + super(fields); + + const { sessionId, config, client, tableName } = fields; + + if (client) { + this.client = client; + } else if (config) { + this.client = new PlanetScaleClient(config); + } else { + throw new Error( + "Either a client or config must be provided to PlanetScaleChatMessageHistory" + ); + } + + this.connection = this.client.connection(); + + this.tableName = tableName || "langchain_chat_histories"; + this.tableInitialized = false; + this.sessionId = sessionId; + } + + /** + * Private method to ensure that the necessary table exists in the + * PlanetScale database before performing any operations. If the table + * does not exist, it is created. + * @returns Promise that resolves to void. + */ + private async ensureTable(): Promise { + if (this.tableInitialized) { + return; + } + + const query = `CREATE TABLE IF NOT EXISTS ${this.tableName} (id BINARY(16) PRIMARY KEY, session_id VARCHAR(255), type VARCHAR(255), content VARCHAR(255), role VARCHAR(255), name VARCHAR(255), additional_kwargs VARCHAR(255));`; + + await this.connection.execute(query); + + const indexQuery = `ALTER TABLE ${this.tableName} MODIFY id BINARY(16) DEFAULT (UUID_TO_BIN(UUID()));`; + + await this.connection.execute(indexQuery); + + this.tableInitialized = true; + } + + /** + * Method to retrieve all messages from the PlanetScale database for the + * current session. + * @returns Promise that resolves to an array of BaseMessage objects. + */ + async getMessages(): Promise { + await this.ensureTable(); + + const query = `SELECT * FROM ${this.tableName} WHERE session_id = :session_id`; + const params = { + session_id: this.sessionId, + }; + + const rawStoredMessages = await this.connection.execute(query, params); + const storedMessagesObject = + rawStoredMessages.rows as unknown as selectStoredMessagesDTO[]; + + const orderedMessages: StoredMessage[] = storedMessagesObject.map( + (message) => { + const data = { + content: message.content, + additional_kwargs: JSON.parse(message.additional_kwargs), + } as StoredMessageData; + + if (message.role) { + data.role = message.role; + } + + if (message.name) { + data.name = message.name; + } + + return { + type: message.type, + data, + }; + } + ); + return mapStoredMessagesToChatMessages(orderedMessages); + } + + /** + * Method to add a new message to the PlanetScale database for the current + * session. + * @param message The BaseMessage object to be added to the database. + * @returns Promise that resolves to void. + */ + async addMessage(message: BaseMessage): Promise { + await this.ensureTable(); + + const messageToAdd = mapChatMessagesToStoredMessages([message]); + + const query = `INSERT INTO ${this.tableName} (session_id, type, content, role, name, additional_kwargs) VALUES (:session_id, :type, :content, :role, :name, :additional_kwargs)`; + + const params = { + session_id: this.sessionId, + type: messageToAdd[0].type, + content: messageToAdd[0].data.content, + role: messageToAdd[0].data.role, + name: messageToAdd[0].data.name, + additional_kwargs: JSON.stringify(messageToAdd[0].data.additional_kwargs), + }; + + await this.connection.execute(query, params); + } + + /** + * Method to delete all messages from the PlanetScale database for the + * current session. + * @returns Promise that resolves to void. + */ + async clear(): Promise { + await this.ensureTable(); + + const query = `DELETE FROM ${this.tableName} WHERE session_id = :session_id`; + const params = { + session_id: this.sessionId, + }; + await this.connection.execute(query, params); + } +} diff --git a/libs/langchain-community/src/stores/message/redis.ts b/libs/langchain-community/src/stores/message/redis.ts new file mode 100644 index 000000000000..8aad2b248ae3 --- /dev/null +++ b/libs/langchain-community/src/stores/message/redis.ts @@ -0,0 +1,130 @@ +// TODO: Deprecate in favor of stores/message/ioredis.ts when LLMCache and other implementations are ported +import { + createClient, + RedisClientOptions, + RedisClientType, + RedisModules, + RedisFunctions, + RedisScripts, +} from "redis"; +import { BaseListChatMessageHistory } from "@langchain/core/chat_history"; +import { + BaseMessage, + mapChatMessagesToStoredMessages, + mapStoredMessagesToChatMessages, +} from "@langchain/core/messages"; + +/** + * Type for the input to the `RedisChatMessageHistory` constructor. + */ +export type RedisChatMessageHistoryInput = { + sessionId: string; + sessionTTL?: number; + config?: RedisClientOptions; + // Typing issues with createClient output: https://github.com/redis/node-redis/issues/1865 + // eslint-disable-next-line @typescript-eslint/no-explicit-any + client?: any; +}; + +/** + * Class for storing chat message history using Redis. Extends the + * `BaseListChatMessageHistory` class. + * @example + * ```typescript + * const chatHistory = new RedisChatMessageHistory({ + * sessionId: new Date().toISOString(), + * sessionTTL: 300, + * url: "redis: + * }); + * + * const chain = new ConversationChain({ + * llm: new ChatOpenAI({ modelName: "gpt-3.5-turbo", temperature: 0 }), + * memory: { chatHistory }, + * }); + * + * const response = await chain.invoke({ + * input: "What did I just say my name was?", + * }); + * console.log({ response }); + * ``` + */ +export class RedisChatMessageHistory extends BaseListChatMessageHistory { + lc_namespace = ["langchain", "stores", "message", "redis"]; + + get lc_secrets() { + return { + "config.url": "REDIS_URL", + "config.username": "REDIS_USERNAME", + "config.password": "REDIS_PASSWORD", + }; + } + + public client: RedisClientType; + + private sessionId: string; + + private sessionTTL?: number; + + constructor(fields: RedisChatMessageHistoryInput) { + super(fields); + + const { sessionId, sessionTTL, config, client } = fields; + this.client = (client ?? createClient(config ?? {})) as RedisClientType< + RedisModules, + RedisFunctions, + RedisScripts + >; + this.sessionId = sessionId; + this.sessionTTL = sessionTTL; + } + + /** + * Ensures the Redis client is ready to perform operations. If the client + * is not ready, it attempts to connect to the Redis database. + * @returns Promise resolving to true when the client is ready. + */ + async ensureReadiness() { + if (!this.client.isReady) { + await this.client.connect(); + } + return true; + } + + /** + * Retrieves all chat messages from the Redis database for the current + * session. + * @returns Promise resolving to an array of `BaseMessage` instances. + */ + async getMessages(): Promise { + await this.ensureReadiness(); + const rawStoredMessages = await this.client.lRange(this.sessionId, 0, -1); + const orderedMessages = rawStoredMessages + .reverse() + .map((message) => JSON.parse(message)); + return mapStoredMessagesToChatMessages(orderedMessages); + } + + /** + * Adds a new chat message to the Redis database for the current session. + * @param message The `BaseMessage` instance to add. + * @returns Promise resolving when the message has been added. + */ + async addMessage(message: BaseMessage): Promise { + await this.ensureReadiness(); + const messageToAdd = mapChatMessagesToStoredMessages([message]); + await this.client.lPush(this.sessionId, JSON.stringify(messageToAdd[0])); + if (this.sessionTTL) { + await this.client.expire(this.sessionId, this.sessionTTL); + } + } + + /** + * Deletes all chat messages from the Redis database for the current + * session. + * @returns Promise resolving when the messages have been deleted. + */ + async clear(): Promise { + await this.ensureReadiness(); + await this.client.del(this.sessionId); + } +} diff --git a/libs/langchain-community/src/stores/message/upstash_redis.ts b/libs/langchain-community/src/stores/message/upstash_redis.ts new file mode 100644 index 000000000000..b1ca897a0df0 --- /dev/null +++ b/libs/langchain-community/src/stores/message/upstash_redis.ts @@ -0,0 +1,93 @@ +import { Redis, type RedisConfigNodejs } from "@upstash/redis"; +import { BaseListChatMessageHistory } from "@langchain/core/chat_history"; +import { + BaseMessage, + StoredMessage, + mapChatMessagesToStoredMessages, + mapStoredMessagesToChatMessages, +} from "@langchain/core/messages"; + +/** + * Type definition for the input parameters required to initialize an + * instance of the UpstashRedisChatMessageHistory class. + */ +export type UpstashRedisChatMessageHistoryInput = { + sessionId: string; + sessionTTL?: number; + config?: RedisConfigNodejs; + client?: Redis; +}; + +/** + * Class used to store chat message history in Redis. It provides methods + * to add, get, and clear messages. + */ +export class UpstashRedisChatMessageHistory extends BaseListChatMessageHistory { + lc_namespace = ["langchain", "stores", "message", "upstash_redis"]; + + get lc_secrets() { + return { + "config.url": "UPSTASH_REDIS_REST_URL", + "config.token": "UPSTASH_REDIS_REST_TOKEN", + }; + } + + public client: Redis; + + private sessionId: string; + + private sessionTTL?: number; + + constructor(fields: UpstashRedisChatMessageHistoryInput) { + super(fields); + const { sessionId, sessionTTL, config, client } = fields; + if (client) { + this.client = client; + } else if (config) { + this.client = new Redis(config); + } else { + throw new Error( + `Upstash Redis message stores require either a config object or a pre-configured client.` + ); + } + this.sessionId = sessionId; + this.sessionTTL = sessionTTL; + } + + /** + * Retrieves the chat messages from the Redis database. + * @returns An array of BaseMessage instances representing the chat history. + */ + async getMessages(): Promise { + const rawStoredMessages: StoredMessage[] = + await this.client.lrange(this.sessionId, 0, -1); + + const orderedMessages = rawStoredMessages.reverse(); + const previousMessages = orderedMessages.filter( + (x): x is StoredMessage => + x.type !== undefined && x.data.content !== undefined + ); + return mapStoredMessagesToChatMessages(previousMessages); + } + + /** + * Adds a new message to the chat history in the Redis database. + * @param message The message to be added to the chat history. + * @returns Promise resolving to void. + */ + async addMessage(message: BaseMessage): Promise { + const messageToAdd = mapChatMessagesToStoredMessages([message]); + await this.client.lpush(this.sessionId, JSON.stringify(messageToAdd[0])); + if (this.sessionTTL) { + await this.client.expire(this.sessionId, this.sessionTTL); + } + } + + /** + * Deletes all messages from the chat history in the Redis database. + * @returns Promise resolving to void. + */ + async clear(): Promise { + await this.client.del(this.sessionId); + } +} diff --git a/libs/langchain-community/src/stores/message/xata.ts b/libs/langchain-community/src/stores/message/xata.ts new file mode 100644 index 000000000000..705282ef049a --- /dev/null +++ b/libs/langchain-community/src/stores/message/xata.ts @@ -0,0 +1,241 @@ +import { + BaseClient, + BaseClientOptions, + GetTableSchemaResponse, + Schemas, + XataApiClient, + parseWorkspacesUrlParts, +} from "@xata.io/client"; +import { BaseListChatMessageHistory } from "@langchain/core/chat_history"; +import { + BaseMessage, + StoredMessage, + StoredMessageData, + mapChatMessagesToStoredMessages, + mapStoredMessagesToChatMessages, +} from "@langchain/core/messages"; + +/** + * An object type that represents the input for the XataChatMessageHistory + * class. + */ +export type XataChatMessageHistoryInput = { + sessionId: string; + config?: BaseClientOptions; + client?: XataClient; + table?: string; + createTable?: boolean; + apiKey?: string; +}; + +/** + * An interface that represents the data transfer object for stored + * messages. + */ +interface storedMessagesDTO { + id: string; + sessionId: string; + type: string; + content: string; + role?: string; + name?: string; + additionalKwargs: string; +} + +const chatMemoryColumns: Schemas.Column[] = [ + { name: "sessionId", type: "string" }, + { name: "type", type: "string" }, + { name: "role", type: "string" }, + { name: "content", type: "text" }, + { name: "name", type: "string" }, + { name: "additionalKwargs", type: "text" }, +]; + +/** + * A class for managing chat message history using Xata.io client. It + * extends the BaseListChatMessageHistory class and provides methods to + * get, add, and clear messages. It also ensures the existence of a table + * where the chat messages are stored. + * @example + * ```typescript + * const chatHistory = new XataChatMessageHistory({ + * table: "messages", + * sessionId: new Date().toISOString(), + * client: new BaseClient({ + * databaseURL: process.env.XATA_DB_URL, + * apiKey: process.env.XATA_API_KEY, + * branch: "main", + * }), + * apiKey: process.env.XATA_API_KEY, + * }); + * + * const chain = new ConversationChain({ + * llm: new ChatOpenAI(), + * memory: new BufferMemory({ chatHistory }), + * }); + * + * const response = await chain.invoke({ + * input: "What did I just say my name was?", + * }); + * console.log({ response }); + * ``` + */ +export class XataChatMessageHistory< + XataClient extends BaseClient +> extends BaseListChatMessageHistory { + lc_namespace = ["langchain", "stores", "message", "xata"]; + + public client: XataClient; + + private sessionId: string; + + private table: string; + + private tableInitialized: boolean; + + private createTable: boolean; + + private apiClient: XataApiClient; + + constructor(fields: XataChatMessageHistoryInput) { + super(fields); + + const { sessionId, config, client, table } = fields; + this.sessionId = sessionId; + this.table = table || "memory"; + if (client) { + this.client = client; + } else if (config) { + this.client = new BaseClient(config) as XataClient; + } else { + throw new Error( + "Either a client or a config must be provided to XataChatMessageHistoryInput" + ); + } + if (fields.createTable !== false) { + this.createTable = true; + const apiKey = fields.apiKey || fields.config?.apiKey; + if (!apiKey) { + throw new Error( + "If createTable is set, an apiKey must be provided to XataChatMessageHistoryInput, either directly or through the config object" + ); + } + this.apiClient = new XataApiClient({ apiKey }); + } else { + this.createTable = false; + } + this.tableInitialized = false; + } + + /** + * Retrieves all messages associated with the session ID, ordered by + * creation time. + * @returns A promise that resolves to an array of BaseMessage instances. + */ + async getMessages(): Promise { + await this.ensureTable(); + const records = await this.client.db[this.table] + .filter({ sessionId: this.sessionId }) + .sort("xata.createdAt", "asc") + .getAll(); + + const rawStoredMessages = records as unknown as storedMessagesDTO[]; + const orderedMessages: StoredMessage[] = rawStoredMessages.map( + (message: storedMessagesDTO) => { + const data = { + content: message.content, + additional_kwargs: JSON.parse(message.additionalKwargs), + } as StoredMessageData; + if (message.role) { + data.role = message.role; + } + if (message.name) { + data.name = message.name; + } + + return { + type: message.type, + data, + }; + } + ); + return mapStoredMessagesToChatMessages(orderedMessages); + } + + /** + * Adds a new message to the database. + * @param message The BaseMessage instance to be added. + * @returns A promise that resolves when the message has been added. + */ + async addMessage(message: BaseMessage): Promise { + await this.ensureTable(); + const messageToAdd = mapChatMessagesToStoredMessages([message]); + await this.client.db[this.table].create({ + sessionId: this.sessionId, + type: messageToAdd[0].type, + content: messageToAdd[0].data.content, + role: messageToAdd[0].data.role, + name: messageToAdd[0].data.name, + additionalKwargs: JSON.stringify(messageToAdd[0].data.additional_kwargs), + }); + } + + /** + * Deletes all messages associated with the session ID. + * @returns A promise that resolves when the messages have been deleted. + */ + async clear(): Promise { + await this.ensureTable(); + const records = await this.client.db[this.table] + .select(["id"]) + .filter({ sessionId: this.sessionId }) + .getAll(); + const ids = records.map((m) => m.id); + await this.client.db[this.table].delete(ids); + } + + /** + * Checks if the table exists and creates it if it doesn't. This method is + * called before any operation on the table. + * @returns A promise that resolves when the table has been ensured. + */ + private async ensureTable(): Promise { + if (!this.createTable) { + return; + } + if (this.tableInitialized) { + return; + } + + const { databaseURL, branch } = await this.client.getConfig(); + const [, , host, , database] = databaseURL.split("/"); + const urlParts = parseWorkspacesUrlParts(host); + if (urlParts == null) { + throw new Error("Invalid databaseURL"); + } + const { workspace, region } = urlParts; + const tableParams = { + workspace, + region, + database, + branch, + table: this.table, + }; + + let schema: GetTableSchemaResponse | null = null; + try { + schema = await this.apiClient.tables.getTableSchema(tableParams); + } catch (e) { + // pass + } + if (schema == null) { + await this.apiClient.tables.createTable(tableParams); + await this.apiClient.tables.setTableSchema({ + ...tableParams, + schema: { + columns: chatMemoryColumns, + }, + }); + } + } +} diff --git a/libs/langchain-community/src/tools/aiplugin.ts b/libs/langchain-community/src/tools/aiplugin.ts new file mode 100644 index 000000000000..7e8ebea28cdb --- /dev/null +++ b/libs/langchain-community/src/tools/aiplugin.ts @@ -0,0 +1,81 @@ +import { Tool, type ToolParams } from "@langchain/core/tools"; + +/** + * Interface for parameters required to create an instance of + * AIPluginTool. + */ +export interface AIPluginToolParams extends ToolParams { + name: string; + description: string; + apiSpec: string; +} + +/** + * Class for creating instances of AI tools from plugins. It extends the + * Tool class and implements the AIPluginToolParams interface. + */ +export class AIPluginTool extends Tool implements AIPluginToolParams { + static lc_name() { + return "AIPluginTool"; + } + + private _name: string; + + private _description: string; + + apiSpec: string; + + get name() { + return this._name; + } + + get description() { + return this._description; + } + + constructor(params: AIPluginToolParams) { + super(params); + this._name = params.name; + this._description = params.description; + this.apiSpec = params.apiSpec; + } + + /** @ignore */ + async _call(_input: string) { + return this.apiSpec; + } + + /** + * Static method that creates an instance of AIPluginTool from a given + * plugin URL. It fetches the plugin and its API specification from the + * provided URL and returns a new instance of AIPluginTool with the + * fetched data. + * @param url The URL of the AI plugin. + * @returns A new instance of AIPluginTool. + */ + static async fromPluginUrl(url: string) { + const aiPluginRes = await fetch(url); + if (!aiPluginRes.ok) { + throw new Error( + `Failed to fetch plugin from ${url} with status ${aiPluginRes.status}` + ); + } + const aiPluginJson = await aiPluginRes.json(); + + const apiUrlRes = await fetch(aiPluginJson.api.url); + if (!apiUrlRes.ok) { + throw new Error( + `Failed to fetch API spec from ${aiPluginJson.api.url} with status ${apiUrlRes.status}` + ); + } + const apiUrlJson = await apiUrlRes.text(); + + return new AIPluginTool({ + name: aiPluginJson.name_for_model, + description: `Call this tool to get the OpenAPI spec (and usage guide) for interacting with the ${aiPluginJson.name_for_human} API. You should only call this ONCE! What is the ${aiPluginJson.name_for_human} API useful for? ${aiPluginJson.description_for_human}`, + apiSpec: `Usage Guide: ${aiPluginJson.description_for_model} + +OpenAPI Spec in JSON or YAML format:\n${apiUrlJson}`, + }); + } +} diff --git a/libs/langchain-community/src/tools/aws_sfn.ts b/libs/langchain-community/src/tools/aws_sfn.ts new file mode 100644 index 000000000000..0614ef959d02 --- /dev/null +++ b/libs/langchain-community/src/tools/aws_sfn.ts @@ -0,0 +1,225 @@ +import { + SFNClient as Client, + StartExecutionCommand as Invoker, + DescribeExecutionCommand as Describer, + SendTaskSuccessCommand as TaskSuccessSender, +} from "@aws-sdk/client-sfn"; + +import { Tool, ToolParams } from "@langchain/core/tools"; + +/** + * Interface for AWS Step Functions configuration. + */ +export interface SfnConfig { + stateMachineArn: string; + region?: string; + accessKeyId?: string; + secretAccessKey?: string; +} + +/** + * Interface for AWS Step Functions client constructor arguments. + */ +interface SfnClientConstructorArgs { + region?: string; + credentials?: { + accessKeyId: string; + secretAccessKey: string; + }; +} + +/** + * Class for starting the execution of an AWS Step Function. + */ +export class StartExecutionAWSSfnTool extends Tool { + static lc_name() { + return "StartExecutionAWSSfnTool"; + } + + private sfnConfig: SfnConfig; + + public name: string; + + public description: string; + + constructor({ + name, + description, + ...rest + }: SfnConfig & { name: string; description: string }) { + super(); + this.name = name; + this.description = description; + this.sfnConfig = rest; + } + + /** + * Generates a formatted description for the StartExecutionAWSSfnTool. + * @param name Name of the state machine. + * @param description Description of the state machine. + * @returns A formatted description string. + */ + static formatDescription(name: string, description: string): string { + return `Use to start executing the ${name} state machine. Use to run ${name} workflows. Whenever you need to start (or execute) an asynchronous workflow (or state machine) about ${description} you should ALWAYS use this. Input should be a valid JSON string.`; + } + + /** @ignore */ + async _call(input: string): Promise { + const clientConstructorArgs: SfnClientConstructorArgs = + getClientConstructorArgs(this.sfnConfig); + const sfnClient = new Client(clientConstructorArgs); + + return new Promise((resolve) => { + let payload; + try { + payload = JSON.parse(input); + } catch (e) { + console.error("Error starting state machine execution:", e); + resolve("failed to complete request"); + } + + const command = new Invoker({ + stateMachineArn: this.sfnConfig.stateMachineArn, + input: JSON.stringify(payload), + }); + + sfnClient + .send(command) + .then((response) => + resolve( + response.executionArn ? response.executionArn : "request completed." + ) + ) + .catch((error: Error) => { + console.error("Error starting state machine execution:", error); + resolve("failed to complete request"); + }); + }); + } +} + +/** + * Class for checking the status of an AWS Step Function execution. + */ +export class DescribeExecutionAWSSfnTool extends Tool { + static lc_name() { + return "DescribeExecutionAWSSfnTool"; + } + + name = "describe-execution-aws-sfn"; + + description = + "This tool should ALWAYS be used for checking the status of any AWS Step Function execution (aka. state machine execution). Input to this tool is a properly formatted AWS Step Function Execution ARN (executionArn). The output is a stringified JSON object containing the executionArn, name, status, startDate, stopDate, input, output, error, and cause of the execution."; + + sfnConfig: Omit; + + constructor(config: Omit & ToolParams) { + super(config); + this.sfnConfig = config; + } + + /** @ignore */ + async _call(input: string) { + const clientConstructorArgs: SfnClientConstructorArgs = + getClientConstructorArgs(this.sfnConfig); + const sfnClient = new Client(clientConstructorArgs); + + const command = new Describer({ + executionArn: input, + }); + return await sfnClient + .send(command) + .then((response) => + response.executionArn + ? JSON.stringify({ + executionArn: response.executionArn, + name: response.name, + status: response.status, + startDate: response.startDate, + stopDate: response.stopDate, + input: response.input, + output: response.output, + error: response.error, + cause: response.cause, + }) + : "{}" + ) + .catch((error: Error) => { + console.error("Error describing state machine execution:", error); + return "failed to complete request"; + }); + } +} + +/** + * Class for sending a task success signal to an AWS Step Function + * execution. + */ +export class SendTaskSuccessAWSSfnTool extends Tool { + static lc_name() { + return "SendTaskSuccessAWSSfnTool"; + } + + name = "send-task-success-aws-sfn"; + + description = + "This tool should ALWAYS be used for sending task success to an AWS Step Function execution (aka. statemachine exeuction). Input to this tool is a stringify JSON object containing the taskToken and output."; + + sfnConfig: Omit; + + constructor(config: Omit & ToolParams) { + super(config); + this.sfnConfig = config; + } + + /** @ignore */ + async _call(input: string) { + const clientConstructorArgs: SfnClientConstructorArgs = + getClientConstructorArgs(this.sfnConfig); + const sfnClient = new Client(clientConstructorArgs); + + let payload; + try { + payload = JSON.parse(input); + } catch (e) { + console.error("Error starting state machine execution:", e); + return "failed to complete request"; + } + + const command = new TaskSuccessSender({ + taskToken: payload.taskToken, + output: JSON.stringify(payload.output), + }); + + return await sfnClient + .send(command) + .then(() => "request completed.") + .catch((error: Error) => { + console.error( + "Error sending task success to state machine execution:", + error + ); + return "failed to complete request"; + }); + } +} + +/** + * Helper function to construct the AWS SFN client. + */ +function getClientConstructorArgs(config: Partial) { + const clientConstructorArgs: SfnClientConstructorArgs = {}; + + if (config.region) { + clientConstructorArgs.region = config.region; + } + + if (config.accessKeyId && config.secretAccessKey) { + clientConstructorArgs.credentials = { + accessKeyId: config.accessKeyId, + secretAccessKey: config.secretAccessKey, + }; + } + + return clientConstructorArgs; +} diff --git a/libs/langchain-community/src/tools/bingserpapi.ts b/libs/langchain-community/src/tools/bingserpapi.ts new file mode 100644 index 000000000000..62d82908b41d --- /dev/null +++ b/libs/langchain-community/src/tools/bingserpapi.ts @@ -0,0 +1,78 @@ +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { Tool } from "@langchain/core/tools"; + +/** + * A tool for web search functionality using Bing's search engine. It + * extends the base `Tool` class and implements the `_call` method to + * perform the search operation. Requires an API key for Bing's search + * engine, which can be set in the environment variables. Also accepts + * additional parameters for the search query. + */ +class BingSerpAPI extends Tool { + static lc_name() { + return "BingSerpAPI"; + } + + /** + * Not implemented. Will throw an error if called. + */ + toJSON() { + return this.toJSONNotImplemented(); + } + + name = "bing-search"; + + description = + "a search engine. useful for when you need to answer questions about current events. input should be a search query."; + + key: string; + + params: Record; + + constructor( + apiKey: string | undefined = getEnvironmentVariable("BingApiKey"), + params: Record = {} + ) { + super(...arguments); + + if (!apiKey) { + throw new Error( + "BingSerpAPI API key not set. You can set it as BingApiKey in your .env file." + ); + } + + this.key = apiKey; + this.params = params; + } + + /** @ignore */ + async _call(input: string): Promise { + const headers = { "Ocp-Apim-Subscription-Key": this.key }; + const params = { q: input, textDecorations: "true", textFormat: "HTML" }; + const searchUrl = new URL("https://api.bing.microsoft.com/v7.0/search"); + + Object.entries(params).forEach(([key, value]) => { + searchUrl.searchParams.append(key, value); + }); + + const response = await fetch(searchUrl, { headers }); + + if (!response.ok) { + throw new Error(`HTTP error ${response.status}`); + } + + const res = await response.json(); + const results: [] = res.webPages.value; + + if (results.length === 0) { + return "No good results found."; + } + const snippets = results + .map((result: { snippet: string }) => result.snippet) + .join(" "); + + return snippets; + } +} + +export { BingSerpAPI }; diff --git a/libs/langchain-community/src/tools/brave_search.ts b/libs/langchain-community/src/tools/brave_search.ts new file mode 100644 index 000000000000..5d360f2b9c29 --- /dev/null +++ b/libs/langchain-community/src/tools/brave_search.ts @@ -0,0 +1,77 @@ +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { Tool } from "@langchain/core/tools"; + +/** + * Interface for the parameters required to instantiate a BraveSearch + * instance. + */ +export interface BraveSearchParams { + apiKey?: string; +} + +/** + * Class for interacting with the Brave Search engine. It extends the Tool + * class and requires an API key to function. The API key can be passed in + * during instantiation or set as an environment variable named + * 'BRAVE_SEARCH_API_KEY'. + */ +export class BraveSearch extends Tool { + static lc_name() { + return "BraveSearch"; + } + + name = "brave-search"; + + description = + "a search engine. useful for when you need to answer questions about current events. input should be a search query."; + + apiKey: string; + + constructor( + fields: BraveSearchParams = { + apiKey: getEnvironmentVariable("BRAVE_SEARCH_API_KEY"), + } + ) { + super(); + + if (!fields.apiKey) { + throw new Error( + `Brave API key not set. Please pass it in or set it as an environment variable named "BRAVE_SEARCH_API_KEY".` + ); + } + + this.apiKey = fields.apiKey; + } + + /** @ignore */ + async _call(input: string): Promise { + const headers = { + "X-Subscription-Token": this.apiKey, + Accept: "application/json", + }; + const searchUrl = new URL( + `https://api.search.brave.com/res/v1/web/search?q=${encodeURIComponent( + input + )}` + ); + + const response = await fetch(searchUrl, { headers }); + + if (!response.ok) { + throw new Error(`HTTP error ${response.status}`); + } + + const parsedResponse = await response.json(); + const webSearchResults = parsedResponse.web?.results; + const finalResults = Array.isArray(webSearchResults) + ? webSearchResults.map( + (item: { title?: string; url?: string; description?: string }) => ({ + title: item.title, + link: item.url, + snippet: item.description, + }) + ) + : []; + return JSON.stringify(finalResults); + } +} diff --git a/libs/langchain-community/src/tools/connery.ts b/libs/langchain-community/src/tools/connery.ts new file mode 100644 index 000000000000..30eb21999782 --- /dev/null +++ b/libs/langchain-community/src/tools/connery.ts @@ -0,0 +1,356 @@ +import { + AsyncCaller, + AsyncCallerParams, +} from "@langchain/core/utils/async_caller"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { Tool } from "@langchain/core/tools"; + +/** + * An object containing configuration parameters for the ConneryService class. + * @extends AsyncCallerParams + */ +export interface ConneryServiceParams extends AsyncCallerParams { + runnerUrl: string; + apiKey: string; +} + +type ApiResponse = { + status: "success"; + data: T; +}; + +type ApiErrorResponse = { + status: "error"; + error: { + message: string; + }; +}; + +type Parameter = { + key: string; + title: string; + description: string; + type: string; + validation?: { + required?: boolean; + }; +}; + +type Action = { + id: string; + key: string; + title: string; + description: string; + type: string; + inputParameters: Parameter[]; + outputParameters: Parameter[]; + pluginId: string; +}; + +type Input = { + [key: string]: string; +}; + +type Output = { + [key: string]: string; +}; + +type RunActionResult = { + output: Output; + used: { + actionId: string; + input: Input; + }; +}; + +/** + * A LangChain Tool object wrapping a Connery action. + * @extends Tool + */ +export class ConneryAction extends Tool { + name: string; + + description: string; + + /** + * Creates a ConneryAction instance based on the provided Connery action. + * @param _action The Connery action. + * @param _service The ConneryService instance. + * @returns A ConneryAction instance. + */ + constructor(protected _action: Action, protected _service: ConneryService) { + super(); + + this.name = this._action.title; + this.description = this.getDescription(); + } + + /** + * Runs the Connery action. + * @param prompt This is a plain English prompt with all the information needed to run the action. + * @returns A promise that resolves to a JSON string containing the output of the action. + */ + protected _call(prompt: string): Promise { + return this._service.runAction(this._action.id, prompt); + } + + /** + * Returns the description of the Connery action. + * @returns A string containing the description of the Connery action together with the instructions on how to use it. + */ + protected getDescription(): string { + const { title, description } = this._action; + const inputParameters = this.prepareJsonForTemplate( + this._action.inputParameters + ); + const example1InputParametersSchema = this.prepareJsonForTemplate([ + { + key: "recipient", + title: "Email Recipient", + description: "Email address of the email recipient.", + type: "string", + validation: { + required: true, + }, + }, + { + key: "subject", + title: "Email Subject", + description: "Subject of the email.", + type: "string", + validation: { + required: true, + }, + }, + { + key: "body", + title: "Email Body", + description: "Body of the email.", + type: "string", + validation: { + required: true, + }, + }, + ]); + + const descriptionTemplate = + "# Instructions about tool input:\n" + + "The input to this tool is a plain English prompt with all the input parameters needed to call it. " + + "The input parameters schema of this tool is provided below. " + + "Use the input parameters schema to construct the prompt for the tool. " + + "If the input parameter is required in the schema, it must be provided in the prompt. " + + "Do not come up with the values for the input parameters yourself. " + + "If you do not have enough information to fill in the input parameter, ask the user to provide it. " + + "See examples below on how to construct the prompt based on the provided tool information. " + + "\n\n" + + "# Instructions about tool output:\n" + + "The output of this tool is a JSON string. " + + "Retrieve the output parameters from the JSON string and use them in the next tool. " + + "Do not return the JSON string as the output of the tool. " + + "\n\n" + + "# Example:\n" + + "Tool information:\n" + + "- Title: Send email\n" + + "- Description: Send an email to a recipient.\n" + + `- Input parameters schema in JSON fromat: ${example1InputParametersSchema}\n` + + "The tool input prompt:\n" + + "recipient: test@example.com, subject: 'Test email', body: 'This is a test email sent from Langchain Connery tool.'\n" + + "\n\n" + + "# The tool information\n" + + `- Title: ${title}\n` + + `- Description: ${description}\n` + + `- Input parameters schema in JSON fromat: ${inputParameters}\n`; + + return descriptionTemplate; + } + + /** + * Converts the provided object to a JSON string and escapes '{' and '}' characters. + * @param obj The object to convert to a JSON string. + * @returns A string containing the JSON representation of the provided object with '{' and '}' characters escaped. + */ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + protected prepareJsonForTemplate(obj: any): string { + // Convert the object to a JSON string + const jsonString = JSON.stringify(obj); + + // Replace '{' with '{{' and '}' with '}}' + const escapedJSON = jsonString.replace(/{/g, "{{").replace(/}/g, "}}"); + + return escapedJSON; + } +} + +/** + * A service for working with Connery actions. + * + * Connery is an open-source plugin infrastructure for AI. + * Source code: https://github.com/connery-io/connery-platform + */ +export class ConneryService { + protected runnerUrl: string; + + protected apiKey: string; + + protected asyncCaller: AsyncCaller; + + /** + * Creates a ConneryService instance. + * @param params A ConneryServiceParams object. + * If not provided, the values are retrieved from the CONNERY_RUNNER_URL + * and CONNERY_RUNNER_API_KEY environment variables. + * @returns A ConneryService instance. + */ + constructor(params?: ConneryServiceParams) { + const runnerUrl = + params?.runnerUrl ?? getEnvironmentVariable("CONNERY_RUNNER_URL"); + const apiKey = + params?.apiKey ?? getEnvironmentVariable("CONNERY_RUNNER_API_KEY"); + + if (!runnerUrl || !apiKey) { + throw new Error( + "CONNERY_RUNNER_URL and CONNERY_RUNNER_API_KEY environment variables must be set." + ); + } + + this.runnerUrl = runnerUrl; + this.apiKey = apiKey; + + this.asyncCaller = new AsyncCaller(params ?? {}); + } + + /** + * Returns the list of Connery actions wrapped as a LangChain Tool objects. + * @returns A promise that resolves to an array of ConneryAction objects. + */ + async listActions(): Promise { + const actions = await this._listActions(); + return actions.map((action) => new ConneryAction(action, this)); + } + + /** + * Returns the specified Connery action wrapped as a LangChain Tool object. + * @param actionId The ID of the action to return. + * @returns A promise that resolves to a ConneryAction object. + */ + async getAction(actionId: string): Promise { + const action = await this._getAction(actionId); + return new ConneryAction(action, this); + } + + /** + * Runs the specified Connery action with the provided input. + * @param actionId The ID of the action to run. + * @param prompt This is a plain English prompt with all the information needed to run the action. + * @param input The input expected by the action. + * If provided together with the prompt, the input takes precedence over the input specified in the prompt. + * @returns A promise that resolves to a JSON string containing the output of the action. + */ + async runAction( + actionId: string, + prompt?: string, + input?: Input + ): Promise { + const result = await this._runAction(actionId, prompt, input); + return JSON.stringify(result); + } + + /** + * Returns the list of actions available in the Connery runner. + * @returns A promise that resolves to an array of Action objects. + */ + protected async _listActions(): Promise { + const response = await this.asyncCaller.call( + fetch, + `${this.runnerUrl}/v1/actions`, + { + method: "GET", + headers: this._getHeaders(), + } + ); + await this._handleError(response, "Failed to list actions"); + + const apiResponse: ApiResponse = await response.json(); + return apiResponse.data; + } + + /** + * Returns the specified action available in the Connery runner. + * @param actionId The ID of the action to return. + * @returns A promise that resolves to an Action object. + * @throws An error if the action with the specified ID is not found. + */ + protected async _getAction(actionId: string): Promise { + const actions = await this._listActions(); + const action = actions.find((a) => a.id === actionId); + if (!action) { + throw new Error( + `The action with ID "${actionId}" was not found in the list of available actions in the Connery runner.` + ); + } + return action; + } + + /** + * Runs the specified Connery action with the provided input. + * @param actionId The ID of the action to run. + * @param prompt This is a plain English prompt with all the information needed to run the action. + * @param input The input object expected by the action. + * If provided together with the prompt, the input takes precedence over the input specified in the prompt. + * @returns A promise that resolves to a RunActionResult object. + */ + protected async _runAction( + actionId: string, + prompt?: string, + input?: Input + ): Promise { + const response = await this.asyncCaller.call( + fetch, + `${this.runnerUrl}/v1/actions/${actionId}/run`, + { + method: "POST", + headers: this._getHeaders(), + body: JSON.stringify({ + prompt, + input, + }), + } + ); + await this._handleError(response, "Failed to run action"); + + const apiResponse: ApiResponse = await response.json(); + return apiResponse.data.output; + } + + /** + * Returns a standard set of HTTP headers to be used in API calls to the Connery runner. + * @returns An object containing the standard set of HTTP headers. + */ + protected _getHeaders(): Record { + return { + "Content-Type": "application/json", + "x-api-key": this.apiKey, + }; + } + + /** + * Shared error handler for API calls to the Connery runner. + * If the response is not ok, an error is thrown containing the error message returned by the Connery runner. + * Otherwise, the promise resolves to void. + * @param response The response object returned by the Connery runner. + * @param errorMessage The error message to be used in the error thrown if the response is not ok. + * @returns A promise that resolves to void. + * @throws An error containing the error message returned by the Connery runner. + */ + protected async _handleError( + response: Response, + errorMessage: string + ): Promise { + if (response.ok) return; + + const apiErrorResponse: ApiErrorResponse = await response.json(); + throw new Error( + `${errorMessage}. Status code: ${response.status}. Error message: ${apiErrorResponse.error.message}` + ); + } +} diff --git a/libs/langchain-community/src/tools/dadjokeapi.ts b/libs/langchain-community/src/tools/dadjokeapi.ts new file mode 100644 index 000000000000..2efd934d67e7 --- /dev/null +++ b/libs/langchain-community/src/tools/dadjokeapi.ts @@ -0,0 +1,44 @@ +import { Tool } from "@langchain/core/tools"; + +/** + * The DadJokeAPI class is a tool for generating dad jokes based on a + * specific topic. It fetches jokes from an external API and returns a + * random joke from the results. If no jokes are found for the given + * search term, it returns a message indicating that no jokes were found. + */ +class DadJokeAPI extends Tool { + static lc_name() { + return "DadJokeAPI"; + } + + name = "dadjoke"; + + description = + "a dad joke generator. get a dad joke about a specific topic. input should be a search term."; + + /** @ignore */ + async _call(input: string): Promise { + const headers = { Accept: "application/json" }; + const searchUrl = `https://icanhazdadjoke.com/search?term=${input}`; + + const response = await fetch(searchUrl, { headers }); + + if (!response.ok) { + throw new Error(`HTTP error ${response.status}`); + } + + const data = await response.json(); + const jokes = data.results; + + if (jokes.length === 0) { + return `No dad jokes found about ${input}`; + } + + const randomIndex = Math.floor(Math.random() * jokes.length); + const randomJoke = jokes[randomIndex].joke; + + return randomJoke; + } +} + +export { DadJokeAPI }; diff --git a/libs/langchain-community/src/tools/dataforseo_api_search.ts b/libs/langchain-community/src/tools/dataforseo_api_search.ts new file mode 100644 index 000000000000..26b6156dd93b --- /dev/null +++ b/libs/langchain-community/src/tools/dataforseo_api_search.ts @@ -0,0 +1,378 @@ +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { Tool } from "@langchain/core/tools"; + +/** + * @interface DataForSeoApiConfig + * @description Represents the configuration object used to set up a DataForSeoAPISearch instance. + */ +export interface DataForSeoApiConfig { + /** + * @property apiLogin + * @type {string} + * @description The API login credential for DataForSEO. If not provided, it will be fetched from environment variables. + */ + apiLogin?: string; + + /** + * @property apiPassword + * @type {string} + * @description The API password credential for DataForSEO. If not provided, it will be fetched from environment variables. + */ + apiPassword?: string; + + /** + * @property params + * @type {Record} + * @description Additional parameters to customize the API request. + */ + params?: Record; + + /** + * @property useJsonOutput + * @type {boolean} + * @description Determines if the output should be in JSON format. + */ + useJsonOutput?: boolean; + + /** + * @property jsonResultTypes + * @type {Array} + * @description Specifies the types of results to include in the output. + */ + jsonResultTypes?: Array; + + /** + * @property jsonResultFields + * @type {Array} + * @description Specifies the fields to include in each result object. + */ + jsonResultFields?: Array; + + /** + * @property topCount + * @type {number} + * @description Specifies the maximum number of results to return. + */ + topCount?: number; +} + +/** + * Represents a task in the API response. + */ +type Task = { + id: string; + status_code: number; + status_message: string; + time: string; + result: Result[]; +}; + +/** + * Represents a result in the API response. + */ +type Result = { + keyword: string; + check_url: string; + datetime: string; + spell?: string; + item_types: string[]; + se_results_count: number; + items_count: number; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + items: any[]; +}; + +/** + * Represents the API response. + */ +type ApiResponse = { + status_code: number; + status_message: string; + tasks: Task[]; +}; + +/** + * @class DataForSeoAPISearch + * @extends {Tool} + * @description Represents a wrapper class to work with DataForSEO SERP API. + */ +export class DataForSeoAPISearch extends Tool { + static lc_name() { + return "DataForSeoAPISearch"; + } + + name = "dataforseo-api-wrapper"; + + description = + "A robust Google Search API provided by DataForSeo. This tool is handy when you need information about trending topics or current events."; + + protected apiLogin: string; + + protected apiPassword: string; + + /** + * @property defaultParams + * @type {Record} + * @description These are the default parameters to be used when making an API request. + */ + protected defaultParams: Record = { + location_name: "United States", + language_code: "en", + depth: 10, + se_name: "google", + se_type: "organic", + }; + + protected params: Record = {}; + + protected jsonResultTypes: Array | undefined; + + protected jsonResultFields: Array | undefined; + + protected topCount: number | undefined; + + protected useJsonOutput = false; + + /** + * @constructor + * @param {DataForSeoApiConfig} config + * @description Sets up the class, throws an error if the API login/password isn't provided. + */ + constructor(config: DataForSeoApiConfig = {}) { + super(); + const apiLogin = + config.apiLogin ?? getEnvironmentVariable("DATAFORSEO_LOGIN"); + const apiPassword = + config.apiPassword ?? getEnvironmentVariable("DATAFORSEO_PASSWORD"); + const params = config.params ?? {}; + if (!apiLogin || !apiPassword) { + throw new Error( + "DataForSEO login or password not set. You can set it as DATAFORSEO_LOGIN and DATAFORSEO_PASSWORD in your .env file, or pass it to DataForSeoAPISearch." + ); + } + this.params = { ...this.defaultParams, ...params }; + this.apiLogin = apiLogin; + this.apiPassword = apiPassword; + this.jsonResultTypes = config.jsonResultTypes; + this.jsonResultFields = config.jsonResultFields; + this.useJsonOutput = config.useJsonOutput ?? false; + this.topCount = config.topCount; + } + + /** + * @method _call + * @param {string} keyword + * @returns {Promise} + * @description Initiates a call to the API and processes the response. + */ + async _call(keyword: string): Promise { + return this.useJsonOutput + ? JSON.stringify(await this.results(keyword)) + : this.processResponse(await this.getResponseJson(keyword)); + } + + /** + * @method results + * @param {string} keyword + * @returns {Promise>} + * @description Fetches the results from the API for the given keyword. + */ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + async results(keyword: string): Promise> { + const res = await this.getResponseJson(keyword); + return this.filterResults(res, this.jsonResultTypes); + } + + /** + * @method prepareRequest + * @param {string} keyword + * @returns {{url: string; headers: HeadersInit; data: BodyInit}} + * @description Prepares the request details for the API call. + */ + protected prepareRequest(keyword: string): { + url: string; + headers: HeadersInit; + data: BodyInit; + } { + if (this.apiLogin === undefined || this.apiPassword === undefined) { + throw new Error("api_login or api_password is not provided"); + } + + const credentials = Buffer.from( + `${this.apiLogin}:${this.apiPassword}`, + "utf-8" + ).toString("base64"); + const headers = { + Authorization: `Basic ${credentials}`, + "Content-Type": "application/json", + }; + + const params = { ...this.params }; + params.keyword ??= keyword; + const data = [params]; + + return { + url: `https://api.dataforseo.com/v3/serp/${params.se_name}/${params.se_type}/live/advanced`, + headers, + data: JSON.stringify(data), + }; + } + + /** + * @method getResponseJson + * @param {string} keyword + * @returns {Promise} + * @description Executes a POST request to the provided URL and returns a parsed JSON response. + */ + protected async getResponseJson(keyword: string): Promise { + const requestDetails = this.prepareRequest(keyword); + const response = await fetch(requestDetails.url, { + method: "POST", + headers: requestDetails.headers, + body: requestDetails.data, + }); + + if (!response.ok) { + throw new Error( + `Got ${response.status} error from DataForSEO: ${response.statusText}` + ); + } + + const result: ApiResponse = await response.json(); + return this.checkResponse(result); + } + + /** + * @method checkResponse + * @param {ApiResponse} response + * @returns {ApiResponse} + * @description Checks the response status code. + */ + private checkResponse(response: ApiResponse): ApiResponse { + if (response.status_code !== 20000) { + throw new Error( + `Got error from DataForSEO SERP API: ${response.status_message}` + ); + } + for (const task of response.tasks) { + if (task.status_code !== 20000) { + throw new Error( + `Got error from DataForSEO SERP API: ${task.status_message}` + ); + } + } + return response; + } + + /* eslint-disable @typescript-eslint/no-explicit-any */ + /** + * @method filterResults + * @param {ApiResponse} res + * @param {Array | undefined} types + * @returns {Array} + * @description Filters the results based on the specified result types. + */ + private filterResults( + res: ApiResponse, + types: Array | undefined + ): Array { + const output: Array = []; + for (const task of res.tasks || []) { + for (const result of task.result || []) { + for (const item of result.items || []) { + if ( + types === undefined || + types.length === 0 || + types.includes(item.type) + ) { + const newItem = this.cleanupUnnecessaryItems(item); + if (Object.keys(newItem).length !== 0) { + output.push(newItem); + } + } + if (this.topCount !== undefined && output.length >= this.topCount) { + break; + } + } + } + } + return output; + } + + /* eslint-disable @typescript-eslint/no-explicit-any */ + /* eslint-disable no-param-reassign */ + /** + * @method cleanupUnnecessaryItems + * @param {any} d + * @description Removes unnecessary items from the response. + */ + private cleanupUnnecessaryItems(d: any): any { + if (Array.isArray(d)) { + return d.map((item) => this.cleanupUnnecessaryItems(item)); + } + + const toRemove = ["xpath", "position", "rectangle"]; + if (typeof d === "object" && d !== null) { + return Object.keys(d).reduce((newObj: any, key: string) => { + if ( + (this.jsonResultFields === undefined || + this.jsonResultFields.includes(key)) && + !toRemove.includes(key) + ) { + if (typeof d[key] === "object" && d[key] !== null) { + newObj[key] = this.cleanupUnnecessaryItems(d[key]); + } else { + newObj[key] = d[key]; + } + } + return newObj; + }, {}); + } + + return d; + } + + /** + * @method processResponse + * @param {ApiResponse} res + * @returns {string} + * @description Processes the response to extract meaningful data. + */ + protected processResponse(res: ApiResponse): string { + let returnValue = "No good search result found"; + for (const task of res.tasks || []) { + for (const result of task.result || []) { + const { item_types } = result; + const items = result.items || []; + if (item_types.includes("answer_box")) { + returnValue = items.find( + (item: { type: string; text: string }) => item.type === "answer_box" + ).text; + } else if (item_types.includes("knowledge_graph")) { + returnValue = items.find( + (item: { type: string; description: string }) => + item.type === "knowledge_graph" + ).description; + } else if (item_types.includes("featured_snippet")) { + returnValue = items.find( + (item: { type: string; description: string }) => + item.type === "featured_snippet" + ).description; + } else if (item_types.includes("shopping")) { + returnValue = items.find( + (item: { type: string; price: string }) => item.type === "shopping" + ).price; + } else if (item_types.includes("organic")) { + returnValue = items.find( + (item: { type: string; description: string }) => + item.type === "organic" + ).description; + } + if (returnValue) { + break; + } + } + } + return returnValue; + } +} diff --git a/libs/langchain-community/src/tools/fixtures/wordoftheday.html b/libs/langchain-community/src/tools/fixtures/wordoftheday.html new file mode 100644 index 000000000000..09baddceca01 --- /dev/null +++ b/libs/langchain-community/src/tools/fixtures/wordoftheday.html @@ -0,0 +1,3892 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Word of the Day: Foible | Merriam-Webster + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+
+ +
+
+
+ +
+
+
+ +
+
+ +
+
+
+
+
+

Word of the Day

+ : April 10, 2023 +
+ +
+ +
+
+
+

foible

+ play +
+
+ + + +
+ noun + + FOY-bul + +
+ + + + +
+
+ +
+ +
+
+
+ +
+

What It Means

+

+ Foibles are minor flaws or shortcomings in + character or behavior. In fencing, + foible refers to the part of a sword's blade + between the middle and point, which is considered the + weakest part. +

+ +

+ // He was amused daily by the foibles of his + eccentric neighbor. +

+ +

+ See the entry > +

+ + +

+ foible in + Context +

+
+
+

+ "Films about important historical moments are often + marked by a heavy solemnity, a sometimes suffocating + respectfulness that can make one forget that these + events involved real people, human beings with + passions and foibles." — Michael Ordoña, + The Los Angeles Times, 20 Jan. 2023 +

+ +
+
+
+ +
+

+ Build your vocabulary! Get Word of the Day in + your inbox every day. +

+
+ + +
+
+ +
+ +
+
+
+
+ + +
+

+ +

+ What Did You Just Call Me? +

+
+

+
+
+ +
+
+
    +
  • +
    + brown chihuahua sitting on the floor with squinting eyes looking at the camera +
    +
  • +
  • + Before we went to her house, Hannah told + us her aunt was a + flibbertigibbet. +
  • +
+
+ +
+
+ +
+
+ Name That Thing +
+
+

+ You know what it looks like… but what is + it called? +

+ TAKE THE QUIZ +
+
+ +
+
+ Solve today's spelling word game by finding as many words as you can with using just 7 letters. Longer words score more points. +
+
+

+ Can you make 12 words with 7 letters? +

+ PLAY +
+
+
+
+
+
+
+
+ +
+ + +
+

Did You Know?

+

+ Many word lovers agree that the pen is mightier than the + sword. But be they + honed + in wit or form, even the sharpest tools in the shed have + their flaws. That’s where foible comes in + handy. Borrowed from French in the 1600s, the word + originally referred to the weakest part of a fencing + sword, that part being the portion between the middle + and the pointed tip. The English foible soon + came to be applied not only to weaknesses in blades but + also to minor failings in character. The French source + of foible is also at a remove from the fencing + arena; the French foible means "weak," and it + comes from the same Old French term, feble, + that gave us + feeble. +

+ +
+
+ + +
+ +
+

Test Your Vocabulary

+

+ Unscramble the letters to create a word that refers to a + particular kind of fencing sword: BRASE. +

+ VIEW THE ANSWER +
+ +
+ + +
+
+

Podcast

+
+ + + +
+
+
+
+ +
+ + +
+
+

More Words of the Day

+
+ + + + +
+ +
+
+ + + + +
+ + + +
+
+ Love words? Need even more definitions? +
+

+ Subscribe to America's largest dictionary and get thousands + more definitions and advanced search—ad free! +

+ +
+
+
+
+ + + + + + + + + + +
+
+ + + + + + + + + + + + +
+
+
+
+ Do Not Sell Or Share My Personal Information +
+ You have chosen to opt-out of the sale or sharing of your information + from this site and any of its affiliates. To opt back in please click + the "Customize my ad experience" link.
+
This site collects information through the use of cookies and + other tracking tools. Cookies and these tools do not contain any + information that personally identifies a user, but personal + information that would be stored about you may be linked to the + information stored in and obtained from them. This information would + be used and shared for Analytics, Ad Serving, Interest Based + Advertising, among other purposes.
+
For more information please visit this site's Privacy + Policy.
+
+
+ CANCEL +
+
+ CONTINUE +
+
+
+
+ + + + + + + + + + + diff --git a/langchain/src/tools/gmail/base.ts b/libs/langchain-community/src/tools/gmail/base.ts similarity index 93% rename from langchain/src/tools/gmail/base.ts rename to libs/langchain-community/src/tools/gmail/base.ts index 7977f53387d3..fa9e122d2dfd 100644 --- a/langchain/src/tools/gmail/base.ts +++ b/libs/langchain-community/src/tools/gmail/base.ts @@ -1,7 +1,7 @@ import { gmail_v1, google } from "googleapis"; import { z } from "zod"; -import { StructuredTool } from "../base.js"; -import { getEnvironmentVariable } from "../../util/env.js"; +import { StructuredTool } from "@langchain/core/tools"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; export interface GmailBaseToolParams { credentials?: { diff --git a/langchain/src/tools/gmail/create_draft.ts b/libs/langchain-community/src/tools/gmail/create_draft.ts similarity index 100% rename from langchain/src/tools/gmail/create_draft.ts rename to libs/langchain-community/src/tools/gmail/create_draft.ts diff --git a/langchain/src/tools/gmail/descriptions.ts b/libs/langchain-community/src/tools/gmail/descriptions.ts similarity index 100% rename from langchain/src/tools/gmail/descriptions.ts rename to libs/langchain-community/src/tools/gmail/descriptions.ts diff --git a/langchain/src/tools/gmail/get_message.ts b/libs/langchain-community/src/tools/gmail/get_message.ts similarity index 100% rename from langchain/src/tools/gmail/get_message.ts rename to libs/langchain-community/src/tools/gmail/get_message.ts diff --git a/langchain/src/tools/gmail/get_thread.ts b/libs/langchain-community/src/tools/gmail/get_thread.ts similarity index 100% rename from langchain/src/tools/gmail/get_thread.ts rename to libs/langchain-community/src/tools/gmail/get_thread.ts diff --git a/libs/langchain-community/src/tools/gmail/index.ts b/libs/langchain-community/src/tools/gmail/index.ts new file mode 100644 index 000000000000..d2f854da54a4 --- /dev/null +++ b/libs/langchain-community/src/tools/gmail/index.ts @@ -0,0 +1,12 @@ +export { GmailCreateDraft } from "./create_draft.js"; +export { GmailGetMessage } from "./get_message.js"; +export { GmailGetThread } from "./get_thread.js"; +export { GmailSearch } from "./search.js"; +export { GmailSendMessage } from "./send_message.js"; + +export type { GmailBaseToolParams } from "./base.js"; +export type { CreateDraftSchema } from "./create_draft.js"; +export type { GetMessageSchema } from "./get_message.js"; +export type { GetThreadSchema } from "./get_thread.js"; +export type { SearchSchema } from "./search.js"; +export type { SendMessageSchema } from "./send_message.js"; diff --git a/langchain/src/tools/gmail/search.ts b/libs/langchain-community/src/tools/gmail/search.ts similarity index 100% rename from langchain/src/tools/gmail/search.ts rename to libs/langchain-community/src/tools/gmail/search.ts diff --git a/langchain/src/tools/gmail/send_message.ts b/libs/langchain-community/src/tools/gmail/send_message.ts similarity index 100% rename from langchain/src/tools/gmail/send_message.ts rename to libs/langchain-community/src/tools/gmail/send_message.ts diff --git a/libs/langchain-community/src/tools/google_custom_search.ts b/libs/langchain-community/src/tools/google_custom_search.ts new file mode 100644 index 000000000000..ef6cfd25d184 --- /dev/null +++ b/libs/langchain-community/src/tools/google_custom_search.ts @@ -0,0 +1,83 @@ +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { Tool } from "@langchain/core/tools"; + +/** + * Interface for parameters required by GoogleCustomSearch class. + */ +export interface GoogleCustomSearchParams { + apiKey?: string; + googleCSEId?: string; +} + +/** + * Class that uses the Google Search API to perform custom searches. + * Requires environment variables `GOOGLE_API_KEY` and `GOOGLE_CSE_ID` to + * be set. + */ +export class GoogleCustomSearch extends Tool { + static lc_name() { + return "GoogleCustomSearch"; + } + + get lc_secrets(): { [key: string]: string } | undefined { + return { + apiKey: "GOOGLE_API_KEY", + }; + } + + name = "google-custom-search"; + + protected apiKey: string; + + protected googleCSEId: string; + + description = + "a custom search engine. useful for when you need to answer questions about current events. input should be a search query. outputs a JSON array of results."; + + constructor( + fields: GoogleCustomSearchParams = { + apiKey: getEnvironmentVariable("GOOGLE_API_KEY"), + googleCSEId: getEnvironmentVariable("GOOGLE_CSE_ID"), + } + ) { + super(...arguments); + if (!fields.apiKey) { + throw new Error( + `Google API key not set. You can set it as "GOOGLE_API_KEY" in your environment variables.` + ); + } + if (!fields.googleCSEId) { + throw new Error( + `Google custom search engine id not set. You can set it as "GOOGLE_CSE_ID" in your environment variables.` + ); + } + this.apiKey = fields.apiKey; + this.googleCSEId = fields.googleCSEId; + } + + async _call(input: string) { + const res = await fetch( + `https://www.googleapis.com/customsearch/v1?key=${this.apiKey}&cx=${ + this.googleCSEId + }&q=${encodeURIComponent(input)}` + ); + + if (!res.ok) { + throw new Error( + `Got ${res.status} error from Google custom search: ${res.statusText}` + ); + } + + const json = await res.json(); + + const results = + json?.items?.map( + (item: { title?: string; link?: string; snippet?: string }) => ({ + title: item.title, + link: item.link, + snippet: item.snippet, + }) + ) ?? []; + return JSON.stringify(results); + } +} diff --git a/libs/langchain-community/src/tools/google_places.ts b/libs/langchain-community/src/tools/google_places.ts new file mode 100644 index 000000000000..6a702e5d6c58 --- /dev/null +++ b/libs/langchain-community/src/tools/google_places.ts @@ -0,0 +1,96 @@ +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { Tool } from "@langchain/core/tools"; + +/** + * Interface for parameters required by GooglePlacesAPI class. + */ +export interface GooglePlacesAPIParams { + apiKey?: string; +} + +/** + * Tool that queries the Google Places API + */ +export class GooglePlacesAPI extends Tool { + static lc_name() { + return "GooglePlacesAPI"; + } + + get lc_secrets(): { [key: string]: string } | undefined { + return { + apiKey: "GOOGLE_PLACES_API_KEY", + }; + } + + name = "google_places"; + + protected apiKey: string; + + description = `A wrapper around Google Places API. Useful for when you need to validate or + discover addresses from ambiguous text. Input should be a search query.`; + + constructor(fields?: GooglePlacesAPIParams) { + super(...arguments); + const apiKey = + fields?.apiKey ?? getEnvironmentVariable("GOOGLE_PLACES_API_KEY"); + if (apiKey === undefined) { + throw new Error( + `Google Places API key not set. You can set it as "GOOGLE_PLACES_API_KEY" in your environment variables.` + ); + } + this.apiKey = apiKey; + } + + async _call(input: string) { + const res = await fetch( + `https://places.googleapis.com/v1/places:searchText`, + { + method: "POST", + body: JSON.stringify({ + textQuery: input, + languageCode: "en", + }), + headers: { + "X-Goog-Api-Key": this.apiKey, + "X-Goog-FieldMask": + "places.displayName,places.formattedAddress,places.id,places.internationalPhoneNumber,places.websiteUri", + "Content-Type": "application/json", + }, + } + ); + + if (!res.ok) { + let message; + try { + const json = await res.json(); + message = json.error.message; + } catch (e) { + message = + "Unable to parse error message: Google did not return a JSON response."; + } + throw new Error( + `Got ${res.status}: ${res.statusText} error from Google Places API: ${message}` + ); + } + + const json = await res.json(); + + const results = + json?.places?.map( + (place: { + id?: string; + internationalPhoneNumber?: string; + formattedAddress?: string; + websiteUri?: string; + displayName?: { text?: string }; + }) => ({ + name: place.displayName?.text, + id: place.id, + address: place.formattedAddress, + phoneNumber: place.internationalPhoneNumber, + website: place.websiteUri, + }) + ) ?? []; + return JSON.stringify(results); + } +} diff --git a/libs/langchain-community/src/tools/ifttt.ts b/libs/langchain-community/src/tools/ifttt.ts new file mode 100644 index 000000000000..44df4d143d30 --- /dev/null +++ b/libs/langchain-community/src/tools/ifttt.ts @@ -0,0 +1,79 @@ +/** From https://github.com/SidU/teams-langchain-js/wiki/Connecting-IFTTT-Services. + +# Creating a webhook +- Go to https://ifttt.com/create + +# Configuring the "If This" +- Click on the "If This" button in the IFTTT interface. +- Search for "Webhooks" in the search bar. +- Choose the first option for "Receive a web request with a JSON payload." +- Choose an Event Name that is specific to the service you plan to connect to. +This will make it easier for you to manage the webhook URL. +For example, if you're connecting to Spotify, you could use "Spotify" as your +Event Name. +- Click the "Create Trigger" button to save your settings and create your webhook. + +# Configuring the "Then That" +- Tap on the "Then That" button in the IFTTT interface. +- Search for the service you want to connect, such as Spotify. +- Choose an action from the service, such as "Add track to a playlist". +- Configure the action by specifying the necessary details, such as the playlist name, +e.g., "Songs from AI". +- Reference the JSON Payload received by the Webhook in your action. For the Spotify +scenario, choose "{{JsonPayload}}" as your search query. +- Tap the "Create Action" button to save your action settings. +- Once you have finished configuring your action, click the "Finish" button to +complete the setup. +- Congratulations! You have successfully connected the Webhook to the desired +service, and you're ready to start receiving data and triggering actions 🎉 + +# Finishing up +- To get your webhook URL go to https://ifttt.com/maker_webhooks/settings +- Copy the IFTTT key value from there. The URL is of the form +https://maker.ifttt.com/use/YOUR_IFTTT_KEY. Grab the YOUR_IFTTT_KEY value. + */ +import { Tool } from "@langchain/core/tools"; + +/** + * Represents a tool for creating and managing webhooks with the IFTTT (If + * This Then That) service. The IFTTT service allows users to create + * chains of simple conditional statements, called applets, which are + * triggered based on changes to other web services. + */ +export class IFTTTWebhook extends Tool { + static lc_name() { + return "IFTTTWebhook"; + } + + private url: string; + + name: string; + + description: string; + + constructor(url: string, name: string, description: string) { + super(...arguments); + this.url = url; + this.name = name; + this.description = description; + } + + /** @ignore */ + async _call(input: string): Promise { + const headers = { "Content-Type": "application/json" }; + const body = JSON.stringify({ this: input }); + + const response = await fetch(this.url, { + method: "POST", + headers, + body, + }); + + if (!response.ok) { + throw new Error(`HTTP error ${response.status}`); + } + + const result = await response.text(); + return result; + } +} diff --git a/libs/langchain-community/src/tools/searchapi.ts b/libs/langchain-community/src/tools/searchapi.ts new file mode 100644 index 000000000000..c6731b7fc682 --- /dev/null +++ b/libs/langchain-community/src/tools/searchapi.ts @@ -0,0 +1,204 @@ +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { Tool } from "@langchain/core/tools"; + +type JSONPrimitive = string | number | boolean | null; +type JSONValue = JSONPrimitive | JSONObject | JSONArray; +interface JSONObject { + [key: string]: JSONValue; +} +interface JSONArray extends Array {} + +function isJSONObject(value: JSONValue): value is JSONObject { + return value !== null && typeof value === "object" && !Array.isArray(value); +} + +/** + * SearchApiParameters Type Definition. + * + * For more parameters and supported search engines, refer specific engine documentation: + * Google - https://www.searchapi.io/docs/google + * Google News - https://www.searchapi.io/docs/google-news + * Google Scholar - https://www.searchapi.io/docs/google-scholar + * YouTube Transcripts - https://www.searchapi.io/docs/youtube-transcripts + * and others. + * + */ +export type SearchApiParameters = { + [key: string]: JSONValue; +}; + +/** + * SearchApi Class Definition. + * + * Provides a wrapper around the SearchApi. + * + * Ensure you've set the SEARCHAPI_API_KEY environment variable for authentication. + * You can obtain a free API key from https://www.searchapi.io/. + * @example + * ```typescript + * const searchApi = new SearchApi("your-api-key", { + * engine: "google_news", + * }); + * const agent = RunnableSequence.from([ + * ChatPromptTemplate.fromMessages([ + * ["ai", "Answer the following questions using a bulleted list markdown format.""], + * ["human", "{input}"], + * ]), + * new ChatOpenAI({ temperature: 0 }), + * (input: BaseMessageChunk) => ({ + * log: "test", + * returnValues: { + * output: input, + * }, + * }), + * ]); + * const executor = AgentExecutor.fromAgentAndTools({ + * agent, + * tools: [searchApi], + * }); + * const res = await executor.invoke({ + * input: "What's happening in Ukraine today?"", + * }); + * console.log(res); + * ``` + */ +export class SearchApi extends Tool { + static lc_name() { + return "SearchApi"; + } + + /** + * Converts the SearchApi instance to JSON. This method is not implemented + * and will throw an error if called. + * @returns Throws an error. + */ + toJSON() { + return this.toJSONNotImplemented(); + } + + protected apiKey: string; + + protected params: Partial; + + constructor( + apiKey: string | undefined = getEnvironmentVariable("SEARCHAPI_API_KEY"), + params: Partial = {} + ) { + super(...arguments); + + if (!apiKey) { + throw new Error( + "SearchApi requires an API key. Please set it as SEARCHAPI_API_KEY in your .env file, or pass it as a parameter to the SearchApi constructor." + ); + } + + this.apiKey = apiKey; + this.params = params; + } + + name = "search"; + + /** + * Builds a URL for the SearchApi request. + * @param parameters The parameters for the request. + * @returns A string representing the built URL. + */ + protected buildUrl(searchQuery: string): string { + const preparedParams: [string, string][] = Object.entries({ + engine: "google", + api_key: this.apiKey, + ...this.params, + q: searchQuery, + }) + .filter( + ([key, value]) => + value !== undefined && value !== null && key !== "apiKey" + ) + .map(([key, value]) => [key, `${value}`]); + + const searchParams = new URLSearchParams(preparedParams); + return `https://www.searchapi.io/api/v1/search?${searchParams}`; + } + + /** @ignore */ + /** + * Calls the SearchAPI. + * + * Accepts an input query and fetches the result from SearchApi. + * + * @param {string} input - Search query. + * @returns {string} - Formatted search results or an error message. + * + * NOTE: This method is the core search handler and processes various types + * of search results including Google organic results, videos, jobs, and images. + */ + async _call(input: string) { + const resp = await fetch(this.buildUrl(input)); + + const json = await resp.json(); + + if (json.error) { + throw new Error( + `Failed to load search results from SearchApi due to: ${json.error}` + ); + } + + // Google Search results + if (json.answer_box?.answer) { + return json.answer_box.answer; + } + + if (json.answer_box?.snippet) { + return json.answer_box.snippet; + } + + if (json.knowledge_graph?.description) { + return json.knowledge_graph.description; + } + + // Organic results (Google, Google News) + if (json.organic_results) { + const snippets = json.organic_results + .filter((r: JSONObject) => r.snippet) + .map((r: JSONObject) => r.snippet); + return snippets.join("\n"); + } + + // Google Jobs results + if (json.jobs) { + const jobDescriptions = json.jobs + .slice(0, 1) + .filter((r: JSONObject) => r.description) + .map((r: JSONObject) => r.description); + return jobDescriptions.join("\n"); + } + + // Google Videos results + if (json.videos) { + const videoInfo = json.videos + .filter((r: JSONObject) => r.title && r.link) + .map((r: JSONObject) => `Title: "${r.title}" Link: ${r.link}`); + return videoInfo.join("\n"); + } + + // Google Images results + if (json.images) { + const image_results = json.images.slice(0, 15); + const imageInfo = image_results + .filter( + (r: JSONObject) => + r.title && r.original && isJSONObject(r.original) && r.original.link + ) + .map( + (r: JSONObject) => + `Title: "${r.title}" Link: ${(r.original as JSONObject).link}` + ); + return imageInfo.join("\n"); + } + + return "No good search result found"; + } + + description = + "a search engine. useful for when you need to answer questions about current events. input should be a search query."; +} diff --git a/libs/langchain-community/src/tools/searxng_search.ts b/libs/langchain-community/src/tools/searxng_search.ts new file mode 100644 index 000000000000..eec2b90df82b --- /dev/null +++ b/libs/langchain-community/src/tools/searxng_search.ts @@ -0,0 +1,258 @@ +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { Tool } from "@langchain/core/tools"; + +/** + * Interface for the results returned by the Searxng search. + */ +interface SearxngResults { + query: string; + number_of_results: number; + results: Array<{ + url: string; + title: string; + content: string; + img_src: string; + engine: string; + parsed_url: Array; + template: string; + engines: Array; + positions: Array; + score: number; + category: string; + pretty_url: string; + open_group?: boolean; + close_group?: boolean; + }>; + answers: Array; + corrections: Array; + infoboxes: Array<{ + infobox: string; + content: string; + engine: string; + engines: Array; + }>; + suggestions: Array; + unresponsive_engines: Array; +} + +/** + * Interface for custom headers used in the Searxng search. + */ +interface SearxngCustomHeaders { + [key: string]: string; +} + +interface SearxngSearchParams { + /** + * @default 10 + * Number of results included in results + */ + numResults?: number; + /** Comma separated list, specifies the active search categories + * https://docs.searxng.org/user/configured_engines.html#configured-engines + */ + categories?: string; + + /** Comma separated list, specifies the active search engines + * https://docs.searxng.org/user/configured_engines.html#configured-engines + */ + engines?: string; + + /** Code of the language. */ + language?: string; + /** Search page number. */ + pageNumber?: number; + /** + * day / month / year + * + * Time range of search for engines which support it. See if an engine supports time range search in the preferences page of an instance. + */ + timeRange?: number; + + /** + * Throws Error if format is set anything other than "json" + * Output format of results. Format needs to be activated in search: + */ + format?: "json"; + /** Open search results on new tab. */ + resultsOnNewTab?: 0 | 1; + /** Proxy image results through SearXNG. */ + imageProxy?: boolean; + autocomplete?: string; + /** + * Filter search results of engines which support safe search. See if an engine supports safe search in the preferences page of an instance. + */ + safesearch?: 0 | 1 | 2; +} + +/** + * SearxngSearch class represents a meta search engine tool. + * Use this class when you need to answer questions about current events. + * The input should be a search query, and the output is a JSON array of the query results. + * + * note: works best with *agentType*: `structured-chat-zero-shot-react-description` + * https://github.com/searxng/searxng + * @example + * ```typescript + * const executor = AgentExecutor.fromAgentAndTools({ + * agent, + * tools: [ + * new SearxngSearch({ + * params: { + * format: "json", + * engines: "google", + * }, + * headers: {}, + * }), + * ], + * }); + * const result = await executor.invoke({ + * input: `What is Langchain? Describe in 50 words`, + * }); + * ``` + */ +export class SearxngSearch extends Tool { + static lc_name() { + return "SearxngSearch"; + } + + name = "searxng-search"; + + description = + "A meta search engine. Useful for when you need to answer questions about current events. Input should be a search query. Output is a JSON array of the query results"; + + protected apiBase?: string; + + protected params?: SearxngSearchParams = { + numResults: 10, + pageNumber: 1, + format: "json", + imageProxy: true, + safesearch: 0, + }; + + protected headers?: SearxngCustomHeaders; + + get lc_secrets(): { [key: string]: string } | undefined { + return { + apiBase: "SEARXNG_API_BASE", + }; + } + + /** + * Constructor for the SearxngSearch class + * @param apiBase Base URL of the Searxng instance + * @param params SearxNG parameters + * @param headers Custom headers + */ + constructor({ + apiBase, + params, + headers, + }: { + /** Base URL of Searxng instance */ + apiBase?: string; + + /** SearxNG Paramerters + * + * https://docs.searxng.org/dev/search_api.html check here for more details + */ + params?: SearxngSearchParams; + + /** + * Custom headers + * Set custom headers if you're using a api from RapidAPI (https://rapidapi.com/iamrony777/api/searxng) + * No headers needed for a locally self-hosted instance + */ + headers?: SearxngCustomHeaders; + }) { + super(...arguments); + + this.apiBase = getEnvironmentVariable("SEARXNG_API_BASE") || apiBase; + this.headers = { "content-type": "application/json", ...headers }; + + if (!this.apiBase) { + throw new Error( + `SEARXNG_API_BASE not set. You can set it as "SEARXNG_API_BASE" in your environment variables.` + ); + } + + if (params) { + this.params = { ...this.params, ...params }; + } + } + + /** + * Builds the URL for the Searxng search. + * @param path The path for the URL. + * @param parameters The parameters for the URL. + * @param baseUrl The base URL. + * @returns The complete URL as a string. + */ + protected buildUrl

( + path: string, + parameters: P, + baseUrl: string + ): string { + const nonUndefinedParams: [string, string][] = Object.entries(parameters) + .filter(([_, value]) => value !== undefined) + .map(([key, value]) => [key, value.toString()]); // Avoid string conversion + const searchParams = new URLSearchParams(nonUndefinedParams); + return `${baseUrl}/${path}?${searchParams}`; + } + + async _call(input: string): Promise { + const queryParams = { + q: input, + ...this.params, + }; + const url = this.buildUrl("search", queryParams, this.apiBase as string); + + const resp = await fetch(url, { + method: "POST", + headers: this.headers, + signal: AbortSignal.timeout(5 * 1000), // 5 seconds + }); + + if (!resp.ok) { + throw new Error(resp.statusText); + } + + const res: SearxngResults = await resp.json(); + + if ( + !res.results.length && + !res.answers.length && + !res.infoboxes.length && + !res.suggestions.length + ) { + return "No good results found."; + } else if (res.results.length) { + const response: string[] = []; + + res.results.forEach((r) => { + response.push( + JSON.stringify({ + title: r.title || "", + link: r.url || "", + snippet: r.content || "", + }) + ); + }); + + return response.slice(0, this.params?.numResults).toString(); + } else if (res.answers.length) { + return res.answers[0]; + } else if (res.infoboxes.length) { + return res.infoboxes[0]?.content.replaceAll(/<[^>]+>/gi, ""); + } else if (res.suggestions.length) { + let suggestions = "Suggestions: "; + res.suggestions.forEach((s) => { + suggestions += `${s}, `; + }); + return suggestions; + } else { + return "No good results found."; + } + } +} diff --git a/libs/langchain-community/src/tools/serpapi.ts b/libs/langchain-community/src/tools/serpapi.ts new file mode 100644 index 000000000000..f6a0f4dc099b --- /dev/null +++ b/libs/langchain-community/src/tools/serpapi.ts @@ -0,0 +1,505 @@ +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { Tool } from "@langchain/core/tools"; + +/** + * This does not use the `serpapi` package because it appears to cause issues + * when used in `jest` tests. Part of the issue seems to be that the `serpapi` + * package imports a wasm module to use instead of native `fetch`, which we + * don't want anyway. + * + * NOTE: you must provide location, gl and hl or your region and language will + * may not match your location, and will not be deterministic. + */ + +// Copied over from `serpapi` package +interface BaseParameters { + /** + * Parameter defines the device to use to get the results. It can be set to + * `desktop` (default) to use a regular browser, `tablet` to use a tablet browser + * (currently using iPads), or `mobile` to use a mobile browser (currently + * using iPhones). + */ + device?: "desktop" | "tablet" | "mobile"; + /** + * Parameter will force SerpApi to fetch the Google results even if a cached + * version is already present. A cache is served only if the query and all + * parameters are exactly the same. Cache expires after 1h. Cached searches + * are free, and are not counted towards your searches per month. It can be set + * to `false` (default) to allow results from the cache, or `true` to disallow + * results from the cache. `no_cache` and `async` parameters should not be used together. + */ + no_cache?: boolean; + /** + * Specify the client-side timeout of the request. In milliseconds. + */ + timeout?: number; +} + +export interface SerpAPIParameters extends BaseParameters { + /** + * Search Query + * Parameter defines the query you want to search. You can use anything that you + * would use in a regular Google search. e.g. `inurl:`, `site:`, `intitle:`. We + * also support advanced search query parameters such as as_dt and as_eq. See the + * [full list](https://serpapi.com/advanced-google-query-parameters) of supported + * advanced search query parameters. + */ + q: string; + /** + * Location + * Parameter defines from where you want the search to originate. If several + * locations match the location requested, we'll pick the most popular one. Head to + * [/locations.json API](https://serpapi.com/locations-api) if you need more + * precise control. location and uule parameters can't be used together. Avoid + * utilizing location when setting the location outside the U.S. when using Google + * Shopping and/or Google Product API. + */ + location?: string; + /** + * Encoded Location + * Parameter is the Google encoded location you want to use for the search. uule + * and location parameters can't be used together. + */ + uule?: string; + /** + * Google Place ID + * Parameter defines the id (`CID`) of the Google My Business listing you want to + * scrape. Also known as Google Place ID. + */ + ludocid?: string; + /** + * Additional Google Place ID + * Parameter that you might have to use to force the knowledge graph map view to + * show up. You can find the lsig ID by using our [Local Pack + * API](https://serpapi.com/local-pack) or [Local Places Results + * API](https://serpapi.com/local-results). + * lsig ID is also available via a redirect Google uses within [Google My + * Business](https://www.google.com/business/). + */ + lsig?: string; + /** + * Google Knowledge Graph ID + * Parameter defines the id (`KGMID`) of the Google Knowledge Graph listing you + * want to scrape. Also known as Google Knowledge Graph ID. Searches with kgmid + * parameter will return results for the originally encrypted search parameters. + * For some searches, kgmid may override all other parameters except start, and num + * parameters. + */ + kgmid?: string; + /** + * Google Cached Search Parameters ID + * Parameter defines the cached search parameters of the Google Search you want to + * scrape. Searches with si parameter will return results for the originally + * encrypted search parameters. For some searches, si may override all other + * parameters except start, and num parameters. si can be used to scrape Google + * Knowledge Graph Tabs. + */ + si?: string; + /** + * Domain + * Parameter defines the Google domain to use. It defaults to `google.com`. Head to + * the [Google domains page](https://serpapi.com/google-domains) for a full list of + * supported Google domains. + */ + google_domain?: string; + /** + * Country + * Parameter defines the country to use for the Google search. It's a two-letter + * country code. (e.g., `us` for the United States, `uk` for United Kingdom, or + * `fr` for France). Head to the [Google countries + * page](https://serpapi.com/google-countries) for a full list of supported Google + * countries. + */ + gl?: string; + /** + * Language + * Parameter defines the language to use for the Google search. It's a two-letter + * language code. (e.g., `en` for English, `es` for Spanish, or `fr` for French). + * Head to the [Google languages page](https://serpapi.com/google-languages) for a + * full list of supported Google languages. + */ + hl?: string; + /** + * Set Multiple Languages + * Parameter defines one or multiple languages to limit the search to. It uses + * `lang_{two-letter language code}` to specify languages and `|` as a delimiter. + * (e.g., `lang_fr|lang_de` will only search French and German pages). Head to the + * [Google lr languages page](https://serpapi.com/google-lr-languages) for a full + * list of supported languages. + */ + lr?: string; + /** + * as_dt + * Parameter controls whether to include or exclude results from the site named in + * the as_sitesearch parameter. + */ + as_dt?: string; + /** + * as_epq + * Parameter identifies a phrase that all documents in the search results must + * contain. You can also use the [phrase + * search](https://developers.google.com/custom-search/docs/xml_results#PhraseSearchqt) + * query term to search for a phrase. + */ + as_epq?: string; + /** + * as_eq + * Parameter identifies a word or phrase that should not appear in any documents in + * the search results. You can also use the [exclude + * query](https://developers.google.com/custom-search/docs/xml_results#Excludeqt) + * term to ensure that a particular word or phrase will not appear in the documents + * in a set of search results. + */ + as_eq?: string; + /** + * as_lq + * Parameter specifies that all search results should contain a link to a + * particular URL. You can also use the + * [link:](https://developers.google.com/custom-search/docs/xml_results#BackLinksqt) + * query term for this type of query. + */ + as_lq?: string; + /** + * as_nlo + * Parameter specifies the starting value for a search range. Use as_nlo and as_nhi + * to append an inclusive search range. + */ + as_nlo?: string; + /** + * as_nhi + * Parameter specifies the ending value for a search range. Use as_nlo and as_nhi + * to append an inclusive search range. + */ + as_nhi?: string; + /** + * as_oq + * Parameter provides additional search terms to check for in a document, where + * each document in the search results must contain at least one of the additional + * search terms. You can also use the [Boolean + * OR](https://developers.google.com/custom-search/docs/xml_results#BooleanOrqt) + * query term for this type of query. + */ + as_oq?: string; + /** + * as_q + * Parameter provides search terms to check for in a document. This parameter is + * also commonly used to allow users to specify additional terms to search for + * within a set of search results. + */ + as_q?: string; + /** + * as_qdr + * Parameter requests search results from a specified time period (quick date + * range). The following values are supported: + * `d[number]`: requests results from the specified number of past days. Example + * for the past 10 days: `as_qdr=d10` + * `w[number]`: requests results from the specified number of past weeks. + * `m[number]`: requests results from the specified number of past months. + * `y[number]`: requests results from the specified number of past years. Example + * for the past year: `as_qdr=y` + */ + as_qdr?: string; + /** + * as_rq + * Parameter specifies that all search results should be pages that are related to + * the specified URL. The parameter value should be a URL. You can also use the + * [related:](https://developers.google.com/custom-search/docs/xml_results#RelatedLinksqt) + * query term for this type of query. + */ + as_rq?: string; + /** + * as_sitesearch + * Parameter allows you to specify that all search results should be pages from a + * given site. By setting the as_dt parameter, you can also use it to exclude pages + * from a given site from your search resutls. + */ + as_sitesearch?: string; + /** + * Advanced Search Parameters + * (to be searched) parameter defines advanced search parameters that aren't + * possible in the regular query field. (e.g., advanced search for patents, dates, + * news, videos, images, apps, or text contents). + */ + tbs?: string; + /** + * Adult Content Filtering + * Parameter defines the level of filtering for adult content. It can be set to + * `active`, or `off` (default). + */ + safe?: string; + /** + * Exclude Auto-corrected Results + * Parameter defines the exclusion of results from an auto-corrected query that is + * spelled wrong. It can be set to `1` to exclude these results, or `0` to include + * them (default). + */ + nfpr?: string; + /** + * Results Filtering + * Parameter defines if the filters for 'Similar Results' and 'Omitted Results' are + * on or off. It can be set to `1` (default) to enable these filters, or `0` to + * disable these filters. + */ + filter?: string; + /** + * Search Type + * (to be matched) parameter defines the type of search you want to do. + * It can be set to: + * `(no tbm parameter)`: regular Google Search, + * `isch`: [Google Images API](https://serpapi.com/images-results), + * `lcl` - [Google Local API](https://serpapi.com/local-results) + * `vid`: [Google Videos API](https://serpapi.com/videos-results), + * `nws`: [Google News API](https://serpapi.com/news-results), + * `shop`: [Google Shopping API](https://serpapi.com/shopping-results), + * or any other Google service. + */ + tbm?: string; + /** + * Result Offset + * Parameter defines the result offset. It skips the given number of results. It's + * used for pagination. (e.g., `0` (default) is the first page of results, `10` is + * the 2nd page of results, `20` is the 3rd page of results, etc.). + * Google Local Results only accepts multiples of `20`(e.g. `20` for the second + * page results, `40` for the third page results, etc.) as the start value. + */ + start?: number; + /** + * Number of Results + * Parameter defines the maximum number of results to return. (e.g., `10` (default) + * returns 10 results, `40` returns 40 results, and `100` returns 100 results). + */ + num?: string; + /** + * Page Number (images) + * Parameter defines the page number for [Google + * Images](https://serpapi.com/images-results). There are 100 images per page. This + * parameter is equivalent to start (offset) = ijn * 100. This parameter works only + * for [Google Images](https://serpapi.com/images-results) (set tbm to `isch`). + */ + ijn?: string; +} + +type UrlParameters = Record< + string, + string | number | boolean | undefined | null +>; + +/** + * Wrapper around SerpAPI. + * + * To use, you should have the `serpapi` package installed and the SERPAPI_API_KEY environment variable set. + */ +export class SerpAPI extends Tool { + static lc_name() { + return "SerpAPI"; + } + + toJSON() { + return this.toJSONNotImplemented(); + } + + protected key: string; + + protected params: Partial; + + protected baseUrl: string; + + constructor( + apiKey: string | undefined = getEnvironmentVariable("SERPAPI_API_KEY"), + params: Partial = {}, + baseUrl = "https://serpapi.com" + ) { + super(...arguments); + + if (!apiKey) { + throw new Error( + "SerpAPI API key not set. You can set it as SERPAPI_API_KEY in your .env file, or pass it to SerpAPI." + ); + } + + this.key = apiKey; + this.params = params; + this.baseUrl = baseUrl; + } + + name = "search"; + + /** + * Builds a URL for the SerpAPI request. + * @param path The path for the request. + * @param parameters The parameters for the request. + * @param baseUrl The base URL for the request. + * @returns A string representing the built URL. + */ + protected buildUrl

( + path: string, + parameters: P, + baseUrl: string + ): string { + const nonUndefinedParams: [string, string][] = Object.entries(parameters) + .filter(([_, value]) => value !== undefined) + .map(([key, value]) => [key, `${value}`]); + const searchParams = new URLSearchParams(nonUndefinedParams); + return `${baseUrl}/${path}?${searchParams}`; + } + + /** @ignore */ + async _call(input: string) { + const { timeout, ...params } = this.params; + const resp = await fetch( + this.buildUrl( + "search", + { + ...params, + api_key: this.key, + q: input, + }, + this.baseUrl + ), + { + signal: timeout ? AbortSignal.timeout(timeout) : undefined, + } + ); + + const res = await resp.json(); + + if (res.error) { + throw new Error(`Got error from serpAPI: ${res.error}`); + } + + const answer_box = res.answer_box_list + ? res.answer_box_list[0] + : res.answer_box; + if (answer_box) { + if (answer_box.result) { + return answer_box.result; + } else if (answer_box.answer) { + return answer_box.answer; + } else if (answer_box.snippet) { + return answer_box.snippet; + } else if (answer_box.snippet_highlighted_words) { + return answer_box.snippet_highlighted_words.toString(); + } else { + const answer: { [key: string]: string } = {}; + Object.keys(answer_box) + .filter( + (k) => + !Array.isArray(answer_box[k]) && + typeof answer_box[k] !== "object" && + !( + typeof answer_box[k] === "string" && + answer_box[k].startsWith("http") + ) + ) + .forEach((k) => { + answer[k] = answer_box[k]; + }); + return JSON.stringify(answer); + } + } + + if (res.events_results) { + return JSON.stringify(res.events_results); + } + + if (res.sports_results) { + return JSON.stringify(res.sports_results); + } + + if (res.top_stories) { + return JSON.stringify(res.top_stories); + } + + if (res.news_results) { + return JSON.stringify(res.news_results); + } + + if (res.jobs_results?.jobs) { + return JSON.stringify(res.jobs_results.jobs); + } + + if (res.questions_and_answers) { + return JSON.stringify(res.questions_and_answers); + } + + if (res.popular_destinations?.destinations) { + return JSON.stringify(res.popular_destinations.destinations); + } + + if (res.top_sights?.sights) { + const sights: Array<{ [key: string]: string }> = res.top_sights.sights + .map((s: { [key: string]: string }) => ({ + title: s.title, + description: s.description, + price: s.price, + })) + .slice(0, 8); + return JSON.stringify(sights); + } + + if (res.shopping_results && res.shopping_results[0]?.title) { + return JSON.stringify(res.shopping_results.slice(0, 3)); + } + + if (res.images_results && res.images_results[0]?.thumbnail) { + return res.images_results + .map((ir: { thumbnail: string }) => ir.thumbnail) + .slice(0, 10) + .toString(); + } + + const snippets = []; + if (res.knowledge_graph) { + if (res.knowledge_graph.description) { + snippets.push(res.knowledge_graph.description); + } + + const title = res.knowledge_graph.title || ""; + Object.keys(res.knowledge_graph) + .filter( + (k) => + typeof res.knowledge_graph[k] === "string" && + k !== "title" && + k !== "description" && + !k.endsWith("_stick") && + !k.endsWith("_link") && + !k.startsWith("http") + ) + .forEach((k) => + snippets.push(`${title} ${k}: ${res.knowledge_graph[k]}`) + ); + } + + const first_organic_result = res.organic_results?.[0]; + if (first_organic_result) { + if (first_organic_result.snippet) { + snippets.push(first_organic_result.snippet); + } else if (first_organic_result.snippet_highlighted_words) { + snippets.push(first_organic_result.snippet_highlighted_words); + } else if (first_organic_result.rich_snippet) { + snippets.push(first_organic_result.rich_snippet); + } else if (first_organic_result.rich_snippet_table) { + snippets.push(first_organic_result.rich_snippet_table); + } else if (first_organic_result.link) { + snippets.push(first_organic_result.link); + } + } + + if (res.buying_guide) { + snippets.push(res.buying_guide); + } + + if (res.local_results?.places) { + snippets.push(res.local_results.places); + } + + if (snippets.length > 0) { + return JSON.stringify(snippets); + } else { + return "No good search result found"; + } + } + + description = + "a search engine. useful for when you need to answer questions about current events. input should be a search query."; +} diff --git a/libs/langchain-community/src/tools/serper.ts b/libs/langchain-community/src/tools/serper.ts new file mode 100644 index 000000000000..ca444d769187 --- /dev/null +++ b/libs/langchain-community/src/tools/serper.ts @@ -0,0 +1,107 @@ +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { Tool } from "@langchain/core/tools"; + +/** + * Defines the parameters that can be passed to the Serper class during + * instantiation. It includes `gl` and `hl` which are optional. + */ +export type SerperParameters = { + gl?: string; + hl?: string; +}; + +/** + * Wrapper around serper. + * + * You can create a free API key at https://serper.dev. + * + * To use, you should have the SERPER_API_KEY environment variable set. + */ +export class Serper extends Tool { + static lc_name() { + return "Serper"; + } + + /** + * Converts the Serper instance to JSON. This method is not implemented + * and will throw an error if called. + * @returns Throws an error. + */ + toJSON() { + return this.toJSONNotImplemented(); + } + + protected key: string; + + protected params: Partial; + + constructor( + apiKey: string | undefined = getEnvironmentVariable("SERPER_API_KEY"), + params: Partial = {} + ) { + super(); + + if (!apiKey) { + throw new Error( + "Serper API key not set. You can set it as SERPER_API_KEY in your .env file, or pass it to Serper." + ); + } + + this.key = apiKey; + this.params = params; + } + + name = "search"; + + /** @ignore */ + async _call(input: string) { + const options = { + method: "POST", + headers: { + "X-API-KEY": this.key, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + q: input, + ...this.params, + }), + }; + + const res = await fetch("https://google.serper.dev/search", options); + + if (!res.ok) { + throw new Error(`Got ${res.status} error from serper: ${res.statusText}`); + } + + const json = await res.json(); + + if (json.answerBox?.answer) { + return json.answerBox.answer; + } + + if (json.answerBox?.snippet) { + return json.answerBox.snippet; + } + + if (json.answerBox?.snippet_highlighted_words) { + return json.answerBox.snippet_highlighted_words[0]; + } + + if (json.sportsResults?.game_spotlight) { + return json.sportsResults.game_spotlight; + } + + if (json.knowledgeGraph?.description) { + return json.knowledgeGraph.description; + } + + if (json.organic?.[0]?.snippet) { + return json.organic[0].snippet; + } + + return "No good search result found"; + } + + description = + "a search engine. useful for when you need to answer questions about current events. input should be a search query."; +} diff --git a/libs/langchain-community/src/tools/wikipedia_query_run.ts b/libs/langchain-community/src/tools/wikipedia_query_run.ts new file mode 100644 index 000000000000..d0166e5c58d0 --- /dev/null +++ b/libs/langchain-community/src/tools/wikipedia_query_run.ts @@ -0,0 +1,181 @@ +import { Tool } from "@langchain/core/tools"; + +/** + * Interface for the parameters that can be passed to the + * WikipediaQueryRun constructor. + */ +export interface WikipediaQueryRunParams { + topKResults?: number; + maxDocContentLength?: number; + baseUrl?: string; +} + +/** + * Type alias for URL parameters. Represents a record where keys are + * strings and values can be string, number, boolean, undefined, or null. + */ +type UrlParameters = Record< + string, + string | number | boolean | undefined | null +>; + +/** + * Interface for the structure of search results returned by the Wikipedia + * API. + */ +interface SearchResults { + query: { + search: Array<{ + title: string; + }>; + }; +} + +/** + * Interface for the structure of a page returned by the Wikipedia API. + */ +interface Page { + pageid: number; + ns: number; + title: string; + extract: string; +} + +/** + * Interface for the structure of a page result returned by the Wikipedia + * API. + */ +interface PageResult { + batchcomplete: string; + query: { + pages: Record; + }; +} + +/** + * Class for interacting with and fetching data from the Wikipedia API. It + * extends the Tool class. + * @example + * ```typescript + * const wikipediaQuery = new WikipediaQueryRun({ + * topKResults: 3, + * maxDocContentLength: 4000, + * }); + * const result = await wikipediaQuery.call("Langchain"); + * ``` + */ +export class WikipediaQueryRun extends Tool { + static lc_name() { + return "WikipediaQueryRun"; + } + + name = "wikipedia-api"; + + description = + "A tool for interacting with and fetching data from the Wikipedia API."; + + protected topKResults = 3; + + protected maxDocContentLength = 4000; + + protected baseUrl = "https://en.wikipedia.org/w/api.php"; + + constructor(params: WikipediaQueryRunParams = {}) { + super(); + + this.topKResults = params.topKResults ?? this.topKResults; + this.maxDocContentLength = + params.maxDocContentLength ?? this.maxDocContentLength; + this.baseUrl = params.baseUrl ?? this.baseUrl; + } + + async _call(query: string): Promise { + const searchResults = await this._fetchSearchResults(query); + const summaries: string[] = []; + + for ( + let i = 0; + i < Math.min(this.topKResults, searchResults.query.search.length); + i += 1 + ) { + const page = searchResults.query.search[i].title; + const pageDetails = await this._fetchPage(page, true); + + if (pageDetails) { + const summary = `Page: ${page}\nSummary: ${pageDetails.extract}`; + summaries.push(summary); + } + } + + if (summaries.length === 0) { + return "No good Wikipedia Search Result was found"; + } else { + return summaries.join("\n\n").slice(0, this.maxDocContentLength); + } + } + + /** + * Fetches the content of a specific Wikipedia page. It returns the + * extracted content as a string. + * @param page The specific Wikipedia page to fetch its content. + * @param redirect A boolean value to indicate whether to redirect or not. + * @returns The extracted content of the specific Wikipedia page as a string. + */ + public async content(page: string, redirect = true): Promise { + try { + const result = await this._fetchPage(page, redirect); + return result.extract; + } catch (error) { + throw new Error(`Failed to fetch content for page "${page}": ${error}`); + } + } + + /** + * Builds a URL for the Wikipedia API using the provided parameters. + * @param parameters The parameters to be used in building the URL. + * @returns A string representing the built URL. + */ + protected buildUrl

(parameters: P): string { + const nonUndefinedParams: [string, string][] = Object.entries(parameters) + .filter(([_, value]) => value !== undefined) + .map(([key, value]) => [key, `${value}`]); + const searchParams = new URLSearchParams(nonUndefinedParams); + return `${this.baseUrl}?${searchParams}`; + } + + private async _fetchSearchResults(query: string): Promise { + const searchParams = new URLSearchParams({ + action: "query", + list: "search", + srsearch: query, + format: "json", + }); + + const response = await fetch(`${this.baseUrl}?${searchParams.toString()}`); + if (!response.ok) throw new Error("Network response was not ok"); + + const data: SearchResults = await response.json(); + + return data; + } + + private async _fetchPage(page: string, redirect: boolean): Promise { + const params = new URLSearchParams({ + action: "query", + prop: "extracts", + explaintext: "true", + redirects: redirect ? "1" : "0", + format: "json", + titles: page, + }); + + const response = await fetch(`${this.baseUrl}?${params.toString()}`); + if (!response.ok) throw new Error("Network response was not ok"); + + const data: PageResult = await response.json(); + const { pages } = data.query; + const pageId = Object.keys(pages)[0]; + + return pages[pageId]; + } +} diff --git a/libs/langchain-community/src/tools/wolframalpha.ts b/libs/langchain-community/src/tools/wolframalpha.ts new file mode 100644 index 000000000000..dfa58807d5f6 --- /dev/null +++ b/libs/langchain-community/src/tools/wolframalpha.ts @@ -0,0 +1,37 @@ +import { Tool, type ToolParams } from "@langchain/core/tools"; + +/** + * @example + * ```typescript + * const tool = new WolframAlphaTool({ + * appid: "YOUR_APP_ID", + * }); + * const res = await tool.invoke("What is 2 * 2?"); + * ``` + */ +export class WolframAlphaTool extends Tool { + appid: string; + + name = "wolfram_alpha"; + + description = `A wrapper around Wolfram Alpha. Useful for when you need to answer questions about Math, Science, Technology, Culture, Society and Everyday Life. Input should be a search query.`; + + constructor(fields: ToolParams & { appid: string }) { + super(fields); + + this.appid = fields.appid; + } + + static lc_name() { + return "WolframAlphaTool"; + } + + async _call(query: string): Promise { + const url = `https://www.wolframalpha.com/api/v1/llm-api?appid=${ + this.appid + }&input=${encodeURIComponent(query)}`; + const res = await fetch(url); + + return res.text(); + } +} diff --git a/libs/langchain-community/src/types/expression-parser.d.ts b/libs/langchain-community/src/types/expression-parser.d.ts new file mode 100644 index 000000000000..87e7e2ef57e7 --- /dev/null +++ b/libs/langchain-community/src/types/expression-parser.d.ts @@ -0,0 +1,91 @@ +declare interface ParseOptions { + filename?: string; + startRule?: "Start"; + tracer?: any; + [key: string]: any; +} +declare type ParseFunction = ( + input: string, + options?: Options +) => Options extends { startRule: infer StartRule } + ? StartRule extends "Start" + ? Start + : Start + : Start; + +// These types were autogenerated by ts-pegjs +declare type Start = Program; +declare type Identifier = IdentifierName; +declare type IdentifierName = { type: "Identifier"; name: string }; +declare type Literal = + | NullLiteral + | BooleanLiteral + | NumericLiteral + | StringLiteral; +declare type NullLiteral = { type: "NullLiteral"; value: null }; +declare type BooleanLiteral = + | { type: "BooleanLiteral"; value: true } + | { type: "BooleanLiteral"; value: false }; +declare type NumericLiteral = DecimalLiteral; +declare type DecimalLiteral = { type: "NumericLiteral"; value: number }; +declare type StringLiteral = { type: "StringLiteral"; value: string }; +declare type PrimaryExpression = + | Identifier + | Literal + | ArrayExpression + | ObjectExpression + | Expression; +declare type ArrayExpression = { + type: "ArrayExpression"; + elements: ElementList; +}; +declare type ElementList = PrimaryExpression[]; +declare type ObjectExpression = + | { type: "ObjectExpression"; properties: [] } + | { type: "ObjectExpression"; properties: PropertyNameAndValueList }; +declare type PropertyNameAndValueList = PrimaryExpression[]; +declare type PropertyAssignment = { + type: "PropertyAssignment"; + key: PropertyName; + value: Expression; + kind: "init"; +}; +declare type PropertyName = IdentifierName | StringLiteral | NumericLiteral; +declare type MemberExpression = + | { + type: "MemberExpression"; + property: StringLiteral; + computed: true; + object: MemberExpression | Identifier | StringLiteral; + } + | { + type: "MemberExpression"; + property: Identifier; + computed: false; + object: MemberExpression | Identifier | StringLiteral; + }; +declare type CallExpression = { + type: "CallExpression"; + arguments: Arguments; + callee: MemberExpression | Identifier; +}; +declare type Arguments = PrimaryExpression[]; +declare type Expression = CallExpression | MemberExpression; +declare type ExpressionStatement = { + type: "ExpressionStatement"; + expression: Expression; +}; +declare type Program = { type: "Program"; body: ExpressionStatement }; +declare type ExpressionNode = + | Program + | ExpressionStatement + | ArrayExpression + | BooleanLiteral + | CallExpression + | Identifier + | MemberExpression + | NumericLiteral + | ObjectExpression + | PropertyAssignment + | NullLiteral + | StringLiteral; diff --git a/libs/langchain-community/src/types/googlevertexai-types.ts b/libs/langchain-community/src/types/googlevertexai-types.ts new file mode 100644 index 000000000000..f65694cacd49 --- /dev/null +++ b/libs/langchain-community/src/types/googlevertexai-types.ts @@ -0,0 +1,89 @@ +import type { BaseLLMParams } from "@langchain/core/language_models/llms"; + +export interface GoogleConnectionParams { + authOptions?: AuthOptions; +} + +export interface GoogleVertexAIConnectionParams + extends GoogleConnectionParams { + /** Hostname for the API call */ + endpoint?: string; + + /** Region where the LLM is stored */ + location?: string; + + /** The version of the API functions. Part of the path. */ + apiVersion?: string; +} + +export interface GoogleVertexAIModelParams { + /** Model to use */ + model?: string; + + /** Sampling temperature to use */ + temperature?: number; + + /** + * Maximum number of tokens to generate in the completion. + */ + maxOutputTokens?: number; + + /** + * Top-p changes how the model selects tokens for output. + * + * Tokens are selected from most probable to least until the sum + * of their probabilities equals the top-p value. + * + * For example, if tokens A, B, and C have a probability of + * .3, .2, and .1 and the top-p value is .5, then the model will + * select either A or B as the next token (using temperature). + */ + topP?: number; + + /** + * Top-k changes how the model selects tokens for output. + * + * A top-k of 1 means the selected token is the most probable among + * all tokens in the model’s vocabulary (also called greedy decoding), + * while a top-k of 3 means that the next token is selected from + * among the 3 most probable tokens (using temperature). + */ + topK?: number; +} + +export interface GoogleVertexAIBaseLLMInput + extends BaseLLMParams, + GoogleVertexAIConnectionParams, + GoogleVertexAIModelParams {} + +export interface GoogleResponse { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + data: any; +} + +export interface GoogleVertexAIBasePrediction { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + safetyAttributes?: any; +} + +export interface GoogleVertexAILLMPredictions< + PredictionType extends GoogleVertexAIBasePrediction +> { + predictions: PredictionType[]; +} + +export type GoogleAbstractedClientOpsMethod = "GET" | "POST"; + +export type GoogleAbstractedClientOpsResponseType = "json" | "stream"; + +export type GoogleAbstractedClientOps = { + url?: string; + method?: GoogleAbstractedClientOpsMethod; + data?: unknown; + responseType?: GoogleAbstractedClientOpsResponseType; +}; + +export interface GoogleAbstractedClient { + request: (opts: GoogleAbstractedClientOps) => unknown; + getProjectId: () => Promise; +} diff --git a/libs/langchain-community/src/types/pdf-parse.d.ts b/libs/langchain-community/src/types/pdf-parse.d.ts new file mode 100644 index 000000000000..675c403a559c --- /dev/null +++ b/libs/langchain-community/src/types/pdf-parse.d.ts @@ -0,0 +1,1409 @@ +/** + * Type definitions adapted from pdfjs-dist + * https://github.com/mozilla/pdfjs-dist/blob/master/types/src/display/api.d.ts + */ + +declare module "pdf-parse/lib/pdf.js/v1.10.100/build/pdf.js" { + export type TypedArray = + | Int8Array + | Uint8Array + | Uint8ClampedArray + | Int16Array + | Uint16Array + | Int32Array + | Uint32Array + | Float32Array + | Float64Array; + export type BinaryData = TypedArray | ArrayBuffer | Array | string; + export type RefProxy = { + num: number; + gen: number; + }; + /** + * Document initialization / loading parameters object. + */ + export type DocumentInitParameters = { + /** + * - The URL of the PDF. + */ + url?: string | URL | undefined; + /** + * - Binary PDF data. + * Use TypedArrays (Uint8Array) to improve the memory usage. If PDF data is + * BASE64-encoded, use `atob()` to convert it to a binary string first. + * + * NOTE: If TypedArrays are used they will generally be transferred to the + * worker-thread. This will help reduce main-thread memory usage, however + * it will take ownership of the TypedArrays. + */ + data?: BinaryData | undefined; + /** + * - Basic authentication headers. + */ + httpHeaders?: Object | undefined; + /** + * - Indicates whether or not + * cross-site Access-Control requests should be made using credentials such + * as cookies or authorization headers. The default is `false`. + */ + withCredentials?: boolean | undefined; + /** + * - For decrypting password-protected PDFs. + */ + password?: string | undefined; + /** + * - The PDF file length. It's used for progress + * reports and range requests operations. + */ + length?: number | undefined; + /** + * - Allows for using a custom range + * transport implementation. + */ + range?: PDFDataRangeTransport | undefined; + /** + * - Specify maximum number of bytes fetched + * per range request. The default value is {@link DEFAULT_RANGE_CHUNK_SIZE }. + */ + rangeChunkSize?: number | undefined; + /** + * - The worker that will be used for loading and + * parsing the PDF data. + */ + worker?: PDFWorker | undefined; + /** + * - Controls the logging level; the constants + * from {@link VerbosityLevel } should be used. + */ + verbosity?: number | undefined; + /** + * - The base URL of the document, used when + * attempting to recover valid absolute URLs for annotations, and outline + * items, that (incorrectly) only specify relative URLs. + */ + docBaseUrl?: string | undefined; + /** + * - The URL where the predefined Adobe CMaps are + * located. Include the trailing slash. + */ + cMapUrl?: string | undefined; + /** + * - Specifies if the Adobe CMaps are binary + * packed or not. The default value is `true`. + */ + cMapPacked?: boolean | undefined; + /** + * - The factory that will be used when + * reading built-in CMap files. Providing a custom factory is useful for + * environments without Fetch API or `XMLHttpRequest` support, such as + * Node.js. The default value is {DOMCMapReaderFactory}. + */ + CMapReaderFactory?: Object | undefined; + /** + * - When `true`, fonts that aren't + * embedded in the PDF document will fallback to a system font. + * The default value is `true` in web environments and `false` in Node.js; + * unless `disableFontFace === true` in which case this defaults to `false` + * regardless of the environment (to prevent completely broken fonts). + */ + useSystemFonts?: boolean | undefined; + /** + * - The URL where the standard font + * files are located. Include the trailing slash. + */ + standardFontDataUrl?: string | undefined; + /** + * - The factory that will be used + * when reading the standard font files. Providing a custom factory is useful + * for environments without Fetch API or `XMLHttpRequest` support, such as + * Node.js. The default value is {DOMStandardFontDataFactory}. + */ + StandardFontDataFactory?: Object | undefined; + /** + * - Enable using the Fetch API in the + * worker-thread when reading CMap and standard font files. When `true`, + * the `CMapReaderFactory` and `StandardFontDataFactory` options are ignored. + * The default value is `true` in web environments and `false` in Node.js. + */ + useWorkerFetch?: boolean | undefined; + /** + * - Reject certain promises, e.g. + * `getOperatorList`, `getTextContent`, and `RenderTask`, when the associated + * PDF data cannot be successfully parsed, instead of attempting to recover + * whatever possible of the data. The default value is `false`. + */ + stopAtErrors?: boolean | undefined; + /** + * - The maximum allowed image size in total + * pixels, i.e. width * height. Images above this value will not be rendered. + * Use -1 for no limit, which is also the default value. + */ + maxImageSize?: number | undefined; + /** + * - Determines if we can evaluate strings + * as JavaScript. Primarily used to improve performance of font rendering, and + * when parsing PDF functions. The default value is `true`. + */ + isEvalSupported?: boolean | undefined; + /** + * - Determines if we can use + * `OffscreenCanvas` in the worker. Primarily used to improve performance of + * image conversion/rendering. + * The default value is `true` in web environments and `false` in Node.js. + */ + isOffscreenCanvasSupported?: boolean | undefined; + /** + * - The integer value is used to + * know when an image must be resized (uses `OffscreenCanvas` in the worker). + * If it's -1 then a possibly slow algorithm is used to guess the max value. + */ + canvasMaxAreaInBytes?: boolean | undefined; + /** + * - By default fonts are converted to + * OpenType fonts and loaded via the Font Loading API or `@font-face` rules. + * If disabled, fonts will be rendered using a built-in font renderer that + * constructs the glyphs with primitive path commands. + * The default value is `false` in web environments and `true` in Node.js. + */ + disableFontFace?: boolean | undefined; + /** + * - Include additional properties, + * which are unused during rendering of PDF documents, when exporting the + * parsed font data from the worker-thread. This may be useful for debugging + * purposes (and backwards compatibility), but note that it will lead to + * increased memory usage. The default value is `false`. + */ + fontExtraProperties?: boolean | undefined; + /** + * - Render Xfa forms if any. + * The default value is `false`. + */ + enableXfa?: boolean | undefined; + /** + * - Specify an explicit document + * context to create elements with and to load resources, such as fonts, + * into. Defaults to the current document. + */ + ownerDocument?: HTMLDocument | undefined; + /** + * - Disable range request loading of PDF + * files. When enabled, and if the server supports partial content requests, + * then the PDF will be fetched in chunks. The default value is `false`. + */ + disableRange?: boolean | undefined; + /** + * - Disable streaming of PDF file data. + * By default PDF.js attempts to load PDF files in chunks. The default value + * is `false`. + */ + disableStream?: boolean | undefined; + /** + * - Disable pre-fetching of PDF file + * data. When range requests are enabled PDF.js will automatically keep + * fetching more data even if it isn't needed to display the current page. + * The default value is `false`. + * + * NOTE: It is also necessary to disable streaming, see above, in order for + * disabling of pre-fetching to work correctly. + */ + disableAutoFetch?: boolean | undefined; + /** + * - Enables special hooks for debugging PDF.js + * (see `web/debugger.js`). The default value is `false`. + */ + pdfBug?: boolean | undefined; + /** + * - The factory instance that will be used + * when creating canvases. The default value is {new DOMCanvasFactory()}. + */ + canvasFactory?: Object | undefined; + /** + * - A factory instance that will be used + * to create SVG filters when rendering some images on the main canvas. + */ + filterFactory?: Object | undefined; + }; + export type OnProgressParameters = { + /** + * - Currently loaded number of bytes. + */ + loaded: number; + /** + * - Total number of bytes in the PDF file. + */ + total: number; + }; + /** + * Page getViewport parameters. + */ + export type GetViewportParameters = { + /** + * - The desired scale of the viewport. + */ + scale: number; + /** + * - The desired rotation, in degrees, of + * the viewport. If omitted it defaults to the page rotation. + */ + rotation?: number | undefined; + /** + * - The horizontal, i.e. x-axis, offset. + * The default value is `0`. + */ + offsetX?: number | undefined; + /** + * - The vertical, i.e. y-axis, offset. + * The default value is `0`. + */ + offsetY?: number | undefined; + /** + * - If true, the y-axis will not be + * flipped. The default value is `false`. + */ + dontFlip?: boolean | undefined; + }; + /** + * Page getTextContent parameters. + */ + export type getTextContentParameters = { + /** + * - When true include marked + * content items in the items array of TextContent. The default is `false`. + */ + includeMarkedContent?: boolean | undefined; + }; + /** + * Page text content. + */ + export type TextContent = { + /** + * - Array of + * {@link TextItem } and {@link TextMarkedContent } objects. TextMarkedContent + * items are included when includeMarkedContent is true. + */ + items: Array; + /** + * - {@link TextStyle } objects, + * indexed by font name. + */ + styles: { + [x: string]: TextStyle; + }; + }; + /** + * Page text content part. + */ + export type TextItem = { + /** + * - Text content. + */ + str: string; + /** + * - Text direction: 'ttb', 'ltr' or 'rtl'. + */ + dir: string; + /** + * - Transformation matrix. + */ + transform: Array; + /** + * - Width in device space. + */ + width: number; + /** + * - Height in device space. + */ + height: number; + /** + * - Font name used by PDF.js for converted font. + */ + fontName: string; + /** + * - Indicating if the text content is followed by a + * line-break. + */ + hasEOL: boolean; + }; + /** + * Page text marked content part. + */ + export type TextMarkedContent = { + /** + * - Either 'beginMarkedContent', + * 'beginMarkedContentProps', or 'endMarkedContent'. + */ + type: string; + /** + * - The marked content identifier. Only used for type + * 'beginMarkedContentProps'. + */ + id: string; + }; + /** + * Text style. + */ + export type TextStyle = { + /** + * - Font ascent. + */ + ascent: number; + /** + * - Font descent. + */ + descent: number; + /** + * - Whether or not the text is in vertical mode. + */ + vertical: boolean; + /** + * - The possible font family. + */ + fontFamily: string; + }; + /** + * Page annotation parameters. + */ + export type GetAnnotationsParameters = { + /** + * - Determines the annotations that are fetched, + * can be 'display' (viewable annotations), 'print' (printable annotations), + * or 'any' (all annotations). The default value is 'display'. + */ + intent?: string | undefined; + }; + /** + * Page render parameters. + */ + export type RenderParameters = { + /** + * - A 2D context of a DOM + * Canvas object. + */ + canvasContext: CanvasRenderingContext2D; + /** + * - Rendering viewport obtained by calling + * the `PDFPageProxy.getViewport` method. + */ + viewport: PageViewport; + /** + * - Rendering intent, can be 'display', 'print', + * or 'any'. The default value is 'display'. + */ + intent?: string | undefined; + /** + * Controls which annotations are rendered + * onto the canvas, for annotations with appearance-data; the values from + * {@link AnnotationMode } should be used. The following values are supported: + * - `AnnotationMode.DISABLE`, which disables all annotations. + * - `AnnotationMode.ENABLE`, which includes all possible annotations (thus + * it also depends on the `intent`-option, see above). + * - `AnnotationMode.ENABLE_FORMS`, which excludes annotations that contain + * interactive form elements (those will be rendered in the display layer). + * - `AnnotationMode.ENABLE_STORAGE`, which includes all possible annotations + * (as above) but where interactive form elements are updated with data + * from the {@link AnnotationStorage }-instance; useful e.g. for printing. + * The default value is `AnnotationMode.ENABLE`. + */ + annotationMode?: number | undefined; + /** + * - Additional transform, applied just + * before viewport transform. + */ + transform?: any[] | undefined; + /** + * - Background + * to use for the canvas. + * Any valid `canvas.fillStyle` can be used: a `DOMString` parsed as CSS + * value, a `CanvasGradient` object (a linear or radial gradient) or + * a `CanvasPattern` object (a repetitive image). The default value is + * 'rgb(255,255,255)'. + * + * NOTE: This option may be partially, or completely, ignored when the + * `pageColors`-option is used. + */ + background?: string | CanvasGradient | CanvasPattern | undefined; + /** + * - Overwrites background and foreground colors + * with user defined ones in order to improve readability in high contrast + * mode. + */ + pageColors?: Object | undefined; + /** + * - + * A promise that should resolve with an {@link OptionalContentConfig }created from `PDFDocumentProxy.getOptionalContentConfig`. If `null`, + * the configuration will be fetched automatically with the default visibility + * states set. + */ + optionalContentConfigPromise?: Promise | undefined; + /** + * - Map some + * annotation ids with canvases used to render them. + */ + annotationCanvasMap?: Map | undefined; + printAnnotationStorage?: PrintAnnotationStorage | undefined; + }; + /** + * Page getOperatorList parameters. + */ + export type GetOperatorListParameters = { + /** + * - Rendering intent, can be 'display', 'print', + * or 'any'. The default value is 'display'. + */ + intent?: string | undefined; + /** + * Controls which annotations are included + * in the operatorList, for annotations with appearance-data; the values from + * {@link AnnotationMode } should be used. The following values are supported: + * - `AnnotationMode.DISABLE`, which disables all annotations. + * - `AnnotationMode.ENABLE`, which includes all possible annotations (thus + * it also depends on the `intent`-option, see above). + * - `AnnotationMode.ENABLE_FORMS`, which excludes annotations that contain + * interactive form elements (those will be rendered in the display layer). + * - `AnnotationMode.ENABLE_STORAGE`, which includes all possible annotations + * (as above) but where interactive form elements are updated with data + * from the {@link AnnotationStorage }-instance; useful e.g. for printing. + * The default value is `AnnotationMode.ENABLE`. + */ + annotationMode?: number | undefined; + printAnnotationStorage?: PrintAnnotationStorage | undefined; + }; + /** + * Structure tree node. The root node will have a role "Root". + */ + export type StructTreeNode = { + /** + * - Array of + * {@link StructTreeNode } and {@link StructTreeContent } objects. + */ + children: Array; + /** + * - element's role, already mapped if a role map exists + * in the PDF. + */ + role: string; + }; + /** + * Structure tree content. + */ + export type StructTreeContent = { + /** + * - either "content" for page and stream structure + * elements or "object" for object references. + */ + type: string; + /** + * - unique id that will map to the text layer. + */ + id: string; + }; + /** + * PDF page operator list. + */ + export type PDFOperatorList = { + /** + * - Array containing the operator functions. + */ + fnArray: Array; + /** + * - Array containing the arguments of the + * functions. + */ + argsArray: Array; + }; + export type PDFWorkerParameters = { + /** + * - The name of the worker. + */ + name?: string | undefined; + /** + * - The `workerPort` object. + */ + port?: Worker | undefined; + /** + * - Controls the logging level; + * the constants from {@link VerbosityLevel } should be used. + */ + verbosity?: number | undefined; + }; + /** @type {string} */ + export const build: string; + export let DefaultCanvasFactory: typeof DOMCanvasFactory; + export let DefaultCMapReaderFactory: typeof DOMCMapReaderFactory; + export let DefaultFilterFactory: typeof DOMFilterFactory; + export let DefaultStandardFontDataFactory: typeof DOMStandardFontDataFactory; + /** + * @typedef { Int8Array | Uint8Array | Uint8ClampedArray | + * Int16Array | Uint16Array | + * Int32Array | Uint32Array | Float32Array | + * Float64Array + * } TypedArray + */ + /** + * @typedef { TypedArray | ArrayBuffer | Array | string } BinaryData + */ + /** + * @typedef {Object} RefProxy + * @property {number} num + * @property {number} gen + */ + /** + * Document initialization / loading parameters object. + * + * @typedef {Object} DocumentInitParameters + * @property {string | URL} [url] - The URL of the PDF. + * @property {BinaryData} [data] - Binary PDF data. + * Use TypedArrays (Uint8Array) to improve the memory usage. If PDF data is + * BASE64-encoded, use `atob()` to convert it to a binary string first. + * + * NOTE: If TypedArrays are used they will generally be transferred to the + * worker-thread. This will help reduce main-thread memory usage, however + * it will take ownership of the TypedArrays. + * @property {Object} [httpHeaders] - Basic authentication headers. + * @property {boolean} [withCredentials] - Indicates whether or not + * cross-site Access-Control requests should be made using credentials such + * as cookies or authorization headers. The default is `false`. + * @property {string} [password] - For decrypting password-protected PDFs. + * @property {number} [length] - The PDF file length. It's used for progress + * reports and range requests operations. + * @property {PDFDataRangeTransport} [range] - Allows for using a custom range + * transport implementation. + * @property {number} [rangeChunkSize] - Specify maximum number of bytes fetched + * per range request. The default value is {@link DEFAULT_RANGE_CHUNK_SIZE}. + * @property {PDFWorker} [worker] - The worker that will be used for loading and + * parsing the PDF data. + * @property {number} [verbosity] - Controls the logging level; the constants + * from {@link VerbosityLevel} should be used. + * @property {string} [docBaseUrl] - The base URL of the document, used when + * attempting to recover valid absolute URLs for annotations, and outline + * items, that (incorrectly) only specify relative URLs. + * @property {string} [cMapUrl] - The URL where the predefined Adobe CMaps are + * located. Include the trailing slash. + * @property {boolean} [cMapPacked] - Specifies if the Adobe CMaps are binary + * packed or not. The default value is `true`. + * @property {Object} [CMapReaderFactory] - The factory that will be used when + * reading built-in CMap files. Providing a custom factory is useful for + * environments without Fetch API or `XMLHttpRequest` support, such as + * Node.js. The default value is {DOMCMapReaderFactory}. + * @property {boolean} [useSystemFonts] - When `true`, fonts that aren't + * embedded in the PDF document will fallback to a system font. + * The default value is `true` in web environments and `false` in Node.js; + * unless `disableFontFace === true` in which case this defaults to `false` + * regardless of the environment (to prevent completely broken fonts). + * @property {string} [standardFontDataUrl] - The URL where the standard font + * files are located. Include the trailing slash. + * @property {Object} [StandardFontDataFactory] - The factory that will be used + * when reading the standard font files. Providing a custom factory is useful + * for environments without Fetch API or `XMLHttpRequest` support, such as + * Node.js. The default value is {DOMStandardFontDataFactory}. + * @property {boolean} [useWorkerFetch] - Enable using the Fetch API in the + * worker-thread when reading CMap and standard font files. When `true`, + * the `CMapReaderFactory` and `StandardFontDataFactory` options are ignored. + * The default value is `true` in web environments and `false` in Node.js. + * @property {boolean} [stopAtErrors] - Reject certain promises, e.g. + * `getOperatorList`, `getTextContent`, and `RenderTask`, when the associated + * PDF data cannot be successfully parsed, instead of attempting to recover + * whatever possible of the data. The default value is `false`. + * @property {number} [maxImageSize] - The maximum allowed image size in total + * pixels, i.e. width * height. Images above this value will not be rendered. + * Use -1 for no limit, which is also the default value. + * @property {boolean} [isEvalSupported] - Determines if we can evaluate strings + * as JavaScript. Primarily used to improve performance of font rendering, and + * when parsing PDF functions. The default value is `true`. + * @property {boolean} [isOffscreenCanvasSupported] - Determines if we can use + * `OffscreenCanvas` in the worker. Primarily used to improve performance of + * image conversion/rendering. + * The default value is `true` in web environments and `false` in Node.js. + * @property {boolean} [canvasMaxAreaInBytes] - The integer value is used to + * know when an image must be resized (uses `OffscreenCanvas` in the worker). + * If it's -1 then a possibly slow algorithm is used to guess the max value. + * @property {boolean} [disableFontFace] - By default fonts are converted to + * OpenType fonts and loaded via the Font Loading API or `@font-face` rules. + * If disabled, fonts will be rendered using a built-in font renderer that + * constructs the glyphs with primitive path commands. + * The default value is `false` in web environments and `true` in Node.js. + * @property {boolean} [fontExtraProperties] - Include additional properties, + * which are unused during rendering of PDF documents, when exporting the + * parsed font data from the worker-thread. This may be useful for debugging + * purposes (and backwards compatibility), but note that it will lead to + * increased memory usage. The default value is `false`. + * @property {boolean} [enableXfa] - Render Xfa forms if any. + * The default value is `false`. + * @property {HTMLDocument} [ownerDocument] - Specify an explicit document + * context to create elements with and to load resources, such as fonts, + * into. Defaults to the current document. + * @property {boolean} [disableRange] - Disable range request loading of PDF + * files. When enabled, and if the server supports partial content requests, + * then the PDF will be fetched in chunks. The default value is `false`. + * @property {boolean} [disableStream] - Disable streaming of PDF file data. + * By default PDF.js attempts to load PDF files in chunks. The default value + * is `false`. + * @property {boolean} [disableAutoFetch] - Disable pre-fetching of PDF file + * data. When range requests are enabled PDF.js will automatically keep + * fetching more data even if it isn't needed to display the current page. + * The default value is `false`. + * + * NOTE: It is also necessary to disable streaming, see above, in order for + * disabling of pre-fetching to work correctly. + * @property {boolean} [pdfBug] - Enables special hooks for debugging PDF.js + * (see `web/debugger.js`). The default value is `false`. + * @property {Object} [canvasFactory] - The factory instance that will be used + * when creating canvases. The default value is {new DOMCanvasFactory()}. + * @property {Object} [filterFactory] - A factory instance that will be used + * to create SVG filters when rendering some images on the main canvas. + */ + /** + * This is the main entry point for loading a PDF and interacting with it. + * + * NOTE: If a URL is used to fetch the PDF data a standard Fetch API call (or + * XHR as fallback) is used, which means it must follow same origin rules, + * e.g. no cross-domain requests without CORS. + * + * @param {string | URL | TypedArray | ArrayBuffer | DocumentInitParameters} + * src - Can be a URL where a PDF file is located, a typed array (Uint8Array) + * already populated with data, or a parameter object. + * @returns {PDFDocumentLoadingTask} + */ + export function getDocument( + src: string | URL | TypedArray | ArrayBuffer | DocumentInitParameters + ): PDFDocumentLoadingTask; + export class LoopbackPort { + postMessage(obj: any, transfer: any): void; + addEventListener(name: any, listener: any): void; + removeEventListener(name: any, listener: any): void; + terminate(): void; + #private; + } + /** + * @typedef {Object} OnProgressParameters + * @property {number} loaded - Currently loaded number of bytes. + * @property {number} total - Total number of bytes in the PDF file. + */ + /** + * The loading task controls the operations required to load a PDF document + * (such as network requests) and provides a way to listen for completion, + * after which individual pages can be rendered. + */ + export class PDFDocumentLoadingTask { + static "__#16@#docId": number; + _capability: import("../shared/util.js").PromiseCapability; + _transport: any; + _worker: any; + /** + * Unique identifier for the document loading task. + * @type {string} + */ + docId: string; + /** + * Whether the loading task is destroyed or not. + * @type {boolean} + */ + destroyed: boolean; + /** + * Callback to request a password if a wrong or no password was provided. + * The callback receives two parameters: a function that should be called + * with the new password, and a reason (see {@link PasswordResponses}). + * @type {function} + */ + onPassword: Function; + /** + * Callback to be able to monitor the loading progress of the PDF file + * (necessary to implement e.g. a loading bar). + * The callback receives an {@link OnProgressParameters} argument. + * @type {function} + */ + onProgress: Function; + /** + * Promise for document loading task completion. + * @type {Promise} + */ + get promise(): Promise; + /** + * Abort all network requests and destroy the worker. + * @returns {Promise} A promise that is resolved when destruction is + * completed. + */ + destroy(): Promise; + } + /** + * Proxy to a `PDFDocument` in the worker thread. + */ + export class PDFDocumentProxy { + constructor(pdfInfo: any, transport: any); + _pdfInfo: any; + _transport: any; + /** + * @type {AnnotationStorage} Storage for annotation data in forms. + */ + get annotationStorage(): AnnotationStorage; + /** + * @type {Object} The filter factory instance. + */ + get filterFactory(): Object; + /** + * @type {number} Total number of pages in the PDF file. + */ + get numPages(): number; + /** + * @type {Array} A (not guaranteed to be) unique ID to + * identify the PDF document. + * NOTE: The first element will always be defined for all PDF documents, + * whereas the second element is only defined for *modified* PDF documents. + */ + get fingerprints(): string[]; + /** + * @type {boolean} True if only XFA form. + */ + get isPureXfa(): boolean; + /** + * NOTE: This is (mostly) intended to support printing of XFA forms. + * + * @type {Object | null} An object representing a HTML tree structure + * to render the XFA, or `null` when no XFA form exists. + */ + get allXfaHtml(): Object | null; + /** + * @param {number} pageNumber - The page number to get. The first page is 1. + * @returns {Promise} A promise that is resolved with + * a {@link PDFPageProxy} object. + */ + getPage(pageNumber: number): Promise; + /** + * @param {RefProxy} ref - The page reference. + * @returns {Promise} A promise that is resolved with the page index, + * starting from zero, that is associated with the reference. + */ + getPageIndex(ref: RefProxy): Promise; + /** + * @returns {Promise>>} A promise that is resolved + * with a mapping from named destinations to references. + * + * This can be slow for large documents. Use `getDestination` instead. + */ + getDestinations(): Promise<{ + [x: string]: Array; + }>; + /** + * @param {string} id - The named destination to get. + * @returns {Promise | null>} A promise that is resolved with all + * information of the given named destination, or `null` when the named + * destination is not present in the PDF file. + */ + getDestination(id: string): Promise | null>; + /** + * @returns {Promise | null>} A promise that is resolved with + * an {Array} containing the page labels that correspond to the page + * indexes, or `null` when no page labels are present in the PDF file. + */ + getPageLabels(): Promise | null>; + /** + * @returns {Promise} A promise that is resolved with a {string} + * containing the page layout name. + */ + getPageLayout(): Promise; + /** + * @returns {Promise} A promise that is resolved with a {string} + * containing the page mode name. + */ + getPageMode(): Promise; + /** + * @returns {Promise} A promise that is resolved with an + * {Object} containing the viewer preferences, or `null` when no viewer + * preferences are present in the PDF file. + */ + getViewerPreferences(): Promise; + /** + * @returns {Promise} A promise that is resolved with an {Array} + * containing the destination, or `null` when no open action is present + * in the PDF. + */ + getOpenAction(): Promise; + /** + * @returns {Promise} A promise that is resolved with a lookup table + * for mapping named attachments to their content. + */ + getAttachments(): Promise; + /** + * @returns {Promise | null>} A promise that is resolved with + * an {Array} of all the JavaScript strings in the name tree, or `null` + * if no JavaScript exists. + */ + getJavaScript(): Promise | null>; + /** + * @returns {Promise} A promise that is resolved with + * an {Object} with the JavaScript actions: + * - from the name tree (like getJavaScript); + * - from A or AA entries in the catalog dictionary. + * , or `null` if no JavaScript exists. + */ + getJSActions(): Promise; + /** + * @typedef {Object} OutlineNode + * @property {string} title + * @property {boolean} bold + * @property {boolean} italic + * @property {Uint8ClampedArray} color - The color in RGB format to use for + * display purposes. + * @property {string | Array | null} dest + * @property {string | null} url + * @property {string | undefined} unsafeUrl + * @property {boolean | undefined} newWindow + * @property {number | undefined} count + * @property {Array} items + */ + /** + * @returns {Promise>} A promise that is resolved with an + * {Array} that is a tree outline (if it has one) of the PDF file. + */ + getOutline(): Promise< + { + title: string; + bold: boolean; + italic: boolean; + /** + * - The color in RGB format to use for + * display purposes. + */ + color: Uint8ClampedArray; + dest: string | Array | null; + url: string | null; + unsafeUrl: string | undefined; + newWindow: boolean | undefined; + count: number | undefined; + items: any[]; + }[] + >; + /** + * @returns {Promise} A promise that is resolved with + * an {@link OptionalContentConfig} that contains all the optional content + * groups (assuming that the document has any). + */ + getOptionalContentConfig(): Promise; + /** + * @returns {Promise | null>} A promise that is resolved with + * an {Array} that contains the permission flags for the PDF document, or + * `null` when no permissions are present in the PDF file. + */ + getPermissions(): Promise | null>; + /** + * @returns {Promise<{ info: Object, metadata: Metadata }>} A promise that is + * resolved with an {Object} that has `info` and `metadata` properties. + * `info` is an {Object} filled with anything available in the information + * dictionary and similarly `metadata` is a {Metadata} object with + * information from the metadata section of the PDF. + */ + getMetadata(): Promise<{ + info: Object; + metadata: Metadata; + }>; + /** + * @typedef {Object} MarkInfo + * Properties correspond to Table 321 of the PDF 32000-1:2008 spec. + * @property {boolean} Marked + * @property {boolean} UserProperties + * @property {boolean} Suspects + */ + /** + * @returns {Promise} A promise that is resolved with + * a {MarkInfo} object that contains the MarkInfo flags for the PDF + * document, or `null` when no MarkInfo values are present in the PDF file. + */ + getMarkInfo(): Promise<{ + Marked: boolean; + UserProperties: boolean; + Suspects: boolean; + } | null>; + /** + * @returns {Promise} A promise that is resolved with a + * {Uint8Array} containing the raw data of the PDF document. + */ + getData(): Promise; + /** + * @returns {Promise} A promise that is resolved with a + * {Uint8Array} containing the full data of the saved document. + */ + saveDocument(): Promise; + /** + * @returns {Promise<{ length: number }>} A promise that is resolved when the + * document's data is loaded. It is resolved with an {Object} that contains + * the `length` property that indicates size of the PDF data in bytes. + */ + getDownloadInfo(): Promise<{ + length: number; + }>; + /** + * Cleans up resources allocated by the document on both the main and worker + * threads. + * + * NOTE: Do not, under any circumstances, call this method when rendering is + * currently ongoing since that may lead to rendering errors. + * + * @param {boolean} [keepLoadedFonts] - Let fonts remain attached to the DOM. + * NOTE: This will increase persistent memory usage, hence don't use this + * option unless absolutely necessary. The default value is `false`. + * @returns {Promise} A promise that is resolved when clean-up has finished. + */ + cleanup(keepLoadedFonts?: boolean | undefined): Promise; + /** + * Destroys the current document instance and terminates the worker. + */ + destroy(): Promise; + /** + * @type {DocumentInitParameters} A subset of the current + * {DocumentInitParameters}, which are needed in the viewer. + */ + get loadingParams(): DocumentInitParameters; + /** + * @type {PDFDocumentLoadingTask} The loadingTask for the current document. + */ + get loadingTask(): PDFDocumentLoadingTask; + /** + * @returns {Promise> | null>} A promise that is + * resolved with an {Object} containing /AcroForm field data for the JS + * sandbox, or `null` when no field data is present in the PDF file. + */ + getFieldObjects(): Promise<{ + [x: string]: Array; + } | null>; + /** + * @returns {Promise} A promise that is resolved with `true` + * if some /AcroForm fields have JavaScript actions. + */ + hasJSActions(): Promise; + /** + * @returns {Promise | null>} A promise that is resolved with an + * {Array} containing IDs of annotations that have a calculation + * action, or `null` when no such annotations are present in the PDF file. + */ + getCalculationOrderIds(): Promise | null>; + } + /** + * Page getViewport parameters. + * + * @typedef {Object} GetViewportParameters + * @property {number} scale - The desired scale of the viewport. + * @property {number} [rotation] - The desired rotation, in degrees, of + * the viewport. If omitted it defaults to the page rotation. + * @property {number} [offsetX] - The horizontal, i.e. x-axis, offset. + * The default value is `0`. + * @property {number} [offsetY] - The vertical, i.e. y-axis, offset. + * The default value is `0`. + * @property {boolean} [dontFlip] - If true, the y-axis will not be + * flipped. The default value is `false`. + */ + /** + * Page getTextContent parameters. + * + * @typedef {Object} getTextContentParameters + * @property {boolean} [includeMarkedContent] - When true include marked + * content items in the items array of TextContent. The default is `false`. + */ + /** + * Page text content. + * + * @typedef {Object} TextContent + * @property {Array} items - Array of + * {@link TextItem} and {@link TextMarkedContent} objects. TextMarkedContent + * items are included when includeMarkedContent is true. + * @property {Object} styles - {@link TextStyle} objects, + * indexed by font name. + */ + /** + * Page text content part. + * + * @typedef {Object} TextItem + * @property {string} str - Text content. + * @property {string} dir - Text direction: 'ttb', 'ltr' or 'rtl'. + * @property {Array} transform - Transformation matrix. + * @property {number} width - Width in device space. + * @property {number} height - Height in device space. + * @property {string} fontName - Font name used by PDF.js for converted font. + * @property {boolean} hasEOL - Indicating if the text content is followed by a + * line-break. + */ + /** + * Page text marked content part. + * + * @typedef {Object} TextMarkedContent + * @property {string} type - Either 'beginMarkedContent', + * 'beginMarkedContentProps', or 'endMarkedContent'. + * @property {string} id - The marked content identifier. Only used for type + * 'beginMarkedContentProps'. + */ + /** + * Text style. + * + * @typedef {Object} TextStyle + * @property {number} ascent - Font ascent. + * @property {number} descent - Font descent. + * @property {boolean} vertical - Whether or not the text is in vertical mode. + * @property {string} fontFamily - The possible font family. + */ + /** + * Page annotation parameters. + * + * @typedef {Object} GetAnnotationsParameters + * @property {string} [intent] - Determines the annotations that are fetched, + * can be 'display' (viewable annotations), 'print' (printable annotations), + * or 'any' (all annotations). The default value is 'display'. + */ + /** + * Page render parameters. + * + * @typedef {Object} RenderParameters + * @property {CanvasRenderingContext2D} canvasContext - A 2D context of a DOM + * Canvas object. + * @property {PageViewport} viewport - Rendering viewport obtained by calling + * the `PDFPageProxy.getViewport` method. + * @property {string} [intent] - Rendering intent, can be 'display', 'print', + * or 'any'. The default value is 'display'. + * @property {number} [annotationMode] Controls which annotations are rendered + * onto the canvas, for annotations with appearance-data; the values from + * {@link AnnotationMode} should be used. The following values are supported: + * - `AnnotationMode.DISABLE`, which disables all annotations. + * - `AnnotationMode.ENABLE`, which includes all possible annotations (thus + * it also depends on the `intent`-option, see above). + * - `AnnotationMode.ENABLE_FORMS`, which excludes annotations that contain + * interactive form elements (those will be rendered in the display layer). + * - `AnnotationMode.ENABLE_STORAGE`, which includes all possible annotations + * (as above) but where interactive form elements are updated with data + * from the {@link AnnotationStorage}-instance; useful e.g. for printing. + * The default value is `AnnotationMode.ENABLE`. + * @property {Array} [transform] - Additional transform, applied just + * before viewport transform. + * @property {CanvasGradient | CanvasPattern | string} [background] - Background + * to use for the canvas. + * Any valid `canvas.fillStyle` can be used: a `DOMString` parsed as CSS + * value, a `CanvasGradient` object (a linear or radial gradient) or + * a `CanvasPattern` object (a repetitive image). The default value is + * 'rgb(255,255,255)'. + * + * NOTE: This option may be partially, or completely, ignored when the + * `pageColors`-option is used. + * @property {Object} [pageColors] - Overwrites background and foreground colors + * with user defined ones in order to improve readability in high contrast + * mode. + * @property {Promise} [optionalContentConfigPromise] - + * A promise that should resolve with an {@link OptionalContentConfig} + * created from `PDFDocumentProxy.getOptionalContentConfig`. If `null`, + * the configuration will be fetched automatically with the default visibility + * states set. + * @property {Map} [annotationCanvasMap] - Map some + * annotation ids with canvases used to render them. + * @property {PrintAnnotationStorage} [printAnnotationStorage] + */ + /** + * Page getOperatorList parameters. + * + * @typedef {Object} GetOperatorListParameters + * @property {string} [intent] - Rendering intent, can be 'display', 'print', + * or 'any'. The default value is 'display'. + * @property {number} [annotationMode] Controls which annotations are included + * in the operatorList, for annotations with appearance-data; the values from + * {@link AnnotationMode} should be used. The following values are supported: + * - `AnnotationMode.DISABLE`, which disables all annotations. + * - `AnnotationMode.ENABLE`, which includes all possible annotations (thus + * it also depends on the `intent`-option, see above). + * - `AnnotationMode.ENABLE_FORMS`, which excludes annotations that contain + * interactive form elements (those will be rendered in the display layer). + * - `AnnotationMode.ENABLE_STORAGE`, which includes all possible annotations + * (as above) but where interactive form elements are updated with data + * from the {@link AnnotationStorage}-instance; useful e.g. for printing. + * The default value is `AnnotationMode.ENABLE`. + * @property {PrintAnnotationStorage} [printAnnotationStorage] + */ + /** + * Structure tree node. The root node will have a role "Root". + * + * @typedef {Object} StructTreeNode + * @property {Array} children - Array of + * {@link StructTreeNode} and {@link StructTreeContent} objects. + * @property {string} role - element's role, already mapped if a role map exists + * in the PDF. + */ + /** + * Structure tree content. + * + * @typedef {Object} StructTreeContent + * @property {string} type - either "content" for page and stream structure + * elements or "object" for object references. + * @property {string} id - unique id that will map to the text layer. + */ + /** + * PDF page operator list. + * + * @typedef {Object} PDFOperatorList + * @property {Array} fnArray - Array containing the operator functions. + * @property {Array} argsArray - Array containing the arguments of the + * functions. + */ + /** + * Proxy to a `PDFPage` in the worker thread. + */ + export class PDFPageProxy { + constructor( + pageIndex: any, + pageInfo: any, + transport: any, + pdfBug?: boolean + ); + _pageIndex: any; + _pageInfo: any; + _transport: any; + _stats: StatTimer | null; + _pdfBug: boolean; + /** @type {PDFObjects} */ + commonObjs: PDFObjects; + objs: PDFObjects; + _maybeCleanupAfterRender: boolean; + _intentStates: Map; + destroyed: boolean; + /** + * @type {number} Page number of the page. First page is 1. + */ + get pageNumber(): number; + /** + * @type {number} The number of degrees the page is rotated clockwise. + */ + get rotate(): number; + /** + * @type {RefProxy | null} The reference that points to this page. + */ + get ref(): RefProxy | null; + /** + * @type {number} The default size of units in 1/72nds of an inch. + */ + get userUnit(): number; + /** + * @type {Array} An array of the visible portion of the PDF page in + * user space units [x1, y1, x2, y2]. + */ + get view(): number[]; + /** + * @param {GetViewportParameters} params - Viewport parameters. + * @returns {PageViewport} Contains 'width' and 'height' properties + * along with transforms required for rendering. + */ + getViewport({ + scale, + rotation, + offsetX, + offsetY, + dontFlip, + }?: GetViewportParameters): PageViewport; + /** + * @param {GetAnnotationsParameters} params - Annotation parameters. + * @returns {Promise>} A promise that is resolved with an + * {Array} of the annotation objects. + */ + getAnnotations({ intent }?: GetAnnotationsParameters): Promise>; + /** + * @returns {Promise} A promise that is resolved with an + * {Object} with JS actions. + */ + getJSActions(): Promise; + /** + * @type {boolean} True if only XFA form. + */ + get isPureXfa(): boolean; + /** + * @returns {Promise} A promise that is resolved with + * an {Object} with a fake DOM object (a tree structure where elements + * are {Object} with a name, attributes (class, style, ...), value and + * children, very similar to a HTML DOM tree), or `null` if no XFA exists. + */ + getXfa(): Promise; + /** + * Begins the process of rendering a page to the desired context. + * + * @param {RenderParameters} params - Page render parameters. + * @returns {RenderTask} An object that contains a promise that is + * resolved when the page finishes rendering. + */ + render( + { + canvasContext, + viewport, + intent, + annotationMode, + transform, + background, + optionalContentConfigPromise, + annotationCanvasMap, + pageColors, + printAnnotationStorage, + }: RenderParameters, + ...args: any[] + ): RenderTask; + /** + * @param {GetOperatorListParameters} params - Page getOperatorList + * parameters. + * @returns {Promise} A promise resolved with an + * {@link PDFOperatorList} object that represents the page's operator list. + */ + getOperatorList({ + intent, + annotationMode, + printAnnotationStorage, + }?: GetOperatorListParameters): Promise; + /** + * NOTE: All occurrences of whitespace will be replaced by + * standard spaces (0x20). + * + * @param {getTextContentParameters} params - getTextContent parameters. + * @returns {ReadableStream} Stream for reading text content chunks. + */ + streamTextContent({ + includeMarkedContent, + }?: getTextContentParameters): ReadableStream; + /** + * NOTE: All occurrences of whitespace will be replaced by + * standard spaces (0x20). + * + * @param {getTextContentParameters} params - getTextContent parameters. + * @returns {Promise} A promise that is resolved with a + * {@link TextContent} object that represents the page's text content. + */ + getTextContent(params?: getTextContentParameters): Promise; + /** + * @returns {Promise} A promise that is resolved with a + * {@link StructTreeNode} object that represents the page's structure tree, + * or `null` when no structure tree is present for the current page. + */ + getStructTree(): Promise; + /** + * Destroys the page object. + * @private + */ + private _destroy; + /** + * Cleans up resources allocated by the page. + * + * @param {boolean} [resetStats] - Reset page stats, if enabled. + * The default value is `false`. + * @returns {boolean} Indicates if clean-up was successfully run. + */ + cleanup(resetStats?: boolean | undefined): boolean; + /** + * @private + */ + private _startRenderPage; + /** + * @private + */ + private _renderPageChunk; + /** + * @private + */ + private _pumpOperatorList; + /** + * @private + */ + private _abortOperatorList; + /** + * @type {StatTimer | null} Returns page stats, if enabled; returns `null` + * otherwise. + */ + get stats(): StatTimer | null; + #private; + } + /** + * PDF.js web worker abstraction that controls the instantiation of PDF + * documents. Message handlers are used to pass information from the main + * thread to the worker thread and vice versa. If the creation of a web + * worker is not possible, a "fake" worker will be used instead. + * + * @param {PDFWorkerParameters} params - The worker initialization parameters. + */ + export class PDFWorker { + static "__#19@#workerPorts": WeakMap; + /** + * @param {PDFWorkerParameters} params - The worker initialization parameters. + */ + static fromPort(params: PDFWorkerParameters): any; + /** + * The current `workerSrc`, when it exists. + * @type {string} + */ + static get workerSrc(): string; + static get _mainThreadWorkerMessageHandler(): any; + static get _setupFakeWorkerGlobal(): any; + constructor({ + name, + port, + verbosity, + }?: { + name?: null | undefined; + port?: null | undefined; + verbosity?: number | undefined; + }); + name: any; + destroyed: boolean; + verbosity: number; + _readyCapability: import("../shared/util.js").PromiseCapability; + _port: any; + _webWorker: Worker | null; + _messageHandler: MessageHandler | null; + /** + * Promise for worker initialization completion. + * @type {Promise} + */ + get promise(): Promise; + /** + * The current `workerPort`, when it exists. + * @type {Worker} + */ + get port(): Worker; + /** + * The current MessageHandler-instance. + * @type {MessageHandler} + */ + get messageHandler(): MessageHandler; + _initializeFromPort(port: any): void; + _initialize(): void; + _setupFakeWorker(): void; + /** + * Destroys the worker instance. + */ + destroy(): void; + } + export namespace PDFWorkerUtil { + const isWorkerDisabled: boolean; + const fallbackWorkerSrc: null; + const fakeWorkerId: number; + } + /** + * Allows controlling of the rendering tasks. + */ + export class RenderTask { + constructor(internalRenderTask: any); + /** + * Callback for incremental rendering -- a function that will be called + * each time the rendering is paused. To continue rendering call the + * function that is the first argument to the callback. + * @type {function} + */ + onContinue: Function; + /** + * Promise for rendering task completion. + * @type {Promise} + */ + get promise(): Promise; + /** + * Cancels the rendering task. If the task is currently rendering it will + * not be cancelled until graphics pauses with a timeout. The promise that + * this object extends will be rejected when cancelled. + * + * @param {number} [extraDelay] + */ + cancel(extraDelay?: number | undefined): void; + /** + * Whether form fields are rendered separately from the main operatorList. + * @type {boolean} + */ + get separateAnnots(): boolean; + #private; + } + /** @type {string} */ + export const version: string; +} diff --git a/libs/langchain-community/src/types/type-utils.ts b/libs/langchain-community/src/types/type-utils.ts new file mode 100644 index 000000000000..e2c1e6970a52 --- /dev/null +++ b/libs/langchain-community/src/types/type-utils.ts @@ -0,0 +1,3 @@ +// Utility for marking only some keys of an interface as optional +// Compare to Partial which marks all keys as optional +export type Optional = Omit & Partial>; diff --git a/libs/langchain-community/src/utils/bedrock.ts b/libs/langchain-community/src/utils/bedrock.ts new file mode 100644 index 000000000000..82a5d21ca4c0 --- /dev/null +++ b/libs/langchain-community/src/utils/bedrock.ts @@ -0,0 +1,134 @@ +import type { AwsCredentialIdentity, Provider } from "@aws-sdk/types"; + +export type CredentialType = + | AwsCredentialIdentity + | Provider; + +/** Bedrock models. + To authenticate, the AWS client uses the following methods to automatically load credentials: + https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + If a specific credential profile should be used, you must pass the name of the profile from the ~/.aws/credentials file that is to be used. + Make sure the credentials / roles used have the required policies to access the Bedrock service. +*/ +export interface BaseBedrockInput { + /** Model to use. + For example, "amazon.titan-tg1-large", this is equivalent to the modelId property in the list-foundation-models api. + */ + model: string; + + /** The AWS region e.g. `us-west-2`. + Fallback to AWS_DEFAULT_REGION env variable or region specified in ~/.aws/config in case it is not provided here. + */ + region?: string; + + /** AWS Credentials. + If no credentials are provided, the default credentials from `@aws-sdk/credential-provider-node` will be used. + */ + credentials?: CredentialType; + + /** Temperature. */ + temperature?: number; + + /** Max tokens. */ + maxTokens?: number; + + /** A custom fetch function for low-level access to AWS API. Defaults to fetch(). */ + fetchFn?: typeof fetch; + + /** @deprecated Use endpointHost instead Override the default endpoint url. */ + endpointUrl?: string; + + /** Override the default endpoint hostname. */ + endpointHost?: string; + + /** + * Optional additional stop sequences to pass to the model. Currently only supported for Anthropic and AI21. + * @deprecated Use .bind({ "stop": [...] }) instead + * */ + stopSequences?: string[]; + + /** Additional kwargs to pass to the model. */ + modelKwargs?: Record; + + /** Whether or not to stream responses */ + streaming: boolean; +} + +type Dict = { [key: string]: unknown }; + +/** + * A helper class used within the `Bedrock` class. It is responsible for + * preparing the input and output for the Bedrock service. It formats the + * input prompt based on the provider (e.g., "anthropic", "ai21", + * "amazon") and extracts the generated text from the service response. + */ +export class BedrockLLMInputOutputAdapter { + /** Adapter class to prepare the inputs from Langchain to a format + that LLM model expects. Also, provides a helper function to extract + the generated text from the model response. */ + + static prepareInput( + provider: string, + prompt: string, + maxTokens = 50, + temperature = 0, + stopSequences: string[] | undefined = undefined, + modelKwargs: Record = {}, + bedrockMethod: "invoke" | "invoke-with-response-stream" = "invoke" + ): Dict { + const inputBody: Dict = {}; + + if (provider === "anthropic") { + inputBody.prompt = prompt; + inputBody.max_tokens_to_sample = maxTokens; + inputBody.temperature = temperature; + inputBody.stop_sequences = stopSequences; + } else if (provider === "ai21") { + inputBody.prompt = prompt; + inputBody.maxTokens = maxTokens; + inputBody.temperature = temperature; + inputBody.stopSequences = stopSequences; + } else if (provider === "meta") { + inputBody.prompt = prompt; + inputBody.max_gen_len = maxTokens; + inputBody.temperature = temperature; + } else if (provider === "amazon") { + inputBody.inputText = prompt; + inputBody.textGenerationConfig = { + maxTokenCount: maxTokens, + temperature, + }; + } else if (provider === "cohere") { + inputBody.prompt = prompt; + inputBody.max_tokens = maxTokens; + inputBody.temperature = temperature; + inputBody.stop_sequences = stopSequences; + if (bedrockMethod === "invoke-with-response-stream") { + inputBody.stream = true; + } + } + return { ...inputBody, ...modelKwargs }; + } + + /** + * Extracts the generated text from the service response. + * @param provider The provider name. + * @param responseBody The response body from the service. + * @returns The generated text. + */ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + static prepareOutput(provider: string, responseBody: any): string { + if (provider === "anthropic") { + return responseBody.completion; + } else if (provider === "ai21") { + return responseBody?.completions?.[0]?.data?.text ?? ""; + } else if (provider === "cohere") { + return responseBody?.generations?.[0]?.text ?? responseBody?.text ?? ""; + } else if (provider === "meta") { + return responseBody.generation; + } + + // I haven't been able to get a response with more than one result in it. + return responseBody.results?.[0]?.outputText; + } +} diff --git a/libs/langchain-community/src/utils/chunk.ts b/libs/langchain-community/src/utils/chunk.ts new file mode 100644 index 000000000000..340ce0a46604 --- /dev/null +++ b/libs/langchain-community/src/utils/chunk.ts @@ -0,0 +1,8 @@ +export const chunkArray = (arr: T[], chunkSize: number) => + arr.reduce((chunks, elem, index) => { + const chunkIndex = Math.floor(index / chunkSize); + const chunk = chunks[chunkIndex] || []; + // eslint-disable-next-line no-param-reassign + chunks[chunkIndex] = chunk.concat([elem]); + return chunks; + }, [] as T[][]); diff --git a/libs/langchain-community/src/utils/convex.ts b/libs/langchain-community/src/utils/convex.ts new file mode 100644 index 000000000000..8638aa9d04d4 --- /dev/null +++ b/libs/langchain-community/src/utils/convex.ts @@ -0,0 +1,82 @@ +/* eslint-disable spaced-comment */ + +// eslint-disable-next-line import/no-extraneous-dependencies +import { + internalQueryGeneric as internalQuery, + internalMutationGeneric as internalMutation, +} from "convex/server"; +// eslint-disable-next-line import/no-extraneous-dependencies +import { GenericId, v } from "convex/values"; + +export const get = /*#__PURE__*/ internalQuery({ + args: { + id: /*#__PURE__*/ v.string(), + }, + handler: async (ctx, args) => { + const result = await ctx.db.get(args.id as GenericId); + return result; + }, +}); + +export const insert = /*#__PURE__*/ internalMutation({ + args: { + table: /*#__PURE__*/ v.string(), + document: /*#__PURE__*/ v.any(), + }, + handler: async (ctx, args) => { + await ctx.db.insert(args.table, args.document); + }, +}); + +export const lookup = /*#__PURE__*/ internalQuery({ + args: { + table: /*#__PURE__*/ v.string(), + index: /*#__PURE__*/ v.string(), + keyField: /*#__PURE__*/ v.string(), + key: /*#__PURE__*/ v.string(), + }, + handler: async (ctx, args) => { + const result = await ctx.db + .query(args.table) + .withIndex(args.index, (q) => q.eq(args.keyField, args.key)) + .collect(); + return result; + }, +}); + +export const upsert = /*#__PURE__*/ internalMutation({ + args: { + table: /*#__PURE__*/ v.string(), + index: /*#__PURE__*/ v.string(), + keyField: /*#__PURE__*/ v.string(), + key: /*#__PURE__*/ v.string(), + document: /*#__PURE__*/ v.any(), + }, + handler: async (ctx, args) => { + const existing = await ctx.db + .query(args.table) + .withIndex(args.index, (q) => q.eq(args.keyField, args.key)) + .unique(); + if (existing !== null) { + await ctx.db.replace(existing._id, args.document); + } else { + await ctx.db.insert(args.table, args.document); + } + }, +}); + +export const deleteMany = /*#__PURE__*/ internalMutation({ + args: { + table: /*#__PURE__*/ v.string(), + index: /*#__PURE__*/ v.string(), + keyField: /*#__PURE__*/ v.string(), + key: /*#__PURE__*/ v.string(), + }, + handler: async (ctx, args) => { + const existing = await ctx.db + .query(args.table) + .withIndex(args.index, (q) => q.eq(args.keyField, args.key)) + .collect(); + await Promise.all(existing.map((doc) => ctx.db.delete(doc._id))); + }, +}); diff --git a/libs/langchain-community/src/utils/event_source_parse.ts b/libs/langchain-community/src/utils/event_source_parse.ts new file mode 100644 index 000000000000..9a279a2b28d4 --- /dev/null +++ b/libs/langchain-community/src/utils/event_source_parse.ts @@ -0,0 +1,287 @@ +/* eslint-disable prefer-template */ +/* eslint-disable default-case */ +/* eslint-disable no-plusplus */ +// Adapted from https://github.com/gfortaine/fetch-event-source/blob/main/src/parse.ts +// due to a packaging issue in the original. +// MIT License +import { type Readable } from "stream"; +import { IterableReadableStream } from "@langchain/core/utils/stream"; + +export const EventStreamContentType = "text/event-stream"; + +/** + * Represents a message sent in an event stream + * https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format + */ +export interface EventSourceMessage { + /** The event ID to set the EventSource object's last event ID value. */ + id: string; + /** A string identifying the type of event described. */ + event: string; + /** The event data */ + data: string; + /** The reconnection interval (in milliseconds) to wait before retrying the connection */ + retry?: number; +} + +function isNodeJSReadable(x: unknown): x is Readable { + return x != null && typeof x === "object" && "on" in x; +} + +/** + * Converts a ReadableStream into a callback pattern. + * @param stream The input ReadableStream. + * @param onChunk A function that will be called on each new byte chunk in the stream. + * @returns {Promise} A promise that will be resolved when the stream closes. + */ +export async function getBytes( + stream: ReadableStream, + onChunk: (arr: Uint8Array, flush?: boolean) => void +) { + // stream is a Node.js Readable / PassThrough stream + // this can happen if node-fetch is polyfilled + if (isNodeJSReadable(stream)) { + return new Promise((resolve) => { + stream.on("readable", () => { + let chunk; + // eslint-disable-next-line no-constant-condition + while (true) { + chunk = stream.read(); + if (chunk == null) { + onChunk(new Uint8Array(), true); + break; + } + onChunk(chunk); + } + + resolve(); + }); + }); + } + + const reader = stream.getReader(); + // CHANGED: Introduced a "flush" mechanism to process potential pending messages when the stream ends. + // This change is essential to ensure that we capture every last piece of information from streams, + // such as those from Azure OpenAI, which may not terminate with a blank line. Without this + // mechanism, we risk ignoring a possibly significant last message. + // See https://github.com/langchain-ai/langchainjs/issues/1299 for details. + // eslint-disable-next-line no-constant-condition + while (true) { + const result = await reader.read(); + if (result.done) { + onChunk(new Uint8Array(), true); + break; + } + onChunk(result.value); + } +} + +const enum ControlChars { + NewLine = 10, + CarriageReturn = 13, + Space = 32, + Colon = 58, +} + +/** + * Parses arbitary byte chunks into EventSource line buffers. + * Each line should be of the format "field: value" and ends with \r, \n, or \r\n. + * @param onLine A function that will be called on each new EventSource line. + * @returns A function that should be called for each incoming byte chunk. + */ +export function getLines( + onLine: (line: Uint8Array, fieldLength: number, flush?: boolean) => void +) { + let buffer: Uint8Array | undefined; + let position: number; // current read position + let fieldLength: number; // length of the `field` portion of the line + let discardTrailingNewline = false; + + // return a function that can process each incoming byte chunk: + return function onChunk(arr: Uint8Array, flush?: boolean) { + if (flush) { + onLine(arr, 0, true); + return; + } + + if (buffer === undefined) { + buffer = arr; + position = 0; + fieldLength = -1; + } else { + // we're still parsing the old line. Append the new bytes into buffer: + buffer = concat(buffer, arr); + } + + const bufLength = buffer.length; + let lineStart = 0; // index where the current line starts + while (position < bufLength) { + if (discardTrailingNewline) { + if (buffer[position] === ControlChars.NewLine) { + lineStart = ++position; // skip to next char + } + + discardTrailingNewline = false; + } + + // start looking forward till the end of line: + let lineEnd = -1; // index of the \r or \n char + for (; position < bufLength && lineEnd === -1; ++position) { + switch (buffer[position]) { + case ControlChars.Colon: + if (fieldLength === -1) { + // first colon in line + fieldLength = position - lineStart; + } + break; + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore:7029 \r case below should fallthrough to \n: + case ControlChars.CarriageReturn: + discardTrailingNewline = true; + // eslint-disable-next-line no-fallthrough + case ControlChars.NewLine: + lineEnd = position; + break; + } + } + + if (lineEnd === -1) { + // We reached the end of the buffer but the line hasn't ended. + // Wait for the next arr and then continue parsing: + break; + } + + // we've reached the line end, send it out: + onLine(buffer.subarray(lineStart, lineEnd), fieldLength); + lineStart = position; // we're now on the next line + fieldLength = -1; + } + + if (lineStart === bufLength) { + buffer = undefined; // we've finished reading it + } else if (lineStart !== 0) { + // Create a new view into buffer beginning at lineStart so we don't + // need to copy over the previous lines when we get the new arr: + buffer = buffer.subarray(lineStart); + position -= lineStart; + } + }; +} + +/** + * Parses line buffers into EventSourceMessages. + * @param onId A function that will be called on each `id` field. + * @param onRetry A function that will be called on each `retry` field. + * @param onMessage A function that will be called on each message. + * @returns A function that should be called for each incoming line buffer. + */ +export function getMessages( + onMessage?: (msg: EventSourceMessage) => void, + onId?: (id: string) => void, + onRetry?: (retry: number) => void +) { + let message = newMessage(); + const decoder = new TextDecoder(); + + // return a function that can process each incoming line buffer: + return function onLine( + line: Uint8Array, + fieldLength: number, + flush?: boolean + ) { + if (flush) { + if (!isEmpty(message)) { + onMessage?.(message); + message = newMessage(); + } + return; + } + + if (line.length === 0) { + // empty line denotes end of message. Trigger the callback and start a new message: + onMessage?.(message); + message = newMessage(); + } else if (fieldLength > 0) { + // exclude comments and lines with no values + // line is of format ":" or ": " + // https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation + const field = decoder.decode(line.subarray(0, fieldLength)); + const valueOffset = + fieldLength + (line[fieldLength + 1] === ControlChars.Space ? 2 : 1); + const value = decoder.decode(line.subarray(valueOffset)); + + switch (field) { + case "data": + // if this message already has data, append the new value to the old. + // otherwise, just set to the new value: + message.data = message.data ? message.data + "\n" + value : value; // otherwise, + break; + case "event": + message.event = value; + break; + case "id": + onId?.((message.id = value)); + break; + case "retry": { + const retry = parseInt(value, 10); + if (!Number.isNaN(retry)) { + // per spec, ignore non-integers + onRetry?.((message.retry = retry)); + } + break; + } + } + } + }; +} + +function concat(a: Uint8Array, b: Uint8Array) { + const res = new Uint8Array(a.length + b.length); + res.set(a); + res.set(b, a.length); + return res; +} + +function newMessage(): EventSourceMessage { + // data, event, and id must be initialized to empty strings: + // https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation + // retry should be initialized to undefined so we return a consistent shape + // to the js engine all the time: https://mathiasbynens.be/notes/shapes-ics#takeaways + return { + data: "", + event: "", + id: "", + retry: undefined, + }; +} + +export function convertEventStreamToIterableReadableDataStream( + stream: ReadableStream +) { + const dataStream = new ReadableStream({ + async start(controller) { + const enqueueLine = getMessages((msg) => { + if (msg.data) controller.enqueue(msg.data); + }); + const onLine = ( + line: Uint8Array, + fieldLength: number, + flush?: boolean + ) => { + enqueueLine(line, fieldLength, flush); + if (flush) controller.close(); + }; + await getBytes(stream, getLines(onLine)); + }, + }); + return IterableReadableStream.fromReadableStream(dataStream); +} + +function isEmpty(message: EventSourceMessage): boolean { + return ( + message.data === "" && + message.event === "" && + message.id === "" && + message.retry === undefined + ); +} diff --git a/libs/langchain-community/src/utils/googlevertexai-connection.ts b/libs/langchain-community/src/utils/googlevertexai-connection.ts new file mode 100644 index 000000000000..f440abe364ac --- /dev/null +++ b/libs/langchain-community/src/utils/googlevertexai-connection.ts @@ -0,0 +1,426 @@ +import { BaseLanguageModelCallOptions } from "@langchain/core/language_models/base"; +import { + AsyncCaller, + AsyncCallerCallOptions, +} from "@langchain/core/utils/async_caller"; +import { GenerationChunk } from "@langchain/core/outputs"; +import type { + GoogleVertexAIBaseLLMInput, + GoogleVertexAIBasePrediction, + GoogleVertexAIConnectionParams, + GoogleVertexAILLMPredictions, + GoogleVertexAIModelParams, + GoogleResponse, + GoogleAbstractedClient, + GoogleAbstractedClientOps, + GoogleAbstractedClientOpsMethod, +} from "../types/googlevertexai-types.js"; + +export abstract class GoogleConnection< + CallOptions extends AsyncCallerCallOptions, + ResponseType extends GoogleResponse +> { + caller: AsyncCaller; + + client: GoogleAbstractedClient; + + streaming: boolean; + + constructor( + caller: AsyncCaller, + client: GoogleAbstractedClient, + streaming?: boolean + ) { + this.caller = caller; + this.client = client; + this.streaming = streaming ?? false; + } + + abstract buildUrl(): Promise; + + abstract buildMethod(): GoogleAbstractedClientOpsMethod; + + async _request( + data: unknown | undefined, + options: CallOptions + ): Promise { + const url = await this.buildUrl(); + const method = this.buildMethod(); + + const opts: GoogleAbstractedClientOps = { + url, + method, + }; + if (data && method === "POST") { + opts.data = data; + } + if (this.streaming) { + opts.responseType = "stream"; + } else { + opts.responseType = "json"; + } + + const callResponse = await this.caller.callWithOptions( + { signal: options?.signal }, + async () => this.client.request(opts) + ); + const response: unknown = callResponse; // Done for typecast safety, I guess + return response; + } +} + +export abstract class GoogleVertexAIConnection< + CallOptions extends AsyncCallerCallOptions, + ResponseType extends GoogleResponse, + AuthOptions + > + extends GoogleConnection + implements GoogleVertexAIConnectionParams +{ + endpoint = "us-central1-aiplatform.googleapis.com"; + + location = "us-central1"; + + apiVersion = "v1"; + + constructor( + fields: GoogleVertexAIConnectionParams | undefined, + caller: AsyncCaller, + client: GoogleAbstractedClient, + streaming?: boolean + ) { + super(caller, client, streaming); + this.caller = caller; + + this.endpoint = fields?.endpoint ?? this.endpoint; + this.location = fields?.location ?? this.location; + this.apiVersion = fields?.apiVersion ?? this.apiVersion; + this.client = client; + } + + buildMethod(): GoogleAbstractedClientOpsMethod { + return "POST"; + } +} + +export function complexValue(value: unknown): unknown { + if (value === null || typeof value === "undefined") { + // I dunno what to put here. An error, probably + return undefined; + } else if (typeof value === "object") { + if (Array.isArray(value)) { + return { + list_val: value.map((avalue) => complexValue(avalue)), + }; + } else { + const ret: Record = {}; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const v: Record = value; + Object.keys(v).forEach((key) => { + ret[key] = complexValue(v[key]); + }); + return { struct_val: ret }; + } + } else if (typeof value === "number") { + if (Number.isInteger(value)) { + return { int_val: value }; + } else { + return { float_val: value }; + } + } else { + return { + string_val: [value], + }; + } +} + +export function simpleValue(val: unknown): unknown { + if (val && typeof val === "object" && !Array.isArray(val)) { + // eslint-disable-next-line no-prototype-builtins + if (val.hasOwnProperty("stringVal")) { + return (val as { stringVal: string[] }).stringVal[0]; + + // eslint-disable-next-line no-prototype-builtins + } else if (val.hasOwnProperty("boolVal")) { + return (val as { boolVal: boolean[] }).boolVal[0]; + + // eslint-disable-next-line no-prototype-builtins + } else if (val.hasOwnProperty("listVal")) { + const { listVal } = val as { listVal: unknown[] }; + return listVal.map((aval) => simpleValue(aval)); + + // eslint-disable-next-line no-prototype-builtins + } else if (val.hasOwnProperty("structVal")) { + const ret: Record = {}; + const struct = (val as { structVal: Record }).structVal; + Object.keys(struct).forEach((key) => { + ret[key] = simpleValue(struct[key]); + }); + return ret; + } else { + const ret: Record = {}; + const struct = val as Record; + Object.keys(struct).forEach((key) => { + ret[key] = simpleValue(struct[key]); + }); + return ret; + } + } else if (Array.isArray(val)) { + return val.map((aval) => simpleValue(aval)); + } else { + return val; + } +} + +export class GoogleVertexAILLMConnection< + CallOptions extends BaseLanguageModelCallOptions, + InstanceType, + PredictionType extends GoogleVertexAIBasePrediction, + AuthOptions + > + extends GoogleVertexAIConnection< + CallOptions, + GoogleVertexAILLMResponse, + AuthOptions + > + implements GoogleVertexAIBaseLLMInput +{ + model: string; + + client: GoogleAbstractedClient; + + constructor( + fields: GoogleVertexAIBaseLLMInput | undefined, + caller: AsyncCaller, + client: GoogleAbstractedClient, + streaming?: boolean + ) { + super(fields, caller, client, streaming); + this.client = client; + this.model = fields?.model ?? this.model; + } + + async buildUrl(): Promise { + const projectId = await this.client.getProjectId(); + const method = this.streaming ? "serverStreamingPredict" : "predict"; + const url = `https://${this.endpoint}/v1/projects/${projectId}/locations/${this.location}/publishers/google/models/${this.model}:${method}`; + return url; + } + + formatStreamingData( + inputs: InstanceType[], + parameters: GoogleVertexAIModelParams + ): unknown { + return { + inputs: [inputs.map((i) => complexValue(i))], + parameters: complexValue(parameters), + }; + } + + formatStandardData( + instances: InstanceType[], + parameters: GoogleVertexAIModelParams + ): unknown { + return { + instances, + parameters, + }; + } + + formatData( + instances: InstanceType[], + parameters: GoogleVertexAIModelParams + ): unknown { + return this.streaming + ? this.formatStreamingData(instances, parameters) + : this.formatStandardData(instances, parameters); + } + + async request( + instances: InstanceType[], + parameters: GoogleVertexAIModelParams, + options: CallOptions + ): Promise> { + const data = this.formatData(instances, parameters); + const response = await this._request(data, options); + return response; + } +} + +export interface GoogleVertexAILLMResponse< + PredictionType extends GoogleVertexAIBasePrediction +> extends GoogleResponse { + data: GoogleVertexAIStream | GoogleVertexAILLMPredictions; +} + +export class GoogleVertexAIStream { + _buffer = ""; + + _bufferOpen = true; + + _firstRun = true; + + /** + * Add data to the buffer. This may cause chunks to be generated, if available. + * @param data + */ + appendBuffer(data: string): void { + this._buffer += data; + // Our first time, skip to the opening of the array + if (this._firstRun) { + this._skipTo("["); + this._firstRun = false; + } + + this._parseBuffer(); + } + + /** + * Indicate there is no more data that will be added to the text buffer. + * This should be called when all the data has been read and added to indicate + * that we should process everything remaining in the buffer. + */ + closeBuffer(): void { + this._bufferOpen = false; + this._parseBuffer(); + } + + /** + * Skip characters in the buffer till we get to the start of an object. + * Then attempt to read a full object. + * If we do read a full object, turn it into a chunk and send it to the chunk handler. + * Repeat this for as much as we can. + */ + _parseBuffer(): void { + let obj = null; + do { + this._skipTo("{"); + obj = this._getFullObject(); + if (obj !== null) { + const chunk = this._simplifyObject(obj); + this._handleChunk(chunk); + } + } while (obj !== null); + + if (!this._bufferOpen) { + // No more data will be added, and we have parsed everything we could, + // so everything else is garbage. + this._handleChunk(null); + this._buffer = ""; + } + } + + /** + * If the string is present, move the start of the buffer to the first occurrence + * of that string. This is useful for skipping over elements or parts that we're not + * really interested in parsing. (ie - the opening characters, comma separators, etc.) + * @param start The string to start the buffer with + */ + _skipTo(start: string): void { + const index = this._buffer.indexOf(start); + if (index > 0) { + this._buffer = this._buffer.slice(index); + } + } + + /** + * Given what is in the buffer, parse a single object out of it. + * If a complete object isn't available, return null. + * Assumes that we are at the start of an object to parse. + */ + _getFullObject(): object | null { + let ret: object | null = null; + + // Loop while we don't have something to return AND we have something in the buffer + let index = 0; + while (ret === null && this._buffer.length > index) { + // Advance to the next close bracket after our current index + index = this._buffer.indexOf("}", index + 1); + + // If we don't find one, exit with null + if (index === -1) { + return null; + } + + // If we have one, try to turn it into an object to return + try { + const objStr = this._buffer.substring(0, index + 1); + ret = JSON.parse(objStr); + + // We only get here if it parsed it ok + // If we did turn it into an object, remove it from the buffer + this._buffer = this._buffer.slice(index + 1); + } catch (xx) { + // It didn't parse it correctly, so we swallow the exception and continue + } + } + + return ret; + } + + _simplifyObject(obj: unknown): object { + return simpleValue(obj) as object; + } + + // Set up a potential Promise that the handler can resolve. + // eslint-disable-next-line @typescript-eslint/no-explicit-any + _chunkResolution: (chunk: any) => void; + + // If there is no Promise (it is null), the handler must add it to the queue + // eslint-disable-next-line @typescript-eslint/no-explicit-any + _chunkPending: Promise | null = null; + + // A queue that will collect chunks while there is no Promise + // eslint-disable-next-line @typescript-eslint/no-explicit-any + _chunkQueue: any[] = []; + + /** + * Register that we have another chunk available for consumption. + * If we are waiting for a chunk, resolve the promise waiting for it immediately. + * If not, then add it to the queue. + * @param chunk + */ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + _handleChunk(chunk: any): void { + if (this._chunkPending) { + this._chunkResolution(chunk); + this._chunkPending = null; + } else { + this._chunkQueue.push(chunk); + } + } + + /** + * Get the next chunk that is coming from the stream. + * This chunk may be null, usually indicating the last chunk in the stream. + */ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + async nextChunk(): Promise { + if (this._chunkQueue.length > 0) { + // If there is data in the queue, return the next queue chunk + return this._chunkQueue.shift() as GenerationChunk; + } else { + // Otherwise, set up a promise that handleChunk will cause to be resolved + this._chunkPending = new Promise((resolve) => { + this._chunkResolution = resolve; + }); + return this._chunkPending; + } + } + + /** + * Is the stream done? + * A stream is only done if all of the following are true: + * - There is no more data to be added to the text buffer + * - There is no more data in the text buffer + * - There are no chunks that are waiting to be consumed + */ + get streamDone(): boolean { + return ( + !this._bufferOpen && + this._buffer.length === 0 && + this._chunkQueue.length === 0 && + this._chunkPending === null + ); + } +} diff --git a/libs/langchain-community/src/utils/googlevertexai-gauth.ts b/libs/langchain-community/src/utils/googlevertexai-gauth.ts new file mode 100644 index 000000000000..e952391ef398 --- /dev/null +++ b/libs/langchain-community/src/utils/googlevertexai-gauth.ts @@ -0,0 +1,38 @@ +import { Readable } from "stream"; +import { GoogleAuth, GoogleAuthOptions } from "google-auth-library"; +import { + GoogleAbstractedClient, + GoogleAbstractedClientOps, +} from "../types/googlevertexai-types.js"; +import { GoogleVertexAIStream } from "./googlevertexai-connection.js"; + +class GoogleVertexAINodeStream extends GoogleVertexAIStream { + constructor(data: Readable) { + super(); + + data.on("data", (data) => this.appendBuffer(data.toString())); + data.on("end", () => this.closeBuffer()); + } +} + +export class GAuthClient implements GoogleAbstractedClient { + gauth: GoogleAuth; + + constructor(options?: GoogleAuthOptions) { + this.gauth = new GoogleAuth(options); + } + + async getProjectId(): Promise { + return this.gauth.getProjectId(); + } + + async request(opts: GoogleAbstractedClientOps): Promise { + const ret = await this.gauth.request(opts); + return opts.responseType !== "stream" + ? ret + : { + ...ret, + data: new GoogleVertexAINodeStream(ret.data), + }; + } +} diff --git a/libs/langchain-community/src/utils/googlevertexai-webauth.ts b/libs/langchain-community/src/utils/googlevertexai-webauth.ts new file mode 100644 index 000000000000..e5a5f39ff11a --- /dev/null +++ b/libs/langchain-community/src/utils/googlevertexai-webauth.ts @@ -0,0 +1,119 @@ +import { + getAccessToken, + getCredentials, + Credentials, +} from "web-auth-library/google"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import type { + GoogleAbstractedClient, + GoogleAbstractedClientOps, +} from "../types/googlevertexai-types.js"; +import { GoogleVertexAIStream } from "./googlevertexai-connection.js"; + +class GoogleVertexAIResponseStream extends GoogleVertexAIStream { + decoder: TextDecoder; + + constructor(body: ReadableStream | null) { + super(); + this.decoder = new TextDecoder(); + if (body) { + void this.run(body); + } else { + console.error("Unexpected empty body while streaming"); + } + } + + async run(body: ReadableStream) { + const reader = body.getReader(); + let isDone = false; + while (!isDone) { + const { value, done } = await reader.read(); + if (!done) { + const svalue = this.decoder.decode(value); + this.appendBuffer(svalue); + } else { + isDone = done; + this.closeBuffer(); + } + } + } +} + +export type WebGoogleAuthOptions = { + credentials: string | Credentials; + scope?: string | string[]; + accessToken?: string; +}; + +export class WebGoogleAuth implements GoogleAbstractedClient { + options: WebGoogleAuthOptions; + + constructor(options?: WebGoogleAuthOptions) { + const accessToken = options?.accessToken; + + const credentials = + options?.credentials ?? + getEnvironmentVariable("GOOGLE_VERTEX_AI_WEB_CREDENTIALS"); + if (credentials === undefined) + throw new Error( + `Credentials not found. Please set the GOOGLE_VERTEX_AI_WEB_CREDENTIALS environment variable or pass credentials into "authOptions.credentials".` + ); + + const scope = + options?.scope ?? "https://www.googleapis.com/auth/cloud-platform"; + + this.options = { ...options, accessToken, credentials, scope }; + } + + async getProjectId() { + const credentials = getCredentials(this.options.credentials); + return credentials.project_id; + } + + async request(opts: GoogleAbstractedClientOps) { + let { accessToken } = this.options; + + if (accessToken === undefined) { + accessToken = await getAccessToken(this.options); + } + + if (opts.url == null) throw new Error("Missing URL"); + const fetchOptions: { + method?: string; + headers: Record; + body?: string; + } = { + method: opts.method, + headers: { + Authorization: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + }; + if (opts.data !== undefined) { + fetchOptions.body = JSON.stringify(opts.data); + } + + const res = await fetch(opts.url, fetchOptions); + + if (!res.ok) { + const error = new Error( + `Could not get access token for Vertex AI with status code: ${res.status}` + ); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (error as any).response = res; + throw error; + } + + return { + data: + opts.responseType === "json" + ? await res.json() + : new GoogleVertexAIResponseStream(res.body), + config: {}, + status: res.status, + statusText: res.statusText, + headers: res.headers, + request: { responseURL: res.url }, + }; + } +} diff --git a/libs/langchain-community/src/utils/iflytek_websocket_stream.ts b/libs/langchain-community/src/utils/iflytek_websocket_stream.ts new file mode 100644 index 000000000000..85766ad37281 --- /dev/null +++ b/libs/langchain-community/src/utils/iflytek_websocket_stream.ts @@ -0,0 +1,95 @@ +export interface WebSocketConnection< + T extends Uint8Array | string = Uint8Array | string +> { + readable: ReadableStream; + writable: WritableStream; + protocol: string; + extensions: string; +} + +export interface WebSocketCloseInfo { + code?: number; + reason?: string; +} + +export interface WebSocketStreamOptions { + protocols?: string[]; + signal?: AbortSignal; +} + +/** + * [WebSocket](https://developer.mozilla.org/en-US/docs/Web/API/WebSocket) with [Streams API](https://developer.mozilla.org/en-US/docs/Web/API/Streams_API) + * + * @see https://web.dev/websocketstream/ + */ +export abstract class BaseWebSocketStream< + T extends Uint8Array | string = Uint8Array | string +> { + readonly url: string; + + readonly connection: Promise>; + + readonly closed: Promise; + + readonly close: (closeInfo?: WebSocketCloseInfo) => void; + + constructor(url: string, options: WebSocketStreamOptions = {}) { + if (options.signal?.aborted) { + throw new DOMException("This operation was aborted", "AbortError"); + } + + this.url = url; + + const ws = this.openWebSocket(url, options); + + const closeWithInfo = ({ code, reason }: WebSocketCloseInfo = {}) => + ws.close(code, reason); + + this.connection = new Promise((resolve, reject) => { + ws.onopen = () => { + resolve({ + readable: new ReadableStream({ + start(controller) { + ws.onmessage = ({ data }) => controller.enqueue(data); + ws.onerror = (e) => controller.error(e); + }, + cancel: closeWithInfo, + }), + writable: new WritableStream({ + write(chunk) { + ws.send(chunk); + }, + abort() { + ws.close(); + }, + close: closeWithInfo, + }), + protocol: ws.protocol, + extensions: ws.extensions, + }); + ws.removeEventListener("error", reject); + }; + ws.addEventListener("error", reject); + }); + + this.closed = new Promise((resolve, reject) => { + ws.onclose = ({ code, reason }) => { + resolve({ code, reason }); + ws.removeEventListener("error", reject); + }; + ws.addEventListener("error", reject); + }); + + if (options.signal) { + // eslint-disable-next-line no-param-reassign + options.signal.onabort = () => ws.close(); + } + + this.close = closeWithInfo; + } + + abstract openWebSocket( + url: string, + options: WebSocketStreamOptions + ): WebSocket; +} diff --git a/libs/langchain-community/src/utils/llama_cpp.ts b/libs/langchain-community/src/utils/llama_cpp.ts new file mode 100644 index 000000000000..d961ab510942 --- /dev/null +++ b/libs/langchain-community/src/utils/llama_cpp.ts @@ -0,0 +1,79 @@ +import { LlamaModel, LlamaContext, LlamaChatSession } from "node-llama-cpp"; + +/** + * Note that the modelPath is the only required parameter. For testing you + * can set this in the environment variable `LLAMA_PATH`. + */ +export interface LlamaBaseCppInputs { + /** Prompt processing batch size. */ + batchSize?: number; + /** Text context size. */ + contextSize?: number; + /** Embedding mode only. */ + embedding?: boolean; + /** Use fp16 for KV cache. */ + f16Kv?: boolean; + /** Number of layers to store in VRAM. */ + gpuLayers?: number; + /** The llama_eval() call computes all logits, not just the last one. */ + logitsAll?: boolean; + /** */ + maxTokens?: number; + /** Path to the model on the filesystem. */ + modelPath: string; + /** Add the begining of sentence token. */ + prependBos?: boolean; + /** If null, a random seed will be used. */ + seed?: null | number; + /** The randomness of the responses, e.g. 0.1 deterministic, 1.5 creative, 0.8 balanced, 0 disables. */ + temperature?: number; + /** Number of threads to use to evaluate tokens. */ + threads?: number; + /** Trim whitespace from the end of the generated text Disabled by default. */ + trimWhitespaceSuffix?: boolean; + /** Consider the n most likely tokens, where n is 1 to vocabulary size, 0 disables (uses full vocabulary). Note: only applies when `temperature` > 0. */ + topK?: number; + /** Selects the smallest token set whose probability exceeds P, where P is between 0 - 1, 1 disables. Note: only applies when `temperature` > 0. */ + topP?: number; + /** Force system to keep model in RAM. */ + useMlock?: boolean; + /** Use mmap if possible. */ + useMmap?: boolean; + /** Only load the vocabulary, no weights. */ + vocabOnly?: boolean; +} + +export function createLlamaModel(inputs: LlamaBaseCppInputs): LlamaModel { + const options = { + gpuLayers: inputs?.gpuLayers, + modelPath: inputs.modelPath, + useMlock: inputs?.useMlock, + useMmap: inputs?.useMmap, + vocabOnly: inputs?.vocabOnly, + }; + + return new LlamaModel(options); +} + +export function createLlamaContext( + model: LlamaModel, + inputs: LlamaBaseCppInputs +): LlamaContext { + const options = { + batchSize: inputs?.batchSize, + contextSize: inputs?.contextSize, + embedding: inputs?.embedding, + f16Kv: inputs?.f16Kv, + logitsAll: inputs?.logitsAll, + model, + prependBos: inputs?.prependBos, + seed: inputs?.seed, + threads: inputs?.threads, + }; + + return new LlamaContext(options); +} + +export function createLlamaSession(context: LlamaContext): LlamaChatSession { + return new LlamaChatSession({ context }); +} diff --git a/libs/langchain-community/src/utils/momento.ts b/libs/langchain-community/src/utils/momento.ts new file mode 100644 index 000000000000..2ef4666ed34e --- /dev/null +++ b/libs/langchain-community/src/utils/momento.ts @@ -0,0 +1,26 @@ +/* eslint-disable no-instanceof/no-instanceof */ +import { ICacheClient, CreateCache } from "@gomomento/sdk"; + +/** + * Utility function to ensure that a Momento cache exists. + * If the cache does not exist, it is created. + * + * @param client The Momento cache client. + * @param cacheName The name of the cache to ensure exists. + */ +export async function ensureCacheExists( + client: ICacheClient, + cacheName: string +): Promise { + const createResponse = await client.createCache(cacheName); + if ( + createResponse instanceof CreateCache.Success || + createResponse instanceof CreateCache.AlreadyExists + ) { + // pass + } else if (createResponse instanceof CreateCache.Error) { + throw createResponse.innerException(); + } else { + throw new Error(`Unknown response type: ${createResponse.toString()}`); + } +} diff --git a/libs/langchain-community/src/utils/ollama.ts b/libs/langchain-community/src/utils/ollama.ts new file mode 100644 index 000000000000..30f675f0d9b3 --- /dev/null +++ b/libs/langchain-community/src/utils/ollama.ts @@ -0,0 +1,146 @@ +import type { BaseLanguageModelCallOptions } from "@langchain/core/language_models/base"; +import { IterableReadableStream } from "@langchain/core/utils/stream"; +import type { StringWithAutocomplete } from "@langchain/core/utils/types"; + +export interface OllamaInput { + embeddingOnly?: boolean; + f16KV?: boolean; + frequencyPenalty?: number; + logitsAll?: boolean; + lowVram?: boolean; + mainGpu?: number; + model?: string; + baseUrl?: string; + mirostat?: number; + mirostatEta?: number; + mirostatTau?: number; + numBatch?: number; + numCtx?: number; + numGpu?: number; + numGqa?: number; + numKeep?: number; + numThread?: number; + penalizeNewline?: boolean; + presencePenalty?: number; + repeatLastN?: number; + repeatPenalty?: number; + ropeFrequencyBase?: number; + ropeFrequencyScale?: number; + temperature?: number; + stop?: string[]; + tfsZ?: number; + topK?: number; + topP?: number; + typicalP?: number; + useMLock?: boolean; + useMMap?: boolean; + vocabOnly?: boolean; + format?: StringWithAutocomplete<"json">; +} + +export interface OllamaRequestParams { + model: string; + prompt: string; + format?: StringWithAutocomplete<"json">; + options: { + embedding_only?: boolean; + f16_kv?: boolean; + frequency_penalty?: number; + logits_all?: boolean; + low_vram?: boolean; + main_gpu?: number; + mirostat?: number; + mirostat_eta?: number; + mirostat_tau?: number; + num_batch?: number; + num_ctx?: number; + num_gpu?: number; + num_gqa?: number; + num_keep?: number; + num_thread?: number; + penalize_newline?: boolean; + presence_penalty?: number; + repeat_last_n?: number; + repeat_penalty?: number; + rope_frequency_base?: number; + rope_frequency_scale?: number; + temperature?: number; + stop?: string[]; + tfs_z?: number; + top_k?: number; + top_p?: number; + typical_p?: number; + use_mlock?: boolean; + use_mmap?: boolean; + vocab_only?: boolean; + }; +} + +export interface OllamaCallOptions extends BaseLanguageModelCallOptions {} + +export type OllamaGenerationChunk = { + response: string; + model: string; + created_at: string; + done: boolean; + total_duration?: number; + load_duration?: number; + prompt_eval_count?: number; + prompt_eval_duration?: number; + eval_count?: number; + eval_duration?: number; +}; + +export async function* createOllamaStream( + baseUrl: string, + params: OllamaRequestParams, + options: OllamaCallOptions +): AsyncGenerator { + let formattedBaseUrl = baseUrl; + if (formattedBaseUrl.startsWith("http://localhost:")) { + // Node 18 has issues with resolving "localhost" + // See https://github.com/node-fetch/node-fetch/issues/1624 + formattedBaseUrl = formattedBaseUrl.replace( + "http://localhost:", + "http://127.0.0.1:" + ); + } + const response = await fetch(`${formattedBaseUrl}/api/generate`, { + method: "POST", + body: JSON.stringify(params), + headers: { + "Content-Type": "application/json", + }, + signal: options.signal, + }); + if (!response.ok) { + const json = await response.json(); + const error = new Error( + `Ollama call failed with status code ${response.status}: ${json.error}` + ); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (error as any).response = response; + throw error; + } + if (!response.body) { + throw new Error( + "Could not begin Ollama stream. Please check the given URL and try again." + ); + } + + const stream = IterableReadableStream.fromReadableStream(response.body); + const decoder = new TextDecoder(); + let extra = ""; + for await (const chunk of stream) { + const decoded = extra + decoder.decode(chunk); + const lines = decoded.split("\n"); + extra = lines.pop() || ""; + for (const line of lines) { + try { + yield JSON.parse(line); + } catch (e) { + console.warn(`Received a non-JSON parseable chunk: ${line}`); + } + } + } +} diff --git a/libs/langchain-community/src/utils/testing.ts b/libs/langchain-community/src/utils/testing.ts new file mode 100644 index 000000000000..205ebc941f5b --- /dev/null +++ b/libs/langchain-community/src/utils/testing.ts @@ -0,0 +1,107 @@ +import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings"; + +/** + * A class that provides fake embeddings by overriding the embedDocuments + * and embedQuery methods to return fixed values. + */ +export class FakeEmbeddings extends Embeddings { + constructor(params?: EmbeddingsParams) { + super(params ?? {}); + } + + /** + * Generates fixed embeddings for a list of documents. + * @param documents List of documents to generate embeddings for. + * @returns A promise that resolves with a list of fixed embeddings for each document. + */ + embedDocuments(documents: string[]): Promise { + return Promise.resolve(documents.map(() => [0.1, 0.2, 0.3, 0.4])); + } + + /** + * Generates a fixed embedding for a query. + * @param _ The query to generate an embedding for. + * @returns A promise that resolves with a fixed embedding for the query. + */ + embedQuery(_: string): Promise { + return Promise.resolve([0.1, 0.2, 0.3, 0.4]); + } +} + +/** + * An interface that defines additional parameters specific to the + * SyntheticEmbeddings class. + */ +interface SyntheticEmbeddingsParams extends EmbeddingsParams { + vectorSize: number; +} + +/** + * A class that provides synthetic embeddings by overriding the + * embedDocuments and embedQuery methods to generate embeddings based on + * the input documents. The embeddings are generated by converting each + * document into chunks, calculating a numerical value for each chunk, and + * returning an array of these values as the embedding. + */ +export class SyntheticEmbeddings + extends Embeddings + implements SyntheticEmbeddingsParams +{ + vectorSize: number; + + constructor(params?: SyntheticEmbeddingsParams) { + super(params ?? {}); + this.vectorSize = params?.vectorSize ?? 4; + } + + /** + * Generates synthetic embeddings for a list of documents. + * @param documents List of documents to generate embeddings for. + * @returns A promise that resolves with a list of synthetic embeddings for each document. + */ + async embedDocuments(documents: string[]): Promise { + return Promise.all(documents.map((doc) => this.embedQuery(doc))); + } + + /** + * Generates a synthetic embedding for a document. The document is + * converted into chunks, a numerical value is calculated for each chunk, + * and an array of these values is returned as the embedding. + * @param document The document to generate an embedding for. + * @returns A promise that resolves with a synthetic embedding for the document. + */ + async embedQuery(document: string): Promise { + let doc = document; + + // Only use the letters (and space) from the document, and make them lower case + doc = doc.toLowerCase().replaceAll(/[^a-z ]/g, ""); + + // Pad the document to make sure it has a divisible number of chunks + const padMod = doc.length % this.vectorSize; + const padGapSize = padMod === 0 ? 0 : this.vectorSize - padMod; + const padSize = doc.length + padGapSize; + doc = doc.padEnd(padSize, " "); + + // Break it into chunks + const chunkSize = doc.length / this.vectorSize; + const docChunk = []; + for (let co = 0; co < doc.length; co += chunkSize) { + docChunk.push(doc.slice(co, co + chunkSize)); + } + + // Turn each chunk into a number + const ret: number[] = docChunk.map((s) => { + let sum = 0; + // Get a total value by adding the value of each character in the string + for (let co = 0; co < s.length; co += 1) { + sum += s === " " ? 0 : s.charCodeAt(co); + } + // Reduce this to a number between 0 and 25 inclusive + // Then get the fractional number by dividing it by 26 + const ret = (sum % 26) / 26; + return ret; + }); + + return ret; + } +} diff --git a/libs/langchain-community/src/utils/time.ts b/libs/langchain-community/src/utils/time.ts new file mode 100644 index 000000000000..f6f5263e4722 --- /dev/null +++ b/libs/langchain-community/src/utils/time.ts @@ -0,0 +1,10 @@ +/** + * Sleep for a given amount of time. + * @param ms - The number of milliseconds to sleep for. Defaults to 1000. + * @returns A promise that resolves when the sleep is complete. + */ +export async function sleep(ms = 1000): Promise { + return new Promise((resolve) => { + setTimeout(resolve, ms); + }); +} diff --git a/libs/langchain-community/src/vectorstores/analyticdb.ts b/libs/langchain-community/src/vectorstores/analyticdb.ts new file mode 100644 index 000000000000..bd4f5caffabe --- /dev/null +++ b/libs/langchain-community/src/vectorstores/analyticdb.ts @@ -0,0 +1,390 @@ +import * as uuid from "uuid"; +import pg, { Pool, PoolConfig } from "pg"; +import { from as copyFrom } from "pg-copy-streams"; +import { pipeline } from "node:stream/promises"; +import { Readable } from "node:stream"; + +import { VectorStore } from "@langchain/core/vectorstores"; +import { Embeddings } from "@langchain/core/embeddings"; +import { Document } from "@langchain/core/documents"; + +const _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain_document"; + +/** + * Interface defining the arguments required to create an instance of + * `AnalyticDBVectorStore`. + */ +export interface AnalyticDBArgs { + connectionOptions: PoolConfig; + embeddingDimension?: number; + collectionName?: string; + preDeleteCollection?: boolean; +} + +/** + * Interface defining the structure of data to be stored in the + * AnalyticDB. + */ +interface DataType { + id: string; + embedding: number[]; + document: string; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + metadata: Record; +} + +/** + * Class that provides methods for creating and managing a collection of + * documents in an AnalyticDB, adding documents or vectors to the + * collection, performing similarity search on vectors, and creating an + * instance of `AnalyticDBVectorStore` from texts or documents. + */ +export class AnalyticDBVectorStore extends VectorStore { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + declare FilterType: Record; + + private pool: Pool; + + private embeddingDimension?: number; + + private collectionName: string; + + private preDeleteCollection: boolean; + + private isCreateCollection = false; + + _vectorstoreType(): string { + return "analyticdb"; + } + + constructor(embeddings: Embeddings, args: AnalyticDBArgs) { + super(embeddings, args); + + this.pool = new pg.Pool({ + host: args.connectionOptions.host, + port: args.connectionOptions.port, + database: args.connectionOptions.database, + user: args.connectionOptions.user, + password: args.connectionOptions.password, + }); + this.embeddingDimension = args.embeddingDimension; + this.collectionName = + args.collectionName || _LANGCHAIN_DEFAULT_COLLECTION_NAME; + this.preDeleteCollection = args.preDeleteCollection || false; + } + + /** + * Closes all the clients in the pool and terminates the pool. + * @returns Promise that resolves when all clients are closed and the pool is terminated. + */ + async end(): Promise { + return this.pool.end(); + } + + /** + * Creates a new table in the database if it does not already exist. The + * table is created with columns for id, embedding, document, and + * metadata. An index is also created on the embedding column if it does + * not already exist. + * @returns Promise that resolves when the table and index are created. + */ + async createTableIfNotExists(): Promise { + if (!this.embeddingDimension) { + this.embeddingDimension = ( + await this.embeddings.embedQuery("test") + ).length; + } + const client = await this.pool.connect(); + try { + await client.query("BEGIN"); + // Create the table if it doesn't exist + await client.query(` + CREATE TABLE IF NOT EXISTS ${this.collectionName} ( + id TEXT PRIMARY KEY DEFAULT NULL, + embedding REAL[], + document TEXT, + metadata JSON + ); + `); + + // Check if the index exists + const indexName = `${this.collectionName}_embedding_idx`; + const indexQuery = ` + SELECT 1 + FROM pg_indexes + WHERE indexname = '${indexName}'; + `; + const result = await client.query(indexQuery); + + // Create the index if it doesn't exist + if (result.rowCount === 0) { + const indexStatement = ` + CREATE INDEX ${indexName} + ON ${this.collectionName} USING ann(embedding) + WITH ( + "dim" = ${this.embeddingDimension}, + "hnsw_m" = 100 + ); + `; + await client.query(indexStatement); + } + await client.query("COMMIT"); + } catch (err) { + await client.query("ROLLBACK"); + throw err; + } finally { + client.release(); + } + } + + /** + * Deletes the collection from the database if it exists. + * @returns Promise that resolves when the collection is deleted. + */ + async deleteCollection(): Promise { + const dropStatement = `DROP TABLE IF EXISTS ${this.collectionName};`; + await this.pool.query(dropStatement); + } + + /** + * Creates a new collection in the database. If `preDeleteCollection` is + * true, any existing collection with the same name is deleted before the + * new collection is created. + * @returns Promise that resolves when the collection is created. + */ + async createCollection(): Promise { + if (this.preDeleteCollection) { + await this.deleteCollection(); + } + await this.createTableIfNotExists(); + this.isCreateCollection = true; + } + + /** + * Adds an array of documents to the collection. The documents are first + * converted to vectors using the `embedDocuments` method of the + * `embeddings` instance. + * @param documents Array of Document instances to be added to the collection. + * @returns Promise that resolves when the documents are added. + */ + async addDocuments(documents: Document[]): Promise { + const texts = documents.map(({ pageContent }) => pageContent); + return this.addVectors( + await this.embeddings.embedDocuments(texts), + documents + ); + } + + /** + * Adds an array of vectors and corresponding documents to the collection. + * The vectors and documents are batch inserted into the database. + * @param vectors Array of vectors to be added to the collection. + * @param documents Array of Document instances corresponding to the vectors. + * @returns Promise that resolves when the vectors and documents are added. + */ + async addVectors(vectors: number[][], documents: Document[]): Promise { + if (vectors.length === 0) { + return; + } + if (vectors.length !== documents.length) { + throw new Error(`Vectors and documents must have the same length`); + } + if (!this.embeddingDimension) { + this.embeddingDimension = ( + await this.embeddings.embedQuery("test") + ).length; + } + if (vectors[0].length !== this.embeddingDimension) { + throw new Error( + `Vectors must have the same length as the number of dimensions (${this.embeddingDimension})` + ); + } + + if (!this.isCreateCollection) { + await this.createCollection(); + } + + const client = await this.pool.connect(); + try { + const chunkSize = 500; + const chunksTableData: DataType[] = []; + + for (let i = 0; i < documents.length; i += 1) { + chunksTableData.push({ + id: uuid.v4(), + embedding: vectors[i], + document: documents[i].pageContent, + metadata: documents[i].metadata, + }); + + // Execute the batch insert when the batch size is reached + if (chunksTableData.length === chunkSize) { + const rs = new Readable(); + let currentIndex = 0; + rs._read = function () { + if (currentIndex === chunkSize) { + rs.push(null); + } else { + const data = chunksTableData[currentIndex]; + rs.push( + `${data.id}\t{${data.embedding.join(",")}}\t${ + data.document + }\t${JSON.stringify(data.metadata)}\n` + ); + currentIndex += 1; + } + }; + const ws = client.query( + copyFrom( + `COPY ${this.collectionName}(id, embedding, document, metadata) FROM STDIN` + ) + ); + + await pipeline(rs, ws); + // Clear the chunksTableData list for the next batch + chunksTableData.length = 0; + } + } + + // Insert any remaining records that didn't make up a full batch + if (chunksTableData.length > 0) { + const rs = new Readable(); + let currentIndex = 0; + rs._read = function () { + if (currentIndex === chunksTableData.length) { + rs.push(null); + } else { + const data = chunksTableData[currentIndex]; + rs.push( + `${data.id}\t{${data.embedding.join(",")}}\t${ + data.document + }\t${JSON.stringify(data.metadata)}\n` + ); + currentIndex += 1; + } + }; + const ws = client.query( + copyFrom( + `COPY ${this.collectionName}(id, embedding, document, metadata) FROM STDIN` + ) + ); + await pipeline(rs, ws); + } + } finally { + client.release(); + } + } + + /** + * Performs a similarity search on the vectors in the collection. The + * search is performed using the given query vector and returns the top k + * most similar vectors along with their corresponding documents and + * similarity scores. + * @param query Query vector for the similarity search. + * @param k Number of top similar vectors to return. + * @param filter Optional. Filter to apply on the metadata of the documents. + * @returns Promise that resolves to an array of tuples, each containing a Document instance and its similarity score. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: this["FilterType"] + ): Promise<[Document, number][]> { + if (!this.isCreateCollection) { + await this.createCollection(); + } + + let filterCondition = ""; + const filterEntries = filter ? Object.entries(filter) : []; + if (filterEntries.length > 0) { + const conditions = filterEntries.map( + (_, index) => `metadata->>$${2 * index + 3} = $${2 * index + 4}` + ); + filterCondition = `WHERE ${conditions.join(" AND ")}`; + } + + const sqlQuery = ` + SELECT *, l2_distance(embedding, $1::real[]) AS distance + FROM ${this.collectionName} + ${filterCondition} + ORDER BY embedding <-> $1 + LIMIT $2; + `; + + // Execute the query and fetch the results + const { rows } = await this.pool.query(sqlQuery, [ + query, + k, + ...filterEntries.flatMap(([key, value]) => [key, value]), + ]); + + const result: [Document, number][] = rows.map((row) => [ + new Document({ pageContent: row.document, metadata: row.metadata }), + row.distance, + ]); + + return result; + } + + /** + * Creates an instance of `AnalyticDBVectorStore` from an array of texts + * and corresponding metadata. The texts are first converted to Document + * instances before being added to the collection. + * @param texts Array of texts to be added to the collection. + * @param metadatas Array or object of metadata corresponding to the texts. + * @param embeddings Embeddings instance used to convert the texts to vectors. + * @param dbConfig Configuration for the AnalyticDB. + * @returns Promise that resolves to an instance of `AnalyticDBVectorStore`. + */ + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + dbConfig: AnalyticDBArgs + ): Promise { + const docs = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + return AnalyticDBVectorStore.fromDocuments(docs, embeddings, dbConfig); + } + + /** + * Creates an instance of `AnalyticDBVectorStore` from an array of + * Document instances. The documents are added to the collection. + * @param docs Array of Document instances to be added to the collection. + * @param embeddings Embeddings instance used to convert the documents to vectors. + * @param dbConfig Configuration for the AnalyticDB. + * @returns Promise that resolves to an instance of `AnalyticDBVectorStore`. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig: AnalyticDBArgs + ): Promise { + const instance = new this(embeddings, dbConfig); + await instance.addDocuments(docs); + return instance; + } + + /** + * Creates an instance of `AnalyticDBVectorStore` from an existing index + * in the database. A new collection is created in the database. + * @param embeddings Embeddings instance used to convert the documents to vectors. + * @param dbConfig Configuration for the AnalyticDB. + * @returns Promise that resolves to an instance of `AnalyticDBVectorStore`. + */ + static async fromExistingIndex( + embeddings: Embeddings, + dbConfig: AnalyticDBArgs + ): Promise { + const instance = new this(embeddings, dbConfig); + await instance.createCollection(); + return instance; + } +} diff --git a/libs/langchain-community/src/vectorstores/cassandra.ts b/libs/langchain-community/src/vectorstores/cassandra.ts new file mode 100644 index 000000000000..85feef551e48 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/cassandra.ts @@ -0,0 +1,584 @@ +/* eslint-disable prefer-template */ +import { Client as CassandraClient, DseClientOptions } from "cassandra-driver"; + +import { + AsyncCaller, + AsyncCallerParams, +} from "@langchain/core/utils/async_caller"; +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; + +export interface Column { + type: string; + name: string; + partition?: boolean; +} + +export interface Index { + name: string; + value: string; +} + +export interface Filter { + name: string; + value: unknown; + operator?: string; +} + +export type WhereClause = Filter[] | Filter | Record; + +export type SupportedVectorTypes = "cosine" | "dot_product" | "euclidean"; + +export interface CassandraLibArgs extends DseClientOptions, AsyncCallerParams { + table: string; + keyspace: string; + vectorType?: SupportedVectorTypes; + dimensions: number; + primaryKey: Column | Column[]; + metadataColumns: Column[]; + withClause?: string; + indices?: Index[]; + batchSize?: number; +} + +/** + * Class for interacting with the Cassandra database. It extends the + * VectorStore class and provides methods for adding vectors and + * documents, searching for similar vectors, and creating instances from + * texts or documents. + */ +export class CassandraStore extends VectorStore { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + declare FilterType: WhereClause; + + private client: CassandraClient; + + private readonly vectorType: SupportedVectorTypes; + + private readonly dimensions: number; + + private readonly keyspace: string; + + private primaryKey: Column[]; + + private metadataColumns: Column[]; + + private withClause: string; + + private selectColumns: string; + + private readonly table: string; + + private indices: Index[]; + + private isInitialized = false; + + asyncCaller: AsyncCaller; + + private readonly batchSize: number; + + _vectorstoreType(): string { + return "cassandra"; + } + + constructor(embeddings: Embeddings, args: CassandraLibArgs) { + super(embeddings, args); + + const { + indices = [], + maxConcurrency = 25, + withClause = "", + batchSize = 1, + vectorType = "cosine", + dimensions, + keyspace, + table, + primaryKey, + metadataColumns, + } = args; + + const argsWithDefaults = { + ...args, + indices, + maxConcurrency, + withClause, + batchSize, + vectorType, + }; + this.asyncCaller = new AsyncCaller(argsWithDefaults); + this.client = new CassandraClient(argsWithDefaults); + + // Assign properties + this.vectorType = vectorType; + this.dimensions = dimensions; + this.keyspace = keyspace; + this.table = table; + this.primaryKey = Array.isArray(primaryKey) ? primaryKey : [primaryKey]; + this.metadataColumns = metadataColumns; + this.withClause = withClause.trim().replace(/^with\s*/i, ""); + this.indices = indices; + this.batchSize = batchSize >= 1 ? batchSize : 1; + } + + /** + * Method to save vectors to the Cassandra database. + * @param vectors Vectors to save. + * @param documents The documents associated with the vectors. + * @returns Promise that resolves when the vectors have been added. + */ + async addVectors(vectors: number[][], documents: Document[]): Promise { + if (vectors.length === 0) { + return; + } + + if (!this.isInitialized) { + await this.initialize(); + } + + await this.insertAll(vectors, documents); + } + + /** + * Method to add documents to the Cassandra database. + * @param documents The documents to add. + * @returns Promise that resolves when the documents have been added. + */ + async addDocuments(documents: Document[]): Promise { + return this.addVectors( + await this.embeddings.embedDocuments(documents.map((d) => d.pageContent)), + documents + ); + } + + /** + * Method to search for vectors that are similar to a given query vector. + * @param query The query vector. + * @param k The number of similar vectors to return. + * @param filter + * @returns Promise that resolves with an array of tuples, each containing a Document and a score. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: WhereClause + ): Promise<[Document, number][]> { + if (!this.isInitialized) { + await this.initialize(); + } + + // Ensure we have an array of Filter from the public interface + const filters = this.asFilters(filter); + + const queryStr = this.buildSearchQuery(filters); + + // Search query will be of format: + // SELECT ..., text, similarity_x(?) AS similarity_score + // FROM ... + // + // ORDER BY vector ANN OF ? + // LIMIT ? + // If any filter values are specified, they will be in the WHERE clause as + // filter.name filter.operator ? + // queryParams is a list of bind variables sent with the prepared statement + const queryParams = []; + const vectorAsFloat32Array = new Float32Array(query); + queryParams.push(vectorAsFloat32Array); + if (filters) { + const values = (filters as Filter[]).map(({ value }) => value); + queryParams.push(...values); + } + queryParams.push(vectorAsFloat32Array); + queryParams.push(k); + + const queryResultSet = await this.client.execute(queryStr, queryParams, { + prepare: true, + }); + + return queryResultSet?.rows.map((row) => { + const textContent = row.text; + const sanitizedRow = { ...row }; + delete sanitizedRow.text; + delete sanitizedRow.similarity_score; + + // A null value in Cassandra evaluates to a deleted column + // as this is treated as a tombstone record for the cell. + Object.keys(sanitizedRow).forEach((key) => { + if (sanitizedRow[key] === null) { + delete sanitizedRow[key]; + } + }); + + return [ + new Document({ pageContent: textContent, metadata: sanitizedRow }), + row.similarity_score, + ]; + }); + } + + /** + * Static method to create an instance of CassandraStore from texts. + * @param texts The texts to use. + * @param metadatas The metadata associated with the texts. + * @param embeddings The embeddings to use. + * @param args The arguments for the CassandraStore. + * @returns Promise that resolves with a new instance of CassandraStore. + */ + static async fromTexts( + texts: string[], + metadatas: object | object[], + embeddings: Embeddings, + args: CassandraLibArgs + ): Promise { + const docs: Document[] = []; + + for (let index = 0; index < texts.length; index += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[index] : metadatas; + const doc = new Document({ + pageContent: texts[index], + metadata, + }); + docs.push(doc); + } + + return CassandraStore.fromDocuments(docs, embeddings, args); + } + + /** + * Static method to create an instance of CassandraStore from documents. + * @param docs The documents to use. + * @param embeddings The embeddings to use. + * @param args The arguments for the CassandraStore. + * @returns Promise that resolves with a new instance of CassandraStore. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + args: CassandraLibArgs + ): Promise { + const instance = new this(embeddings, args); + await instance.addDocuments(docs); + return instance; + } + + /** + * Static method to create an instance of CassandraStore from an existing + * index. + * @param embeddings The embeddings to use. + * @param args The arguments for the CassandraStore. + * @returns Promise that resolves with a new instance of CassandraStore. + */ + static async fromExistingIndex( + embeddings: Embeddings, + args: CassandraLibArgs + ): Promise { + const instance = new this(embeddings, args); + + await instance.initialize(); + return instance; + } + + /** + * Method to initialize the Cassandra database. + * @returns Promise that resolves when the database has been initialized. + */ + private async initialize(): Promise { + let cql = ""; + cql = `CREATE TABLE IF NOT EXISTS ${this.keyspace}.${this.table} ( + ${this.primaryKey.map((col) => `${col.name} ${col.type}`).join(", ")} + , text TEXT + ${ + this.metadataColumns.length > 0 + ? ", " + + this.metadataColumns + .map((col) => `${col.name} ${col.type}`) + .join(", ") + : "" + } + , vector VECTOR + , ${this.buildPrimaryKey(this.primaryKey)} + ) ${this.withClause ? `WITH ${this.withClause}` : ""};`; + + await this.client.execute(cql); + + this.selectColumns = `${this.primaryKey + .map((col) => `${col.name}`) + .join(", ")} + ${ + this.metadataColumns.length > 0 + ? ", " + + this.metadataColumns + .map((col) => `${col.name}`) + .join(", ") + : "" + }`; + + cql = `CREATE CUSTOM INDEX IF NOT EXISTS idx_vector_${this.table} + ON ${this.keyspace}.${ + this.table + }(vector) USING 'StorageAttachedIndex' WITH OPTIONS = {'similarity_function': '${this.vectorType.toUpperCase()}'};`; + await this.client.execute(cql); + + for await (const { name, value } of this.indices) { + cql = `CREATE CUSTOM INDEX IF NOT EXISTS idx_${this.table}_${name} + ON ${this.keyspace}.${this.table} ${value} USING 'StorageAttachedIndex';`; + await this.client.execute(cql); + } + this.isInitialized = true; + } + + /** + * Method to build the PRIMARY KEY clause for CREATE TABLE. + * @param columns: list of Column to include in the key + * @returns The clause, including PRIMARY KEY + */ + private buildPrimaryKey(columns: Column[]): string { + // Partition columns may be specified with optional attribute col.partition + const partitionColumns = columns + .filter((col) => col.partition) + .map((col) => col.name) + .join(", "); + + // All columns not part of the partition key are clustering columns + const clusteringColumns = columns + .filter((col) => !col.partition) + .map((col) => col.name) + .join(", "); + + let primaryKey = ""; + + // If partition columns are specified, they are included in a () wrapper + // If not, the clustering columns are used, and the first clustering column + // is the partition key per normal Cassandra behaviour. + if (partitionColumns) { + primaryKey = `PRIMARY KEY ((${partitionColumns}), ${clusteringColumns})`; + } else { + primaryKey = `PRIMARY KEY (${clusteringColumns})`; + } + + return primaryKey; + } + + /** + * Type guard to check if an object is a Filter. + * @param obj: the object to check + * @returns boolean indicating if the object is a Filter + */ + private isFilter(obj: unknown): obj is Filter { + return ( + typeof obj === "object" && obj !== null && "name" in obj && "value" in obj + ); + } + + /** + * Helper to convert Record to a Filter[] + * @param record: a key-value Record collection + * @returns Record as a Filter[] + */ + private convertToFilters(record: Record): Filter[] { + return Object.entries(record).map(([name, value]) => ({ + name, + value, + operator: "=", + })); + } + + /** + * Input santisation method for filters, as FilterType is not required to be + * Filter[], but we want to use Filter[] internally. + * @param record: the proposed filter + * @returns A Filter[], which may be empty + */ + private asFilters(record: WhereClause | undefined): Filter[] { + if (!record) { + return []; + } + + // If record is already an array + if (Array.isArray(record)) { + return record.flatMap((item) => { + // Check if item is a Filter before passing it to convertToFilters + if (this.isFilter(item)) { + return [item]; + } else { + // Here item is treated as Record + return this.convertToFilters(item); + } + }); + } + + // If record is a single Filter object, return it in an array + if (this.isFilter(record)) { + return [record]; + } + + // If record is a Record, convert it to an array of Filter + return this.convertToFilters(record); + } + + /** + * Method to build the WHERE clause of a CQL query, using bind variable ? + * @param filters list of filters to include in the WHERE clause + * @returns The WHERE clause + */ + private buildWhereClause(filters?: Filter[]): string { + if (!filters || filters.length === 0) { + return ""; + } + + const whereConditions = filters.map( + ({ name, operator = "=" }) => `${name} ${operator} ?` + ); + + return `WHERE ${whereConditions.join(" AND ")}`; + } + + /** + * Method to build an CQL query for searching for similar vectors in the + * Cassandra database. + * @param query The query vector. + * @param k The number of similar vectors to return. + * @param filters + * @returns The CQL query string. + */ + private buildSearchQuery(filters: Filter[]): string { + const whereClause = filters ? this.buildWhereClause(filters) : ""; + + const cqlQuery = `SELECT ${this.selectColumns}, text, similarity_${this.vectorType}(vector, ?) AS similarity_score + FROM ${this.keyspace}.${this.table} ${whereClause} ORDER BY vector ANN OF ? LIMIT ?`; + + return cqlQuery; + } + + /** + * Method for inserting vectors and documents into the Cassandra database in a batch. + * @param batchVectors The list of vectors to insert. + * @param batchDocuments The list of documents to insert. + * @returns Promise that resolves when the batch has been inserted. + */ + private async executeInsert( + batchVectors: number[][], + batchDocuments: Document[] + ): Promise { + // Input validation: Check if the lengths of batchVectors and batchDocuments are the same + if (batchVectors.length !== batchDocuments.length) { + throw new Error( + `The lengths of vectors (${batchVectors.length}) and documents (${batchDocuments.length}) must be the same.` + ); + } + + // Initialize an array to hold query objects + const queries = []; + + // Loop through each vector and document in the batch + for (let i = 0; i < batchVectors.length; i += 1) { + // Convert the list of numbers to a Float32Array, the driver's expected format of a vector + const preparedVector = new Float32Array(batchVectors[i]); + // Retrieve the corresponding document + const document = batchDocuments[i]; + + // Extract metadata column names and values from the document + const metadataColNames = Object.keys(document.metadata); + const metadataVals = Object.values(document.metadata); + + // Prepare the metadata columns string for the query, if metadata exists + const metadataInsert = + metadataColNames.length > 0 ? ", " + metadataColNames.join(", ") : ""; + + // Construct the query string and parameters + const query = { + query: `INSERT INTO ${this.keyspace}.${ + this.table + } (vector, text${metadataInsert}) + VALUES (?, ?${", ?".repeat(metadataColNames.length)})`, + params: [preparedVector, document.pageContent, ...metadataVals], + }; + + // Add the query to the list + queries.push(query); + } + + // Execute the queries: use a batch if multiple, otherwise execute a single query + if (queries.length === 1) { + await this.client.execute(queries[0].query, queries[0].params, { + prepare: true, + }); + } else { + await this.client.batch(queries, { prepare: true, logged: false }); + } + } + + /** + * Method for inserting vectors and documents into the Cassandra database in + * parallel, keeping within maxConcurrency number of active insert statements. + * @param vectors The vectors to insert. + * @param documents The documents to insert. + * @returns Promise that resolves when the documents have been added. + */ + private async insertAll( + vectors: number[][], + documents: Document[] + ): Promise { + // Input validation: Check if the lengths of vectors and documents are the same + if (vectors.length !== documents.length) { + throw new Error( + `The lengths of vectors (${vectors.length}) and documents (${documents.length}) must be the same.` + ); + } + + // Early exit: If there are no vectors or documents to insert, return immediately + if (vectors.length === 0) { + return; + } + + // Ensure the store is initialized before proceeding + if (!this.isInitialized) { + await this.initialize(); + } + + // Initialize an array to hold promises for each batch insert + const insertPromises: Promise[] = []; + + // Buffers to hold the current batch of vectors and documents + let currentBatchVectors: number[][] = []; + let currentBatchDocuments: Document[] = []; + + // Loop through each vector/document pair to insert; we use + // <= vectors.length to ensure the last batch is inserted + for (let i = 0; i <= vectors.length; i += 1) { + // Check if we're still within the array boundaries + if (i < vectors.length) { + // Add the current vector and document to the batch + currentBatchVectors.push(vectors[i]); + currentBatchDocuments.push(documents[i]); + } + + // Check if we've reached the batch size or end of the array + if ( + currentBatchVectors.length >= this.batchSize || + i === vectors.length + ) { + // Only proceed if there are items in the current batch + if (currentBatchVectors.length > 0) { + // Create copies of the current batch arrays to use in the async insert operation + const batchVectors = [...currentBatchVectors]; + const batchDocuments = [...currentBatchDocuments]; + + // Execute the insert using the AsyncCaller - it will handle concurrency and queueing. + insertPromises.push( + this.asyncCaller.call(() => + this.executeInsert(batchVectors, batchDocuments) + ) + ); + + // Clear the current buffers for the next iteration + currentBatchVectors = []; + currentBatchDocuments = []; + } + } + } + + // Wait for all insert operations to complete. + await Promise.all(insertPromises); + } +} diff --git a/libs/langchain-community/src/vectorstores/chroma.ts b/libs/langchain-community/src/vectorstores/chroma.ts new file mode 100644 index 000000000000..96e6b15475dc --- /dev/null +++ b/libs/langchain-community/src/vectorstores/chroma.ts @@ -0,0 +1,364 @@ +import * as uuid from "uuid"; +import type { ChromaClient as ChromaClientT, Collection } from "chromadb"; +import type { CollectionMetadata, Where } from "chromadb/dist/main/types.js"; + +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; + +/** + * Defines the arguments that can be passed to the `Chroma` class + * constructor. It can either contain a `url` for the Chroma database, the + * number of dimensions for the vectors (`numDimensions`), a + * `collectionName` for the collection to be used in the database, and a + * `filter` object; or it can contain an `index` which is an instance of + * `ChromaClientT`, along with the `numDimensions`, `collectionName`, and + * `filter`. + */ +export type ChromaLibArgs = + | { + url?: string; + numDimensions?: number; + collectionName?: string; + filter?: object; + collectionMetadata?: CollectionMetadata; + } + | { + index?: ChromaClientT; + numDimensions?: number; + collectionName?: string; + filter?: object; + collectionMetadata?: CollectionMetadata; + }; + +/** + * Defines the parameters for the `delete` method in the `Chroma` class. + * It can either contain an array of `ids` of the documents to be deleted + * or a `filter` object to specify the documents to be deleted. + */ +export interface ChromaDeleteParams { + ids?: string[]; + filter?: T; +} + +/** + * The main class that extends the `VectorStore` class. It provides + * methods for interacting with the Chroma database, such as adding + * documents, deleting documents, and searching for similar vectors. + */ +export class Chroma extends VectorStore { + declare FilterType: Where; + + index?: ChromaClientT; + + collection?: Collection; + + collectionName: string; + + collectionMetadata?: CollectionMetadata; + + numDimensions?: number; + + url: string; + + filter?: object; + + _vectorstoreType(): string { + return "chroma"; + } + + constructor(embeddings: Embeddings, args: ChromaLibArgs) { + super(embeddings, args); + this.numDimensions = args.numDimensions; + this.embeddings = embeddings; + this.collectionName = ensureCollectionName(args.collectionName); + this.collectionMetadata = args.collectionMetadata; + if ("index" in args) { + this.index = args.index; + } else if ("url" in args) { + this.url = args.url || "http://localhost:8000"; + } + + this.filter = args.filter; + } + + /** + * Adds documents to the Chroma database. The documents are first + * converted to vectors using the `embeddings` instance, and then added to + * the database. + * @param documents An array of `Document` instances to be added to the database. + * @param options Optional. An object containing an array of `ids` for the documents. + * @returns A promise that resolves when the documents have been added to the database. + */ + async addDocuments(documents: Document[], options?: { ids?: string[] }) { + const texts = documents.map(({ pageContent }) => pageContent); + return this.addVectors( + await this.embeddings.embedDocuments(texts), + documents, + options + ); + } + + /** + * Ensures that a collection exists in the Chroma database. If the + * collection does not exist, it is created. + * @returns A promise that resolves with the `Collection` instance. + */ + async ensureCollection(): Promise { + if (!this.collection) { + if (!this.index) { + const { ChromaClient } = await Chroma.imports(); + this.index = new ChromaClient({ path: this.url }); + } + try { + this.collection = await this.index.getOrCreateCollection({ + name: this.collectionName, + ...(this.collectionMetadata && { metadata: this.collectionMetadata }), + }); + } catch (err) { + throw new Error(`Chroma getOrCreateCollection error: ${err}`); + } + } + + return this.collection; + } + + /** + * Adds vectors to the Chroma database. The vectors are associated with + * the provided documents. + * @param vectors An array of vectors to be added to the database. + * @param documents An array of `Document` instances associated with the vectors. + * @param options Optional. An object containing an array of `ids` for the vectors. + * @returns A promise that resolves with an array of document IDs when the vectors have been added to the database. + */ + async addVectors( + vectors: number[][], + documents: Document[], + options?: { ids?: string[] } + ) { + if (vectors.length === 0) { + return []; + } + if (this.numDimensions === undefined) { + this.numDimensions = vectors[0].length; + } + if (vectors.length !== documents.length) { + throw new Error(`Vectors and metadatas must have the same length`); + } + if (vectors[0].length !== this.numDimensions) { + throw new Error( + `Vectors must have the same length as the number of dimensions (${this.numDimensions})` + ); + } + + const documentIds = + options?.ids ?? Array.from({ length: vectors.length }, () => uuid.v1()); + const collection = await this.ensureCollection(); + + const mappedMetadatas = documents.map(({ metadata }) => { + let locFrom; + let locTo; + + if (metadata?.loc) { + if (metadata.loc.lines?.from !== undefined) + locFrom = metadata.loc.lines.from; + if (metadata.loc.lines?.to !== undefined) locTo = metadata.loc.lines.to; + } + + const newMetadata: Document["metadata"] = { + ...metadata, + ...(locFrom !== undefined && { locFrom }), + ...(locTo !== undefined && { locTo }), + }; + + if (newMetadata.loc) delete newMetadata.loc; + + return newMetadata; + }); + + await collection.upsert({ + ids: documentIds, + embeddings: vectors, + metadatas: mappedMetadatas, + documents: documents.map(({ pageContent }) => pageContent), + }); + return documentIds; + } + + /** + * Deletes documents from the Chroma database. The documents to be deleted + * can be specified by providing an array of `ids` or a `filter` object. + * @param params An object containing either an array of `ids` of the documents to be deleted or a `filter` object to specify the documents to be deleted. + * @returns A promise that resolves when the specified documents have been deleted from the database. + */ + async delete(params: ChromaDeleteParams): Promise { + const collection = await this.ensureCollection(); + if (Array.isArray(params.ids)) { + await collection.delete({ ids: params.ids }); + } else if (params.filter) { + await collection.delete({ + where: { ...params.filter }, + }); + } else { + throw new Error(`You must provide one of "ids or "filter".`); + } + } + + /** + * Searches for vectors in the Chroma database that are similar to the + * provided query vector. The search can be filtered using the provided + * `filter` object or the `filter` property of the `Chroma` instance. + * @param query The query vector. + * @param k The number of similar vectors to return. + * @param filter Optional. A `filter` object to filter the search results. + * @returns A promise that resolves with an array of tuples, each containing a `Document` instance and a similarity score. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: this["FilterType"] + ) { + if (filter && this.filter) { + throw new Error("cannot provide both `filter` and `this.filter`"); + } + const _filter = filter ?? this.filter; + + const collection = await this.ensureCollection(); + + // similaritySearchVectorWithScore supports one query vector at a time + // chroma supports multiple query vectors at a time + const result = await collection.query({ + queryEmbeddings: query, + nResults: k, + where: { ..._filter }, + }); + + const { ids, distances, documents, metadatas } = result; + if (!ids || !distances || !documents || !metadatas) { + return []; + } + // get the result data from the first and only query vector + const [firstIds] = ids; + const [firstDistances] = distances; + const [firstDocuments] = documents; + const [firstMetadatas] = metadatas; + + const results: [Document, number][] = []; + for (let i = 0; i < firstIds.length; i += 1) { + let metadata: Document["metadata"] = firstMetadatas?.[i] ?? {}; + + if (metadata.locFrom && metadata.locTo) { + metadata = { + ...metadata, + loc: { + lines: { + from: metadata.locFrom, + to: metadata.locTo, + }, + }, + }; + + delete metadata.locFrom; + delete metadata.locTo; + } + + results.push([ + new Document({ + pageContent: firstDocuments?.[i] ?? "", + metadata, + }), + firstDistances[i], + ]); + } + return results; + } + + /** + * Creates a new `Chroma` instance from an array of text strings. The text + * strings are converted to `Document` instances and added to the Chroma + * database. + * @param texts An array of text strings. + * @param metadatas An array of metadata objects or a single metadata object. If an array is provided, it must have the same length as the `texts` array. + * @param embeddings An `Embeddings` instance used to generate embeddings for the documents. + * @param dbConfig A `ChromaLibArgs` object containing the configuration for the Chroma database. + * @returns A promise that resolves with a new `Chroma` instance. + */ + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + dbConfig: ChromaLibArgs + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + return this.fromDocuments(docs, embeddings, dbConfig); + } + + /** + * Creates a new `Chroma` instance from an array of `Document` instances. + * The documents are added to the Chroma database. + * @param docs An array of `Document` instances. + * @param embeddings An `Embeddings` instance used to generate embeddings for the documents. + * @param dbConfig A `ChromaLibArgs` object containing the configuration for the Chroma database. + * @returns A promise that resolves with a new `Chroma` instance. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig: ChromaLibArgs + ): Promise { + const instance = new this(embeddings, dbConfig); + await instance.addDocuments(docs); + return instance; + } + + /** + * Creates a new `Chroma` instance from an existing collection in the + * Chroma database. + * @param embeddings An `Embeddings` instance used to generate embeddings for the documents. + * @param dbConfig A `ChromaLibArgs` object containing the configuration for the Chroma database. + * @returns A promise that resolves with a new `Chroma` instance. + */ + static async fromExistingCollection( + embeddings: Embeddings, + dbConfig: ChromaLibArgs + ): Promise { + const instance = new this(embeddings, dbConfig); + await instance.ensureCollection(); + return instance; + } + + /** + * Imports the `ChromaClient` from the `chromadb` module. + * @returns A promise that resolves with an object containing the `ChromaClient` constructor. + */ + static async imports(): Promise<{ + ChromaClient: typeof ChromaClientT; + }> { + try { + const { ChromaClient } = await import("chromadb"); + return { ChromaClient }; + } catch (e) { + throw new Error( + "Please install chromadb as a dependency with, e.g. `npm install -S chromadb`" + ); + } + } +} + +/** + * Generates a unique collection name if none is provided. + */ +function ensureCollectionName(collectionName?: string) { + if (!collectionName) { + return `langchain-${uuid.v4()}`; + } + return collectionName; +} diff --git a/libs/langchain-community/src/vectorstores/clickhouse.ts b/libs/langchain-community/src/vectorstores/clickhouse.ts new file mode 100644 index 000000000000..4cb9bd127e7c --- /dev/null +++ b/libs/langchain-community/src/vectorstores/clickhouse.ts @@ -0,0 +1,338 @@ +import * as uuid from "uuid"; +import { ClickHouseClient, createClient } from "@clickhouse/client"; +import { format } from "mysql2"; +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; + +/** + * Arguments for the ClickHouseStore class, which include the host, port, + * protocol, username, password, index type, index parameters, + * index query params, column map, database, table. + */ +export interface ClickHouseLibArgs { + host: string; + port: string | number; + protocol?: string; + username: string; + password: string; + indexType?: string; + indexParam?: Record; + indexQueryParams?: Record; + columnMap?: ColumnMap; + database?: string; + table?: string; +} + +/** + * Mapping of columns in the ClickHouse database. + */ +export interface ColumnMap { + id: string; + uuid: string; + document: string; + embedding: string; + metadata: string; +} + +/** + * Type for filtering search results in the ClickHouse database. + */ +export interface ClickHouseFilter { + whereStr: string; +} + +/** + * Class for interacting with the ClickHouse database. It extends the + * VectorStore class and provides methods for adding vectors and + * documents, searching for similar vectors, and creating instances from + * texts or documents. + */ +export class ClickHouseStore extends VectorStore { + declare FilterType: ClickHouseFilter; + + private client: ClickHouseClient; + + private indexType: string; + + private indexParam: Record; + + private indexQueryParams: Record; + + private columnMap: ColumnMap; + + private database: string; + + private table: string; + + private isInitialized = false; + + _vectorstoreType(): string { + return "clickhouse"; + } + + constructor(embeddings: Embeddings, args: ClickHouseLibArgs) { + super(embeddings, args); + + this.indexType = args.indexType || "annoy"; + this.indexParam = args.indexParam || { L2Distance: 100 }; + this.indexQueryParams = args.indexQueryParams || {}; + this.columnMap = args.columnMap || { + id: "id", + document: "document", + embedding: "embedding", + metadata: "metadata", + uuid: "uuid", + }; + this.database = args.database || "default"; + this.table = args.table || "vector_table"; + + this.client = createClient({ + host: `${args.protocol ?? "https://"}${args.host}:${args.port}`, + username: args.username, + password: args.password, + session_id: uuid.v4(), + }); + } + + /** + * Method to add vectors to the ClickHouse database. + * @param vectors The vectors to add. + * @param documents The documents associated with the vectors. + * @returns Promise that resolves when the vectors have been added. + */ + async addVectors(vectors: number[][], documents: Document[]): Promise { + if (vectors.length === 0) { + return; + } + + if (!this.isInitialized) { + await this.initialize(vectors[0].length); + } + + const queryStr = this.buildInsertQuery(vectors, documents); + await this.client.exec({ query: queryStr }); + } + + /** + * Method to add documents to the ClickHouse database. + * @param documents The documents to add. + * @returns Promise that resolves when the documents have been added. + */ + async addDocuments(documents: Document[]): Promise { + return this.addVectors( + await this.embeddings.embedDocuments(documents.map((d) => d.pageContent)), + documents + ); + } + + /** + * Method to search for vectors that are similar to a given query vector. + * @param query The query vector. + * @param k The number of similar vectors to return. + * @param filter Optional filter for the search results. + * @returns Promise that resolves with an array of tuples, each containing a Document and a score. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: this["FilterType"] + ): Promise<[Document, number][]> { + if (!this.isInitialized) { + await this.initialize(query.length); + } + const queryStr = this.buildSearchQuery(query, k, filter); + + const queryResultSet = await this.client.query({ query: queryStr }); + + const queryResult: { + data: { document: string; metadata: object; dist: number }[]; + } = await queryResultSet.json(); + + const result: [Document, number][] = queryResult.data.map((item) => [ + new Document({ pageContent: item.document, metadata: item.metadata }), + item.dist, + ]); + + return result; + } + + /** + * Static method to create an instance of ClickHouseStore from texts. + * @param texts The texts to use. + * @param metadatas The metadata associated with the texts. + * @param embeddings The embeddings to use. + * @param args The arguments for the ClickHouseStore. + * @returns Promise that resolves with a new instance of ClickHouseStore. + */ + static async fromTexts( + texts: string[], + metadatas: object | object[], + embeddings: Embeddings, + args: ClickHouseLibArgs + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + return ClickHouseStore.fromDocuments(docs, embeddings, args); + } + + /** + * Static method to create an instance of ClickHouseStore from documents. + * @param docs The documents to use. + * @param embeddings The embeddings to use. + * @param args The arguments for the ClickHouseStore. + * @returns Promise that resolves with a new instance of ClickHouseStore. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + args: ClickHouseLibArgs + ): Promise { + const instance = new this(embeddings, args); + await instance.addDocuments(docs); + return instance; + } + + /** + * Static method to create an instance of ClickHouseStore from an existing + * index. + * @param embeddings The embeddings to use. + * @param args The arguments for the ClickHouseStore. + * @returns Promise that resolves with a new instance of ClickHouseStore. + */ + static async fromExistingIndex( + embeddings: Embeddings, + args: ClickHouseLibArgs + ): Promise { + const instance = new this(embeddings, args); + + await instance.initialize(); + return instance; + } + + /** + * Method to initialize the ClickHouse database. + * @param dimension Optional dimension of the vectors. + * @returns Promise that resolves when the database has been initialized. + */ + private async initialize(dimension?: number): Promise { + const dim = dimension ?? (await this.embeddings.embedQuery("test")).length; + + const indexParamStr = this.indexParam + ? Object.entries(this.indexParam) + .map(([key, value]) => `'${key}', ${value}`) + .join(", ") + : ""; + + const query = ` + CREATE TABLE IF NOT EXISTS ${this.database}.${this.table}( + ${this.columnMap.id} Nullable(String), + ${this.columnMap.document} Nullable(String), + ${this.columnMap.embedding} Array(Float32), + ${this.columnMap.metadata} JSON, + ${this.columnMap.uuid} UUID DEFAULT generateUUIDv4(), + CONSTRAINT cons_vec_len CHECK length(${this.columnMap.embedding}) = ${dim}, + INDEX vec_idx ${this.columnMap.embedding} TYPE ${this.indexType}(${indexParamStr}) GRANULARITY 1000 + ) ENGINE = MergeTree ORDER BY ${this.columnMap.uuid} SETTINGS index_granularity = 8192;`; + + await this.client.exec({ + query, + clickhouse_settings: { + allow_experimental_object_type: 1, + allow_experimental_annoy_index: 1, + }, + }); + this.isInitialized = true; + } + + /** + * Method to build an SQL query for inserting vectors and documents into + * the ClickHouse database. + * @param vectors The vectors to insert. + * @param documents The documents to insert. + * @returns The SQL query string. + */ + private buildInsertQuery(vectors: number[][], documents: Document[]): string { + const columnsStr = Object.values( + Object.fromEntries( + Object.entries(this.columnMap).filter( + ([key]) => key !== this.columnMap.uuid + ) + ) + ).join(", "); + + const placeholders = vectors.map(() => "(?, ?, ?, ?)").join(", "); + const values = []; + + for (let i = 0; i < vectors.length; i += 1) { + const vector = vectors[i]; + const document = documents[i]; + values.push( + uuid.v4(), + this.escapeString(document.pageContent), + JSON.stringify(vector), + JSON.stringify(document.metadata) + ); + } + + const insertQueryStr = ` + INSERT INTO TABLE ${this.database}.${this.table}(${columnsStr}) + VALUES ${placeholders} + `; + + const insertQuery = format(insertQueryStr, values); + return insertQuery; + } + + private escapeString(str: string): string { + return str.replace(/\\/g, "\\\\").replace(/'/g, "\\'"); + } + + /** + * Method to build an SQL query for searching for similar vectors in the + * ClickHouse database. + * @param query The query vector. + * @param k The number of similar vectors to return. + * @param filter Optional filter for the search results. + * @returns The SQL query string. + */ + private buildSearchQuery( + query: number[], + k: number, + filter?: ClickHouseFilter + ): string { + const order = "ASC"; + const whereStr = filter ? `PREWHERE ${filter.whereStr}` : ""; + const placeholders = query.map(() => "?").join(", "); + + const settingStrings: string[] = []; + if (this.indexQueryParams) { + for (const [key, value] of Object.entries(this.indexQueryParams)) { + settingStrings.push(`SETTING ${key}=${value}`); + } + } + + const searchQueryStr = ` + SELECT ${this.columnMap.document} AS document, ${ + this.columnMap.metadata + } AS metadata, dist + FROM ${this.database}.${this.table} + ${whereStr} + ORDER BY L2Distance(${ + this.columnMap.embedding + }, [${placeholders}]) AS dist ${order} + LIMIT ${k} ${settingStrings.join(" ")} + `; + + // Format the query with actual values + const searchQuery = format(searchQueryStr, query); + return searchQuery; + } +} diff --git a/langchain/src/vectorstores/closevector/common.ts b/libs/langchain-community/src/vectorstores/closevector/common.ts similarity index 96% rename from langchain/src/vectorstores/closevector/common.ts rename to libs/langchain-community/src/vectorstores/closevector/common.ts index 8afef62324cc..ec62f28b7987 100644 --- a/langchain/src/vectorstores/closevector/common.ts +++ b/libs/langchain-community/src/vectorstores/closevector/common.ts @@ -1,8 +1,8 @@ import type { CloseVectorSaveableVectorStore } from "closevector-common"; -import { Embeddings } from "../../embeddings/base.js"; -import { Document } from "../../document.js"; -import { SaveableVectorStore } from "../base.js"; +import { Embeddings } from "@langchain/core/embeddings"; +import { Document } from "@langchain/core/documents"; +import { SaveableVectorStore } from "@langchain/core/vectorstores"; type CloseVectorCredentials = { key?: string; diff --git a/libs/langchain-community/src/vectorstores/closevector/node.ts b/libs/langchain-community/src/vectorstores/closevector/node.ts new file mode 100644 index 000000000000..c45bf012bcdc --- /dev/null +++ b/libs/langchain-community/src/vectorstores/closevector/node.ts @@ -0,0 +1,182 @@ +import { + CloseVectorHNSWNode, + HierarchicalNSWT, + CloseVectorHNSWLibArgs, + CloseVectorCredentials, +} from "closevector-node"; + +import { Embeddings } from "@langchain/core/embeddings"; +import { Document } from "@langchain/core/documents"; + +import { CloseVector } from "./common.js"; + +/** + * package closevector-node is largely based on hnswlib.ts in the current folder with the following exceptions: + * 1. It uses a modified version of hnswlib-node to ensure the generated index can be loaded by closevector_web.ts. + * 2. It adds features to upload and download the index to/from the CDN provided by CloseVector. + * + * For more information, check out https://closevector-docs.getmegaportal.com/ + */ + +/** + * Arguments for creating a CloseVectorNode instance, extending CloseVectorHNSWLibArgs. + */ +export interface CloseVectorNodeArgs + extends CloseVectorHNSWLibArgs { + instance?: CloseVectorHNSWNode; +} + +/** + * Class that implements a vector store using Hierarchical Navigable Small + * World (HNSW) graphs. It extends the SaveableVectorStore class and + * provides methods for adding documents and vectors, performing + * similarity searches, and saving and loading the vector store. + */ +export class CloseVectorNode extends CloseVector { + declare FilterType: (doc: Document) => boolean; + + constructor( + embeddings: Embeddings, + args: CloseVectorNodeArgs, + credentials?: CloseVectorCredentials + ) { + super(embeddings, args, credentials); + if (args.instance) { + this.instance = args.instance; + } else { + this.instance = new CloseVectorHNSWNode(embeddings, args); + } + if (this.credentials?.key) { + this.instance.accessKey = this.credentials.key; + } + if (this.credentials?.secret) { + this.instance.secret = this.credentials.secret; + } + } + + /** + * Method to save the index to the CloseVector CDN. + * @param options + * @param options.description A description of the index. + * @param options.public Whether the index should be public or private. Defaults to false. + * @param options.uuid A UUID for the index. If not provided, a new index will be created. + * @param options.onProgress A callback function that will be called with the progress of the upload. + */ + async saveToCloud( + options: Parameters[0] + ) { + await this.instance.saveToCloud(options); + } + + /** + * Method to load the index from the CloseVector CDN. + * @param options + * @param options.uuid The UUID of the index to be downloaded. + * @param options.credentials The credentials to be used by the CloseVectorNode instance. + * @param options.embeddings The embeddings to be used by the CloseVectorNode instance. + * @param options.onProgress A callback function that will be called with the progress of the download. + */ + static async loadFromCloud( + options: Omit< + Parameters<(typeof CloseVectorHNSWNode)["loadFromCloud"]>[0] & { + embeddings: Embeddings; + credentials: CloseVectorCredentials; + }, + "accessKey" | "secret" + > + ) { + if (!options.credentials.key || !options.credentials.secret) { + throw new Error("key and secret must be provided"); + } + const instance = await CloseVectorHNSWNode.loadFromCloud({ + ...options, + accessKey: options.credentials.key, + secret: options.credentials.secret, + }); + const vectorstore = new this( + options.embeddings, + instance.args, + options.credentials + ); + return vectorstore; + } + + /** + * Static method to load a vector store from a directory. It reads the + * HNSW index, the arguments, and the document store from the directory, + * then creates a new HNSWLib instance with these values. + * @param directory The directory from which to load the vector store. + * @param embeddings The embeddings to be used by the CloseVectorNode instance. + * @returns A Promise that resolves to a new CloseVectorNode instance. + */ + static async load( + directory: string, + embeddings: Embeddings, + credentials?: CloseVectorCredentials + ) { + const instance = await CloseVectorHNSWNode.load(directory, embeddings); + const vectorstore = new this(embeddings, instance.args, credentials); + return vectorstore; + } + + /** + * Static method to create a new CloseVectorWeb instance from texts and metadata. + * It creates a new Document instance for each text and metadata, then + * calls the fromDocuments method to create the CloseVectorWeb instance. + * @param texts The texts to be used to create the documents. + * @param metadatas The metadata to be used to create the documents. + * @param embeddings The embeddings to be used by the CloseVectorWeb instance. + * @param args An optional configuration object for the CloseVectorWeb instance. + * @param credential An optional credential object for the CloseVector API. + * @returns A Promise that resolves to a new CloseVectorWeb instance. + */ + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + args?: Record, + credential?: CloseVectorCredentials + ): Promise { + const docs = CloseVector.textsToDocuments(texts, metadatas); + return await CloseVectorNode.fromDocuments( + docs, + embeddings, + args, + credential + ); + } + + /** + * Static method to create a new CloseVectorNode instance from documents. It + * creates a new CloseVectorNode instance, adds the documents to it, then returns + * the instance. + * @param docs The documents to be added to the HNSWLib instance. + * @param embeddings The embeddings to be used by the HNSWLib instance. + * @param args An optional configuration object for the HNSWLib instance. + * @param credentials An optional credential object for the CloseVector API. + * @returns A Promise that resolves to a new CloseVectorNode instance. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + args?: Record, + credentials?: CloseVectorCredentials + ): Promise { + const _args: Record = args || { + space: "cosine", + }; + const instance = new this( + embeddings, + _args as unknown as CloseVectorNodeArgs, + credentials + ); + await instance.addDocuments(docs); + return instance; + } + + static async imports(): Promise<{ + HierarchicalNSW: typeof HierarchicalNSWT; + }> { + return CloseVectorHNSWNode.imports(); + } +} diff --git a/libs/langchain-community/src/vectorstores/closevector/web.ts b/libs/langchain-community/src/vectorstores/closevector/web.ts new file mode 100644 index 000000000000..9a0896041df5 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/closevector/web.ts @@ -0,0 +1,179 @@ +import { + CloseVectorHNSWWeb, + HierarchicalNSWT, + CloseVectorHNSWLibArgs, + CloseVectorCredentials, + HnswlibModule, +} from "closevector-web"; + +import { Embeddings } from "@langchain/core/embeddings"; +import { Document } from "@langchain/core/documents"; + +import { CloseVector } from "./common.js"; + +/** + * package closevector-node is largely based on hnswlib.ts in the current folder with the following exceptions: + * 1. It uses a modified version of hnswlib-node to ensure the generated index can be loaded by closevector_web.ts. + * 2. It adds features to upload and download the index to/from the CDN provided by CloseVector. + * + * For more information, check out https://closevector-docs.getmegaportal.com/ + */ + +/** + * Arguments for creating a CloseVectorWeb instance, extending CloseVectorHNSWLibArgs. + */ +export interface CloseVectorWebArgs + extends CloseVectorHNSWLibArgs { + instance?: CloseVectorHNSWWeb; +} + +/** + * Class that implements a vector store using CloseVector, It extends the SaveableVectorStore class and + * provides methods for adding documents and vectors, performing + * similarity searches, and saving and loading the vector store. + */ +export class CloseVectorWeb extends CloseVector { + declare FilterType: (doc: Document) => boolean; + + constructor( + embeddings: Embeddings, + args: CloseVectorWebArgs, + credentials?: CloseVectorCredentials + ) { + super(embeddings, args, credentials); + if (args.instance) { + this.instance = args.instance; + } else { + this.instance = new CloseVectorHNSWWeb(embeddings, args); + } + } + + /** + * Method to save the index to the CloseVector CDN. + * @param options + * @param options.url the upload url generated by the CloseVector API: https://closevector-docs.getmegaportal.com/docs/api/http-api/file-url + * @param options.onProgress a callback function to track the upload progress + */ + async saveToCloud( + options: Parameters[0] & { + uuid?: string; + } + ) { + if (!this.instance.uuid && !options.uuid) { + throw new Error("No uuid provided"); + } + if (!this.instance.uuid) { + this.instance._uuid = options.uuid; + } + await this.save(this.instance.uuid); + await this.instance.saveToCloud(options); + } + + /** + * Method to load the index from the CloseVector CDN. + * @param options + * @param options.url the upload url generated by the CloseVector API: https://closevector-docs.getmegaportal.com/docs/api/http-api/file-url + * @param options.onProgress a callback function to track the upload progress + * @param options.uuid the uuid of the index to be downloaded + * @param options.embeddings the embeddings to be used by the CloseVectorWeb instance + */ + static async loadFromCloud( + options: Parameters[0] & { + embeddings: Embeddings; + credentials?: CloseVectorCredentials; + } + ) { + const instance = await CloseVectorHNSWWeb.loadFromCloud(options); + const vectorstore = new this( + options.embeddings, + instance.args, + options.credentials + ); + return vectorstore; + } + + /** + * Static method to load a vector store from a directory. It reads the + * HNSW index, the arguments, and the document store from the directory, + * then creates a new CloseVectorWeb instance with these values. + * @param directory The directory from which to load the vector store. + * @param embeddings The embeddings to be used by the CloseVectorWeb instance. + * @returns A Promise that resolves to a new CloseVectorWeb instance. + */ + static async load( + directory: string, + embeddings: Embeddings, + credentials?: CloseVectorCredentials + ) { + const instance = await CloseVectorHNSWWeb.load(directory, embeddings); + const vectorstore = new this(embeddings, instance.args, credentials); + return vectorstore; + } + + /** + * Static method to create a new CloseVectorWeb instance from texts and metadata. + * It creates a new Document instance for each text and metadata, then + * calls the fromDocuments method to create the CloseVectorWeb instance. + * @param texts The texts to be used to create the documents. + * @param metadatas The metadata to be used to create the documents. + * @param embeddings The embeddings to be used by the CloseVectorWeb instance. + * @param args An optional configuration object for the CloseVectorWeb instance. + * @param credential An optional credential object for the CloseVector API. + * @returns A Promise that resolves to a new CloseVectorWeb instance. + */ + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + args?: Record, + credential?: CloseVectorCredentials + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + return await CloseVectorWeb.fromDocuments( + docs, + embeddings, + args, + credential + ); + } + + /** + * Static method to create a new CloseVectorWeb instance from documents. It + * creates a new CloseVectorWeb instance, adds the documents to it, then returns + * the instance. + * @param docs The documents to be added to the CloseVectorWeb instance. + * @param embeddings The embeddings to be used by the CloseVectorWeb instance. + * @param args An optional configuration object for the CloseVectorWeb instance. + * @param credentials An optional credential object for the CloseVector API. + * @returns A Promise that resolves to a new CloseVectorWeb instance. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + args?: Record, + credentials?: CloseVectorCredentials + ): Promise { + const _args: Record = args || { + space: "cosine", + }; + const instance = new this( + embeddings, + _args as unknown as CloseVectorWebArgs, + credentials + ); + await instance.addDocuments(docs); + return instance; + } + + static async imports(): Promise { + return CloseVectorHNSWWeb.imports(); + } +} diff --git a/libs/langchain-community/src/vectorstores/cloudflare_vectorize.ts b/libs/langchain-community/src/vectorstores/cloudflare_vectorize.ts new file mode 100644 index 000000000000..605aa7ab44ec --- /dev/null +++ b/libs/langchain-community/src/vectorstores/cloudflare_vectorize.ts @@ -0,0 +1,230 @@ +import * as uuid from "uuid"; + +import { + VectorizeIndex, + VectorizeVectorMetadata, +} from "@cloudflare/workers-types"; +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; +import { + AsyncCaller, + type AsyncCallerParams, +} from "@langchain/core/utils/async_caller"; +import { chunkArray } from "../utils/chunk.js"; + +export interface VectorizeLibArgs extends AsyncCallerParams { + index: VectorizeIndex; + textKey?: string; +} + +/** + * Type that defines the parameters for the delete operation in the + * CloudflareVectorizeStore class. It includes ids, deleteAll flag, and namespace. + */ +export type VectorizeDeleteParams = { + ids: string[]; +}; + +/** + * Class that extends the VectorStore class and provides methods to + * interact with the Cloudflare Vectorize vector database. + */ +export class CloudflareVectorizeStore extends VectorStore { + textKey: string; + + namespace?: string; + + index: VectorizeIndex; + + caller: AsyncCaller; + + _vectorstoreType(): string { + return "cloudflare_vectorize"; + } + + constructor(embeddings: Embeddings, args: VectorizeLibArgs) { + super(embeddings, args); + + this.embeddings = embeddings; + const { index, textKey, ...asyncCallerArgs } = args; + if (!index) { + throw new Error( + "Must supply a Vectorize index binding, eg { index: env.VECTORIZE }" + ); + } + this.index = index; + this.textKey = textKey ?? "text"; + this.caller = new AsyncCaller({ + maxConcurrency: 6, + maxRetries: 0, + ...asyncCallerArgs, + }); + } + + /** + * Method that adds documents to the Vectorize database. + * @param documents Array of documents to add. + * @param options Optional ids for the documents. + * @returns Promise that resolves with the ids of the added documents. + */ + async addDocuments( + documents: Document[], + options?: { ids?: string[] } | string[] + ) { + const texts = documents.map(({ pageContent }) => pageContent); + return this.addVectors( + await this.embeddings.embedDocuments(texts), + documents, + options + ); + } + + /** + * Method that adds vectors to the Vectorize database. + * @param vectors Array of vectors to add. + * @param documents Array of documents associated with the vectors. + * @param options Optional ids for the vectors. + * @returns Promise that resolves with the ids of the added vectors. + */ + async addVectors( + vectors: number[][], + documents: Document[], + options?: { ids?: string[] } | string[] + ) { + const ids = Array.isArray(options) ? options : options?.ids; + const documentIds = ids == null ? documents.map(() => uuid.v4()) : ids; + const vectorizeVectors = vectors.map((values, idx) => { + const metadata: Record = { + ...documents[idx].metadata, + [this.textKey]: documents[idx].pageContent, + }; + return { + id: documentIds[idx], + metadata, + values, + }; + }); + + // Stick to a limit of 500 vectors per upsert request + const chunkSize = 500; + const chunkedVectors = chunkArray(vectorizeVectors, chunkSize); + const batchRequests = chunkedVectors.map((chunk) => + this.caller.call(async () => this.index.upsert(chunk)) + ); + + await Promise.all(batchRequests); + + return documentIds; + } + + /** + * Method that deletes vectors from the Vectorize database. + * @param params Parameters for the delete operation. + * @returns Promise that resolves when the delete operation is complete. + */ + async delete(params: VectorizeDeleteParams): Promise { + const batchSize = 1000; + const batchedIds = chunkArray(params.ids, batchSize); + const batchRequests = batchedIds.map((batchIds) => + this.caller.call(async () => this.index.deleteByIds(batchIds)) + ); + await Promise.all(batchRequests); + } + + /** + * Method that performs a similarity search in the Vectorize database and + * returns the results along with their scores. + * @param query Query vector for the similarity search. + * @param k Number of top results to return. + * @returns Promise that resolves with an array of documents and their scores. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number + ): Promise<[Document, number][]> { + const results = await this.index.query(query, { + returnVectors: true, + topK: k, + }); + + const result: [Document, number][] = []; + + if (results.matches) { + for (const res of results.matches) { + const { [this.textKey]: pageContent, ...metadata } = + res.vector?.metadata ?? {}; + result.push([ + new Document({ metadata, pageContent: pageContent as string }), + res.score, + ]); + } + } + + return result; + } + + /** + * Static method that creates a new instance of the CloudflareVectorizeStore class + * from texts. + * @param texts Array of texts to add to the Vectorize database. + * @param metadatas Metadata associated with the texts. + * @param embeddings Embeddings to use for the texts. + * @param dbConfig Configuration for the Vectorize database. + * @param options Optional ids for the vectors. + * @returns Promise that resolves with a new instance of the CloudflareVectorizeStore class. + */ + static async fromTexts( + texts: string[], + metadatas: + | Record[] + | Record, + embeddings: Embeddings, + dbConfig: VectorizeLibArgs + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + return CloudflareVectorizeStore.fromDocuments(docs, embeddings, dbConfig); + } + + /** + * Static method that creates a new instance of the CloudflareVectorizeStore class + * from documents. + * @param docs Array of documents to add to the Vectorize database. + * @param embeddings Embeddings to use for the documents. + * @param dbConfig Configuration for the Vectorize database. + * @param options Optional ids for the vectors. + * @returns Promise that resolves with a new instance of the CloudflareVectorizeStore class. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig: VectorizeLibArgs + ): Promise { + const instance = new this(embeddings, dbConfig); + await instance.addDocuments(docs); + return instance; + } + + /** + * Static method that creates a new instance of the CloudflareVectorizeStore class + * from an existing index. + * @param embeddings Embeddings to use for the documents. + * @param dbConfig Configuration for the Vectorize database. + * @returns Promise that resolves with a new instance of the CloudflareVectorizeStore class. + */ + static async fromExistingIndex( + embeddings: Embeddings, + dbConfig: VectorizeLibArgs + ): Promise { + const instance = new this(embeddings, dbConfig); + return instance; + } +} diff --git a/libs/langchain-community/src/vectorstores/convex.ts b/libs/langchain-community/src/vectorstores/convex.ts new file mode 100644 index 000000000000..72d3661a8b2c --- /dev/null +++ b/libs/langchain-community/src/vectorstores/convex.ts @@ -0,0 +1,376 @@ +// eslint-disable-next-line import/no-extraneous-dependencies +import { + DocumentByInfo, + FieldPaths, + FilterExpression, + FunctionReference, + GenericActionCtx, + GenericDataModel, + GenericTableInfo, + NamedTableInfo, + NamedVectorIndex, + TableNamesInDataModel, + VectorFilterBuilder, + VectorIndexNames, + makeFunctionReference, +} from "convex/server"; +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; + +/** + * Type that defines the config required to initialize the + * ConvexVectorStore class. It includes the table name, + * index name, text field name, and embedding field name. + */ +export type ConvexVectorStoreConfig< + DataModel extends GenericDataModel, + TableName extends TableNamesInDataModel, + IndexName extends VectorIndexNames>, + TextFieldName extends FieldPaths>, + EmbeddingFieldName extends FieldPaths>, + MetadataFieldName extends FieldPaths>, + InsertMutation extends FunctionReference< + "mutation", + "internal", + { table: string; document: object } + >, + GetQuery extends FunctionReference< + "query", + "internal", + { id: string }, + object | null + > +> = { + readonly ctx: GenericActionCtx; + /** + * Defaults to "documents" + */ + readonly table?: TableName; + /** + * Defaults to "byEmbedding" + */ + readonly index?: IndexName; + /** + * Defaults to "text" + */ + readonly textField?: TextFieldName; + /** + * Defaults to "embedding" + */ + readonly embeddingField?: EmbeddingFieldName; + /** + * Defaults to "metadata" + */ + readonly metadataField?: MetadataFieldName; + /** + * Defaults to `internal.langchain.db.insert` + */ + readonly insert?: InsertMutation; + /** + * Defaults to `internal.langchain.db.get` + */ + readonly get?: GetQuery; +}; + +/** + * Class that is a wrapper around Convex storage and vector search. It is used + * to insert embeddings in Convex documents with a vector search index, + * and perform a vector search on them. + * + * ConvexVectorStore does NOT implement maxMarginalRelevanceSearch. + */ +export class ConvexVectorStore< + DataModel extends GenericDataModel, + TableName extends TableNamesInDataModel, + IndexName extends VectorIndexNames>, + TextFieldName extends FieldPaths>, + EmbeddingFieldName extends FieldPaths>, + MetadataFieldName extends FieldPaths>, + InsertMutation extends FunctionReference< + "mutation", + "internal", + { table: string; document: object } + >, + GetQuery extends FunctionReference< + "query", + "internal", + { id: string }, + object | null + > +> extends VectorStore { + /** + * Type that defines the filter used in the + * similaritySearchVectorWithScore and maxMarginalRelevanceSearch methods. + * It includes limit, filter and a flag to include embeddings. + */ + declare FilterType: { + filter?: ( + q: VectorFilterBuilder< + DocumentByInfo, + NamedVectorIndex, IndexName> + > + ) => FilterExpression; + includeEmbeddings?: boolean; + }; + + private readonly ctx: GenericActionCtx; + + private readonly table: TableName; + + private readonly index: IndexName; + + private readonly textField: TextFieldName; + + private readonly embeddingField: EmbeddingFieldName; + + private readonly metadataField: MetadataFieldName; + + private readonly insert: InsertMutation; + + private readonly get: GetQuery; + + _vectorstoreType(): string { + return "convex"; + } + + constructor( + embeddings: Embeddings, + config: ConvexVectorStoreConfig< + DataModel, + TableName, + IndexName, + TextFieldName, + EmbeddingFieldName, + MetadataFieldName, + InsertMutation, + GetQuery + > + ) { + super(embeddings, config); + this.ctx = config.ctx; + this.table = config.table ?? ("documents" as TableName); + this.index = config.index ?? ("byEmbedding" as IndexName); + this.textField = config.textField ?? ("text" as TextFieldName); + this.embeddingField = + config.embeddingField ?? ("embedding" as EmbeddingFieldName); + this.metadataField = + config.metadataField ?? ("metadata" as MetadataFieldName); + this.insert = + // eslint-disable-next-line @typescript-eslint/no-explicit-any + config.insert ?? (makeFunctionReference("langchain/db:insert") as any); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + this.get = config.get ?? (makeFunctionReference("langchain/db:get") as any); + } + + /** + * Add vectors and their corresponding documents to the Convex table. + * @param vectors Vectors to be added. + * @param documents Corresponding documents to be added. + * @returns Promise that resolves when the vectors and documents have been added. + */ + async addVectors(vectors: number[][], documents: Document[]): Promise { + const convexDocuments = vectors.map((embedding, idx) => ({ + [this.textField]: documents[idx].pageContent, + [this.embeddingField]: embedding, + [this.metadataField]: documents[idx].metadata, + })); + // TODO: Remove chunking when Convex handles the concurrent requests correctly + const PAGE_SIZE = 16; + for (let i = 0; i < convexDocuments.length; i += PAGE_SIZE) { + await Promise.all( + convexDocuments.slice(i, i + PAGE_SIZE).map((document) => + this.ctx.runMutation(this.insert, { + table: this.table, + document, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any) + ) + ); + } + } + + /** + * Add documents to the Convex table. It first converts + * the documents to vectors using the embeddings and then calls the + * addVectors method. + * @param documents Documents to be added. + * @returns Promise that resolves when the documents have been added. + */ + async addDocuments(documents: Document[]): Promise { + const texts = documents.map(({ pageContent }) => pageContent); + return this.addVectors( + await this.embeddings.embedDocuments(texts), + documents + ); + } + + /** + * Similarity search on the vectors stored in the + * Convex table. It returns a list of documents and their + * corresponding similarity scores. + * @param query Query vector for the similarity search. + * @param k Number of nearest neighbors to return. + * @param filter Optional filter to be applied. + * @returns Promise that resolves to a list of documents and their corresponding similarity scores. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: this["FilterType"] + ): Promise<[Document, number][]> { + const idsAndScores = await this.ctx.vectorSearch(this.table, this.index, { + vector: query, + limit: k, + filter: filter?.filter, + }); + + const documents = await Promise.all( + idsAndScores.map(({ _id }) => + // eslint-disable-next-line @typescript-eslint/no-explicit-any + this.ctx.runQuery(this.get, { id: _id } as any) + ) + ); + + return documents.map( + ( + { + [this.textField]: text, + [this.embeddingField]: embedding, + [this.metadataField]: metadata, + }, + idx + ) => [ + new Document({ + pageContent: text as string, + metadata: { + ...metadata, + ...(filter?.includeEmbeddings ? { embedding } : null), + }, + }), + idsAndScores[idx]._score, + ] + ); + } + + /** + * Static method to create an instance of ConvexVectorStore from a + * list of texts. It first converts the texts to vectors and then adds + * them to the Convex table. + * @param texts List of texts to be converted to vectors. + * @param metadatas Metadata for the texts. + * @param embeddings Embeddings to be used for conversion. + * @param dbConfig Database configuration for Convex. + * @returns Promise that resolves to a new instance of ConvexVectorStore. + */ + static async fromTexts< + DataModel extends GenericDataModel, + TableName extends TableNamesInDataModel, + IndexName extends VectorIndexNames>, + TextFieldName extends FieldPaths>, + EmbeddingFieldName extends FieldPaths>, + MetadataFieldName extends FieldPaths>, + InsertMutation extends FunctionReference< + "mutation", + "internal", + { table: string; document: object } + >, + GetQuery extends FunctionReference< + "query", + "internal", + { id: string }, + object | null + > + >( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + dbConfig: ConvexVectorStoreConfig< + DataModel, + TableName, + IndexName, + TextFieldName, + EmbeddingFieldName, + MetadataFieldName, + InsertMutation, + GetQuery + > + ): Promise< + ConvexVectorStore< + DataModel, + TableName, + IndexName, + TextFieldName, + EmbeddingFieldName, + MetadataFieldName, + InsertMutation, + GetQuery + > + > { + const docs = texts.map( + (text, i) => + new Document({ + pageContent: text, + metadata: Array.isArray(metadatas) ? metadatas[i] : metadatas, + }) + ); + return ConvexVectorStore.fromDocuments(docs, embeddings, dbConfig); + } + + /** + * Static method to create an instance of ConvexVectorStore from a + * list of documents. It first converts the documents to vectors and then + * adds them to the Convex table. + * @param docs List of documents to be converted to vectors. + * @param embeddings Embeddings to be used for conversion. + * @param dbConfig Database configuration for Convex. + * @returns Promise that resolves to a new instance of ConvexVectorStore. + */ + static async fromDocuments< + DataModel extends GenericDataModel, + TableName extends TableNamesInDataModel, + IndexName extends VectorIndexNames>, + TextFieldName extends FieldPaths>, + EmbeddingFieldName extends FieldPaths>, + MetadataFieldName extends FieldPaths>, + InsertMutation extends FunctionReference< + "mutation", + "internal", + { table: string; document: object } + >, + GetQuery extends FunctionReference< + "query", + "internal", + { id: string }, + object | null + > + >( + docs: Document[], + embeddings: Embeddings, + dbConfig: ConvexVectorStoreConfig< + DataModel, + TableName, + IndexName, + TextFieldName, + EmbeddingFieldName, + MetadataFieldName, + InsertMutation, + GetQuery + > + ): Promise< + ConvexVectorStore< + DataModel, + TableName, + IndexName, + TextFieldName, + EmbeddingFieldName, + MetadataFieldName, + InsertMutation, + GetQuery + > + > { + const instance = new this(embeddings, dbConfig); + await instance.addDocuments(docs); + return instance; + } +} diff --git a/libs/langchain-community/src/vectorstores/elasticsearch.ts b/libs/langchain-community/src/vectorstores/elasticsearch.ts new file mode 100644 index 000000000000..f10369270892 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/elasticsearch.ts @@ -0,0 +1,341 @@ +import * as uuid from "uuid"; +import { Client, estypes } from "@elastic/elasticsearch"; +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; +/** + * Type representing the k-nearest neighbors (k-NN) engine used in + * Elasticsearch. + */ +type ElasticKnnEngine = "hnsw"; +/** + * Type representing the similarity measure used in Elasticsearch. + */ +type ElasticSimilarity = "l2_norm" | "dot_product" | "cosine"; + +/** + * Interface defining the options for vector search in Elasticsearch. + */ +interface VectorSearchOptions { + readonly engine?: ElasticKnnEngine; + readonly similarity?: ElasticSimilarity; + readonly m?: number; + readonly efConstruction?: number; + readonly candidates?: number; +} + +/** + * Interface defining the arguments required to create an Elasticsearch + * client. + */ +export interface ElasticClientArgs { + readonly client: Client; + readonly indexName?: string; + readonly vectorSearchOptions?: VectorSearchOptions; +} + +/** + * Type representing a filter object in Elasticsearch. + */ +// eslint-disable-next-line @typescript-eslint/no-explicit-any +type ElasticFilter = object | { field: string; operator: string; value: any }[]; + +/** + * Class for interacting with an Elasticsearch database. It extends the + * VectorStore base class and provides methods for adding documents and + * vectors to the Elasticsearch database, performing similarity searches, + * deleting documents, and more. + */ +export class ElasticVectorSearch extends VectorStore { + declare FilterType: ElasticFilter; + + private readonly client: Client; + + private readonly indexName: string; + + private readonly engine: ElasticKnnEngine; + + private readonly similarity: ElasticSimilarity; + + private readonly efConstruction: number; + + private readonly m: number; + + private readonly candidates: number; + + _vectorstoreType(): string { + return "elasticsearch"; + } + + constructor(embeddings: Embeddings, args: ElasticClientArgs) { + super(embeddings, args); + + this.engine = args.vectorSearchOptions?.engine ?? "hnsw"; + this.similarity = args.vectorSearchOptions?.similarity ?? "l2_norm"; + this.m = args.vectorSearchOptions?.m ?? 16; + this.efConstruction = args.vectorSearchOptions?.efConstruction ?? 100; + this.candidates = args.vectorSearchOptions?.candidates ?? 200; + + this.client = args.client.child({ + headers: { "user-agent": "langchain-js-vs/0.0.1" }, + }); + this.indexName = args.indexName ?? "documents"; + } + + /** + * Method to add documents to the Elasticsearch database. It first + * converts the documents to vectors using the embeddings, then adds the + * vectors to the database. + * @param documents The documents to add to the database. + * @param options Optional parameter that can contain the IDs for the documents. + * @returns A promise that resolves with the IDs of the added documents. + */ + async addDocuments(documents: Document[], options?: { ids?: string[] }) { + const texts = documents.map(({ pageContent }) => pageContent); + return this.addVectors( + await this.embeddings.embedDocuments(texts), + documents, + options + ); + } + + /** + * Method to add vectors to the Elasticsearch database. It ensures the + * index exists, then adds the vectors and their corresponding documents + * to the database. + * @param vectors The vectors to add to the database. + * @param documents The documents corresponding to the vectors. + * @param options Optional parameter that can contain the IDs for the documents. + * @returns A promise that resolves with the IDs of the added documents. + */ + async addVectors( + vectors: number[][], + documents: Document[], + options?: { ids?: string[] } + ) { + await this.ensureIndexExists( + vectors[0].length, + this.engine, + this.similarity, + this.efConstruction, + this.m + ); + const documentIds = + options?.ids ?? Array.from({ length: vectors.length }, () => uuid.v4()); + const operations = vectors.flatMap((embedding, idx) => [ + { + index: { + _id: documentIds[idx], + _index: this.indexName, + }, + }, + { + embedding, + metadata: documents[idx].metadata, + text: documents[idx].pageContent, + }, + ]); + await this.client.bulk({ refresh: true, operations }); + return documentIds; + } + + /** + * Method to perform a similarity search in the Elasticsearch database + * using a vector. It returns the k most similar documents along with + * their similarity scores. + * @param query The query vector. + * @param k The number of most similar documents to return. + * @param filter Optional filter to apply to the search. + * @returns A promise that resolves with an array of tuples, where each tuple contains a Document and its similarity score. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: ElasticFilter + ): Promise<[Document, number][]> { + const result = await this.client.search({ + index: this.indexName, + size: k, + knn: { + field: "embedding", + query_vector: query, + filter: this.buildMetadataTerms(filter), + k, + num_candidates: this.candidates, + }, + }); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return result.hits.hits.map((hit: any) => [ + new Document({ + pageContent: hit._source.text, + metadata: hit._source.metadata, + }), + hit._score, + ]); + } + + /** + * Method to delete documents from the Elasticsearch database. + * @param params Object containing the IDs of the documents to delete. + * @returns A promise that resolves when the deletion is complete. + */ + async delete(params: { ids: string[] }): Promise { + const operations = params.ids.map((id) => ({ + delete: { + _id: id, + _index: this.indexName, + }, + })); + await this.client.bulk({ refresh: true, operations }); + } + + /** + * Static method to create an ElasticVectorSearch instance from texts. It + * creates Document instances from the texts and their corresponding + * metadata, then calls the fromDocuments method to create the + * ElasticVectorSearch instance. + * @param texts The texts to create the ElasticVectorSearch instance from. + * @param metadatas The metadata corresponding to the texts. + * @param embeddings The embeddings to use for the documents. + * @param args The arguments to create the Elasticsearch client. + * @returns A promise that resolves with the created ElasticVectorSearch instance. + */ + static fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + args: ElasticClientArgs + ): Promise { + const documents = texts.map((text, idx) => { + const metadata = Array.isArray(metadatas) ? metadatas[idx] : metadatas; + return new Document({ pageContent: text, metadata }); + }); + + return ElasticVectorSearch.fromDocuments(documents, embeddings, args); + } + + /** + * Static method to create an ElasticVectorSearch instance from Document + * instances. It adds the documents to the Elasticsearch database, then + * returns the ElasticVectorSearch instance. + * @param docs The Document instances to create the ElasticVectorSearch instance from. + * @param embeddings The embeddings to use for the documents. + * @param dbConfig The configuration for the Elasticsearch database. + * @returns A promise that resolves with the created ElasticVectorSearch instance. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig: ElasticClientArgs + ): Promise { + const store = new ElasticVectorSearch(embeddings, dbConfig); + await store.addDocuments(docs).then(() => store); + return store; + } + + /** + * Static method to create an ElasticVectorSearch instance from an + * existing index in the Elasticsearch database. It checks if the index + * exists, then returns the ElasticVectorSearch instance if it does. + * @param embeddings The embeddings to use for the documents. + * @param dbConfig The configuration for the Elasticsearch database. + * @returns A promise that resolves with the created ElasticVectorSearch instance if the index exists, otherwise it throws an error. + */ + static async fromExistingIndex( + embeddings: Embeddings, + dbConfig: ElasticClientArgs + ): Promise { + const store = new ElasticVectorSearch(embeddings, dbConfig); + const exists = await store.doesIndexExist(); + if (exists) { + return store; + } + throw new Error(`The index ${store.indexName} does not exist.`); + } + + private async ensureIndexExists( + dimension: number, + engine = "hnsw", + similarity = "l2_norm", + efConstruction = 100, + m = 16 + ): Promise { + const request: estypes.IndicesCreateRequest = { + index: this.indexName, + mappings: { + dynamic_templates: [ + { + // map all metadata properties to be keyword + "metadata.*": { + match_mapping_type: "*", + mapping: { type: "keyword" }, + }, + }, + ], + properties: { + text: { type: "text" }, + metadata: { type: "object" }, + embedding: { + type: "dense_vector", + dims: dimension, + index: true, + similarity, + index_options: { + type: engine, + m, + ef_construction: efConstruction, + }, + }, + }, + }, + }; + + const indexExists = await this.doesIndexExist(); + if (indexExists) return; + + await this.client.indices.create(request); + } + + private buildMetadataTerms( + filter?: ElasticFilter + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ): { [operator: string]: { [field: string]: any } }[] { + if (filter == null) return []; + const result = []; + const filters = Array.isArray(filter) + ? filter + : Object.entries(filter).map(([key, value]) => ({ + operator: "term", + field: key, + value, + })); + for (const condition of filters) { + result.push({ + [condition.operator]: { + [`metadata.${condition.field}`]: condition.value, + }, + }); + } + return result; + } + + /** + * Method to check if an index exists in the Elasticsearch database. + * @returns A promise that resolves with a boolean indicating whether the index exists. + */ + async doesIndexExist(): Promise { + return await this.client.indices.exists({ index: this.indexName }); + } + + /** + * Method to delete an index from the Elasticsearch database if it exists. + * @returns A promise that resolves when the deletion is complete. + */ + async deleteIfExists(): Promise { + const indexExists = await this.doesIndexExist(); + if (!indexExists) return; + + await this.client.indices.delete({ index: this.indexName }); + } +} diff --git a/libs/langchain-community/src/vectorstores/faiss.ts b/libs/langchain-community/src/vectorstores/faiss.ts new file mode 100644 index 000000000000..0403aa89c726 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/faiss.ts @@ -0,0 +1,461 @@ +import type { IndexFlatL2 } from "faiss-node"; +import type { NameRegistry, Parser } from "pickleparser"; +import * as uuid from "uuid"; +import { Embeddings } from "@langchain/core/embeddings"; +import { SaveableVectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; +import { SynchronousInMemoryDocstore } from "../stores/doc/in_memory.js"; + +/** + * Interface for the arguments required to initialize a FaissStore + * instance. + */ +export interface FaissLibArgs { + docstore?: SynchronousInMemoryDocstore; + index?: IndexFlatL2; + mapping?: Record; +} + +/** + * A class that wraps the FAISS (Facebook AI Similarity Search) vector + * database for efficient similarity search and clustering of dense + * vectors. + */ +export class FaissStore extends SaveableVectorStore { + _index?: IndexFlatL2; + + _mapping: Record; + + docstore: SynchronousInMemoryDocstore; + + args: FaissLibArgs; + + _vectorstoreType(): string { + return "faiss"; + } + + getMapping(): Record { + return this._mapping; + } + + getDocstore(): SynchronousInMemoryDocstore { + return this.docstore; + } + + constructor(embeddings: Embeddings, args: FaissLibArgs) { + super(embeddings, args); + this.args = args; + this._index = args.index; + this._mapping = args.mapping ?? {}; + this.embeddings = embeddings; + this.docstore = args?.docstore ?? new SynchronousInMemoryDocstore(); + } + + /** + * Adds an array of Document objects to the store. + * @param documents An array of Document objects. + * @returns A Promise that resolves when the documents have been added. + */ + async addDocuments(documents: Document[], options?: { ids?: string[] }) { + const texts = documents.map(({ pageContent }) => pageContent); + return this.addVectors( + await this.embeddings.embedDocuments(texts), + documents, + options + ); + } + + public get index(): IndexFlatL2 { + if (!this._index) { + throw new Error( + "Vector store not initialised yet. Try calling `fromTexts`, `fromDocuments` or `fromIndex` first." + ); + } + return this._index; + } + + private set index(index: IndexFlatL2) { + this._index = index; + } + + /** + * Adds an array of vectors and their corresponding Document objects to + * the store. + * @param vectors An array of vectors. + * @param documents An array of Document objects corresponding to the vectors. + * @returns A Promise that resolves with an array of document IDs when the vectors and documents have been added. + */ + async addVectors( + vectors: number[][], + documents: Document[], + options?: { ids?: string[] } + ) { + if (vectors.length === 0) { + return []; + } + if (vectors.length !== documents.length) { + throw new Error(`Vectors and documents must have the same length`); + } + const dv = vectors[0].length; + if (!this._index) { + const { IndexFlatL2 } = await FaissStore.importFaiss(); + this._index = new IndexFlatL2(dv); + } + const d = this.index.getDimension(); + if (dv !== d) { + throw new Error( + `Vectors must have the same length as the number of dimensions (${d})` + ); + } + + const docstoreSize = this.index.ntotal(); + const documentIds = options?.ids ?? documents.map(() => uuid.v4()); + for (let i = 0; i < vectors.length; i += 1) { + const documentId = documentIds[i]; + const id = docstoreSize + i; + this.index.add(vectors[i]); + this._mapping[id] = documentId; + this.docstore.add({ [documentId]: documents[i] }); + } + return documentIds; + } + + /** + * Performs a similarity search in the vector store using a query vector + * and returns the top k results along with their scores. + * @param query A query vector. + * @param k The number of top results to return. + * @returns A Promise that resolves with an array of tuples, each containing a Document and its corresponding score. + */ + async similaritySearchVectorWithScore(query: number[], k: number) { + const d = this.index.getDimension(); + if (query.length !== d) { + throw new Error( + `Query vector must have the same length as the number of dimensions (${d})` + ); + } + if (k > this.index.ntotal()) { + const total = this.index.ntotal(); + console.warn( + `k (${k}) is greater than the number of elements in the index (${total}), setting k to ${total}` + ); + // eslint-disable-next-line no-param-reassign + k = total; + } + const result = this.index.search(query, k); + return result.labels.map((id, index) => { + const uuid = this._mapping[id]; + return [this.docstore.search(uuid), result.distances[index]] as [ + Document, + number + ]; + }); + } + + /** + * Saves the current state of the FaissStore to a specified directory. + * @param directory The directory to save the state to. + * @returns A Promise that resolves when the state has been saved. + */ + async save(directory: string) { + const fs = await import("node:fs/promises"); + const path = await import("node:path"); + await fs.mkdir(directory, { recursive: true }); + await Promise.all([ + this.index.write(path.join(directory, "faiss.index")), + await fs.writeFile( + path.join(directory, "docstore.json"), + JSON.stringify([ + Array.from(this.docstore._docs.entries()), + this._mapping, + ]) + ), + ]); + } + + /** + * Method to delete documents. + * @param params Object containing the IDs of the documents to delete. + * @returns A promise that resolves when the deletion is complete. + */ + async delete(params: { ids: string[] }) { + const documentIds = params.ids; + if (documentIds == null) { + throw new Error("No documentIds provided to delete."); + } + + const mappings = new Map( + Object.entries(this._mapping).map(([key, value]) => [ + parseInt(key, 10), + value, + ]) + ); + const reversedMappings = new Map( + Array.from(mappings, (entry) => [entry[1], entry[0]]) + ); + + const missingIds = new Set( + documentIds.filter((id) => !reversedMappings.has(id)) + ); + if (missingIds.size > 0) { + throw new Error( + `Some specified documentIds do not exist in the current store. DocumentIds not found: ${Array.from( + missingIds + ).join(", ")}` + ); + } + + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + const indexIdToDelete = documentIds.map((id) => reversedMappings.get(id)!); + + // remove from index + this.index.removeIds(indexIdToDelete); + // remove from docstore + documentIds.forEach((id) => { + this.docstore._docs.delete(id); + }); + // remove from mappings + indexIdToDelete.forEach((id) => { + mappings.delete(id); + }); + + this._mapping = { ...Array.from(mappings.values()) }; + } + + /** + * Merges the current FaissStore with another FaissStore. + * @param targetIndex The FaissStore to merge with. + * @returns A Promise that resolves with an array of document IDs when the merge is complete. + */ + async mergeFrom(targetIndex: FaissStore) { + const targetIndexDimensions = targetIndex.index.getDimension(); + if (!this._index) { + const { IndexFlatL2 } = await FaissStore.importFaiss(); + this._index = new IndexFlatL2(targetIndexDimensions); + } + const d = this.index.getDimension(); + if (targetIndexDimensions !== d) { + throw new Error("Cannot merge indexes with different dimensions."); + } + const targetMapping = targetIndex.getMapping(); + const targetDocstore = targetIndex.getDocstore(); + const targetSize = targetIndex.index.ntotal(); + const documentIds = []; + const currentDocstoreSize = this.index.ntotal(); + for (let i = 0; i < targetSize; i += 1) { + const targetId = targetMapping[i]; + documentIds.push(targetId); + const targetDocument = targetDocstore.search(targetId); + const id = currentDocstoreSize + i; + this._mapping[id] = targetId; + this.docstore.add({ [targetId]: targetDocument }); + } + this.index.mergeFrom(targetIndex.index); + return documentIds; + } + + /** + * Loads a FaissStore from a specified directory. + * @param directory The directory to load the FaissStore from. + * @param embeddings An Embeddings object. + * @returns A Promise that resolves with a new FaissStore instance. + */ + static async load(directory: string, embeddings: Embeddings) { + const fs = await import("node:fs/promises"); + const path = await import("node:path"); + const readStore = (directory: string) => + fs + .readFile(path.join(directory, "docstore.json"), "utf8") + .then(JSON.parse) as Promise< + [Map, Record] + >; + const readIndex = async (directory: string) => { + const { IndexFlatL2 } = await this.importFaiss(); + return IndexFlatL2.read(path.join(directory, "faiss.index")); + }; + const [[docstoreFiles, mapping], index] = await Promise.all([ + readStore(directory), + readIndex(directory), + ]); + const docstore = new SynchronousInMemoryDocstore(new Map(docstoreFiles)); + return new this(embeddings, { docstore, index, mapping }); + } + + static async loadFromPython(directory: string, embeddings: Embeddings) { + const fs = await import("node:fs/promises"); + const path = await import("node:path"); + const { Parser, NameRegistry } = await this.importPickleparser(); + + class PyDocument extends Map { + toDocument(): Document { + return new Document({ + pageContent: this.get("page_content"), + metadata: this.get("metadata"), + }); + } + } + + class PyInMemoryDocstore { + _dict: Map; + + toInMemoryDocstore(): SynchronousInMemoryDocstore { + const s = new SynchronousInMemoryDocstore(); + for (const [key, value] of Object.entries(this._dict)) { + s._docs.set(key, value.toDocument()); + } + return s; + } + } + + const readStore = async (directory: string) => { + const pkl = await fs.readFile( + path.join(directory, "index.pkl"), + "binary" + ); + const buffer = Buffer.from(pkl, "binary"); + + const registry = new NameRegistry() + .register( + "langchain.docstore.in_memory", + "InMemoryDocstore", + PyInMemoryDocstore + ) + .register("langchain.schema", "Document", PyDocument) + .register("langchain.docstore.document", "Document", PyDocument) + .register("langchain.schema.document", "Document", PyDocument) + .register("pathlib", "WindowsPath", (...args) => args.join("\\")) + .register("pathlib", "PosixPath", (...args) => args.join("/")); + + const pickleparser = new Parser({ + nameResolver: registry, + }); + const [rawStore, mapping] = + pickleparser.parse<[PyInMemoryDocstore, Record]>( + buffer + ); + const store = rawStore.toInMemoryDocstore(); + return { store, mapping }; + }; + const readIndex = async (directory: string) => { + const { IndexFlatL2 } = await this.importFaiss(); + return IndexFlatL2.read(path.join(directory, "index.faiss")); + }; + const [store, index] = await Promise.all([ + readStore(directory), + readIndex(directory), + ]); + return new this(embeddings, { + docstore: store.store, + index, + mapping: store.mapping, + }); + } + + /** + * Creates a new FaissStore from an array of texts, their corresponding + * metadata, and an Embeddings object. + * @param texts An array of texts. + * @param metadatas An array of metadata corresponding to the texts, or a single metadata object to be used for all texts. + * @param embeddings An Embeddings object. + * @param dbConfig An optional configuration object for the document store. + * @returns A Promise that resolves with a new FaissStore instance. + */ + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + dbConfig?: { + docstore?: SynchronousInMemoryDocstore; + } + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + return this.fromDocuments(docs, embeddings, dbConfig); + } + + /** + * Creates a new FaissStore from an array of Document objects and an + * Embeddings object. + * @param docs An array of Document objects. + * @param embeddings An Embeddings object. + * @param dbConfig An optional configuration object for the document store. + * @returns A Promise that resolves with a new FaissStore instance. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig?: { + docstore?: SynchronousInMemoryDocstore; + } + ): Promise { + const args: FaissLibArgs = { + docstore: dbConfig?.docstore, + }; + const instance = new this(embeddings, args); + await instance.addDocuments(docs); + return instance; + } + + /** + * Creates a new FaissStore from an existing FaissStore and an Embeddings + * object. + * @param targetIndex An existing FaissStore. + * @param embeddings An Embeddings object. + * @param dbConfig An optional configuration object for the document store. + * @returns A Promise that resolves with a new FaissStore instance. + */ + static async fromIndex( + targetIndex: FaissStore, + embeddings: Embeddings, + dbConfig?: { + docstore?: SynchronousInMemoryDocstore; + } + ): Promise { + const args: FaissLibArgs = { + docstore: dbConfig?.docstore, + }; + const instance = new this(embeddings, args); + await instance.mergeFrom(targetIndex); + return instance; + } + + static async importFaiss(): Promise<{ IndexFlatL2: typeof IndexFlatL2 }> { + try { + const { + default: { IndexFlatL2 }, + } = await import("faiss-node"); + + return { IndexFlatL2 }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (err: any) { + throw new Error( + `Could not import faiss-node. Please install faiss-node as a dependency with, e.g. \`npm install -S faiss-node\`.\n\nError: ${err?.message}` + ); + } + } + + static async importPickleparser(): Promise<{ + Parser: typeof Parser; + NameRegistry: typeof NameRegistry; + }> { + try { + const { + default: { Parser, NameRegistry }, + } = await import("pickleparser"); + + return { Parser, NameRegistry }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (err: any) { + throw new Error( + `Could not import pickleparser. Please install pickleparser as a dependency with, e.g. \`npm install -S pickleparser\`.\n\nError: ${err?.message}` + ); + } + } +} diff --git a/libs/langchain-community/src/vectorstores/googlevertexai.ts b/libs/langchain-community/src/vectorstores/googlevertexai.ts new file mode 100644 index 000000000000..d693f460f30d --- /dev/null +++ b/libs/langchain-community/src/vectorstores/googlevertexai.ts @@ -0,0 +1,738 @@ +import * as uuid from "uuid"; +import flatten from "flat"; +import { GoogleAuth, GoogleAuthOptions } from "google-auth-library"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Embeddings } from "@langchain/core/embeddings"; +import { Document, DocumentInput } from "@langchain/core/documents"; +import { + AsyncCaller, + AsyncCallerCallOptions, + AsyncCallerParams, +} from "@langchain/core/utils/async_caller"; + +import { GoogleVertexAIConnection } from "../utils/googlevertexai-connection.js"; +import { Docstore } from "../stores/doc/base.js"; +import { + GoogleVertexAIConnectionParams, + GoogleResponse, + GoogleAbstractedClientOpsMethod, +} from "../types/googlevertexai-types.js"; + +/** + * Allows us to create IdDocument classes that contain the ID. + */ +export interface IdDocumentInput extends DocumentInput { + id?: string; +} + +/** + * A Document that optionally includes the ID of the document. + */ +export class IdDocument extends Document implements IdDocumentInput { + id?: string; + + constructor(fields: IdDocumentInput) { + super(fields); + this.id = fields.id; + } +} + +interface IndexEndpointConnectionParams + extends GoogleVertexAIConnectionParams { + indexEndpoint: string; +} + +interface DeployedIndex { + id: string; + index: string; + // There are other attributes, but we don't care about them right now +} + +interface IndexEndpointResponse extends GoogleResponse { + data: { + deployedIndexes: DeployedIndex[]; + publicEndpointDomainName: string; + // There are other attributes, but we don't care about them right now + }; +} + +class IndexEndpointConnection extends GoogleVertexAIConnection< + AsyncCallerCallOptions, + IndexEndpointResponse, + GoogleAuthOptions +> { + indexEndpoint: string; + + constructor(fields: IndexEndpointConnectionParams, caller: AsyncCaller) { + super(fields, caller, new GoogleAuth(fields.authOptions)); + + this.indexEndpoint = fields.indexEndpoint; + } + + async buildUrl(): Promise { + const projectId = await this.client.getProjectId(); + const url = `https://${this.endpoint}/${this.apiVersion}/projects/${projectId}/locations/${this.location}/indexEndpoints/${this.indexEndpoint}`; + return url; + } + + buildMethod(): GoogleAbstractedClientOpsMethod { + return "GET"; + } + + async request( + options: AsyncCallerCallOptions + ): Promise { + return this._request(undefined, options); + } +} + +/** + * Used to represent parameters that are necessary to delete documents + * from the matching engine. These must be a list of string IDs + */ +export interface MatchingEngineDeleteParams { + ids: string[]; +} + +interface RemoveDatapointParams + extends GoogleVertexAIConnectionParams { + index: string; +} + +interface RemoveDatapointRequest { + datapointIds: string[]; +} + +interface RemoveDatapointResponse extends GoogleResponse { + // Should be empty +} + +class RemoveDatapointConnection extends GoogleVertexAIConnection< + AsyncCallerCallOptions, + RemoveDatapointResponse, + GoogleAuthOptions +> { + index: string; + + constructor(fields: RemoveDatapointParams, caller: AsyncCaller) { + super(fields, caller, new GoogleAuth(fields.authOptions)); + + this.index = fields.index; + } + + async buildUrl(): Promise { + const projectId = await this.client.getProjectId(); + const url = `https://${this.endpoint}/${this.apiVersion}/projects/${projectId}/locations/${this.location}/indexes/${this.index}:removeDatapoints`; + return url; + } + + buildMethod(): GoogleAbstractedClientOpsMethod { + return "POST"; + } + + async request( + datapointIds: string[], + options: AsyncCallerCallOptions + ): Promise { + const data: RemoveDatapointRequest = { + datapointIds, + }; + return this._request(data, options); + } +} + +interface UpsertDatapointParams + extends GoogleVertexAIConnectionParams { + index: string; +} + +export interface Restriction { + namespace: string; + allowList?: string[]; + denyList?: string[]; +} + +interface CrowdingTag { + crowdingAttribute: string; +} + +interface IndexDatapoint { + datapointId: string; + featureVector: number[]; + restricts?: Restriction[]; + crowdingTag?: CrowdingTag; +} + +interface UpsertDatapointRequest { + datapoints: IndexDatapoint[]; +} + +interface UpsertDatapointResponse extends GoogleResponse { + // Should be empty +} + +class UpsertDatapointConnection extends GoogleVertexAIConnection< + AsyncCallerCallOptions, + UpsertDatapointResponse, + GoogleAuthOptions +> { + index: string; + + constructor(fields: UpsertDatapointParams, caller: AsyncCaller) { + super(fields, caller, new GoogleAuth(fields.authOptions)); + + this.index = fields.index; + } + + async buildUrl(): Promise { + const projectId = await this.client.getProjectId(); + const url = `https://${this.endpoint}/${this.apiVersion}/projects/${projectId}/locations/${this.location}/indexes/${this.index}:upsertDatapoints`; + return url; + } + + buildMethod(): GoogleAbstractedClientOpsMethod { + return "POST"; + } + + async request( + datapoints: IndexDatapoint[], + options: AsyncCallerCallOptions + ): Promise { + const data: UpsertDatapointRequest = { + datapoints, + }; + return this._request(data, options); + } +} + +interface FindNeighborsConnectionParams + extends GoogleVertexAIConnectionParams { + indexEndpoint: string; + + deployedIndexId: string; +} + +interface FindNeighborsRequestQuery { + datapoint: { + datapointId: string; + featureVector: number[]; + restricts?: Restriction[]; + }; + neighborCount: number; +} + +interface FindNeighborsRequest { + deployedIndexId: string; + queries: FindNeighborsRequestQuery[]; +} + +interface FindNeighborsResponseNeighbor { + datapoint: { + datapointId: string; + crowdingTag: { + crowdingTagAttribute: string; + }; + }; + distance: number; +} + +interface FindNeighborsResponseNearestNeighbor { + id: string; + neighbors: FindNeighborsResponseNeighbor[]; +} + +interface FindNeighborsResponse extends GoogleResponse { + data: { + nearestNeighbors: FindNeighborsResponseNearestNeighbor[]; + }; +} + +class FindNeighborsConnection + extends GoogleVertexAIConnection< + AsyncCallerCallOptions, + FindNeighborsResponse, + GoogleAuthOptions + > + implements FindNeighborsConnectionParams +{ + indexEndpoint: string; + + deployedIndexId: string; + + constructor(params: FindNeighborsConnectionParams, caller: AsyncCaller) { + super(params, caller, new GoogleAuth(params.authOptions)); + + this.indexEndpoint = params.indexEndpoint; + this.deployedIndexId = params.deployedIndexId; + } + + async buildUrl(): Promise { + const projectId = await this.client.getProjectId(); + const url = `https://${this.endpoint}/${this.apiVersion}/projects/${projectId}/locations/${this.location}/indexEndpoints/${this.indexEndpoint}:findNeighbors`; + return url; + } + + buildMethod(): GoogleAbstractedClientOpsMethod { + return "POST"; + } + + async request( + request: FindNeighborsRequest, + options: AsyncCallerCallOptions + ): Promise { + return this._request(request, options); + } +} + +/** + * Information about the Matching Engine public API endpoint. + * Primarily exported to allow for testing. + */ +export interface PublicAPIEndpointInfo { + apiEndpoint?: string; + + deployedIndexId?: string; +} + +/** + * Parameters necessary to configure the Matching Engine. + */ +export interface MatchingEngineArgs + extends GoogleVertexAIConnectionParams, + IndexEndpointConnectionParams, + UpsertDatapointParams { + docstore: Docstore; + + callerParams?: AsyncCallerParams; + + callerOptions?: AsyncCallerCallOptions; + + apiEndpoint?: string; + + deployedIndexId?: string; +} + +/** + * A class that represents a connection to a Google Vertex AI Matching Engine + * instance. + */ +export class MatchingEngine extends VectorStore implements MatchingEngineArgs { + declare FilterType: Restriction[]; + + /** + * Docstore that retains the document, stored by ID + */ + docstore: Docstore; + + /** + * The host to connect to for queries and upserts. + */ + apiEndpoint: string; + + apiVersion = "v1"; + + endpoint = "us-central1-aiplatform.googleapis.com"; + + location = "us-central1"; + + /** + * The id for the index endpoint + */ + indexEndpoint: string; + + /** + * The id for the index + */ + index: string; + + /** + * The id for the "deployed index", which is an identifier in the + * index endpoint that references the index (but is not the index id) + */ + deployedIndexId: string; + + callerParams: AsyncCallerParams; + + callerOptions: AsyncCallerCallOptions; + + caller: AsyncCaller; + + indexEndpointClient: IndexEndpointConnection; + + removeDatapointClient: RemoveDatapointConnection; + + upsertDatapointClient: UpsertDatapointConnection; + + constructor(embeddings: Embeddings, args: MatchingEngineArgs) { + super(embeddings, args); + + this.embeddings = embeddings; + this.docstore = args.docstore; + + this.apiEndpoint = args.apiEndpoint ?? this.apiEndpoint; + this.deployedIndexId = args.deployedIndexId ?? this.deployedIndexId; + + this.apiVersion = args.apiVersion ?? this.apiVersion; + this.endpoint = args.endpoint ?? this.endpoint; + this.location = args.location ?? this.location; + this.indexEndpoint = args.indexEndpoint ?? this.indexEndpoint; + this.index = args.index ?? this.index; + + this.callerParams = args.callerParams ?? this.callerParams; + this.callerOptions = args.callerOptions ?? this.callerOptions; + this.caller = new AsyncCaller(this.callerParams || {}); + + const indexClientParams: IndexEndpointConnectionParams = { + endpoint: this.endpoint, + location: this.location, + apiVersion: this.apiVersion, + indexEndpoint: this.indexEndpoint, + }; + this.indexEndpointClient = new IndexEndpointConnection( + indexClientParams, + this.caller + ); + + const removeClientParams: RemoveDatapointParams = { + endpoint: this.endpoint, + location: this.location, + apiVersion: this.apiVersion, + index: this.index, + }; + this.removeDatapointClient = new RemoveDatapointConnection( + removeClientParams, + this.caller + ); + + const upsertClientParams: UpsertDatapointParams = { + endpoint: this.endpoint, + location: this.location, + apiVersion: this.apiVersion, + index: this.index, + }; + this.upsertDatapointClient = new UpsertDatapointConnection( + upsertClientParams, + this.caller + ); + } + + _vectorstoreType(): string { + return "googlevertexai"; + } + + async addDocuments(documents: Document[]): Promise { + const texts: string[] = documents.map((doc) => doc.pageContent); + const vectors: number[][] = await this.embeddings.embedDocuments(texts); + return this.addVectors(vectors, documents); + } + + async addVectors(vectors: number[][], documents: Document[]): Promise { + if (vectors.length !== documents.length) { + throw new Error(`Vectors and metadata must have the same length`); + } + const datapoints: IndexDatapoint[] = vectors.map((vector, idx) => + this.buildDatapoint(vector, documents[idx]) + ); + const options = {}; + const response = await this.upsertDatapointClient.request( + datapoints, + options + ); + if (Object.keys(response?.data ?? {}).length === 0) { + // Nothing in the response in the body means we saved it ok + const idDoc = documents as IdDocument[]; + const docsToStore: Record = {}; + idDoc.forEach((doc) => { + if (doc.id) { + docsToStore[doc.id] = doc; + } + }); + await this.docstore.add(docsToStore); + } + } + + // TODO: Refactor this into a utility type and use with pinecone as well? + // eslint-disable-next-line @typescript-eslint/no-explicit-any + cleanMetadata(documentMetadata: Record): { + [key: string]: string | number | boolean | string[] | null; + } { + type metadataType = { + [key: string]: string | number | boolean | string[] | null; + }; + + function getStringArrays( + prefix: string, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + m: Record + ): Record { + let ret: Record = {}; + + Object.keys(m).forEach((key) => { + const newPrefix = prefix.length > 0 ? `${prefix}.${key}` : key; + const val = m[key]; + if (!val) { + // Ignore it + } else if (Array.isArray(val)) { + // Make sure everything in the array is a string + ret[newPrefix] = val.map((v) => `${v}`); + } else if (typeof val === "object") { + const subArrays = getStringArrays(newPrefix, val); + ret = { ...ret, ...subArrays }; + } + }); + + return ret; + } + + const stringArrays: Record = getStringArrays( + "", + documentMetadata + ); + + const flatMetadata: metadataType = flatten(documentMetadata); + Object.keys(flatMetadata).forEach((key) => { + Object.keys(stringArrays).forEach((arrayKey) => { + const matchKey = `${arrayKey}.`; + if (key.startsWith(matchKey)) { + delete flatMetadata[key]; + } + }); + }); + + const metadata: metadataType = { + ...flatMetadata, + ...stringArrays, + }; + return metadata; + } + + /** + * Given the metadata from a document, convert it to an array of Restriction + * objects that may be passed to the Matching Engine and stored. + * The default implementation flattens any metadata and includes it as + * an "allowList". Subclasses can choose to convert some of these to + * "denyList" items or to add additional restrictions (for example, to format + * dates into a different structure or to add additional restrictions + * based on the date). + * @param documentMetadata - The metadata from a document + * @returns a Restriction[] (or an array of a subclass, from the FilterType) + */ + metadataToRestrictions( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + documentMetadata: Record + ): this["FilterType"] { + const metadata = this.cleanMetadata(documentMetadata); + + const restrictions: this["FilterType"] = []; + for (const key of Object.keys(metadata)) { + // Make sure the value is an array (or that we'll ignore it) + let valArray; + const val = metadata[key]; + if (val === null) { + valArray = null; + } else if (Array.isArray(val) && val.length > 0) { + valArray = val; + } else { + valArray = [`${val}`]; + } + + // Add to the restrictions if we do have a valid value + if (valArray) { + // Determine if this key is for the allowList or denyList + // TODO: get which ones should be on the deny list + const listType = "allowList"; + + // Create the restriction + const restriction: Restriction = { + namespace: key, + [listType]: valArray, + }; + + // Add it to the restriction list + restrictions.push(restriction); + } + } + return restrictions; + } + + /** + * Create an index datapoint for the vector and document id. + * If an id does not exist, create it and set the document to its value. + * @param vector + * @param document + */ + buildDatapoint(vector: number[], document: IdDocument): IndexDatapoint { + if (!document.id) { + // eslint-disable-next-line no-param-reassign + document.id = uuid.v4(); + } + const ret: IndexDatapoint = { + datapointId: document.id, + featureVector: vector, + }; + const restrictions = this.metadataToRestrictions(document.metadata); + if (restrictions?.length > 0) { + ret.restricts = restrictions; + } + return ret; + } + + async delete(params: MatchingEngineDeleteParams): Promise { + const options = {}; + await this.removeDatapointClient.request(params.ids, options); + } + + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: this["FilterType"] + ): Promise<[Document, number][]> { + // Format the query into the request + const deployedIndexId = await this.getDeployedIndexId(); + const requestQuery: FindNeighborsRequestQuery = { + neighborCount: k, + datapoint: { + datapointId: `0`, + featureVector: query, + }, + }; + if (filter) { + requestQuery.datapoint.restricts = filter; + } + const request: FindNeighborsRequest = { + deployedIndexId, + queries: [requestQuery], + }; + + // Build the connection. + // Has to be done here, since we defer getting the endpoint until + // we need it. + const apiEndpoint = await this.getPublicAPIEndpoint(); + const findNeighborsParams: FindNeighborsConnectionParams = { + endpoint: apiEndpoint, + indexEndpoint: this.indexEndpoint, + apiVersion: this.apiVersion, + location: this.location, + deployedIndexId, + }; + const connection = new FindNeighborsConnection( + findNeighborsParams, + this.caller + ); + + // Make the call + const options = {}; + const response = await connection.request(request, options); + + // Get the document for each datapoint id and return them + const nearestNeighbors = response?.data?.nearestNeighbors ?? []; + const nearestNeighbor = nearestNeighbors[0]; + const neighbors = nearestNeighbor?.neighbors ?? []; + const ret: [Document, number][] = await Promise.all( + neighbors.map(async (neighbor) => { + const id = neighbor?.datapoint?.datapointId; + const distance = neighbor?.distance; + let doc: IdDocument; + try { + doc = await this.docstore.search(id); + } catch (xx) { + // Documents that are in the index are returned, even if they + // are not in the document store, to allow for some way to get + // the id so they can be deleted. + console.error(xx); + console.warn( + [ + `Document with id "${id}" is missing from the backing docstore.`, + `This can occur if you clear the docstore without deleting from the corresponding Matching Engine index.`, + `To resolve this, you should call .delete() with this id as part of the "ids" parameter.`, + ].join("\n") + ); + doc = new Document({ pageContent: `Missing document ${id}` }); + } + doc.id ??= id; + return [doc, distance]; + }) + ); + + return ret; + } + + /** + * For this index endpoint, figure out what API Endpoint URL and deployed + * index ID should be used to do upserts and queries. + * Also sets the `apiEndpoint` and `deployedIndexId` property for future use. + * @return The URL + */ + async determinePublicAPIEndpoint(): Promise { + const response: IndexEndpointResponse = + await this.indexEndpointClient.request(this.callerOptions); + + // Get the endpoint + const publicEndpointDomainName = response?.data?.publicEndpointDomainName; + this.apiEndpoint = publicEndpointDomainName; + + // Determine which of the deployed indexes match the index id + // and get the deployed index id. The list of deployed index ids + // contain the "index name" or path, but not the index id by itself, + // so we need to extract it from the name + const indexPathPattern = /projects\/.+\/locations\/.+\/indexes\/(.+)$/; + const deployedIndexes = response?.data?.deployedIndexes ?? []; + const deployedIndex = deployedIndexes.find((index) => { + const deployedIndexPath = index.index; + const match = deployedIndexPath.match(indexPathPattern); + if (match) { + const [, potentialIndexId] = match; + if (potentialIndexId === this.index) { + return true; + } + } + return false; + }); + if (deployedIndex) { + this.deployedIndexId = deployedIndex.id; + } + + return { + apiEndpoint: this.apiEndpoint, + deployedIndexId: this.deployedIndexId, + }; + } + + async getPublicAPIEndpoint(): Promise { + return ( + this.apiEndpoint ?? (await this.determinePublicAPIEndpoint()).apiEndpoint + ); + } + + async getDeployedIndexId(): Promise { + return ( + this.deployedIndexId ?? + (await this.determinePublicAPIEndpoint()).deployedIndexId + ); + } + + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + dbConfig: MatchingEngineArgs + ): Promise { + const docs: Document[] = texts.map( + (text, index): Document => ({ + pageContent: text, + metadata: Array.isArray(metadatas) ? metadatas[index] : metadatas, + }) + ); + return this.fromDocuments(docs, embeddings, dbConfig); + } + + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig: MatchingEngineArgs + ): Promise { + const ret = new MatchingEngine(embeddings, dbConfig); + await ret.addDocuments(docs); + return ret; + } +} diff --git a/libs/langchain-community/src/vectorstores/hnswlib.ts b/libs/langchain-community/src/vectorstores/hnswlib.ts new file mode 100644 index 000000000000..86d896a70566 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/hnswlib.ts @@ -0,0 +1,354 @@ +import type { + HierarchicalNSW as HierarchicalNSWT, + SpaceName, +} from "hnswlib-node"; +import { Embeddings } from "@langchain/core/embeddings"; +import { SaveableVectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; +import { SynchronousInMemoryDocstore } from "../stores/doc/in_memory.js"; + +/** + * Interface for the base configuration of HNSWLib. It includes the space + * name and the number of dimensions. + */ +export interface HNSWLibBase { + space: SpaceName; + numDimensions?: number; +} + +/** + * Interface for the arguments that can be passed to the HNSWLib + * constructor. It extends HNSWLibBase and includes properties for the + * document store and HNSW index. + */ +export interface HNSWLibArgs extends HNSWLibBase { + docstore?: SynchronousInMemoryDocstore; + index?: HierarchicalNSWT; +} + +/** + * Class that implements a vector store using Hierarchical Navigable Small + * World (HNSW) graphs. It extends the SaveableVectorStore class and + * provides methods for adding documents and vectors, performing + * similarity searches, and saving and loading the vector store. + */ +export class HNSWLib extends SaveableVectorStore { + declare FilterType: (doc: Document) => boolean; + + _index?: HierarchicalNSWT; + + docstore: SynchronousInMemoryDocstore; + + args: HNSWLibBase; + + _vectorstoreType(): string { + return "hnswlib"; + } + + constructor(embeddings: Embeddings, args: HNSWLibArgs) { + super(embeddings, args); + this._index = args.index; + this.args = args; + this.embeddings = embeddings; + this.docstore = args?.docstore ?? new SynchronousInMemoryDocstore(); + } + + /** + * Method to add documents to the vector store. It first converts the + * documents to vectors using the embeddings, then adds the vectors to the + * vector store. + * @param documents The documents to be added to the vector store. + * @returns A Promise that resolves when the documents have been added. + */ + async addDocuments(documents: Document[]): Promise { + const texts = documents.map(({ pageContent }) => pageContent); + return this.addVectors( + await this.embeddings.embedDocuments(texts), + documents + ); + } + + private static async getHierarchicalNSW(args: HNSWLibBase) { + const { HierarchicalNSW } = await HNSWLib.imports(); + if (!args.space) { + throw new Error("hnswlib-node requires a space argument"); + } + if (args.numDimensions === undefined) { + throw new Error("hnswlib-node requires a numDimensions argument"); + } + return new HierarchicalNSW(args.space, args.numDimensions); + } + + private async initIndex(vectors: number[][]) { + if (!this._index) { + if (this.args.numDimensions === undefined) { + this.args.numDimensions = vectors[0].length; + } + this.index = await HNSWLib.getHierarchicalNSW(this.args); + } + if (!this.index.getCurrentCount()) { + this.index.initIndex(vectors.length); + } + } + + public get index(): HierarchicalNSWT { + if (!this._index) { + throw new Error( + "Vector store not initialised yet. Try calling `addTexts` first." + ); + } + return this._index; + } + + private set index(index: HierarchicalNSWT) { + this._index = index; + } + + /** + * Method to add vectors to the vector store. It first initializes the + * index if it hasn't been initialized yet, then adds the vectors to the + * index and the documents to the document store. + * @param vectors The vectors to be added to the vector store. + * @param documents The documents corresponding to the vectors. + * @returns A Promise that resolves when the vectors and documents have been added. + */ + async addVectors(vectors: number[][], documents: Document[]) { + if (vectors.length === 0) { + return; + } + await this.initIndex(vectors); + + // TODO here we could optionally normalise the vectors to unit length + // so that dot product is equivalent to cosine similarity, like this + // https://github.com/nmslib/hnswlib/issues/384#issuecomment-1155737730 + // While we only support OpenAI embeddings this isn't necessary + if (vectors.length !== documents.length) { + throw new Error(`Vectors and metadatas must have the same length`); + } + if (vectors[0].length !== this.args.numDimensions) { + throw new Error( + `Vectors must have the same length as the number of dimensions (${this.args.numDimensions})` + ); + } + const capacity = this.index.getMaxElements(); + const needed = this.index.getCurrentCount() + vectors.length; + if (needed > capacity) { + this.index.resizeIndex(needed); + } + const docstoreSize = this.index.getCurrentCount(); + const toSave: Record = {}; + for (let i = 0; i < vectors.length; i += 1) { + this.index.addPoint(vectors[i], docstoreSize + i); + toSave[docstoreSize + i] = documents[i]; + } + this.docstore.add(toSave); + } + + /** + * Method to perform a similarity search in the vector store using a query + * vector. It returns the k most similar documents along with their + * similarity scores. An optional filter function can be provided to + * filter the documents. + * @param query The query vector. + * @param k The number of most similar documents to return. + * @param filter An optional filter function to filter the documents. + * @returns A Promise that resolves to an array of tuples, where each tuple contains a document and its similarity score. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: this["FilterType"] + ) { + if (this.args.numDimensions && !this._index) { + await this.initIndex([[]]); + } + if (query.length !== this.args.numDimensions) { + throw new Error( + `Query vector must have the same length as the number of dimensions (${this.args.numDimensions})` + ); + } + if (k > this.index.getCurrentCount()) { + const total = this.index.getCurrentCount(); + console.warn( + `k (${k}) is greater than the number of elements in the index (${total}), setting k to ${total}` + ); + // eslint-disable-next-line no-param-reassign + k = total; + } + const filterFunction = (label: number): boolean => { + if (!filter) { + return true; + } + const document = this.docstore.search(String(label)); + // eslint-disable-next-line no-instanceof/no-instanceof + if (typeof document !== "string") { + return filter(document); + } + return false; + }; + const result = this.index.searchKnn( + query, + k, + filter ? filterFunction : undefined + ); + return result.neighbors.map( + (docIndex, resultIndex) => + [ + this.docstore.search(String(docIndex)), + result.distances[resultIndex], + ] as [Document, number] + ); + } + + /** + * Method to delete the vector store from a directory. It deletes the + * hnswlib.index file, the docstore.json file, and the args.json file from + * the directory. + * @param params An object with a directory property that specifies the directory from which to delete the vector store. + * @returns A Promise that resolves when the vector store has been deleted. + */ + async delete(params: { directory: string }) { + const fs = await import("node:fs/promises"); + const path = await import("node:path"); + try { + await fs.access(path.join(params.directory, "hnswlib.index")); + } catch (err) { + throw new Error( + `Directory ${params.directory} does not contain a hnswlib.index file.` + ); + } + + await Promise.all([ + await fs.rm(path.join(params.directory, "hnswlib.index"), { + force: true, + }), + await fs.rm(path.join(params.directory, "docstore.json"), { + force: true, + }), + await fs.rm(path.join(params.directory, "args.json"), { force: true }), + ]); + } + + /** + * Method to save the vector store to a directory. It saves the HNSW + * index, the arguments, and the document store to the directory. + * @param directory The directory to which to save the vector store. + * @returns A Promise that resolves when the vector store has been saved. + */ + async save(directory: string) { + const fs = await import("node:fs/promises"); + const path = await import("node:path"); + await fs.mkdir(directory, { recursive: true }); + await Promise.all([ + this.index.writeIndex(path.join(directory, "hnswlib.index")), + await fs.writeFile( + path.join(directory, "args.json"), + JSON.stringify(this.args) + ), + await fs.writeFile( + path.join(directory, "docstore.json"), + JSON.stringify(Array.from(this.docstore._docs.entries())) + ), + ]); + } + + /** + * Static method to load a vector store from a directory. It reads the + * HNSW index, the arguments, and the document store from the directory, + * then creates a new HNSWLib instance with these values. + * @param directory The directory from which to load the vector store. + * @param embeddings The embeddings to be used by the HNSWLib instance. + * @returns A Promise that resolves to a new HNSWLib instance. + */ + static async load(directory: string, embeddings: Embeddings) { + const fs = await import("node:fs/promises"); + const path = await import("node:path"); + const args = JSON.parse( + await fs.readFile(path.join(directory, "args.json"), "utf8") + ); + const index = await HNSWLib.getHierarchicalNSW(args); + const [docstoreFiles] = await Promise.all([ + fs + .readFile(path.join(directory, "docstore.json"), "utf8") + .then(JSON.parse), + index.readIndex(path.join(directory, "hnswlib.index")), + ]); + args.docstore = new SynchronousInMemoryDocstore(new Map(docstoreFiles)); + + args.index = index; + + return new HNSWLib(embeddings, args); + } + + /** + * Static method to create a new HNSWLib instance from texts and metadata. + * It creates a new Document instance for each text and metadata, then + * calls the fromDocuments method to create the HNSWLib instance. + * @param texts The texts to be used to create the documents. + * @param metadatas The metadata to be used to create the documents. + * @param embeddings The embeddings to be used by the HNSWLib instance. + * @param dbConfig An optional configuration object for the document store. + * @returns A Promise that resolves to a new HNSWLib instance. + */ + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + dbConfig?: { + docstore?: SynchronousInMemoryDocstore; + } + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + return HNSWLib.fromDocuments(docs, embeddings, dbConfig); + } + + /** + * Static method to create a new HNSWLib instance from documents. It + * creates a new HNSWLib instance, adds the documents to it, then returns + * the instance. + * @param docs The documents to be added to the HNSWLib instance. + * @param embeddings The embeddings to be used by the HNSWLib instance. + * @param dbConfig An optional configuration object for the document store. + * @returns A Promise that resolves to a new HNSWLib instance. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig?: { + docstore?: SynchronousInMemoryDocstore; + } + ): Promise { + const args: HNSWLibArgs = { + docstore: dbConfig?.docstore, + space: "cosine", + }; + const instance = new this(embeddings, args); + await instance.addDocuments(docs); + return instance; + } + + static async imports(): Promise<{ + HierarchicalNSW: typeof HierarchicalNSWT; + }> { + try { + const { + default: { HierarchicalNSW }, + } = await import("hnswlib-node"); + + return { HierarchicalNSW }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (err: any) { + throw new Error( + `Could not import hnswlib-node. Please install hnswlib-node as a dependency with, e.g. \`npm install -S hnswlib-node\`.\n\nError: ${err?.message}` + ); + } + } +} diff --git a/libs/langchain-community/src/vectorstores/lancedb.ts b/libs/langchain-community/src/vectorstores/lancedb.ts new file mode 100644 index 000000000000..86b87a111758 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/lancedb.ts @@ -0,0 +1,152 @@ +import { Table } from "vectordb"; +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; + +/** + * Defines the arguments for the LanceDB class constructor. It includes a + * table and an optional textKey. + */ +export type LanceDBArgs = { + table: Table; + textKey?: string; +}; + +/** + * A wrapper for an open-source database for vector-search with persistent + * storage. It simplifies retrieval, filtering, and management of + * embeddings. + */ +export class LanceDB extends VectorStore { + private table: Table; + + private textKey: string; + + constructor(embeddings: Embeddings, args: LanceDBArgs) { + super(embeddings, args); + this.table = args.table; + this.embeddings = embeddings; + this.textKey = args.textKey || "text"; + } + + /** + * Adds documents to the database. + * @param documents The documents to be added. + * @returns A Promise that resolves when the documents have been added. + */ + async addDocuments(documents: Document[]): Promise { + const texts = documents.map(({ pageContent }) => pageContent); + return this.addVectors( + await this.embeddings.embedDocuments(texts), + documents + ); + } + + _vectorstoreType(): string { + return "lancedb"; + } + + /** + * Adds vectors and their corresponding documents to the database. + * @param vectors The vectors to be added. + * @param documents The corresponding documents to be added. + * @returns A Promise that resolves when the vectors and documents have been added. + */ + async addVectors(vectors: number[][], documents: Document[]): Promise { + if (vectors.length === 0) { + return; + } + if (vectors.length !== documents.length) { + throw new Error(`Vectors and documents must have the same length`); + } + + const data: Array> = []; + for (let i = 0; i < documents.length; i += 1) { + const record = { + vector: vectors[i], + [this.textKey]: documents[i].pageContent, + }; + Object.keys(documents[i].metadata).forEach((metaKey) => { + record[metaKey] = documents[i].metadata[metaKey]; + }); + data.push(record); + } + await this.table.add(data); + } + + /** + * Performs a similarity search on the vectors in the database and returns + * the documents and their scores. + * @param query The query vector. + * @param k The number of results to return. + * @returns A Promise that resolves with an array of tuples, each containing a Document and its score. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number + ): Promise<[Document, number][]> { + const results = await this.table.search(query).limit(k).execute(); + + const docsAndScore: [Document, number][] = []; + results.forEach((item) => { + const metadata: Record = {}; + Object.keys(item).forEach((key) => { + if (key !== "vector" && key !== "score" && key !== this.textKey) { + metadata[key] = item[key]; + } + }); + + docsAndScore.push([ + new Document({ + pageContent: item[this.textKey] as string, + metadata, + }), + item.score as number, + ]); + }); + return docsAndScore; + } + + /** + * Creates a new instance of LanceDB from texts. + * @param texts The texts to be converted into documents. + * @param metadatas The metadata for the texts. + * @param embeddings The embeddings to be managed. + * @param dbConfig The configuration for the LanceDB instance. + * @returns A Promise that resolves with a new instance of LanceDB. + */ + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + dbConfig: LanceDBArgs + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + return LanceDB.fromDocuments(docs, embeddings, dbConfig); + } + + /** + * Creates a new instance of LanceDB from documents. + * @param docs The documents to be added to the database. + * @param embeddings The embeddings to be managed. + * @param dbConfig The configuration for the LanceDB instance. + * @returns A Promise that resolves with a new instance of LanceDB. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig: LanceDBArgs + ): Promise { + const instance = new this(embeddings, dbConfig); + await instance.addDocuments(docs); + return instance; + } +} diff --git a/libs/langchain-community/src/vectorstores/memory.ts b/libs/langchain-community/src/vectorstores/memory.ts new file mode 100644 index 000000000000..9ea9e4aeeeb4 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/memory.ts @@ -0,0 +1,192 @@ +import { similarity as ml_distance_similarity } from "ml-distance"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Embeddings } from "@langchain/core/embeddings"; +import { Document } from "@langchain/core/documents"; + +/** + * Interface representing a vector in memory. It includes the content + * (text), the corresponding embedding (vector), and any associated + * metadata. + */ +interface MemoryVector { + content: string; + embedding: number[]; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + metadata: Record; +} + +/** + * Interface for the arguments that can be passed to the + * `MemoryVectorStore` constructor. It includes an optional `similarity` + * function. + */ +export interface MemoryVectorStoreArgs { + similarity?: typeof ml_distance_similarity.cosine; +} + +/** + * Class that extends `VectorStore` to store vectors in memory. Provides + * methods for adding documents, performing similarity searches, and + * creating instances from texts, documents, or an existing index. + */ +export class MemoryVectorStore extends VectorStore { + declare FilterType: (doc: Document) => boolean; + + memoryVectors: MemoryVector[] = []; + + similarity: typeof ml_distance_similarity.cosine; + + _vectorstoreType(): string { + return "memory"; + } + + constructor( + embeddings: Embeddings, + { similarity, ...rest }: MemoryVectorStoreArgs = {} + ) { + super(embeddings, rest); + + this.similarity = similarity ?? ml_distance_similarity.cosine; + } + + /** + * Method to add documents to the memory vector store. It extracts the + * text from each document, generates embeddings for them, and adds the + * resulting vectors to the store. + * @param documents Array of `Document` instances to be added to the store. + * @returns Promise that resolves when all documents have been added. + */ + async addDocuments(documents: Document[]): Promise { + const texts = documents.map(({ pageContent }) => pageContent); + return this.addVectors( + await this.embeddings.embedDocuments(texts), + documents + ); + } + + /** + * Method to add vectors to the memory vector store. It creates + * `MemoryVector` instances for each vector and document pair and adds + * them to the store. + * @param vectors Array of vectors to be added to the store. + * @param documents Array of `Document` instances corresponding to the vectors. + * @returns Promise that resolves when all vectors have been added. + */ + async addVectors(vectors: number[][], documents: Document[]): Promise { + const memoryVectors = vectors.map((embedding, idx) => ({ + content: documents[idx].pageContent, + embedding, + metadata: documents[idx].metadata, + })); + + this.memoryVectors = this.memoryVectors.concat(memoryVectors); + } + + /** + * Method to perform a similarity search in the memory vector store. It + * calculates the similarity between the query vector and each vector in + * the store, sorts the results by similarity, and returns the top `k` + * results along with their scores. + * @param query Query vector to compare against the vectors in the store. + * @param k Number of top results to return. + * @param filter Optional filter function to apply to the vectors before performing the search. + * @returns Promise that resolves with an array of tuples, each containing a `Document` and its similarity score. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: this["FilterType"] + ): Promise<[Document, number][]> { + const filterFunction = (memoryVector: MemoryVector) => { + if (!filter) { + return true; + } + + const doc = new Document({ + metadata: memoryVector.metadata, + pageContent: memoryVector.content, + }); + return filter(doc); + }; + const filteredMemoryVectors = this.memoryVectors.filter(filterFunction); + const searches = filteredMemoryVectors + .map((vector, index) => ({ + similarity: this.similarity(query, vector.embedding), + index, + })) + .sort((a, b) => (a.similarity > b.similarity ? -1 : 0)) + .slice(0, k); + + const result: [Document, number][] = searches.map((search) => [ + new Document({ + metadata: filteredMemoryVectors[search.index].metadata, + pageContent: filteredMemoryVectors[search.index].content, + }), + search.similarity, + ]); + + return result; + } + + /** + * Static method to create a `MemoryVectorStore` instance from an array of + * texts. It creates a `Document` for each text and metadata pair, and + * adds them to the store. + * @param texts Array of texts to be added to the store. + * @param metadatas Array or single object of metadata corresponding to the texts. + * @param embeddings `Embeddings` instance used to generate embeddings for the texts. + * @param dbConfig Optional `MemoryVectorStoreArgs` to configure the `MemoryVectorStore` instance. + * @returns Promise that resolves with a new `MemoryVectorStore` instance. + */ + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + dbConfig?: MemoryVectorStoreArgs + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + return MemoryVectorStore.fromDocuments(docs, embeddings, dbConfig); + } + + /** + * Static method to create a `MemoryVectorStore` instance from an array of + * `Document` instances. It adds the documents to the store. + * @param docs Array of `Document` instances to be added to the store. + * @param embeddings `Embeddings` instance used to generate embeddings for the documents. + * @param dbConfig Optional `MemoryVectorStoreArgs` to configure the `MemoryVectorStore` instance. + * @returns Promise that resolves with a new `MemoryVectorStore` instance. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig?: MemoryVectorStoreArgs + ): Promise { + const instance = new this(embeddings, dbConfig); + await instance.addDocuments(docs); + return instance; + } + + /** + * Static method to create a `MemoryVectorStore` instance from an existing + * index. It creates a new `MemoryVectorStore` instance without adding any + * documents or vectors. + * @param embeddings `Embeddings` instance used to generate embeddings for the documents. + * @param dbConfig Optional `MemoryVectorStoreArgs` to configure the `MemoryVectorStore` instance. + * @returns Promise that resolves with a new `MemoryVectorStore` instance. + */ + static async fromExistingIndex( + embeddings: Embeddings, + dbConfig?: MemoryVectorStoreArgs + ): Promise { + const instance = new this(embeddings, dbConfig); + return instance; + } +} diff --git a/libs/langchain-community/src/vectorstores/milvus.ts b/libs/langchain-community/src/vectorstores/milvus.ts new file mode 100644 index 000000000000..b9978554da96 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/milvus.ts @@ -0,0 +1,674 @@ +import * as uuid from "uuid"; +import { + MilvusClient, + DataType, + DataTypeMap, + ErrorCode, + FieldType, + ClientConfig, +} from "@zilliz/milvus2-sdk-node"; + +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +/** + * Interface for the arguments required by the Milvus class constructor. + */ +export interface MilvusLibArgs { + collectionName?: string; + primaryField?: string; + vectorField?: string; + textField?: string; + url?: string; // db address + ssl?: boolean; + username?: string; + password?: string; + textFieldMaxLength?: number; + clientConfig?: ClientConfig; + autoId?: boolean; +} + +/** + * Type representing the type of index used in the Milvus database. + */ +type IndexType = + | "IVF_FLAT" + | "IVF_SQ8" + | "IVF_PQ" + | "HNSW" + | "RHNSW_FLAT" + | "RHNSW_SQ" + | "RHNSW_PQ" + | "IVF_HNSW" + | "ANNOY"; + +/** + * Interface for the parameters required to create an index in the Milvus + * database. + */ +interface IndexParam { + params: { nprobe?: number; ef?: number; search_k?: number }; +} + +interface InsertRow { + [x: string]: string | number[]; +} + +const MILVUS_PRIMARY_FIELD_NAME = "langchain_primaryid"; +const MILVUS_VECTOR_FIELD_NAME = "langchain_vector"; +const MILVUS_TEXT_FIELD_NAME = "langchain_text"; +const MILVUS_COLLECTION_NAME_PREFIX = "langchain_col"; + +/** + * Class for interacting with a Milvus database. Extends the VectorStore + * class. + */ +export class Milvus extends VectorStore { + get lc_secrets(): { [key: string]: string } { + return { + ssl: "MILVUS_SSL", + username: "MILVUS_USERNAME", + password: "MILVUS_PASSWORD", + }; + } + + declare FilterType: string; + + collectionName: string; + + numDimensions?: number; + + autoId?: boolean; + + primaryField: string; + + vectorField: string; + + textField: string; + + textFieldMaxLength: number; + + fields: string[]; + + client: MilvusClient; + + indexParams: Record = { + IVF_FLAT: { params: { nprobe: 10 } }, + IVF_SQ8: { params: { nprobe: 10 } }, + IVF_PQ: { params: { nprobe: 10 } }, + HNSW: { params: { ef: 10 } }, + RHNSW_FLAT: { params: { ef: 10 } }, + RHNSW_SQ: { params: { ef: 10 } }, + RHNSW_PQ: { params: { ef: 10 } }, + IVF_HNSW: { params: { nprobe: 10, ef: 10 } }, + ANNOY: { params: { search_k: 10 } }, + }; + + indexCreateParams = { + index_type: "HNSW", + metric_type: "L2", + params: JSON.stringify({ M: 8, efConstruction: 64 }), + }; + + indexSearchParams = JSON.stringify({ ef: 64 }); + + _vectorstoreType(): string { + return "milvus"; + } + + constructor(embeddings: Embeddings, args: MilvusLibArgs) { + super(embeddings, args); + this.embeddings = embeddings; + this.collectionName = args.collectionName ?? genCollectionName(); + this.textField = args.textField ?? MILVUS_TEXT_FIELD_NAME; + + this.autoId = args.autoId ?? true; + this.primaryField = args.primaryField ?? MILVUS_PRIMARY_FIELD_NAME; + this.vectorField = args.vectorField ?? MILVUS_VECTOR_FIELD_NAME; + + this.textFieldMaxLength = args.textFieldMaxLength ?? 0; + + this.fields = []; + + const url = args.url ?? getEnvironmentVariable("MILVUS_URL"); + const { + address = "", + username = "", + password = "", + ssl, + } = args.clientConfig || {}; + + // combine args clientConfig and env variables + const clientConfig: ClientConfig = { + ...(args.clientConfig || {}), + address: url || address, + username: args.username || username, + password: args.password || password, + ssl: args.ssl || ssl, + }; + + if (!clientConfig.address) { + throw new Error("Milvus URL address is not provided."); + } + this.client = new MilvusClient(clientConfig); + } + + /** + * Adds documents to the Milvus database. + * @param documents Array of Document instances to be added to the database. + * @returns Promise resolving to void. + */ + async addDocuments(documents: Document[]): Promise { + const texts = documents.map(({ pageContent }) => pageContent); + await this.addVectors( + await this.embeddings.embedDocuments(texts), + documents + ); + } + + /** + * Adds vectors to the Milvus database. + * @param vectors Array of vectors to be added to the database. + * @param documents Array of Document instances associated with the vectors. + * @returns Promise resolving to void. + */ + async addVectors(vectors: number[][], documents: Document[]): Promise { + if (vectors.length === 0) { + return; + } + await this.ensureCollection(vectors, documents); + + const insertDatas: InsertRow[] = []; + // eslint-disable-next-line no-plusplus + for (let index = 0; index < vectors.length; index++) { + const vec = vectors[index]; + const doc = documents[index]; + const data: InsertRow = { + [this.textField]: doc.pageContent, + [this.vectorField]: vec, + }; + this.fields.forEach((field) => { + switch (field) { + case this.primaryField: + if (!this.autoId) { + if (doc.metadata[this.primaryField] === undefined) { + throw new Error( + `The Collection's primaryField is configured with autoId=false, thus its value must be provided through metadata.` + ); + } + data[field] = doc.metadata[this.primaryField]; + } + break; + case this.textField: + data[field] = doc.pageContent; + break; + case this.vectorField: + data[field] = vec; + break; + default: // metadata fields + if (doc.metadata[field] === undefined) { + throw new Error( + `The field "${field}" is not provided in documents[${index}].metadata.` + ); + } else if (typeof doc.metadata[field] === "object") { + data[field] = JSON.stringify(doc.metadata[field]); + } else { + data[field] = doc.metadata[field]; + } + break; + } + }); + + insertDatas.push(data); + } + + const insertResp = await this.client.insert({ + collection_name: this.collectionName, + fields_data: insertDatas, + }); + if (insertResp.status.error_code !== ErrorCode.SUCCESS) { + throw new Error(`Error inserting data: ${JSON.stringify(insertResp)}`); + } + await this.client.flushSync({ collection_names: [this.collectionName] }); + } + + /** + * Searches for vectors in the Milvus database that are similar to a given + * vector. + * @param query Vector to compare with the vectors in the database. + * @param k Number of similar vectors to return. + * @param filter Optional filter to apply to the search. + * @returns Promise resolving to an array of tuples, each containing a Document instance and a similarity score. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: string + ): Promise<[Document, number][]> { + const hasColResp = await this.client.hasCollection({ + collection_name: this.collectionName, + }); + if (hasColResp.status.error_code !== ErrorCode.SUCCESS) { + throw new Error(`Error checking collection: ${hasColResp}`); + } + if (hasColResp.value === false) { + throw new Error( + `Collection not found: ${this.collectionName}, please create collection before search.` + ); + } + + const filterStr = filter ?? ""; + + await this.grabCollectionFields(); + + const loadResp = await this.client.loadCollectionSync({ + collection_name: this.collectionName, + }); + if (loadResp.error_code !== ErrorCode.SUCCESS) { + throw new Error(`Error loading collection: ${loadResp}`); + } + + // clone this.field and remove vectorField + const outputFields = this.fields.filter( + (field) => field !== this.vectorField + ); + + const searchResp = await this.client.search({ + collection_name: this.collectionName, + search_params: { + anns_field: this.vectorField, + topk: k.toString(), + metric_type: this.indexCreateParams.metric_type, + params: this.indexSearchParams, + }, + output_fields: outputFields, + vector_type: DataType.FloatVector, + vectors: [query], + filter: filterStr, + }); + if (searchResp.status.error_code !== ErrorCode.SUCCESS) { + throw new Error(`Error searching data: ${JSON.stringify(searchResp)}`); + } + const results: [Document, number][] = []; + searchResp.results.forEach((result) => { + const fields = { + pageContent: "", + // eslint-disable-next-line @typescript-eslint/no-explicit-any + metadata: {} as Record, + }; + Object.keys(result).forEach((key) => { + if (key === this.textField) { + fields.pageContent = result[key]; + } else if (this.fields.includes(key) || key === this.primaryField) { + if (typeof result[key] === "string") { + const { isJson, obj } = checkJsonString(result[key]); + fields.metadata[key] = isJson ? obj : result[key]; + } else { + fields.metadata[key] = result[key]; + } + } + }); + results.push([new Document(fields), result.score]); + }); + // console.log("Search result: " + JSON.stringify(results, null, 2)); + return results; + } + + /** + * Ensures that a collection exists in the Milvus database. + * @param vectors Optional array of vectors to be used if a new collection needs to be created. + * @param documents Optional array of Document instances to be used if a new collection needs to be created. + * @returns Promise resolving to void. + */ + async ensureCollection(vectors?: number[][], documents?: Document[]) { + const hasColResp = await this.client.hasCollection({ + collection_name: this.collectionName, + }); + if (hasColResp.status.error_code !== ErrorCode.SUCCESS) { + throw new Error( + `Error checking collection: ${JSON.stringify(hasColResp, null, 2)}` + ); + } + + if (hasColResp.value === false) { + if (vectors === undefined || documents === undefined) { + throw new Error( + `Collection not found: ${this.collectionName}, please provide vectors and documents to create collection.` + ); + } + await this.createCollection(vectors, documents); + } else { + await this.grabCollectionFields(); + } + } + + /** + * Creates a collection in the Milvus database. + * @param vectors Array of vectors to be added to the new collection. + * @param documents Array of Document instances to be added to the new collection. + * @returns Promise resolving to void. + */ + async createCollection( + vectors: number[][], + documents: Document[] + ): Promise { + const fieldList: FieldType[] = []; + + fieldList.push(...createFieldTypeForMetadata(documents, this.primaryField)); + + fieldList.push( + { + name: this.primaryField, + description: "Primary key", + data_type: DataType.Int64, + is_primary_key: true, + autoID: this.autoId, + }, + { + name: this.textField, + description: "Text field", + data_type: DataType.VarChar, + type_params: { + max_length: + this.textFieldMaxLength > 0 + ? this.textFieldMaxLength.toString() + : getTextFieldMaxLength(documents).toString(), + }, + }, + { + name: this.vectorField, + description: "Vector field", + data_type: DataType.FloatVector, + type_params: { + dim: getVectorFieldDim(vectors).toString(), + }, + } + ); + + fieldList.forEach((field) => { + if (!field.autoID) { + this.fields.push(field.name); + } + }); + + const createRes = await this.client.createCollection({ + collection_name: this.collectionName, + fields: fieldList, + }); + + if (createRes.error_code !== ErrorCode.SUCCESS) { + console.log(createRes); + throw new Error(`Failed to create collection: ${createRes}`); + } + + await this.client.createIndex({ + collection_name: this.collectionName, + field_name: this.vectorField, + extra_params: this.indexCreateParams, + }); + } + + /** + * Retrieves the fields of a collection in the Milvus database. + * @returns Promise resolving to void. + */ + async grabCollectionFields(): Promise { + if (!this.collectionName) { + throw new Error("Need collection name to grab collection fields"); + } + if ( + this.primaryField && + this.vectorField && + this.textField && + this.fields.length > 0 + ) { + return; + } + const desc = await this.client.describeCollection({ + collection_name: this.collectionName, + }); + desc.schema.fields.forEach((field) => { + this.fields.push(field.name); + if (field.autoID) { + const index = this.fields.indexOf(field.name); + if (index !== -1) { + this.fields.splice(index, 1); + } + } + if (field.is_primary_key) { + this.primaryField = field.name; + } + const dtype = DataTypeMap[field.data_type]; + if (dtype === DataType.FloatVector || dtype === DataType.BinaryVector) { + this.vectorField = field.name; + } + + if (dtype === DataType.VarChar && field.name === MILVUS_TEXT_FIELD_NAME) { + this.textField = field.name; + } + }); + } + + /** + * Creates a Milvus instance from a set of texts and their associated + * metadata. + * @param texts Array of texts to be added to the database. + * @param metadatas Array of metadata objects associated with the texts. + * @param embeddings Embeddings instance used to generate vector embeddings for the texts. + * @param dbConfig Optional configuration for the Milvus database. + * @returns Promise resolving to a new Milvus instance. + */ + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + dbConfig?: MilvusLibArgs + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + return Milvus.fromDocuments(docs, embeddings, dbConfig); + } + + /** + * Creates a Milvus instance from a set of Document instances. + * @param docs Array of Document instances to be added to the database. + * @param embeddings Embeddings instance used to generate vector embeddings for the documents. + * @param dbConfig Optional configuration for the Milvus database. + * @returns Promise resolving to a new Milvus instance. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig?: MilvusLibArgs + ): Promise { + const args: MilvusLibArgs = { + collectionName: dbConfig?.collectionName || genCollectionName(), + url: dbConfig?.url, + ssl: dbConfig?.ssl, + username: dbConfig?.username, + password: dbConfig?.password, + textField: dbConfig?.textField, + primaryField: dbConfig?.primaryField, + vectorField: dbConfig?.vectorField, + clientConfig: dbConfig?.clientConfig, + autoId: dbConfig?.autoId, + }; + const instance = new this(embeddings, args); + await instance.addDocuments(docs); + return instance; + } + + /** + * Creates a Milvus instance from an existing collection in the Milvus + * database. + * @param embeddings Embeddings instance used to generate vector embeddings for the documents in the collection. + * @param dbConfig Configuration for the Milvus database. + * @returns Promise resolving to a new Milvus instance. + */ + static async fromExistingCollection( + embeddings: Embeddings, + dbConfig: MilvusLibArgs + ): Promise { + const instance = new this(embeddings, dbConfig); + await instance.ensureCollection(); + return instance; + } + + /** + * Deletes data from the Milvus database. + * @param params Object containing a filter to apply to the deletion. + * @returns Promise resolving to void. + */ + async delete(params: { filter: string }): Promise { + const hasColResp = await this.client.hasCollection({ + collection_name: this.collectionName, + }); + if (hasColResp.status.error_code !== ErrorCode.SUCCESS) { + throw new Error(`Error checking collection: ${hasColResp}`); + } + if (hasColResp.value === false) { + throw new Error( + `Collection not found: ${this.collectionName}, please create collection before search.` + ); + } + + const { filter } = params; + + const deleteResp = await this.client.deleteEntities({ + collection_name: this.collectionName, + expr: filter, + }); + + if (deleteResp.status.error_code !== ErrorCode.SUCCESS) { + throw new Error(`Error deleting data: ${JSON.stringify(deleteResp)}`); + } + } +} + +function createFieldTypeForMetadata( + documents: Document[], + primaryFieldName: string +): FieldType[] { + const sampleMetadata = documents[0].metadata; + let textFieldMaxLength = 0; + let jsonFieldMaxLength = 0; + documents.forEach(({ metadata }) => { + // check all keys name and count in metadata is same as sampleMetadata + Object.keys(metadata).forEach((key) => { + if ( + !(key in metadata) || + typeof metadata[key] !== typeof sampleMetadata[key] + ) { + throw new Error( + "All documents must have same metadata keys and datatype" + ); + } + + // find max length of string field and json field, cache json string value + if (typeof metadata[key] === "string") { + if (metadata[key].length > textFieldMaxLength) { + textFieldMaxLength = metadata[key].length; + } + } else if (typeof metadata[key] === "object") { + const json = JSON.stringify(metadata[key]); + if (json.length > jsonFieldMaxLength) { + jsonFieldMaxLength = json.length; + } + } + }); + }); + + const fields: FieldType[] = []; + for (const [key, value] of Object.entries(sampleMetadata)) { + const type = typeof value; + + if (key === primaryFieldName) { + /** + * skip primary field + * because we will create primary field in createCollection + * */ + } else if (type === "string") { + fields.push({ + name: key, + description: `Metadata String field`, + data_type: DataType.VarChar, + type_params: { + max_length: textFieldMaxLength.toString(), + }, + }); + } else if (type === "number") { + fields.push({ + name: key, + description: `Metadata Number field`, + data_type: DataType.Float, + }); + } else if (type === "boolean") { + fields.push({ + name: key, + description: `Metadata Boolean field`, + data_type: DataType.Bool, + }); + } else if (value === null) { + // skip + } else { + // use json for other types + try { + fields.push({ + name: key, + description: `Metadata JSON field`, + data_type: DataType.VarChar, + type_params: { + max_length: jsonFieldMaxLength.toString(), + }, + }); + } catch (e) { + throw new Error("Failed to parse metadata field as JSON"); + } + } + } + return fields; +} + +function genCollectionName(): string { + return `${MILVUS_COLLECTION_NAME_PREFIX}_${uuid.v4().replaceAll("-", "")}`; +} + +function getTextFieldMaxLength(documents: Document[]) { + let textMaxLength = 0; + const textEncoder = new TextEncoder(); + // eslint-disable-next-line no-plusplus + for (let i = 0; i < documents.length; i++) { + const text = documents[i].pageContent; + const textLengthInBytes = textEncoder.encode(text).length; + if (textLengthInBytes > textMaxLength) { + textMaxLength = textLengthInBytes; + } + } + return textMaxLength; +} + +function getVectorFieldDim(vectors: number[][]) { + if (vectors.length === 0) { + throw new Error("No vectors found"); + } + return vectors[0].length; +} + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +function checkJsonString(value: string): { isJson: boolean; obj: any } { + try { + const result = JSON.parse(value); + return { isJson: true, obj: result }; + } catch (e) { + return { isJson: false, obj: null }; + } +} diff --git a/libs/langchain-community/src/vectorstores/momento_vector_index.ts b/libs/langchain-community/src/vectorstores/momento_vector_index.ts new file mode 100644 index 000000000000..e9623ee1a131 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/momento_vector_index.ts @@ -0,0 +1,402 @@ +/* eslint-disable no-instanceof/no-instanceof */ +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { + ALL_VECTOR_METADATA, + IVectorIndexClient, + VectorIndexItem, + CreateVectorIndex, + VectorUpsertItemBatch, + VectorDeleteItemBatch, + VectorSearch, + VectorSearchAndFetchVectors, +} from "@gomomento/sdk-core"; +import * as uuid from "uuid"; +import { Document } from "@langchain/core/documents"; +import { Embeddings } from "@langchain/core/embeddings"; +import { + MaxMarginalRelevanceSearchOptions, + VectorStore, +} from "@langchain/core/vectorstores"; +import { maximalMarginalRelevance } from "@langchain/core/utils/math"; + +export interface DocumentProps { + ids: string[]; +} + +export interface MomentoVectorIndexLibArgs { + /** + * The Momento Vector Index client. + */ + client: IVectorIndexClient; + /** + * The name of the index to use to store the data. + * Defaults to "default". + */ + indexName?: string; + /** + * The name of the metadata field to use to store the text of the document. + * Defaults to "text". + */ + textField?: string; + /** + * Whether to create the index if it does not already exist. + * Defaults to true. + */ + ensureIndexExists?: boolean; +} + +export interface DeleteProps { + /** + * The ids of the documents to delete. + */ + ids: string[]; +} + +/** + * A vector store that uses the Momento Vector Index. + * + * @remarks + * To sign up for a free Momento account, visit https://console.gomomento.com. + */ +export class MomentoVectorIndex extends VectorStore { + private client: IVectorIndexClient; + + private indexName: string; + + private textField: string; + + private _ensureIndexExists: boolean; + + _vectorstoreType(): string { + return "momento"; + } + + /** + * Creates a new `MomentoVectorIndex` instance. + * @param embeddings The embeddings instance to use to generate embeddings from documents. + * @param args The arguments to use to configure the vector store. + */ + constructor(embeddings: Embeddings, args: MomentoVectorIndexLibArgs) { + super(embeddings, args); + + this.embeddings = embeddings; + this.client = args.client; + this.indexName = args.indexName ?? "default"; + this.textField = args.textField ?? "text"; + this._ensureIndexExists = args.ensureIndexExists ?? true; + } + + /** + * Returns the Momento Vector Index client. + * @returns The Momento Vector Index client. + */ + public getClient(): IVectorIndexClient { + return this.client; + } + + /** + * Creates the index if it does not already exist. + * @param numDimensions The number of dimensions of the vectors to be stored in the index. + * @returns Promise that resolves to true if the index was created, false if it already existed. + */ + private async ensureIndexExists(numDimensions: number): Promise { + const response = await this.client.createIndex( + this.indexName, + numDimensions + ); + if (response instanceof CreateVectorIndex.Success) { + return true; + } else if (response instanceof CreateVectorIndex.AlreadyExists) { + return false; + } else if (response instanceof CreateVectorIndex.Error) { + throw new Error(response.toString()); + } else { + throw new Error(`Unknown response type: ${response.toString()}`); + } + } + + /** + * Converts the documents to a format that can be stored in the index. + * + * This is necessary because the Momento Vector Index requires that the metadata + * be a map of strings to strings. + * @param vectors The vectors to convert. + * @param documents The documents to convert. + * @param ids The ids to convert. + * @returns The converted documents. + */ + private prepareItemBatch( + vectors: number[][], + documents: Document>[], + ids: string[] + ): VectorIndexItem[] { + return vectors.map((vector, idx) => ({ + id: ids[idx], + vector, + metadata: { + ...documents[idx].metadata, + [this.textField]: documents[idx].pageContent, + }, + })); + } + + /** + * Adds vectors to the index. + * + * @remarks If the index does not already exist, it will be created if `ensureIndexExists` is true. + * @param vectors The vectors to add to the index. + * @param documents The documents to add to the index. + * @param documentProps The properties of the documents to add to the index, specifically the ids. + * @returns Promise that resolves when the vectors have been added to the index. Also returns the ids of the + * documents that were added. + */ + public async addVectors( + vectors: number[][], + documents: Document>[], + documentProps?: DocumentProps + ): Promise { + if (vectors.length === 0) { + return; + } + + if (documents.length !== vectors.length) { + throw new Error( + `Number of vectors (${vectors.length}) does not equal number of documents (${documents.length})` + ); + } + + if (vectors.some((v) => v.length !== vectors[0].length)) { + throw new Error("All vectors must have the same length"); + } + + if ( + documentProps?.ids !== undefined && + documentProps.ids.length !== vectors.length + ) { + throw new Error( + `Number of ids (${ + documentProps?.ids?.length || "null" + }) does not equal number of vectors (${vectors.length})` + ); + } + + if (this._ensureIndexExists) { + await this.ensureIndexExists(vectors[0].length); + } + const documentIds = documentProps?.ids ?? documents.map(() => uuid.v4()); + + const batchSize = 128; + const numBatches = Math.ceil(vectors.length / batchSize); + + // Add each batch of vectors to the index + for (let i = 0; i < numBatches; i += 1) { + const [startIndex, endIndex] = [ + i * batchSize, + Math.min((i + 1) * batchSize, vectors.length), + ]; + + const batchVectors = vectors.slice(startIndex, endIndex); + const batchDocuments = documents.slice(startIndex, endIndex); + const batchDocumentIds = documentIds.slice(startIndex, endIndex); + + // Insert the items to the index + const response = await this.client.upsertItemBatch( + this.indexName, + this.prepareItemBatch(batchVectors, batchDocuments, batchDocumentIds) + ); + if (response instanceof VectorUpsertItemBatch.Success) { + // eslint-disable-next-line no-continue + continue; + } else if (response instanceof VectorUpsertItemBatch.Error) { + throw new Error(response.toString()); + } else { + throw new Error(`Unknown response type: ${response.toString()}`); + } + } + } + + /** + * Adds vectors to the index. Generates embeddings from the documents + * using the `Embeddings` instance passed to the constructor. + * @param documents Array of `Document` instances to be added to the index. + * @returns Promise that resolves when the documents have been added to the index. + */ + async addDocuments( + documents: Document[], + documentProps?: DocumentProps + ): Promise { + const texts = documents.map(({ pageContent }) => pageContent); + await this.addVectors( + await this.embeddings.embedDocuments(texts), + documents, + documentProps + ); + } + + /** + * Deletes vectors from the index by id. + * @param params The parameters to use to delete the vectors, specifically the ids. + */ + public async delete(params: DeleteProps): Promise { + const response = await this.client.deleteItemBatch( + this.indexName, + params.ids + ); + if (response instanceof VectorDeleteItemBatch.Success) { + // pass + } else if (response instanceof VectorDeleteItemBatch.Error) { + throw new Error(response.toString()); + } else { + throw new Error(`Unknown response type: ${response.toString()}`); + } + } + + /** + * Searches the index for the most similar vectors to the query vector. + * @param query The query vector. + * @param k The number of results to return. + * @returns Promise that resolves to the documents of the most similar vectors + * to the query vector. + */ + public async similaritySearchVectorWithScore( + query: number[], + k: number + ): Promise<[Document>, number][]> { + const response = await this.client.search(this.indexName, query, { + topK: k, + metadataFields: ALL_VECTOR_METADATA, + }); + if (response instanceof VectorSearch.Success) { + if (response.hits === undefined) { + return []; + } + + return response.hits().map((hit) => [ + new Document({ + pageContent: hit.metadata[this.textField]?.toString() ?? "", + metadata: Object.fromEntries( + Object.entries(hit.metadata).filter( + ([key]) => key !== this.textField + ) + ), + }), + hit.score, + ]); + } else if (response instanceof VectorSearch.Error) { + throw new Error(response.toString()); + } else { + throw new Error(`Unknown response type: ${response.toString()}`); + } + } + + /** + * Return documents selected using the maximal marginal relevance. + * Maximal marginal relevance optimizes for similarity to the query AND diversity + * among selected documents. + * + * @param {string} query - Text to look up documents similar to. + * @param {number} options.k - Number of documents to return. + * @param {number} options.fetchK - Number of documents to fetch before passing to the MMR algorithm. + * @param {number} options.lambda - Number between 0 and 1 that determines the degree of diversity among the results, + * where 0 corresponds to maximum diversity and 1 to minimum diversity. + * @param {this["FilterType"]} options.filter - Optional filter + * @param _callbacks + * + * @returns {Promise} - List of documents selected by maximal marginal relevance. + */ + async maxMarginalRelevanceSearch( + query: string, + options: MaxMarginalRelevanceSearchOptions + ): Promise { + const queryEmbedding = await this.embeddings.embedQuery(query); + const response = await this.client.searchAndFetchVectors( + this.indexName, + queryEmbedding, + { topK: options.fetchK ?? 20, metadataFields: ALL_VECTOR_METADATA } + ); + + if (response instanceof VectorSearchAndFetchVectors.Success) { + const hits = response.hits(); + + // Gather the embeddings of the search results + const embeddingList = hits.map((hit) => hit.vector); + + // Gather the ids of the most relevant results when applying MMR + const mmrIndexes = maximalMarginalRelevance( + queryEmbedding, + embeddingList, + options.lambda, + options.k + ); + + const finalResult = mmrIndexes.map((index) => { + const hit = hits[index]; + const { [this.textField]: pageContent, ...metadata } = hit.metadata; + return new Document({ metadata, pageContent: pageContent as string }); + }); + return finalResult; + } else if (response instanceof VectorSearchAndFetchVectors.Error) { + throw new Error(response.toString()); + } else { + throw new Error(`Unknown response type: ${response.toString()}`); + } + } + + /** + * Stores the documents in the index. + * + * Converts the documents to vectors using the `Embeddings` instance passed. + * @param texts The texts to store in the index. + * @param metadatas The metadata to store in the index. + * @param embeddings The embeddings instance to use to generate embeddings from the documents. + * @param dbConfig The configuration to use to instantiate the vector store. + * @param documentProps The properties of the documents to add to the index, specifically the ids. + * @returns Promise that resolves to the vector store. + */ + public static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + dbConfig: MomentoVectorIndexLibArgs, + documentProps?: DocumentProps + ): Promise { + if (Array.isArray(metadatas) && texts.length !== metadatas.length) { + throw new Error( + `Number of texts (${texts.length}) does not equal number of metadatas (${metadatas.length})` + ); + } + + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment + const metadata: object = Array.isArray(metadatas) + ? metadatas[i] + : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + return await this.fromDocuments(docs, embeddings, dbConfig, documentProps); + } + + /** + * Stores the documents in the index. + * @param docs The documents to store in the index. + * @param embeddings The embeddings instance to use to generate embeddings from the documents. + * @param dbConfig The configuration to use to instantiate the vector store. + * @param documentProps The properties of the documents to add to the index, specifically the ids. + * @returns Promise that resolves to the vector store. + */ + public static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig: MomentoVectorIndexLibArgs, + documentProps?: DocumentProps + ): Promise { + const vectorStore = new MomentoVectorIndex(embeddings, dbConfig); + await vectorStore.addDocuments(docs, documentProps); + return vectorStore; + } +} diff --git a/libs/langchain-community/src/vectorstores/mongodb_atlas.ts b/libs/langchain-community/src/vectorstores/mongodb_atlas.ts new file mode 100755 index 000000000000..10a56dc3f7df --- /dev/null +++ b/libs/langchain-community/src/vectorstores/mongodb_atlas.ts @@ -0,0 +1,282 @@ +import type { Collection, Document as MongoDBDocument } from "mongodb"; +import { + MaxMarginalRelevanceSearchOptions, + VectorStore, +} from "@langchain/core/vectorstores"; +import { Embeddings } from "@langchain/core/embeddings"; +import { Document } from "@langchain/core/documents"; +import { maximalMarginalRelevance } from "@langchain/core/utils/math"; + +/** + * Type that defines the arguments required to initialize the + * MongoDBAtlasVectorSearch class. It includes the MongoDB collection, + * index name, text key, and embedding key. + */ +export type MongoDBAtlasVectorSearchLibArgs = { + readonly collection: Collection; + readonly indexName?: string; + readonly textKey?: string; + readonly embeddingKey?: string; +}; + +/** + * Type that defines the filter used in the + * similaritySearchVectorWithScore and maxMarginalRelevanceSearch methods. + * It includes pre-filter, post-filter pipeline, and a flag to include + * embeddings. + */ +type MongoDBAtlasFilter = { + preFilter?: MongoDBDocument; + postFilterPipeline?: MongoDBDocument[]; + includeEmbeddings?: boolean; +} & MongoDBDocument; + +/** + * Class that is a wrapper around MongoDB Atlas Vector Search. It is used + * to store embeddings in MongoDB documents, create a vector search index, + * and perform K-Nearest Neighbors (KNN) search with an approximate + * nearest neighbor algorithm. + */ +export class MongoDBAtlasVectorSearch extends VectorStore { + declare FilterType: MongoDBAtlasFilter; + + private readonly collection: Collection; + + private readonly indexName: string; + + private readonly textKey: string; + + private readonly embeddingKey: string; + + _vectorstoreType(): string { + return "mongodb_atlas"; + } + + constructor(embeddings: Embeddings, args: MongoDBAtlasVectorSearchLibArgs) { + super(embeddings, args); + this.collection = args.collection; + this.indexName = args.indexName ?? "default"; + this.textKey = args.textKey ?? "text"; + this.embeddingKey = args.embeddingKey ?? "embedding"; + } + + /** + * Method to add vectors and their corresponding documents to the MongoDB + * collection. + * @param vectors Vectors to be added. + * @param documents Corresponding documents to be added. + * @returns Promise that resolves when the vectors and documents have been added. + */ + async addVectors(vectors: number[][], documents: Document[]): Promise { + const docs = vectors.map((embedding, idx) => ({ + [this.textKey]: documents[idx].pageContent, + [this.embeddingKey]: embedding, + ...documents[idx].metadata, + })); + await this.collection.insertMany(docs); + } + + /** + * Method to add documents to the MongoDB collection. It first converts + * the documents to vectors using the embeddings and then calls the + * addVectors method. + * @param documents Documents to be added. + * @returns Promise that resolves when the documents have been added. + */ + async addDocuments(documents: Document[]): Promise { + const texts = documents.map(({ pageContent }) => pageContent); + return this.addVectors( + await this.embeddings.embedDocuments(texts), + documents + ); + } + + /** + * Method that performs a similarity search on the vectors stored in the + * MongoDB collection. It returns a list of documents and their + * corresponding similarity scores. + * @param query Query vector for the similarity search. + * @param k Number of nearest neighbors to return. + * @param filter Optional filter to be applied. + * @returns Promise that resolves to a list of documents and their corresponding similarity scores. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: MongoDBAtlasFilter + ): Promise<[Document, number][]> { + const postFilterPipeline = filter?.postFilterPipeline ?? []; + const preFilter: MongoDBDocument | undefined = + filter?.preFilter || + filter?.postFilterPipeline || + filter?.includeEmbeddings + ? filter.preFilter + : filter; + const removeEmbeddingsPipeline = !filter?.includeEmbeddings + ? [ + { + $project: { + [this.embeddingKey]: 0, + }, + }, + ] + : []; + + const pipeline: MongoDBDocument[] = [ + { + $vectorSearch: { + queryVector: MongoDBAtlasVectorSearch.fixArrayPrecision(query), + index: this.indexName, + path: this.embeddingKey, + limit: k, + numCandidates: 10 * k, + ...(preFilter && { filter: preFilter }), + }, + }, + { + $set: { + score: { $meta: "vectorSearchScore" }, + }, + }, + ...removeEmbeddingsPipeline, + ...postFilterPipeline, + ]; + + const results = this.collection + .aggregate(pipeline) + .map<[Document, number]>((result) => { + const { score, [this.textKey]: text, ...metadata } = result; + return [new Document({ pageContent: text, metadata }), score]; + }); + + return results.toArray(); + } + + /** + * Return documents selected using the maximal marginal relevance. + * Maximal marginal relevance optimizes for similarity to the query AND diversity + * among selected documents. + * + * @param {string} query - Text to look up documents similar to. + * @param {number} options.k - Number of documents to return. + * @param {number} options.fetchK=20- Number of documents to fetch before passing to the MMR algorithm. + * @param {number} options.lambda=0.5 - Number between 0 and 1 that determines the degree of diversity among the results, + * where 0 corresponds to maximum diversity and 1 to minimum diversity. + * @param {MongoDBAtlasFilter} options.filter - Optional Atlas Search operator to pre-filter on document fields + * or post-filter following the knnBeta search. + * + * @returns {Promise} - List of documents selected by maximal marginal relevance. + */ + async maxMarginalRelevanceSearch( + query: string, + options: MaxMarginalRelevanceSearchOptions + ): Promise { + const { k, fetchK = 20, lambda = 0.5, filter } = options; + + const queryEmbedding = await this.embeddings.embedQuery(query); + + // preserve the original value of includeEmbeddings + const includeEmbeddingsFlag = options.filter?.includeEmbeddings || false; + + // update filter to include embeddings, as they will be used in MMR + const includeEmbeddingsFilter = { + ...filter, + includeEmbeddings: true, + }; + + const resultDocs = await this.similaritySearchVectorWithScore( + MongoDBAtlasVectorSearch.fixArrayPrecision(queryEmbedding), + fetchK, + includeEmbeddingsFilter + ); + + const embeddingList = resultDocs.map( + (doc) => doc[0].metadata[this.embeddingKey] + ); + + const mmrIndexes = maximalMarginalRelevance( + queryEmbedding, + embeddingList, + lambda, + k + ); + + return mmrIndexes.map((idx) => { + const doc = resultDocs[idx][0]; + + // remove embeddings if they were not requested originally + if (!includeEmbeddingsFlag) { + delete doc.metadata[this.embeddingKey]; + } + return doc; + }); + } + + /** + * Static method to create an instance of MongoDBAtlasVectorSearch from a + * list of texts. It first converts the texts to vectors and then adds + * them to the MongoDB collection. + * @param texts List of texts to be converted to vectors. + * @param metadatas Metadata for the texts. + * @param embeddings Embeddings to be used for conversion. + * @param dbConfig Database configuration for MongoDB Atlas. + * @returns Promise that resolves to a new instance of MongoDBAtlasVectorSearch. + */ + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + dbConfig: MongoDBAtlasVectorSearchLibArgs + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + return MongoDBAtlasVectorSearch.fromDocuments(docs, embeddings, dbConfig); + } + + /** + * Static method to create an instance of MongoDBAtlasVectorSearch from a + * list of documents. It first converts the documents to vectors and then + * adds them to the MongoDB collection. + * @param docs List of documents to be converted to vectors. + * @param embeddings Embeddings to be used for conversion. + * @param dbConfig Database configuration for MongoDB Atlas. + * @returns Promise that resolves to a new instance of MongoDBAtlasVectorSearch. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig: MongoDBAtlasVectorSearchLibArgs + ): Promise { + const instance = new this(embeddings, dbConfig); + await instance.addDocuments(docs); + return instance; + } + + /** + * Static method to fix the precision of the array that ensures that + * every number in this array is always float when casted to other types. + * This is needed since MongoDB Atlas Vector Search does not cast integer + * inside vector search to float automatically. + * This method shall introduce a hint of error but should be safe to use + * since introduced error is very small, only applies to integer numbers + * returned by embeddings, and most embeddings shall not have precision + * as high as 15 decimal places. + * @param array Array of number to be fixed. + * @returns + */ + static fixArrayPrecision(array: number[]) { + return array.map((value) => { + if (Number.isInteger(value)) { + return value + 0.000000000000001; + } + return value; + }); + } +} diff --git a/libs/langchain-community/src/vectorstores/myscale.ts b/libs/langchain-community/src/vectorstores/myscale.ts new file mode 100644 index 000000000000..316505414b57 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/myscale.ts @@ -0,0 +1,314 @@ +import * as uuid from "uuid"; +import { ClickHouseClient, createClient } from "@clickhouse/client"; + +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; + +/** + * Arguments for the MyScaleStore class, which include the host, port, + * protocol, username, password, index type, index parameters, column map, + * database, table, and metric. + */ +export interface MyScaleLibArgs { + host: string; + port: string | number; + protocol?: string; + username: string; + password: string; + indexType?: string; + indexParam?: Record; + columnMap?: ColumnMap; + database?: string; + table?: string; + metric?: metric; +} + +/** + * Mapping of columns in the MyScale database. + */ +export interface ColumnMap { + id: string; + text: string; + vector: string; + metadata: string; +} + +/** + * Type of metric used in the MyScale database. + */ +export type metric = "L2" | "Cosine" | "IP"; + +/** + * Type for filtering search results in the MyScale database. + */ +export interface MyScaleFilter { + whereStr: string; +} + +/** + * Class for interacting with the MyScale database. It extends the + * VectorStore class and provides methods for adding vectors and + * documents, searching for similar vectors, and creating instances from + * texts or documents. + */ +export class MyScaleStore extends VectorStore { + declare FilterType: MyScaleFilter; + + private client: ClickHouseClient; + + private indexType: string; + + private indexParam: Record; + + private columnMap: ColumnMap; + + private database: string; + + private table: string; + + private metric: metric; + + private isInitialized = false; + + _vectorstoreType(): string { + return "myscale"; + } + + constructor(embeddings: Embeddings, args: MyScaleLibArgs) { + super(embeddings, args); + + this.indexType = args.indexType || "MSTG"; + this.indexParam = args.indexParam || {}; + this.columnMap = args.columnMap || { + id: "id", + text: "text", + vector: "vector", + metadata: "metadata", + }; + this.database = args.database || "default"; + this.table = args.table || "vector_table"; + this.metric = args.metric || "Cosine"; + + this.client = createClient({ + host: `${args.protocol ?? "https://"}${args.host}:${args.port}`, + username: args.username, + password: args.password, + session_id: uuid.v4(), + }); + } + + /** + * Method to add vectors to the MyScale database. + * @param vectors The vectors to add. + * @param documents The documents associated with the vectors. + * @returns Promise that resolves when the vectors have been added. + */ + async addVectors(vectors: number[][], documents: Document[]): Promise { + if (vectors.length === 0) { + return; + } + + if (!this.isInitialized) { + await this.initialize(vectors[0].length); + } + + const queryStr = this.buildInsertQuery(vectors, documents); + await this.client.exec({ query: queryStr }); + } + + /** + * Method to add documents to the MyScale database. + * @param documents The documents to add. + * @returns Promise that resolves when the documents have been added. + */ + async addDocuments(documents: Document[]): Promise { + return this.addVectors( + await this.embeddings.embedDocuments(documents.map((d) => d.pageContent)), + documents + ); + } + + /** + * Method to search for vectors that are similar to a given query vector. + * @param query The query vector. + * @param k The number of similar vectors to return. + * @param filter Optional filter for the search results. + * @returns Promise that resolves with an array of tuples, each containing a Document and a score. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: this["FilterType"] + ): Promise<[Document, number][]> { + if (!this.isInitialized) { + await this.initialize(query.length); + } + const queryStr = this.buildSearchQuery(query, k, filter); + + const queryResultSet = await this.client.query({ query: queryStr }); + const queryResult: { + data: { text: string; metadata: object; dist: number }[]; + } = await queryResultSet.json(); + + const result: [Document, number][] = queryResult.data.map((item) => [ + new Document({ pageContent: item.text, metadata: item.metadata }), + item.dist, + ]); + + return result; + } + + /** + * Static method to create an instance of MyScaleStore from texts. + * @param texts The texts to use. + * @param metadatas The metadata associated with the texts. + * @param embeddings The embeddings to use. + * @param args The arguments for the MyScaleStore. + * @returns Promise that resolves with a new instance of MyScaleStore. + */ + static async fromTexts( + texts: string[], + metadatas: object | object[], + embeddings: Embeddings, + args: MyScaleLibArgs + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + return MyScaleStore.fromDocuments(docs, embeddings, args); + } + + /** + * Static method to create an instance of MyScaleStore from documents. + * @param docs The documents to use. + * @param embeddings The embeddings to use. + * @param args The arguments for the MyScaleStore. + * @returns Promise that resolves with a new instance of MyScaleStore. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + args: MyScaleLibArgs + ): Promise { + const instance = new this(embeddings, args); + await instance.addDocuments(docs); + return instance; + } + + /** + * Static method to create an instance of MyScaleStore from an existing + * index. + * @param embeddings The embeddings to use. + * @param args The arguments for the MyScaleStore. + * @returns Promise that resolves with a new instance of MyScaleStore. + */ + static async fromExistingIndex( + embeddings: Embeddings, + args: MyScaleLibArgs + ): Promise { + const instance = new this(embeddings, args); + + await instance.initialize(); + return instance; + } + + /** + * Method to initialize the MyScale database. + * @param dimension Optional dimension of the vectors. + * @returns Promise that resolves when the database has been initialized. + */ + private async initialize(dimension?: number): Promise { + const dim = dimension ?? (await this.embeddings.embedQuery("test")).length; + + let indexParamStr = ""; + for (const [key, value] of Object.entries(this.indexParam)) { + indexParamStr += `, '${key}=${value}'`; + } + + const query = ` + CREATE TABLE IF NOT EXISTS ${this.database}.${this.table}( + ${this.columnMap.id} String, + ${this.columnMap.text} String, + ${this.columnMap.vector} Array(Float32), + ${this.columnMap.metadata} JSON, + CONSTRAINT cons_vec_len CHECK length(${this.columnMap.vector}) = ${dim}, + VECTOR INDEX vidx ${this.columnMap.vector} TYPE ${this.indexType}('metric_type=${this.metric}'${indexParamStr}) + ) ENGINE = MergeTree ORDER BY ${this.columnMap.id} + `; + + await this.client.exec({ query: "SET allow_experimental_object_type=1" }); + await this.client.exec({ + query: "SET output_format_json_named_tuples_as_objects = 1", + }); + await this.client.exec({ query }); + this.isInitialized = true; + } + + /** + * Method to build an SQL query for inserting vectors and documents into + * the MyScale database. + * @param vectors The vectors to insert. + * @param documents The documents to insert. + * @returns The SQL query string. + */ + private buildInsertQuery(vectors: number[][], documents: Document[]): string { + const columnsStr = Object.values(this.columnMap).join(", "); + + const data: string[] = []; + for (let i = 0; i < vectors.length; i += 1) { + const vector = vectors[i]; + const document = documents[i]; + const item = [ + `'${uuid.v4()}'`, + `'${this.escapeString(document.pageContent)}'`, + `[${vector}]`, + `'${JSON.stringify(document.metadata)}'`, + ].join(", "); + data.push(`(${item})`); + } + const dataStr = data.join(", "); + + return ` + INSERT INTO TABLE + ${this.database}.${this.table}(${columnsStr}) + VALUES + ${dataStr} + `; + } + + private escapeString(str: string): string { + return str.replace(/\\/g, "\\\\").replace(/'/g, "\\'"); + } + + /** + * Method to build an SQL query for searching for similar vectors in the + * MyScale database. + * @param query The query vector. + * @param k The number of similar vectors to return. + * @param filter Optional filter for the search results. + * @returns The SQL query string. + */ + private buildSearchQuery( + query: number[], + k: number, + filter?: MyScaleFilter + ): string { + const order = this.metric === "IP" ? "DESC" : "ASC"; + + const whereStr = filter ? `PREWHERE ${filter.whereStr}` : ""; + return ` + SELECT ${this.columnMap.text} AS text, ${this.columnMap.metadata} AS metadata, dist + FROM ${this.database}.${this.table} + ${whereStr} + ORDER BY distance(${this.columnMap.vector}, [${query}]) AS dist ${order} + LIMIT ${k} + `; + } +} diff --git a/libs/langchain-community/src/vectorstores/neo4j_vector.ts b/libs/langchain-community/src/vectorstores/neo4j_vector.ts new file mode 100644 index 000000000000..e35f24fde603 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/neo4j_vector.ts @@ -0,0 +1,731 @@ +import neo4j from "neo4j-driver"; +import * as uuid from "uuid"; +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; + +export type SearchType = "vector" | "hybrid"; + +export type DistanceStrategy = "euclidean" | "cosine"; + +interface Neo4jVectorStoreArgs { + url: string; + username: string; + password: string; + database?: string; + preDeleteCollection?: boolean; + textNodeProperty?: string; + textNodeProperties?: string[]; + embeddingNodeProperty?: string; + keywordIndexName?: string; + indexName?: string; + searchType?: SearchType; + retrievalQuery?: string; + nodeLabel?: string; + createIdIndex?: boolean; +} + +const DEFAULT_SEARCH_TYPE = "vector"; +const DEFAULT_DISTANCE_STRATEGY = "cosine"; + +/** + * @security *Security note*: Make sure that the database connection uses credentials + * that are narrowly-scoped to only include necessary permissions. + * Failure to do so may result in data corruption or loss, since the calling + * code may attempt commands that would result in deletion, mutation + * of data if appropriately prompted or reading sensitive data if such + * data is present in the database. + * The best way to guard against such negative outcomes is to (as appropriate) + * limit the permissions granted to the credentials used with this tool. + * For example, creating read only users for the database is a good way to + * ensure that the calling code cannot mutate or delete data. + * + * @link See https://js.langchain.com/docs/security for more information. + */ +export class Neo4jVectorStore extends VectorStore { + private driver: neo4j.Driver; + + private database: string; + + private preDeleteCollection: boolean; + + private nodeLabel: string; + + private embeddingNodeProperty: string; + + private embeddingDimension: number; + + private textNodeProperty: string; + + private keywordIndexName: string; + + private indexName: string; + + private retrievalQuery: string; + + private searchType: SearchType; + + private distanceStrategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY; + + _vectorstoreType(): string { + return "neo4jvector"; + } + + constructor(embeddings: Embeddings, config: Neo4jVectorStoreArgs) { + super(embeddings, config); + } + + static async initialize( + embeddings: Embeddings, + config: Neo4jVectorStoreArgs + ) { + const store = new Neo4jVectorStore(embeddings, config); + await store._initializeDriver(config); + await store._verifyConnectivity(); + + const { + preDeleteCollection = false, + nodeLabel = "Chunk", + textNodeProperty = "text", + embeddingNodeProperty = "embedding", + keywordIndexName = "keyword", + indexName = "vector", + retrievalQuery = "", + searchType = DEFAULT_SEARCH_TYPE, + } = config; + + store.embeddingDimension = (await embeddings.embedQuery("foo")).length; + store.preDeleteCollection = preDeleteCollection; + store.nodeLabel = nodeLabel; + store.textNodeProperty = textNodeProperty; + store.embeddingNodeProperty = embeddingNodeProperty; + store.keywordIndexName = keywordIndexName; + store.indexName = indexName; + store.retrievalQuery = retrievalQuery; + store.searchType = searchType; + + if (store.preDeleteCollection) { + await store._dropIndex(); + } + + return store; + } + + async _initializeDriver({ + url, + username, + password, + database = "neo4j", + }: Neo4jVectorStoreArgs) { + try { + this.driver = neo4j.driver(url, neo4j.auth.basic(username, password)); + this.database = database; + } catch (error) { + throw new Error( + "Could not create a Neo4j driver instance. Please check the connection details." + ); + } + } + + async _verifyConnectivity() { + await this.driver.verifyAuthentication(); + } + + async close() { + await this.driver.close(); + } + + async _dropIndex() { + try { + await this.query(` + MATCH (n:\`${this.nodeLabel}\`) + CALL { + WITH n + DETACH DELETE n + } + IN TRANSACTIONS OF 10000 ROWS; + `); + await this.query(`DROP INDEX ${this.indexName}`); + } catch (error) { + console.error("An error occurred while dropping the index:", error); + } + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + async query(query: string, params: any = {}): Promise { + const session = this.driver.session({ database: this.database }); + const result = await session.run(query, params); + return toObjects(result.records); + } + + static async fromTexts( + texts: string[], + // eslint-disable-next-line @typescript-eslint/no-explicit-any + metadatas: any, + embeddings: Embeddings, + config: Neo4jVectorStoreArgs + ): Promise { + const docs = []; + + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + + return Neo4jVectorStore.fromDocuments(docs, embeddings, config); + } + + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + config: Neo4jVectorStoreArgs + ): Promise { + const { + searchType = DEFAULT_SEARCH_TYPE, + createIdIndex = true, + textNodeProperties = [], + } = config; + + const store = await this.initialize(embeddings, config); + + const embeddingDimension = await store.retrieveExistingIndex(); + + if (!embeddingDimension) { + await store.createNewIndex(); + } else if (store.embeddingDimension !== embeddingDimension) { + throw new Error( + `Index with name "${store.indexName}" already exists. The provided embedding function and vector index dimensions do not match. + Embedding function dimension: ${store.embeddingDimension} + Vector index dimension: ${embeddingDimension}` + ); + } + + if (searchType === "hybrid") { + const ftsNodeLabel = await store.retrieveExistingFtsIndex(); + + if (!ftsNodeLabel) { + await store.createNewKeywordIndex(textNodeProperties); + } else { + if (ftsNodeLabel !== store.nodeLabel) { + throw Error( + "Vector and keyword index don't index the same node label" + ); + } + } + } + + if (createIdIndex) { + await store.query( + `CREATE CONSTRAINT IF NOT EXISTS FOR (n:${store.nodeLabel}) REQUIRE n.id IS UNIQUE;` + ); + } + + await store.addDocuments(docs); + + return store; + } + + static async fromExistingIndex( + embeddings: Embeddings, + config: Neo4jVectorStoreArgs + ) { + const { searchType = DEFAULT_SEARCH_TYPE, keywordIndexName = "keyword" } = + config; + + if (searchType === "hybrid" && !keywordIndexName) { + throw Error( + "keyword_index name has to be specified when using hybrid search option" + ); + } + + const store = await this.initialize(embeddings, config); + const embeddingDimension = await store.retrieveExistingIndex(); + + if (!embeddingDimension) { + throw Error( + "The specified vector index name does not exist. Make sure to check if you spelled it correctly" + ); + } + + if (store.embeddingDimension !== embeddingDimension) { + throw new Error( + `The provided embedding function and vector index dimensions do not match. + Embedding function dimension: ${store.embeddingDimension} + Vector index dimension: ${embeddingDimension}` + ); + } + + if (searchType === "hybrid") { + const ftsNodeLabel = await store.retrieveExistingFtsIndex(); + + if (!ftsNodeLabel) { + throw Error( + "The specified keyword index name does not exist. Make sure to check if you spelled it correctly" + ); + } else { + if (ftsNodeLabel !== store.nodeLabel) { + throw Error( + "Vector and keyword index don't index the same node label" + ); + } + } + } + + return store; + } + + static async fromExistingGraph( + embeddings: Embeddings, + config: Neo4jVectorStoreArgs + ) { + const { + textNodeProperties = [], + embeddingNodeProperty, + searchType = DEFAULT_SEARCH_TYPE, + retrievalQuery = "", + nodeLabel, + } = config; + + let _retrievalQuery = retrievalQuery; + + if (textNodeProperties.length === 0) { + throw Error( + "Parameter `text_node_properties` must not be an empty array" + ); + } + + if (!retrievalQuery) { + _retrievalQuery = ` + RETURN reduce(str='', k IN ${JSON.stringify(textNodeProperties)} | + str + '\\n' + k + ': ' + coalesce(node[k], '')) AS text, + node {.*, \`${embeddingNodeProperty}\`: Null, id: Null, ${textNodeProperties + .map((prop) => `\`${prop}\`: Null`) + .join(", ")} } AS metadata, score + `; + } + + const store = await this.initialize(embeddings, { + ...config, + retrievalQuery: _retrievalQuery, + }); + + const embeddingDimension = await store.retrieveExistingIndex(); + + if (!embeddingDimension) { + await store.createNewIndex(); + } else if (store.embeddingDimension !== embeddingDimension) { + throw new Error( + `Index with name ${store.indexName} already exists. The provided embedding function and vector index dimensions do not match.\nEmbedding function dimension: ${store.embeddingDimension}\nVector index dimension: ${embeddingDimension}` + ); + } + + if (searchType === "hybrid") { + const ftsNodeLabel = await store.retrieveExistingFtsIndex( + textNodeProperties + ); + + if (!ftsNodeLabel) { + await store.createNewKeywordIndex(textNodeProperties); + } else { + if (ftsNodeLabel !== store.nodeLabel) { + throw Error( + "Vector and keyword index don't index the same node label" + ); + } + } + } + + // eslint-disable-next-line no-constant-condition + while (true) { + const fetchQuery = ` + MATCH (n:\`${nodeLabel}\`) + WHERE n.${embeddingNodeProperty} IS null + AND any(k in $props WHERE n[k] IS NOT null) + RETURN elementId(n) AS id, reduce(str='', k IN $props | + str + '\\n' + k + ':' + coalesce(n[k], '')) AS text + LIMIT 1000 + `; + + const data = await store.query(fetchQuery, { props: textNodeProperties }); + + if (!data) { + continue; + } + + const textEmbeddings = await embeddings.embedDocuments( + data.map((el) => el.text) + ); + + const params = { + data: data.map((el, index) => ({ + id: el.id, + embedding: textEmbeddings[index], + })), + }; + + await store.query( + ` + UNWIND $data AS row + MATCH (n:\`${nodeLabel}\`) + WHERE elementId(n) = row.id + CALL db.create.setVectorProperty(n, '${embeddingNodeProperty}', row.embedding) + YIELD node RETURN count(*) + `, + params + ); + + if (data.length < 1000) { + break; + } + } + + return store; + } + + async createNewIndex(): Promise { + const indexQuery = ` + CALL db.index.vector.createNodeIndex( + $index_name, + $node_label, + $embedding_node_property, + toInteger($embedding_dimension), + $similarity_metric + ) + `; + + const parameters = { + index_name: this.indexName, + node_label: this.nodeLabel, + embedding_node_property: this.embeddingNodeProperty, + embedding_dimension: this.embeddingDimension, + similarity_metric: this.distanceStrategy, + }; + + await this.query(indexQuery, parameters); + } + + async retrieveExistingIndex() { + let indexInformation = await this.query( + ` + SHOW INDEXES YIELD name, type, labelsOrTypes, properties, options + WHERE type = 'VECTOR' AND (name = $index_name + OR (labelsOrTypes[0] = $node_label AND + properties[0] = $embedding_node_property)) + RETURN name, labelsOrTypes, properties, options + `, + { + index_name: this.indexName, + node_label: this.nodeLabel, + embedding_node_property: this.embeddingNodeProperty, + } + ); + + if (indexInformation) { + indexInformation = this.sortByIndexName(indexInformation, this.indexName); + + try { + const [index] = indexInformation; + const [labelOrType] = index.labelsOrTypes; + const [property] = index.properties; + + this.indexName = index.name; + this.nodeLabel = labelOrType; + this.embeddingNodeProperty = property; + + const embeddingDimension = + index.options.indexConfig["vector.dimensions"]; + return Number(embeddingDimension); + } catch (error) { + return null; + } + } + + return null; + } + + async retrieveExistingFtsIndex( + textNodeProperties: string[] = [] + ): Promise { + const indexInformation = await this.query( + ` + SHOW INDEXES YIELD name, type, labelsOrTypes, properties, options + WHERE type = 'FULLTEXT' AND (name = $keyword_index_name + OR (labelsOrTypes = [$node_label] AND + properties = $text_node_property)) + RETURN name, labelsOrTypes, properties, options + `, + { + keyword_index_name: this.keywordIndexName, + node_label: this.nodeLabel, + text_node_property: + textNodeProperties.length > 0 + ? textNodeProperties + : [this.textNodeProperty], + } + ); + + if (indexInformation) { + // Sort the index information by index name + const sortedIndexInformation = this.sortByIndexName( + indexInformation, + this.indexName + ); + + try { + const [index] = sortedIndexInformation; + const [labelOrType] = index.labelsOrTypes; + const [property] = index.properties; + + this.keywordIndexName = index.name; + this.textNodeProperty = property; + this.nodeLabel = labelOrType; + + return labelOrType; + } catch (error) { + return null; + } + } + + return null; + } + + async createNewKeywordIndex( + textNodeProperties: string[] = [] + ): Promise { + const nodeProps = + textNodeProperties.length > 0 + ? textNodeProperties + : [this.textNodeProperty]; + + // Construct the Cypher query to create a new full text index + const ftsIndexQuery = ` + CREATE FULLTEXT INDEX ${this.keywordIndexName} + FOR (n:\`${this.nodeLabel}\`) ON EACH + [${nodeProps.map((prop) => `n.\`${prop}\``).join(", ")}] + `; + + await this.query(ftsIndexQuery); + } + + sortByIndexName( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + values: Array<{ [key: string]: any }>, + indexName: string + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ): Array<{ [key: string]: any }> { + return values.sort( + (a, b) => + (a.index_name === indexName ? -1 : 0) - + (b.index_name === indexName ? -1 : 0) + ); + } + + async addVectors( + vectors: number[][], + documents: Document[], + // eslint-disable-next-line @typescript-eslint/no-explicit-any + metadatas?: Record[], + ids?: string[] + ): Promise { + let _ids = ids; + let _metadatas = metadatas; + + if (!_ids) { + _ids = documents.map(() => uuid.v1()); + } + + if (!metadatas) { + _metadatas = documents.map(() => ({})); + } + + const importQuery = ` + UNWIND $data AS row + CALL { + WITH row + MERGE (c:\`${this.nodeLabel}\` {id: row.id}) + WITH c, row + CALL db.create.setVectorProperty(c, '${this.embeddingNodeProperty}', row.embedding) + YIELD node + SET c.\`${this.textNodeProperty}\` = row.text + SET c += row.metadata + } IN TRANSACTIONS OF 1000 ROWS + `; + + const parameters = { + data: documents.map(({ pageContent, metadata }, index) => ({ + text: pageContent, + metadata: _metadatas ? _metadatas[index] : metadata, + embedding: vectors[index], + id: _ids ? _ids[index] : null, + })), + }; + + await this.query(importQuery, parameters); + + return _ids; + } + + async addDocuments(documents: Document[]): Promise { + const texts = documents.map(({ pageContent }) => pageContent); + + return this.addVectors( + await this.embeddings.embedDocuments(texts), + documents + ); + } + + async similaritySearch(query: string, k = 4): Promise { + const embedding = await this.embeddings.embedQuery(query); + + const results = await this.similaritySearchVectorWithScore( + embedding, + k, + query + ); + + return results.map((result) => result[0]); + } + + async similaritySearchVectorWithScore( + vector: number[], + k: number, + query: string + ): Promise<[Document, number][]> { + const defaultRetrieval = ` + RETURN node.${this.textNodeProperty} AS text, score, + node {.*, ${this.textNodeProperty}: Null, + ${this.embeddingNodeProperty}: Null, id: Null } AS metadata + `; + + const retrievalQuery = this.retrievalQuery + ? this.retrievalQuery + : defaultRetrieval; + + const readQuery = `${getSearchIndexQuery( + this.searchType + )} ${retrievalQuery}`; + + const parameters = { + index: this.indexName, + k: Number(k), + embedding: vector, + keyword_index: this.keywordIndexName, + query, + }; + const results = await this.query(readQuery, parameters); + + if (results) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const docs: [Document, number][] = results.map((result: any) => [ + new Document({ + pageContent: result.text, + metadata: Object.fromEntries( + Object.entries(result.metadata).filter(([_, v]) => v !== null) + ), + }), + result.score, + ]); + + return docs; + } + + return []; + } +} + +function toObjects(records: neo4j.Record[]) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const recordValues: Record[] = records.map((record) => { + const rObj = record.toObject(); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const out: { [key: string]: any } = {}; + Object.keys(rObj).forEach((key) => { + out[key] = itemIntToString(rObj[key]); + }); + return out; + }); + return recordValues; +} + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +function itemIntToString(item: any): any { + if (neo4j.isInt(item)) return item.toString(); + if (Array.isArray(item)) return item.map((ii) => itemIntToString(ii)); + if (["number", "string", "boolean"].indexOf(typeof item) !== -1) return item; + if (item === null) return item; + if (typeof item === "object") return objIntToString(item); +} + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +function objIntToString(obj: any) { + const entry = extractFromNeoObjects(obj); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let newObj: any = null; + if (Array.isArray(entry)) { + newObj = entry.map((item) => itemIntToString(item)); + } else if (entry !== null && typeof entry === "object") { + newObj = {}; + Object.keys(entry).forEach((key) => { + newObj[key] = itemIntToString(entry[key]); + }); + } + return newObj; +} + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +function extractFromNeoObjects(obj: any) { + if ( + // eslint-disable-next-line + obj instanceof (neo4j.types.Node as any) || + // eslint-disable-next-line + obj instanceof (neo4j.types.Relationship as any) + ) { + return obj.properties; + // eslint-disable-next-line + } else if (obj instanceof (neo4j.types.Path as any)) { + // eslint-disable-next-line + return [].concat.apply([], extractPathForRows(obj)); + } + return obj; +} + +function extractPathForRows(path: neo4j.Path) { + let { segments } = path; + // Zero length path. No relationship, end === start + if (!Array.isArray(path.segments) || path.segments.length < 1) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + segments = [{ ...path, end: null } as any]; + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return segments.map((segment: any) => + [ + objIntToString(segment.start), + objIntToString(segment.relationship), + objIntToString(segment.end), + ].filter((part) => part !== null) + ); +} + +function getSearchIndexQuery(searchType: SearchType): string { + const typeToQueryMap: { [key in SearchType]: string } = { + vector: + "CALL db.index.vector.queryNodes($index, $k, $embedding) YIELD node, score", + hybrid: ` + CALL { + CALL db.index.vector.queryNodes($index, $k, $embedding) YIELD node, score + RETURN node, score UNION + CALL db.index.fulltext.queryNodes($keyword_index, $query, {limit: $k}) YIELD node, score + WITH collect({node: node, score: score}) AS nodes, max(score) AS max + UNWIND nodes AS n + RETURN n.node AS node, (n.score / max) AS score + } + WITH node, max(score) AS score ORDER BY score DESC LIMIT toInteger($k) + `, + }; + + return typeToQueryMap[searchType]; +} diff --git a/libs/langchain-community/src/vectorstores/opensearch.ts b/libs/langchain-community/src/vectorstores/opensearch.ts new file mode 100644 index 000000000000..0ea90dff456a --- /dev/null +++ b/libs/langchain-community/src/vectorstores/opensearch.ts @@ -0,0 +1,326 @@ +import { Client, RequestParams, errors } from "@opensearch-project/opensearch"; +import * as uuid from "uuid"; +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; + +type OpenSearchEngine = "nmslib" | "hnsw"; +type OpenSearchSpaceType = "l2" | "cosinesimil" | "ip"; + +/** + * Interface defining the options for vector search in OpenSearch. It + * includes the engine type, space type, and parameters for the HNSW + * algorithm. + */ +interface VectorSearchOptions { + readonly engine?: OpenSearchEngine; + readonly spaceType?: OpenSearchSpaceType; + readonly m?: number; + readonly efConstruction?: number; + readonly efSearch?: number; +} + +/** + * Interface defining the arguments required to create an instance of the + * OpenSearchVectorStore class. It includes the OpenSearch client, index + * name, and vector search options. + */ +export interface OpenSearchClientArgs { + readonly client: Client; + readonly indexName?: string; + + readonly vectorSearchOptions?: VectorSearchOptions; +} + +/** + * Type alias for an object. It's used to define filters for OpenSearch + * queries. + */ +type OpenSearchFilter = object; + +/** + * Class that provides a wrapper around the OpenSearch service for vector + * search. It provides methods for adding documents and vectors to the + * OpenSearch index, searching for similar vectors, and managing the + * OpenSearch index. + */ +export class OpenSearchVectorStore extends VectorStore { + declare FilterType: OpenSearchFilter; + + private readonly client: Client; + + private readonly indexName: string; + + private readonly engine: OpenSearchEngine; + + private readonly spaceType: OpenSearchSpaceType; + + private readonly efConstruction: number; + + private readonly efSearch: number; + + private readonly m: number; + + _vectorstoreType(): string { + return "opensearch"; + } + + constructor(embeddings: Embeddings, args: OpenSearchClientArgs) { + super(embeddings, args); + + this.spaceType = args.vectorSearchOptions?.spaceType ?? "l2"; + this.engine = args.vectorSearchOptions?.engine ?? "nmslib"; + this.m = args.vectorSearchOptions?.m ?? 16; + this.efConstruction = args.vectorSearchOptions?.efConstruction ?? 512; + this.efSearch = args.vectorSearchOptions?.efSearch ?? 512; + + this.client = args.client; + this.indexName = args.indexName ?? "documents"; + } + + /** + * Method to add documents to the OpenSearch index. It first converts the + * documents to vectors using the embeddings, then adds the vectors to the + * index. + * @param documents The documents to be added to the OpenSearch index. + * @returns Promise resolving to void. + */ + async addDocuments(documents: Document[]): Promise { + const texts = documents.map(({ pageContent }) => pageContent); + return this.addVectors( + await this.embeddings.embedDocuments(texts), + documents + ); + } + + /** + * Method to add vectors to the OpenSearch index. It ensures the index + * exists, then adds the vectors and associated documents to the index. + * @param vectors The vectors to be added to the OpenSearch index. + * @param documents The documents associated with the vectors. + * @param options Optional parameter that can contain the IDs for the documents. + * @returns Promise resolving to void. + */ + async addVectors( + vectors: number[][], + documents: Document[], + options?: { ids?: string[] } + ): Promise { + await this.ensureIndexExists( + vectors[0].length, + this.engine, + this.spaceType, + this.efSearch, + this.efConstruction, + this.m + ); + const documentIds = + options?.ids ?? Array.from({ length: vectors.length }, () => uuid.v4()); + const operations = vectors.flatMap((embedding, idx) => [ + { + index: { + _index: this.indexName, + _id: documentIds[idx], + }, + }, + { + embedding, + metadata: documents[idx].metadata, + text: documents[idx].pageContent, + }, + ]); + await this.client.bulk({ body: operations }); + await this.client.indices.refresh({ index: this.indexName }); + } + + /** + * Method to perform a similarity search on the OpenSearch index using a + * query vector. It returns the k most similar documents and their scores. + * @param query The query vector. + * @param k The number of similar documents to return. + * @param filter Optional filter for the OpenSearch query. + * @returns Promise resolving to an array of tuples, each containing a Document and its score. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: OpenSearchFilter | undefined + ): Promise<[Document, number][]> { + const search: RequestParams.Search = { + index: this.indexName, + body: { + query: { + bool: { + filter: { bool: { must: this.buildMetadataTerms(filter) } }, + must: [ + { + knn: { + embedding: { vector: query, k }, + }, + }, + ], + }, + }, + size: k, + }, + }; + + const { body } = await this.client.search(search); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return body.hits.hits.map((hit: any) => [ + new Document({ + pageContent: hit._source.text, + metadata: hit._source.metadata, + }), + hit._score, + ]); + } + + /** + * Static method to create a new OpenSearchVectorStore from an array of + * texts, their metadata, embeddings, and OpenSearch client arguments. + * @param texts The texts to be converted into documents and added to the OpenSearch index. + * @param metadatas The metadata associated with the texts. Can be an array of objects or a single object. + * @param embeddings The embeddings used to convert the texts into vectors. + * @param args The OpenSearch client arguments. + * @returns Promise resolving to a new instance of OpenSearchVectorStore. + */ + static fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + args: OpenSearchClientArgs + ): Promise { + const documents = texts.map((text, idx) => { + const metadata = Array.isArray(metadatas) ? metadatas[idx] : metadatas; + return new Document({ pageContent: text, metadata }); + }); + + return OpenSearchVectorStore.fromDocuments(documents, embeddings, args); + } + + /** + * Static method to create a new OpenSearchVectorStore from an array of + * Documents, embeddings, and OpenSearch client arguments. + * @param docs The documents to be added to the OpenSearch index. + * @param embeddings The embeddings used to convert the documents into vectors. + * @param dbConfig The OpenSearch client arguments. + * @returns Promise resolving to a new instance of OpenSearchVectorStore. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig: OpenSearchClientArgs + ): Promise { + const store = new OpenSearchVectorStore(embeddings, dbConfig); + await store.addDocuments(docs).then(() => store); + return store; + } + + /** + * Static method to create a new OpenSearchVectorStore from an existing + * OpenSearch index, embeddings, and OpenSearch client arguments. + * @param embeddings The embeddings used to convert the documents into vectors. + * @param dbConfig The OpenSearch client arguments. + * @returns Promise resolving to a new instance of OpenSearchVectorStore. + */ + static async fromExistingIndex( + embeddings: Embeddings, + dbConfig: OpenSearchClientArgs + ): Promise { + const store = new OpenSearchVectorStore(embeddings, dbConfig); + await store.client.cat.indices({ index: store.indexName }); + return store; + } + + private async ensureIndexExists( + dimension: number, + engine = "nmslib", + spaceType = "l2", + efSearch = 512, + efConstruction = 512, + m = 16 + ): Promise { + const body = { + settings: { + index: { + number_of_shards: 5, + number_of_replicas: 1, + knn: true, + "knn.algo_param.ef_search": efSearch, + }, + }, + mappings: { + dynamic_templates: [ + { + // map all metadata properties to be keyword + "metadata.*": { + match_mapping_type: "*", + mapping: { type: "keyword" }, + }, + }, + ], + properties: { + text: { type: "text" }, + metadata: { type: "object" }, + embedding: { + type: "knn_vector", + dimension, + method: { + name: "hnsw", + engine, + space_type: spaceType, + parameters: { ef_construction: efConstruction, m }, + }, + }, + }, + }, + }; + + const indexExists = await this.doesIndexExist(); + if (indexExists) return; + + await this.client.indices.create({ index: this.indexName, body }); + } + + private buildMetadataTerms( + filter?: OpenSearchFilter + ): { [key: string]: Record }[] { + if (filter == null) return []; + const result = []; + for (const [key, value] of Object.entries(filter)) { + const aggregatorKey = Array.isArray(value) ? "terms" : "term"; + result.push({ [aggregatorKey]: { [`metadata.${key}`]: value } }); + } + return result; + } + + /** + * Method to check if the OpenSearch index exists. + * @returns Promise resolving to a boolean indicating whether the index exists. + */ + async doesIndexExist(): Promise { + try { + await this.client.cat.indices({ index: this.indexName }); + return true; + } catch (err: unknown) { + // eslint-disable-next-line no-instanceof/no-instanceof + if (err instanceof errors.ResponseError && err.statusCode === 404) { + return false; + } + throw err; + } + } + + /** + * Method to delete the OpenSearch index if it exists. + * @returns Promise resolving to void. + */ + async deleteIfExists(): Promise { + const indexExists = await this.doesIndexExist(); + if (!indexExists) return; + + await this.client.indices.delete({ index: this.indexName }); + } +} diff --git a/libs/langchain-community/src/vectorstores/pgvector.ts b/libs/langchain-community/src/vectorstores/pgvector.ts new file mode 100644 index 000000000000..6755e365c75c --- /dev/null +++ b/libs/langchain-community/src/vectorstores/pgvector.ts @@ -0,0 +1,440 @@ +import pg, { type Pool, type PoolClient, type PoolConfig } from "pg"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Embeddings } from "@langchain/core/embeddings"; +import { Document } from "@langchain/core/documents"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +type Metadata = Record; + +/** + * Interface that defines the arguments required to create a + * `PGVectorStore` instance. It includes Postgres connection options, + * table name, filter, and verbosity level. + */ +export interface PGVectorStoreArgs { + postgresConnectionOptions: PoolConfig; + tableName: string; + collectionTableName?: string; + collectionName?: string; + collectionMetadata?: Metadata | null; + columns?: { + idColumnName?: string; + vectorColumnName?: string; + contentColumnName?: string; + metadataColumnName?: string; + }; + filter?: Metadata; + verbose?: boolean; + /** + * The amount of documents to chunk by when + * adding vectors. + * @default 500 + */ + chunkSize?: number; +} + +/** + * Class that provides an interface to a Postgres vector database. It + * extends the `VectorStore` base class and implements methods for adding + * documents and vectors, performing similarity searches, and ensuring the + * existence of a table in the database. + */ +export class PGVectorStore extends VectorStore { + declare FilterType: Metadata; + + tableName: string; + + collectionTableName?: string; + + collectionName = "langchain"; + + collectionMetadata: Metadata | null; + + idColumnName: string; + + vectorColumnName: string; + + contentColumnName: string; + + metadataColumnName: string; + + filter?: Metadata; + + _verbose?: boolean; + + pool: Pool; + + client?: PoolClient; + + chunkSize = 500; + + _vectorstoreType(): string { + return "pgvector"; + } + + private constructor(embeddings: Embeddings, config: PGVectorStoreArgs) { + super(embeddings, config); + this.tableName = config.tableName; + this.collectionTableName = config.collectionTableName; + this.collectionName = config.collectionName ?? "langchain"; + this.collectionMetadata = config.collectionMetadata ?? null; + this.filter = config.filter; + + this.vectorColumnName = config.columns?.vectorColumnName ?? "embedding"; + this.contentColumnName = config.columns?.contentColumnName ?? "text"; + this.idColumnName = config.columns?.idColumnName ?? "id"; + this.metadataColumnName = config.columns?.metadataColumnName ?? "metadata"; + + const pool = new pg.Pool(config.postgresConnectionOptions); + this.pool = pool; + this.chunkSize = config.chunkSize ?? 500; + + this._verbose = + getEnvironmentVariable("LANGCHAIN_VERBOSE") === "true" ?? + !!config.verbose; + } + + /** + * Static method to create a new `PGVectorStore` instance from a + * connection. It creates a table if one does not exist, and calls + * `connect` to return a new instance of `PGVectorStore`. + * + * @param embeddings - Embeddings instance. + * @param fields - `PGVectorStoreArgs` instance. + * @returns A new instance of `PGVectorStore`. + */ + static async initialize( + embeddings: Embeddings, + config: PGVectorStoreArgs + ): Promise { + const postgresqlVectorStore = new PGVectorStore(embeddings, config); + + await postgresqlVectorStore._initializeClient(); + await postgresqlVectorStore.ensureTableInDatabase(); + if (postgresqlVectorStore.collectionTableName) { + await postgresqlVectorStore.ensureCollectionTableInDatabase(); + } + + return postgresqlVectorStore; + } + + protected async _initializeClient() { + this.client = await this.pool.connect(); + } + + /** + * Method to add documents to the vector store. It converts the documents into + * vectors, and adds them to the store. + * + * @param documents - Array of `Document` instances. + * @returns Promise that resolves when the documents have been added. + */ + async addDocuments(documents: Document[]): Promise { + const texts = documents.map(({ pageContent }) => pageContent); + + return this.addVectors( + await this.embeddings.embedDocuments(texts), + documents + ); + } + + /** + * Inserts a row for the collectionName provided at initialization if it does not + * exist and returns the collectionId. + * + * @returns The collectionId for the given collectionName. + */ + async getOrCreateCollection(): Promise { + const queryString = ` + SELECT uuid from ${this.collectionTableName} + WHERE name = $1; + `; + const queryResult = await this.pool.query(queryString, [ + this.collectionName, + ]); + let collectionId = queryResult.rows[0]?.uuid; + + if (!collectionId) { + const insertString = ` + INSERT INTO ${this.collectionTableName}( + uuid, + name, + cmetadata + ) + VALUES ( + uuid_generate_v4(), + $1, + $2 + ) + RETURNING uuid; + `; + const insertResult = await this.pool.query(insertString, [ + this.collectionName, + this.collectionMetadata, + ]); + collectionId = insertResult.rows[0]?.uuid; + } + + return collectionId; + } + + /** + * Generates the SQL placeholders for a specific row at the provided index. + * + * @param index - The index of the row for which placeholders need to be generated. + * @param numOfColumns - The number of columns we are inserting data into. + * @returns The SQL placeholders for the row values. + */ + private generatePlaceholderForRowAt( + index: number, + numOfColumns: number + ): string { + const placeholders = []; + for (let i = 0; i < numOfColumns; i += 1) { + placeholders.push(`$${index * numOfColumns + i + 1}`); + } + return `(${placeholders.join(", ")})`; + } + + /** + * Constructs the SQL query for inserting rows into the specified table. + * + * @param rows - The rows of data to be inserted, consisting of values and records. + * @param chunkIndex - The starting index for generating query placeholders based on chunk positioning. + * @returns The complete SQL INSERT INTO query string. + */ + private async buildInsertQuery(rows: (string | Record)[][]) { + let collectionId; + if (this.collectionTableName) { + collectionId = await this.getOrCreateCollection(); + } + + const columns = [ + this.contentColumnName, + this.vectorColumnName, + this.metadataColumnName, + ]; + + if (collectionId) { + columns.push("collection_id"); + } + + const valuesPlaceholders = rows + .map((_, j) => this.generatePlaceholderForRowAt(j, columns.length)) + .join(", "); + + const text = ` + INSERT INTO ${this.tableName}( + ${columns} + ) + VALUES ${valuesPlaceholders} + `; + return text; + } + + /** + * Method to add vectors to the vector store. It converts the vectors into + * rows and inserts them into the database. + * + * @param vectors - Array of vectors. + * @param documents - Array of `Document` instances. + * @returns Promise that resolves when the vectors have been added. + */ + async addVectors(vectors: number[][], documents: Document[]): Promise { + const rows = []; + let collectionId; + if (this.collectionTableName) { + collectionId = await this.getOrCreateCollection(); + } + + for (let i = 0; i < vectors.length; i += 1) { + const values = []; + const embedding = vectors[i]; + const embeddingString = `[${embedding.join(",")}]`; + values.push( + documents[i].pageContent, + embeddingString, + documents[i].metadata + ); + if (collectionId) { + values.push(collectionId); + } + rows.push(values); + } + + for (let i = 0; i < rows.length; i += this.chunkSize) { + const chunk = rows.slice(i, i + this.chunkSize); + const insertQuery = await this.buildInsertQuery(chunk); + const flatValues = chunk.flat(); + try { + await this.pool.query(insertQuery, flatValues); + } catch (e) { + console.error(e); + throw new Error(`Error inserting: ${(e as Error).message}`); + } + } + } + + /** + * Method to perform a similarity search in the vector store. It returns + * the `k` most similar documents to the query vector, along with their + * similarity scores. + * + * @param query - Query vector. + * @param k - Number of most similar documents to return. + * @param filter - Optional filter to apply to the search. + * @returns Promise that resolves with an array of tuples, each containing a `Document` and its similarity score. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: this["FilterType"] + ): Promise<[Document, number][]> { + const embeddingString = `[${query.join(",")}]`; + const _filter = filter ?? "{}"; + let collectionId; + if (this.collectionTableName) { + collectionId = await this.getOrCreateCollection(); + } + + const parameters = [embeddingString, _filter, k]; + if (collectionId) { + parameters.push(collectionId); + } + + const queryString = ` + SELECT *, ${this.vectorColumnName} <=> $1 as "_distance" + FROM ${this.tableName} + WHERE ${this.metadataColumnName}::jsonb @> $2 + ${collectionId ? "AND collection_id = $4" : ""} + ORDER BY "_distance" ASC + LIMIT $3; + `; + + const documents = (await this.pool.query(queryString, parameters)).rows; + + const results = [] as [Document, number][]; + for (const doc of documents) { + if (doc._distance != null && doc[this.contentColumnName] != null) { + const document = new Document({ + pageContent: doc[this.contentColumnName], + metadata: doc[this.metadataColumnName], + }); + results.push([document, doc._distance]); + } + } + return results; + } + + /** + * Method to ensure the existence of the table in the database. It creates + * the table if it does not already exist. + * + * @returns Promise that resolves when the table has been ensured. + */ + async ensureTableInDatabase(): Promise { + await this.pool.query("CREATE EXTENSION IF NOT EXISTS vector;"); + await this.pool.query('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";'); + + await this.pool.query(` + CREATE TABLE IF NOT EXISTS ${this.tableName} ( + "${this.idColumnName}" uuid NOT NULL DEFAULT uuid_generate_v4() PRIMARY KEY, + "${this.contentColumnName}" text, + "${this.metadataColumnName}" jsonb, + "${this.vectorColumnName}" vector + ); + `); + } + + /** + * Method to ensure the existence of the collection table in the database. + * It creates the table if it does not already exist. + * + * @returns Promise that resolves when the collection table has been ensured. + */ + async ensureCollectionTableInDatabase(): Promise { + try { + await this.pool.query(` + CREATE TABLE IF NOT EXISTS ${this.collectionTableName} ( + uuid uuid NOT NULL DEFAULT uuid_generate_v4() PRIMARY KEY, + name character varying, + cmetadata jsonb + ); + + ALTER TABLE ${this.tableName} + ADD COLUMN collection_id uuid; + + ALTER TABLE ${this.tableName} + ADD CONSTRAINT ${this.tableName}_collection_id_fkey + FOREIGN KEY (collection_id) + REFERENCES ${this.collectionTableName}(uuid) + ON DELETE CASCADE; + `); + } catch (e) { + if (!(e as Error).message.includes("already exists")) { + console.error(e); + throw new Error(`Error adding column: ${(e as Error).message}`); + } + } + } + + /** + * Static method to create a new `PGVectorStore` instance from an + * array of texts and their metadata. It converts the texts into + * `Document` instances and adds them to the store. + * + * @param texts - Array of texts. + * @param metadatas - Array of metadata objects or a single metadata object. + * @param embeddings - Embeddings instance. + * @param dbConfig - `PGVectorStoreArgs` instance. + * @returns Promise that resolves with a new instance of `PGVectorStore`. + */ + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + dbConfig: PGVectorStoreArgs + ): Promise { + const docs = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + + return PGVectorStore.fromDocuments(docs, embeddings, dbConfig); + } + + /** + * Static method to create a new `PGVectorStore` instance from an + * array of `Document` instances. It adds the documents to the store. + * + * @param docs - Array of `Document` instances. + * @param embeddings - Embeddings instance. + * @param dbConfig - `PGVectorStoreArgs` instance. + * @returns Promise that resolves with a new instance of `PGVectorStore`. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig: PGVectorStoreArgs + ): Promise { + const instance = await PGVectorStore.initialize(embeddings, dbConfig); + await instance.addDocuments(docs); + + return instance; + } + + /** + * Closes all the clients in the pool and terminates the pool. + * + * @returns Promise that resolves when all clients are closed and the pool is terminated. + */ + async end(): Promise { + this.client?.release(); + return this.pool.end(); + } +} diff --git a/libs/langchain-community/src/vectorstores/pinecone.ts b/libs/langchain-community/src/vectorstores/pinecone.ts new file mode 100644 index 000000000000..e368fe1670e9 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/pinecone.ts @@ -0,0 +1,366 @@ +/* eslint-disable no-process-env */ +import * as uuid from "uuid"; +import flatten from "flat"; + +import { + RecordMetadata, + PineconeRecord, + Index as PineconeIndex, +} from "@pinecone-database/pinecone"; + +import { + MaxMarginalRelevanceSearchOptions, + VectorStore, +} from "@langchain/core/vectorstores"; +import { Embeddings } from "@langchain/core/embeddings"; +import { Document } from "@langchain/core/documents"; +import { + AsyncCaller, + AsyncCallerParams, +} from "@langchain/core/utils/async_caller"; +import { maximalMarginalRelevance } from "@langchain/core/utils/math"; +import { chunkArray } from "../utils/chunk.js"; + +// eslint-disable-next-line @typescript-eslint/ban-types, @typescript-eslint/no-explicit-any +type PineconeMetadata = Record; + +export interface PineconeLibArgs extends AsyncCallerParams { + pineconeIndex: PineconeIndex; + textKey?: string; + namespace?: string; + filter?: PineconeMetadata; +} + +/** + * Type that defines the parameters for the delete operation in the + * PineconeStore class. It includes ids, filter, deleteAll flag, and namespace. + */ +export type PineconeDeleteParams = { + ids?: string[]; + deleteAll?: boolean; + filter?: object; + namespace?: string; +}; + +/** + * Class that extends the VectorStore class and provides methods to + * interact with the Pinecone vector database. + */ +export class PineconeStore extends VectorStore { + declare FilterType: PineconeMetadata; + + textKey: string; + + namespace?: string; + + pineconeIndex: PineconeIndex; + + filter?: PineconeMetadata; + + caller: AsyncCaller; + + _vectorstoreType(): string { + return "pinecone"; + } + + constructor(embeddings: Embeddings, args: PineconeLibArgs) { + super(embeddings, args); + + this.embeddings = embeddings; + const { namespace, pineconeIndex, textKey, filter, ...asyncCallerArgs } = + args; + this.namespace = namespace; + this.pineconeIndex = pineconeIndex; + this.textKey = textKey ?? "text"; + this.filter = filter; + this.caller = new AsyncCaller(asyncCallerArgs); + } + + /** + * Method that adds documents to the Pinecone database. + * @param documents Array of documents to add to the Pinecone database. + * @param options Optional ids for the documents. + * @returns Promise that resolves with the ids of the added documents. + */ + async addDocuments( + documents: Document[], + options?: { ids?: string[] } | string[] + ) { + const texts = documents.map(({ pageContent }) => pageContent); + return this.addVectors( + await this.embeddings.embedDocuments(texts), + documents, + options + ); + } + + /** + * Method that adds vectors to the Pinecone database. + * @param vectors Array of vectors to add to the Pinecone database. + * @param documents Array of documents associated with the vectors. + * @param options Optional ids for the vectors. + * @returns Promise that resolves with the ids of the added vectors. + */ + async addVectors( + vectors: number[][], + documents: Document[], + options?: { ids?: string[] } | string[] + ) { + const ids = Array.isArray(options) ? options : options?.ids; + const documentIds = ids == null ? documents.map(() => uuid.v4()) : ids; + const pineconeVectors = vectors.map((values, idx) => { + // Pinecone doesn't support nested objects, so we flatten them + const documentMetadata = { ...documents[idx].metadata }; + // preserve string arrays which are allowed + const stringArrays: Record = {}; + for (const key of Object.keys(documentMetadata)) { + if ( + Array.isArray(documentMetadata[key]) && + // eslint-disable-next-line @typescript-eslint/ban-types, @typescript-eslint/no-explicit-any + documentMetadata[key].every((el: any) => typeof el === "string") + ) { + stringArrays[key] = documentMetadata[key]; + delete documentMetadata[key]; + } + } + const metadata: { + [key: string]: string | number | boolean | string[] | null; + } = { + ...flatten(documentMetadata), + ...stringArrays, + [this.textKey]: documents[idx].pageContent, + }; + // Pinecone doesn't support null values, so we remove them + for (const key of Object.keys(metadata)) { + if (metadata[key] == null) { + delete metadata[key]; + } else if ( + typeof metadata[key] === "object" && + Object.keys(metadata[key] as unknown as object).length === 0 + ) { + delete metadata[key]; + } + } + + return { + id: documentIds[idx], + metadata, + values, + } as PineconeRecord; + }); + + const namespace = this.pineconeIndex.namespace(this.namespace ?? ""); + // Pinecone recommends a limit of 100 vectors per upsert request + const chunkSize = 100; + const chunkedVectors = chunkArray(pineconeVectors, chunkSize); + const batchRequests = chunkedVectors.map((chunk) => + this.caller.call(async () => namespace.upsert(chunk)) + ); + + await Promise.all(batchRequests); + + return documentIds; + } + + /** + * Method that deletes vectors from the Pinecone database. + * @param params Parameters for the delete operation. + * @returns Promise that resolves when the delete operation is complete. + */ + async delete(params: PineconeDeleteParams): Promise { + const { deleteAll, ids, filter } = params; + const namespace = this.pineconeIndex.namespace(this.namespace ?? ""); + + if (deleteAll) { + await namespace.deleteAll(); + } else if (ids) { + const batchSize = 1000; + for (let i = 0; i < ids.length; i += batchSize) { + const batchIds = ids.slice(i, i + batchSize); + await namespace.deleteMany(batchIds); + } + } else if (filter) { + await namespace.deleteMany(filter); + } else { + throw new Error("Either ids or delete_all must be provided."); + } + } + + protected async _runPineconeQuery( + query: number[], + k: number, + filter?: PineconeMetadata, + options?: { includeValues: boolean } + ) { + if (filter && this.filter) { + throw new Error("cannot provide both `filter` and `this.filter`"); + } + const _filter = filter ?? this.filter; + const namespace = this.pineconeIndex.namespace(this.namespace ?? ""); + + const results = await namespace.query({ + includeMetadata: true, + topK: k, + vector: query, + filter: _filter, + ...options, + }); + + return results; + } + + /** + * Method that performs a similarity search in the Pinecone database and + * returns the results along with their scores. + * @param query Query vector for the similarity search. + * @param k Number of top results to return. + * @param filter Optional filter to apply to the search. + * @returns Promise that resolves with an array of documents and their scores. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: PineconeMetadata + ): Promise<[Document, number][]> { + const results = await this._runPineconeQuery(query, k, filter); + const result: [Document, number][] = []; + + if (results.matches) { + for (const res of results.matches) { + const { [this.textKey]: pageContent, ...metadata } = (res.metadata ?? + {}) as PineconeMetadata; + if (res.score) { + result.push([new Document({ metadata, pageContent }), res.score]); + } + } + } + + return result; + } + + /** + * Return documents selected using the maximal marginal relevance. + * Maximal marginal relevance optimizes for similarity to the query AND diversity + * among selected documents. + * + * @param {string} query - Text to look up documents similar to. + * @param {number} options.k - Number of documents to return. + * @param {number} options.fetchK=20 - Number of documents to fetch before passing to the MMR algorithm. + * @param {number} options.lambda=0.5 - Number between 0 and 1 that determines the degree of diversity among the results, + * where 0 corresponds to maximum diversity and 1 to minimum diversity. + * @param {PineconeMetadata} options.filter - Optional filter to apply to the search. + * + * @returns {Promise} - List of documents selected by maximal marginal relevance. + */ + async maxMarginalRelevanceSearch( + query: string, + options: MaxMarginalRelevanceSearchOptions + ): Promise { + const queryEmbedding = await this.embeddings.embedQuery(query); + + const results = await this._runPineconeQuery( + queryEmbedding, + options.fetchK ?? 20, + options.filter, + { includeValues: true } + ); + + const matches = results?.matches ?? []; + const embeddingList = matches.map((match) => match.values); + + const mmrIndexes = maximalMarginalRelevance( + queryEmbedding, + embeddingList, + options.lambda, + options.k + ); + + const topMmrMatches = mmrIndexes.map((idx) => matches[idx]); + + const finalResult: Document[] = []; + for (const res of topMmrMatches) { + const { [this.textKey]: pageContent, ...metadata } = (res.metadata ?? + {}) as PineconeMetadata; + if (res.score) { + finalResult.push(new Document({ metadata, pageContent })); + } + } + + return finalResult; + } + + /** + * Static method that creates a new instance of the PineconeStore class + * from texts. + * @param texts Array of texts to add to the Pinecone database. + * @param metadatas Metadata associated with the texts. + * @param embeddings Embeddings to use for the texts. + * @param dbConfig Configuration for the Pinecone database. + * @returns Promise that resolves with a new instance of the PineconeStore class. + */ + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + dbConfig: + | { + pineconeIndex: PineconeIndex; + textKey?: string; + namespace?: string | undefined; + } + | PineconeLibArgs + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + + const args: PineconeLibArgs = { + pineconeIndex: dbConfig.pineconeIndex, + textKey: dbConfig.textKey, + namespace: dbConfig.namespace, + }; + return PineconeStore.fromDocuments(docs, embeddings, args); + } + + /** + * Static method that creates a new instance of the PineconeStore class + * from documents. + * @param docs Array of documents to add to the Pinecone database. + * @param embeddings Embeddings to use for the documents. + * @param dbConfig Configuration for the Pinecone database. + * @returns Promise that resolves with a new instance of the PineconeStore class. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig: PineconeLibArgs + ): Promise { + const args = dbConfig; + args.textKey = dbConfig.textKey ?? "text"; + + const instance = new this(embeddings, args); + await instance.addDocuments(docs); + return instance; + } + + /** + * Static method that creates a new instance of the PineconeStore class + * from an existing index. + * @param embeddings Embeddings to use for the documents. + * @param dbConfig Configuration for the Pinecone database. + * @returns Promise that resolves with a new instance of the PineconeStore class. + */ + static async fromExistingIndex( + embeddings: Embeddings, + dbConfig: PineconeLibArgs + ): Promise { + const instance = new this(embeddings, dbConfig); + return instance; + } +} diff --git a/libs/langchain-community/src/vectorstores/prisma.ts b/libs/langchain-community/src/vectorstores/prisma.ts new file mode 100644 index 000000000000..170a41d96367 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/prisma.ts @@ -0,0 +1,511 @@ +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; +import { Callbacks } from "@langchain/core/callbacks/manager"; + +const IdColumnSymbol = Symbol("id"); +const ContentColumnSymbol = Symbol("content"); + +type ColumnSymbol = typeof IdColumnSymbol | typeof ContentColumnSymbol; + +declare type Value = unknown; +declare type RawValue = Value | Sql; + +declare class Sql { + strings: string[]; + + constructor( + rawStrings: ReadonlyArray, + rawValues: ReadonlyArray + ); +} + +type PrismaNamespace = { + ModelName: Record; + Sql: typeof Sql; + raw: (sql: string) => Sql; + join: ( + values: RawValue[], + separator?: string, + prefix?: string, + suffix?: string + ) => Sql; + sql: (strings: ReadonlyArray, ...values: RawValue[]) => Sql; +}; + +type PrismaClient = { + $queryRaw( + query: TemplateStringsArray | Sql, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ...values: any[] + ): Promise; + $executeRaw( + query: TemplateStringsArray | Sql, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ...values: any[] + ): // eslint-disable-next-line @typescript-eslint/no-explicit-any + Promise; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + $transaction

[]>(arg: [...P]): Promise; +}; + +type ObjectIntersect = { + [P in keyof A & keyof B]: A[P] | B[P]; +}; + +type ModelColumns> = { + [K in keyof TModel]?: true | ColumnSymbol; +}; + +export type PrismaSqlFilter> = { + [K in keyof TModel]?: { + equals?: TModel[K]; + in?: TModel[K][]; + isNull?: TModel[K]; + isNotNull?: TModel[K]; + like?: TModel[K]; + lt?: TModel[K]; + lte?: TModel[K]; + gt?: TModel[K]; + gte?: TModel[K]; + not?: TModel[K]; + }; +}; + +const OpMap = { + equals: "=", + in: "IN", + isNull: "IS NULL", + isNotNull: "IS NOT NULL", + like: "LIKE", + lt: "<", + lte: "<=", + gt: ">", + gte: ">=", + not: "<>", +}; + +type SimilarityModel< + TModel extends Record = Record, + TColumns extends ModelColumns = ModelColumns +> = Pick> & { + _distance: number | null; +}; + +type DefaultPrismaVectorStore = PrismaVectorStore< + Record, + string, + ModelColumns>, + PrismaSqlFilter> +>; + +/** + * A specific implementation of the VectorStore class that is designed to + * work with Prisma. It provides methods for adding models, documents, and + * vectors, as well as for performing similarity searches. + */ +export class PrismaVectorStore< + TModel extends Record, + TModelName extends string, + TSelectModel extends ModelColumns, + TFilterModel extends PrismaSqlFilter +> extends VectorStore { + protected tableName: string; + + protected vectorColumnName: string; + + protected selectColumns: string[]; + + filter?: TFilterModel; + + idColumn: keyof TModel & string; + + contentColumn: keyof TModel & string; + + static IdColumn: typeof IdColumnSymbol = IdColumnSymbol; + + static ContentColumn: typeof ContentColumnSymbol = ContentColumnSymbol; + + protected db: PrismaClient; + + protected Prisma: PrismaNamespace; + + _vectorstoreType(): string { + return "prisma"; + } + + constructor( + embeddings: Embeddings, + config: { + db: PrismaClient; + prisma: PrismaNamespace; + tableName: TModelName; + vectorColumnName: string; + columns: TSelectModel; + filter?: TFilterModel; + } + ) { + super(embeddings, {}); + + this.Prisma = config.prisma; + this.db = config.db; + + const entries = Object.entries(config.columns); + const idColumn = entries.find((i) => i[1] === IdColumnSymbol)?.[0]; + const contentColumn = entries.find( + (i) => i[1] === ContentColumnSymbol + )?.[0]; + + if (idColumn == null) throw new Error("Missing ID column"); + if (contentColumn == null) throw new Error("Missing content column"); + + this.idColumn = idColumn; + this.contentColumn = contentColumn; + + this.tableName = config.tableName; + this.vectorColumnName = config.vectorColumnName; + + this.selectColumns = entries + .map(([key, alias]) => (alias && key) || null) + .filter((x): x is string => !!x); + + if (config.filter) { + this.filter = config.filter; + } + } + + /** + * Creates a new PrismaVectorStore with the specified model. + * @param db The PrismaClient instance. + * @returns An object with create, fromTexts, and fromDocuments methods. + */ + static withModel>(db: PrismaClient) { + function create< + TPrisma extends PrismaNamespace, + TColumns extends ModelColumns, + TFilters extends PrismaSqlFilter + >( + embeddings: Embeddings, + config: { + prisma: TPrisma; + tableName: keyof TPrisma["ModelName"] & string; + vectorColumnName: string; + columns: TColumns; + filter?: TFilters; + } + ) { + type ModelName = keyof TPrisma["ModelName"] & string; + return new PrismaVectorStore( + embeddings, + { ...config, db } + ); + } + + async function fromTexts< + TPrisma extends PrismaNamespace, + TColumns extends ModelColumns + >( + texts: string[], + metadatas: TModel[], + embeddings: Embeddings, + dbConfig: { + prisma: TPrisma; + tableName: keyof TPrisma["ModelName"] & string; + vectorColumnName: string; + columns: TColumns; + } + ) { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + + return PrismaVectorStore.fromDocuments(docs, embeddings, { + ...dbConfig, + db, + }); + } + + async function fromDocuments< + TPrisma extends PrismaNamespace, + TColumns extends ModelColumns, + TFilters extends PrismaSqlFilter + >( + docs: Document[], + embeddings: Embeddings, + dbConfig: { + prisma: TPrisma; + tableName: keyof TPrisma["ModelName"] & string; + vectorColumnName: string; + columns: TColumns; + } + ) { + type ModelName = keyof TPrisma["ModelName"] & string; + const instance = new PrismaVectorStore< + TModel, + ModelName, + TColumns, + TFilters + >(embeddings, { ...dbConfig, db }); + await instance.addDocuments(docs); + return instance; + } + + return { create, fromTexts, fromDocuments }; + } + + /** + * Adds the specified models to the store. + * @param models The models to add. + * @returns A promise that resolves when the models have been added. + */ + async addModels(models: TModel[]) { + return this.addDocuments( + models.map((metadata) => { + const pageContent = metadata[this.contentColumn]; + if (typeof pageContent !== "string") + throw new Error("Content column must be a string"); + return new Document({ pageContent, metadata }); + }) + ); + } + + /** + * Adds the specified documents to the store. + * @param documents The documents to add. + * @returns A promise that resolves when the documents have been added. + */ + async addDocuments(documents: Document[]) { + const texts = documents.map(({ pageContent }) => pageContent); + return this.addVectors( + await this.embeddings.embedDocuments(texts), + documents + ); + } + + /** + * Adds the specified vectors to the store. + * @param vectors The vectors to add. + * @param documents The documents associated with the vectors. + * @returns A promise that resolves when the vectors have been added. + */ + async addVectors(vectors: number[][], documents: Document[]) { + // table name, column name cannot be parametrised + // these fields are thus not escaped by Prisma and can be dangerous if user input is used + const idColumnRaw = this.Prisma.raw(`"${this.idColumn}"`); + const tableNameRaw = this.Prisma.raw(`"${this.tableName}"`); + const vectorColumnRaw = this.Prisma.raw(`"${this.vectorColumnName}"`); + + await this.db.$transaction( + vectors.map( + (vector, idx) => this.db.$executeRaw` + UPDATE ${tableNameRaw} + SET ${vectorColumnRaw} = ${`[${vector.join(",")}]`}::vector + WHERE ${idColumnRaw} = ${documents[idx].metadata[this.idColumn]} + ` + ) + ); + } + + /** + * Performs a similarity search with the specified query. + * @param query The query to use for the similarity search. + * @param k The number of results to return. + * @param _filter The filter to apply to the results. + * @param _callbacks The callbacks to use during the search. + * @returns A promise that resolves with the search results. + */ + async similaritySearch( + query: string, + k = 4, + _filter: this["FilterType"] | undefined = undefined, // not used. here to make the interface compatible with the other stores + _callbacks: Callbacks | undefined = undefined // implement passing to embedQuery later + ): Promise>[]> { + const results = await this.similaritySearchVectorWithScore( + await this.embeddings.embedQuery(query), + k + ); + + return results.map((result) => result[0]); + } + + /** + * Performs a similarity search with the specified query and returns the + * results along with their scores. + * @param query The query to use for the similarity search. + * @param k The number of results to return. + * @param filter The filter to apply to the results. + * @param _callbacks The callbacks to use during the search. + * @returns A promise that resolves with the search results and their scores. + */ + async similaritySearchWithScore( + query: string, + k?: number, + filter?: TFilterModel, + _callbacks: Callbacks | undefined = undefined // implement passing to embedQuery later + ) { + return super.similaritySearchWithScore(query, k, filter); + } + + /** + * Performs a similarity search with the specified vector and returns the + * results along with their scores. + * @param query The vector to use for the similarity search. + * @param k The number of results to return. + * @param filter The filter to apply to the results. + * @returns A promise that resolves with the search results and their scores. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: TFilterModel + ): Promise<[Document>, number][]> { + // table name, column names cannot be parametrised + // these fields are thus not escaped by Prisma and can be dangerous if user input is used + const vectorColumnRaw = this.Prisma.raw(`"${this.vectorColumnName}"`); + const tableNameRaw = this.Prisma.raw(`"${this.tableName}"`); + const selectRaw = this.Prisma.raw( + this.selectColumns.map((x) => `"${x}"`).join(", ") + ); + + const vector = `[${query.join(",")}]`; + const articles = await this.db.$queryRaw< + Array> + >( + this.Prisma.join( + [ + this.Prisma.sql` + SELECT ${selectRaw}, ${vectorColumnRaw} <=> ${vector}::vector as "_distance" + FROM ${tableNameRaw} + `, + this.buildSqlFilterStr(filter ?? this.filter), + this.Prisma.sql` + ORDER BY "_distance" ASC + LIMIT ${k}; + `, + ].filter((x) => x != null), + "" + ) + ); + + const results: [Document>, number][] = + []; + for (const article of articles) { + if (article._distance != null && article[this.contentColumn] != null) { + results.push([ + new Document({ + pageContent: article[this.contentColumn] as string, + metadata: article, + }), + article._distance, + ]); + } + } + + return results; + } + + buildSqlFilterStr(filter?: TFilterModel) { + if (filter == null) return null; + return this.Prisma.join( + Object.entries(filter).flatMap(([key, ops]) => + Object.entries(ops).map(([opName, value]) => { + // column name, operators cannot be parametrised + // these fields are thus not escaped by Prisma and can be dangerous if user input is used + const opNameKey = opName as keyof typeof OpMap; + const colRaw = this.Prisma.raw(`"${key}"`); + const opRaw = this.Prisma.raw(OpMap[opNameKey]); + + switch (OpMap[opNameKey]) { + case OpMap.in: { + if ( + !Array.isArray(value) || + !value.every((v) => typeof v === "string") + ) { + throw new Error( + `Invalid filter: IN operator requires an array of strings. Received: ${JSON.stringify( + value, + null, + 2 + )}` + ); + } + return this.Prisma.sql`${colRaw} ${opRaw} (${this.Prisma.join( + value + )})`; + } + case OpMap.isNull: + case OpMap.isNotNull: + return this.Prisma.sql`${colRaw} ${opRaw}`; + default: + return this.Prisma.sql`${colRaw} ${opRaw} ${value}`; + } + }) + ), + " AND ", + " WHERE " + ); + } + + /** + * Creates a new PrismaVectorStore from the specified texts. + * @param texts The texts to use to create the store. + * @param metadatas The metadata for the texts. + * @param embeddings The embeddings to use. + * @param dbConfig The database configuration. + * @returns A promise that resolves with the new PrismaVectorStore. + */ + static async fromTexts( + texts: string[], + metadatas: object[], + embeddings: Embeddings, + dbConfig: { + db: PrismaClient; + prisma: PrismaNamespace; + tableName: string; + vectorColumnName: string; + columns: ModelColumns>; + } + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + + return PrismaVectorStore.fromDocuments(docs, embeddings, dbConfig); + } + + /** + * Creates a new PrismaVectorStore from the specified documents. + * @param docs The documents to use to create the store. + * @param embeddings The embeddings to use. + * @param dbConfig The database configuration. + * @returns A promise that resolves with the new PrismaVectorStore. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig: { + db: PrismaClient; + prisma: PrismaNamespace; + tableName: string; + vectorColumnName: string; + columns: ModelColumns>; + } + ): Promise { + const instance = new PrismaVectorStore(embeddings, dbConfig); + await instance.addDocuments(docs); + return instance; + } +} diff --git a/libs/langchain-community/src/vectorstores/qdrant.ts b/libs/langchain-community/src/vectorstores/qdrant.ts new file mode 100644 index 000000000000..21062e1e86b5 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/qdrant.ts @@ -0,0 +1,260 @@ +import { QdrantClient } from "@qdrant/js-client-rest"; +import type { Schemas as QdrantSchemas } from "@qdrant/js-client-rest"; +import { v4 as uuid } from "uuid"; + +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +/** + * Interface for the arguments that can be passed to the + * `QdrantVectorStore` constructor. It includes options for specifying a + * `QdrantClient` instance, the URL and API key for a Qdrant database, and + * the name and configuration for a collection. + */ +export interface QdrantLibArgs { + client?: QdrantClient; + url?: string; + apiKey?: string; + collectionName?: string; + collectionConfig?: QdrantSchemas["CreateCollection"]; +} + +/** + * Type for the response returned by a search operation in the Qdrant + * database. It includes the score and payload (metadata and content) for + * each point (document) in the search results. + */ +type QdrantSearchResponse = QdrantSchemas["ScoredPoint"] & { + payload: { + metadata: object; + content: string; + }; +}; + +/** + * Class that extends the `VectorStore` base class to interact with a + * Qdrant database. It includes methods for adding documents and vectors + * to the Qdrant database, searching for similar vectors, and ensuring the + * existence of a collection in the database. + */ +export class QdrantVectorStore extends VectorStore { + get lc_secrets(): { [key: string]: string } { + return { + apiKey: "QDRANT_API_KEY", + url: "QDRANT_URL", + }; + } + + client: QdrantClient; + + collectionName: string; + + collectionConfig?: QdrantSchemas["CreateCollection"]; + + _vectorstoreType(): string { + return "qdrant"; + } + + constructor(embeddings: Embeddings, args: QdrantLibArgs) { + super(embeddings, args); + + const url = args.url ?? getEnvironmentVariable("QDRANT_URL"); + const apiKey = args.apiKey ?? getEnvironmentVariable("QDRANT_API_KEY"); + + if (!args.client && !url) { + throw new Error("Qdrant client or url address must be set."); + } + + this.client = + args.client || + new QdrantClient({ + url, + apiKey, + }); + + this.collectionName = args.collectionName ?? "documents"; + + this.collectionConfig = args.collectionConfig; + } + + /** + * Method to add documents to the Qdrant database. It generates vectors + * from the documents using the `Embeddings` instance and then adds the + * vectors to the database. + * @param documents Array of `Document` instances to be added to the Qdrant database. + * @returns Promise that resolves when the documents have been added to the database. + */ + async addDocuments(documents: Document[]): Promise { + const texts = documents.map(({ pageContent }) => pageContent); + await this.addVectors( + await this.embeddings.embedDocuments(texts), + documents + ); + } + + /** + * Method to add vectors to the Qdrant database. Each vector is associated + * with a document, which is stored as the payload for a point in the + * database. + * @param vectors Array of vectors to be added to the Qdrant database. + * @param documents Array of `Document` instances associated with the vectors. + * @returns Promise that resolves when the vectors have been added to the database. + */ + async addVectors(vectors: number[][], documents: Document[]): Promise { + if (vectors.length === 0) { + return; + } + + await this.ensureCollection(); + + const points = vectors.map((embedding, idx) => ({ + id: uuid(), + vector: embedding, + payload: { + content: documents[idx].pageContent, + metadata: documents[idx].metadata, + }, + })); + + try { + await this.client.upsert(this.collectionName, { + wait: true, + points, + }); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (e: any) { + const error = new Error( + `${e?.status ?? "Undefined error code"} ${e?.message}: ${ + e?.data?.status?.error + }` + ); + throw error; + } + } + + /** + * Method to search for vectors in the Qdrant database that are similar to + * a given query vector. The search results include the score and payload + * (metadata and content) for each similar vector. + * @param query Query vector to search for similar vectors in the Qdrant database. + * @param k Optional number of similar vectors to return. If not specified, all similar vectors are returned. + * @param filter Optional filter to apply to the search results. + * @returns Promise that resolves with an array of tuples, where each tuple includes a `Document` instance and a score for a similar vector. + */ + async similaritySearchVectorWithScore( + query: number[], + k?: number, + filter?: QdrantSchemas["Filter"] + ): Promise<[Document, number][]> { + if (!query) { + return []; + } + + await this.ensureCollection(); + + const results = await this.client.search(this.collectionName, { + vector: query, + limit: k, + filter, + }); + + const result: [Document, number][] = ( + results as QdrantSearchResponse[] + ).map((res) => [ + new Document({ + metadata: res.payload.metadata, + pageContent: res.payload.content, + }), + res.score, + ]); + + return result; + } + + /** + * Method to ensure the existence of a collection in the Qdrant database. + * If the collection does not exist, it is created. + * @returns Promise that resolves when the existence of the collection has been ensured. + */ + async ensureCollection() { + const response = await this.client.getCollections(); + + const collectionNames = response.collections.map( + (collection) => collection.name + ); + + if (!collectionNames.includes(this.collectionName)) { + const collectionConfig = this.collectionConfig ?? { + vectors: { + size: (await this.embeddings.embedQuery("test")).length, + distance: "Cosine", + }, + }; + await this.client.createCollection(this.collectionName, collectionConfig); + } + } + + /** + * Static method to create a `QdrantVectorStore` instance from texts. Each + * text is associated with metadata and converted to a `Document` + * instance, which is then added to the Qdrant database. + * @param texts Array of texts to be converted to `Document` instances and added to the Qdrant database. + * @param metadatas Array or single object of metadata to be associated with the texts. + * @param embeddings `Embeddings` instance used to generate vectors from the texts. + * @param dbConfig `QdrantLibArgs` instance specifying the configuration for the Qdrant database. + * @returns Promise that resolves with a new `QdrantVectorStore` instance. + */ + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + dbConfig: QdrantLibArgs + ): Promise { + const docs = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + return QdrantVectorStore.fromDocuments(docs, embeddings, dbConfig); + } + + /** + * Static method to create a `QdrantVectorStore` instance from `Document` + * instances. The documents are added to the Qdrant database. + * @param docs Array of `Document` instances to be added to the Qdrant database. + * @param embeddings `Embeddings` instance used to generate vectors from the documents. + * @param dbConfig `QdrantLibArgs` instance specifying the configuration for the Qdrant database. + * @returns Promise that resolves with a new `QdrantVectorStore` instance. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig: QdrantLibArgs + ): Promise { + const instance = new this(embeddings, dbConfig); + await instance.addDocuments(docs); + return instance; + } + + /** + * Static method to create a `QdrantVectorStore` instance from an existing + * collection in the Qdrant database. + * @param embeddings `Embeddings` instance used to generate vectors from the documents in the collection. + * @param dbConfig `QdrantLibArgs` instance specifying the configuration for the Qdrant database. + * @returns Promise that resolves with a new `QdrantVectorStore` instance. + */ + static async fromExistingCollection( + embeddings: Embeddings, + dbConfig: QdrantLibArgs + ): Promise { + const instance = new this(embeddings, dbConfig); + await instance.ensureCollection(); + return instance; + } +} diff --git a/libs/langchain-community/src/vectorstores/redis.ts b/libs/langchain-community/src/vectorstores/redis.ts new file mode 100644 index 000000000000..5df94f4646b4 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/redis.ts @@ -0,0 +1,458 @@ +import type { + createCluster, + createClient, + RediSearchSchema, + SearchOptions, +} from "redis"; +import { SchemaFieldTypes, VectorAlgorithms } from "redis"; +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; + +// Adapated from internal redis types which aren't exported +/** + * Type for creating a schema vector field. It includes the algorithm, + * distance metric, and initial capacity. + */ +export type CreateSchemaVectorField< + T extends VectorAlgorithms, + A extends Record +> = { + ALGORITHM: T; + DISTANCE_METRIC: "L2" | "IP" | "COSINE"; + INITIAL_CAP?: number; +} & A; +/** + * Type for creating a flat schema vector field. It extends + * CreateSchemaVectorField with a block size property. + */ +export type CreateSchemaFlatVectorField = CreateSchemaVectorField< + VectorAlgorithms.FLAT, + { + BLOCK_SIZE?: number; + } +>; +/** + * Type for creating a HNSW schema vector field. It extends + * CreateSchemaVectorField with M, EF_CONSTRUCTION, and EF_RUNTIME + * properties. + */ +export type CreateSchemaHNSWVectorField = CreateSchemaVectorField< + VectorAlgorithms.HNSW, + { + M?: number; + EF_CONSTRUCTION?: number; + EF_RUNTIME?: number; + } +>; + +type CreateIndexOptions = NonNullable< + Parameters["ft"]["create"]>[3] +>; + +export type RedisSearchLanguages = `${NonNullable< + CreateIndexOptions["LANGUAGE"] +>}`; + +export type RedisVectorStoreIndexOptions = Omit< + CreateIndexOptions, + "LANGUAGE" +> & { LANGUAGE?: RedisSearchLanguages }; + +/** + * Interface for the configuration of the RedisVectorStore. It includes + * the Redis client, index name, index options, key prefix, content key, + * metadata key, vector key, and filter. + */ +export interface RedisVectorStoreConfig { + redisClient: + | ReturnType + | ReturnType; + indexName: string; + indexOptions?: CreateSchemaFlatVectorField | CreateSchemaHNSWVectorField; + createIndexOptions?: Omit; // PREFIX must be set with keyPrefix + keyPrefix?: string; + contentKey?: string; + metadataKey?: string; + vectorKey?: string; + filter?: RedisVectorStoreFilterType; +} + +/** + * Interface for the options when adding documents to the + * RedisVectorStore. It includes keys and batch size. + */ +export interface RedisAddOptions { + keys?: string[]; + batchSize?: number; +} + +/** + * Type for the filter used in the RedisVectorStore. It is an array of + * strings. + */ +export type RedisVectorStoreFilterType = string[]; + +/** + * Class representing a RedisVectorStore. It extends the VectorStore class + * and includes methods for adding documents and vectors, performing + * similarity searches, managing the index, and more. + */ +export class RedisVectorStore extends VectorStore { + declare FilterType: RedisVectorStoreFilterType; + + private redisClient: + | ReturnType + | ReturnType; + + indexName: string; + + indexOptions: CreateSchemaFlatVectorField | CreateSchemaHNSWVectorField; + + createIndexOptions: CreateIndexOptions; + + keyPrefix: string; + + contentKey: string; + + metadataKey: string; + + vectorKey: string; + + filter?: RedisVectorStoreFilterType; + + _vectorstoreType(): string { + return "redis"; + } + + constructor(embeddings: Embeddings, _dbConfig: RedisVectorStoreConfig) { + super(embeddings, _dbConfig); + + this.redisClient = _dbConfig.redisClient; + this.indexName = _dbConfig.indexName; + this.indexOptions = _dbConfig.indexOptions ?? { + ALGORITHM: VectorAlgorithms.HNSW, + DISTANCE_METRIC: "COSINE", + }; + this.keyPrefix = _dbConfig.keyPrefix ?? `doc:${this.indexName}:`; + this.contentKey = _dbConfig.contentKey ?? "content"; + this.metadataKey = _dbConfig.metadataKey ?? "metadata"; + this.vectorKey = _dbConfig.vectorKey ?? "content_vector"; + this.filter = _dbConfig.filter; + this.createIndexOptions = { + ON: "HASH", + PREFIX: this.keyPrefix, + ...(_dbConfig.createIndexOptions as CreateIndexOptions), + }; + } + + /** + * Method for adding documents to the RedisVectorStore. It first converts + * the documents to texts and then adds them as vectors. + * @param documents The documents to add. + * @param options Optional parameters for adding the documents. + * @returns A promise that resolves when the documents have been added. + */ + async addDocuments(documents: Document[], options?: RedisAddOptions) { + const texts = documents.map(({ pageContent }) => pageContent); + return this.addVectors( + await this.embeddings.embedDocuments(texts), + documents, + options + ); + } + + /** + * Method for adding vectors to the RedisVectorStore. It checks if the + * index exists and creates it if it doesn't, then adds the vectors in + * batches. + * @param vectors The vectors to add. + * @param documents The documents associated with the vectors. + * @param keys Optional keys for the vectors. + * @param batchSize The size of the batches in which to add the vectors. Defaults to 1000. + * @returns A promise that resolves when the vectors have been added. + */ + async addVectors( + vectors: number[][], + documents: Document[], + { keys, batchSize = 1000 }: RedisAddOptions = {} + ) { + if (!vectors.length || !vectors[0].length) { + throw new Error("No vectors provided"); + } + // check if the index exists and create it if it doesn't + await this.createIndex(vectors[0].length); + + const info = await this.redisClient.ft.info(this.indexName); + const lastKeyCount = parseInt(info.numDocs, 10) || 0; + const multi = this.redisClient.multi(); + + vectors.map(async (vector, idx) => { + const key = + keys && keys.length + ? keys[idx] + : `${this.keyPrefix}${idx + lastKeyCount}`; + const metadata = + documents[idx] && documents[idx].metadata + ? documents[idx].metadata + : {}; + + multi.hSet(key, { + [this.vectorKey]: this.getFloat32Buffer(vector), + [this.contentKey]: documents[idx].pageContent, + [this.metadataKey]: this.escapeSpecialChars(JSON.stringify(metadata)), + }); + + // write batch + if (idx % batchSize === 0) { + await multi.exec(); + } + }); + + // insert final batch + await multi.exec(); + } + + /** + * Method for performing a similarity search in the RedisVectorStore. It + * returns the documents and their scores. + * @param query The query vector. + * @param k The number of nearest neighbors to return. + * @param filter Optional filter to apply to the search. + * @returns A promise that resolves to an array of documents and their scores. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: RedisVectorStoreFilterType + ): Promise<[Document, number][]> { + if (filter && this.filter) { + throw new Error("cannot provide both `filter` and `this.filter`"); + } + + const _filter = filter ?? this.filter; + const results = await this.redisClient.ft.search( + this.indexName, + ...this.buildQuery(query, k, _filter) + ); + const result: [Document, number][] = []; + + if (results.total) { + for (const res of results.documents) { + if (res.value) { + const document = res.value; + if (document.vector_score) { + result.push([ + new Document({ + pageContent: document[this.contentKey] as string, + metadata: JSON.parse( + this.unEscapeSpecialChars(document.metadata as string) + ), + }), + Number(document.vector_score), + ]); + } + } + } + } + + return result; + } + + /** + * Static method for creating a new instance of RedisVectorStore from + * texts. It creates documents from the texts and metadata, then adds them + * to the RedisVectorStore. + * @param texts The texts to add. + * @param metadatas The metadata associated with the texts. + * @param embeddings The embeddings to use. + * @param dbConfig The configuration for the RedisVectorStore. + * @returns A promise that resolves to a new instance of RedisVectorStore. + */ + static fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + dbConfig: RedisVectorStoreConfig + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + return RedisVectorStore.fromDocuments(docs, embeddings, dbConfig); + } + + /** + * Static method for creating a new instance of RedisVectorStore from + * documents. It adds the documents to the RedisVectorStore. + * @param docs The documents to add. + * @param embeddings The embeddings to use. + * @param dbConfig The configuration for the RedisVectorStore. + * @returns A promise that resolves to a new instance of RedisVectorStore. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig: RedisVectorStoreConfig + ): Promise { + const instance = new this(embeddings, dbConfig); + await instance.addDocuments(docs); + return instance; + } + + /** + * Method for checking if an index exists in the RedisVectorStore. + * @returns A promise that resolves to a boolean indicating whether the index exists. + */ + async checkIndexExists() { + try { + await this.redisClient.ft.info(this.indexName); + } catch (err) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + if ((err as any)?.message.includes("unknown command")) { + throw new Error( + "Failed to run FT.INFO command. Please ensure that you are running a RediSearch-capable Redis instance: https://js.langchain.com/docs/modules/data_connection/vectorstores/integrations/redis#setup" + ); + } + // index doesn't exist + return false; + } + + return true; + } + + /** + * Method for creating an index in the RedisVectorStore. If the index + * already exists, it does nothing. + * @param dimensions The dimensions of the index + * @returns A promise that resolves when the index has been created. + */ + async createIndex(dimensions = 1536): Promise { + if (await this.checkIndexExists()) { + return; + } + + const schema: RediSearchSchema = { + [this.vectorKey]: { + type: SchemaFieldTypes.VECTOR, + TYPE: "FLOAT32", + DIM: dimensions, + ...this.indexOptions, + }, + [this.contentKey]: SchemaFieldTypes.TEXT, + [this.metadataKey]: SchemaFieldTypes.TEXT, + }; + + await this.redisClient.ft.create( + this.indexName, + schema, + this.createIndexOptions + ); + } + + /** + * Method for dropping an index from the RedisVectorStore. + * @param deleteDocuments Optional boolean indicating whether to drop the associated documents. + * @returns A promise that resolves to a boolean indicating whether the index was dropped. + */ + async dropIndex(deleteDocuments?: boolean): Promise { + try { + const options = deleteDocuments ? { DD: deleteDocuments } : undefined; + await this.redisClient.ft.dropIndex(this.indexName, options); + + return true; + } catch (err) { + return false; + } + } + + /** + * Deletes vectors from the vector store. + * @param params The parameters for deleting vectors. + * @returns A promise that resolves when the vectors have been deleted. + */ + async delete(params: { deleteAll: boolean }): Promise { + if (params.deleteAll) { + await this.dropIndex(true); + } else { + throw new Error(`Invalid parameters passed to "delete".`); + } + } + + private buildQuery( + query: number[], + k: number, + filter?: RedisVectorStoreFilterType + ): [string, SearchOptions] { + const vectorScoreField = "vector_score"; + + let hybridFields = "*"; + // if a filter is set, modify the hybrid query + if (filter && filter.length) { + // `filter` is a list of strings, then it's applied using the OR operator in the metadata key + // for example: filter = ['foo', 'bar'] => this will filter all metadata containing either 'foo' OR 'bar' + hybridFields = `@${this.metadataKey}:(${this.prepareFilter(filter)})`; + } + + const baseQuery = `${hybridFields} => [KNN ${k} @${this.vectorKey} $vector AS ${vectorScoreField}]`; + const returnFields = [this.metadataKey, this.contentKey, vectorScoreField]; + + const options: SearchOptions = { + PARAMS: { + vector: this.getFloat32Buffer(query), + }, + RETURN: returnFields, + SORTBY: vectorScoreField, + DIALECT: 2, + LIMIT: { + from: 0, + size: k, + }, + }; + + return [baseQuery, options]; + } + + private prepareFilter(filter: RedisVectorStoreFilterType) { + return filter.map(this.escapeSpecialChars).join("|"); + } + + /** + * Escapes all '-' characters. + * RediSearch considers '-' as a negative operator, hence we need + * to escape it + * @see https://redis.io/docs/stack/search/reference/query_syntax + * + * @param str + * @returns + */ + private escapeSpecialChars(str: string) { + return str.replaceAll("-", "\\-"); + } + + /** + * Unescapes all '-' characters, returning the original string + * + * @param str + * @returns + */ + private unEscapeSpecialChars(str: string) { + return str.replaceAll("\\-", "-"); + } + + /** + * Converts the vector to the buffer Redis needs to + * correctly store an embedding + * + * @param vector + * @returns Buffer + */ + private getFloat32Buffer(vector: number[]) { + return Buffer.from(new Float32Array(vector).buffer); + } +} diff --git a/libs/langchain-community/src/vectorstores/rockset.ts b/libs/langchain-community/src/vectorstores/rockset.ts new file mode 100644 index 000000000000..04a93f4f6689 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/rockset.ts @@ -0,0 +1,452 @@ +import { MainApi } from "@rockset/client"; +import type { CreateCollectionRequest } from "@rockset/client/dist/codegen/api.d.ts"; +import { Collection } from "@rockset/client/dist/codegen/api.js"; + +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; +/** + * Generic Rockset vector storage error + */ +export class RocksetStoreError extends Error { + /** + * Constructs a RocksetStoreError + * @param message The error message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * Error that is thrown when a RocksetStore function is called + * after `destroy()` is called (meaning the collection would be + * deleted). + */ +export class RocksetStoreDestroyedError extends RocksetStoreError { + constructor() { + super("The Rockset store has been destroyed"); + this.name = this.constructor.name; + } +} + +/** + * Functions to measure vector distance/similarity by. + * See https://rockset.com/docs/vector-functions/#vector-distance-functions + * @enum SimilarityMetric + */ +export const SimilarityMetric = { + CosineSimilarity: "COSINE_SIM", + EuclideanDistance: "EUCLIDEAN_DIST", + DotProduct: "DOT_PRODUCT", +} as const; + +export type SimilarityMetric = + (typeof SimilarityMetric)[keyof typeof SimilarityMetric]; + +interface CollectionNotFoundError { + message_key: string; +} + +/** + * Vector store arguments + * @interface RocksetStore + */ +export interface RocksetLibArgs { + /** + * The rockset client object constructed with `rocksetConfigure` + * @type {MainAPI} + */ + client: MainApi; + /** + * The name of the Rockset collection to store vectors + * @type {string} + */ + collectionName: string; + /** + * The name of othe Rockset workspace that holds @member collectionName + * @type {string} + */ + workspaceName?: string; + /** + * The name of the collection column to contain page contnent of documents + * @type {string} + */ + textKey?: string; + /** + * The name of the collection column to contain vectors + * @type {string} + */ + embeddingKey?: string; + /** + * The SQL `WHERE` clause to filter by + * @type {string} + */ + filter?: string; + /** + * The metric used to measure vector relationship + * @type {SimilarityMetric} + */ + similarityMetric?: SimilarityMetric; +} + +/** + * Exposes Rockset's vector store/search functionality + */ +export class RocksetStore extends VectorStore { + declare FilterType: string; + + client: MainApi; + + collectionName: string; + + workspaceName: string; + + textKey: string; + + embeddingKey: string; + + filter?: string; + + private _similarityMetric: SimilarityMetric; + + private similarityOrder: "ASC" | "DESC"; + + private destroyed: boolean; + + /** + * Gets a string representation of the type of this VectorStore + * @returns {"rockset"} + */ + _vectorstoreType(): "rockset" { + return "rockset"; + } + + /** + * Constructs a new RocksetStore + * @param {Embeddings} embeddings Object used to embed queries and + * page content + * @param {RocksetLibArgs} args + */ + constructor(embeddings: Embeddings, args: RocksetLibArgs) { + super(embeddings, args); + + this.embeddings = embeddings; + this.client = args.client; + this.collectionName = args.collectionName; + this.workspaceName = args.workspaceName ?? "commons"; + this.textKey = args.textKey ?? "text"; + this.embeddingKey = args.embeddingKey ?? "embedding"; + this.filter = args.filter; + this.similarityMetric = + args.similarityMetric ?? SimilarityMetric.CosineSimilarity; + this.setSimilarityOrder(); + } + + /** + * Sets the object's similarity order based on what + * SimilarityMetric is being used + */ + private setSimilarityOrder() { + this.checkIfDestroyed(); + this.similarityOrder = + this.similarityMetric === SimilarityMetric.EuclideanDistance + ? "ASC" + : "DESC"; + } + + /** + * Embeds and adds Documents to the store. + * @param {Documents[]} documents The documents to store + * @returns {Promise} The _id's of the documents added + */ + async addDocuments(documents: Document[]): Promise { + const texts = documents.map(({ pageContent }) => pageContent); + return await this.addVectors( + await this.embeddings.embedDocuments(texts), + documents + ); + } + + /** + * Adds vectors to the store given their corresponding Documents + * @param {number[][]} vectors The vectors to store + * @param {Document[]} documents The Documents they represent + * @return {Promise} The _id's of the added documents + */ + async addVectors(vectors: number[][], documents: Document[]) { + this.checkIfDestroyed(); + const rocksetDocs = []; + for (let i = 0; i < documents.length; i += 1) { + const currDoc = documents[i]; + const currVector = vectors[i]; + rocksetDocs.push({ + [this.textKey]: currDoc.pageContent, + [this.embeddingKey]: currVector, + ...currDoc.metadata, + }); + } + + return ( + await this.client.documents.addDocuments( + this.workspaceName, + this.collectionName, + { + data: rocksetDocs, + } + ) + ).data?.map((docStatus) => docStatus._id || ""); + } + + /** + * Deletes Rockset documements given their _id's + * @param {string[]} ids The IDS to remove documents with + */ + async delete(ids: string[]): Promise { + this.checkIfDestroyed(); + await this.client.documents.deleteDocuments( + this.workspaceName, + this.collectionName, + { + data: ids.map((id) => ({ _id: id })), + } + ); + } + + /** + * Gets the most relevant documents to a query along + * with their similarity score. The returned documents + * are ordered by similarity (most similar at the first + * index) + * @param {number[]} query The embedded query to search + * the store by + * @param {number} k The number of documents to retreive + * @param {string?} filter The SQL `WHERE` clause to filter by + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: string + ): Promise<[Document, number][]> { + this.checkIfDestroyed(); + if (filter && this.filter) { + throw new RocksetStoreError( + "cannot provide both `filter` and `this.filter`" + ); + } + const similarityKey = "similarity"; + const _filter = filter ?? this.filter; + return ( + ( + await this.client.queries.query({ + sql: { + query: ` + SELECT + * EXCEPT("${this.embeddingKey}"), + "${this.textKey}", + ${this.similarityMetric}(:query, "${ + this.embeddingKey + }") AS "${similarityKey}" + FROM + "${this.workspaceName}"."${this.collectionName}" + ${_filter ? `WHERE ${_filter}` : ""} + ORDER BY + "${similarityKey}" ${this.similarityOrder} + LIMIT + ${k} + `, + parameters: [ + { + name: "query", + type: "", + value: `[${query.toString()}]`, + }, + ], + }, + }) + ).results?.map((rocksetDoc) => [ + new Document>({ + pageContent: rocksetDoc[this.textKey], + metadata: (({ + [this.textKey]: t, + [similarityKey]: s, + ...rocksetDoc + }) => rocksetDoc)(rocksetDoc), + }), + rocksetDoc[similarityKey] as number, + ]) ?? [] + ); + } + + /** + * Constructs and returns a RocksetStore object given texts to store. + * @param {string[]} texts The texts to store + * @param {object[] | object} metadatas The metadatas that correspond + * to @param texts + * @param {Embeddings} embeddings The object used to embed queries + * and page content + * @param {RocksetLibArgs} dbConfig The options to be passed into the + * RocksetStore constructor + * @returns {RocksetStore} + */ + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + dbConfig: RocksetLibArgs + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + + return RocksetStore.fromDocuments(docs, embeddings, dbConfig); + } + + /** + * Constructs, adds docs to, and returns a RocksetStore object + * @param {Document[]} docs The Documents to store + * @param {Embeddings} embeddings The object used to embed queries + * and page content + * @param {RocksetLibArgs} dbConfig The options to be passed into the + * RocksetStore constructor + * @returns {RocksetStore} + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig: RocksetLibArgs + ): Promise { + const args = { ...dbConfig, textKey: dbConfig.textKey ?? "text" }; + const instance = new this(embeddings, args); + await instance.addDocuments(docs); + return instance; + } + + /** + * Checks if a Rockset collection exists. + * @param {RocksetLibArgs} dbConfig The object containing the collection + * and workspace names + * @return {boolean} whether the collection exists + */ + private static async collectionExists(dbConfig: RocksetLibArgs) { + try { + await dbConfig.client.collections.getCollection( + dbConfig.workspaceName ?? "commons", + dbConfig.collectionName + ); + } catch (err) { + if ( + (err as CollectionNotFoundError).message_key === + "COLLECTION_DOES_NOT_EXIST" + ) { + return false; + } + throw err; + } + return true; + } + + /** + * Checks whether a Rockset collection is ready to be queried. + * @param {RocksetLibArgs} dbConfig The object containing the collection + * name and workspace + * @return {boolean} whether the collection is ready + */ + private static async collectionReady(dbConfig: RocksetLibArgs) { + return ( + ( + await dbConfig.client.collections.getCollection( + dbConfig.workspaceName ?? "commons", + dbConfig.collectionName + ) + ).data?.status === Collection.StatusEnum.READY + ); + } + + /** + * Deletes the collection this RocksetStore uses + * @param {boolean?} waitUntilDeletion Whether to sleep until the + * collection is ready to be + * queried + */ + async destroy(waitUntilDeletion?: boolean) { + await this.client.collections.deleteCollection( + this.workspaceName, + this.collectionName + ); + this.destroyed = true; + if (waitUntilDeletion) { + while ( + await RocksetStore.collectionExists({ + collectionName: this.collectionName, + client: this.client, + }) + ); + } + } + + /** + * Checks if this RocksetStore has been destroyed. + * @throws {RocksetStoreDestroyederror} if it has. + */ + private checkIfDestroyed() { + if (this.destroyed) { + throw new RocksetStoreDestroyedError(); + } + } + + /** + * Creates a new Rockset collection and returns a RocksetStore that + * uses it + * @param {Embeddings} embeddings Object used to embed queries and + * page content + * @param {RocksetLibArgs} dbConfig The options to be passed into the + * RocksetStore constructor + * @param {CreateCollectionRequest?} collectionOptions The arguments to sent with the + * HTTP request when creating the + * collection. Setting a field mapping + * that `VECTOR_ENFORCE`s is recommended + * when using this function. See + * https://rockset.com/docs/vector-functions/#vector_enforce + * @returns {RocsketStore} + */ + static async withNewCollection( + embeddings: Embeddings, + dbConfig: RocksetLibArgs, + collectionOptions?: CreateCollectionRequest + ): Promise { + if ( + collectionOptions?.name && + dbConfig.collectionName !== collectionOptions?.name + ) { + throw new RocksetStoreError( + "`dbConfig.name` and `collectionOptions.name` do not match" + ); + } + await dbConfig.client.collections.createCollection( + dbConfig.workspaceName ?? "commons", + collectionOptions || { name: dbConfig.collectionName } + ); + while ( + !(await this.collectionExists(dbConfig)) || + !(await this.collectionReady(dbConfig)) + ); + return new this(embeddings, dbConfig); + } + + public get similarityMetric() { + return this._similarityMetric; + } + + public set similarityMetric(metric: SimilarityMetric) { + this._similarityMetric = metric; + this.setSimilarityOrder(); + } +} diff --git a/libs/langchain-community/src/vectorstores/singlestore.ts b/libs/langchain-community/src/vectorstores/singlestore.ts new file mode 100644 index 000000000000..34abd51598a0 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/singlestore.ts @@ -0,0 +1,294 @@ +import type { + Pool, + RowDataPacket, + OkPacket, + ResultSetHeader, + FieldPacket, + PoolOptions, +} from "mysql2/promise"; +import { format } from "mysql2"; +import { createPool } from "mysql2/promise"; +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export type Metadata = Record; + +export type DistanceMetrics = "DOT_PRODUCT" | "EUCLIDEAN_DISTANCE"; + +const OrderingDirective: Record = { + DOT_PRODUCT: "DESC", + EUCLIDEAN_DISTANCE: "", +}; + +export interface ConnectionOptions extends PoolOptions {} + +type ConnectionWithUri = { + connectionOptions?: never; + connectionURI: string; +}; + +type ConnectionWithOptions = { + connectionURI?: never; + connectionOptions: ConnectionOptions; +}; + +type ConnectionConfig = ConnectionWithUri | ConnectionWithOptions; + +export type SingleStoreVectorStoreConfig = ConnectionConfig & { + tableName?: string; + contentColumnName?: string; + vectorColumnName?: string; + metadataColumnName?: string; + distanceMetric?: DistanceMetrics; +}; + +function withConnectAttributes( + config: SingleStoreVectorStoreConfig +): ConnectionOptions { + let newOptions: ConnectionOptions = {}; + if (config.connectionURI) { + newOptions = { + uri: config.connectionURI, + }; + } else if (config.connectionOptions) { + newOptions = { + ...config.connectionOptions, + }; + } + const result: ConnectionOptions = { + ...newOptions, + connectAttributes: { + ...newOptions.connectAttributes, + }, + }; + + if (!result.connectAttributes) { + result.connectAttributes = {}; + } + + result.connectAttributes = { + ...result.connectAttributes, + _connector_name: "langchain js sdk", + _connector_version: "1.0.0", + _driver_name: "Node-MySQL-2", + }; + + return result; +} + +/** + * Class for interacting with SingleStoreDB, a high-performance + * distributed SQL database. It provides vector storage and vector + * functions. + */ +export class SingleStoreVectorStore extends VectorStore { + connectionPool: Pool; + + tableName: string; + + contentColumnName: string; + + vectorColumnName: string; + + metadataColumnName: string; + + distanceMetric: DistanceMetrics; + + _vectorstoreType(): string { + return "singlestore"; + } + + constructor(embeddings: Embeddings, config: SingleStoreVectorStoreConfig) { + super(embeddings, config); + this.connectionPool = createPool(withConnectAttributes(config)); + this.tableName = config.tableName ?? "embeddings"; + this.contentColumnName = config.contentColumnName ?? "content"; + this.vectorColumnName = config.vectorColumnName ?? "vector"; + this.metadataColumnName = config.metadataColumnName ?? "metadata"; + this.distanceMetric = config.distanceMetric ?? "DOT_PRODUCT"; + } + + /** + * Creates a new table in the SingleStoreDB database if it does not + * already exist. + */ + async createTableIfNotExists(): Promise { + await this.connectionPool + .execute(`CREATE TABLE IF NOT EXISTS ${this.tableName} ( + ${this.contentColumnName} TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci, + ${this.vectorColumnName} BLOB, + ${this.metadataColumnName} JSON);`); + } + + /** + * Ends the connection to the SingleStoreDB database. + */ + async end(): Promise { + return this.connectionPool.end(); + } + + /** + * Adds new documents to the SingleStoreDB database. + * @param documents An array of Document objects. + */ + async addDocuments(documents: Document[]): Promise { + const texts = documents.map(({ pageContent }) => pageContent); + const vectors = await this.embeddings.embedDocuments(texts); + return this.addVectors(vectors, documents); + } + + /** + * Adds new vectors to the SingleStoreDB database. + * @param vectors An array of vectors. + * @param documents An array of Document objects. + */ + async addVectors(vectors: number[][], documents: Document[]): Promise { + await this.createTableIfNotExists(); + const { tableName } = this; + + await Promise.all( + vectors.map(async (vector, idx) => { + try { + await this.connectionPool.execute( + format( + `INSERT INTO ${tableName} VALUES (?, JSON_ARRAY_PACK('[?]'), ?);`, + [ + documents[idx].pageContent, + vector, + JSON.stringify(documents[idx].metadata), + ] + ) + ); + } catch (error) { + console.error(`Error adding vector at index ${idx}:`, error); + } + }) + ); + } + + /** + * Performs a similarity search on the vectors stored in the SingleStoreDB + * database. + * @param query An array of numbers representing the query vector. + * @param k The number of nearest neighbors to return. + * @param filter Optional metadata to filter the vectors by. + * @returns Top matching vectors with score + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: Metadata + ): Promise<[Document, number][]> { + // build the where clause from filter + const whereArgs: string[] = []; + const buildWhereClause = (record: Metadata, argList: string[]): string => { + const whereTokens: string[] = []; + for (const key in record) + if (record[key] !== undefined) { + if ( + typeof record[key] === "object" && + record[key] != null && + !Array.isArray(record[key]) + ) { + whereTokens.push( + buildWhereClause(record[key], argList.concat([key])) + ); + } else { + whereTokens.push( + `JSON_EXTRACT_JSON(${this.metadataColumnName}, `.concat( + Array.from({ length: argList.length + 1 }, () => "?").join( + ", " + ), + ") = ?" + ) + ); + whereArgs.push(...argList, key, JSON.stringify(record[key])); + } + } + return whereTokens.join(" AND "); + }; + const whereClause = filter + ? "WHERE ".concat(buildWhereClause(filter, [])) + : ""; + + const [rows]: [ + ( + | RowDataPacket[] + | RowDataPacket[][] + | OkPacket + | OkPacket[] + | ResultSetHeader + ), + FieldPacket[] + ] = await this.connectionPool.query( + format( + `SELECT ${this.contentColumnName}, + ${this.metadataColumnName}, + ${this.distanceMetric}(${ + this.vectorColumnName + }, JSON_ARRAY_PACK('[?]')) as __score FROM ${ + this.tableName + } ${whereClause} + ORDER BY __score ${OrderingDirective[this.distanceMetric]} LIMIT ?;`, + [query, ...whereArgs, k] + ) + ); + const result: [Document, number][] = []; + for (const row of rows as RowDataPacket[]) { + const rowData = row as unknown as Record; + result.push([ + new Document({ + pageContent: rowData[this.contentColumnName] as string, + metadata: rowData[this.metadataColumnName] as Record, + }), + Number(rowData.score), + ]); + } + return result; + } + + /** + * Creates a new instance of the SingleStoreVectorStore class from a list + * of texts. + * @param texts An array of strings. + * @param metadatas An array of metadata objects. + * @param embeddings An Embeddings object. + * @param dbConfig A SingleStoreVectorStoreConfig object. + * @returns A new SingleStoreVectorStore instance + */ + static async fromTexts( + texts: string[], + metadatas: object[], + embeddings: Embeddings, + dbConfig: SingleStoreVectorStoreConfig + ): Promise { + const docs = texts.map((text, idx) => { + const metadata = Array.isArray(metadatas) ? metadatas[idx] : metadatas; + return new Document({ + pageContent: text, + metadata, + }); + }); + return SingleStoreVectorStore.fromDocuments(docs, embeddings, dbConfig); + } + + /** + * Creates a new instance of the SingleStoreVectorStore class from a list + * of Document objects. + * @param docs An array of Document objects. + * @param embeddings An Embeddings object. + * @param dbConfig A SingleStoreVectorStoreConfig object. + * @returns A new SingleStoreVectorStore instance + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig: SingleStoreVectorStoreConfig + ): Promise { + const instance = new this(embeddings, dbConfig); + await instance.addDocuments(docs); + return instance; + } +} diff --git a/libs/langchain-community/src/vectorstores/supabase.ts b/libs/langchain-community/src/vectorstores/supabase.ts new file mode 100644 index 000000000000..5d8e9a5d67cb --- /dev/null +++ b/libs/langchain-community/src/vectorstores/supabase.ts @@ -0,0 +1,313 @@ +import type { SupabaseClient } from "@supabase/supabase-js"; +import type { PostgrestFilterBuilder } from "@supabase/postgrest-js"; +import { + MaxMarginalRelevanceSearchOptions, + VectorStore, +} from "@langchain/core/vectorstores"; +import { Embeddings } from "@langchain/core/embeddings"; +import { Document } from "@langchain/core/documents"; +import { maximalMarginalRelevance } from "@langchain/core/utils/math"; + +/** + * Interface for the parameters required for searching embeddings. + */ +interface SearchEmbeddingsParams { + query_embedding: number[]; + match_count: number; // int + filter?: SupabaseMetadata | SupabaseFilterRPCCall; +} + +// eslint-disable-next-line @typescript-eslint/ban-types, @typescript-eslint/no-explicit-any +export type SupabaseMetadata = Record; +// eslint-disable-next-line @typescript-eslint/ban-types, @typescript-eslint/no-explicit-any +export type SupabaseFilter = PostgrestFilterBuilder; +export type SupabaseFilterRPCCall = (rpcCall: SupabaseFilter) => SupabaseFilter; + +/** + * Interface for the response returned when searching embeddings. + */ +interface SearchEmbeddingsResponse { + id: number; + content: string; + metadata: object; + embedding: number[]; + similarity: number; +} + +/** + * Interface for the arguments required to initialize a Supabase library. + */ +export interface SupabaseLibArgs { + client: SupabaseClient; + tableName?: string; + queryName?: string; + filter?: SupabaseMetadata | SupabaseFilterRPCCall; + upsertBatchSize?: number; +} + +/** + * Class for interacting with a Supabase database to store and manage + * vectors. + */ +export class SupabaseVectorStore extends VectorStore { + declare FilterType: SupabaseMetadata | SupabaseFilterRPCCall; + + client: SupabaseClient; + + tableName: string; + + queryName: string; + + filter?: SupabaseMetadata | SupabaseFilterRPCCall; + + upsertBatchSize = 500; + + _vectorstoreType(): string { + return "supabase"; + } + + constructor(embeddings: Embeddings, args: SupabaseLibArgs) { + super(embeddings, args); + + this.client = args.client; + this.tableName = args.tableName || "documents"; + this.queryName = args.queryName || "match_documents"; + this.filter = args.filter; + this.upsertBatchSize = args.upsertBatchSize ?? this.upsertBatchSize; + } + + /** + * Adds documents to the vector store. + * @param documents The documents to add. + * @param options Optional parameters for adding the documents. + * @returns A promise that resolves when the documents have been added. + */ + async addDocuments( + documents: Document[], + options?: { ids?: string[] | number[] } + ) { + const texts = documents.map(({ pageContent }) => pageContent); + return this.addVectors( + await this.embeddings.embedDocuments(texts), + documents, + options + ); + } + + /** + * Adds vectors to the vector store. + * @param vectors The vectors to add. + * @param documents The documents associated with the vectors. + * @param options Optional parameters for adding the vectors. + * @returns A promise that resolves with the IDs of the added vectors when the vectors have been added. + */ + async addVectors( + vectors: number[][], + documents: Document[], + options?: { ids?: string[] | number[] } + ) { + const rows = vectors.map((embedding, idx) => ({ + content: documents[idx].pageContent, + embedding, + metadata: documents[idx].metadata, + })); + + // upsert returns 500/502/504 (yes really any of them) if given too many rows/characters + // ~2000 trips it, but my data is probably smaller than average pageContent and metadata + let returnedIds: string[] = []; + for (let i = 0; i < rows.length; i += this.upsertBatchSize) { + const chunk = rows.slice(i, i + this.upsertBatchSize).map((row, j) => { + if (options?.ids) { + return { id: options.ids[i + j], ...row }; + } + return row; + }); + + const res = await this.client.from(this.tableName).upsert(chunk).select(); + if (res.error) { + throw new Error( + `Error inserting: ${res.error.message} ${res.status} ${res.statusText}` + ); + } + if (res.data) { + returnedIds = returnedIds.concat(res.data.map((row) => row.id)); + } + } + return returnedIds; + } + + /** + * Deletes vectors from the vector store. + * @param params The parameters for deleting vectors. + * @returns A promise that resolves when the vectors have been deleted. + */ + async delete(params: { ids: string[] | number[] }): Promise { + const { ids } = params; + for (const id of ids) { + await this.client.from(this.tableName).delete().eq("id", id); + } + } + + protected async _searchSupabase( + query: number[], + k: number, + filter?: this["FilterType"] + ): Promise { + if (filter && this.filter) { + throw new Error("cannot provide both `filter` and `this.filter`"); + } + const _filter = filter ?? this.filter ?? {}; + const matchDocumentsParams: Partial = { + query_embedding: query, + }; + + let filterFunction: SupabaseFilterRPCCall; + + if (typeof _filter === "function") { + filterFunction = (rpcCall) => _filter(rpcCall).limit(k); + } else if (typeof _filter === "object") { + matchDocumentsParams.filter = _filter; + matchDocumentsParams.match_count = k; + filterFunction = (rpcCall) => rpcCall; + } else { + throw new Error("invalid filter type"); + } + + const rpcCall = this.client.rpc(this.queryName, matchDocumentsParams); + + const { data: searches, error } = await filterFunction(rpcCall); + + if (error) { + throw new Error( + `Error searching for documents: ${error.code} ${error.message} ${error.details}` + ); + } + + return searches; + } + + /** + * Performs a similarity search on the vector store. + * @param query The query vector. + * @param k The number of results to return. + * @param filter Optional filter to apply to the search. + * @returns A promise that resolves with the search results when the search is complete. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: this["FilterType"] + ): Promise<[Document, number][]> { + const searches = await this._searchSupabase(query, k, filter); + const result: [Document, number][] = searches.map((resp) => [ + new Document({ + metadata: resp.metadata, + pageContent: resp.content, + }), + resp.similarity, + ]); + + return result; + } + + /** + * Return documents selected using the maximal marginal relevance. + * Maximal marginal relevance optimizes for similarity to the query AND diversity + * among selected documents. + * + * @param {string} query - Text to look up documents similar to. + * @param {number} options.k - Number of documents to return. + * @param {number} options.fetchK=20- Number of documents to fetch before passing to the MMR algorithm. + * @param {number} options.lambda=0.5 - Number between 0 and 1 that determines the degree of diversity among the results, + * where 0 corresponds to maximum diversity and 1 to minimum diversity. + * @param {SupabaseLibArgs} options.filter - Optional filter to apply to the search. + * + * @returns {Promise} - List of documents selected by maximal marginal relevance. + */ + async maxMarginalRelevanceSearch( + query: string, + options: MaxMarginalRelevanceSearchOptions + ): Promise { + const queryEmbedding = await this.embeddings.embedQuery(query); + + const searches = await this._searchSupabase( + queryEmbedding, + options.fetchK ?? 20, + options.filter + ); + + const embeddingList = searches.map((searchResp) => searchResp.embedding); + + const mmrIndexes = maximalMarginalRelevance( + queryEmbedding, + embeddingList, + options.lambda, + options.k + ); + + return mmrIndexes.map( + (idx) => + new Document({ + metadata: searches[idx].metadata, + pageContent: searches[idx].content, + }) + ); + } + + /** + * Creates a new SupabaseVectorStore instance from an array of texts. + * @param texts The texts to create documents from. + * @param metadatas The metadata for the documents. + * @param embeddings The embeddings to use. + * @param dbConfig The configuration for the Supabase database. + * @returns A promise that resolves with a new SupabaseVectorStore instance when the instance has been created. + */ + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + dbConfig: SupabaseLibArgs + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + return SupabaseVectorStore.fromDocuments(docs, embeddings, dbConfig); + } + + /** + * Creates a new SupabaseVectorStore instance from an array of documents. + * @param docs The documents to create the instance from. + * @param embeddings The embeddings to use. + * @param dbConfig The configuration for the Supabase database. + * @returns A promise that resolves with a new SupabaseVectorStore instance when the instance has been created. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig: SupabaseLibArgs + ): Promise { + const instance = new this(embeddings, dbConfig); + await instance.addDocuments(docs); + return instance; + } + + /** + * Creates a new SupabaseVectorStore instance from an existing index. + * @param embeddings The embeddings to use. + * @param dbConfig The configuration for the Supabase database. + * @returns A promise that resolves with a new SupabaseVectorStore instance when the instance has been created. + */ + static async fromExistingIndex( + embeddings: Embeddings, + dbConfig: SupabaseLibArgs + ): Promise { + const instance = new this(embeddings, dbConfig); + return instance; + } +} diff --git a/langchain/src/vectorstores/tests/analyticdb.int.test.ts b/libs/langchain-community/src/vectorstores/tests/analyticdb.int.test.ts similarity index 97% rename from langchain/src/vectorstores/tests/analyticdb.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/analyticdb.int.test.ts index 4607dc83ec8b..5ade9b6dd8bf 100644 --- a/langchain/src/vectorstores/tests/analyticdb.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/analyticdb.int.test.ts @@ -2,9 +2,9 @@ /* eslint-disable import/no-extraneous-dependencies */ import { test } from "@jest/globals"; +import { OpenAIEmbeddings } from "@langchain/openai"; +import { Document } from "@langchain/core/documents"; import { AnalyticDBVectorStore } from "../analyticdb.js"; -import { Document } from "../../document.js"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; const connectionOptions = { host: process.env.ANALYTICDB_HOST || "localhost", diff --git a/langchain/src/vectorstores/tests/cassandra.int.test.ts b/libs/langchain-community/src/vectorstores/tests/cassandra.int.test.ts similarity index 98% rename from langchain/src/vectorstores/tests/cassandra.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/cassandra.int.test.ts index 66ceb4198c52..d9d4df2d3a3d 100644 --- a/langchain/src/vectorstores/tests/cassandra.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/cassandra.int.test.ts @@ -2,9 +2,9 @@ import { test, expect, describe } from "@jest/globals"; import { Client } from "cassandra-driver"; +import { OpenAIEmbeddings } from "@langchain/openai"; +import { Document } from "@langchain/core/documents"; import { CassandraStore } from "../cassandra.js"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; -import { Document } from "../../document.js"; const cassandraConfig = { cloud: { diff --git a/langchain/src/vectorstores/tests/chroma.int.test.ts b/libs/langchain-community/src/vectorstores/tests/chroma.int.test.ts similarity index 97% rename from langchain/src/vectorstores/tests/chroma.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/chroma.int.test.ts index 7c440c1a1095..af9da7661dc6 100644 --- a/langchain/src/vectorstores/tests/chroma.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/chroma.int.test.ts @@ -4,9 +4,9 @@ import { beforeEach, describe, expect, test } from "@jest/globals"; import { ChromaClient } from "chromadb"; import { faker } from "@faker-js/faker"; import * as uuid from "uuid"; -import { Document } from "../../document.js"; +import { Document } from "@langchain/core/documents"; +import { OpenAIEmbeddings } from "@langchain/openai"; import { Chroma } from "../chroma.js"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; describe.skip("Chroma", () => { let chromaStore: Chroma; diff --git a/langchain/src/vectorstores/tests/chroma.test.ts b/libs/langchain-community/src/vectorstores/tests/chroma.test.ts similarity index 98% rename from langchain/src/vectorstores/tests/chroma.test.ts rename to libs/langchain-community/src/vectorstores/tests/chroma.test.ts index 3ac187331015..25b78e436c0a 100644 --- a/langchain/src/vectorstores/tests/chroma.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/chroma.test.ts @@ -3,7 +3,7 @@ import { jest, test, expect } from "@jest/globals"; import { type Collection } from "chromadb"; import { Chroma } from "../chroma.js"; -import { FakeEmbeddings } from "../../embeddings/fake.js"; +import { FakeEmbeddings } from "../../utils/testing.js"; const mockCollection = { count: jest.fn().mockResolvedValue(5), diff --git a/langchain/src/vectorstores/tests/clickhouse.int.test.ts b/libs/langchain-community/src/vectorstores/tests/clickhouse.int.test.ts similarity index 98% rename from langchain/src/vectorstores/tests/clickhouse.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/clickhouse.int.test.ts index 36c68d275940..ba3e51ad3669 100644 --- a/langchain/src/vectorstores/tests/clickhouse.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/clickhouse.int.test.ts @@ -1,10 +1,10 @@ /* eslint-disable no-process-env */ import { test, expect } from "@jest/globals"; +import { Document } from "@langchain/core/documents"; import { ClickHouseStore } from "../clickhouse.js"; // Import OpenAIEmbeddings if you have a valid OpenAI API key import { HuggingFaceInferenceEmbeddings } from "../../embeddings/hf.js"; -import { Document } from "../../document.js"; test.skip("ClickHouseStore.fromText", async () => { const vectorStore = await ClickHouseStore.fromTexts( diff --git a/langchain/src/vectorstores/tests/closevector_node.int.test.ts b/libs/langchain-community/src/vectorstores/tests/closevector_node.int.test.ts similarity index 89% rename from langchain/src/vectorstores/tests/closevector_node.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/closevector_node.int.test.ts index b5718dc34015..6c7d05d30c2b 100644 --- a/langchain/src/vectorstores/tests/closevector_node.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/closevector_node.int.test.ts @@ -1,7 +1,7 @@ import { test, expect } from "@jest/globals"; +import { OpenAIEmbeddings } from "@langchain/openai"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; import { CloseVectorNode } from "../closevector/node.js"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; -import { getEnvironmentVariable } from "../../util/env.js"; test.skip("Test CloseVectorNode.fromTexts + addVectors", async () => { const key = getEnvironmentVariable("CLOSEVECTOR_API_KEY"); diff --git a/langchain/src/vectorstores/tests/closevector_node.test.ts b/libs/langchain-community/src/vectorstores/tests/closevector_node.test.ts similarity index 93% rename from langchain/src/vectorstores/tests/closevector_node.test.ts rename to libs/langchain-community/src/vectorstores/tests/closevector_node.test.ts index 207b9ea4879d..b40ea928d2fc 100644 --- a/langchain/src/vectorstores/tests/closevector_node.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/closevector_node.test.ts @@ -1,7 +1,7 @@ import { test, expect } from "@jest/globals"; +import { Document } from "@langchain/core/documents"; +import { FakeEmbeddings } from "../../utils/testing.js"; import { CloseVectorNode } from "../closevector/node.js"; -import { Document } from "../../document.js"; -import { FakeEmbeddings } from "../../embeddings/fake.js"; test("Test CloseVectorNode.fromTexts + addVectors", async () => { const vectorStore = await CloseVectorNode.fromTexts( diff --git a/langchain/src/vectorstores/tests/convex.int.test.ts b/libs/langchain-community/src/vectorstores/tests/convex.int.test.ts similarity index 100% rename from langchain/src/vectorstores/tests/convex.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/convex.int.test.ts diff --git a/langchain/src/vectorstores/tests/convex/convex/_generated/api.d.ts b/libs/langchain-community/src/vectorstores/tests/convex/convex/_generated/api.d.ts similarity index 100% rename from langchain/src/vectorstores/tests/convex/convex/_generated/api.d.ts rename to libs/langchain-community/src/vectorstores/tests/convex/convex/_generated/api.d.ts diff --git a/langchain/src/vectorstores/tests/convex/convex/_generated/api.js b/libs/langchain-community/src/vectorstores/tests/convex/convex/_generated/api.js similarity index 100% rename from langchain/src/vectorstores/tests/convex/convex/_generated/api.js rename to libs/langchain-community/src/vectorstores/tests/convex/convex/_generated/api.js diff --git a/langchain/src/vectorstores/tests/convex/convex/_generated/dataModel.d.ts b/libs/langchain-community/src/vectorstores/tests/convex/convex/_generated/dataModel.d.ts similarity index 100% rename from langchain/src/vectorstores/tests/convex/convex/_generated/dataModel.d.ts rename to libs/langchain-community/src/vectorstores/tests/convex/convex/_generated/dataModel.d.ts diff --git a/langchain/src/vectorstores/tests/convex/convex/_generated/server.d.ts b/libs/langchain-community/src/vectorstores/tests/convex/convex/_generated/server.d.ts similarity index 100% rename from langchain/src/vectorstores/tests/convex/convex/_generated/server.d.ts rename to libs/langchain-community/src/vectorstores/tests/convex/convex/_generated/server.d.ts diff --git a/langchain/src/vectorstores/tests/convex/convex/_generated/server.js b/libs/langchain-community/src/vectorstores/tests/convex/convex/_generated/server.js similarity index 100% rename from langchain/src/vectorstores/tests/convex/convex/_generated/server.js rename to libs/langchain-community/src/vectorstores/tests/convex/convex/_generated/server.js diff --git a/libs/langchain-community/src/vectorstores/tests/convex/convex/langchain/db.ts b/libs/langchain-community/src/vectorstores/tests/convex/convex/langchain/db.ts new file mode 100644 index 000000000000..02d53f0c4aff --- /dev/null +++ b/libs/langchain-community/src/vectorstores/tests/convex/convex/langchain/db.ts @@ -0,0 +1 @@ +export * from "../../../../../utils/convex.js"; diff --git a/langchain/src/vectorstores/tests/convex/convex/lib.ts b/libs/langchain-community/src/vectorstores/tests/convex/convex/lib.ts similarity index 72% rename from langchain/src/vectorstores/tests/convex/convex/lib.ts rename to libs/langchain-community/src/vectorstores/tests/convex/convex/lib.ts index d0cdbb922acc..95175610faf8 100644 --- a/langchain/src/vectorstores/tests/convex/convex/lib.ts +++ b/libs/langchain-community/src/vectorstores/tests/convex/convex/lib.ts @@ -1,8 +1,8 @@ // eslint-disable-next-line import/no-extraneous-dependencies import { v } from "convex/values"; -import { OpenAIEmbeddings } from "../../../../embeddings/openai.js"; import { ConvexVectorStore } from "../../../convex.js"; import { action, mutation } from "./_generated/server.js"; +import { FakeEmbeddings } from "../../../../utils/testing.js"; export const reset = mutation({ args: {}, @@ -18,11 +18,11 @@ export const ingest = action({ texts: v.array(v.string()), metadatas: v.array(v.any()), }, - handler: async (ctx, { openAIApiKey, texts, metadatas }) => { + handler: async (ctx, { texts, metadatas }) => { await ConvexVectorStore.fromTexts( texts, metadatas, - new OpenAIEmbeddings({ openAIApiKey }), + new FakeEmbeddings({}), { ctx } ); }, @@ -33,11 +33,8 @@ export const similaritySearch = action({ openAIApiKey: v.string(), query: v.string(), }, - handler: async (ctx, { openAIApiKey, query }) => { - const vectorStore = new ConvexVectorStore( - new OpenAIEmbeddings({ openAIApiKey }), - { ctx } - ); + handler: async (ctx, { query }) => { + const vectorStore = new ConvexVectorStore(new FakeEmbeddings({}), { ctx }); const result = await vectorStore.similaritySearch(query, 3); return result.map(({ metadata }) => metadata); diff --git a/langchain/src/vectorstores/tests/convex/convex/schema.ts b/libs/langchain-community/src/vectorstores/tests/convex/convex/schema.ts similarity index 100% rename from langchain/src/vectorstores/tests/convex/convex/schema.ts rename to libs/langchain-community/src/vectorstores/tests/convex/convex/schema.ts diff --git a/langchain/src/vectorstores/tests/convex/package.json b/libs/langchain-community/src/vectorstores/tests/convex/package.json similarity index 100% rename from langchain/src/vectorstores/tests/convex/package.json rename to libs/langchain-community/src/vectorstores/tests/convex/package.json diff --git a/langchain/src/vectorstores/tests/elasticsearch.int.test.ts b/libs/langchain-community/src/vectorstores/tests/elasticsearch.int.test.ts similarity index 97% rename from langchain/src/vectorstores/tests/elasticsearch.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/elasticsearch.int.test.ts index 4aa3be383bfc..7fc968331b9a 100644 --- a/langchain/src/vectorstores/tests/elasticsearch.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/elasticsearch.int.test.ts @@ -1,9 +1,9 @@ /* eslint-disable no-process-env */ import { test, expect } from "@jest/globals"; import { Client, ClientOptions } from "@elastic/elasticsearch"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; +import { OpenAIEmbeddings } from "@langchain/openai"; +import { Document } from "@langchain/core/documents"; import { ElasticVectorSearch } from "../elasticsearch.js"; -import { Document } from "../../document.js"; describe("ElasticVectorSearch", () => { let store: ElasticVectorSearch; diff --git a/langchain/src/vectorstores/tests/faiss.int.test.data/faiss.int.test.py b/libs/langchain-community/src/vectorstores/tests/faiss.int.test.data/faiss.int.test.py similarity index 100% rename from langchain/src/vectorstores/tests/faiss.int.test.data/faiss.int.test.py rename to libs/langchain-community/src/vectorstores/tests/faiss.int.test.data/faiss.int.test.py diff --git a/langchain/src/vectorstores/tests/faiss.int.test.data/faiss_index/index.faiss b/libs/langchain-community/src/vectorstores/tests/faiss.int.test.data/faiss_index/index.faiss similarity index 100% rename from langchain/src/vectorstores/tests/faiss.int.test.data/faiss_index/index.faiss rename to libs/langchain-community/src/vectorstores/tests/faiss.int.test.data/faiss_index/index.faiss diff --git a/langchain/src/vectorstores/tests/faiss.int.test.data/faiss_index/index.pkl b/libs/langchain-community/src/vectorstores/tests/faiss.int.test.data/faiss_index/index.pkl similarity index 100% rename from langchain/src/vectorstores/tests/faiss.int.test.data/faiss_index/index.pkl rename to libs/langchain-community/src/vectorstores/tests/faiss.int.test.data/faiss_index/index.pkl diff --git a/langchain/src/vectorstores/tests/faiss.int.test.data/requirements.txt b/libs/langchain-community/src/vectorstores/tests/faiss.int.test.data/requirements.txt similarity index 100% rename from langchain/src/vectorstores/tests/faiss.int.test.data/requirements.txt rename to libs/langchain-community/src/vectorstores/tests/faiss.int.test.data/requirements.txt diff --git a/langchain/src/vectorstores/tests/faiss.int.test.ts b/libs/langchain-community/src/vectorstores/tests/faiss.int.test.ts similarity index 98% rename from langchain/src/vectorstores/tests/faiss.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/faiss.int.test.ts index c136d5ed3dc1..cf3fc3e894a8 100644 --- a/langchain/src/vectorstores/tests/faiss.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/faiss.int.test.ts @@ -4,9 +4,9 @@ import * as path from "node:path"; import * as os from "node:os"; import { fileURLToPath } from "node:url"; +import { OpenAIEmbeddings } from "@langchain/openai"; +import { Document } from "@langchain/core/documents"; import { FaissStore } from "../faiss.js"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; -import { Document } from "../../document.js"; test("Test FaissStore.fromTexts", async () => { const vectorStore = await FaissStore.fromTexts( diff --git a/langchain/src/vectorstores/tests/faiss.test.ts b/libs/langchain-community/src/vectorstores/tests/faiss.test.ts similarity index 98% rename from langchain/src/vectorstores/tests/faiss.test.ts rename to libs/langchain-community/src/vectorstores/tests/faiss.test.ts index e2708f40d582..bcc0afc347fc 100644 --- a/langchain/src/vectorstores/tests/faiss.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/faiss.test.ts @@ -1,7 +1,7 @@ import { test, expect } from "@jest/globals"; +import { Document } from "@langchain/core/documents"; import { FaissStore } from "../faiss.js"; -import { Document } from "../../document.js"; -import { FakeEmbeddings } from "../../embeddings/fake.js"; +import { FakeEmbeddings } from "../../utils/testing.js"; test("Test FaissStore.fromTexts + addVectors", async () => { const vectorStore = await FaissStore.fromTexts( diff --git a/langchain/src/vectorstores/tests/googlevertexai.int.test.ts b/libs/langchain-community/src/vectorstores/tests/googlevertexai.int.test.ts similarity index 96% rename from langchain/src/vectorstores/tests/googlevertexai.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/googlevertexai.int.test.ts index 681379814149..db93d4be707a 100644 --- a/langchain/src/vectorstores/tests/googlevertexai.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/googlevertexai.int.test.ts @@ -1,16 +1,16 @@ /* eslint-disable no-process-env */ /* eslint-disable @typescript-eslint/no-non-null-assertion */ import { beforeAll, expect, test } from "@jest/globals"; -import { SyntheticEmbeddings } from "../../embeddings/fake.js"; +import { Document } from "@langchain/core/documents"; +import { Embeddings } from "@langchain/core/embeddings"; +import { SyntheticEmbeddings } from "../../utils/testing.js"; import { InMemoryDocstore } from "../../stores/doc/in_memory.js"; -import { Document } from "../../document.js"; import { MatchingEngineArgs, MatchingEngine, IdDocument, Restriction, } from "../googlevertexai.js"; -import { Embeddings } from "../../embeddings/base.js"; describe("Vertex AI matching", () => { let embeddings: Embeddings; diff --git a/langchain/src/vectorstores/tests/googlevertexai.test.ts b/libs/langchain-community/src/vectorstores/tests/googlevertexai.test.ts similarity index 95% rename from langchain/src/vectorstores/tests/googlevertexai.test.ts rename to libs/langchain-community/src/vectorstores/tests/googlevertexai.test.ts index fab9e981c070..07b2acbe2d17 100644 --- a/langchain/src/vectorstores/tests/googlevertexai.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/googlevertexai.test.ts @@ -1,10 +1,10 @@ /* eslint-disable no-process-env */ /* eslint-disable @typescript-eslint/no-non-null-assertion */ import { beforeEach, expect, test } from "@jest/globals"; -import { SyntheticEmbeddings } from "../../embeddings/fake.js"; +import { Embeddings } from "@langchain/core/embeddings"; +import { SyntheticEmbeddings } from "../../utils/testing.js"; import { InMemoryDocstore } from "../../stores/doc/in_memory.js"; import { MatchingEngineArgs, MatchingEngine } from "../googlevertexai.js"; -import { Embeddings } from "../../embeddings/base.js"; describe("Vertex AI matching", () => { let embeddings: Embeddings; diff --git a/langchain/src/vectorstores/tests/hnswlib.int.test.ts b/libs/langchain-community/src/vectorstores/tests/hnswlib.int.test.ts similarity index 96% rename from langchain/src/vectorstores/tests/hnswlib.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/hnswlib.int.test.ts index bc5ba7019fbe..79872e8cf474 100644 --- a/langchain/src/vectorstores/tests/hnswlib.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/hnswlib.int.test.ts @@ -3,9 +3,10 @@ import * as fs from "node:fs/promises"; import * as path from "node:path"; import * as os from "node:os"; +import { OpenAIEmbeddings } from "@langchain/openai"; +import { Document } from "@langchain/core/documents"; + import { HNSWLib } from "../hnswlib.js"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; -import { Document } from "../../document.js"; test("Test HNSWLib.fromTexts", async () => { const vectorStore = await HNSWLib.fromTexts( diff --git a/langchain/src/vectorstores/tests/hnswlib.test.ts b/libs/langchain-community/src/vectorstores/tests/hnswlib.test.ts similarity index 93% rename from langchain/src/vectorstores/tests/hnswlib.test.ts rename to libs/langchain-community/src/vectorstores/tests/hnswlib.test.ts index b191c7bfd74d..dab197901c42 100644 --- a/langchain/src/vectorstores/tests/hnswlib.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/hnswlib.test.ts @@ -1,7 +1,8 @@ import { test, expect } from "@jest/globals"; +import { Document } from "@langchain/core/documents"; + import { HNSWLib } from "../hnswlib.js"; -import { Document } from "../../document.js"; -import { FakeEmbeddings } from "../../embeddings/fake.js"; +import { FakeEmbeddings } from "../../utils/testing.js"; test("Test HNSWLib.fromTexts + addVectors", async () => { const vectorStore = await HNSWLib.fromTexts( diff --git a/langchain/src/vectorstores/tests/lancedb.int.test.ts b/libs/langchain-community/src/vectorstores/tests/lancedb.int.test.ts similarity index 91% rename from langchain/src/vectorstores/tests/lancedb.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/lancedb.int.test.ts index 7ca3bdab9d0d..ec9bb2bb566e 100644 --- a/langchain/src/vectorstores/tests/lancedb.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/lancedb.int.test.ts @@ -3,9 +3,11 @@ import * as fs from "node:fs/promises"; import * as path from "node:path"; import * as os from "node:os"; import { connect, Table } from "vectordb"; + +import { OpenAIEmbeddings } from "@langchain/openai"; +import { Document } from "@langchain/core/documents"; + import { LanceDB } from "../lancedb.js"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; -import { Document } from "../../document.js"; describe("LanceDB", () => { let lanceDBTable: Table; diff --git a/langchain/src/vectorstores/tests/milvus.int.test.ts b/libs/langchain-community/src/vectorstores/tests/milvus.int.test.ts similarity index 99% rename from langchain/src/vectorstores/tests/milvus.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/milvus.int.test.ts index c7398e250526..bc328304ddbf 100644 --- a/langchain/src/vectorstores/tests/milvus.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/milvus.int.test.ts @@ -1,7 +1,7 @@ import { test, expect, afterAll, beforeAll } from "@jest/globals"; import { ErrorCode, MilvusClient } from "@zilliz/milvus2-sdk-node"; +import { OpenAIEmbeddings } from "@langchain/openai"; import { Milvus } from "../milvus.js"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; let collectionName: string; let embeddings: OpenAIEmbeddings; diff --git a/langchain/src/vectorstores/tests/momento_vector_index.int.test.ts b/libs/langchain-community/src/vectorstores/tests/momento_vector_index.int.test.ts similarity index 97% rename from langchain/src/vectorstores/tests/momento_vector_index.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/momento_vector_index.int.test.ts index 79e598111628..0af54f2f92d0 100644 --- a/langchain/src/vectorstores/tests/momento_vector_index.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/momento_vector_index.int.test.ts @@ -8,10 +8,11 @@ import { CredentialProvider, } from "@gomomento/sdk"; import * as uuid from "uuid"; -import { Document } from "../../document.js"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; + +import { OpenAIEmbeddings } from "@langchain/openai"; +import { Document } from "@langchain/core/documents"; +import { sleep } from "../../utils/time.js"; import { MomentoVectorIndex } from "../momento_vector_index.js"; -import { sleep } from "../../util/time.js"; async function withVectorStore( block: (vectorStore: MomentoVectorIndex) => Promise diff --git a/langchain/src/vectorstores/tests/mongodb_atlas.int.test.ts b/libs/langchain-community/src/vectorstores/tests/mongodb_atlas.int.test.ts similarity index 97% rename from langchain/src/vectorstores/tests/mongodb_atlas.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/mongodb_atlas.int.test.ts index 26b254fd495c..6f4a9f0730b7 100755 --- a/langchain/src/vectorstores/tests/mongodb_atlas.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/mongodb_atlas.int.test.ts @@ -4,10 +4,10 @@ import { test, expect } from "@jest/globals"; import { MongoClient } from "mongodb"; import { setTimeout } from "timers/promises"; -import { MongoDBAtlasVectorSearch } from "../mongodb_atlas.js"; +import { OpenAIEmbeddings } from "@langchain/openai"; +import { Document } from "@langchain/core/documents"; -import { Document } from "../../document.js"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; +import { MongoDBAtlasVectorSearch } from "../mongodb_atlas.js"; /** * The following json can be used to create an index in atlas for Cohere embeddings. diff --git a/langchain/src/vectorstores/tests/myscale.int.test.ts b/libs/langchain-community/src/vectorstores/tests/myscale.int.test.ts similarity index 95% rename from langchain/src/vectorstores/tests/myscale.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/myscale.int.test.ts index 294e120581c7..d50ea8cf5926 100644 --- a/langchain/src/vectorstores/tests/myscale.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/myscale.int.test.ts @@ -1,9 +1,10 @@ /* eslint-disable no-process-env */ import { test, expect } from "@jest/globals"; +import { OpenAIEmbeddings } from "@langchain/openai"; +import { Document } from "@langchain/core/documents"; + import { MyScaleStore } from "../myscale.js"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; -import { Document } from "../../document.js"; test.skip("MyScaleStore.fromText", async () => { const vectorStore = await MyScaleStore.fromTexts( diff --git a/langchain/src/vectorstores/tests/neo4j_vector.int.test.ts b/libs/langchain-community/src/vectorstores/tests/neo4j_vector.int.test.ts similarity index 99% rename from langchain/src/vectorstores/tests/neo4j_vector.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/neo4j_vector.int.test.ts index 5bdfbf02c9ce..7aed797d4cc3 100644 --- a/langchain/src/vectorstores/tests/neo4j_vector.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/neo4j_vector.int.test.ts @@ -1,7 +1,7 @@ /* eslint-disable no-process-env */ -import { FakeEmbeddings } from "../../embeddings/fake.js"; +import { Document } from "@langchain/core/documents"; +import { FakeEmbeddings } from "../../utils/testing.js"; import { Neo4jVectorStore } from "../neo4j_vector.js"; -import { Document } from "../../document.js"; const OS_TOKEN_COUNT = 1536; diff --git a/langchain/src/vectorstores/tests/opensearch.int.test.ts b/libs/langchain-community/src/vectorstores/tests/opensearch.int.test.ts similarity index 91% rename from langchain/src/vectorstores/tests/opensearch.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/opensearch.int.test.ts index 5f04262928b8..f3a497dc5b9d 100644 --- a/langchain/src/vectorstores/tests/opensearch.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/opensearch.int.test.ts @@ -1,9 +1,9 @@ /* eslint-disable no-process-env */ import { test, expect } from "@jest/globals"; import { Client } from "@opensearch-project/opensearch"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; +import { OpenAIEmbeddings } from "@langchain/openai"; +import { Document } from "@langchain/core/documents"; import { OpenSearchVectorStore } from "../opensearch.js"; -import { Document } from "../../document.js"; test.skip("OpenSearchVectorStore integration", async () => { if (!process.env.OPENSEARCH_URL) { diff --git a/langchain/src/vectorstores/tests/pgvector.int.test.ts b/libs/langchain-community/src/vectorstores/tests/pgvector.int.test.ts similarity index 97% rename from langchain/src/vectorstores/tests/pgvector.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/pgvector.int.test.ts index f70a777a41db..5d3ae78bd0fc 100644 --- a/langchain/src/vectorstores/tests/pgvector.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/pgvector.int.test.ts @@ -1,6 +1,6 @@ import { expect, test } from "@jest/globals"; import type { PoolConfig } from "pg"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; +import { OpenAIEmbeddings } from "@langchain/openai"; import { PGVectorStore } from "../pgvector.js"; describe("PGVectorStore", () => { diff --git a/langchain/src/vectorstores/tests/pinecone.int.test.ts b/libs/langchain-community/src/vectorstores/tests/pinecone.int.test.ts similarity index 97% rename from langchain/src/vectorstores/tests/pinecone.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/pinecone.int.test.ts index f9c6894b15b5..744c6e73ae31 100644 --- a/langchain/src/vectorstores/tests/pinecone.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/pinecone.int.test.ts @@ -5,8 +5,8 @@ import { describe, expect, test } from "@jest/globals"; import { faker } from "@faker-js/faker"; import { Pinecone } from "@pinecone-database/pinecone"; import * as uuid from "uuid"; -import { Document } from "../../document.js"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; +import { OpenAIEmbeddings } from "@langchain/openai"; +import { Document } from "@langchain/core/documents"; import { PineconeLibArgs, PineconeStore } from "../pinecone.js"; describe("PineconeStore", () => { diff --git a/langchain/src/vectorstores/tests/pinecone.test.ts b/libs/langchain-community/src/vectorstores/tests/pinecone.test.ts similarity index 97% rename from langchain/src/vectorstores/tests/pinecone.test.ts rename to libs/langchain-community/src/vectorstores/tests/pinecone.test.ts index 8e4dd19439c4..dbd50faf1ef1 100644 --- a/langchain/src/vectorstores/tests/pinecone.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/pinecone.test.ts @@ -1,6 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { jest, test, expect } from "@jest/globals"; -import { FakeEmbeddings } from "../../embeddings/fake.js"; +import { FakeEmbeddings } from "../../utils/testing.js"; import { PineconeStore } from "../pinecone.js"; test("PineconeStore with external ids", async () => { diff --git a/langchain/src/vectorstores/tests/qdrant.int.test.ts b/libs/langchain-community/src/vectorstores/tests/qdrant.int.test.ts similarity index 94% rename from langchain/src/vectorstores/tests/qdrant.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/qdrant.int.test.ts index 5779d1acdda7..4b7c0b1208a2 100644 --- a/langchain/src/vectorstores/tests/qdrant.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/qdrant.int.test.ts @@ -3,8 +3,8 @@ import { describe, expect, test } from "@jest/globals"; import { QdrantClient } from "@qdrant/js-client-rest"; import { faker } from "@faker-js/faker"; -import { Document } from "../../document.js"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; +import { OpenAIEmbeddings } from "@langchain/openai"; +import { Document } from "@langchain/core/documents"; import { QdrantVectorStore } from "../qdrant.js"; import { OllamaEmbeddings } from "../../embeddings/ollama.js"; diff --git a/langchain/src/vectorstores/tests/qdrant.test.ts b/libs/langchain-community/src/vectorstores/tests/qdrant.test.ts similarity index 93% rename from langchain/src/vectorstores/tests/qdrant.test.ts rename to libs/langchain-community/src/vectorstores/tests/qdrant.test.ts index 38d5a8c1c85e..7e3aaf0eab2b 100644 --- a/langchain/src/vectorstores/tests/qdrant.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/qdrant.test.ts @@ -1,6 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { jest, test, expect } from "@jest/globals"; -import { FakeEmbeddings } from "../../embeddings/fake.js"; +import { FakeEmbeddings } from "../../utils/testing.js"; import { QdrantVectorStore } from "../qdrant.js"; diff --git a/langchain/src/vectorstores/tests/redis.int.test.ts b/libs/langchain-community/src/vectorstores/tests/redis.int.test.ts similarity index 95% rename from langchain/src/vectorstores/tests/redis.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/redis.int.test.ts index 92249db7d50a..f44f52b8d273 100644 --- a/langchain/src/vectorstores/tests/redis.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/redis.int.test.ts @@ -5,9 +5,9 @@ import { RedisClientType, createClient } from "redis"; import { v4 as uuidv4 } from "uuid"; import { test, expect } from "@jest/globals"; import { faker } from "@faker-js/faker"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; +import { OpenAIEmbeddings } from "@langchain/openai"; +import { Document } from "@langchain/core/documents"; import { RedisVectorStore } from "../redis.js"; -import { Document } from "../../document.js"; describe("RedisVectorStore", () => { let vectorStore: RedisVectorStore; diff --git a/langchain/src/vectorstores/tests/redis.test.ts b/libs/langchain-community/src/vectorstores/tests/redis.test.ts similarity index 98% rename from langchain/src/vectorstores/tests/redis.test.ts rename to libs/langchain-community/src/vectorstores/tests/redis.test.ts index 550c4aff7ccb..681e63c67685 100644 --- a/langchain/src/vectorstores/tests/redis.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/redis.test.ts @@ -1,6 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { jest, test, expect, describe } from "@jest/globals"; -import { FakeEmbeddings } from "../../embeddings/fake.js"; +import { FakeEmbeddings } from "../../utils/testing.js"; import { RedisVectorStore } from "../redis.js"; diff --git a/langchain/src/vectorstores/tests/rockset.int.test.ts b/libs/langchain-community/src/vectorstores/tests/rockset.int.test.ts similarity index 92% rename from langchain/src/vectorstores/tests/rockset.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/rockset.int.test.ts index 88d93bac8283..6891aa921a8d 100644 --- a/langchain/src/vectorstores/tests/rockset.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/rockset.int.test.ts @@ -2,12 +2,12 @@ /* eslint-disable @typescript-eslint/no-non-null-assertion */ import rockset from "@rockset/client"; import { test, expect } from "@jest/globals"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; +import { OpenAIEmbeddings } from "@langchain/openai"; +import { Document } from "@langchain/core/documents"; import { RocksetStore, SimilarityMetric } from "../rockset.js"; -import { Document } from "../../document.js"; -import { formatDocumentsAsString } from "../../util/document.js"; -const getPageContents = formatDocumentsAsString; +const getPageContents = (documents: Document[]) => + documents.map((document) => document.pageContent); const embeddings = new OpenAIEmbeddings(); let store: RocksetStore | undefined; diff --git a/langchain/src/vectorstores/tests/singlestore.int.test.ts b/libs/langchain-community/src/vectorstores/tests/singlestore.int.test.ts similarity index 97% rename from langchain/src/vectorstores/tests/singlestore.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/singlestore.int.test.ts index c07f3e6d38dd..aacfd5dd71b4 100644 --- a/langchain/src/vectorstores/tests/singlestore.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/singlestore.int.test.ts @@ -1,9 +1,9 @@ /* eslint-disable no-process-env */ /* eslint-disable import/no-extraneous-dependencies */ import { test, expect } from "@jest/globals"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; +import { OpenAIEmbeddings } from "@langchain/openai"; +import { Document } from "@langchain/core/documents"; import { SingleStoreVectorStore } from "../singlestore.js"; -import { Document } from "../../document.js"; test.skip("SingleStoreVectorStore", async () => { expect(process.env.SINGLESTORE_HOST).toBeDefined(); diff --git a/langchain/src/vectorstores/tests/supabase.int.test.ts b/libs/langchain-community/src/vectorstores/tests/supabase.int.test.ts similarity index 99% rename from langchain/src/vectorstores/tests/supabase.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/supabase.int.test.ts index 04739a143767..cc785d39a06b 100644 --- a/langchain/src/vectorstores/tests/supabase.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/supabase.int.test.ts @@ -3,8 +3,8 @@ import { test, expect } from "@jest/globals"; import { createClient } from "@supabase/supabase-js"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; -import { Document } from "../../document.js"; +import { OpenAIEmbeddings } from "@langchain/openai"; +import { Document } from "@langchain/core/documents"; import { SupabaseVectorStore, SupabaseFilterRPCCall } from "../supabase.js"; test("SupabaseVectorStore with external ids", async () => { diff --git a/langchain/src/vectorstores/tests/supabase.test.ts b/libs/langchain-community/src/vectorstores/tests/supabase.test.ts similarity index 96% rename from langchain/src/vectorstores/tests/supabase.test.ts rename to libs/langchain-community/src/vectorstores/tests/supabase.test.ts index 76ed734f1f66..0e82073da3d0 100644 --- a/langchain/src/vectorstores/tests/supabase.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/supabase.test.ts @@ -3,7 +3,7 @@ import { SupabaseClient } from "@supabase/supabase-js"; import { SupabaseVectorStore } from "../supabase.js"; -import { FakeEmbeddings } from "../../embeddings/fake.js"; +import { FakeEmbeddings } from "../../utils/testing.js"; test("similaritySearchVectorWithScore should call RPC with the vectorstore filters", async () => { const supabaseClientMock = { diff --git a/langchain/src/vectorstores/tests/tigris.test.ts b/libs/langchain-community/src/vectorstores/tests/tigris.test.ts similarity index 96% rename from langchain/src/vectorstores/tests/tigris.test.ts rename to libs/langchain-community/src/vectorstores/tests/tigris.test.ts index 0e722572e7dc..f838ff10fdc0 100644 --- a/langchain/src/vectorstores/tests/tigris.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/tigris.test.ts @@ -1,6 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { jest, test, expect } from "@jest/globals"; -import { FakeEmbeddings } from "../../embeddings/fake.js"; +import { FakeEmbeddings } from "../../utils/testing.js"; import { TigrisVectorStore } from "../tigris.js"; diff --git a/langchain/src/vectorstores/tests/typeorm.int.test.ts b/libs/langchain-community/src/vectorstores/tests/typeorm.int.test.ts similarity index 95% rename from langchain/src/vectorstores/tests/typeorm.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/typeorm.int.test.ts index be1068437883..1a59fde00409 100644 --- a/langchain/src/vectorstores/tests/typeorm.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/typeorm.int.test.ts @@ -1,6 +1,6 @@ import { expect, test } from "@jest/globals"; import { DataSourceOptions } from "typeorm"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; +import { OpenAIEmbeddings } from "@langchain/openai"; import { TypeORMVectorStore } from "../typeorm.js"; test.skip("Test embeddings creation", async () => { diff --git a/langchain/src/vectorstores/tests/typesense.test.ts b/libs/langchain-community/src/vectorstores/tests/typesense.test.ts similarity index 96% rename from langchain/src/vectorstores/tests/typesense.test.ts rename to libs/langchain-community/src/vectorstores/tests/typesense.test.ts index 3ae95690bde0..08c872528001 100644 --- a/langchain/src/vectorstores/tests/typesense.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/typesense.test.ts @@ -1,6 +1,6 @@ import { Client } from "typesense"; -import { Document } from "../../document.js"; -import { FakeEmbeddings } from "../../embeddings/fake.js"; +import { Document } from "@langchain/core/documents"; +import { FakeEmbeddings } from "../../utils/testing.js"; import { Typesense } from "../typesense.js"; test("documentsToTypesenseRecords should return the correct typesense records", async () => { diff --git a/langchain/src/vectorstores/tests/usearch.int.test.ts b/libs/langchain-community/src/vectorstores/tests/usearch.int.test.ts similarity index 94% rename from langchain/src/vectorstores/tests/usearch.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/usearch.int.test.ts index 08fe7836a51f..88ecf2e3a955 100644 --- a/langchain/src/vectorstores/tests/usearch.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/usearch.int.test.ts @@ -1,7 +1,7 @@ import { test, expect } from "@jest/globals"; +import { Document } from "@langchain/core/documents"; import { USearch } from "../usearch.js"; -import { Document } from "../../document.js"; -import { FakeEmbeddings } from "../../embeddings/fake.js"; +import { FakeEmbeddings } from "../../utils/testing.js"; test("Test USearch.fromTexts + addVectors", async () => { const vectorStore = await USearch.fromTexts( diff --git a/langchain/src/vectorstores/tests/vectara.int.test.ts b/libs/langchain-community/src/vectorstores/tests/vectara.int.test.ts similarity index 98% rename from langchain/src/vectorstores/tests/vectara.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/vectara.int.test.ts index c29ac59a8de6..e8bde2c1bbf5 100644 --- a/langchain/src/vectorstores/tests/vectara.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/vectara.int.test.ts @@ -3,8 +3,8 @@ import fs from "fs"; import { expect, beforeAll } from "@jest/globals"; import { insecureHash } from "@langchain/core/utils/hash"; -import { FakeEmbeddings } from "../../embeddings/fake.js"; -import { Document } from "../../document.js"; +import { Document } from "@langchain/core/documents"; +import { FakeEmbeddings } from "../../utils/testing.js"; import { VectaraFile, VectaraLibArgs, VectaraStore } from "../vectara.js"; const getDocs = (): Document[] => { diff --git a/langchain/src/vectorstores/tests/vercel_postgres.int.test.ts b/libs/langchain-community/src/vectorstores/tests/vercel_postgres.int.test.ts similarity index 98% rename from langchain/src/vectorstores/tests/vercel_postgres.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/vercel_postgres.int.test.ts index 16ce499cd24d..a7a8a43230e4 100644 --- a/langchain/src/vectorstores/tests/vercel_postgres.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/vercel_postgres.int.test.ts @@ -1,5 +1,5 @@ import { expect, test } from "@jest/globals"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; +import { OpenAIEmbeddings } from "@langchain/openai"; import { VercelPostgres } from "../vercel_postgres.js"; let vercelPostgresStore: VercelPostgres; diff --git a/langchain/src/vectorstores/tests/voy.int.test.ts b/libs/langchain-community/src/vectorstores/tests/voy.int.test.ts similarity index 92% rename from langchain/src/vectorstores/tests/voy.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/voy.int.test.ts index 5c4abfb357af..d8d221100aa3 100644 --- a/langchain/src/vectorstores/tests/voy.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/voy.int.test.ts @@ -1,7 +1,7 @@ import { expect, test } from "@jest/globals"; import { Voy as VoyOriginClient } from "voy-search"; -import { Document } from "../../document.js"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; +import { OpenAIEmbeddings } from "@langchain/openai"; +import { Document } from "@langchain/core/documents"; import { VoyVectorStore } from "../voy.js"; const client = new VoyOriginClient(); diff --git a/langchain/src/vectorstores/tests/voy.test.ts b/libs/langchain-community/src/vectorstores/tests/voy.test.ts similarity index 92% rename from langchain/src/vectorstores/tests/voy.test.ts rename to libs/langchain-community/src/vectorstores/tests/voy.test.ts index 5f05a789a2cf..4d8cfcb473a5 100644 --- a/langchain/src/vectorstores/tests/voy.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/voy.test.ts @@ -1,6 +1,6 @@ import { test, expect } from "@jest/globals"; -import { Document } from "../../document.js"; -import { FakeEmbeddings } from "../../embeddings/fake.js"; +import { Document } from "@langchain/core/documents"; +import { FakeEmbeddings } from "../../utils/testing.js"; import { VoyVectorStore, VoyClient } from "../voy.js"; const fakeClient: VoyClient = { diff --git a/langchain/src/vectorstores/tests/weaviate.int.test.ts b/libs/langchain-community/src/vectorstores/tests/weaviate.int.test.ts similarity index 98% rename from langchain/src/vectorstores/tests/weaviate.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/weaviate.int.test.ts index c12e4e53129c..945bcc7dadbd 100644 --- a/langchain/src/vectorstores/tests/weaviate.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/weaviate.int.test.ts @@ -1,9 +1,9 @@ /* eslint-disable no-process-env */ import { test, expect } from "@jest/globals"; import weaviate from "weaviate-ts-client"; +import { OpenAIEmbeddings } from "@langchain/openai"; +import { Document } from "@langchain/core/documents"; import { WeaviateStore } from "../weaviate.js"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; -import { Document } from "../../document.js"; test("WeaviateStore", async () => { // Something wrong with the weaviate-ts-client types, so we need to disable diff --git a/langchain/src/vectorstores/tests/weaviate.test.ts b/libs/langchain-community/src/vectorstores/tests/weaviate.test.ts similarity index 100% rename from langchain/src/vectorstores/tests/weaviate.test.ts rename to libs/langchain-community/src/vectorstores/tests/weaviate.test.ts diff --git a/langchain/src/vectorstores/tests/xata.int.test.ts b/libs/langchain-community/src/vectorstores/tests/xata.int.test.ts similarity index 97% rename from langchain/src/vectorstores/tests/xata.int.test.ts rename to libs/langchain-community/src/vectorstores/tests/xata.int.test.ts index 977754d0189c..c7040c029c76 100644 --- a/langchain/src/vectorstores/tests/xata.int.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/xata.int.test.ts @@ -1,10 +1,10 @@ /* eslint-disable no-process-env */ // eslint-disable-next-line import/no-extraneous-dependencies import { BaseClient } from "@xata.io/client"; +import { OpenAIEmbeddings } from "@langchain/openai"; +import { Document } from "@langchain/core/documents"; import { XataVectorSearch } from "../xata.js"; -import { OpenAIEmbeddings } from "../../embeddings/openai.js"; -import { Document } from "../../document.js"; // Tests require a DB with a table called "docs" with: // * a column name content of type Text diff --git a/langchain/src/vectorstores/tests/zep.test.ts b/libs/langchain-community/src/vectorstores/tests/zep.test.ts similarity index 98% rename from langchain/src/vectorstores/tests/zep.test.ts rename to libs/langchain-community/src/vectorstores/tests/zep.test.ts index 5468cc9dab9c..1b6109d64fa6 100644 --- a/langchain/src/vectorstores/tests/zep.test.ts +++ b/libs/langchain-community/src/vectorstores/tests/zep.test.ts @@ -8,10 +8,10 @@ import { NotFoundError, ZepClient, } from "@getzep/zep-js"; -import { Document } from "../../document.js"; +import { Embeddings } from "@langchain/core/embeddings"; +import { Document } from "@langchain/core/documents"; import { IZepConfig, ZepVectorStore } from "../zep.js"; -import { Embeddings } from "../../embeddings/base.js"; -import { FakeEmbeddings } from "../../embeddings/fake.js"; +import { FakeEmbeddings } from "../../utils/testing.js"; jest.mock("@getzep/zep-js"); diff --git a/libs/langchain-community/src/vectorstores/tigris.ts b/libs/langchain-community/src/vectorstores/tigris.ts new file mode 100644 index 000000000000..e4bc57623991 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/tigris.ts @@ -0,0 +1,177 @@ +import * as uuid from "uuid"; + +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; + +/** + * Type definition for the arguments required to initialize a + * TigrisVectorStore instance. + */ +export type TigrisLibArgs = { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + index: any; +}; + +/** + * Class for managing and operating vector search applications with + * Tigris, an open-source Serverless NoSQL Database and Search Platform. + */ +export class TigrisVectorStore extends VectorStore { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + index?: any; + + _vectorstoreType(): string { + return "tigris"; + } + + constructor(embeddings: Embeddings, args: TigrisLibArgs) { + super(embeddings, args); + + this.embeddings = embeddings; + this.index = args.index; + } + + /** + * Method to add an array of documents to the Tigris database. + * @param documents An array of Document instances to be added to the Tigris database. + * @param options Optional parameter that can either be an array of string IDs or an object with a property 'ids' that is an array of string IDs. + * @returns A Promise that resolves when the documents have been added to the Tigris database. + */ + async addDocuments( + documents: Document[], + options?: { ids?: string[] } | string[] + ): Promise { + const texts = documents.map(({ pageContent }) => pageContent); + await this.addVectors( + await this.embeddings.embedDocuments(texts), + documents, + options + ); + } + + /** + * Method to add vectors to the Tigris database. + * @param vectors An array of vectors to be added to the Tigris database. + * @param documents An array of Document instances corresponding to the vectors. + * @param options Optional parameter that can either be an array of string IDs or an object with a property 'ids' that is an array of string IDs. + * @returns A Promise that resolves when the vectors have been added to the Tigris database. + */ + async addVectors( + vectors: number[][], + documents: Document[], + options?: { ids?: string[] } | string[] + ) { + if (vectors.length === 0) { + return; + } + + if (vectors.length !== documents.length) { + throw new Error(`Vectors and metadatas must have the same length`); + } + + const ids = Array.isArray(options) ? options : options?.ids; + const documentIds = ids == null ? documents.map(() => uuid.v4()) : ids; + await this.index?.addDocumentsWithVectors({ + ids: documentIds, + embeddings: vectors, + documents: documents.map(({ metadata, pageContent }) => ({ + content: pageContent, + metadata, + })), + }); + } + + /** + * Method to perform a similarity search in the Tigris database and return + * the k most similar vectors along with their similarity scores. + * @param query The query vector. + * @param k The number of most similar vectors to return. + * @param filter Optional filter object to apply during the search. + * @returns A Promise that resolves to an array of tuples, each containing a Document and its similarity score. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: object + ) { + const result = await this.index?.similaritySearchVectorWithScore({ + query, + k, + filter, + }); + + if (!result) { + return []; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return result.map(([document, score]: [any, any]) => [ + new Document({ + pageContent: document.content, + metadata: document.metadata, + }), + score, + ]) as [Document, number][]; + } + + /** + * Static method to create a new instance of TigrisVectorStore from an + * array of texts. + * @param texts An array of texts to be converted into Document instances and added to the Tigris database. + * @param metadatas Either an array of metadata objects or a single metadata object to be associated with the texts. + * @param embeddings An instance of Embeddings to be used for embedding the texts. + * @param dbConfig An instance of TigrisLibArgs to be used for configuring the Tigris database. + * @returns A Promise that resolves to a new instance of TigrisVectorStore. + */ + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + dbConfig: TigrisLibArgs + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + return TigrisVectorStore.fromDocuments(docs, embeddings, dbConfig); + } + + /** + * Static method to create a new instance of TigrisVectorStore from an + * array of Document instances. + * @param docs An array of Document instances to be added to the Tigris database. + * @param embeddings An instance of Embeddings to be used for embedding the documents. + * @param dbConfig An instance of TigrisLibArgs to be used for configuring the Tigris database. + * @returns A Promise that resolves to a new instance of TigrisVectorStore. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig: TigrisLibArgs + ): Promise { + const instance = new this(embeddings, dbConfig); + await instance.addDocuments(docs); + return instance; + } + + /** + * Static method to create a new instance of TigrisVectorStore from an + * existing index. + * @param embeddings An instance of Embeddings to be used for embedding the documents. + * @param dbConfig An instance of TigrisLibArgs to be used for configuring the Tigris database. + * @returns A Promise that resolves to a new instance of TigrisVectorStore. + */ + static async fromExistingIndex( + embeddings: Embeddings, + dbConfig: TigrisLibArgs + ): Promise { + const instance = new this(embeddings, dbConfig); + return instance; + } +} diff --git a/libs/langchain-community/src/vectorstores/typeorm.ts b/libs/langchain-community/src/vectorstores/typeorm.ts new file mode 100644 index 000000000000..22d2d85e1f95 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/typeorm.ts @@ -0,0 +1,298 @@ +import { Metadata } from "@opensearch-project/opensearch/api/types.js"; +import { DataSource, DataSourceOptions, EntitySchema } from "typeorm"; +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +/** + * Interface that defines the arguments required to create a + * `TypeORMVectorStore` instance. It includes Postgres connection options, + * table name, filter, and verbosity level. + */ +export interface TypeORMVectorStoreArgs { + postgresConnectionOptions: DataSourceOptions; + tableName?: string; + filter?: Metadata; + verbose?: boolean; +} + +/** + * Class that extends the `Document` base class and adds an `embedding` + * property. It represents a document in the vector store. + */ +export class TypeORMVectorStoreDocument extends Document { + embedding: string; + + id?: string; +} + +const defaultDocumentTableName = "documents"; + +/** + * Class that provides an interface to a Postgres vector database. It + * extends the `VectorStore` base class and implements methods for adding + * documents and vectors, performing similarity searches, and ensuring the + * existence of a table in the database. + */ +export class TypeORMVectorStore extends VectorStore { + declare FilterType: Metadata; + + tableName: string; + + documentEntity: EntitySchema; + + filter?: Metadata; + + appDataSource: DataSource; + + _verbose?: boolean; + + _vectorstoreType(): string { + return "typeorm"; + } + + private constructor(embeddings: Embeddings, fields: TypeORMVectorStoreArgs) { + super(embeddings, fields); + this.tableName = fields.tableName || defaultDocumentTableName; + this.filter = fields.filter; + + const TypeORMDocumentEntity = new EntitySchema({ + name: fields.tableName ?? defaultDocumentTableName, + columns: { + id: { + generated: "uuid", + type: "uuid", + primary: true, + }, + pageContent: { + type: String, + }, + metadata: { + type: "jsonb", + }, + embedding: { + type: String, + }, + }, + }); + const appDataSource = new DataSource({ + entities: [TypeORMDocumentEntity], + ...fields.postgresConnectionOptions, + }); + this.appDataSource = appDataSource; + this.documentEntity = TypeORMDocumentEntity; + + this._verbose = + getEnvironmentVariable("LANGCHAIN_VERBOSE") === "true" ?? + fields.verbose ?? + false; + } + + /** + * Static method to create a new `TypeORMVectorStore` instance from a + * `DataSource`. It initializes the `DataSource` if it is not already + * initialized. + * @param embeddings Embeddings instance. + * @param fields `TypeORMVectorStoreArgs` instance. + * @returns A new instance of `TypeORMVectorStore`. + */ + static async fromDataSource( + embeddings: Embeddings, + fields: TypeORMVectorStoreArgs + ): Promise { + const postgresqlVectorStore = new TypeORMVectorStore(embeddings, fields); + + if (!postgresqlVectorStore.appDataSource.isInitialized) { + await postgresqlVectorStore.appDataSource.initialize(); + } + + return postgresqlVectorStore; + } + + /** + * Method to add documents to the vector store. It ensures the existence + * of the table in the database, converts the documents into vectors, and + * adds them to the store. + * @param documents Array of `Document` instances. + * @returns Promise that resolves when the documents have been added. + */ + async addDocuments(documents: Document[]): Promise { + const texts = documents.map(({ pageContent }) => pageContent); + // This will create the table if it does not exist. We can call it every time as it doesn't + // do anything if the table already exists, and it is not expensive in terms of performance + await this.ensureTableInDatabase(); + return this.addVectors( + await this.embeddings.embedDocuments(texts), + documents + ); + } + + /** + * Method to add vectors to the vector store. It converts the vectors into + * rows and inserts them into the database. + * @param vectors Array of vectors. + * @param documents Array of `Document` instances. + * @returns Promise that resolves when the vectors have been added. + */ + async addVectors(vectors: number[][], documents: Document[]): Promise { + const rows = vectors.map((embedding, idx) => { + const embeddingString = `[${embedding.join(",")}]`; + const documentRow = { + pageContent: documents[idx].pageContent, + embedding: embeddingString, + metadata: documents[idx].metadata, + }; + + return documentRow; + }); + + const documentRepository = this.appDataSource.getRepository( + this.documentEntity + ); + + const chunkSize = 500; + for (let i = 0; i < rows.length; i += chunkSize) { + const chunk = rows.slice(i, i + chunkSize); + + try { + await documentRepository.save(chunk); + } catch (e) { + console.error(e); + throw new Error(`Error inserting: ${chunk[0].pageContent}`); + } + } + } + + /** + * Method to perform a similarity search in the vector store. It returns + * the `k` most similar documents to the query vector, along with their + * similarity scores. + * @param query Query vector. + * @param k Number of most similar documents to return. + * @param filter Optional filter to apply to the search. + * @returns Promise that resolves with an array of tuples, each containing a `TypeORMVectorStoreDocument` and its similarity score. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: this["FilterType"] + ): Promise<[TypeORMVectorStoreDocument, number][]> { + const embeddingString = `[${query.join(",")}]`; + const _filter = filter ?? "{}"; + + const queryString = ` + SELECT *, embedding <=> $1 as "_distance" + FROM ${this.tableName} + WHERE metadata @> $2 + ORDER BY "_distance" ASC + LIMIT $3;`; + + const documents = await this.appDataSource.query(queryString, [ + embeddingString, + _filter, + k, + ]); + + const results = [] as [TypeORMVectorStoreDocument, number][]; + for (const doc of documents) { + if (doc._distance != null && doc.pageContent != null) { + const document = new Document(doc) as TypeORMVectorStoreDocument; + document.id = doc.id; + results.push([document, doc._distance]); + } + } + + return results; + } + + /** + * Method to ensure the existence of the table in the database. It creates + * the table if it does not already exist. + * @returns Promise that resolves when the table has been ensured. + */ + async ensureTableInDatabase(): Promise { + await this.appDataSource.query("CREATE EXTENSION IF NOT EXISTS vector;"); + await this.appDataSource.query( + 'CREATE EXTENSION IF NOT EXISTS "uuid-ossp";' + ); + + await this.appDataSource.query(` + CREATE TABLE IF NOT EXISTS ${this.tableName} ( + "id" uuid NOT NULL DEFAULT uuid_generate_v4() PRIMARY KEY, + "pageContent" text, + metadata jsonb, + embedding vector + ); + `); + } + + /** + * Static method to create a new `TypeORMVectorStore` instance from an + * array of texts and their metadata. It converts the texts into + * `Document` instances and adds them to the store. + * @param texts Array of texts. + * @param metadatas Array of metadata objects or a single metadata object. + * @param embeddings Embeddings instance. + * @param dbConfig `TypeORMVectorStoreArgs` instance. + * @returns Promise that resolves with a new instance of `TypeORMVectorStore`. + */ + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + dbConfig: TypeORMVectorStoreArgs + ): Promise { + const docs = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + + return TypeORMVectorStore.fromDocuments(docs, embeddings, dbConfig); + } + + /** + * Static method to create a new `TypeORMVectorStore` instance from an + * array of `Document` instances. It adds the documents to the store. + * @param docs Array of `Document` instances. + * @param embeddings Embeddings instance. + * @param dbConfig `TypeORMVectorStoreArgs` instance. + * @returns Promise that resolves with a new instance of `TypeORMVectorStore`. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig: TypeORMVectorStoreArgs + ): Promise { + const instance = await TypeORMVectorStore.fromDataSource( + embeddings, + dbConfig + ); + await instance.addDocuments(docs); + + return instance; + } + + /** + * Static method to create a new `TypeORMVectorStore` instance from an + * existing index. + * @param embeddings Embeddings instance. + * @param dbConfig `TypeORMVectorStoreArgs` instance. + * @returns Promise that resolves with a new instance of `TypeORMVectorStore`. + */ + static async fromExistingIndex( + embeddings: Embeddings, + dbConfig: TypeORMVectorStoreArgs + ): Promise { + const instance = await TypeORMVectorStore.fromDataSource( + embeddings, + dbConfig + ); + return instance; + } +} diff --git a/libs/langchain-community/src/vectorstores/typesense.ts b/libs/langchain-community/src/vectorstores/typesense.ts new file mode 100644 index 000000000000..d914a6cc9eee --- /dev/null +++ b/libs/langchain-community/src/vectorstores/typesense.ts @@ -0,0 +1,323 @@ +import type { Client } from "typesense"; +import type { MultiSearchRequestSchema } from "typesense/lib/Typesense/MultiSearch.js"; +import type { + SearchResponseHit, + DocumentSchema, +} from "typesense/lib/Typesense/Documents.js"; +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; +import { + AsyncCaller, + AsyncCallerParams, +} from "@langchain/core/utils/async_caller"; + +/** + * Interface for the response hit from a vector search in Typesense. + */ +interface VectorSearchResponseHit + extends SearchResponseHit { + vector_distance?: number; +} + +/** + * Typesense vector store configuration. + */ +export interface TypesenseConfig extends AsyncCallerParams { + /** + * Typesense client. + */ + typesenseClient: Client; + /** + * Typesense schema name in which documents will be stored and searched. + */ + schemaName: string; + /** + * Typesense search parameters. + * @default { q: '*', per_page: 5, query_by: '' } + */ + searchParams?: MultiSearchRequestSchema; + /** + * Column names. + */ + columnNames?: { + /** + * Vector column name. + * @default 'vec' + */ + vector?: string; + /** + * Page content column name. + * @default 'text' + */ + pageContent?: string; + /** + * Metadata column names. + * @default [] + */ + metadataColumnNames?: string[]; + }; + /** + * Replace default import function. + * Default import function will update documents if there is a document with the same id. + * @param data + * @param collectionName + */ + import? = Record>( + data: T[], + collectionName: string + ): Promise; +} + +/** + * Typesense vector store. + */ +export class Typesense extends VectorStore { + declare FilterType: Partial; + + private client: Client; + + private schemaName: string; + + private searchParams: MultiSearchRequestSchema; + + private vectorColumnName: string; + + private pageContentColumnName: string; + + private metadataColumnNames: string[]; + + private caller: AsyncCaller; + + private import: ( + data: Record[], + collectionName: string + ) => Promise; + + _vectorstoreType(): string { + return "typesense"; + } + + constructor(embeddings: Embeddings, config: TypesenseConfig) { + super(embeddings, config); + + // Assign config values to class properties. + this.client = config.typesenseClient; + this.schemaName = config.schemaName; + this.searchParams = config.searchParams || { + q: "*", + per_page: 5, + query_by: "", + }; + this.vectorColumnName = config.columnNames?.vector || "vec"; + this.pageContentColumnName = config.columnNames?.pageContent || "text"; + this.metadataColumnNames = config.columnNames?.metadataColumnNames || []; + + // Assign import function. + this.import = config.import || this.importToTypesense.bind(this); + + this.caller = new AsyncCaller(config); + } + + /** + * Default function to import data to typesense + * @param data + * @param collectionName + */ + private async importToTypesense< + T extends Record = Record + >(data: T[], collectionName: string) { + const chunkSize = 2000; + for (let i = 0; i < data.length; i += chunkSize) { + const chunk = data.slice(i, i + chunkSize); + + await this.caller.call(async () => { + await this.client + .collections(collectionName) + .documents() + .import(chunk, { action: "emplace", dirty_values: "drop" }); + }); + } + } + + /** + * Transform documents to Typesense records. + * @param documents + * @returns Typesense records. + */ + _documentsToTypesenseRecords( + documents: Document[], + vectors: number[][] + ): Record[] { + const metadatas = documents.map((doc) => doc.metadata); + + const typesenseDocuments = documents.map((doc, index) => { + const metadata = metadatas[index]; + const objectWithMetadatas: Record = {}; + + this.metadataColumnNames.forEach((metadataColumnName) => { + objectWithMetadatas[metadataColumnName] = metadata[metadataColumnName]; + }); + + return { + [this.pageContentColumnName]: doc.pageContent, + [this.vectorColumnName]: vectors[index], + ...objectWithMetadatas, + }; + }); + + return typesenseDocuments; + } + + /** + * Transform the Typesense records to documents. + * @param typesenseRecords + * @returns documents + */ + _typesenseRecordsToDocuments( + typesenseRecords: + | { document?: Record; vector_distance: number }[] + | undefined + ): [Document, number][] { + const documents: [Document, number][] = + typesenseRecords?.map((hit) => { + const objectWithMetadatas: Record = {}; + const hitDoc = hit.document || {}; + this.metadataColumnNames.forEach((metadataColumnName) => { + objectWithMetadatas[metadataColumnName] = hitDoc[metadataColumnName]; + }); + + const document: Document = { + pageContent: (hitDoc[this.pageContentColumnName] as string) || "", + metadata: objectWithMetadatas, + }; + return [document, hit.vector_distance]; + }) || []; + + return documents; + } + + /** + * Add documents to the vector store. + * Will be updated if in the metadata there is a document with the same id if is using the default import function. + * Metadata will be added in the columns of the schema based on metadataColumnNames. + * @param documents Documents to add. + */ + async addDocuments(documents: Document[]) { + const typesenseDocuments = this._documentsToTypesenseRecords( + documents, + await this.embeddings.embedDocuments( + documents.map((doc) => doc.pageContent) + ) + ); + await this.import(typesenseDocuments, this.schemaName); + } + + /** + * Adds vectors to the vector store. + * @param vectors Vectors to add. + * @param documents Documents associated with the vectors. + */ + async addVectors(vectors: number[][], documents: Document[]) { + const typesenseDocuments = this._documentsToTypesenseRecords( + documents, + vectors + ); + await this.import(typesenseDocuments, this.schemaName); + } + + /** + * Search for similar documents with their similarity score. + * @param vectorPrompt vector to search for + * @param k amount of results to return + * @returns similar documents with their similarity score + */ + async similaritySearchVectorWithScore( + vectorPrompt: number[], + k?: number, + filter: this["FilterType"] = {} + ) { + const amount = k || this.searchParams.per_page || 5; + const vector_query = `${this.vectorColumnName}:([${vectorPrompt}], k:${amount})`; + const typesenseResponse = await this.client.multiSearch.perform( + { + searches: [ + { + ...this.searchParams, + ...filter, + per_page: amount, + vector_query, + collection: this.schemaName, + }, + ], + }, + {} + ); + const results = typesenseResponse.results[0].hits; + + const hits = results?.map((hit: VectorSearchResponseHit) => ({ + document: hit?.document || {}, + vector_distance: hit?.vector_distance || 2, + })) as + | { document: Record; vector_distance: number }[] + | undefined; + + return this._typesenseRecordsToDocuments(hits); + } + + /** + * Delete documents from the vector store. + * @param documentIds ids of the documents to delete + */ + async deleteDocuments(documentIds: string[]) { + await this.client + .collections(this.schemaName) + .documents() + .delete({ + filter_by: `id:=${documentIds.join(",")}`, + }); + } + + /** + * Create a vector store from documents. + * @param docs documents + * @param embeddings embeddings + * @param config Typesense configuration + * @returns Typesense vector store + * @warning You can omit this method, and only use the constructor and addDocuments. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + config: TypesenseConfig + ): Promise { + const instance = new Typesense(embeddings, config); + await instance.addDocuments(docs); + + return instance; + } + + /** + * Create a vector store from texts. + * @param texts + * @param metadatas + * @param embeddings + * @param config + * @returns Typesense vector store + */ + static async fromTexts( + texts: string[], + metadatas: object[], + embeddings: Embeddings, + config: TypesenseConfig + ) { + const instance = new Typesense(embeddings, config); + const documents: Document[] = texts.map((text, i) => ({ + pageContent: text, + metadata: metadatas[i] || {}, + })); + await instance.addDocuments(documents); + + return instance; + } +} diff --git a/libs/langchain-community/src/vectorstores/usearch.ts b/libs/langchain-community/src/vectorstores/usearch.ts new file mode 100644 index 000000000000..69711f75e877 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/usearch.ts @@ -0,0 +1,223 @@ +import usearch from "usearch"; +import * as uuid from "uuid"; +import { Embeddings } from "@langchain/core/embeddings"; +import { SaveableVectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; +import { SynchronousInMemoryDocstore } from "../stores/doc/in_memory.js"; + +/** + * Interface that defines the arguments that can be passed to the + * `USearch` constructor. It includes optional properties for a + * `docstore`, `index`, and `mapping`. + */ +export interface USearchArgs { + docstore?: SynchronousInMemoryDocstore; + index?: usearch.Index; + mapping?: Record; +} + +/** + * Class that extends `SaveableVectorStore` and provides methods for + * adding documents and vectors to a `usearch` index, performing + * similarity searches, and saving the index. + */ +export class USearch extends SaveableVectorStore { + _index?: usearch.Index; + + _mapping: Record; + + docstore: SynchronousInMemoryDocstore; + + args: USearchArgs; + + _vectorstoreType(): string { + return "usearch"; + } + + constructor(embeddings: Embeddings, args: USearchArgs) { + super(embeddings, args); + this.args = args; + this._index = args.index; + this._mapping = args.mapping ?? {}; + this.embeddings = embeddings; + this.docstore = args?.docstore ?? new SynchronousInMemoryDocstore(); + } + + /** + * Method that adds documents to the `usearch` index. It generates + * embeddings for the documents and adds them to the index. + * @param documents An array of `Document` instances to be added to the index. + * @returns A promise that resolves with an array of document IDs. + */ + async addDocuments(documents: Document[]) { + const texts = documents.map(({ pageContent }) => pageContent); + return this.addVectors( + await this.embeddings.embedDocuments(texts), + documents + ); + } + + public get index(): usearch.Index { + if (!this._index) { + throw new Error( + "Vector store not initialised yet. Try calling `fromTexts` or `fromDocuments` first." + ); + } + return this._index; + } + + private set index(index: usearch.Index) { + this._index = index; + } + + /** + * Method that adds vectors to the `usearch` index. It also updates the + * mapping between vector IDs and document IDs. + * @param vectors An array of vectors to be added to the index. + * @param documents An array of `Document` instances corresponding to the vectors. + * @returns A promise that resolves with an array of document IDs. + */ + async addVectors(vectors: number[][], documents: Document[]) { + if (vectors.length === 0) { + return []; + } + if (vectors.length !== documents.length) { + throw new Error(`Vectors and documents must have the same length`); + } + const dv = vectors[0].length; + if (!this._index) { + this._index = new usearch.Index({ + metric: "l2sq", + connectivity: BigInt(16), + dimensions: BigInt(dv), + }); + } + const d = this.index.dimensions(); + if (BigInt(dv) !== d) { + throw new Error( + `Vectors must have the same length as the number of dimensions (${d})` + ); + } + + const docstoreSize = this.index.size(); + const documentIds = []; + for (let i = 0; i < vectors.length; i += 1) { + const documentId = uuid.v4(); + documentIds.push(documentId); + const id = Number(docstoreSize) + i; + this.index.add(BigInt(id), new Float32Array(vectors[i])); + this._mapping[id] = documentId; + this.docstore.add({ [documentId]: documents[i] }); + } + return documentIds; + } + + /** + * Method that performs a similarity search in the `usearch` index. It + * returns the `k` most similar documents to a given query vector, along + * with their similarity scores. + * @param query The query vector. + * @param k The number of most similar documents to return. + * @returns A promise that resolves with an array of tuples, each containing a `Document` and its similarity score. + */ + async similaritySearchVectorWithScore(query: number[], k: number) { + const d = this.index.dimensions(); + if (BigInt(query.length) !== d) { + throw new Error( + `Query vector must have the same length as the number of dimensions (${d})` + ); + } + if (k > this.index.size()) { + const total = this.index.size(); + console.warn( + `k (${k}) is greater than the number of elements in the index (${total}), setting k to ${total}` + ); + // eslint-disable-next-line no-param-reassign + k = Number(total); + } + const result = this.index.search(new Float32Array(query), BigInt(k)); + + const return_list: [Document, number][] = []; + for (let i = 0; i < result.count; i += 1) { + const uuid = this._mapping[Number(result.keys[i])]; + return_list.push([this.docstore.search(uuid), result.distances[i]]); + } + + return return_list; + } + + /** + * Method that saves the `usearch` index and the document store to disk. + * @param directory The directory where the index and document store should be saved. + * @returns A promise that resolves when the save operation is complete. + */ + async save(directory: string) { + const fs = await import("node:fs/promises"); + const path = await import("node:path"); + await fs.mkdir(directory, { recursive: true }); + await Promise.all([ + this.index.save(path.join(directory, "usearch.index")), + await fs.writeFile( + path.join(directory, "docstore.json"), + JSON.stringify([ + Array.from(this.docstore._docs.entries()), + this._mapping, + ]) + ), + ]); + } + + /** + * Static method that creates a new `USearch` instance from a list of + * texts. It generates embeddings for the texts and adds them to the + * `usearch` index. + * @param texts An array of texts to be added to the index. + * @param metadatas Metadata associated with the texts. + * @param embeddings An instance of `Embeddings` used to generate embeddings for the texts. + * @param dbConfig Optional configuration for the document store. + * @returns A promise that resolves with a new `USearch` instance. + */ + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + dbConfig?: { + docstore?: SynchronousInMemoryDocstore; + } + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + return this.fromDocuments(docs, embeddings, dbConfig); + } + + /** + * Static method that creates a new `USearch` instance from a list of + * documents. It generates embeddings for the documents and adds them to + * the `usearch` index. + * @param docs An array of `Document` instances to be added to the index. + * @param embeddings An instance of `Embeddings` used to generate embeddings for the documents. + * @param dbConfig Optional configuration for the document store. + * @returns A promise that resolves with a new `USearch` instance. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig?: { + docstore?: SynchronousInMemoryDocstore; + } + ): Promise { + const args: USearchArgs = { + docstore: dbConfig?.docstore, + }; + const instance = new this(embeddings, args); + await instance.addDocuments(docs); + return instance; + } +} diff --git a/libs/langchain-community/src/vectorstores/vectara.ts b/libs/langchain-community/src/vectorstores/vectara.ts new file mode 100644 index 000000000000..d3afa22fadbe --- /dev/null +++ b/libs/langchain-community/src/vectorstores/vectara.ts @@ -0,0 +1,532 @@ +import * as uuid from "uuid"; + +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { FakeEmbeddings } from "../utils/testing.js"; + +/** + * Interface for the arguments required to initialize a VectaraStore + * instance. + */ +export interface VectaraLibArgs { + customerId: number; + corpusId: number | number[]; + apiKey: string; + verbose?: boolean; + source?: string; +} + +/** + * Interface for the headers required for Vectara API calls. + */ +interface VectaraCallHeader { + headers: { + "x-api-key": string; + "Content-Type": string; + "customer-id": string; + "X-Source": string; + }; +} + +/** + * Interface for the file objects to be uploaded to Vectara. + */ +export interface VectaraFile { + // The contents of the file to be uploaded. + blob: Blob; + // The name of the file to be uploaded. + fileName: string; +} + +/** + * Interface for the filter options used in Vectara API calls. + */ +export interface VectaraFilter { + // Example of a vectara filter string can be: "doc.rating > 3.0 and part.lang = 'deu'" + // See https://docs.vectara.com/docs/search-apis/sql/filter-overview for more details. + filter?: string; + // Improve retrieval accuracy by adjusting the balance (from 0 to 1), known as lambda, + // between neural search and keyword-based search factors. Values between 0.01 and 0.2 tend to work well. + // see https://docs.vectara.com/docs/api-reference/search-apis/lexical-matching for more details. + lambda?: number; + // The number of sentences before/after the matching segment to add to the context. + contextConfig?: VectaraContextConfig; +} + +/** + * Interface for the context configuration used in Vectara API calls. + */ +export interface VectaraContextConfig { + // The number of sentences before the matching segment to add. Default is 2. + sentencesBefore?: number; + // The number of sentences after the matching segment to add. Default is 2. + sentencesAfter?: number; +} + +/** + * Class for interacting with the Vectara API. Extends the VectorStore + * class. + */ +export class VectaraStore extends VectorStore { + get lc_secrets(): { [key: string]: string } { + return { + apiKey: "VECTARA_API_KEY", + corpusId: "VECTARA_CORPUS_ID", + customerId: "VECTARA_CUSTOMER_ID", + }; + } + + get lc_aliases(): { [key: string]: string } { + return { + apiKey: "vectara_api_key", + corpusId: "vectara_corpus_id", + customerId: "vectara_customer_id", + }; + } + + declare FilterType: VectaraFilter; + + private apiEndpoint = "api.vectara.io"; + + private apiKey: string; + + private corpusId: number[]; + + private customerId: number; + + private verbose: boolean; + + private source: string; + + private vectaraApiTimeoutSeconds = 60; + + _vectorstoreType(): string { + return "vectara"; + } + + constructor(args: VectaraLibArgs) { + // Vectara doesn't need embeddings, but we need to pass something to the parent constructor + // The embeddings are abstracted out from the user in Vectara. + super(new FakeEmbeddings(), args); + + const apiKey = args.apiKey ?? getEnvironmentVariable("VECTARA_API_KEY"); + if (!apiKey) { + throw new Error("Vectara api key is not provided."); + } + this.apiKey = apiKey; + this.source = args.source ?? "langchainjs"; + + const corpusId = + args.corpusId ?? + getEnvironmentVariable("VECTARA_CORPUS_ID") + ?.split(",") + .map((id) => { + const num = Number(id); + if (Number.isNaN(num)) + throw new Error("Vectara corpus id is not a number."); + return num; + }); + if (!corpusId) { + throw new Error("Vectara corpus id is not provided."); + } + + if (typeof corpusId === "number") { + this.corpusId = [corpusId]; + } else { + if (corpusId.length === 0) + throw new Error("Vectara corpus id is not provided."); + this.corpusId = corpusId; + } + + const customerId = + args.customerId ?? getEnvironmentVariable("VECTARA_CUSTOMER_ID"); + if (!customerId) { + throw new Error("Vectara customer id is not provided."); + } + this.customerId = customerId; + + this.verbose = args.verbose ?? false; + } + + /** + * Returns a header for Vectara API calls. + * @returns A Promise that resolves to a VectaraCallHeader object. + */ + async getJsonHeader(): Promise { + return { + headers: { + "x-api-key": this.apiKey, + "Content-Type": "application/json", + "customer-id": this.customerId.toString(), + "X-Source": this.source, + }, + }; + } + + /** + * Throws an error, as this method is not implemented. Use addDocuments + * instead. + * @param _vectors Not used. + * @param _documents Not used. + * @returns Does not return a value. + */ + async addVectors( + _vectors: number[][], + _documents: Document[] + ): Promise { + throw new Error( + "Method not implemented. Please call addDocuments instead." + ); + } + + /** + * Method to delete data from the Vectara corpus. + * @param params an array of document IDs to be deleted + * @returns Promise that resolves when the deletion is complete. + */ + async deleteDocuments(ids: string[]): Promise { + if (ids && ids.length > 0) { + const headers = await this.getJsonHeader(); + for (const id of ids) { + const data = { + customer_id: this.customerId, + corpus_id: this.corpusId[0], + document_id: id, + }; + + try { + const controller = new AbortController(); + const timeout = setTimeout( + () => controller.abort(), + this.vectaraApiTimeoutSeconds * 1000 + ); + const response = await fetch( + `https://${this.apiEndpoint}/v1/delete-doc`, + { + method: "POST", + headers: headers?.headers, + body: JSON.stringify(data), + signal: controller.signal, + } + ); + clearTimeout(timeout); + if (response.status !== 200) { + throw new Error( + `Vectara API returned status code ${response.status} when deleting document ${id}` + ); + } + } catch (e) { + const error = new Error(`Error ${(e as Error).message}`); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (error as any).code = 500; + throw error; + } + } + } else { + throw new Error(`no "ids" specified for deletion`); + } + } + + /** + * Adds documents to the Vectara store. + * @param documents An array of Document objects to add to the Vectara store. + * @returns A Promise that resolves to an array of document IDs indexed in Vectara. + */ + async addDocuments(documents: Document[]): Promise { + if (this.corpusId.length > 1) + throw new Error("addDocuments does not support multiple corpus ids"); + + const headers = await this.getJsonHeader(); + const doc_ids: string[] = []; + let countAdded = 0; + for (const document of documents) { + const doc_id: string = document.metadata?.document_id ?? uuid.v4(); + const data = { + customer_id: this.customerId, + corpus_id: this.corpusId[0], + document: { + document_id: doc_id, + title: document.metadata?.title ?? "", + metadata_json: JSON.stringify(document.metadata ?? {}), + section: [ + { + text: document.pageContent, + }, + ], + }, + }; + + try { + const controller = new AbortController(); + const timeout = setTimeout( + () => controller.abort(), + this.vectaraApiTimeoutSeconds * 1000 + ); + const response = await fetch(`https://${this.apiEndpoint}/v1/index`, { + method: "POST", + headers: headers?.headers, + body: JSON.stringify(data), + signal: controller.signal, + }); + clearTimeout(timeout); + const result = await response.json(); + if ( + result.status?.code !== "OK" && + result.status?.code !== "ALREADY_EXISTS" + ) { + const error = new Error( + `Vectara API returned status code ${ + result.status?.code + }: ${JSON.stringify(result.message)}` + ); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (error as any).code = 500; + throw error; + } else { + countAdded += 1; + doc_ids.push(doc_id); + } + } catch (e) { + const error = new Error( + `Error ${(e as Error).message} while adding document` + ); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (error as any).code = 500; + throw error; + } + } + if (this.verbose) { + console.log(`Added ${countAdded} documents to Vectara`); + } + + return doc_ids; + } + + /** + * Vectara provides a way to add documents directly via their API. This API handles + * pre-processing and chunking internally in an optimal manner. This method is a wrapper + * to utilize that API within LangChain. + * + * @param files An array of VectaraFile objects representing the files and their respective file names to be uploaded to Vectara. + * @param metadata Optional. An array of metadata objects corresponding to each file in the `filePaths` array. + * @returns A Promise that resolves to the number of successfully uploaded files. + */ + async addFiles( + files: VectaraFile[], + metadatas: Record | undefined = undefined + ) { + if (this.corpusId.length > 1) + throw new Error("addFiles does not support multiple corpus ids"); + + const doc_ids: string[] = []; + + for (const [index, file] of files.entries()) { + const md = metadatas ? metadatas[index] : {}; + + const data = new FormData(); + data.append("file", file.blob, file.fileName); + data.append("doc-metadata", JSON.stringify(md)); + + const response = await fetch( + `https://api.vectara.io/v1/upload?c=${this.customerId}&o=${this.corpusId[0]}&d=true`, + { + method: "POST", + headers: { + "x-api-key": this.apiKey, + "X-Source": this.source, + }, + body: data, + } + ); + + const { status } = response; + if (status === 409) { + throw new Error(`File at index ${index} already exists in Vectara`); + } else if (status !== 200) { + throw new Error(`Vectara API returned status code ${status}`); + } else { + const result = await response.json(); + const doc_id = result.document.documentId; + doc_ids.push(doc_id); + } + } + + if (this.verbose) { + console.log(`Uploaded ${files.length} files to Vectara`); + } + + return doc_ids; + } + + /** + * Performs a similarity search and returns documents along with their + * scores. + * @param query The query string for the similarity search. + * @param k Optional. The number of results to return. Default is 10. + * @param filter Optional. A VectaraFilter object to refine the search results. + * @returns A Promise that resolves to an array of tuples, each containing a Document and its score. + */ + async similaritySearchWithScore( + query: string, + k = 10, + filter: VectaraFilter | undefined = undefined + ): Promise<[Document, number][]> { + const headers = await this.getJsonHeader(); + + const corpusKeys = this.corpusId.map((corpusId) => ({ + customerId: this.customerId, + corpusId, + metadataFilter: filter?.filter ?? "", + lexicalInterpolationConfig: { lambda: filter?.lambda ?? 0.025 }, + })); + + const data = { + query: [ + { + query, + numResults: k, + contextConfig: { + sentencesAfter: filter?.contextConfig?.sentencesAfter ?? 2, + sentencesBefore: filter?.contextConfig?.sentencesBefore ?? 2, + }, + corpusKey: corpusKeys, + }, + ], + }; + + const controller = new AbortController(); + const timeout = setTimeout( + () => controller.abort(), + this.vectaraApiTimeoutSeconds * 1000 + ); + const response = await fetch(`https://${this.apiEndpoint}/v1/query`, { + method: "POST", + headers: headers?.headers, + body: JSON.stringify(data), + signal: controller.signal, + }); + clearTimeout(timeout); + if (response.status !== 200) { + throw new Error(`Vectara API returned status code ${response.status}`); + } + + const result = await response.json(); + const responses = result.responseSet[0].response; + const documents = result.responseSet[0].document; + + for (let i = 0; i < responses.length; i += 1) { + const responseMetadata = responses[i].metadata; + const documentMetadata = documents[responses[i].documentIndex].metadata; + const combinedMetadata: Record = {}; + + responseMetadata.forEach((item: { name: string; value: unknown }) => { + combinedMetadata[item.name] = item.value; + }); + + documentMetadata.forEach((item: { name: string; value: unknown }) => { + combinedMetadata[item.name] = item.value; + }); + + responses[i].metadata = combinedMetadata; + } + + const documentsAndScores = responses.map( + (response: { + text: string; + metadata: Record; + score: number; + }) => [ + new Document({ + pageContent: response.text, + metadata: response.metadata, + }), + response.score, + ] + ); + return documentsAndScores; + } + + /** + * Performs a similarity search and returns documents. + * @param query The query string for the similarity search. + * @param k Optional. The number of results to return. Default is 10. + * @param filter Optional. A VectaraFilter object to refine the search results. + * @returns A Promise that resolves to an array of Document objects. + */ + async similaritySearch( + query: string, + k = 10, + filter: VectaraFilter | undefined = undefined + ): Promise { + const resultWithScore = await this.similaritySearchWithScore( + query, + k, + filter + ); + return resultWithScore.map((result) => result[0]); + } + + /** + * Throws an error, as this method is not implemented. Use + * similaritySearch or similaritySearchWithScore instead. + * @param _query Not used. + * @param _k Not used. + * @param _filter Not used. + * @returns Does not return a value. + */ + async similaritySearchVectorWithScore( + _query: number[], + _k: number, + _filter?: VectaraFilter | undefined + ): Promise<[Document, number][]> { + throw new Error( + "Method not implemented. Please call similaritySearch or similaritySearchWithScore instead." + ); + } + + /** + * Creates a VectaraStore instance from texts. + * @param texts An array of text strings. + * @param metadatas Metadata for the texts. Can be a single object or an array of objects. + * @param _embeddings Not used. + * @param args A VectaraLibArgs object for initializing the VectaraStore instance. + * @returns A Promise that resolves to a VectaraStore instance. + */ + static fromTexts( + texts: string[], + metadatas: object | object[], + _embeddings: Embeddings, + args: VectaraLibArgs + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + + return VectaraStore.fromDocuments(docs, new FakeEmbeddings(), args); + } + + /** + * Creates a VectaraStore instance from documents. + * @param docs An array of Document objects. + * @param _embeddings Not used. + * @param args A VectaraLibArgs object for initializing the VectaraStore instance. + * @returns A Promise that resolves to a VectaraStore instance. + */ + static async fromDocuments( + docs: Document[], + _embeddings: Embeddings, + args: VectaraLibArgs + ): Promise { + const instance = new this(args); + await instance.addDocuments(docs); + return instance; + } +} diff --git a/libs/langchain-community/src/vectorstores/vercel_postgres.ts b/libs/langchain-community/src/vectorstores/vercel_postgres.ts new file mode 100644 index 000000000000..782428673b2a --- /dev/null +++ b/libs/langchain-community/src/vectorstores/vercel_postgres.ts @@ -0,0 +1,393 @@ +import { + type VercelPool, + type VercelPoolClient, + type VercelPostgresPoolConfig, + createPool, +} from "@vercel/postgres"; +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +type Metadata = Record>; + +/** + * Interface that defines the arguments required to create a + * `VercelPostgres` instance. It includes Postgres connection options, + * table name, filter, and verbosity level. + */ +export interface VercelPostgresFields { + pool: VercelPool; + client: VercelPoolClient; + tableName?: string; + columns?: { + idColumnName?: string; + vectorColumnName?: string; + contentColumnName?: string; + metadataColumnName?: string; + }; + filter?: Metadata; + verbose?: boolean; +} + +/** + * Class that provides an interface to a Vercel Postgres vector database. It + * extends the `VectorStore` base class and implements methods for adding + * documents and vectors and performing similarity searches. + */ +export class VercelPostgres extends VectorStore { + declare FilterType: Metadata; + + tableName: string; + + idColumnName: string; + + vectorColumnName: string; + + contentColumnName: string; + + metadataColumnName: string; + + filter?: Metadata; + + _verbose?: boolean; + + pool: VercelPool; + + client: VercelPoolClient; + + _vectorstoreType(): string { + return "vercel"; + } + + private constructor(embeddings: Embeddings, config: VercelPostgresFields) { + super(embeddings, config); + this.tableName = config.tableName ?? "langchain_vectors"; + this.filter = config.filter; + + this.vectorColumnName = config.columns?.vectorColumnName ?? "embedding"; + this.contentColumnName = config.columns?.contentColumnName ?? "text"; + this.idColumnName = config.columns?.idColumnName ?? "id"; + this.metadataColumnName = config.columns?.metadataColumnName ?? "metadata"; + + this.pool = config.pool; + this.client = config.client; + + this._verbose = + getEnvironmentVariable("LANGCHAIN_VERBOSE") === "true" ?? + !!config.verbose; + } + + /** + * Static method to create a new `VercelPostgres` instance from a + * connection. It creates a table if one does not exist, and calls + * `connect` to return a new instance of `VercelPostgres`. + * + * @param embeddings - Embeddings instance. + * @param fields - `VercelPostgres` configuration options. + * @returns A new instance of `VercelPostgres`. + */ + static async initialize( + embeddings: Embeddings, + config?: Partial & { + postgresConnectionOptions?: VercelPostgresPoolConfig; + } + ): Promise { + // Default maxUses to 1 for edge environments: + // https://github.com/vercel/storage/tree/main/packages/postgres#a-note-on-edge-environments + const pool = + config?.pool ?? + createPool({ maxUses: 1, ...config?.postgresConnectionOptions }); + const client = config?.client ?? (await pool.connect()); + const postgresqlVectorStore = new VercelPostgres(embeddings, { + ...config, + pool, + client, + }); + + await postgresqlVectorStore.ensureTableInDatabase(); + + return postgresqlVectorStore; + } + + /** + * Method to add documents to the vector store. It converts the documents into + * vectors, and adds them to the store. + * + * @param documents - Array of `Document` instances. + * @returns Promise that resolves when the documents have been added. + */ + async addDocuments( + documents: Document[], + options?: { ids?: string[] } + ): Promise { + const texts = documents.map(({ pageContent }) => pageContent); + + return this.addVectors( + await this.embeddings.embedDocuments(texts), + documents, + options + ); + } + + /** + * Generates the SQL placeholders for a specific row at the provided index. + * + * @param index - The index of the row for which placeholders need to be generated. + * @returns The SQL placeholders for the row values. + */ + protected generatePlaceholderForRowAt( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + row: (string | Record)[], + index: number + ): string { + const base = index * row.length; + return `(${row.map((_, j) => `$${base + 1 + j}`)})`; + } + + /** + * Constructs the SQL query for inserting rows into the specified table. + * + * @param rows - The rows of data to be inserted, consisting of values and records. + * @param chunkIndex - The starting index for generating query placeholders based on chunk positioning. + * @returns The complete SQL INSERT INTO query string. + */ + protected async runInsertQuery( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + rows: (string | Record)[][], + useIdColumn: boolean + ) { + const values = rows.map((row, j) => + this.generatePlaceholderForRowAt(row, j) + ); + const flatValues = rows.flat(); + return this.client.query( + ` + INSERT INTO ${this.tableName} ( + ${useIdColumn ? `${this.idColumnName},` : ""} + ${this.contentColumnName}, + ${this.vectorColumnName}, + ${this.metadataColumnName} + ) VALUES ${values.join(", ")} + ON CONFLICT (${this.idColumnName}) + DO UPDATE + SET + ${this.contentColumnName} = EXCLUDED.${this.contentColumnName}, + ${this.vectorColumnName} = EXCLUDED.${this.vectorColumnName}, + ${this.metadataColumnName} = EXCLUDED.${this.metadataColumnName} + RETURNING ${this.idColumnName}`, + flatValues + ); + } + + /** + * Method to add vectors to the vector store. It converts the vectors into + * rows and inserts them into the database. + * + * @param vectors - Array of vectors. + * @param documents - Array of `Document` instances. + * @returns Promise that resolves when the vectors have been added. + */ + async addVectors( + vectors: number[][], + documents: Document[], + options?: { ids?: string[] } + ): Promise { + if (options?.ids !== undefined && options?.ids.length !== vectors.length) { + throw new Error( + `If provided, the length of "ids" must be the same as the number of vectors.` + ); + } + const rows = vectors.map((embedding, idx) => { + const embeddingString = `[${embedding.join(",")}]`; + const row = [ + documents[idx].pageContent, + embeddingString, + documents[idx].metadata, + ]; + if (options?.ids) { + return [options.ids[idx], ...row]; + } + return row; + }); + + const chunkSize = 500; + + const ids = []; + + for (let i = 0; i < rows.length; i += chunkSize) { + const chunk = rows.slice(i, i + chunkSize); + try { + const result = await this.runInsertQuery( + chunk, + options?.ids !== undefined + ); + ids.push(...result.rows.map((row) => row[this.idColumnName])); + } catch (e) { + console.error(e); + throw new Error(`Error inserting: ${(e as Error).message}`); + } + } + return ids; + } + + /** + * Method to perform a similarity search in the vector store. It returns + * the `k` most similar documents to the query vector, along with their + * similarity scores. + * + * @param query - Query vector. + * @param k - Number of most similar documents to return. + * @param filter - Optional filter to apply to the search. + * @returns Promise that resolves with an array of tuples, each containing a `Document` and its similarity score. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: this["FilterType"] + ): Promise<[Document, number][]> { + const embeddingString = `[${query.join(",")}]`; + const _filter: this["FilterType"] = filter ?? {}; + const whereClauses = []; + const values = [embeddingString, k]; + let paramCount = values.length; + + for (const [key, value] of Object.entries(_filter)) { + if (typeof value === "object" && value !== null) { + const currentParamCount = paramCount; + const placeholders = value.in + .map((_, index) => `$${currentParamCount + index + 1}`) + .join(","); + whereClauses.push( + `${this.metadataColumnName}->>'${key}' IN (${placeholders})` + ); + values.push(...value.in); + paramCount += value.in.length; + } else { + paramCount += 1; + whereClauses.push( + `${this.metadataColumnName}->>'${key}' = $${paramCount}` + ); + values.push(value); + } + } + + const whereClause = whereClauses.length + ? `WHERE ${whereClauses.join(" AND ")}` + : ""; + + const queryString = ` + SELECT *, ${this.vectorColumnName} <=> $1 as "_distance" + FROM ${this.tableName} + ${whereClause} + ORDER BY "_distance" ASC + LIMIT $2;`; + + const documents = (await this.client.query(queryString, values)).rows; + const results = [] as [Document, number][]; + for (const doc of documents) { + if (doc._distance != null && doc[this.contentColumnName] != null) { + const document = new Document({ + pageContent: doc[this.contentColumnName], + metadata: doc[this.metadataColumnName], + }); + results.push([document, doc._distance]); + } + } + return results; + } + + async delete(params: { ids?: string[]; deleteAll?: boolean }): Promise { + if (params.ids !== undefined) { + await this.client.query( + `DELETE FROM ${this.tableName} WHERE ${ + this.idColumnName + } IN (${params.ids.map((_, idx) => `$${idx + 1}`)})`, + params.ids + ); + } else if (params.deleteAll) { + await this.client.query(`TRUNCATE TABLE ${this.tableName}`); + } + } + + /** + * Method to ensure the existence of the table in the database. It creates + * the table if it does not already exist. + * + * @returns Promise that resolves when the table has been ensured. + */ + async ensureTableInDatabase(): Promise { + await this.client.query(`CREATE EXTENSION IF NOT EXISTS vector;`); + await this.client.query(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`); + await this.client.query(`CREATE TABLE IF NOT EXISTS "${this.tableName}" ( + "${this.idColumnName}" uuid NOT NULL DEFAULT uuid_generate_v4() PRIMARY KEY, + "${this.contentColumnName}" text, + "${this.metadataColumnName}" jsonb, + "${this.vectorColumnName}" vector + );`); + } + + /** + * Static method to create a new `VercelPostgres` instance from an + * array of texts and their metadata. It converts the texts into + * `Document` instances and adds them to the store. + * + * @param texts - Array of texts. + * @param metadatas - Array of metadata objects or a single metadata object. + * @param embeddings - Embeddings instance. + * @param fields - `VercelPostgres` configuration options. + * @returns Promise that resolves with a new instance of `VercelPostgres`. + */ + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + dbConfig?: Partial & { + postgresConnectionOptions?: VercelPostgresPoolConfig; + } + ): Promise { + const docs = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + + return this.fromDocuments(docs, embeddings, dbConfig); + } + + /** + * Static method to create a new `VercelPostgres` instance from an + * array of `Document` instances. It adds the documents to the store. + * + * @param docs - Array of `Document` instances. + * @param embeddings - Embeddings instance. + * @param fields - `VercelPostgres` configuration options. + * @returns Promise that resolves with a new instance of `VercelPostgres`. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + dbConfig?: Partial & { + postgresConnectionOptions?: VercelPostgresPoolConfig; + } + ): Promise { + const instance = await this.initialize(embeddings, dbConfig); + await instance.addDocuments(docs); + + return instance; + } + + /** + * Closes all the clients in the pool and terminates the pool. + * + * @returns Promise that resolves when all clients are closed and the pool is terminated. + */ + async end(): Promise { + await this.client?.release(); + return this.pool.end(); + } +} diff --git a/libs/langchain-community/src/vectorstores/voy.ts b/libs/langchain-community/src/vectorstores/voy.ts new file mode 100644 index 000000000000..a1a341eb7701 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/voy.ts @@ -0,0 +1,191 @@ +import type { Voy as VoyOriginClient, SearchResult } from "voy-search"; +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; + +export type VoyClient = Omit< + VoyOriginClient, + "remove" | "size" | "serialize" | "free" +>; + +/** + * Internal interface for storing documents mappings. + */ +interface InternalDoc { + embeddings: number[]; + document: Document; +} + +/** + * Class that extends `VectorStore`. It allows to perform similarity search using + * Voi similarity search engine. The class requires passing Voy Client as an input parameter. + */ +export class VoyVectorStore extends VectorStore { + client: VoyClient; + + numDimensions: number | null = null; + + docstore: InternalDoc[] = []; + + _vectorstoreType(): string { + return "voi"; + } + + constructor(client: VoyClient, embeddings: Embeddings) { + super(embeddings, {}); + this.client = client; + this.embeddings = embeddings; + } + + /** + * Adds documents to the Voy database. The documents are embedded using embeddings provided while instantiating the class. + * @param documents An array of `Document` instances associated with the vectors. + */ + async addDocuments(documents: Document[]): Promise { + const texts = documents.map(({ pageContent }) => pageContent); + if (documents.length === 0) { + return; + } + + const firstVector = ( + await this.embeddings.embedDocuments(texts.slice(0, 1)) + )[0]; + if (this.numDimensions === null) { + this.numDimensions = firstVector.length; + } else if (this.numDimensions !== firstVector.length) { + throw new Error( + `Vectors must have the same length as the number of dimensions (${this.numDimensions})` + ); + } + const restResults = await this.embeddings.embedDocuments(texts.slice(1)); + await this.addVectors([firstVector, ...restResults], documents); + } + + /** + * Adds vectors to the Voy database. The vectors are associated with + * the provided documents. + * @param vectors An array of vectors to be added to the database. + * @param documents An array of `Document` instances associated with the vectors. + */ + async addVectors(vectors: number[][], documents: Document[]): Promise { + if (vectors.length === 0) { + return; + } + if (this.numDimensions === null) { + this.numDimensions = vectors[0].length; + } + + if (vectors.length !== documents.length) { + throw new Error(`Vectors and metadata must have the same length`); + } + if (!vectors.every((v) => v.length === this.numDimensions)) { + throw new Error( + `Vectors must have the same length as the number of dimensions (${this.numDimensions})` + ); + } + + vectors.forEach((item, idx) => { + const doc = documents[idx]; + this.docstore.push({ embeddings: item, document: doc }); + }); + const embeddings = this.docstore.map((item, idx) => ({ + id: String(idx), + embeddings: item.embeddings, + title: "", + url: "", + })); + this.client.index({ embeddings }); + } + + /** + * Searches for vectors in the Voy database that are similar to the + * provided query vector. + * @param query The query vector. + * @param k The number of similar vectors to return. + * @returns A promise that resolves with an array of tuples, each containing a `Document` instance and a similarity score. + */ + async similaritySearchVectorWithScore(query: number[], k: number) { + if (this.numDimensions === null) { + throw new Error("There aren't any elements in the index yet."); + } + if (query.length !== this.numDimensions) { + throw new Error( + `Query vector must have the same length as the number of dimensions (${this.numDimensions})` + ); + } + const itemsToQuery = Math.min(this.docstore.length, k); + if (itemsToQuery > this.docstore.length) { + console.warn( + `k (${k}) is greater than the number of elements in the index (${this.docstore.length}), setting k to ${itemsToQuery}` + ); + } + const results: SearchResult = this.client.search( + new Float32Array(query), + itemsToQuery + ); + return results.neighbors.map( + ({ id }, idx) => + [this.docstore[parseInt(id, 10)].document, idx] as [Document, number] + ); + } + + /** + * Method to delete data from the Voy index. It can delete data based + * on specific IDs or a filter. + * @param params Object that includes either an array of IDs or a filter for the data to be deleted. + * @returns Promise that resolves when the deletion is complete. + */ + async delete(params: { deleteAll?: boolean }): Promise { + if (params.deleteAll === true) { + await this.client.clear(); + } else { + throw new Error(`You must provide a "deleteAll" parameter.`); + } + } + + /** + * Creates a new `VoyVectorStore` instance from an array of text strings. The text + * strings are converted to `Document` instances and added to the Voy + * database. + * @param texts An array of text strings. + * @param metadatas An array of metadata objects or a single metadata object. If an array is provided, it must have the same length as the `texts` array. + * @param embeddings An `Embeddings` instance used to generate embeddings for the documents. + * @param client An instance of Voy client to use in the underlying operations. + * @returns A promise that resolves with a new `VoyVectorStore` instance. + */ + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + client: VoyClient + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + return VoyVectorStore.fromDocuments(docs, embeddings, client); + } + + /** + * Creates a new `VoyVectorStore` instance from an array of `Document` instances. + * The documents are added to the Voy database. + * @param docs An array of `Document` instances. + * @param embeddings An `Embeddings` instance used to generate embeddings for the documents. + * @param client An instance of Voy client to use in the underlying operations. + * @returns A promise that resolves with a new `VoyVectorStore` instance. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + client: VoyClient + ): Promise { + const instance = new VoyVectorStore(client, embeddings); + await instance.addDocuments(docs); + return instance; + } +} diff --git a/libs/langchain-community/src/vectorstores/weaviate.ts b/libs/langchain-community/src/vectorstores/weaviate.ts new file mode 100644 index 000000000000..cf45fcdda489 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/weaviate.ts @@ -0,0 +1,435 @@ +import * as uuid from "uuid"; +import type { + WeaviateClient, + WeaviateObject, + WhereFilter, +} from "weaviate-ts-client"; +import { + MaxMarginalRelevanceSearchOptions, + VectorStore, +} from "@langchain/core/vectorstores"; +import { Embeddings } from "@langchain/core/embeddings"; +import { Document } from "@langchain/core/documents"; +import { maximalMarginalRelevance } from "@langchain/core/utils/math"; + +// Note this function is not generic, it is designed specifically for Weaviate +// https://weaviate.io/developers/weaviate/config-refs/datatypes#introduction +export const flattenObjectForWeaviate = ( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + obj: Record +) => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const flattenedObject: Record = {}; + + for (const key in obj) { + if (!Object.hasOwn(obj, key)) { + continue; + } + const value = obj[key]; + if (typeof obj[key] === "object" && !Array.isArray(value)) { + const recursiveResult = flattenObjectForWeaviate(value); + + for (const deepKey in recursiveResult) { + if (Object.hasOwn(obj, key)) { + flattenedObject[`${key}_${deepKey}`] = recursiveResult[deepKey]; + } + } + } else if (Array.isArray(value)) { + if ( + value.length > 0 && + typeof value[0] !== "object" && + // eslint-disable-next-line @typescript-eslint/no-explicit-any + value.every((el: any) => typeof el === typeof value[0]) + ) { + // Weaviate only supports arrays of primitive types, + // where all elements are of the same type + flattenedObject[key] = value; + } + } else { + flattenedObject[key] = value; + } + } + + return flattenedObject; +}; + +/** + * Interface that defines the arguments required to create a new instance + * of the `WeaviateStore` class. It includes the Weaviate client, the name + * of the class in Weaviate, and optional keys for text and metadata. + */ +export interface WeaviateLibArgs { + client: WeaviateClient; + /** + * The name of the class in Weaviate. Must start with a capital letter. + */ + indexName: string; + textKey?: string; + metadataKeys?: string[]; + tenant?: string; +} + +interface ResultRow { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + [key: string]: any; +} + +/** + * Interface that defines a filter for querying data from Weaviate. It + * includes a distance and a `WhereFilter`. + */ +export interface WeaviateFilter { + distance?: number; + where: WhereFilter; +} + +/** + * Class that extends the `VectorStore` base class. It provides methods to + * interact with a Weaviate index, including adding vectors and documents, + * deleting data, and performing similarity searches. + */ +export class WeaviateStore extends VectorStore { + declare FilterType: WeaviateFilter; + + private client: WeaviateClient; + + private indexName: string; + + private textKey: string; + + private queryAttrs: string[]; + + private tenant?: string; + + _vectorstoreType(): string { + return "weaviate"; + } + + constructor(public embeddings: Embeddings, args: WeaviateLibArgs) { + super(embeddings, args); + + this.client = args.client; + this.indexName = args.indexName; + this.textKey = args.textKey || "text"; + this.queryAttrs = [this.textKey]; + this.tenant = args.tenant; + + if (args.metadataKeys) { + this.queryAttrs = [ + ...new Set([ + ...this.queryAttrs, + ...args.metadataKeys.filter((k) => { + // https://spec.graphql.org/June2018/#sec-Names + // queryAttrs need to be valid GraphQL Names + const keyIsValid = /^[_A-Za-z][_0-9A-Za-z]*$/.test(k); + if (!keyIsValid) { + console.warn( + `Skipping metadata key ${k} as it is not a valid GraphQL Name` + ); + } + return keyIsValid; + }), + ]), + ]; + } + } + + /** + * Method to add vectors and corresponding documents to the Weaviate + * index. + * @param vectors Array of vectors to be added. + * @param documents Array of documents corresponding to the vectors. + * @param options Optional parameter that can include specific IDs for the documents. + * @returns An array of document IDs. + */ + async addVectors( + vectors: number[][], + documents: Document[], + options?: { ids?: string[] } + ) { + const documentIds = options?.ids ?? documents.map((_) => uuid.v4()); + const batch: WeaviateObject[] = documents.map((document, index) => { + if (Object.hasOwn(document.metadata, "id")) + throw new Error( + "Document inserted to Weaviate vectorstore should not have `id` in their metadata." + ); + + const flattenedMetadata = flattenObjectForWeaviate(document.metadata); + return { + ...(this.tenant ? { tenant: this.tenant } : {}), + class: this.indexName, + id: documentIds[index], + vector: vectors[index], + properties: { + [this.textKey]: document.pageContent, + ...flattenedMetadata, + }, + }; + }); + + try { + const responses = await this.client.batch + .objectsBatcher() + .withObjects(...batch) + .do(); + // if storing vectors fails, we need to know why + const errorMessages: string[] = []; + responses.forEach((response) => { + if (response?.result?.errors?.error) { + errorMessages.push( + ...response.result.errors.error.map( + (err) => + err.message ?? + "!! Unfortunately no error message was presented in the API response !!" + ) + ); + } + }); + if (errorMessages.length > 0) { + throw new Error(errorMessages.join("\n")); + } + } catch (e) { + throw Error(`Error adding vectors: ${e}`); + } + return documentIds; + } + + /** + * Method to add documents to the Weaviate index. It first generates + * vectors for the documents using the embeddings, then adds the vectors + * and documents to the index. + * @param documents Array of documents to be added. + * @param options Optional parameter that can include specific IDs for the documents. + * @returns An array of document IDs. + */ + async addDocuments(documents: Document[], options?: { ids?: string[] }) { + return this.addVectors( + await this.embeddings.embedDocuments(documents.map((d) => d.pageContent)), + documents, + options + ); + } + + /** + * Method to delete data from the Weaviate index. It can delete data based + * on specific IDs or a filter. + * @param params Object that includes either an array of IDs or a filter for the data to be deleted. + * @returns Promise that resolves when the deletion is complete. + */ + async delete(params: { + ids?: string[]; + filter?: WeaviateFilter; + }): Promise { + const { ids, filter } = params; + + if (ids && ids.length > 0) { + for (const id of ids) { + let deleter = this.client.data + .deleter() + .withClassName(this.indexName) + .withId(id); + + if (this.tenant) { + deleter = deleter.withTenant(this.tenant); + } + + await deleter.do(); + } + } else if (filter) { + let batchDeleter = this.client.batch + .objectsBatchDeleter() + .withClassName(this.indexName) + .withWhere(filter.where); + + if (this.tenant) { + batchDeleter = batchDeleter.withTenant(this.tenant); + } + + await batchDeleter.do(); + } else { + throw new Error( + `This method requires either "ids" or "filter" to be set in the input object` + ); + } + } + + /** + * Method to perform a similarity search on the stored vectors in the + * Weaviate index. It returns the top k most similar documents and their + * similarity scores. + * @param query The query vector. + * @param k The number of most similar documents to return. + * @param filter Optional filter to apply to the search. + * @returns An array of tuples, where each tuple contains a document and its similarity score. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: WeaviateFilter + ): Promise<[Document, number][]> { + const resultsWithEmbedding = + await this.similaritySearchVectorWithScoreAndEmbedding(query, k, filter); + return resultsWithEmbedding.map(([document, score, _embedding]) => [ + document, + score, + ]); + } + + /** + * Method to perform a similarity search on the stored vectors in the + * Weaviate index. It returns the top k most similar documents, their + * similarity scores and embedding vectors. + * @param query The query vector. + * @param k The number of most similar documents to return. + * @param filter Optional filter to apply to the search. + * @returns An array of tuples, where each tuple contains a document, its similarity score and its embedding vector. + */ + async similaritySearchVectorWithScoreAndEmbedding( + query: number[], + k: number, + filter?: WeaviateFilter + ): Promise<[Document, number, number[]][]> { + try { + let builder = this.client.graphql + .get() + .withClassName(this.indexName) + .withFields( + `${this.queryAttrs.join(" ")} _additional { distance vector }` + ) + .withNearVector({ + vector: query, + distance: filter?.distance, + }) + .withLimit(k); + + if (this.tenant) { + builder = builder.withTenant(this.tenant); + } + + if (filter?.where) { + builder = builder.withWhere(filter.where); + } + + const result = await builder.do(); + + const documents: [Document, number, number[]][] = []; + for (const data of result.data.Get[this.indexName]) { + const { [this.textKey]: text, _additional, ...rest }: ResultRow = data; + + documents.push([ + new Document({ + pageContent: text, + metadata: rest, + }), + _additional.distance, + _additional.vector, + ]); + } + return documents; + } catch (e) { + throw Error(`'Error in similaritySearch' ${e}`); + } + } + + /** + * Return documents selected using the maximal marginal relevance. + * Maximal marginal relevance optimizes for similarity to the query AND diversity + * among selected documents. + * + * @param {string} query - Text to look up documents similar to. + * @param {number} options.k - Number of documents to return. + * @param {number} options.fetchK - Number of documents to fetch before passing to the MMR algorithm. + * @param {number} options.lambda - Number between 0 and 1 that determines the degree of diversity among the results, + * where 0 corresponds to maximum diversity and 1 to minimum diversity. + * @param {this["FilterType"]} options.filter - Optional filter + * @param _callbacks + * + * @returns {Promise} - List of documents selected by maximal marginal relevance. + */ + override async maxMarginalRelevanceSearch( + query: string, + options: MaxMarginalRelevanceSearchOptions, + _callbacks?: undefined + ): Promise { + const { k, fetchK = 20, lambda = 0.5, filter } = options; + const queryEmbedding: number[] = await this.embeddings.embedQuery(query); + const allResults: [Document, number, number[]][] = + await this.similaritySearchVectorWithScoreAndEmbedding( + queryEmbedding, + fetchK, + filter + ); + const embeddingList = allResults.map( + ([_doc, _score, embedding]) => embedding + ); + const mmrIndexes = maximalMarginalRelevance( + queryEmbedding, + embeddingList, + lambda, + k + ); + return mmrIndexes + .filter((idx) => idx !== -1) + .map((idx) => allResults[idx][0]); + } + + /** + * Static method to create a new `WeaviateStore` instance from a list of + * texts. It first creates documents from the texts and metadata, then + * adds the documents to the Weaviate index. + * @param texts Array of texts. + * @param metadatas Metadata for the texts. Can be a single object or an array of objects. + * @param embeddings Embeddings to be used for the texts. + * @param args Arguments required to create a new `WeaviateStore` instance. + * @returns A new `WeaviateStore` instance. + */ + static fromTexts( + texts: string[], + metadatas: object | object[], + embeddings: Embeddings, + args: WeaviateLibArgs + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + return WeaviateStore.fromDocuments(docs, embeddings, args); + } + + /** + * Static method to create a new `WeaviateStore` instance from a list of + * documents. It adds the documents to the Weaviate index. + * @param docs Array of documents. + * @param embeddings Embeddings to be used for the documents. + * @param args Arguments required to create a new `WeaviateStore` instance. + * @returns A new `WeaviateStore` instance. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + args: WeaviateLibArgs + ): Promise { + const instance = new this(embeddings, args); + await instance.addDocuments(docs); + return instance; + } + + /** + * Static method to create a new `WeaviateStore` instance from an existing + * Weaviate index. + * @param embeddings Embeddings to be used for the Weaviate index. + * @param args Arguments required to create a new `WeaviateStore` instance. + * @returns A new `WeaviateStore` instance. + */ + static async fromExistingIndex( + embeddings: Embeddings, + args: WeaviateLibArgs + ): Promise { + return new this(embeddings, args); + } +} diff --git a/libs/langchain-community/src/vectorstores/xata.ts b/libs/langchain-community/src/vectorstores/xata.ts new file mode 100644 index 000000000000..5b5c1d23bc5c --- /dev/null +++ b/libs/langchain-community/src/vectorstores/xata.ts @@ -0,0 +1,149 @@ +import { BaseClient } from "@xata.io/client"; +import { Embeddings } from "@langchain/core/embeddings"; +import { VectorStore } from "@langchain/core/vectorstores"; +import { Document } from "@langchain/core/documents"; + +/** + * Interface for the arguments required to create a XataClient. Includes + * the client instance and the table name. + */ +export interface XataClientArgs { + readonly client: XataClient; + readonly table: string; +} + +/** + * Type for the filter object used in Xata database queries. + */ +type XataFilter = object; + +/** + * Class for interacting with a Xata database as a VectorStore. Provides + * methods to add documents and vectors to the database, delete entries, + * and perform similarity searches. + */ +export class XataVectorSearch< + XataClient extends BaseClient +> extends VectorStore { + declare FilterType: XataFilter; + + private readonly client: XataClient; + + private readonly table: string; + + _vectorstoreType(): string { + return "xata"; + } + + constructor(embeddings: Embeddings, args: XataClientArgs) { + super(embeddings, args); + + this.client = args.client; + this.table = args.table; + } + + /** + * Method to add documents to the Xata database. Maps the page content of + * each document, embeds the documents using the embeddings, and adds the + * vectors to the database. + * @param documents Array of documents to be added. + * @param options Optional object containing an array of ids. + * @returns Promise resolving to an array of ids of the added documents. + */ + async addDocuments(documents: Document[], options?: { ids?: string[] }) { + const texts = documents.map(({ pageContent }) => pageContent); + return this.addVectors( + await this.embeddings.embedDocuments(texts), + documents, + options + ); + } + + /** + * Method to add vectors to the Xata database. Maps each vector to a row + * with the document's content, embedding, and metadata. Creates or + * replaces these rows in the Xata database. + * @param vectors Array of vectors to be added. + * @param documents Array of documents corresponding to the vectors. + * @param options Optional object containing an array of ids. + * @returns Promise resolving to an array of ids of the added vectors. + */ + async addVectors( + vectors: number[][], + documents: Document[], + options?: { ids?: string[] } + ) { + const rows = vectors + .map((embedding, idx) => ({ + content: documents[idx].pageContent, + embedding, + ...documents[idx].metadata, + })) + .map((row, idx) => { + if (options?.ids) { + return { id: options.ids[idx], ...row }; + } + return row; + }); + + const res = await this.client.db[this.table].createOrReplace(rows); + // Since we have an untyped BaseClient, it doesn't know the + // actual return type of the overload. + const results = res as unknown as { id: string }[]; + const returnedIds = results.map((row) => row.id); + return returnedIds; + } + + /** + * Method to delete entries from the Xata database. Deletes the entries + * with the provided ids. + * @param params Object containing an array of ids of the entries to be deleted. + * @returns Promise resolving to void. + */ + async delete(params: { ids: string[] }): Promise { + const { ids } = params; + await this.client.db[this.table].delete(ids); + } + + /** + * Method to perform a similarity search in the Xata database. Returns the + * k most similar documents along with their scores. + * @param query Query vector for the similarity search. + * @param k Number of most similar documents to return. + * @param filter Optional filter for the search. + * @returns Promise resolving to an array of tuples, each containing a Document and its score. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: XataFilter | undefined + ): Promise<[Document, number][]> { + const { records } = await this.client.db[this.table].vectorSearch( + "embedding", + query, + { + size: k, + filter, + } + ); + + return ( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + records?.map((record: any) => [ + new Document({ + pageContent: record.content, + metadata: Object.fromEntries( + Object.entries(record).filter( + ([key]) => + key !== "content" && + key !== "embedding" && + key !== "xata" && + key !== "id" + ) + ), + }), + record.xata.score, + ]) ?? [] + ); + } +} diff --git a/libs/langchain-community/src/vectorstores/zep.ts b/libs/langchain-community/src/vectorstores/zep.ts new file mode 100644 index 000000000000..0606c3437495 --- /dev/null +++ b/libs/langchain-community/src/vectorstores/zep.ts @@ -0,0 +1,427 @@ +import { + DocumentCollection, + IDocument, + NotFoundError, + ZepClient, +} from "@getzep/zep-js"; + +import { + MaxMarginalRelevanceSearchOptions, + VectorStore, +} from "@langchain/core/vectorstores"; +import { Embeddings } from "@langchain/core/embeddings"; +import { Document } from "@langchain/core/documents"; +import { Callbacks } from "@langchain/core/callbacks/manager"; +import { maximalMarginalRelevance } from "@langchain/core/utils/math"; +import { FakeEmbeddings } from "../utils/testing.js"; + +/** + * Interface for the arguments required to initialize a ZepVectorStore + * instance. + */ +export interface IZepArgs { + collection: DocumentCollection; +} + +/** + * Interface for the configuration options for a ZepVectorStore instance. + */ +export interface IZepConfig { + apiUrl: string; + apiKey?: string; + collectionName: string; + description?: string; + metadata?: Record; + embeddingDimensions?: number; + isAutoEmbedded?: boolean; +} + +/** + * Interface for the parameters required to delete documents from a + * ZepVectorStore instance. + */ +export interface IZepDeleteParams { + uuids: string[]; +} + +/** + * ZepVectorStore is a VectorStore implementation that uses the Zep long-term memory store as a backend. + * + * If the collection does not exist, it will be created automatically. + * + * Requires `zep-js` to be installed: + * ```bash + * npm install @getzep/zep-js + * ``` + * + * @property {ZepClient} client - The ZepClient instance used to interact with Zep's API. + * @property {Promise} initPromise - A promise that resolves when the collection is initialized. + * @property {DocumentCollection} collection - The Zep document collection. + */ +export class ZepVectorStore extends VectorStore { + public client: ZepClient; + + public collection: DocumentCollection; + + private initPromise: Promise; + + private autoEmbed = false; + + constructor(embeddings: Embeddings, args: IZepConfig) { + super(embeddings, args); + + this.embeddings = embeddings; + + // eslint-disable-next-line no-instanceof/no-instanceof + if (this.embeddings instanceof FakeEmbeddings) { + this.autoEmbed = true; + } + + this.initPromise = this.initCollection(args).catch((err) => { + console.error("Error initializing collection:", err); + throw err; + }); + } + + /** + * Initializes the document collection. If the collection does not exist, it creates a new one. + * + * @param {IZepConfig} args - The configuration object for the Zep API. + */ + private async initCollection(args: IZepConfig) { + this.client = await ZepClient.init(args.apiUrl, args.apiKey); + try { + this.collection = await this.client.document.getCollection( + args.collectionName + ); + + // If the Embedding passed in is fake, but the collection is not auto embedded, throw an error + // eslint-disable-next-line no-instanceof/no-instanceof + if (!this.collection.is_auto_embedded && this.autoEmbed) { + throw new Error(`You can't pass in FakeEmbeddings when collection ${args.collectionName} + is not set to auto-embed.`); + } + } catch (err) { + // eslint-disable-next-line no-instanceof/no-instanceof + if (err instanceof Error) { + // eslint-disable-next-line no-instanceof/no-instanceof + if (err instanceof NotFoundError || err.name === "NotFoundError") { + await this.createCollection(args); + } else { + throw err; + } + } + } + } + + /** + * Creates a new document collection. + * + * @param {IZepConfig} args - The configuration object for the Zep API. + */ + private async createCollection(args: IZepConfig) { + if (!args.embeddingDimensions) { + throw new Error(`Collection ${args.collectionName} not found. + You can create a new Collection by providing embeddingDimensions.`); + } + + this.collection = await this.client.document.addCollection({ + name: args.collectionName, + description: args.description, + metadata: args.metadata, + embeddingDimensions: args.embeddingDimensions, + isAutoEmbedded: this.autoEmbed, + }); + + console.info("Created new collection:", args.collectionName); + } + + /** + * Adds vectors and corresponding documents to the collection. + * + * @param {number[][]} vectors - The vectors to add. + * @param {Document[]} documents - The corresponding documents to add. + * @returns {Promise} - A promise that resolves with the UUIDs of the added documents. + */ + async addVectors( + vectors: number[][], + documents: Document[] + ): Promise { + if (!this.autoEmbed && vectors.length === 0) { + throw new Error(`Vectors must be provided if autoEmbed is false`); + } + if (!this.autoEmbed && vectors.length !== documents.length) { + throw new Error(`Vectors and documents must have the same length`); + } + + const docs: Array = []; + for (let i = 0; i < documents.length; i += 1) { + const doc: IDocument = { + content: documents[i].pageContent, + metadata: documents[i].metadata, + embedding: vectors.length > 0 ? vectors[i] : undefined, + }; + docs.push(doc); + } + // Wait for collection to be initialized + await this.initPromise; + return await this.collection.addDocuments(docs); + } + + /** + * Adds documents to the collection. The documents are first embedded into vectors + * using the provided embedding model. + * + * @param {Document[]} documents - The documents to add. + * @returns {Promise} - A promise that resolves with the UUIDs of the added documents. + */ + async addDocuments(documents: Document[]): Promise { + const texts = documents.map(({ pageContent }) => pageContent); + let vectors: number[][] = []; + if (!this.autoEmbed) { + vectors = await this.embeddings.embedDocuments(texts); + } + return this.addVectors(vectors, documents); + } + + _vectorstoreType(): string { + return "zep"; + } + + /** + * Deletes documents from the collection. + * + * @param {IZepDeleteParams} params - The list of Zep document UUIDs to delete. + * @returns {Promise} + */ + async delete(params: IZepDeleteParams): Promise { + // Wait for collection to be initialized + await this.initPromise; + for (const uuid of params.uuids) { + await this.collection.deleteDocument(uuid); + } + } + + /** + * Performs a similarity search in the collection and returns the results with their scores. + * + * @param {number[]} query - The query vector. + * @param {number} k - The number of results to return. + * @param {Record} filter - The filter to apply to the search. Zep only supports Record as filter. + * @returns {Promise<[Document, number][]>} - A promise that resolves with the search results and their scores. + */ + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: Record | undefined + ): Promise<[Document, number][]> { + await this.initPromise; + const results = await this.collection.search( + { + embedding: new Float32Array(query), + metadata: assignMetadata(filter), + }, + k + ); + return zepDocsToDocumentsAndScore(results); + } + + async _similaritySearchWithScore( + query: string, + k: number, + filter?: Record | undefined + ): Promise<[Document, number][]> { + await this.initPromise; + const results = await this.collection.search( + { + text: query, + metadata: assignMetadata(filter), + }, + k + ); + return zepDocsToDocumentsAndScore(results); + } + + async similaritySearchWithScore( + query: string, + k = 4, + filter: Record | undefined = undefined, + _callbacks = undefined // implement passing to embedQuery later + ): Promise<[Document, number][]> { + if (this.autoEmbed) { + return this._similaritySearchWithScore(query, k, filter); + } else { + return this.similaritySearchVectorWithScore( + await this.embeddings.embedQuery(query), + k, + filter + ); + } + } + + /** + * Performs a similarity search on the Zep collection. + * + * @param {string} query - The query string to search for. + * @param {number} [k=4] - The number of results to return. Defaults to 4. + * @param {this["FilterType"] | undefined} [filter=undefined] - An optional set of JSONPath filters to apply to the search. + * @param {Callbacks | undefined} [_callbacks=undefined] - Optional callbacks. Currently not implemented. + * @returns {Promise} - A promise that resolves to an array of Documents that are similar to the query. + * + * @async + */ + async similaritySearch( + query: string, + k = 4, + filter: this["FilterType"] | undefined = undefined, + _callbacks: Callbacks | undefined = undefined // implement passing to embedQuery later + ): Promise { + await this.initPromise; + + let results: [Document, number][]; + if (this.autoEmbed) { + const zepResults = await this.collection.search( + { text: query, metadata: assignMetadata(filter) }, + k + ); + results = zepDocsToDocumentsAndScore(zepResults); + } else { + results = await this.similaritySearchVectorWithScore( + await this.embeddings.embedQuery(query), + k, + assignMetadata(filter) + ); + } + + return results.map((result) => result[0]); + } + + /** + * Return documents selected using the maximal marginal relevance. + * Maximal marginal relevance optimizes for similarity to the query AND diversity + * among selected documents. + * + * @param {string} query - Text to look up documents similar to. + * @param options + * @param {number} options.k - Number of documents to return. + * @param {number} options.fetchK=20- Number of documents to fetch before passing to the MMR algorithm. + * @param {number} options.lambda=0.5 - Number between 0 and 1 that determines the degree of diversity among the results, + * where 0 corresponds to maximum diversity and 1 to minimum diversity. + * @param {Record} options.filter - Optional Zep JSONPath query to pre-filter on document metadata field + * + * @returns {Promise} - List of documents selected by maximal marginal relevance. + */ + async maxMarginalRelevanceSearch( + query: string, + options: MaxMarginalRelevanceSearchOptions + ): Promise { + const { k, fetchK = 20, lambda = 0.5, filter } = options; + + let queryEmbedding: number[]; + let zepResults: IDocument[]; + if (!this.autoEmbed) { + queryEmbedding = await this.embeddings.embedQuery(query); + zepResults = await this.collection.search( + { + embedding: new Float32Array(queryEmbedding), + metadata: assignMetadata(filter), + }, + fetchK + ); + } else { + let queryEmbeddingArray: Float32Array; + [zepResults, queryEmbeddingArray] = + await this.collection.searchReturnQueryVector( + { text: query, metadata: assignMetadata(filter) }, + fetchK + ); + queryEmbedding = Array.from(queryEmbeddingArray); + } + + const results = zepDocsToDocumentsAndScore(zepResults); + + const embeddingList = zepResults.map((doc) => + Array.from(doc.embedding ? doc.embedding : []) + ); + + const mmrIndexes = maximalMarginalRelevance( + queryEmbedding, + embeddingList, + lambda, + k + ); + + return mmrIndexes.filter((idx) => idx !== -1).map((idx) => results[idx][0]); + } + + /** + * Creates a new ZepVectorStore instance from an array of texts. Each text is converted into a Document and added to the collection. + * + * @param {string[]} texts - The texts to convert into Documents. + * @param {object[] | object} metadatas - The metadata to associate with each Document. If an array is provided, each element is associated with the corresponding Document. If an object is provided, it is associated with all Documents. + * @param {Embeddings} embeddings - The embeddings to use for vectorizing the texts. + * @param {IZepConfig} zepConfig - The configuration object for the Zep API. + * @returns {Promise} - A promise that resolves with the new ZepVectorStore instance. + */ + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: Embeddings, + zepConfig: IZepConfig + ): Promise { + const docs: Document[] = []; + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas; + const newDoc = new Document({ + pageContent: texts[i], + metadata, + }); + docs.push(newDoc); + } + return ZepVectorStore.fromDocuments(docs, embeddings, zepConfig); + } + + /** + * Creates a new ZepVectorStore instance from an array of Documents. Each Document is added to a Zep collection. + * + * @param {Document[]} docs - The Documents to add. + * @param {Embeddings} embeddings - The embeddings to use for vectorizing the Document contents. + * @param {IZepConfig} zepConfig - The configuration object for the Zep API. + * @returns {Promise} - A promise that resolves with the new ZepVectorStore instance. + */ + static async fromDocuments( + docs: Document[], + embeddings: Embeddings, + zepConfig: IZepConfig + ): Promise { + const instance = new this(embeddings, zepConfig); + // Wait for collection to be initialized + await instance.initPromise; + await instance.addDocuments(docs); + return instance; + } +} + +function zepDocsToDocumentsAndScore( + results: IDocument[] +): [Document, number][] { + return results.map((d) => [ + new Document({ + pageContent: d.content, + metadata: d.metadata, + }), + d.score ? d.score : 0, + ]); +} + +function assignMetadata( + value: string | Record | object | undefined +): Record | undefined { + if (typeof value === "object" && value !== null) { + return value as Record; + } + if (value !== undefined) { + console.warn("Metadata filters must be an object, Record, or undefined."); + } + return undefined; +} diff --git a/libs/langchain-community/tsconfig.cjs.json b/libs/langchain-community/tsconfig.cjs.json new file mode 100644 index 000000000000..3b7026ea406c --- /dev/null +++ b/libs/langchain-community/tsconfig.cjs.json @@ -0,0 +1,8 @@ +{ + "extends": "./tsconfig.json", + "compilerOptions": { + "module": "commonjs", + "declaration": false + }, + "exclude": ["node_modules", "dist", "docs", "**/tests"] +} diff --git a/libs/langchain-community/tsconfig.json b/libs/langchain-community/tsconfig.json new file mode 100644 index 000000000000..bc85d83b6229 --- /dev/null +++ b/libs/langchain-community/tsconfig.json @@ -0,0 +1,23 @@ +{ + "extends": "@tsconfig/recommended", + "compilerOptions": { + "outDir": "../dist", + "rootDir": "./src", + "target": "ES2021", + "lib": ["ES2021", "ES2022.Object", "DOM"], + "module": "ES2020", + "moduleResolution": "nodenext", + "esModuleInterop": true, + "declaration": true, + "noImplicitReturns": true, + "noFallthroughCasesInSwitch": true, + "noUnusedLocals": true, + "noUnusedParameters": true, + "useDefineForClassFields": true, + "strictPropertyInitialization": false, + "allowJs": true, + "strict": true + }, + "include": ["src/**/*"], + "exclude": ["node_modules", "dist", "docs"] +} diff --git a/libs/langchain-openai/package.json b/libs/langchain-openai/package.json index 6e5aea4a189c..49e5f04823f3 100644 --- a/libs/langchain-openai/package.json +++ b/libs/langchain-openai/package.json @@ -1,6 +1,6 @@ { "name": "@langchain/openai", - "version": "0.0.1", + "version": "0.0.2-rc.0", "description": "OpenAI integrations for LangChain.js", "type": "module", "engines": { @@ -34,14 +34,13 @@ "author": "LangChain", "license": "MIT", "dependencies": { - "@langchain/core": "~0.0.1", + "@langchain/core": "~0.0.11-rc.1", "js-tiktoken": "^1.0.7", "openai": "^4.19.0", "zod-to-json-schema": "3.20.3" }, "devDependencies": { "@jest/globals": "^29.5.0", - "@langchain/core": "workspace:*", "@swc/core": "^1.3.90", "@swc/jest": "^0.2.29", "dpdm": "^3.12.0", diff --git a/package.json b/package.json index 44b3f60cc84a..f8dce25ba422 100644 --- a/package.json +++ b/package.json @@ -19,7 +19,7 @@ "packageManager": "yarn@3.4.1", "scripts": { "build": "turbo run build --filter=\"!test-exports-*\" --concurrency 1", - "build:deps": "yarn workspace @langchain/core build && yarn workspace @langchain/anthropic build && yarn workspace @langchain/openai build", + "build:deps": "yarn workspace @langchain/core build && yarn workspace @langchain/anthropic build && yarn workspace @langchain/openai build && yarn workspace @langchain/community build", "format": "turbo run format", "format:check": "turbo run format:check", "lint": "turbo run lint --concurrency 1", diff --git a/turbo.json b/turbo.json index 01a50e732e26..18bd38254a0f 100644 --- a/turbo.json +++ b/turbo.json @@ -3,14 +3,17 @@ "globalDependencies": ["**/.env"], "pipeline": { "@langchain/core#build": {}, - "libs/langchain-anthropic#build": { + "@langchain/anthropic#build": { "dependsOn": ["@langchain/core#build"] }, - "libs/langchain-openai#build": { + "@langchain/openai#build": { "dependsOn": ["@langchain/core#build"] }, + "@langchain/community#build": { + "dependsOn": ["@langchain/openai#build"] + }, "build": { - "dependsOn": ["@langchain/core#build", "^build"], + "dependsOn": ["@langchain/core#build", "@langchain/community#build", "^build"], "outputs": ["dist/**", "dist-cjs/**", "*.js", "*.cjs", "*.d.ts"], "inputs": ["src/**", "scripts/**", "package.json", "tsconfig.json"] }, diff --git a/yarn.lock b/yarn.lock index b8ef29229cb9..4796b182f1fd 100644 --- a/yarn.lock +++ b/yarn.lock @@ -7997,7 +7997,398 @@ __metadata: languageName: unknown linkType: soft -"@langchain/core@workspace:*, @langchain/core@workspace:langchain-core, @langchain/core@~0.0.10": +"@langchain/community@workspace:libs/langchain-community, @langchain/community@~0.0.0": + version: 0.0.0-use.local + resolution: "@langchain/community@workspace:libs/langchain-community" + dependencies: + "@aws-crypto/sha256-js": ^5.0.0 + "@aws-sdk/client-bedrock-runtime": ^3.422.0 + "@aws-sdk/client-dynamodb": ^3.310.0 + "@aws-sdk/client-kendra": ^3.352.0 + "@aws-sdk/client-lambda": ^3.310.0 + "@aws-sdk/client-sagemaker-runtime": ^3.414.0 + "@aws-sdk/client-sfn": ^3.362.0 + "@aws-sdk/credential-provider-node": ^3.388.0 + "@aws-sdk/types": ^3.357.0 + "@clickhouse/client": ^0.2.5 + "@cloudflare/ai": ^1.0.12 + "@cloudflare/workers-types": ^4.20230922.0 + "@elastic/elasticsearch": ^8.4.0 + "@getmetal/metal-sdk": ^4.0.0 + "@getzep/zep-js": ^0.9.0 + "@gomomento/sdk": ^1.51.1 + "@gomomento/sdk-core": ^1.51.1 + "@google-ai/generativelanguage": ^0.2.1 + "@google-cloud/storage": ^6.10.1 + "@gradientai/nodejs-sdk": ^1.2.0 + "@huggingface/inference": ^2.6.4 + "@jest/globals": ^29.5.0 + "@langchain/core": ~0.0.11-rc.1 + "@langchain/openai": ~0.0.2-rc.0 + "@mozilla/readability": ^0.4.4 + "@notionhq/client": ^2.2.10 + "@opensearch-project/opensearch": ^2.2.0 + "@pinecone-database/pinecone": ^1.1.0 + "@planetscale/database": ^1.8.0 + "@qdrant/js-client-rest": ^1.2.0 + "@raycast/api": ^1.55.2 + "@rockset/client": ^0.9.1 + "@smithy/eventstream-codec": ^2.0.5 + "@smithy/protocol-http": ^3.0.6 + "@smithy/signature-v4": ^2.0.10 + "@smithy/util-utf8": ^2.0.0 + "@supabase/postgrest-js": ^1.1.1 + "@supabase/supabase-js": ^2.10.0 + "@swc/core": ^1.3.90 + "@swc/jest": ^0.2.29 + "@tensorflow-models/universal-sentence-encoder": ^1.3.3 + "@tensorflow/tfjs-backend-cpu": ^3 + "@tensorflow/tfjs-converter": ^3.6.0 + "@tensorflow/tfjs-core": ^3.6.0 + "@tsconfig/recommended": ^1.0.2 + "@types/flat": ^5.0.2 + "@types/html-to-text": ^9 + "@types/jsdom": ^21.1.1 + "@types/lodash": ^4 + "@types/mozilla-readability": ^0.2.1 + "@types/pg": ^8 + "@types/pg-copy-streams": ^1.2.2 + "@types/uuid": ^9 + "@types/ws": ^8 + "@typescript-eslint/eslint-plugin": ^5.58.0 + "@typescript-eslint/parser": ^5.58.0 + "@upstash/redis": ^1.20.6 + "@vercel/kv": ^0.2.3 + "@vercel/postgres": ^0.5.0 + "@writerai/writer-sdk": ^0.40.2 + "@xata.io/client": ^0.28.0 + "@xenova/transformers": ^2.5.4 + "@zilliz/milvus2-sdk-node": ">=2.2.11" + axios: ^0.26.0 + cassandra-driver: ^4.7.2 + chromadb: ^1.5.3 + closevector-common: 0.1.0-alpha.1 + closevector-node: 0.1.0-alpha.10 + closevector-web: 0.1.0-alpha.15 + cohere-ai: ">=6.0.0" + convex: ^1.3.1 + d3-dsv: ^2.0.0 + dotenv: ^16.0.3 + dpdm: ^3.12.0 + eslint: ^8.33.0 + eslint-config-airbnb-base: ^15.0.0 + eslint-config-prettier: ^8.6.0 + eslint-plugin-import: ^2.27.5 + eslint-plugin-jest: ^27.6.0 + eslint-plugin-no-instanceof: ^1.0.1 + eslint-plugin-prettier: ^4.2.1 + faiss-node: ^0.5.1 + fast-xml-parser: ^4.2.7 + firebase-admin: ^11.9.0 + flat: ^5.0.2 + google-auth-library: ^8.9.0 + googleapis: ^126.0.1 + graphql: ^16.6.0 + hnswlib-node: ^1.4.2 + html-to-text: ^9.0.5 + ignore: ^5.2.0 + ioredis: ^5.3.2 + jest: ^29.5.0 + jest-environment-node: ^29.6.4 + jsdom: ^22.1.0 + langsmith: ~0.0.48 + llmonitor: ^0.5.9 + lodash: ^4.17.21 + mammoth: ^1.5.1 + ml-distance: ^4.0.0 + mongodb: ^5.2.0 + mysql2: ^3.3.3 + neo4j-driver: ^5.12.0 + node-llama-cpp: 2.7.3 + pg: ^8.11.0 + pg-copy-streams: ^6.0.5 + pickleparser: ^0.2.1 + portkey-ai: ^0.1.11 + prettier: ^2.8.3 + pyodide: ^0.24.1 + redis: ^4.6.6 + release-it: ^15.10.1 + replicate: ^0.18.0 + rollup: ^3.19.1 + sqlite3: ^5.1.4 + ts-jest: ^29.1.0 + typeorm: ^0.3.12 + typescript: ~5.1.6 + typesense: ^1.5.3 + usearch: ^1.1.1 + uuid: ^9.0.0 + vectordb: ^0.1.4 + voy-search: 0.6.2 + weaviate-ts-client: ^1.4.0 + web-auth-library: ^1.0.3 + zod: ^3.22.3 + peerDependencies: + "@aws-crypto/sha256-js": ^5.0.0 + "@aws-sdk/client-bedrock-runtime": ^3.422.0 + "@aws-sdk/client-dynamodb": ^3.310.0 + "@aws-sdk/client-kendra": ^3.352.0 + "@aws-sdk/client-lambda": ^3.310.0 + "@aws-sdk/client-sagemaker-runtime": ^3.310.0 + "@aws-sdk/client-sfn": ^3.310.0 + "@aws-sdk/credential-provider-node": ^3.388.0 + "@clickhouse/client": ^0.2.5 + "@cloudflare/ai": ^1.0.12 + "@elastic/elasticsearch": ^8.4.0 + "@faker-js/faker": ^7.6.0 + "@getmetal/metal-sdk": "*" + "@getzep/zep-js": ^0.9.0 + "@gomomento/sdk": ^1.51.1 + "@gomomento/sdk-core": ^1.51.1 + "@gomomento/sdk-web": ^1.51.1 + "@google-ai/generativelanguage": ^0.2.1 + "@google-cloud/storage": ^6.10.1 + "@gradientai/nodejs-sdk": ^1.2.0 + "@huggingface/inference": ^2.6.4 + "@mozilla/readability": "*" + "@notionhq/client": ^2.2.10 + "@opensearch-project/opensearch": "*" + "@pinecone-database/pinecone": ^1.1.0 + "@planetscale/database": ^1.8.0 + "@qdrant/js-client-rest": ^1.2.0 + "@raycast/api": ^1.55.2 + "@rockset/client": ^0.9.1 + "@smithy/eventstream-codec": ^2.0.5 + "@smithy/protocol-http": ^3.0.6 + "@smithy/signature-v4": ^2.0.10 + "@smithy/util-utf8": ^2.0.0 + "@supabase/postgrest-js": ^1.1.1 + "@supabase/supabase-js": ^2.10.0 + "@tensorflow-models/universal-sentence-encoder": "*" + "@tensorflow/tfjs-converter": "*" + "@tensorflow/tfjs-core": "*" + "@upstash/redis": ^1.20.6 + "@vercel/kv": ^0.2.3 + "@vercel/postgres": ^0.5.0 + "@writerai/writer-sdk": ^0.40.2 + "@xata.io/client": ^0.28.0 + "@xenova/transformers": ^2.5.4 + "@zilliz/milvus2-sdk-node": ">=2.2.7" + axios: "*" + cassandra-driver: ^4.7.2 + chromadb: "*" + closevector-common: 0.1.0-alpha.1 + closevector-node: 0.1.0-alpha.10 + closevector-web: 0.1.0-alpha.16 + cohere-ai: ">=6.0.0" + convex: ^1.3.1 + d3-dsv: ^2.0.0 + faiss-node: ^0.5.1 + fast-xml-parser: ^4.2.7 + firebase-admin: ^11.9.0 + google-auth-library: ^8.9.0 + googleapis: ^126.0.1 + hnswlib-node: ^1.4.2 + html-to-text: ^9.0.5 + ignore: ^5.2.0 + ioredis: ^5.3.2 + jsdom: "*" + llmonitor: ^0.5.9 + lodash: ^4.17.21 + mammoth: "*" + mongodb: ^5.2.0 + mysql2: ^3.3.3 + neo4j-driver: "*" + node-llama-cpp: "*" + pg: ^8.11.0 + pg-copy-streams: ^6.0.5 + pickleparser: ^0.2.1 + portkey-ai: ^0.1.11 + pyodide: ^0.24.1 + redis: ^4.6.4 + replicate: ^0.18.0 + typeorm: ^0.3.12 + typesense: ^1.5.3 + usearch: ^1.1.1 + vectordb: ^0.1.4 + voy-search: 0.6.2 + weaviate-ts-client: ^1.4.0 + web-auth-library: ^1.0.3 + ws: ^8.14.2 + peerDependenciesMeta: + "@aws-crypto/sha256-js": + optional: true + "@aws-sdk/client-bedrock-runtime": + optional: true + "@aws-sdk/client-dynamodb": + optional: true + "@aws-sdk/client-kendra": + optional: true + "@aws-sdk/client-lambda": + optional: true + "@aws-sdk/client-sagemaker-runtime": + optional: true + "@aws-sdk/client-sfn": + optional: true + "@aws-sdk/credential-provider-node": + optional: true + "@clickhouse/client": + optional: true + "@cloudflare/ai": + optional: true + "@elastic/elasticsearch": + optional: true + "@getmetal/metal-sdk": + optional: true + "@getzep/zep-js": + optional: true + "@gomomento/sdk": + optional: true + "@gomomento/sdk-core": + optional: true + "@gomomento/sdk-web": + optional: true + "@google-ai/generativelanguage": + optional: true + "@google-cloud/storage": + optional: true + "@gradientai/nodejs-sdk": + optional: true + "@huggingface/inference": + optional: true + "@mozilla/readability": + optional: true + "@notionhq/client": + optional: true + "@opensearch-project/opensearch": + optional: true + "@pinecone-database/pinecone": + optional: true + "@planetscale/database": + optional: true + "@qdrant/js-client-rest": + optional: true + "@raycast/api": + optional: true + "@rockset/client": + optional: true + "@smithy/eventstream-codec": + optional: true + "@smithy/protocol-http": + optional: true + "@smithy/signature-v4": + optional: true + "@smithy/util-utf8": + optional: true + "@supabase/postgrest-js": + optional: true + "@supabase/supabase-js": + optional: true + "@tensorflow-models/universal-sentence-encoder": + optional: true + "@tensorflow/tfjs-converter": + optional: true + "@tensorflow/tfjs-core": + optional: true + "@upstash/redis": + optional: true + "@vercel/kv": + optional: true + "@vercel/postgres": + optional: true + "@writerai/writer-sdk": + optional: true + "@xata.io/client": + optional: true + "@xenova/transformers": + optional: true + "@zilliz/milvus2-sdk-node": + optional: true + axios: + optional: true + cassandra-driver: + optional: true + chromadb: + optional: true + closevector-common: + optional: true + closevector-node: + optional: true + closevector-web: + optional: true + cohere-ai: + optional: true + convex: + optional: true + d3-dsv: + optional: true + faiss-node: + optional: true + fast-xml-parser: + optional: true + firebase-admin: + optional: true + google-auth-library: + optional: true + googleapis: + optional: true + hnswlib-node: + optional: true + html-to-text: + optional: true + ignore: + optional: true + ioredis: + optional: true + jsdom: + optional: true + llmonitor: + optional: true + lodash: + optional: true + mammoth: + optional: true + mongodb: + optional: true + mysql2: + optional: true + neo4j-driver: + optional: true + node-llama-cpp: + optional: true + pg: + optional: true + pg-copy-streams: + optional: true + pickleparser: + optional: true + portkey-ai: + optional: true + pyodide: + optional: true + redis: + optional: true + replicate: + optional: true + typeorm: + optional: true + typesense: + optional: true + usearch: + optional: true + vectordb: + optional: true + voy-search: + optional: true + weaviate-ts-client: + optional: true + web-auth-library: + optional: true + ws: + optional: true + languageName: unknown + linkType: soft + +"@langchain/core@workspace:*, @langchain/core@workspace:langchain-core, @langchain/core@~0.0.11-rc.1": version: 0.0.0-use.local resolution: "@langchain/core@workspace:langchain-core" dependencies: @@ -8019,6 +8410,8 @@ __metadata: jest-environment-node: ^29.6.4 js-tiktoken: ^1.0.8 langsmith: ~0.0.48 + ml-distance: ^4.0.0 + ml-matrix: ^6.10.4 p-queue: ^6.6.2 p-retry: 4 prettier: ^2.8.3 @@ -8031,12 +8424,12 @@ __metadata: languageName: unknown linkType: soft -"@langchain/openai@workspace:libs/langchain-openai": +"@langchain/openai@workspace:libs/langchain-openai, @langchain/openai@~0.0.2-rc.0": version: 0.0.0-use.local resolution: "@langchain/openai@workspace:libs/langchain-openai" dependencies: "@jest/globals": ^29.5.0 - "@langchain/core": "workspace:*" + "@langchain/core": ~0.0.11-rc.1 "@swc/core": ^1.3.90 "@swc/jest": ^0.2.29 dpdm: ^3.12.0 @@ -22624,10 +23017,6 @@ __metadata: resolution: "langchain@workspace:langchain" dependencies: "@anthropic-ai/sdk": ^0.9.1 - "@aws-crypto/sha256-js": ^5.0.0 - "@aws-sdk/client-bedrock-runtime": ^3.422.0 - "@aws-sdk/client-dynamodb": ^3.310.0 - "@aws-sdk/client-kendra": ^3.352.0 "@aws-sdk/client-lambda": ^3.310.0 "@aws-sdk/client-s3": ^3.310.0 "@aws-sdk/client-sagemaker-runtime": ^3.414.0 @@ -22649,7 +23038,9 @@ __metadata: "@gradientai/nodejs-sdk": ^1.2.0 "@huggingface/inference": ^2.6.4 "@jest/globals": ^29.5.0 - "@langchain/core": ~0.0.10 + "@langchain/community": ~0.0.0 + "@langchain/core": ~0.0.11-rc.1 + "@langchain/openai": ~0.0.2-rc.0 "@mozilla/readability": ^0.4.4 "@notionhq/client": ^2.2.10 "@opensearch-project/opensearch": ^2.2.0 @@ -22673,7 +23064,6 @@ __metadata: "@tsconfig/recommended": ^1.0.2 "@types/d3-dsv": ^2 "@types/decamelize": ^1.2.0 - "@types/flat": ^5.0.2 "@types/html-to-text": ^9 "@types/js-yaml": ^4 "@types/jsdom": ^21.1.1 @@ -22720,7 +23110,6 @@ __metadata: faiss-node: ^0.5.1 fast-xml-parser: ^4.2.7 firebase-admin: ^11.9.0 - flat: ^5.0.2 google-auth-library: ^8.9.0 googleapis: ^126.0.1 graphql: ^16.6.0 @@ -22740,14 +23129,12 @@ __metadata: lodash: ^4.17.21 mammoth: ^1.5.1 ml-distance: ^4.0.0 - ml-matrix: ^6.10.4 mongodb: ^5.2.0 mysql2: ^3.3.3 neo4j-driver: ^5.12.0 node-llama-cpp: 2.7.3 notion-to-md: ^3.1.0 officeparser: ^4.0.4 - openai: ^4.19.0 openapi-types: ^12.1.3 p-retry: 4 pdf-parse: 1.1.1 @@ -22784,10 +23171,6 @@ __metadata: zod: ^3.22.3 zod-to-json-schema: 3.20.3 peerDependencies: - "@aws-crypto/sha256-js": ^5.0.0 - "@aws-sdk/client-bedrock-runtime": ^3.422.0 - "@aws-sdk/client-dynamodb": ^3.310.0 - "@aws-sdk/client-kendra": ^3.352.0 "@aws-sdk/client-lambda": ^3.310.0 "@aws-sdk/client-s3": ^3.310.0 "@aws-sdk/client-sagemaker-runtime": ^3.310.0 @@ -22886,14 +23269,6 @@ __metadata: youtube-transcript: ^1.0.6 youtubei.js: ^5.8.0 peerDependenciesMeta: - "@aws-crypto/sha256-js": - optional: true - "@aws-sdk/client-bedrock-runtime": - optional: true - "@aws-sdk/client-dynamodb": - optional: true - "@aws-sdk/client-kendra": - optional: true "@aws-sdk/client-lambda": optional: true "@aws-sdk/client-s3":