Skip to content

Commit

Permalink
langchain[minor]: feat(LLM Integration): WatsonX AI Integration (#3399)
Browse files Browse the repository at this point in the history
* base watsonx file

* base setup route

* build

* add project id lc secret

* WatsonX AI Functional MVP

* enable custom model parameters

* enable custom model selection with llama 2 default

* wrap fetch in `this.caller.call`

* run format

* add request error handling

* update typedoc string

* yarn format

* add watsonx ai example

* watsonx-ai -> watsonx_ai

* Add watson x documentation

* delete old files

* yarn format

* add error for missing project id

* Add setup note about secrets to docs

* format

* remove redundant count

* add var set on class instantiation example

* format

* update ibm cloud api key to follow common convention

* fix type cast

* update llmType casing

* add iam token caching

* use expiration field

* Small style fixes in dos

* Update watsonx_ai.ts

---------

Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
chasemcdo and jacoblee93 authored Nov 30, 2023
1 parent b557b13 commit b982b42
Show file tree
Hide file tree
Showing 9 changed files with 261 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/api_refs/typedoc.json
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"./langchain/src/llms/googlepalm.ts",
"./langchain/src/llms/fireworks.ts",
"./langchain/src/llms/sagemaker_endpoint.ts",
"./langchain/src/llms/watsonx_ai.ts",
"./langchain/src/llms/bedrock/index.ts",
"./langchain/src/llms/bedrock/web.ts",
"./langchain/src/llms/llama_cpp.ts",
Expand Down
27 changes: 27 additions & 0 deletions docs/core_docs/docs/integrations/llms/watsonx_ai.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# WatsonX AI

LangChain.js supports integration with IBM WatsonX AI. Checkout [WatsonX AI](https://www.ibm.com/products/watsonx-ai) for a list of available models.

## Setup

You will need to set the following environment variables for using the WatsonX AI API.

1. `IBM_CLOUD_API_KEY` which can be generated via [IBM Cloud](https://cloud.ibm.com/iam/apikeys)
2. `WATSONX_PROJECT_ID` which can be found in your [project's manage tab](https://dataplatform.cloud.ibm.com/projects/?context=wx)

Alternatively, these can be set during the WatsonxAI Class instantiation as `ibmCloudApiKey` and `projectId` respectively.
For example:

```typescript
const model = new WatsonxAI({
ibmCloudApiKey: "My secret IBM Cloud API Key"
projectId: "My secret WatsonX AI Project id"
});
```

## Usage

import CodeBlock from "@theme/CodeBlock";
import WatsonxAiExample from "@examples/llms/watsonx_ai.ts";

<CodeBlock language="typescript">{WatsonxAiExample}</CodeBlock>
20 changes: 20 additions & 0 deletions examples/src/llms/watsonx_ai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { WatsonxAI } from "langchain/llms/watsonx_ai";

export const run = async () => {
// Note that modelParameters are optional
const model = new WatsonxAI({
modelId: "meta-llama/llama-2-70b-chat",
modelParameters: {
max_new_tokens: 100,
min_new_tokens: 0,
stop_sequences: [],
repetition_penalty: 1,
},
});

const res = await model.invoke(
"What would be a good company name for a company that makes colorful socks?"
);

console.log({ res });
};
3 changes: 3 additions & 0 deletions langchain/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ llms/fireworks.d.ts
llms/sagemaker_endpoint.cjs
llms/sagemaker_endpoint.js
llms/sagemaker_endpoint.d.ts
llms/watsonx_ai.cjs
llms/watsonx_ai.js
llms/watsonx_ai.d.ts
llms/bedrock.cjs
llms/bedrock.js
llms/bedrock.d.ts
Expand Down
8 changes: 8 additions & 0 deletions langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@
"llms/sagemaker_endpoint.cjs",
"llms/sagemaker_endpoint.js",
"llms/sagemaker_endpoint.d.ts",
"llms/watsonx_ai.cjs",
"llms/watsonx_ai.js",
"llms/watsonx_ai.d.ts",
"llms/bedrock.cjs",
"llms/bedrock.js",
"llms/bedrock.d.ts",
Expand Down Expand Up @@ -1756,6 +1759,11 @@
"import": "./llms/sagemaker_endpoint.js",
"require": "./llms/sagemaker_endpoint.cjs"
},
"./llms/watsonx_ai": {
"types": "./llms/watsonx_ai.d.ts",
"import": "./llms/watsonx_ai.js",
"require": "./llms/watsonx_ai.cjs"
},
"./llms/bedrock": {
"types": "./llms/bedrock.d.ts",
"import": "./llms/bedrock.js",
Expand Down
2 changes: 2 additions & 0 deletions langchain/scripts/create-entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ const entrypoints = {
"llms/googlepalm": "llms/googlepalm",
"llms/fireworks": "llms/fireworks",
"llms/sagemaker_endpoint": "llms/sagemaker_endpoint",
"llms/watsonx_ai": "llms/watsonx_ai",
"llms/bedrock": "llms/bedrock/index",
"llms/bedrock/web": "llms/bedrock/web",
"llms/llama_cpp": "llms/llama_cpp",
Expand Down Expand Up @@ -369,6 +370,7 @@ const requiresOptionalDependency = [
"llms/raycast",
"llms/replicate",
"llms/sagemaker_endpoint",
"llms/watsonx_ai",
"llms/bedrock",
"llms/bedrock/web",
"llms/llama_cpp",
Expand Down
194 changes: 194 additions & 0 deletions langchain/src/llms/watsonx_ai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import { BaseLLMCallOptions, BaseLLMParams, LLM } from "./base.js";
import { getEnvironmentVariable } from "../util/env.js";

/**
* The WatsonxAIParams interface defines the input parameters for
* the WatsonxAI class.
*/
export interface WatsonxAIParams extends BaseLLMParams {
/**
* WatsonX AI Complete Endpoint.
* Can be used if you want a fully custom endpoint.
*/
endpoint?: string;
/**
* IBM Cloud Compute Region.
* eg. us-south, us-east, etc.
*/
region?: string;
/**
* WatsonX AI Version.
* Date representing the WatsonX AI Version.
* eg. 2023-05-29
*/
version?: string;
/**
* WatsonX AI Key.
* Provide API Key if you do not wish to automatically pull from env.
*/
ibmCloudApiKey?: string;
/**
* WatsonX AI Key.
* Provide API Key if you do not wish to automatically pull from env.
*/
projectId?: string;
/**
* Parameters accepted by the WatsonX AI Endpoint.
*/
modelParameters?: Record<string, unknown>;
/**
* WatsonX AI Model ID.
*/
modelId?: string;
}

const endpointConstructor = (region: string, version: string) =>
`https://${region}.ml.cloud.ibm.com/ml/v1-beta/generation/text?version=${version}`;

/**
* The WatsonxAI class is used to interact with Watsonx AI
* Inference Endpoint models. It uses IBM Cloud for authentication.
* This requires your IBM Cloud API Key which is autoloaded if not specified.
*/

export class WatsonxAI extends LLM<BaseLLMCallOptions> {
static lc_name() {
return "WatsonxAI";
}

get lc_secrets(): { [key: string]: string } | undefined {
return {
ibmCloudApiKey: "IBM_CLOUD_API_KEY",
projectId: "WATSONX_PROJECT_ID",
};
}

endpoint: string;

region = "us-south";

version = "2023-05-29";

modelId = "meta-llama/llama-2-70b-chat";

modelKwargs?: Record<string, unknown>;

ibmCloudApiKey?: string;

ibmCloudToken?: string;

ibmCloudTokenExpiresAt?: number;

projectId?: string;

modelParameters?: Record<string, unknown>;

constructor(fields: WatsonxAIParams) {
super(fields);

this.region = fields?.region ?? this.region;
this.version = fields?.version ?? this.version;
this.modelId = fields?.modelId ?? this.modelId;
this.ibmCloudApiKey =
fields?.ibmCloudApiKey ?? getEnvironmentVariable("IBM_CLOUD_API_KEY");
this.projectId =
fields?.projectId ?? getEnvironmentVariable("WATSONX_PROJECT_ID");

this.endpoint =
fields?.endpoint ?? endpointConstructor(this.region, this.version);
this.modelParameters = fields.modelParameters;

if (!this.ibmCloudApiKey) {
throw new Error("Missing IBM Cloud API Key");
}

if (!this.projectId) {
throw new Error("Missing WatsonX AI Project ID");
}
}

_llmType() {
return "watsonx_ai";
}

/**
* Calls the WatsonX AI endpoint and retrieves the result.
* @param {string} prompt The input prompt.
* @returns {Promise<string>} A promise that resolves to the generated string.
*/
/** @ignore */
async _call(
prompt: string,
_options: this["ParsedCallOptions"]
): Promise<string> {
interface WatsonxAIResponse {
results: {
generated_text: string;
generated_token_count: number;
input_token_count: number;
}[];
errors: {
code: string;
message: string;
}[];
}
const response = (await this.caller.call(async () =>
fetch(this.endpoint, {
method: "POST",
headers: {
"Content-Type": "application/json",
Accept: "application/json",
Authorization: `Bearer ${await this.generateToken()}`,
},
body: JSON.stringify({
project_id: this.projectId,
model_id: this.modelId,
input: prompt,
parameters: this.modelParameters,
}),
}).then((res) => res.json())
)) as WatsonxAIResponse;

/**
* Handle Errors for invalid requests.
*/
if (response.errors) {
throw new Error(response.errors[0].message);
}

return response.results[0].generated_text;
}

async generateToken(): Promise<string> {
if (this.ibmCloudToken && this.ibmCloudTokenExpiresAt) {
if (this.ibmCloudTokenExpiresAt > Date.now()) {
return this.ibmCloudToken;
}
}

interface TokenResponse {
access_token: string;
expiration: number;
}

const urlTokenParams = new URLSearchParams();
urlTokenParams.append(
"grant_type",
"urn:ibm:params:oauth:grant-type:apikey"
);
urlTokenParams.append("apikey", this.ibmCloudApiKey as string);

const data = (await fetch("https://iam.cloud.ibm.com/identity/token", {
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
body: urlTokenParams,
}).then((res) => res.json())) as TokenResponse;

this.ibmCloudTokenExpiresAt = data.expiration * 1000;
this.ibmCloudToken = data.access_token;

return this.ibmCloudToken;
}
}
1 change: 1 addition & 0 deletions langchain/src/load/import_constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export const optionalImportEntrypoints = [
"langchain/llms/googlevertexai/web",
"langchain/llms/googlepalm",
"langchain/llms/sagemaker_endpoint",
"langchain/llms/watsonx_ai",
"langchain/llms/bedrock",
"langchain/llms/bedrock/web",
"langchain/llms/llama_cpp",
Expand Down
5 changes: 5 additions & 0 deletions langchain/src/load/import_type.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ export interface OptionalImportMap {
"langchain/llms/sagemaker_endpoint"?:
| typeof import("../llms/sagemaker_endpoint.js")
| Promise<typeof import("../llms/sagemaker_endpoint.js")>;
"langchain/llms/watsonx_ai"?:
| typeof import("../llms/watsonx_ai.js")
| Promise<typeof import("../llms/watsonx_ai.js")>;
"langchain/llms/bedrock"?:
| typeof import("../llms/bedrock/index.js")
| Promise<typeof import("../llms/bedrock/index.js")>;
Expand Down Expand Up @@ -520,6 +523,7 @@ export interface SecretMap {
GOOGLE_PALM_API_KEY?: string;
GOOGLE_VERTEX_AI_WEB_CREDENTIALS?: string;
HUGGINGFACEHUB_API_KEY?: string;
IBM_CLOUD_API_KEY?: string;
IFLYTEK_API_KEY?: string;
IFLYTEK_API_SECRET?: string;
MILVUS_PASSWORD?: string;
Expand Down Expand Up @@ -547,6 +551,7 @@ export interface SecretMap {
VECTARA_API_KEY?: string;
VECTARA_CORPUS_ID?: string;
VECTARA_CUSTOMER_ID?: string;
WATSONX_PROJECT_ID?: string;
WRITER_API_KEY?: string;
WRITER_ORG_ID?: string;
YC_API_KEY?: string;
Expand Down

2 comments on commit b982b42

@vercel
Copy link

@vercel vercel bot commented on b982b42 Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vercel
Copy link

@vercel vercel bot commented on b982b42 Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

langchainjs-docs – ./docs/core_docs/

langchainjs-docs-ruddy.vercel.app
langchainjs-docs-langchain.vercel.app
langchainjs-docs-git-main-langchain.vercel.app
js.langchain.com

Please sign in to comment.