Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[feat]: Add withListeners method to runnables/callbacks #3445

Closed
wants to merge 15 commits into from
26 changes: 21 additions & 5 deletions langchain-core/src/callbacks/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,9 @@ export class CallbackManager
extends BaseCallbackManager
implements BaseCallbackManagerMethods
{
handlers: BaseCallbackHandler[];
handlers: BaseCallbackHandler[] = [];

inheritableHandlers: BaseCallbackHandler[];
inheritableHandlers: BaseCallbackHandler[] = [];

tags: string[] = [];

Expand All @@ -500,10 +500,26 @@ export class CallbackManager

private readonly _parentRunId?: string;

constructor(parentRunId?: string) {
constructor(
parentRunId?: string,
options?: {
handlers?: BaseCallbackHandler[];
inheritableHandlers?: BaseCallbackHandler[];
tags?: string[];
inheritableTags?: string[];
metadata?: Record<string, unknown>;
inheritableMetadata?: Record<string, unknown>;
}
) {
super();
this.handlers = [];
this.inheritableHandlers = [];
this.handlers = options?.handlers ?? this.handlers;
this.inheritableHandlers =
options?.inheritableHandlers ?? this.inheritableHandlers;
this.tags = options?.tags ?? this.tags;
this.inheritableTags = options?.inheritableTags ?? this.inheritableTags;
this.metadata = options?.metadata ?? this.metadata;
this.inheritableMetadata =
options?.inheritableMetadata ?? this.inheritableMetadata;
this._parentRunId = parentRunId;
}

Expand Down
173 changes: 145 additions & 28 deletions langchain-core/src/runnables/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,14 @@ import {
} from "../tracers/log_stream.js";
import { Serializable } from "../load/serializable.js";
import { IterableReadableStream } from "../utils/stream.js";
import { RunnableConfig, getCallbackMangerForConfig } from "./config.js";
import {
RunnableConfig,
getCallbackMangerForConfig,
mergeConfigs,
} from "./config.js";
import { AsyncCaller } from "../utils/async_caller.js";
import { Run } from "../tracers/base.js";
import { RootListenersTracer } from "../tracers/root_listener.js";

export type RunnableFunc<RunInput, RunOutput> = (
input: RunInput
Expand Down Expand Up @@ -549,6 +555,45 @@ export abstract class Runnable<
static isRunnable(thing: any): thing is Runnable {
return thing ? thing.lc_runnable : false;
}

/**
* Bind lifecycle listeners to a Runnable, returning a new Runnable.
* The Run object contains information about the run, including its id,
* type, input, output, error, startTime, endTime, and any tags or metadata
* added to the run.
*
* @param {Object} params - The object containing the callback functions.
* @param {(run: Run) => void} params.onStart - Called before the runnable starts running, with the Run object.
* @param {(run: Run) => void} params.onEnd - Called after the runnable finishes running, with the Run object.
* @param {(run: Run) => void} params.onError - Called if the runnable throws an error, with the Run object.
*/
withListeners({
onStart,
onEnd,
onError,
}: {
onStart?: (run: Run, config?: RunnableConfig) => void | Promise<void>;
onEnd?: (run: Run, config?: RunnableConfig) => void | Promise<void>;
onError?: (run: Run, config?: RunnableConfig) => void | Promise<void>;
}): Runnable<RunInput, RunOutput, CallOptions> {
// eslint-disable-next-line @typescript-eslint/no-use-before-define
return new RunnableBinding<RunInput, RunOutput, CallOptions>({
bound: this,
config: {},
configFactories: [
(config) => ({
callbacks: [
new RootListenersTracer({
config,
onStart,
onEnd,
onError,
}),
],
}),
],
});
}
}

export type RunnableBindingArgs<
Expand All @@ -557,8 +602,9 @@ export type RunnableBindingArgs<
CallOptions extends RunnableConfig
> = {
bound: Runnable<RunInput, RunOutput, CallOptions>;
kwargs: Partial<CallOptions>;
kwargs?: Partial<CallOptions>;
config: RunnableConfig;
configFactories?: Array<(config: RunnableConfig) => RunnableConfig>;
};

/**
Expand All @@ -581,31 +627,33 @@ export class RunnableBinding<

config: RunnableConfig;

protected kwargs: Partial<CallOptions>;
protected kwargs?: Partial<CallOptions>;

configFactories?: Array<
(config: RunnableConfig) => RunnableConfig | Promise<RunnableConfig>
>;

constructor(fields: RunnableBindingArgs<RunInput, RunOutput, CallOptions>) {
super(fields);
this.bound = fields.bound;
this.kwargs = fields.kwargs;
this.config = fields.config;
this.configFactories = fields.configFactories;
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
_mergeConfig(options?: Record<string, any>) {
async _mergeConfig(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const copy: Record<string, any> = { ...this.config };
if (options) {
for (const key of Object.keys(options)) {
if (key === "metadata") {
copy[key] = { ...copy[key], ...options[key] };
} else if (key === "tags") {
copy[key] = (copy[key] ?? []).concat(options[key] ?? []);
} else {
copy[key] = options[key] ?? copy[key];
}
}
}
return copy as Partial<CallOptions>;
options?: Record<string, any>
): Promise<Partial<CallOptions>> {
const config = mergeConfigs(this.config, options);
return mergeConfigs(
config,
...(this.configFactories
? await Promise.all(
this.configFactories.map(async (f) => await f(config))
)
: [])
);
}

bind(
Expand Down Expand Up @@ -645,7 +693,7 @@ export class RunnableBinding<
): Promise<RunOutput> {
return this.bound.invoke(
input,
this._mergeConfig({ ...options, ...this.kwargs })
await this._mergeConfig({ ...options, ...this.kwargs })
);
}

Expand Down Expand Up @@ -673,13 +721,16 @@ export class RunnableBinding<
batchOptions?: RunnableBatchOptions
): Promise<(RunOutput | Error)[]> {
const mergedOptions = Array.isArray(options)
? options.map((individualOption) =>
this._mergeConfig({
...individualOption,
...this.kwargs,
})
? await Promise.all(
options.map(
async (individualOption) =>
await this._mergeConfig({
...individualOption,
...this.kwargs,
})
)
)
: this._mergeConfig({ ...options, ...this.kwargs });
: await this._mergeConfig({ ...options, ...this.kwargs });
return this.bound.batch(inputs, mergedOptions, batchOptions);
}

Expand All @@ -689,7 +740,7 @@ export class RunnableBinding<
) {
yield* this.bound._streamIterator(
input,
this._mergeConfig({ ...options, ...this.kwargs })
await this._mergeConfig({ ...options, ...this.kwargs })
);
}

Expand All @@ -699,7 +750,7 @@ export class RunnableBinding<
): Promise<IterableReadableStream<RunOutput>> {
return this.bound.stream(
input,
this._mergeConfig({ ...options, ...this.kwargs })
await this._mergeConfig({ ...options, ...this.kwargs })
);
}

Expand All @@ -710,7 +761,7 @@ export class RunnableBinding<
): AsyncGenerator<RunOutput> {
yield* this.bound.transform(
generator,
this._mergeConfig({ ...options, ...this.kwargs })
await this._mergeConfig({ ...options, ...this.kwargs })
);
}

Expand All @@ -721,6 +772,46 @@ export class RunnableBinding<
): thing is RunnableBinding<any, any, any> {
return thing.bound && Runnable.isRunnable(thing.bound);
}

/**
* Bind lifecycle listeners to a Runnable, returning a new Runnable.
* The Run object contains information about the run, including its id,
* type, input, output, error, startTime, endTime, and any tags or metadata
* added to the run.
*
* @param {Object} params - The object containing the callback functions.
* @param {(run: Run) => void} params.onStart - Called before the runnable starts running, with the Run object.
* @param {(run: Run) => void} params.onEnd - Called after the runnable finishes running, with the Run object.
* @param {(run: Run) => void} params.onError - Called if the runnable throws an error, with the Run object.
*/
withListeners({
onStart,
onEnd,
onError,
}: {
onStart?: (run: Run, config?: RunnableConfig) => void | Promise<void>;
onEnd?: (run: Run, config?: RunnableConfig) => void | Promise<void>;
onError?: (run: Run, config?: RunnableConfig) => void | Promise<void>;
}): Runnable<RunInput, RunOutput, CallOptions> {
//
return new RunnableBinding<RunInput, RunOutput, CallOptions>({
bound: this.bound,
kwargs: this.kwargs,
config: this.config,
configFactories: [
(config) => ({
callbacks: [
new RootListenersTracer({
config,
onStart,
onEnd,
onError,
}),
],
}),
],
});
}
}

/**
Expand Down Expand Up @@ -789,6 +880,32 @@ export class RunnableEach<
this._patchConfig(config, runManager?.getChild())
);
}

/**
* Bind lifecycle listeners to a Runnable, returning a new Runnable.
* The Run object contains information about the run, including its id,
* type, input, output, error, startTime, endTime, and any tags or metadata
* added to the run.
*
* @param {Object} params - The object containing the callback functions.
* @param {(run: Run) => void} params.onStart - Called before the runnable starts running, with the Run object.
* @param {(run: Run) => void} params.onEnd - Called after the runnable finishes running, with the Run object.
* @param {(run: Run) => void} params.onError - Called if the runnable throws an error, with the Run object.
*/
withListeners({
onStart,
onEnd,
onError,
}: {
onStart?: (run: Run, config?: RunnableConfig) => void | Promise<void>;
onEnd?: (run: Run, config?: RunnableConfig) => void | Promise<void>;
onError?: (run: Run, config?: RunnableConfig) => void | Promise<void>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
}): Runnable<any, any, CallOptions> {
return new RunnableEach<RunInputItem, RunOutputItem, CallOptions>({
bound: this.bound.withListeners({ onStart, onEnd, onError }),
});
}
}

