From ed8f8ffef5f4b67f0cc5cdc7c3af23cbe34b2d85 Mon Sep 17 00:00:00 2001 From: faileon Date: Thu, 7 Dec 2023 09:57:18 +0100 Subject: [PATCH 1/4] streaming, chat model --- docs/api_refs/typedoc.json | 1 + examples/src/llms/watsonx_ai-chat.ts | 20 +++ langchain/.gitignore | 3 + langchain/package.json | 8 + langchain/scripts/create-entrypoints.js | 1 + langchain/src/chat_models/watsonx_ai.ts | 154 ++++++++++++++++++ langchain/src/llms/watsonx_ai.ts | 199 ++++++++---------------- langchain/src/load/import_map.ts | 1 + langchain/src/types/watsonx-types.ts | 58 +++++++ langchain/src/util/watsonx-client.ts | 165 ++++++++++++++++++++ 10 files changed, 478 insertions(+), 132 deletions(-) create mode 100644 examples/src/llms/watsonx_ai-chat.ts create mode 100644 langchain/src/chat_models/watsonx_ai.ts create mode 100644 langchain/src/types/watsonx-types.ts create mode 100644 langchain/src/util/watsonx-client.ts diff --git a/docs/api_refs/typedoc.json b/docs/api_refs/typedoc.json index 044d8a32270f..ecee3e559d66 100644 --- a/docs/api_refs/typedoc.json +++ b/docs/api_refs/typedoc.json @@ -204,6 +204,7 @@ "./langchain/src/chat_models/llama_cpp.ts", "./langchain/src/chat_models/yandex.ts", "./langchain/src/chat_models/fake.ts", + "./langchain/src/chat_models/watsonx_ai.ts", "./langchain/src/schema/index.ts", "./langchain/src/schema/document.ts", "./langchain/src/schema/output_parser.ts", diff --git a/examples/src/llms/watsonx_ai-chat.ts b/examples/src/llms/watsonx_ai-chat.ts new file mode 100644 index 000000000000..3d98051b9a46 --- /dev/null +++ b/examples/src/llms/watsonx_ai-chat.ts @@ -0,0 +1,20 @@ +import { WatsonChatModel } from "langchain/chat_models/watsonx_ai"; + +const model = new WatsonChatModel({ + clientConfig: { + region: "eu-de", + }, + modelParameters: { + max_new_tokens: 100, + }, +}); + +const stream = await model.stream( + "What would be a good company name for a company that makes colorful socks?" +); + +let text = ""; +for await (const chunk of stream) { + text += chunk.content; + console.log(text); +} diff --git a/langchain/.gitignore b/langchain/.gitignore index 99eaf327b22a..ef66a2c0491a 100644 --- a/langchain/.gitignore +++ b/langchain/.gitignore @@ -556,6 +556,9 @@ chat_models/yandex.d.ts chat_models/fake.cjs chat_models/fake.js chat_models/fake.d.ts +chat_models/watsonx_ai.cjs +chat_models/watsonx_ai.js +chat_models/watsonx_ai.d.ts schema.cjs schema.js schema.d.ts diff --git a/langchain/package.json b/langchain/package.json index 94fd29bb9156..64820525a6d5 100644 --- a/langchain/package.json +++ b/langchain/package.json @@ -568,6 +568,9 @@ "chat_models/fake.cjs", "chat_models/fake.js", "chat_models/fake.d.ts", + "chat_models/watsonx_ai.cjs", + "chat_models/watsonx_ai.js", + "chat_models/watsonx_ai.d.ts", "schema.cjs", "schema.js", "schema.d.ts", @@ -2398,6 +2401,11 @@ "import": "./chat_models/fake.js", "require": "./chat_models/fake.cjs" }, + "./chat_models/watsonx_ai": { + "types": "./chat_models/watsonx_ai.d.ts", + "import": "./chat_models/watsonx_ai.js", + "require": "./chat_models/watsonx_ai.cjs" + }, "./schema": { "types": "./schema.d.ts", "import": "./schema.js", diff --git a/langchain/scripts/create-entrypoints.js b/langchain/scripts/create-entrypoints.js index 65cc773bdf0c..6a6649b92022 100644 --- a/langchain/scripts/create-entrypoints.js +++ b/langchain/scripts/create-entrypoints.js @@ -217,6 +217,7 @@ const entrypoints = { "chat_models/llama_cpp": "chat_models/llama_cpp", "chat_models/yandex": "chat_models/yandex", "chat_models/fake": "chat_models/fake", + "chat_models/watsonx_ai": "chat_models/watsonx_ai", // schema schema: "schema/index", "schema/document": "schema/document", diff --git a/langchain/src/chat_models/watsonx_ai.ts b/langchain/src/chat_models/watsonx_ai.ts new file mode 100644 index 000000000000..2257210c7aab --- /dev/null +++ b/langchain/src/chat_models/watsonx_ai.ts @@ -0,0 +1,154 @@ +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { CallbackManagerForLLMRun } from "../callbacks/index.js"; +import { + AIMessageChunk, + BaseMessage, + ChatGenerationChunk, + ChatMessage, +} from "../schema/index.js"; +import { + WatsonModelParameters, + WatsonxAIParams, +} from "../types/watsonx-types.js"; +import { WatsonApiClient } from "../util/watsonx-client.js"; +import { BaseChatModelParams, SimpleChatModel } from "./base.js"; + +export class WatsonChatModel extends SimpleChatModel { + private readonly watsonApiClient!: WatsonApiClient; + + readonly modelId!: string; + + readonly modelParameters?: WatsonModelParameters; + + readonly projectId!: string; + + constructor(fields: WatsonxAIParams & BaseChatModelParams) { + super(fields); + + const { + clientConfig = {}, + modelId = "meta-llama/llama-2-70b-chat", + modelParameters, + projectId = getEnvironmentVariable("WATSONX_PROJECT_ID") ?? "", + } = fields; + + this.modelId = modelId; + this.modelParameters = modelParameters; + this.projectId = projectId; + + const { + apiKey = getEnvironmentVariable("IBM_CLOUD_API_KEY"), + apiVersion = "2023-05-29", + region = "us-south", + } = clientConfig; + + if (!apiKey) { + throw new Error("Missing IBM Cloud API Key"); + } + + if (!this.projectId) { + throw new Error("Missing WatsonX AI Project ID"); + } + + this.watsonApiClient = new WatsonApiClient({ + apiKey, + apiVersion, + region, + }); + } + + protected _formatMessagesAsPrompt(messages: BaseMessage[]): string { + return messages + .map((message) => { + let messageText; + if (message._getType() === "human") { + messageText = `[INST] ${message.content} [/INST]`; + } else if (message._getType() === "ai") { + messageText = message.content; + } else if (message._getType() === "system") { + messageText = `<> ${message.content} <>`; + } else if (ChatMessage.isInstance(message)) { + messageText = `\n\n${message.role[0].toUpperCase()}${message.role.slice( + 1 + )}: ${message.content}`; + } else { + console.warn( + `Unsupported message type passed to Watson: "${message._getType()}"` + ); + messageText = ""; + } + return messageText; + }) + .join("\n"); + } + + _combineLLMOutput() { + return {}; + } + + async _call( + messages: BaseMessage[], + options: this["ParsedCallOptions"], + runManager: CallbackManagerForLLMRun | undefined + ): Promise { + const chunks = []; + const stream = this._streamResponseChunks(messages, options, runManager); + for await (const chunk of stream) { + chunks.push(chunk.message.content); + } + return chunks.join(""); + } + + override async *_streamResponseChunks( + _messages: BaseMessage[], + _options: this["ParsedCallOptions"], + _runManager?: CallbackManagerForLLMRun + ): AsyncGenerator { + const formattedMessages = this._formatMessagesAsPrompt(_messages); + const stream = await this.caller.call(async () => + this.watsonApiClient.generateTextStream( + formattedMessages, + this.projectId, + this.modelId, + this.modelParameters + ) + ); + + for await (const data of stream) { + const [ + { + generated_text, + generated_token_count, + input_token_count, + stop_reason, + }, + ] = data.results; + const generationChunk = new ChatGenerationChunk({ + text: generated_text, + message: new AIMessageChunk({ content: generated_text }), + generationInfo: { + generated_token_count, + input_token_count, + stop_reason, + }, + }); + yield generationChunk; + await _runManager?.handleLLMNewToken(generated_text); + } + } + + static lc_name() { + return "WatsonxAIChat"; + } + + _llmType(): string { + return "watsonx_ai_chat"; + } + + get lc_secrets(): { [key: string]: string } | undefined { + return { + ibmCloudApiKey: "IBM_CLOUD_API_KEY", + projectId: "WATSONX_PROJECT_ID", + }; + } +} diff --git a/langchain/src/llms/watsonx_ai.ts b/langchain/src/llms/watsonx_ai.ts index dca510ba21c1..c9793925bd03 100644 --- a/langchain/src/llms/watsonx_ai.ts +++ b/langchain/src/llms/watsonx_ai.ts @@ -1,49 +1,13 @@ -import { BaseLLMCallOptions, BaseLLMParams, LLM } from "./base.js"; +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { ChatGenerationChunk } from "@langchain/core/outputs"; +import { AIMessageChunk } from "@langchain/core/messages"; +import { BaseLLMCallOptions, LLM } from "./base.js"; import { getEnvironmentVariable } from "../util/env.js"; - -/** - * The WatsonxAIParams interface defines the input parameters for - * the WatsonxAI class. - */ -export interface WatsonxAIParams extends BaseLLMParams { - /** - * WatsonX AI Complete Endpoint. - * Can be used if you want a fully custom endpoint. - */ - endpoint?: string; - /** - * IBM Cloud Compute Region. - * eg. us-south, us-east, etc. - */ - region?: string; - /** - * WatsonX AI Version. - * Date representing the WatsonX AI Version. - * eg. 2023-05-29 - */ - version?: string; - /** - * WatsonX AI Key. - * Provide API Key if you do not wish to automatically pull from env. - */ - ibmCloudApiKey?: string; - /** - * WatsonX AI Key. - * Provide API Key if you do not wish to automatically pull from env. - */ - projectId?: string; - /** - * Parameters accepted by the WatsonX AI Endpoint. - */ - modelParameters?: Record; - /** - * WatsonX AI Model ID. - */ - modelId?: string; -} - -const endpointConstructor = (region: string, version: string) => - `https://${region}.ml.cloud.ibm.com/ml/v1-beta/generation/text?version=${version}`; +import { WatsonApiClient } from "../util/watsonx-client.js"; +import { + WatsonModelParameters, + WatsonxAIParams, +} from "../types/watsonx-types.js"; /** * The WatsonxAI class is used to interact with Watsonx AI @@ -63,48 +27,43 @@ export class WatsonxAI extends LLM { }; } - endpoint: string; - - region = "us-south"; - - version = "2023-05-29"; - modelId = "meta-llama/llama-2-70b-chat"; modelKwargs?: Record; - ibmCloudApiKey?: string; - - ibmCloudToken?: string; - - ibmCloudTokenExpiresAt?: number; + projectId!: string; - projectId?: string; + modelParameters?: WatsonModelParameters; - modelParameters?: Record; + private readonly watsonApiClient!: WatsonApiClient; constructor(fields: WatsonxAIParams) { super(fields); - this.region = fields?.region ?? this.region; - this.version = fields?.version ?? this.version; this.modelId = fields?.modelId ?? this.modelId; - this.ibmCloudApiKey = - fields?.ibmCloudApiKey ?? getEnvironmentVariable("IBM_CLOUD_API_KEY"); this.projectId = - fields?.projectId ?? getEnvironmentVariable("WATSONX_PROJECT_ID"); - - this.endpoint = - fields?.endpoint ?? endpointConstructor(this.region, this.version); + fields?.projectId ?? getEnvironmentVariable("WATSONX_PROJECT_ID") ?? ""; this.modelParameters = fields.modelParameters; - if (!this.ibmCloudApiKey) { + const { + apiKey = getEnvironmentVariable("IBM_CLOUD_API_KEY"), + apiVersion = "2023-05-29", + region = "us-south", + } = fields.clientConfig ?? {}; + + if (!apiKey) { throw new Error("Missing IBM Cloud API Key"); } if (!this.projectId) { throw new Error("Missing WatsonX AI Project ID"); } + + this.watsonApiClient = new WatsonApiClient({ + apiKey, + region, + apiVersion, + }); } _llmType() { @@ -121,74 +80,50 @@ export class WatsonxAI extends LLM { prompt: string, _options: this["ParsedCallOptions"] ): Promise { - interface WatsonxAIResponse { - results: { - generated_text: string; - generated_token_count: number; - input_token_count: number; - }[]; - errors: { - code: string; - message: string; - }[]; - } - const response = (await this.caller.call(async () => - fetch(this.endpoint, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "application/json", - Authorization: `Bearer ${await this.generateToken()}`, - }, - body: JSON.stringify({ - project_id: this.projectId, - model_id: this.modelId, - input: prompt, - parameters: this.modelParameters, - }), - }).then((res) => res.json()) - )) as WatsonxAIResponse; - - /** - * Handle Errors for invalid requests. - */ - if (response.errors) { - throw new Error(response.errors[0].message); - } - - return response.results[0].generated_text; + return await this.caller.call(async () => + this.watsonApiClient.generateText( + prompt, + this.projectId, + this.modelId, + this.modelParameters + ) + ); } - async generateToken(): Promise { - if (this.ibmCloudToken && this.ibmCloudTokenExpiresAt) { - if (this.ibmCloudTokenExpiresAt > Date.now()) { - return this.ibmCloudToken; - } - } - - interface TokenResponse { - access_token: string; - expiration: number; - } - - const urlTokenParams = new URLSearchParams(); - urlTokenParams.append( - "grant_type", - "urn:ibm:params:oauth:grant-type:apikey" + async *_streamResponseChunks( + input: string, + _options: this["ParsedCallOptions"], + _runManager?: CallbackManagerForLLMRun + ) { + const stream = await this.caller.call(async () => + this.watsonApiClient.generateTextStream( + input, + this.projectId, + this.modelId, + this.modelParameters + ) ); - urlTokenParams.append("apikey", this.ibmCloudApiKey as string); - const data = (await fetch("https://iam.cloud.ibm.com/identity/token", { - method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - }, - body: urlTokenParams, - }).then((res) => res.json())) as TokenResponse; - - this.ibmCloudTokenExpiresAt = data.expiration * 1000; - this.ibmCloudToken = data.access_token; - - return this.ibmCloudToken; + for await (const data of stream) { + const [ + { + generated_text, + generated_token_count, + input_token_count, + stop_reason, + }, + ] = data.results; + const generationChunk = new ChatGenerationChunk({ + text: generated_text, + message: new AIMessageChunk({ content: generated_text }), + generationInfo: { + generated_token_count, + input_token_count, + stop_reason, + }, + }); + yield generationChunk; + await _runManager?.handleLLMNewToken(generated_text); + } } } diff --git a/langchain/src/load/import_map.ts b/langchain/src/load/import_map.ts index fa6ef233e228..8857870237f0 100644 --- a/langchain/src/load/import_map.ts +++ b/langchain/src/load/import_map.ts @@ -60,6 +60,7 @@ export * as chat_models__ollama from "../chat_models/ollama.js"; export * as chat_models__minimax from "../chat_models/minimax.js"; export * as chat_models__yandex from "../chat_models/yandex.js"; export * as chat_models__fake from "../chat_models/fake.js"; +export * as chat_models__watsonx_ai from "../chat_models/watsonx_ai.js"; export * as schema from "../schema/index.js"; export * as schema__document from "../schema/document.js"; export * as schema__output_parser from "../schema/output_parser.js"; diff --git a/langchain/src/types/watsonx-types.ts b/langchain/src/types/watsonx-types.ts new file mode 100644 index 000000000000..eabfbbb4d728 --- /dev/null +++ b/langchain/src/types/watsonx-types.ts @@ -0,0 +1,58 @@ +import { BaseLanguageModelParams } from "@langchain/core/language_models/base"; + +export interface WatsonModelParameters { + decoding_method?: "sample" | "greedy"; + max_new_tokens?: number; + min_new_tokens?: number; + stop_sequences?: string[]; + temperature?: number; + top_k?: number; + top_p?: number; + repetition_penalty?: number; +} + +export interface WatsonApiClientSettings { + /** + * IBM Cloud Compute Region. + * eg. us-south, us-east, etc. + */ + region?: string; + + /** + * WatsonX AI Key. + * Provide API Key if you do not wish to automatically pull from env. + */ + apiKey?: string; + + /** + * WatsonX AI Version. + * Date representing the WatsonX AI Version. + * eg. 2023-05-29 + */ + apiVersion?: string; +} + +/** + * The WatsonxAIParams interface defines the input parameters for + * the WatsonxAI class. + */ +export interface WatsonxAIParams extends BaseLanguageModelParams { + /** + * WatsonX AI Key. + * Provide API Key if you do not wish to automatically pull from env. + */ + projectId?: string; + /** + * Parameters accepted by the WatsonX AI Endpoint. + */ + modelParameters?: WatsonModelParameters; + /** + * WatsonX AI Model ID. + */ + modelId?: string; + + /** + * Watson rest api client config + */ + clientConfig?: WatsonApiClientSettings; +} diff --git a/langchain/src/util/watsonx-client.ts b/langchain/src/util/watsonx-client.ts new file mode 100644 index 000000000000..ea9d25f22306 --- /dev/null +++ b/langchain/src/util/watsonx-client.ts @@ -0,0 +1,165 @@ +import { + WatsonApiClientSettings, + WatsonModelParameters, +} from "../types/watsonx-types.js"; +import { convertEventStreamToIterableReadableDataStream } from "./event-source-parse.js"; + +type WatsonResult = { + generated_text: string; + generated_token_count: number; + input_token_count: number; + stop_reason: string; +}; + +type WatsonError = { + code: string; + message: string; +}; + +type WatsonResponse = { + model_id: string; + created_at: string; + results: WatsonResult[]; + errors: WatsonError[]; +}; + +type IAMTokenResponse = { + access_token: string; + refresh_token: string; + ims_user_id: number; + token_type: string; + expires_in: number; + expiration: number; + scope: string; +}; + +export class WatsonApiClient { + private iamToken?: IAMTokenResponse; + + private readonly apiKey!: string; + + private readonly apiVersion!: string; + + private readonly region!: string; + + private readonly baseUrl!: string; + + constructor({ + region, + apiKey, + apiVersion, + }: Required) { + this.apiKey = apiKey; + this.apiVersion = apiVersion; + this.region = region; + this.baseUrl = `https://${this.region}.ml.cloud.ibm.com`; + } + + private async getIAMToken() { + const url = `https://iam.cloud.ibm.com/identity/token`; + const payload = { + grant_type: "urn:ibm:params:oauth:grant-type:apikey", + apikey: this.apiKey, + }; + const headers = { + Accept: "application/json", + "Content-Type": "application/x-www-form-urlencoded", + }; + + const formBody = Object.entries(payload).reduce((acc, [key, value]) => { + const encodedKey = encodeURIComponent(key as string); + const encodedValue = encodeURIComponent(value as string); + acc.push(`${encodedKey}=${encodedValue}`); + return acc; + }, [] as string[]); + const body = formBody.join("&"); + + const response = await fetch(url, { + method: "POST", + headers, + body, + }); + + return (await response.json()) as IAMTokenResponse; + } + + private async getJwt() { + // we have token and it's not expired + if (this.iamToken && this.iamToken.expiration * 1000 > Date.now()) { + return this.iamToken.access_token; + } + + // we don't have token or its expired + this.iamToken = await this.getIAMToken(); + return this.iamToken.access_token; + } + + public async *generateTextStream( + input: string, + project_id: string, + model_id: string, + parameters?: WatsonModelParameters + ) { + const jwt = await this.getJwt(); + + const url = `${this.baseUrl}/ml/v1-beta/generation/text_stream?version=${this.apiVersion}`; + const headers = { + Authorization: `Bearer ${jwt}`, + "Content-Type": "application/json", + }; + const { body, ok, statusText } = await fetch(url, { + method: "POST", + headers, + body: JSON.stringify({ + input, + project_id, + model_id, + parameters, + }), + }); + + if (!body || !ok) { + throw new Error(statusText); + } + + const stream = convertEventStreamToIterableReadableDataStream(body); + + for await (const chunk of stream) { + yield JSON.parse(chunk) as WatsonResponse; + } + } + + public async generateText( + input: string, + project_id: string, + model_id: string, + parameters?: WatsonModelParameters + ) { + const jwt = await this.getJwt(); + + const url = `${this.baseUrl}/ml/v1-beta/generation/text?version=${this.apiVersion}`; + const headers = { + Authorization: `Bearer ${jwt}`, + "Content-Type": "application/json", + }; + + const response = await fetch(url, { + method: "POST", + headers, + body: JSON.stringify({ + input, + project_id, + model_id, + parameters, + }), + }); + + const data = (await response.json()) as WatsonResponse; + + if (data.errors) { + throw new Error(data.errors[0].message); + } + + return data.results[0].generated_text; + } +} From 8a52a65350b473ce2b651990f4831bbea3b83a6e Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Tue, 12 Dec 2023 16:57:40 -0800 Subject: [PATCH 2/4] Format --- langchain/src/chat_models/watsonx_ai.ts | 155 +----------------- .../src/chat_models/watsonx_ai.ts | 7 +- .../src/llms/watsonx_ai.ts | 7 +- .../src/types/watsonx-types.ts | 2 +- .../src/utils/watsonx-client.ts | 2 +- 5 files changed, 13 insertions(+), 160 deletions(-) diff --git a/langchain/src/chat_models/watsonx_ai.ts b/langchain/src/chat_models/watsonx_ai.ts index 2257210c7aab..66464c608987 100644 --- a/langchain/src/chat_models/watsonx_ai.ts +++ b/langchain/src/chat_models/watsonx_ai.ts @@ -1,154 +1 @@ -import { getEnvironmentVariable } from "@langchain/core/utils/env"; -import { CallbackManagerForLLMRun } from "../callbacks/index.js"; -import { - AIMessageChunk, - BaseMessage, - ChatGenerationChunk, - ChatMessage, -} from "../schema/index.js"; -import { - WatsonModelParameters, - WatsonxAIParams, -} from "../types/watsonx-types.js"; -import { WatsonApiClient } from "../util/watsonx-client.js"; -import { BaseChatModelParams, SimpleChatModel } from "./base.js"; - -export class WatsonChatModel extends SimpleChatModel { - private readonly watsonApiClient!: WatsonApiClient; - - readonly modelId!: string; - - readonly modelParameters?: WatsonModelParameters; - - readonly projectId!: string; - - constructor(fields: WatsonxAIParams & BaseChatModelParams) { - super(fields); - - const { - clientConfig = {}, - modelId = "meta-llama/llama-2-70b-chat", - modelParameters, - projectId = getEnvironmentVariable("WATSONX_PROJECT_ID") ?? "", - } = fields; - - this.modelId = modelId; - this.modelParameters = modelParameters; - this.projectId = projectId; - - const { - apiKey = getEnvironmentVariable("IBM_CLOUD_API_KEY"), - apiVersion = "2023-05-29", - region = "us-south", - } = clientConfig; - - if (!apiKey) { - throw new Error("Missing IBM Cloud API Key"); - } - - if (!this.projectId) { - throw new Error("Missing WatsonX AI Project ID"); - } - - this.watsonApiClient = new WatsonApiClient({ - apiKey, - apiVersion, - region, - }); - } - - protected _formatMessagesAsPrompt(messages: BaseMessage[]): string { - return messages - .map((message) => { - let messageText; - if (message._getType() === "human") { - messageText = `[INST] ${message.content} [/INST]`; - } else if (message._getType() === "ai") { - messageText = message.content; - } else if (message._getType() === "system") { - messageText = `<> ${message.content} <>`; - } else if (ChatMessage.isInstance(message)) { - messageText = `\n\n${message.role[0].toUpperCase()}${message.role.slice( - 1 - )}: ${message.content}`; - } else { - console.warn( - `Unsupported message type passed to Watson: "${message._getType()}"` - ); - messageText = ""; - } - return messageText; - }) - .join("\n"); - } - - _combineLLMOutput() { - return {}; - } - - async _call( - messages: BaseMessage[], - options: this["ParsedCallOptions"], - runManager: CallbackManagerForLLMRun | undefined - ): Promise { - const chunks = []; - const stream = this._streamResponseChunks(messages, options, runManager); - for await (const chunk of stream) { - chunks.push(chunk.message.content); - } - return chunks.join(""); - } - - override async *_streamResponseChunks( - _messages: BaseMessage[], - _options: this["ParsedCallOptions"], - _runManager?: CallbackManagerForLLMRun - ): AsyncGenerator { - const formattedMessages = this._formatMessagesAsPrompt(_messages); - const stream = await this.caller.call(async () => - this.watsonApiClient.generateTextStream( - formattedMessages, - this.projectId, - this.modelId, - this.modelParameters - ) - ); - - for await (const data of stream) { - const [ - { - generated_text, - generated_token_count, - input_token_count, - stop_reason, - }, - ] = data.results; - const generationChunk = new ChatGenerationChunk({ - text: generated_text, - message: new AIMessageChunk({ content: generated_text }), - generationInfo: { - generated_token_count, - input_token_count, - stop_reason, - }, - }); - yield generationChunk; - await _runManager?.handleLLMNewToken(generated_text); - } - } - - static lc_name() { - return "WatsonxAIChat"; - } - - _llmType(): string { - return "watsonx_ai_chat"; - } - - get lc_secrets(): { [key: string]: string } | undefined { - return { - ibmCloudApiKey: "IBM_CLOUD_API_KEY", - projectId: "WATSONX_PROJECT_ID", - }; - } -} +export * from "@langchain/community/chat_models/watsonx_ai"; diff --git a/libs/langchain-community/src/chat_models/watsonx_ai.ts b/libs/langchain-community/src/chat_models/watsonx_ai.ts index 6731d3c2a890..3ebbee4076d4 100644 --- a/libs/langchain-community/src/chat_models/watsonx_ai.ts +++ b/libs/langchain-community/src/chat_models/watsonx_ai.ts @@ -1,5 +1,8 @@ import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; -import { type BaseChatModelParams, SimpleChatModel } from "@langchain/core/language_models/chat_models"; +import { + type BaseChatModelParams, + SimpleChatModel, +} from "@langchain/core/language_models/chat_models"; import { AIMessageChunk, BaseMessage, @@ -151,4 +154,4 @@ export class WatsonChatModel extends SimpleChatModel { projectId: "WATSONX_PROJECT_ID", }; } -} \ No newline at end of file +} diff --git a/libs/langchain-community/src/llms/watsonx_ai.ts b/libs/langchain-community/src/llms/watsonx_ai.ts index 8b99287fd7be..47aff32a9b4a 100644 --- a/libs/langchain-community/src/llms/watsonx_ai.ts +++ b/libs/langchain-community/src/llms/watsonx_ai.ts @@ -1,7 +1,10 @@ import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; import { ChatGenerationChunk } from "@langchain/core/outputs"; import { AIMessageChunk } from "@langchain/core/messages"; -import { type BaseLLMCallOptions, LLM } from "@langchain/core/language_models/llms"; +import { + type BaseLLMCallOptions, + LLM, +} from "@langchain/core/language_models/llms"; import { getEnvironmentVariable } from "@langchain/core/utils/env"; import { WatsonApiClient } from "../utils/watsonx-client.js"; import type { @@ -126,4 +129,4 @@ export class WatsonxAI extends LLM { await _runManager?.handleLLMNewToken(generated_text); } } -} \ No newline at end of file +} diff --git a/libs/langchain-community/src/types/watsonx-types.ts b/libs/langchain-community/src/types/watsonx-types.ts index a51e33bf0936..eabfbbb4d728 100644 --- a/libs/langchain-community/src/types/watsonx-types.ts +++ b/libs/langchain-community/src/types/watsonx-types.ts @@ -55,4 +55,4 @@ export interface WatsonxAIParams extends BaseLanguageModelParams { * Watson rest api client config */ clientConfig?: WatsonApiClientSettings; -} \ No newline at end of file +} diff --git a/libs/langchain-community/src/utils/watsonx-client.ts b/libs/langchain-community/src/utils/watsonx-client.ts index d93a0502a093..a3e1d3ae67ea 100644 --- a/libs/langchain-community/src/utils/watsonx-client.ts +++ b/libs/langchain-community/src/utils/watsonx-client.ts @@ -162,4 +162,4 @@ export class WatsonApiClient { return data.results[0].generated_text; } -} \ No newline at end of file +} From 48e4380bec2bcb1fd4fff51d770cac02f447399d Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Tue, 12 Dec 2023 17:03:48 -0800 Subject: [PATCH 3/4] Remove unused files --- langchain/src/load/import_type.d.ts | 2 - langchain/src/types/watsonx-types.ts | 58 ---------- langchain/src/util/watsonx-client.ts | 165 --------------------------- 3 files changed, 225 deletions(-) delete mode 100644 langchain/src/types/watsonx-types.ts delete mode 100644 langchain/src/util/watsonx-client.ts diff --git a/langchain/src/load/import_type.d.ts b/langchain/src/load/import_type.d.ts index 03fee33c4121..050171f77ecd 100644 --- a/langchain/src/load/import_type.d.ts +++ b/langchain/src/load/import_type.d.ts @@ -521,10 +521,8 @@ export interface OptionalImportMap { export interface SecretMap { ANTHROPIC_API_KEY?: string; - IBM_CLOUD_API_KEY?: string; OPENAI_API_KEY?: string; PROMPTLAYER_API_KEY?: string; REMOTE_RETRIEVER_AUTH_BEARER?: string; - WATSONX_PROJECT_ID?: string; ZAPIER_NLA_API_KEY?: string; } diff --git a/langchain/src/types/watsonx-types.ts b/langchain/src/types/watsonx-types.ts deleted file mode 100644 index eabfbbb4d728..000000000000 --- a/langchain/src/types/watsonx-types.ts +++ /dev/null @@ -1,58 +0,0 @@ -import { BaseLanguageModelParams } from "@langchain/core/language_models/base"; - -export interface WatsonModelParameters { - decoding_method?: "sample" | "greedy"; - max_new_tokens?: number; - min_new_tokens?: number; - stop_sequences?: string[]; - temperature?: number; - top_k?: number; - top_p?: number; - repetition_penalty?: number; -} - -export interface WatsonApiClientSettings { - /** - * IBM Cloud Compute Region. - * eg. us-south, us-east, etc. - */ - region?: string; - - /** - * WatsonX AI Key. - * Provide API Key if you do not wish to automatically pull from env. - */ - apiKey?: string; - - /** - * WatsonX AI Version. - * Date representing the WatsonX AI Version. - * eg. 2023-05-29 - */ - apiVersion?: string; -} - -/** - * The WatsonxAIParams interface defines the input parameters for - * the WatsonxAI class. - */ -export interface WatsonxAIParams extends BaseLanguageModelParams { - /** - * WatsonX AI Key. - * Provide API Key if you do not wish to automatically pull from env. - */ - projectId?: string; - /** - * Parameters accepted by the WatsonX AI Endpoint. - */ - modelParameters?: WatsonModelParameters; - /** - * WatsonX AI Model ID. - */ - modelId?: string; - - /** - * Watson rest api client config - */ - clientConfig?: WatsonApiClientSettings; -} diff --git a/langchain/src/util/watsonx-client.ts b/langchain/src/util/watsonx-client.ts deleted file mode 100644 index ea9d25f22306..000000000000 --- a/langchain/src/util/watsonx-client.ts +++ /dev/null @@ -1,165 +0,0 @@ -import { - WatsonApiClientSettings, - WatsonModelParameters, -} from "../types/watsonx-types.js"; -import { convertEventStreamToIterableReadableDataStream } from "./event-source-parse.js"; - -type WatsonResult = { - generated_text: string; - generated_token_count: number; - input_token_count: number; - stop_reason: string; -}; - -type WatsonError = { - code: string; - message: string; -}; - -type WatsonResponse = { - model_id: string; - created_at: string; - results: WatsonResult[]; - errors: WatsonError[]; -}; - -type IAMTokenResponse = { - access_token: string; - refresh_token: string; - ims_user_id: number; - token_type: string; - expires_in: number; - expiration: number; - scope: string; -}; - -export class WatsonApiClient { - private iamToken?: IAMTokenResponse; - - private readonly apiKey!: string; - - private readonly apiVersion!: string; - - private readonly region!: string; - - private readonly baseUrl!: string; - - constructor({ - region, - apiKey, - apiVersion, - }: Required) { - this.apiKey = apiKey; - this.apiVersion = apiVersion; - this.region = region; - this.baseUrl = `https://${this.region}.ml.cloud.ibm.com`; - } - - private async getIAMToken() { - const url = `https://iam.cloud.ibm.com/identity/token`; - const payload = { - grant_type: "urn:ibm:params:oauth:grant-type:apikey", - apikey: this.apiKey, - }; - const headers = { - Accept: "application/json", - "Content-Type": "application/x-www-form-urlencoded", - }; - - const formBody = Object.entries(payload).reduce((acc, [key, value]) => { - const encodedKey = encodeURIComponent(key as string); - const encodedValue = encodeURIComponent(value as string); - acc.push(`${encodedKey}=${encodedValue}`); - return acc; - }, [] as string[]); - const body = formBody.join("&"); - - const response = await fetch(url, { - method: "POST", - headers, - body, - }); - - return (await response.json()) as IAMTokenResponse; - } - - private async getJwt() { - // we have token and it's not expired - if (this.iamToken && this.iamToken.expiration * 1000 > Date.now()) { - return this.iamToken.access_token; - } - - // we don't have token or its expired - this.iamToken = await this.getIAMToken(); - return this.iamToken.access_token; - } - - public async *generateTextStream( - input: string, - project_id: string, - model_id: string, - parameters?: WatsonModelParameters - ) { - const jwt = await this.getJwt(); - - const url = `${this.baseUrl}/ml/v1-beta/generation/text_stream?version=${this.apiVersion}`; - const headers = { - Authorization: `Bearer ${jwt}`, - "Content-Type": "application/json", - }; - const { body, ok, statusText } = await fetch(url, { - method: "POST", - headers, - body: JSON.stringify({ - input, - project_id, - model_id, - parameters, - }), - }); - - if (!body || !ok) { - throw new Error(statusText); - } - - const stream = convertEventStreamToIterableReadableDataStream(body); - - for await (const chunk of stream) { - yield JSON.parse(chunk) as WatsonResponse; - } - } - - public async generateText( - input: string, - project_id: string, - model_id: string, - parameters?: WatsonModelParameters - ) { - const jwt = await this.getJwt(); - - const url = `${this.baseUrl}/ml/v1-beta/generation/text?version=${this.apiVersion}`; - const headers = { - Authorization: `Bearer ${jwt}`, - "Content-Type": "application/json", - }; - - const response = await fetch(url, { - method: "POST", - headers, - body: JSON.stringify({ - input, - project_id, - model_id, - parameters, - }), - }); - - const data = (await response.json()) as WatsonResponse; - - if (data.errors) { - throw new Error(data.errors[0].message); - } - - return data.results[0].generated_text; - } -} From f6dba28368bb7f2b8b53c29817515562a7b8a399 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Tue, 12 Dec 2023 17:09:51 -0800 Subject: [PATCH 4/4] Rename --- examples/src/llms/watsonx_ai-chat.ts | 4 ++-- libs/langchain-community/src/chat_models/watsonx_ai.ts | 6 +++--- libs/langchain-community/src/llms/watsonx_ai.ts | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/src/llms/watsonx_ai-chat.ts b/examples/src/llms/watsonx_ai-chat.ts index 3d98051b9a46..8a271139878d 100644 --- a/examples/src/llms/watsonx_ai-chat.ts +++ b/examples/src/llms/watsonx_ai-chat.ts @@ -1,6 +1,6 @@ -import { WatsonChatModel } from "langchain/chat_models/watsonx_ai"; +import { WatsonxAIChat } from "langchain/chat_models/watsonx_ai"; -const model = new WatsonChatModel({ +const model = new WatsonxAIChat({ clientConfig: { region: "eu-de", }, diff --git a/libs/langchain-community/src/chat_models/watsonx_ai.ts b/libs/langchain-community/src/chat_models/watsonx_ai.ts index 3ebbee4076d4..838c35a0b8d6 100644 --- a/libs/langchain-community/src/chat_models/watsonx_ai.ts +++ b/libs/langchain-community/src/chat_models/watsonx_ai.ts @@ -16,8 +16,8 @@ import type { } from "../types/watsonx-types.js"; import { WatsonApiClient } from "../utils/watsonx-client.js"; -export class WatsonChatModel extends SimpleChatModel { - private readonly watsonApiClient!: WatsonApiClient; +export class WatsonxAIChat extends SimpleChatModel { + private readonly watsonApiClient: WatsonApiClient; readonly modelId!: string; @@ -145,7 +145,7 @@ export class WatsonChatModel extends SimpleChatModel { } _llmType(): string { - return "watsonx_ai_chat"; + return "watsonx_ai"; } get lc_secrets(): { [key: string]: string } | undefined { diff --git a/libs/langchain-community/src/llms/watsonx_ai.ts b/libs/langchain-community/src/llms/watsonx_ai.ts index 47aff32a9b4a..1e047ad20b27 100644 --- a/libs/langchain-community/src/llms/watsonx_ai.ts +++ b/libs/langchain-community/src/llms/watsonx_ai.ts @@ -38,7 +38,7 @@ export class WatsonxAI extends LLM { modelParameters?: WatsonModelParameters; - private readonly watsonApiClient!: WatsonApiClient; + private readonly watsonApiClient: WatsonApiClient; constructor(fields: WatsonxAIParams) { super(fields);