diff --git a/langchain-core/src/callbacks/manager.ts b/langchain-core/src/callbacks/manager.ts index 34fe086636a6..1bb20acf30b8 100644 --- a/langchain-core/src/callbacks/manager.ts +++ b/langchain-core/src/callbacks/manager.ts @@ -484,9 +484,9 @@ export class CallbackManager extends BaseCallbackManager implements BaseCallbackManagerMethods { - handlers: BaseCallbackHandler[]; + handlers: BaseCallbackHandler[] = []; - inheritableHandlers: BaseCallbackHandler[]; + inheritableHandlers: BaseCallbackHandler[] = []; tags: string[] = []; @@ -500,10 +500,26 @@ export class CallbackManager private readonly _parentRunId?: string; - constructor(parentRunId?: string) { + constructor( + parentRunId?: string, + options?: { + handlers?: BaseCallbackHandler[]; + inheritableHandlers?: BaseCallbackHandler[]; + tags?: string[]; + inheritableTags?: string[]; + metadata?: Record; + inheritableMetadata?: Record; + } + ) { super(); - this.handlers = []; - this.inheritableHandlers = []; + this.handlers = options?.handlers ?? this.handlers; + this.inheritableHandlers = + options?.inheritableHandlers ?? this.inheritableHandlers; + this.tags = options?.tags ?? this.tags; + this.inheritableTags = options?.inheritableTags ?? this.inheritableTags; + this.metadata = options?.metadata ?? this.metadata; + this.inheritableMetadata = + options?.inheritableMetadata ?? this.inheritableMetadata; this._parentRunId = parentRunId; } diff --git a/langchain-core/src/runnables/base.ts b/langchain-core/src/runnables/base.ts index 0403b91486bb..cffccc316e8e 100644 --- a/langchain-core/src/runnables/base.ts +++ b/langchain-core/src/runnables/base.ts @@ -12,8 +12,14 @@ import { } from "../tracers/log_stream.js"; import { Serializable } from "../load/serializable.js"; import { IterableReadableStream } from "../utils/stream.js"; -import { RunnableConfig, getCallbackMangerForConfig } from "./config.js"; +import { + RunnableConfig, + getCallbackMangerForConfig, + mergeConfigs, +} from "./config.js"; import { AsyncCaller } from "../utils/async_caller.js"; +import { Run } from "../tracers/base.js"; +import { RootListenersTracer } from "../tracers/root_listener.js"; export type RunnableFunc = ( input: RunInput @@ -549,6 +555,45 @@ export abstract class Runnable< static isRunnable(thing: any): thing is Runnable { return thing ? thing.lc_runnable : false; } + + /** + * Bind lifecycle listeners to a Runnable, returning a new Runnable. + * The Run object contains information about the run, including its id, + * type, input, output, error, startTime, endTime, and any tags or metadata + * added to the run. + * + * @param {Object} params - The object containing the callback functions. + * @param {(run: Run) => void} params.onStart - Called before the runnable starts running, with the Run object. + * @param {(run: Run) => void} params.onEnd - Called after the runnable finishes running, with the Run object. + * @param {(run: Run) => void} params.onError - Called if the runnable throws an error, with the Run object. + */ + withListeners({ + onStart, + onEnd, + onError, + }: { + onStart?: (run: Run, config?: RunnableConfig) => void | Promise; + onEnd?: (run: Run, config?: RunnableConfig) => void | Promise; + onError?: (run: Run, config?: RunnableConfig) => void | Promise; + }): Runnable { + // eslint-disable-next-line @typescript-eslint/no-use-before-define + return new RunnableBinding({ + bound: this, + config: {}, + configFactories: [ + (config) => ({ + callbacks: [ + new RootListenersTracer({ + config, + onStart, + onEnd, + onError, + }), + ], + }), + ], + }); + } } export type RunnableBindingArgs< @@ -557,8 +602,9 @@ export type RunnableBindingArgs< CallOptions extends RunnableConfig > = { bound: Runnable; - kwargs: Partial; + kwargs?: Partial; config: RunnableConfig; + configFactories?: Array<(config: RunnableConfig) => RunnableConfig>; }; /** @@ -581,31 +627,33 @@ export class RunnableBinding< config: RunnableConfig; - protected kwargs: Partial; + protected kwargs?: Partial; + + configFactories?: Array< + (config: RunnableConfig) => RunnableConfig | Promise + >; constructor(fields: RunnableBindingArgs) { super(fields); this.bound = fields.bound; this.kwargs = fields.kwargs; this.config = fields.config; + this.configFactories = fields.configFactories; } - // eslint-disable-next-line @typescript-eslint/no-explicit-any - _mergeConfig(options?: Record) { + async _mergeConfig( // eslint-disable-next-line @typescript-eslint/no-explicit-any - const copy: Record = { ...this.config }; - if (options) { - for (const key of Object.keys(options)) { - if (key === "metadata") { - copy[key] = { ...copy[key], ...options[key] }; - } else if (key === "tags") { - copy[key] = (copy[key] ?? []).concat(options[key] ?? []); - } else { - copy[key] = options[key] ?? copy[key]; - } - } - } - return copy as Partial; + options?: Record + ): Promise> { + const config = mergeConfigs(this.config, options); + return mergeConfigs( + config, + ...(this.configFactories + ? await Promise.all( + this.configFactories.map(async (f) => await f(config)) + ) + : []) + ); } bind( @@ -645,7 +693,7 @@ export class RunnableBinding< ): Promise { return this.bound.invoke( input, - this._mergeConfig({ ...options, ...this.kwargs }) + await this._mergeConfig({ ...options, ...this.kwargs }) ); } @@ -673,13 +721,16 @@ export class RunnableBinding< batchOptions?: RunnableBatchOptions ): Promise<(RunOutput | Error)[]> { const mergedOptions = Array.isArray(options) - ? options.map((individualOption) => - this._mergeConfig({ - ...individualOption, - ...this.kwargs, - }) + ? await Promise.all( + options.map( + async (individualOption) => + await this._mergeConfig({ + ...individualOption, + ...this.kwargs, + }) + ) ) - : this._mergeConfig({ ...options, ...this.kwargs }); + : await this._mergeConfig({ ...options, ...this.kwargs }); return this.bound.batch(inputs, mergedOptions, batchOptions); } @@ -689,7 +740,7 @@ export class RunnableBinding< ) { yield* this.bound._streamIterator( input, - this._mergeConfig({ ...options, ...this.kwargs }) + await this._mergeConfig({ ...options, ...this.kwargs }) ); } @@ -699,7 +750,7 @@ export class RunnableBinding< ): Promise> { return this.bound.stream( input, - this._mergeConfig({ ...options, ...this.kwargs }) + await this._mergeConfig({ ...options, ...this.kwargs }) ); } @@ -710,7 +761,7 @@ export class RunnableBinding< ): AsyncGenerator { yield* this.bound.transform( generator, - this._mergeConfig({ ...options, ...this.kwargs }) + await this._mergeConfig({ ...options, ...this.kwargs }) ); } @@ -721,6 +772,46 @@ export class RunnableBinding< ): thing is RunnableBinding { return thing.bound && Runnable.isRunnable(thing.bound); } + + /** + * Bind lifecycle listeners to a Runnable, returning a new Runnable. + * The Run object contains information about the run, including its id, + * type, input, output, error, startTime, endTime, and any tags or metadata + * added to the run. + * + * @param {Object} params - The object containing the callback functions. + * @param {(run: Run) => void} params.onStart - Called before the runnable starts running, with the Run object. + * @param {(run: Run) => void} params.onEnd - Called after the runnable finishes running, with the Run object. + * @param {(run: Run) => void} params.onError - Called if the runnable throws an error, with the Run object. + */ + withListeners({ + onStart, + onEnd, + onError, + }: { + onStart?: (run: Run, config?: RunnableConfig) => void | Promise; + onEnd?: (run: Run, config?: RunnableConfig) => void | Promise; + onError?: (run: Run, config?: RunnableConfig) => void | Promise; + }): Runnable { + // + return new RunnableBinding({ + bound: this.bound, + kwargs: this.kwargs, + config: this.config, + configFactories: [ + (config) => ({ + callbacks: [ + new RootListenersTracer({ + config, + onStart, + onEnd, + onError, + }), + ], + }), + ], + }); + } } /** @@ -789,6 +880,32 @@ export class RunnableEach< this._patchConfig(config, runManager?.getChild()) ); } + + /** + * Bind lifecycle listeners to a Runnable, returning a new Runnable. + * The Run object contains information about the run, including its id, + * type, input, output, error, startTime, endTime, and any tags or metadata + * added to the run. + * + * @param {Object} params - The object containing the callback functions. + * @param {(run: Run) => void} params.onStart - Called before the runnable starts running, with the Run object. + * @param {(run: Run) => void} params.onEnd - Called after the runnable finishes running, with the Run object. + * @param {(run: Run) => void} params.onError - Called if the runnable throws an error, with the Run object. + */ + withListeners({ + onStart, + onEnd, + onError, + }: { + onStart?: (run: Run, config?: RunnableConfig) => void | Promise; + onEnd?: (run: Run, config?: RunnableConfig) => void | Promise; + onError?: (run: Run, config?: RunnableConfig) => void | Promise; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + }): Runnable { + return new RunnableEach({ + bound: this.bound.withListeners({ onStart, onEnd, onError }), + }); + } } /** diff --git a/langchain-core/src/runnables/config.ts b/langchain-core/src/runnables/config.ts index d70cbeb328bc..6491d38ac25f 100644 --- a/langchain-core/src/runnables/config.ts +++ b/langchain-core/src/runnables/config.ts @@ -14,3 +14,76 @@ export async function getCallbackMangerForConfig(config?: RunnableConfig) { config?.metadata ); } + +export function mergeConfigs( + config: RunnableConfig, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + options?: Record +): Partial { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const copy: Record = { ...config }; + if (options) { + for (const key of Object.keys(options)) { + if (key === "metadata") { + copy[key] = { ...copy[key], ...options[key] }; + } else if (key === "tags") { + copy[key] = (copy[key] ?? []).concat(options[key] ?? []); + } else if (key === "callbacks") { + const baseCallbacks = copy.callbacks; + const theseCallbacks = options.callbacks ?? config.callbacks; + // callbacks can be either undefined, Array or manager + // so merging two callbacks values has 6 cases + if (Array.isArray(theseCallbacks)) { + if (!baseCallbacks) { + copy.callbacks = theseCallbacks; + } else if (Array.isArray(baseCallbacks)) { + copy.callbacks = baseCallbacks.concat(theseCallbacks); + } else { + // baseCallbacks is a manager + const manager = baseCallbacks.copy(); + for (const callback of theseCallbacks) { + manager.addHandler(callback, true); + } + copy.callbacks = manager; + } + } else if (theseCallbacks) { + // theseCallbacks is a manager + if (!baseCallbacks) { + copy.callbacks = theseCallbacks; + } else if (Array.isArray(baseCallbacks)) { + const manager = theseCallbacks.copy(); + for (const callback of baseCallbacks) { + manager.addHandler(callback, true); + } + copy.callbacks = manager; + } else { + // baseCallbacks is also a manager + copy.callbacks = new CallbackManager(theseCallbacks.parentRunId, { + handlers: baseCallbacks.handlers.concat(theseCallbacks.handlers), + inheritableHandlers: baseCallbacks.inheritableHandlers.concat( + theseCallbacks.inheritableHandlers + ), + tags: Array.from( + new Set(baseCallbacks.tags.concat(theseCallbacks.tags)) + ), + inheritableTags: Array.from( + new Set( + baseCallbacks.inheritableTags.concat( + theseCallbacks.inheritableTags + ) + ) + ), + metadata: { + ...baseCallbacks.metadata, + ...theseCallbacks.metadata, + }, + }); + } + } + } else { + copy[key] = options[key] ?? copy[key]; + } + } + } + return copy as Partial; +} diff --git a/langchain-core/src/tracers/root_listener.ts b/langchain-core/src/tracers/root_listener.ts new file mode 100644 index 000000000000..222c307f531f --- /dev/null +++ b/langchain-core/src/tracers/root_listener.ts @@ -0,0 +1,90 @@ +import { RunnableConfig } from "../runnables/config.js"; +import { BaseTracer, Run } from "./base.js"; + +export class RootListenersTracer extends BaseTracer { + name = "RootListenersTracer"; + + /** The Run's ID. Type UUID */ + rootId?: string; + + config: RunnableConfig; + + argOnStart?: { + (run: Run): void | Promise; + (run: Run, config: RunnableConfig): void | Promise; + }; + + argOnEnd?: { + (run: Run): void | Promise; + (run: Run, config: RunnableConfig): void | Promise; + }; + + argOnError?: { + (run: Run): void | Promise; + (run: Run, config: RunnableConfig): void | Promise; + }; + + constructor({ + config, + onStart, + onEnd, + onError, + }: { + config: RunnableConfig; + onStart?: (run: Run, config?: RunnableConfig) => void | Promise; + onEnd?: (run: Run, config?: RunnableConfig) => void | Promise; + onError?: (run: Run, config?: RunnableConfig) => void | Promise; + }) { + super(); + this.config = config; + this.argOnStart = onStart; + this.argOnEnd = onEnd; + this.argOnError = onError; + } + + /** + * This is a legacy method only called once for an entire run tree + * therefore not useful here + * @param {Run} _ Not used + */ + persistRun(_: Run): Promise { + return Promise.resolve(); + } + + async onRunCreate(run: Run) { + if (this.rootId) { + return; + } + + this.rootId = run.id; + + if (this.argOnStart) { + if (this.argOnStart.length === 1) { + await this.argOnStart(run); + } else if (this.argOnStart.length === 2) { + await this.argOnStart(run, this.config); + } + } + } + + async onRunUpdate(run: Run) { + if (run.id !== this.rootId) { + return; + } + if (!run.error) { + if (this.argOnEnd) { + if (this.argOnEnd.length === 1) { + await this.argOnEnd(run); + } else if (this.argOnEnd.length === 2) { + await this.argOnEnd(run, this.config); + } + } + } else if (this.argOnError) { + if (this.argOnError.length === 1) { + await this.argOnError(run); + } else if (this.argOnError.length === 2) { + await this.argOnError(run, this.config); + } + } + } +} diff --git a/langchain/src/schema/tests/runnable.test.ts b/langchain/src/schema/tests/runnable.test.ts index ffd78a2c50c7..223e5710ce0e 100644 --- a/langchain/src/schema/tests/runnable.test.ts +++ b/langchain/src/schema/tests/runnable.test.ts @@ -2,7 +2,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { z } from "zod"; -import { test } from "@jest/globals"; +import { jest, test } from "@jest/globals"; import { createChatMessageChunkEncoderStream } from "../../chat_models/base.js"; import { ChatPromptTemplate, @@ -28,6 +28,7 @@ import { FakeStreamingLLM, FakeSplitIntoListParser, } from "./lib.js"; +import { FakeListChatModel } from "../../chat_models/fake.js"; test("Test batch", async () => { const llm = new FakeLLM({}); @@ -247,3 +248,71 @@ test("Runnable withConfig", async () => { expect(chunks[0]?.tags).toEqual(["a-tag", "b-tag"]); expect(chunks[0]?.metadata).toEqual({ a: "updated", b: "c" }); }); + +test("Listeners work", async () => { + const prompt = ChatPromptTemplate.fromMessages([ + SystemMessagePromptTemplate.fromTemplate("You are a nice assistant."), + ["human", "{question}"], + ]); + const model = new FakeListChatModel({ + responses: ["foo"], + }); + const chain = prompt.pipe(model); + + const mockStart = jest.fn(); + const mockEnd = jest.fn(); + + await chain + .withListeners({ + onStart: (run) => { + mockStart(run); + }, + onEnd: (run) => { + mockEnd(run); + }, + }) + .invoke({ question: "What is the meaning of life?" }); + + expect(mockStart).toHaveBeenCalledTimes(1); + expect((mockStart.mock.calls[0][0] as { name: string }).name).toBe( + "RunnableSequence" + ); + expect(mockEnd).toHaveBeenCalledTimes(1); +}); + +test("Listeners work with async handlers", async () => { + const prompt = ChatPromptTemplate.fromMessages([ + SystemMessagePromptTemplate.fromTemplate("You are a nice assistant."), + ["human", "{question}"], + ]); + const model = new FakeListChatModel({ + responses: ["foo"], + }); + const chain = prompt.pipe(model); + + const mockStart = jest.fn(); + const mockEnd = jest.fn(); + + await chain + .withListeners({ + // eslint-disable-next-line @typescript-eslint/no-misused-promises + onStart: async (run) => { + const promise = new Promise((resolve) => setTimeout(resolve, 2000)); + await promise; + mockStart(run); + }, + // eslint-disable-next-line @typescript-eslint/no-misused-promises + onEnd: async (run) => { + const promise = new Promise((resolve) => setTimeout(resolve, 2000)); + await promise; + mockEnd(run); + }, + }) + .invoke({ question: "What is the meaning of life?" }); + + expect(mockStart).toHaveBeenCalledTimes(1); + expect((mockStart.mock.calls[0][0] as { name: string }).name).toBe( + "RunnableSequence" + ); + expect(mockEnd).toHaveBeenCalledTimes(1); +});