-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
langchain[minor]: feat(LLM Integration): WatsonX AI Integration (#3399)
* base watsonx file * base setup route * build * add project id lc secret * WatsonX AI Functional MVP * enable custom model parameters * enable custom model selection with llama 2 default * wrap fetch in `this.caller.call` * run format * add request error handling * update typedoc string * yarn format * add watsonx ai example * watsonx-ai -> watsonx_ai * Add watson x documentation * delete old files * yarn format * add error for missing project id * Add setup note about secrets to docs * format * remove redundant count * add var set on class instantiation example * format * update ibm cloud api key to follow common convention * fix type cast * update llmType casing * add iam token caching * use expiration field * Small style fixes in dos * Update watsonx_ai.ts --------- Co-authored-by: jacoblee93 <[email protected]>
- Loading branch information
1 parent
b557b13
commit b982b42
Showing
9 changed files
with
261 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# WatsonX AI | ||
|
||
LangChain.js supports integration with IBM WatsonX AI. Checkout [WatsonX AI](https://www.ibm.com/products/watsonx-ai) for a list of available models. | ||
|
||
## Setup | ||
|
||
You will need to set the following environment variables for using the WatsonX AI API. | ||
|
||
1. `IBM_CLOUD_API_KEY` which can be generated via [IBM Cloud](https://cloud.ibm.com/iam/apikeys) | ||
2. `WATSONX_PROJECT_ID` which can be found in your [project's manage tab](https://dataplatform.cloud.ibm.com/projects/?context=wx) | ||
|
||
Alternatively, these can be set during the WatsonxAI Class instantiation as `ibmCloudApiKey` and `projectId` respectively. | ||
For example: | ||
|
||
```typescript | ||
const model = new WatsonxAI({ | ||
ibmCloudApiKey: "My secret IBM Cloud API Key" | ||
projectId: "My secret WatsonX AI Project id" | ||
}); | ||
``` | ||
|
||
## Usage | ||
|
||
import CodeBlock from "@theme/CodeBlock"; | ||
import WatsonxAiExample from "@examples/llms/watsonx_ai.ts"; | ||
|
||
<CodeBlock language="typescript">{WatsonxAiExample}</CodeBlock> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import { WatsonxAI } from "langchain/llms/watsonx_ai"; | ||
|
||
export const run = async () => { | ||
// Note that modelParameters are optional | ||
const model = new WatsonxAI({ | ||
modelId: "meta-llama/llama-2-70b-chat", | ||
modelParameters: { | ||
max_new_tokens: 100, | ||
min_new_tokens: 0, | ||
stop_sequences: [], | ||
repetition_penalty: 1, | ||
}, | ||
}); | ||
|
||
const res = await model.invoke( | ||
"What would be a good company name for a company that makes colorful socks?" | ||
); | ||
|
||
console.log({ res }); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
import { BaseLLMCallOptions, BaseLLMParams, LLM } from "./base.js"; | ||
import { getEnvironmentVariable } from "../util/env.js"; | ||
|
||
/** | ||
* The WatsonxAIParams interface defines the input parameters for | ||
* the WatsonxAI class. | ||
*/ | ||
export interface WatsonxAIParams extends BaseLLMParams { | ||
/** | ||
* WatsonX AI Complete Endpoint. | ||
* Can be used if you want a fully custom endpoint. | ||
*/ | ||
endpoint?: string; | ||
/** | ||
* IBM Cloud Compute Region. | ||
* eg. us-south, us-east, etc. | ||
*/ | ||
region?: string; | ||
/** | ||
* WatsonX AI Version. | ||
* Date representing the WatsonX AI Version. | ||
* eg. 2023-05-29 | ||
*/ | ||
version?: string; | ||
/** | ||
* WatsonX AI Key. | ||
* Provide API Key if you do not wish to automatically pull from env. | ||
*/ | ||
ibmCloudApiKey?: string; | ||
/** | ||
* WatsonX AI Key. | ||
* Provide API Key if you do not wish to automatically pull from env. | ||
*/ | ||
projectId?: string; | ||
/** | ||
* Parameters accepted by the WatsonX AI Endpoint. | ||
*/ | ||
modelParameters?: Record<string, unknown>; | ||
/** | ||
* WatsonX AI Model ID. | ||
*/ | ||
modelId?: string; | ||
} | ||
|
||
const endpointConstructor = (region: string, version: string) => | ||
`https://${region}.ml.cloud.ibm.com/ml/v1-beta/generation/text?version=${version}`; | ||
|
||
/** | ||
* The WatsonxAI class is used to interact with Watsonx AI | ||
* Inference Endpoint models. It uses IBM Cloud for authentication. | ||
* This requires your IBM Cloud API Key which is autoloaded if not specified. | ||
*/ | ||
|
||
export class WatsonxAI extends LLM<BaseLLMCallOptions> { | ||
static lc_name() { | ||
return "WatsonxAI"; | ||
} | ||
|
||
get lc_secrets(): { [key: string]: string } | undefined { | ||
return { | ||
ibmCloudApiKey: "IBM_CLOUD_API_KEY", | ||
projectId: "WATSONX_PROJECT_ID", | ||
}; | ||
} | ||
|
||
endpoint: string; | ||
|
||
region = "us-south"; | ||
|
||
version = "2023-05-29"; | ||
|
||
modelId = "meta-llama/llama-2-70b-chat"; | ||
|
||
modelKwargs?: Record<string, unknown>; | ||
|
||
ibmCloudApiKey?: string; | ||
|
||
ibmCloudToken?: string; | ||
|
||
ibmCloudTokenExpiresAt?: number; | ||
|
||
projectId?: string; | ||
|
||
modelParameters?: Record<string, unknown>; | ||
|
||
constructor(fields: WatsonxAIParams) { | ||
super(fields); | ||
|
||
this.region = fields?.region ?? this.region; | ||
this.version = fields?.version ?? this.version; | ||
this.modelId = fields?.modelId ?? this.modelId; | ||
this.ibmCloudApiKey = | ||
fields?.ibmCloudApiKey ?? getEnvironmentVariable("IBM_CLOUD_API_KEY"); | ||
this.projectId = | ||
fields?.projectId ?? getEnvironmentVariable("WATSONX_PROJECT_ID"); | ||
|
||
this.endpoint = | ||
fields?.endpoint ?? endpointConstructor(this.region, this.version); | ||
this.modelParameters = fields.modelParameters; | ||
|
||
if (!this.ibmCloudApiKey) { | ||
throw new Error("Missing IBM Cloud API Key"); | ||
} | ||
|
||
if (!this.projectId) { | ||
throw new Error("Missing WatsonX AI Project ID"); | ||
} | ||
} | ||
|
||
_llmType() { | ||
return "watsonx_ai"; | ||
} | ||
|
||
/** | ||
* Calls the WatsonX AI endpoint and retrieves the result. | ||
* @param {string} prompt The input prompt. | ||
* @returns {Promise<string>} A promise that resolves to the generated string. | ||
*/ | ||
/** @ignore */ | ||
async _call( | ||
prompt: string, | ||
_options: this["ParsedCallOptions"] | ||
): Promise<string> { | ||
interface WatsonxAIResponse { | ||
results: { | ||
generated_text: string; | ||
generated_token_count: number; | ||
input_token_count: number; | ||
}[]; | ||
errors: { | ||
code: string; | ||
message: string; | ||
}[]; | ||
} | ||
const response = (await this.caller.call(async () => | ||
fetch(this.endpoint, { | ||
method: "POST", | ||
headers: { | ||
"Content-Type": "application/json", | ||
Accept: "application/json", | ||
Authorization: `Bearer ${await this.generateToken()}`, | ||
}, | ||
body: JSON.stringify({ | ||
project_id: this.projectId, | ||
model_id: this.modelId, | ||
input: prompt, | ||
parameters: this.modelParameters, | ||
}), | ||
}).then((res) => res.json()) | ||
)) as WatsonxAIResponse; | ||
|
||
/** | ||
* Handle Errors for invalid requests. | ||
*/ | ||
if (response.errors) { | ||
throw new Error(response.errors[0].message); | ||
} | ||
|
||
return response.results[0].generated_text; | ||
} | ||
|
||
async generateToken(): Promise<string> { | ||
if (this.ibmCloudToken && this.ibmCloudTokenExpiresAt) { | ||
if (this.ibmCloudTokenExpiresAt > Date.now()) { | ||
return this.ibmCloudToken; | ||
} | ||
} | ||
|
||
interface TokenResponse { | ||
access_token: string; | ||
expiration: number; | ||
} | ||
|
||
const urlTokenParams = new URLSearchParams(); | ||
urlTokenParams.append( | ||
"grant_type", | ||
"urn:ibm:params:oauth:grant-type:apikey" | ||
); | ||
urlTokenParams.append("apikey", this.ibmCloudApiKey as string); | ||
|
||
const data = (await fetch("https://iam.cloud.ibm.com/identity/token", { | ||
method: "POST", | ||
headers: { | ||
"Content-Type": "application/x-www-form-urlencoded", | ||
}, | ||
body: urlTokenParams, | ||
}).then((res) => res.json())) as TokenResponse; | ||
|
||
this.ibmCloudTokenExpiresAt = data.expiration * 1000; | ||
this.ibmCloudToken = data.access_token; | ||
|
||
return this.ibmCloudToken; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
b982b42
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Successfully deployed to the following URLs:
langchainjs-api-refs β ./docs/api_refs
langchainjs-api-refs-git-main-langchain.vercel.app
langchainjs-api-refs-langchain.vercel.app
langchainjs-api-docs.vercel.app
api.js.langchain.com
b982b42
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Successfully deployed to the following URLs:
langchainjs-docs β ./docs/core_docs/
langchainjs-docs-ruddy.vercel.app
langchainjs-docs-langchain.vercel.app
langchainjs-docs-git-main-langchain.vercel.app
js.langchain.com