Skip to content

Commit

Permalink
feat: support event.reason.computedCallers
Browse files Browse the repository at this point in the history
  • Loading branch information
himself65 committed Apr 1, 2024
1 parent f66f705 commit caef936
Show file tree
Hide file tree
Showing 15 changed files with 190 additions and 18 deletions.
13 changes: 12 additions & 1 deletion examples/chatHistory.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import { stdin as input, stdout as output } from "node:process";
import readline from "node:readline/promises";

import { OpenAI, SimpleChatEngine, SummaryChatHistory } from "llamaindex";
import {
OpenAI,
Settings,
SimpleChatEngine,
SummaryChatHistory,
} from "llamaindex";

if (process.env.NODE_ENV === "development") {
Settings.callbackManager.on("llm-end", (event) => {
console.log("callers chain", event.reason?.computedCallers);
});
}

async function main() {
// Set maxTokens to 75% of the context window size of 4096
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/agent/openai/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ export class OpenAIAgentWorker implements AgentWorker {
...llmChatKwargs,
});

const iterator = streamConverter(
const iterator = streamConverter.bind(this)(
streamReducer({
stream,
initialValue: "",
Expand Down
41 changes: 37 additions & 4 deletions packages/core/src/callbacks/CallbackManager.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,33 @@
import type { Anthropic } from "@anthropic-ai/sdk";
import { CustomEvent } from "@llamaindex/env";
import type { NodeWithScore } from "../Node.js";
import {
EventCaller,
getEventCaller,
} from "../internal/context/EventCaller.js";

export class LlamaIndexCustomEvent<T = any> extends CustomEvent<T> {
reason: EventCaller | null;
private constructor(
event: string,
options?: CustomEventInit & {
reason?: EventCaller | null;
},
) {
super(event, options);
this.reason = options?.reason ?? null;
}

static fromEvent<Type extends keyof LlamaIndexEventMaps>(
type: Type,
detail: LlamaIndexEventMaps[Type]["detail"],
) {
return new LlamaIndexCustomEvent(type, {
detail: detail,
reason: getEventCaller(),
});
}
}

/**
* This type is used to define the event maps for the Llamaindex package.
Expand Down Expand Up @@ -78,7 +105,11 @@ interface CallbackManagerMethods {

const noop: (...args: any[]) => any = () => void 0;

type EventHandler<Event extends CustomEvent> = (event: Event) => void;
type EventHandler<Event extends CustomEvent> = (
event: Event & {
reason: EventCaller | null;
},
) => void;

export class CallbackManager implements CallbackManagerMethods {
/**
Expand All @@ -90,7 +121,7 @@ export class CallbackManager implements CallbackManagerMethods {
this.#handlers
.get("stream")!
.map((handler) =>
handler(new CustomEvent("stream", { detail: response })),
handler(LlamaIndexCustomEvent.fromEvent("stream", response)),
),
);
};
Expand All @@ -105,7 +136,7 @@ export class CallbackManager implements CallbackManagerMethods {
this.#handlers
.get("retrieve")!
.map((handler) =>
handler(new CustomEvent("retrieve", { detail: response })),
handler(LlamaIndexCustomEvent.fromEvent("retrieve", response)),
),
);
};
Expand Down Expand Up @@ -168,6 +199,8 @@ export class CallbackManager implements CallbackManagerMethods {
if (!handlers) {
return;
}
handlers.forEach((handler) => handler(new CustomEvent(event, { detail })));
handlers.forEach((handler) =>
handler(LlamaIndexCustomEvent.fromEvent(event, detail)),
);
}
}
3 changes: 2 additions & 1 deletion packages/core/src/cloud/LlamaCloudRetriever.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ import type { NodeWithScore } from "../Node.js";
import { ObjectType, jsonToNode } from "../Node.js";
import type { BaseRetriever, RetrieveParams } from "../Retriever.js";
import { Settings } from "../Settings.js";
import { wrapEventCaller } from "../internal/context/EventCaller.js";
import type { ClientParams, CloudConstructorParams } from "./types.js";
import { DEFAULT_PROJECT_NAME } from "./types.js";
import { getClient } from "./utils.js";

export type CloudRetrieveParams = Omit<
PlatformApi.RetrievalParams,
"query" | "searchFilters" | "pipelineId" | "className"
Expand Down Expand Up @@ -50,6 +50,7 @@ export class LlamaCloudRetriever implements BaseRetriever {
return this.client;
}

@wrapEventCaller
async retrieve({
query,
preFilters,
Expand Down
3 changes: 2 additions & 1 deletion packages/core/src/engines/chat/CondenseQuestionChatEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
import type { Response } from "../../Response.js";
import type { ServiceContext } from "../../ServiceContext.js";
import { llmFromSettingsOrContext } from "../../Settings.js";
import { wrapEventCaller } from "../../internal/context/EventCaller.js";
import type { ChatMessage, LLM } from "../../llm/index.js";
import { extractText, streamReducer } from "../../llm/utils.js";
import { PromptMixin } from "../../prompts/index.js";
Expand All @@ -17,7 +18,6 @@ import type {
ChatEngineParamsNonStreaming,
ChatEngineParamsStreaming,
} from "./types.js";

/**
* CondenseQuestionChatEngine is used in conjunction with a Index (for example VectorStoreIndex).
* It does two steps on taking a user's chat message: first, it condenses the chat message
Expand Down Expand Up @@ -82,6 +82,7 @@ export class CondenseQuestionChatEngine

chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>;
chat(params: ChatEngineParamsNonStreaming): Promise<Response>;
@wrapEventCaller
async chat(
params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming,
): Promise<Response | AsyncIterable<Response>> {
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/engines/chat/ContextChatEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { getHistory } from "../../ChatHistory.js";
import type { ContextSystemPrompt } from "../../Prompt.js";
import { Response } from "../../Response.js";
import type { BaseRetriever } from "../../Retriever.js";
import { wrapEventCaller } from "../../internal/context/EventCaller.js";
import type { ChatMessage, ChatResponseChunk, LLM } from "../../llm/index.js";
import { OpenAI } from "../../llm/index.js";
import type { MessageContent } from "../../llm/types.js";
Expand Down Expand Up @@ -58,6 +59,7 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine {

chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>;
chat(params: ChatEngineParamsNonStreaming): Promise<Response>;
@wrapEventCaller
async chat(
params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming,
): Promise<Response | AsyncIterable<Response>> {
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/engines/chat/SimpleChatEngine.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { ChatHistory } from "../../ChatHistory.js";
import { getHistory } from "../../ChatHistory.js";
import { Response } from "../../Response.js";
import { wrapEventCaller } from "../../internal/context/EventCaller.js";
import type { ChatResponseChunk, LLM } from "../../llm/index.js";
import { OpenAI } from "../../llm/index.js";
import { streamConverter, streamReducer } from "../../llm/utils.js";
Expand All @@ -25,6 +26,7 @@ export class SimpleChatEngine implements ChatEngine {

chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>;
chat(params: ChatEngineParamsNonStreaming): Promise<Response>;
@wrapEventCaller
async chat(
params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming,
): Promise<Response | AsyncIterable<Response>> {
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/engines/query/RetrieverQueryEngine.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { NodeWithScore } from "../../Node.js";
import type { Response } from "../../Response.js";
import type { BaseRetriever } from "../../Retriever.js";
import { wrapEventCaller } from "../../internal/context/EventCaller.js";
import type { BaseNodePostprocessor } from "../../postprocessors/index.js";
import { PromptMixin } from "../../prompts/Mixin.js";
import type { BaseSynthesizer } from "../../synthesizers/index.js";
Expand Down Expand Up @@ -71,6 +72,7 @@ export class RetrieverQueryEngine

query(params: QueryEngineParamsStreaming): Promise<AsyncIterable<Response>>;
query(params: QueryEngineParamsNonStreaming): Promise<Response>;
@wrapEventCaller
async query(
params: QueryEngineParamsStreaming | QueryEngineParamsNonStreaming,
): Promise<Response | AsyncIterable<Response>> {
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/engines/query/SubQuestionQueryEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import type {
ToolMetadata,
} from "../../types.js";

import { wrapEventCaller } from "../../internal/context/EventCaller.js";
import type { BaseQuestionGenerator, SubQuestion } from "./types.js";

/**
Expand Down Expand Up @@ -78,6 +79,7 @@ export class SubQuestionQueryEngine

query(params: QueryEngineParamsStreaming): Promise<AsyncIterable<Response>>;
query(params: QueryEngineParamsNonStreaming): Promise<Response>;
@wrapEventCaller
async query(
params: QueryEngineParamsStreaming | QueryEngineParamsNonStreaming,
): Promise<Response | AsyncIterable<Response>> {
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/indices/summary/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
nodeParserFromSettingsOrContext,
} from "../../Settings.js";
import { RetrieverQueryEngine } from "../../engines/query/index.js";
import { wrapEventCaller } from "../../internal/context/EventCaller.js";
import type { BaseNodePostprocessor } from "../../postprocessors/index.js";
import type { StorageContext } from "../../storage/StorageContext.js";
import { storageContextFromDefaults } from "../../storage/StorageContext.js";
Expand Down Expand Up @@ -286,6 +287,7 @@ export class SummaryIndexRetriever implements BaseRetriever {
this.index = index;
}

@wrapEventCaller
async retrieve({ query }: RetrieveParams): Promise<NodeWithScore[]> {
const nodeIds = this.index.indexStruct.nodes;
const nodes = await this.index.docStore.getNodes(nodeIds);
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/indices/vectorStore/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import {
DocStoreStrategy,
createDocStoreStrategy,
} from "../../ingestion/strategies/index.js";
import { wrapEventCaller } from "../../internal/context/EventCaller.js";
import type { BaseNodePostprocessor } from "../../postprocessors/types.js";
import type { StorageContext } from "../../storage/StorageContext.js";
import { storageContextFromDefaults } from "../../storage/StorageContext.js";
Expand Down Expand Up @@ -484,6 +485,7 @@ export class VectorIndexRetriever implements BaseRetriever {
return this.buildNodeListFromQueryResult(result);
}

@wrapEventCaller
protected sendEvent(
query: string,
nodesWithScores: NodeWithScore<Metadata>[],
Expand Down
99 changes: 99 additions & 0 deletions packages/core/src/internal/context/EventCaller.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import { AsyncLocalStorage, randomUUID } from "@llamaindex/env";
import { isAsyncGenerator, isGenerator } from "../utils.js";

const eventReasonAsyncLocalStorage = new AsyncLocalStorage<EventCaller>();

/**
* EventCaller is used to track the caller of an event.
*/
export class EventCaller {
public readonly id = randomUUID();

private constructor(
public readonly caller: unknown,
public readonly parent: EventCaller | null,
) {}

#computedCallers: unknown[] | null = null;

public get computedCallers(): unknown[] {
if (this.#computedCallers != null) {
return this.#computedCallers;
}
const callers = [this.caller];
let parent = this.parent;
while (parent != null) {
callers.push(parent.caller);
parent = parent.parent;
}
this.#computedCallers = callers;
return callers;
}

public static create(
caller: unknown,
parent: EventCaller | null,
): EventCaller {
return new EventCaller(caller, parent);
}
}

export function getEventCaller(): EventCaller | null {
return eventReasonAsyncLocalStorage.getStore() ?? null;
}

/**
* @param caller who is calling this function, pass in `this` if it's a class method
* @param fn
*/
function withEventCaller<T>(caller: unknown, fn: () => T) {
// create a chain of event callers
const parentCaller = getEventCaller();
return eventReasonAsyncLocalStorage.run(
EventCaller.create(caller, parentCaller),
fn,
);
}

export function wrapEventCaller<This, Result, Args extends unknown[]>(
originalMethod: (this: This, ...args: Args) => Result,
context: ClassMethodDecoratorContext<object>,
) {
const name = context.name;
context.addInitializer(function () {
// @ts-expect-error
const fn = this[name].bind(this);
// @ts-expect-error
this[name] = (...args: unknown[]) => {
return withEventCaller(this, () => fn(...args));
};
});
return function (this: This, ...args: Args): Result {
const result = originalMethod.call(this, ...args);
// patch for iterators because AsyncLocalStorage doesn't work with them
if (isAsyncGenerator(result)) {
const snapshot = AsyncLocalStorage.snapshot();
return (async function* asyncGeneratorWrapper() {
while (true) {
const { value, done } = await snapshot(() => result.next());
if (done) {
break;
}
yield value;
}
})() as Result;
} else if (isGenerator(result)) {
const snapshot = AsyncLocalStorage.snapshot();
return (function* generatorWrapper() {
while (true) {
const { value, done } = snapshot(() => result.next());
if (done) {
break;
}
yield value;
}
})() as Result;
}
return result;
};
}
7 changes: 7 additions & 0 deletions packages/core/src/internal/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
export const isAsyncGenerator = (obj: unknown): obj is AsyncGenerator => {
return obj != null && typeof obj === "object" && Symbol.asyncIterator in obj;
};

export const isGenerator = (obj: unknown): obj is Generator => {
return obj != null && typeof obj === "object" && Symbol.iterator in obj;
};
Loading

0 comments on commit caef936

Please sign in to comment.