From fe6ead06064237b4fa7473628d283a76e7036a9c Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Tue, 5 Dec 2023 16:26:17 -0800 Subject: [PATCH 1/3] Simplify RunnableSequence transform implementation --- langchain-core/src/runnables/base.ts | 54 ++++--------------- .../runnables/tests/runnable_history.test.ts | 26 +++++++++ 2 files changed, 35 insertions(+), 45 deletions(-) diff --git a/langchain-core/src/runnables/base.ts b/langchain-core/src/runnables/base.ts index bd97724a15a0..efc72889d7da 100644 --- a/langchain-core/src/runnables/base.ts +++ b/langchain-core/src/runnables/base.ts @@ -1269,58 +1269,22 @@ export class RunnableSequence< undefined, options?.runName ); - let nextStepInput = input; const steps = [this.first, ...this.middle, this.last]; - // Find the index of the last runnable in the sequence that doesn't have an overridden .transform() method - // and start streaming from there - const streamingStartStepIndex = Math.min( - steps.length - 1, - steps.length - - [...steps].reverse().findIndex((step) => { - const isDefaultImplementation = - step.transform === Runnable.prototype.transform; - const boundRunnableIsDefaultImplementation = - RunnableBinding.isRunnableBinding(step) && - step.bound?.transform === Runnable.prototype.transform; - return ( - isDefaultImplementation || boundRunnableIsDefaultImplementation - ); - }) - - 1 - ); - - try { - const invokeSteps = steps.slice(0, streamingStartStepIndex); - for (let i = 0; i < invokeSteps.length; i += 1) { - const step = invokeSteps[i]; - nextStepInput = await step.invoke( - nextStepInput, - this._patchConfig(options, runManager?.getChild(`seq:step:${i + 1}`)) - ); - } - } catch (e) { - await runManager?.handleChainError(e); - throw e; - } let concatSupported = true; let finalOutput; + async function* inputGenerator() { + yield input; + } try { - let finalGenerator = await steps[streamingStartStepIndex]._streamIterator( - nextStepInput, - this._patchConfig( - options, - runManager?.getChild(`seq:step:${streamingStartStepIndex + 1}`) - ) + let finalGenerator = steps[0].transform( + inputGenerator(), + this._patchConfig(options, runManager?.getChild(`seq:step:0`)) ); - const finalSteps = steps.slice(streamingStartStepIndex + 1); - for (let i = 0; i < finalSteps.length; i += 1) { - const step = finalSteps[i]; + for (let i = 1; i < steps.length; i += 1) { + const step = steps[i]; finalGenerator = await step.transform( finalGenerator, - this._patchConfig( - options, - runManager?.getChild(`seq:step:${streamingStartStepIndex + i + 2}`) - ) + this._patchConfig(options, runManager?.getChild(`seq:step:${i + 1}`)) ); } for await (const chunk of finalGenerator) { diff --git a/langchain-core/src/runnables/tests/runnable_history.test.ts b/langchain-core/src/runnables/tests/runnable_history.test.ts index 3d9628478bfd..43d65366cb83 100644 --- a/langchain-core/src/runnables/tests/runnable_history.test.ts +++ b/langchain-core/src/runnables/tests/runnable_history.test.ts @@ -10,8 +10,10 @@ import { FakeChatMessageHistory, FakeLLM, FakeListChatMessageHistory, + FakeStreamingLLM, } from "../../utils/testing/index.js"; import { ChatPromptTemplate, MessagesPlaceholder } from "../../prompts/chat.js"; +import { StringOutputParser } from "../../output_parsers/string.js"; // For `BaseChatMessageHistory` async function getGetSessionHistory(): Promise< @@ -120,3 +122,27 @@ AI: AI: You are a helpful assistant Human: hello Human: good bye`); }); + +test("Runnable with message history should stream through", async () => { + const prompt = ChatPromptTemplate.fromMessages([ + ["ai", "You are a helpful assistant"], + new MessagesPlaceholder("history"), + ["human", "{input}"], + ]); + const model = new FakeStreamingLLM({}); + const chain = prompt.pipe(model); + + const getListMessageHistory = await getListSessionHistory(); + const withHistory = new RunnableWithMessageHistory({ + runnable: chain, + config: {}, + getMessageHistory: getListMessageHistory, + inputMessagesKey: "input", + historyMessagesKey: "history", + }).pipe(new StringOutputParser()); + const config: RunnableConfig = { configurable: { sessionId: "1" } }; + const stream = await withHistory.stream({ input: "hello" }, config); + for await (const chunk of stream) { + console.log("CHUNK", chunk); + } +}); From 5e9831d30f49d466c3a7bbe367b317afb0f0c737 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Tue, 5 Dec 2023 16:28:09 -0800 Subject: [PATCH 2/3] Fix test --- langchain-core/src/runnables/tests/runnable_history.test.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/langchain-core/src/runnables/tests/runnable_history.test.ts b/langchain-core/src/runnables/tests/runnable_history.test.ts index 43d65366cb83..1d1b0f45e421 100644 --- a/langchain-core/src/runnables/tests/runnable_history.test.ts +++ b/langchain-core/src/runnables/tests/runnable_history.test.ts @@ -142,7 +142,9 @@ test("Runnable with message history should stream through", async () => { }).pipe(new StringOutputParser()); const config: RunnableConfig = { configurable: { sessionId: "1" } }; const stream = await withHistory.stream({ input: "hello" }, config); + const chunks = []; for await (const chunk of stream) { - console.log("CHUNK", chunk); + chunks.push(chunk); } + expect(chunks.length).toBeGreaterThan(1); }); From 524a923a71da238a44da8c5a80529aaec9cc67e2 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Tue, 5 Dec 2023 16:42:37 -0800 Subject: [PATCH 3/3] Fix tracing tags --- langchain-core/src/runnables/base.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain-core/src/runnables/base.ts b/langchain-core/src/runnables/base.ts index efc72889d7da..477ce73716c2 100644 --- a/langchain-core/src/runnables/base.ts +++ b/langchain-core/src/runnables/base.ts @@ -1278,7 +1278,7 @@ export class RunnableSequence< try { let finalGenerator = steps[0].transform( inputGenerator(), - this._patchConfig(options, runManager?.getChild(`seq:step:0`)) + this._patchConfig(options, runManager?.getChild(`seq:step:1`)) ); for (let i = 1; i < steps.length; i += 1) { const step = steps[i];