-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: added ibm watsonx chat model and streaming to base llm model #3577
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import { WatsonxAIChat } from "langchain/chat_models/watsonx_ai"; | ||
|
||
const model = new WatsonxAIChat({ | ||
clientConfig: { | ||
region: "eu-de", | ||
}, | ||
modelParameters: { | ||
max_new_tokens: 100, | ||
}, | ||
}); | ||
|
||
const stream = await model.stream( | ||
"What would be a good company name for a company that makes colorful socks?" | ||
); | ||
|
||
let text = ""; | ||
for await (const chunk of stream) { | ||
text += chunk.content; | ||
console.log(text); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -217,6 +217,7 @@ const entrypoints = { | |
"chat_models/llama_cpp": "chat_models/llama_cpp", | ||
"chat_models/yandex": "chat_models/yandex", | ||
"chat_models/fake": "chat_models/fake", | ||
"chat_models/watsonx_ai": "chat_models/watsonx_ai", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We no longer are adding integrations to the |
||
// schema | ||
schema: "schema/index", | ||
"schema/document": "schema/document", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
export * from "@langchain/community/chat_models/watsonx_ai"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. drop file |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,157 @@ | ||||||
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; | ||||||
import { | ||||||
type BaseChatModelParams, | ||||||
SimpleChatModel, | ||||||
} from "@langchain/core/language_models/chat_models"; | ||||||
import { | ||||||
AIMessageChunk, | ||||||
BaseMessage, | ||||||
ChatMessage, | ||||||
} from "@langchain/core/messages"; | ||||||
import { ChatGenerationChunk } from "@langchain/core/outputs"; | ||||||
import { getEnvironmentVariable } from "@langchain/core/utils/env"; | ||||||
import type { | ||||||
WatsonModelParameters, | ||||||
WatsonxAIParams, | ||||||
} from "../types/watsonx-types.js"; | ||||||
import { WatsonApiClient } from "../utils/watsonx-client.js"; | ||||||
|
||||||
export class WatsonxAIChat extends SimpleChatModel { | ||||||
private readonly watsonApiClient: WatsonApiClient; | ||||||
|
||||||
readonly modelId!: string; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't use definite assignment assertions, instead make it a non nullable type and add checks in the constructor that verify the value is defined. These can come back to bite you in production (I've done this before and it was not fun to debug 😅) |
||||||
|
||||||
readonly modelParameters?: WatsonModelParameters; | ||||||
|
||||||
readonly projectId!: string; | ||||||
|
||||||
constructor(fields: WatsonxAIParams & BaseChatModelParams) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lets redefine this as an interface above the class: interface WatsonxAIChatParams extends WatsonxAIParams, BaseChatModelParams {}; This way it's easy to:
Also, I believe all the properties on those two interfaces are optional, so we should be able to do:
Suggested change
|
||||||
super(fields); | ||||||
|
||||||
const { | ||||||
clientConfig = {}, | ||||||
modelId = "meta-llama/llama-2-70b-chat", | ||||||
modelParameters, | ||||||
projectId = getEnvironmentVariable("WATSONX_PROJECT_ID") ?? "", | ||||||
} = fields; | ||||||
|
||||||
this.modelId = modelId; | ||||||
this.modelParameters = modelParameters; | ||||||
this.projectId = projectId; | ||||||
|
||||||
const { | ||||||
apiKey = getEnvironmentVariable("IBM_CLOUD_API_KEY"), | ||||||
apiVersion = "2023-05-29", | ||||||
region = "us-south", | ||||||
} = clientConfig; | ||||||
|
||||||
if (!apiKey) { | ||||||
throw new Error("Missing IBM Cloud API Key"); | ||||||
} | ||||||
|
||||||
if (!this.projectId) { | ||||||
throw new Error("Missing WatsonX AI Project ID"); | ||||||
} | ||||||
|
||||||
this.watsonApiClient = new WatsonApiClient({ | ||||||
apiKey, | ||||||
apiVersion, | ||||||
region, | ||||||
}); | ||||||
} | ||||||
|
||||||
protected _formatMessagesAsPrompt(messages: BaseMessage[]): string { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should make a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes definitely, I was thinking about it too. Right now this is blatantly copy pasted from the Ollama model and assumes the user only uses llama based chat model. However Watson offers the ability to run many models, including spinning any model from HF... |
||||||
return messages | ||||||
.map((message) => { | ||||||
let messageText; | ||||||
if (message._getType() === "human") { | ||||||
messageText = `[INST] ${message.content} [/INST]`; | ||||||
} else if (message._getType() === "ai") { | ||||||
messageText = message.content; | ||||||
} else if (message._getType() === "system") { | ||||||
messageText = `<<SYS>> ${message.content} <</SYS>>`; | ||||||
} else if (ChatMessage.isInstance(message)) { | ||||||
messageText = `\n\n${message.role[0].toUpperCase()}${message.role.slice( | ||||||
1 | ||||||
)}: ${message.content}`; | ||||||
} else { | ||||||
console.warn( | ||||||
`Unsupported message type passed to Watson: "${message._getType()}"` | ||||||
); | ||||||
messageText = ""; | ||||||
} | ||||||
return messageText; | ||||||
}) | ||||||
.join("\n"); | ||||||
} | ||||||
|
||||||
_combineLLMOutput() { | ||||||
return {}; | ||||||
} | ||||||
|
||||||
async _call( | ||||||
messages: BaseMessage[], | ||||||
options: this["ParsedCallOptions"], | ||||||
runManager: CallbackManagerForLLMRun | undefined | ||||||
): Promise<string> { | ||||||
const chunks = []; | ||||||
const stream = this._streamResponseChunks(messages, options, runManager); | ||||||
for await (const chunk of stream) { | ||||||
chunks.push(chunk.message.content); | ||||||
} | ||||||
return chunks.join(""); | ||||||
} | ||||||
|
||||||
override async *_streamResponseChunks( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think you need the |
||||||
_messages: BaseMessage[], | ||||||
_options: this["ParsedCallOptions"], | ||||||
_runManager?: CallbackManagerForLLMRun | ||||||
Comment on lines
+106
to
+108
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Drop the underscore if they're being used. Typically, variables prefixed with an underscore are unused, and the underscore is used to bypass a lint rule for no-unused-variables |
||||||
): AsyncGenerator<ChatGenerationChunk> { | ||||||
const formattedMessages = this._formatMessagesAsPrompt(_messages); | ||||||
const stream = await this.caller.call(async () => | ||||||
this.watsonApiClient.generateTextStream( | ||||||
formattedMessages, | ||||||
this.projectId, | ||||||
this.modelId, | ||||||
this.modelParameters | ||||||
) | ||||||
); | ||||||
|
||||||
for await (const data of stream) { | ||||||
const [ | ||||||
{ | ||||||
generated_text, | ||||||
generated_token_count, | ||||||
input_token_count, | ||||||
stop_reason, | ||||||
}, | ||||||
] = data.results; | ||||||
const generationChunk = new ChatGenerationChunk({ | ||||||
text: generated_text, | ||||||
message: new AIMessageChunk({ content: generated_text }), | ||||||
generationInfo: { | ||||||
generated_token_count, | ||||||
input_token_count, | ||||||
stop_reason, | ||||||
}, | ||||||
}); | ||||||
yield generationChunk; | ||||||
await _runManager?.handleLLMNewToken(generated_text); | ||||||
} | ||||||
} | ||||||
|
||||||
static lc_name() { | ||||||
return "WatsonxAIChat"; | ||||||
} | ||||||
|
||||||
_llmType(): string { | ||||||
return "watsonx_ai"; | ||||||
} | ||||||
|
||||||
get lc_secrets(): { [key: string]: string } | undefined { | ||||||
return { | ||||||
ibmCloudApiKey: "IBM_CLOUD_API_KEY", | ||||||
projectId: "WATSONX_PROJECT_ID", | ||||||
}; | ||||||
} | ||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey there! I've noticed that this PR introduces new dependencies related to chat models. This comment is just to flag the change for maintainers to review the impact on peer/dev/hard dependencies. Great work on the PR!