/**
Expand Down
73 changes: 73 additions & 0 deletions langchain-core/src/runnables/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,76 @@ export async function getCallbackMangerForConfig(config?: RunnableConfig) {
config?.metadata
);
}

export function mergeConfigs<CallOptions extends RunnableConfig>(
config: RunnableConfig,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
options?: Record<string, any>
): Partial<CallOptions> {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const copy: Record<string, any> = { ...config };
if (options) {
for (const key of Object.keys(options)) {
if (key === "metadata") {
copy[key] = { ...copy[key], ...options[key] };
} else if (key === "tags") {
copy[key] = (copy[key] ?? []).concat(options[key] ?? []);
} else if (key === "callbacks") {
const baseCallbacks = copy.callbacks;
const theseCallbacks = options.callbacks ?? config.callbacks;
// callbacks can be either undefined, Array<handler> or manager
// so merging two callbacks values has 6 cases
if (Array.isArray(theseCallbacks)) {
if (!baseCallbacks) {
copy.callbacks = theseCallbacks;
} else if (Array.isArray(baseCallbacks)) {
copy.callbacks = baseCallbacks.concat(theseCallbacks);
} else {
// baseCallbacks is a manager
const manager = baseCallbacks.copy();
for (const callback of theseCallbacks) {
manager.addHandler(callback, true);
}
copy.callbacks = manager;
}
} else if (theseCallbacks) {
// theseCallbacks is a manager
if (!baseCallbacks) {
copy.callbacks = theseCallbacks;
} else if (Array.isArray(baseCallbacks)) {
const manager = theseCallbacks.copy();
for (const callback of baseCallbacks) {
manager.addHandler(callback, true);
}
copy.callbacks = manager;
} else {
// baseCallbacks is also a manager
copy.callbacks = new CallbackManager(theseCallbacks.parentRunId, {
handlers: baseCallbacks.handlers.concat(theseCallbacks.handlers),
inheritableHandlers: baseCallbacks.inheritableHandlers.concat(
theseCallbacks.inheritableHandlers
),
tags: Array.from(
new Set(baseCallbacks.tags.concat(theseCallbacks.tags))
),
inheritableTags: Array.from(
new Set(
baseCallbacks.inheritableTags.concat(
theseCallbacks.inheritableTags
)
)
),
metadata: {
...baseCallbacks.metadata,
...theseCallbacks.metadata,
},
});
}
}
} else {
copy[key] = options[key] ?? copy[key];
}
}
}
return copy as Partial<CallOptions>;
}
Loading