From af69dd9660ff2a02484bfb00b00ba098b2f1db83 Mon Sep 17 00:00:00 2001 From: oleg Date: Tue, 5 Mar 2024 13:53:46 +0100 Subject: [PATCH] fix(editor): Fix retrieving of messages from memory in chat modal (#8807) Signed-off-by: Oleg Ivaniv --- .../@n8n/nodes-langchain/utils/logWrapper.ts | 54 ++++++++++--------- .../src/components/WorkflowLMChat.vue | 38 +++++-------- 2 files changed, 41 insertions(+), 51 deletions(-) diff --git a/packages/@n8n/nodes-langchain/utils/logWrapper.ts b/packages/@n8n/nodes-langchain/utils/logWrapper.ts index 70e30b5ce932b..37b19b6324b34 100644 --- a/packages/@n8n/nodes-langchain/utils/logWrapper.ts +++ b/packages/@n8n/nodes-langchain/utils/logWrapper.ts @@ -15,7 +15,7 @@ import type { BaseDocumentLoader } from 'langchain/document_loaders/base'; import type { BaseCallbackConfig, Callbacks } from 'langchain/dist/callbacks/manager'; import { BaseLLM } from 'langchain/llms/base'; import { BaseChatMemory } from 'langchain/memory'; -import type { MemoryVariables } from 'langchain/dist/memory/base'; +import type { MemoryVariables, OutputValues } from 'langchain/dist/memory/base'; import { BaseRetriever } from 'langchain/schema/retriever'; import type { FormatInstructionsOptions } from 'langchain/schema/output_parser'; import { BaseOutputParser, OutputParserException } from 'langchain/schema/output_parser'; @@ -148,35 +148,37 @@ export function logWrapper( arguments: [values], })) as MemoryVariables; + const chatHistory = (response?.chat_history as BaseMessage[]) ?? response; + executeFunctions.addOutputData(connectionType, index, [ - [{ json: { action: 'loadMemoryVariables', response } }], + [{ json: { action: 'loadMemoryVariables', chatHistory } }], ]); return response; }; - } else if ( - prop === 'outputKey' && - 'outputKey' in target && - target.constructor.name === 'BufferWindowMemory' - ) { - connectionType = NodeConnectionType.AiMemory; - const { index } = executeFunctions.addInputData(connectionType, [ - [{ json: { action: 'chatHistory' } }], - ]); - const response = target[prop]; - - target.chatHistory - .getMessages() - .then((messages) => { - executeFunctions.addOutputData(NodeConnectionType.AiMemory, index, [ - [{ json: { action: 'chatHistory', chatHistory: messages } }], - ]); - }) - .catch((error: Error) => { - executeFunctions.addOutputData(NodeConnectionType.AiMemory, index, [ - [{ json: { action: 'chatHistory', error } }], - ]); - }); - return response; + } else if (prop === 'saveContext' && 'saveContext' in target) { + return async (input: InputValues, output: OutputValues): Promise => { + connectionType = NodeConnectionType.AiMemory; + + const { index } = executeFunctions.addInputData(connectionType, [ + [{ json: { action: 'saveContext', input, output } }], + ]); + + const response = (await callMethodAsync.call(target, { + executeFunctions, + connectionType, + currentNodeRunIndex: index, + method: target[prop], + arguments: [input, output], + })) as MemoryVariables; + + const chatHistory = await target.chatHistory.getMessages(); + + executeFunctions.addOutputData(connectionType, index, [ + [{ json: { action: 'saveContext', chatHistory } }], + ]); + + return response; + }; } } diff --git a/packages/editor-ui/src/components/WorkflowLMChat.vue b/packages/editor-ui/src/components/WorkflowLMChat.vue index 7bffa1cae8173..14badd2de527d 100644 --- a/packages/editor-ui/src/components/WorkflowLMChat.vue +++ b/packages/editor-ui/src/components/WorkflowLMChat.vue @@ -171,6 +171,10 @@ interface LangChainMessage { }; } +interface MemoryOutput { + action: string; + chatHistory?: LangChainMessage[]; +} // TODO: // - display additional information like execution time, tokens used, ... // - display errors better @@ -217,7 +221,10 @@ export default defineComponent({ this.messages = this.getChatMessages(); this.setNode(); - setTimeout(() => this.$refs.inputField?.focus(), 0); + setTimeout(() => { + this.scrollToLatestMessage(); + this.$refs.inputField?.focus(); + }, 0); }, methods: { displayExecution(executionId: string) { @@ -353,32 +360,13 @@ export default defineComponent({ memoryConnection.node, ); - const memoryOutputData = nodeResultData - ?.map( - ( - data, - ): { - action: string; - chatHistory?: unknown[]; - response?: { - sessionId?: unknown[]; - }; - } => get(data, ['data', NodeConnectionType.AiMemory, 0, 0, 'json'])!, + const memoryOutputData = (nodeResultData ?? []) + .map( + (data) => get(data, ['data', NodeConnectionType.AiMemory, 0, 0, 'json']) as MemoryOutput, ) - ?.find((data) => - ['chatHistory', 'loadMemoryVariables'].includes(data?.action) ? data : undefined, - ); - - let chatHistory: LangChainMessage[]; - if (memoryOutputData?.chatHistory) { - chatHistory = memoryOutputData?.chatHistory as LangChainMessage[]; - } else if (memoryOutputData?.response) { - chatHistory = memoryOutputData?.response.sessionId as LangChainMessage[]; - } else { - return []; - } + .find((data) => data.action === 'saveContext'); - return (chatHistory || []).map((message) => { + return (memoryOutputData?.chatHistory ?? []).map((message) => { return { text: message.kwargs.content, sender: last(message.id) === 'HumanMessage' ? 'user' : 'bot',