From bf0ca07bc13f04702d9baee80a33151f80241de8 Mon Sep 17 00:00:00 2001 From: Oleg Ivaniv Date: Thu, 1 Feb 2024 17:18:20 +0100 Subject: [PATCH] fix: Improve BaseChatModel instance guard to fix Mistral Chat error Signed-off-by: Oleg Ivaniv --- .../agents/Agent/agents/ConversationalAgent/execute.ts | 9 +++------ .../nodes/agents/Agent/agents/ReActAgent/execute.ts | 5 +++-- .../nodes/chains/ChainLLM/ChainLlm.node.ts | 5 +++-- packages/@n8n/nodes-langchain/utils/helpers.ts | 7 +++++++ packages/@n8n/nodes-langchain/utils/logWrapper.ts | 3 ++- 5 files changed, 18 insertions(+), 11 deletions(-) diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ConversationalAgent/execute.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ConversationalAgent/execute.ts index dd0ff5c07b328..e371c862ef23f 100644 --- a/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ConversationalAgent/execute.ts +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ConversationalAgent/execute.ts @@ -6,24 +6,21 @@ import { } from 'n8n-workflow'; import { initializeAgentExecutorWithOptions } from 'langchain/agents'; -import { BaseChatModel } from 'langchain/chat_models/base'; import type { Tool } from 'langchain/tools'; import type { BaseChatMemory } from 'langchain/memory'; import type { BaseOutputParser } from 'langchain/schema/output_parser'; import { PromptTemplate } from 'langchain/prompts'; import { CombiningOutputParser } from 'langchain/output_parsers'; +import { isChatInstance } from '../../../../../utils/helpers'; export async function conversationalAgentExecute( this: IExecuteFunctions, ): Promise { this.logger.verbose('Executing Conversational Agent'); - const model = (await this.getInputConnectionData( - NodeConnectionType.AiLanguageModel, - 0, - )) as BaseChatModel; + const model = await this.getInputConnectionData(NodeConnectionType.AiLanguageModel, 0); - if (!(model instanceof BaseChatModel)) { + if (!isChatInstance(model)) { throw new NodeOperationError(this.getNode(), 'Conversational Agent requires Chat Model'); } diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ReActAgent/execute.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ReActAgent/execute.ts index 492272f5af74c..140d6444feacc 100644 --- a/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ReActAgent/execute.ts +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ReActAgent/execute.ts @@ -11,7 +11,8 @@ import type { Tool } from 'langchain/tools'; import type { BaseOutputParser } from 'langchain/schema/output_parser'; import { PromptTemplate } from 'langchain/prompts'; import { CombiningOutputParser } from 'langchain/output_parsers'; -import { BaseChatModel } from 'langchain/chat_models/base'; +import type { BaseChatModel } from 'langchain/chat_models/base'; +import { isChatInstance } from '../../../../../utils/helpers'; export async function reActAgentAgentExecute( this: IExecuteFunctions, @@ -38,7 +39,7 @@ export async function reActAgentAgentExecute( }; let agent: ChatAgent | ZeroShotAgent; - if (model instanceof BaseChatModel) { + if (isChatInstance(model)) { agent = ChatAgent.fromLLMAndTools(model, tools, { prefix: options.prefix, suffix: options.suffixChat, diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/ChainLlm.node.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/ChainLlm.node.ts index d1a5363636715..1fd27972b143b 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/ChainLlm.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/ChainLlm.node.ts @@ -19,9 +19,10 @@ import { import type { BaseOutputParser } from 'langchain/schema/output_parser'; import { CombiningOutputParser } from 'langchain/output_parsers'; import { LLMChain } from 'langchain/chains'; -import { BaseChatModel } from 'langchain/chat_models/base'; +import type { BaseChatModel } from 'langchain/chat_models/base'; import { HumanMessage } from 'langchain/schema'; import { getTemplateNoticeField } from '../../../utils/sharedFields'; +import { isChatInstance } from '../../../utils/helpers'; interface MessagesTemplate { type: string; @@ -94,7 +95,7 @@ async function getChainPromptTemplate( partialVariables: formatInstructions ? { formatInstructions } : undefined, }); - if (llm instanceof BaseChatModel) { + if (isChatInstance(llm)) { const parsedMessages = await Promise.all( (messages ?? []).map(async (message) => { const messageClass = [ diff --git a/packages/@n8n/nodes-langchain/utils/helpers.ts b/packages/@n8n/nodes-langchain/utils/helpers.ts index 0640cfbd9118f..e0f0b32087f7d 100644 --- a/packages/@n8n/nodes-langchain/utils/helpers.ts +++ b/packages/@n8n/nodes-langchain/utils/helpers.ts @@ -1,4 +1,6 @@ import type { IExecuteFunctions } from 'n8n-workflow'; +import { BaseChatModel } from 'langchain/chat_models/base'; +import { BaseChatModel as BaseChatModelCore } from '@langchain/core/language_models/chat_models'; export function getMetadataFiltersValues( ctx: IExecuteFunctions, @@ -14,3 +16,8 @@ export function getMetadataFiltersValues( return undefined; } + +// TODO: Remove this function once langchain package is updated to 0.1.x +export function isChatInstance(model: any): model is BaseChatModel | BaseChatModelCore { + return model instanceof BaseChatModel || model instanceof BaseChatModelCore; +} diff --git a/packages/@n8n/nodes-langchain/utils/logWrapper.ts b/packages/@n8n/nodes-langchain/utils/logWrapper.ts index 52a24ef6ac327..771c3e545250b 100644 --- a/packages/@n8n/nodes-langchain/utils/logWrapper.ts +++ b/packages/@n8n/nodes-langchain/utils/logWrapper.ts @@ -27,6 +27,7 @@ import { BaseOutputParser } from 'langchain/schema/output_parser'; import { isObject } from 'lodash'; import { N8nJsonLoader } from './N8nJsonLoader'; import { N8nBinaryLoader } from './N8nBinaryLoader'; +import { isChatInstance } from './helpers'; const errorsMap: { [key: string]: { message: string; description: string } } = { 'You exceeded your current quota, please check your plan and billing details.': { @@ -225,7 +226,7 @@ export function logWrapper( } // ========== BaseChatModel ========== - if (originalInstance instanceof BaseLLM || originalInstance instanceof BaseChatModel) { + if (originalInstance instanceof BaseLLM || isChatInstance(originalInstance)) { if (prop === '_generate' && '_generate' in target) { return async ( messages: BaseMessage[] & string[],