From b982b429d46c16b144f82b946e4c2f5ca8fca694 Mon Sep 17 00:00:00 2001 From: Chase McDougall Date: Thu, 30 Nov 2023 17:17:26 -0500 Subject: [PATCH] langchain[minor]: feat(LLM Integration): WatsonX AI Integration (#3399) * base watsonx file * base setup route * build * add project id lc secret * WatsonX AI Functional MVP * enable custom model parameters * enable custom model selection with llama 2 default * wrap fetch in `this.caller.call` * run format * add request error handling * update typedoc string * yarn format * add watsonx ai example * watsonx-ai -> watsonx_ai * Add watson x documentation * delete old files * yarn format * add error for missing project id * Add setup note about secrets to docs * format * remove redundant count * add var set on class instantiation example * format * update ibm cloud api key to follow common convention * fix type cast * update llmType casing * add iam token caching * use expiration field * Small style fixes in dos * Update watsonx_ai.ts --------- Co-authored-by: jacoblee93 --- docs/api_refs/typedoc.json | 1 + .../docs/integrations/llms/watsonx_ai.mdx | 27 +++ examples/src/llms/watsonx_ai.ts | 20 ++ langchain/.gitignore | 3 + langchain/package.json | 8 + langchain/scripts/create-entrypoints.js | 2 + langchain/src/llms/watsonx_ai.ts | 194 ++++++++++++++++++ langchain/src/load/import_constants.ts | 1 + langchain/src/load/import_type.d.ts | 5 + 9 files changed, 261 insertions(+) create mode 100644 docs/core_docs/docs/integrations/llms/watsonx_ai.mdx create mode 100644 examples/src/llms/watsonx_ai.ts create mode 100644 langchain/src/llms/watsonx_ai.ts diff --git a/docs/api_refs/typedoc.json b/docs/api_refs/typedoc.json index 800d235ded73..1c7c87e352c3 100644 --- a/docs/api_refs/typedoc.json +++ b/docs/api_refs/typedoc.json @@ -82,6 +82,7 @@ "./langchain/src/llms/googlepalm.ts", "./langchain/src/llms/fireworks.ts", "./langchain/src/llms/sagemaker_endpoint.ts", + "./langchain/src/llms/watsonx_ai.ts", "./langchain/src/llms/bedrock/index.ts", "./langchain/src/llms/bedrock/web.ts", "./langchain/src/llms/llama_cpp.ts", diff --git a/docs/core_docs/docs/integrations/llms/watsonx_ai.mdx b/docs/core_docs/docs/integrations/llms/watsonx_ai.mdx new file mode 100644 index 000000000000..ac0168b9db05 --- /dev/null +++ b/docs/core_docs/docs/integrations/llms/watsonx_ai.mdx @@ -0,0 +1,27 @@ +# WatsonX AI + +LangChain.js supports integration with IBM WatsonX AI. Checkout [WatsonX AI](https://www.ibm.com/products/watsonx-ai) for a list of available models. + +## Setup + +You will need to set the following environment variables for using the WatsonX AI API. + +1. `IBM_CLOUD_API_KEY` which can be generated via [IBM Cloud](https://cloud.ibm.com/iam/apikeys) +2. `WATSONX_PROJECT_ID` which can be found in your [project's manage tab](https://dataplatform.cloud.ibm.com/projects/?context=wx) + +Alternatively, these can be set during the WatsonxAI Class instantiation as `ibmCloudApiKey` and `projectId` respectively. +For example: + +```typescript +const model = new WatsonxAI({ + ibmCloudApiKey: "My secret IBM Cloud API Key" + projectId: "My secret WatsonX AI Project id" +}); +``` + +## Usage + +import CodeBlock from "@theme/CodeBlock"; +import WatsonxAiExample from "@examples/llms/watsonx_ai.ts"; + +{WatsonxAiExample} diff --git a/examples/src/llms/watsonx_ai.ts b/examples/src/llms/watsonx_ai.ts new file mode 100644 index 000000000000..f40a50c6a835 --- /dev/null +++ b/examples/src/llms/watsonx_ai.ts @@ -0,0 +1,20 @@ +import { WatsonxAI } from "langchain/llms/watsonx_ai"; + +export const run = async () => { + // Note that modelParameters are optional + const model = new WatsonxAI({ + modelId: "meta-llama/llama-2-70b-chat", + modelParameters: { + max_new_tokens: 100, + min_new_tokens: 0, + stop_sequences: [], + repetition_penalty: 1, + }, + }); + + const res = await model.invoke( + "What would be a good company name for a company that makes colorful socks?" + ); + + console.log({ res }); +}; diff --git a/langchain/.gitignore b/langchain/.gitignore index 093f76cd21eb..14293fe114d2 100644 --- a/langchain/.gitignore +++ b/langchain/.gitignore @@ -190,6 +190,9 @@ llms/fireworks.d.ts llms/sagemaker_endpoint.cjs llms/sagemaker_endpoint.js llms/sagemaker_endpoint.d.ts +llms/watsonx_ai.cjs +llms/watsonx_ai.js +llms/watsonx_ai.d.ts llms/bedrock.cjs llms/bedrock.js llms/bedrock.d.ts diff --git a/langchain/package.json b/langchain/package.json index 662dbfd68cea..41654a5f760f 100644 --- a/langchain/package.json +++ b/langchain/package.json @@ -202,6 +202,9 @@ "llms/sagemaker_endpoint.cjs", "llms/sagemaker_endpoint.js", "llms/sagemaker_endpoint.d.ts", + "llms/watsonx_ai.cjs", + "llms/watsonx_ai.js", + "llms/watsonx_ai.d.ts", "llms/bedrock.cjs", "llms/bedrock.js", "llms/bedrock.d.ts", @@ -1756,6 +1759,11 @@ "import": "./llms/sagemaker_endpoint.js", "require": "./llms/sagemaker_endpoint.cjs" }, + "./llms/watsonx_ai": { + "types": "./llms/watsonx_ai.d.ts", + "import": "./llms/watsonx_ai.js", + "require": "./llms/watsonx_ai.cjs" + }, "./llms/bedrock": { "types": "./llms/bedrock.d.ts", "import": "./llms/bedrock.js", diff --git a/langchain/scripts/create-entrypoints.js b/langchain/scripts/create-entrypoints.js index ff1e67984735..6e5846b3602f 100644 --- a/langchain/scripts/create-entrypoints.js +++ b/langchain/scripts/create-entrypoints.js @@ -79,6 +79,7 @@ const entrypoints = { "llms/googlepalm": "llms/googlepalm", "llms/fireworks": "llms/fireworks", "llms/sagemaker_endpoint": "llms/sagemaker_endpoint", + "llms/watsonx_ai": "llms/watsonx_ai", "llms/bedrock": "llms/bedrock/index", "llms/bedrock/web": "llms/bedrock/web", "llms/llama_cpp": "llms/llama_cpp", @@ -369,6 +370,7 @@ const requiresOptionalDependency = [ "llms/raycast", "llms/replicate", "llms/sagemaker_endpoint", + "llms/watsonx_ai", "llms/bedrock", "llms/bedrock/web", "llms/llama_cpp", diff --git a/langchain/src/llms/watsonx_ai.ts b/langchain/src/llms/watsonx_ai.ts new file mode 100644 index 000000000000..dca510ba21c1 --- /dev/null +++ b/langchain/src/llms/watsonx_ai.ts @@ -0,0 +1,194 @@ +import { BaseLLMCallOptions, BaseLLMParams, LLM } from "./base.js"; +import { getEnvironmentVariable } from "../util/env.js"; + +/** + * The WatsonxAIParams interface defines the input parameters for + * the WatsonxAI class. + */ +export interface WatsonxAIParams extends BaseLLMParams { + /** + * WatsonX AI Complete Endpoint. + * Can be used if you want a fully custom endpoint. + */ + endpoint?: string; + /** + * IBM Cloud Compute Region. + * eg. us-south, us-east, etc. + */ + region?: string; + /** + * WatsonX AI Version. + * Date representing the WatsonX AI Version. + * eg. 2023-05-29 + */ + version?: string; + /** + * WatsonX AI Key. + * Provide API Key if you do not wish to automatically pull from env. + */ + ibmCloudApiKey?: string; + /** + * WatsonX AI Key. + * Provide API Key if you do not wish to automatically pull from env. + */ + projectId?: string; + /** + * Parameters accepted by the WatsonX AI Endpoint. + */ + modelParameters?: Record; + /** + * WatsonX AI Model ID. + */ + modelId?: string; +} + +const endpointConstructor = (region: string, version: string) => + `https://${region}.ml.cloud.ibm.com/ml/v1-beta/generation/text?version=${version}`; + +/** + * The WatsonxAI class is used to interact with Watsonx AI + * Inference Endpoint models. It uses IBM Cloud for authentication. + * This requires your IBM Cloud API Key which is autoloaded if not specified. + */ + +export class WatsonxAI extends LLM { + static lc_name() { + return "WatsonxAI"; + } + + get lc_secrets(): { [key: string]: string } | undefined { + return { + ibmCloudApiKey: "IBM_CLOUD_API_KEY", + projectId: "WATSONX_PROJECT_ID", + }; + } + + endpoint: string; + + region = "us-south"; + + version = "2023-05-29"; + + modelId = "meta-llama/llama-2-70b-chat"; + + modelKwargs?: Record; + + ibmCloudApiKey?: string; + + ibmCloudToken?: string; + + ibmCloudTokenExpiresAt?: number; + + projectId?: string; + + modelParameters?: Record; + + constructor(fields: WatsonxAIParams) { + super(fields); + + this.region = fields?.region ?? this.region; + this.version = fields?.version ?? this.version; + this.modelId = fields?.modelId ?? this.modelId; + this.ibmCloudApiKey = + fields?.ibmCloudApiKey ?? getEnvironmentVariable("IBM_CLOUD_API_KEY"); + this.projectId = + fields?.projectId ?? getEnvironmentVariable("WATSONX_PROJECT_ID"); + + this.endpoint = + fields?.endpoint ?? endpointConstructor(this.region, this.version); + this.modelParameters = fields.modelParameters; + + if (!this.ibmCloudApiKey) { + throw new Error("Missing IBM Cloud API Key"); + } + + if (!this.projectId) { + throw new Error("Missing WatsonX AI Project ID"); + } + } + + _llmType() { + return "watsonx_ai"; + } + + /** + * Calls the WatsonX AI endpoint and retrieves the result. + * @param {string} prompt The input prompt. + * @returns {Promise} A promise that resolves to the generated string. + */ + /** @ignore */ + async _call( + prompt: string, + _options: this["ParsedCallOptions"] + ): Promise { + interface WatsonxAIResponse { + results: { + generated_text: string; + generated_token_count: number; + input_token_count: number; + }[]; + errors: { + code: string; + message: string; + }[]; + } + const response = (await this.caller.call(async () => + fetch(this.endpoint, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json", + Authorization: `Bearer ${await this.generateToken()}`, + }, + body: JSON.stringify({ + project_id: this.projectId, + model_id: this.modelId, + input: prompt, + parameters: this.modelParameters, + }), + }).then((res) => res.json()) + )) as WatsonxAIResponse; + + /** + * Handle Errors for invalid requests. + */ + if (response.errors) { + throw new Error(response.errors[0].message); + } + + return response.results[0].generated_text; + } + + async generateToken(): Promise { + if (this.ibmCloudToken && this.ibmCloudTokenExpiresAt) { + if (this.ibmCloudTokenExpiresAt > Date.now()) { + return this.ibmCloudToken; + } + } + + interface TokenResponse { + access_token: string; + expiration: number; + } + + const urlTokenParams = new URLSearchParams(); + urlTokenParams.append( + "grant_type", + "urn:ibm:params:oauth:grant-type:apikey" + ); + urlTokenParams.append("apikey", this.ibmCloudApiKey as string); + + const data = (await fetch("https://iam.cloud.ibm.com/identity/token", { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: urlTokenParams, + }).then((res) => res.json())) as TokenResponse; + + this.ibmCloudTokenExpiresAt = data.expiration * 1000; + this.ibmCloudToken = data.access_token; + + return this.ibmCloudToken; + } +} diff --git a/langchain/src/load/import_constants.ts b/langchain/src/load/import_constants.ts index 5ecf4d99784d..961dafb48c9a 100644 --- a/langchain/src/load/import_constants.ts +++ b/langchain/src/load/import_constants.ts @@ -33,6 +33,7 @@ export const optionalImportEntrypoints = [ "langchain/llms/googlevertexai/web", "langchain/llms/googlepalm", "langchain/llms/sagemaker_endpoint", + "langchain/llms/watsonx_ai", "langchain/llms/bedrock", "langchain/llms/bedrock/web", "langchain/llms/llama_cpp", diff --git a/langchain/src/load/import_type.d.ts b/langchain/src/load/import_type.d.ts index 9755865d9556..8018ae39b2a7 100644 --- a/langchain/src/load/import_type.d.ts +++ b/langchain/src/load/import_type.d.ts @@ -97,6 +97,9 @@ export interface OptionalImportMap { "langchain/llms/sagemaker_endpoint"?: | typeof import("../llms/sagemaker_endpoint.js") | Promise; + "langchain/llms/watsonx_ai"?: + | typeof import("../llms/watsonx_ai.js") + | Promise; "langchain/llms/bedrock"?: | typeof import("../llms/bedrock/index.js") | Promise; @@ -520,6 +523,7 @@ export interface SecretMap { GOOGLE_PALM_API_KEY?: string; GOOGLE_VERTEX_AI_WEB_CREDENTIALS?: string; HUGGINGFACEHUB_API_KEY?: string; + IBM_CLOUD_API_KEY?: string; IFLYTEK_API_KEY?: string; IFLYTEK_API_SECRET?: string; MILVUS_PASSWORD?: string; @@ -547,6 +551,7 @@ export interface SecretMap { VECTARA_API_KEY?: string; VECTARA_CORPUS_ID?: string; VECTARA_CUSTOMER_ID?: string; + WATSONX_PROJECT_ID?: string; WRITER_API_KEY?: string; WRITER_ORG_ID?: string; YC_API_KEY?: string;