diff --git a/examples/chatHistory.ts b/examples/chatHistory.ts index 7ed93471da..388fd428cf 100644 --- a/examples/chatHistory.ts +++ b/examples/chatHistory.ts @@ -1,7 +1,18 @@ import { stdin as input, stdout as output } from "node:process"; import readline from "node:readline/promises"; -import { OpenAI, SimpleChatEngine, SummaryChatHistory } from "llamaindex"; +import { + OpenAI, + Settings, + SimpleChatEngine, + SummaryChatHistory, +} from "llamaindex"; + +if (process.env.NODE_ENV === "development") { + Settings.callbackManager.on("llm-end", (event) => { + console.log("callers chain", event.reason?.computedCallers); + }); +} async function main() { // Set maxTokens to 75% of the context window size of 4096 diff --git a/packages/core/src/GlobalsHelper.ts b/packages/core/src/GlobalsHelper.ts index 1051f9b07f..2df512ea2a 100644 --- a/packages/core/src/GlobalsHelper.ts +++ b/packages/core/src/GlobalsHelper.ts @@ -1,12 +1,5 @@ import { encodingForModel } from "js-tiktoken"; -import { randomUUID } from "@llamaindex/env"; -import type { - Event, - EventTag, - EventType, -} from "./callbacks/CallbackManager.js"; - export enum Tokenizers { CL100K_BASE = "cl100k_base", } @@ -51,39 +44,6 @@ class GlobalsHelper { return this.defaultTokenizer!.decode.bind(this.defaultTokenizer); } - - /** - * @deprecated createEvent will be removed in the future, - * please use `new CustomEvent(eventType, { detail: payload })` instead. - * - * Also, `parentEvent` will not be used in the future, - * use `AsyncLocalStorage` to track parent events instead. - * @example - Usage of `AsyncLocalStorage`: - * let id = 0; - * const asyncLocalStorage = new AsyncLocalStorage(); - * asyncLocalStorage.run(++id, async () => { - * setTimeout(() => { - * console.log('parent event id:', asyncLocalStorage.getStore()); // 1 - * }, 1000) - * }); - */ - createEvent({ - parentEvent, - type, - tags, - }: { - parentEvent?: Event; - type: EventType; - tags?: EventTag[]; - }): Event { - return { - id: randomUUID(), - type, - // inherit parent tags if tags not set - tags: tags || parentEvent?.tags, - parentId: parentEvent?.id, - }; - } } export const globalsHelper = new GlobalsHelper(); diff --git a/packages/core/src/Retriever.ts b/packages/core/src/Retriever.ts index b8b942b558..bc836527b7 100644 --- a/packages/core/src/Retriever.ts +++ b/packages/core/src/Retriever.ts @@ -1,13 +1,8 @@ -import type { Event } from "./callbacks/CallbackManager.js"; import type { NodeWithScore } from "./Node.js"; import type { ServiceContext } from "./ServiceContext.js"; export type RetrieveParams = { query: string; - /** - * @deprecated will be removed in the next major version - */ - parentEvent?: Event; preFilters?: unknown; }; diff --git a/packages/core/src/agent/openai/worker.ts b/packages/core/src/agent/openai/worker.ts index 94c9372617..6bb9946c59 100644 --- a/packages/core/src/agent/openai/worker.ts +++ b/packages/core/src/agent/openai/worker.ts @@ -204,7 +204,7 @@ export class OpenAIAgentWorker implements AgentWorker { ...llmChatKwargs, }); - const iterator = streamConverter( + const iterator = streamConverter.bind(this)( streamReducer({ stream, initialValue: "", diff --git a/packages/core/src/callbacks/CallbackManager.ts b/packages/core/src/callbacks/CallbackManager.ts index cf38f49e97..71e0c3149e 100644 --- a/packages/core/src/callbacks/CallbackManager.ts +++ b/packages/core/src/callbacks/CallbackManager.ts @@ -1,6 +1,33 @@ import type { Anthropic } from "@anthropic-ai/sdk"; import { CustomEvent } from "@llamaindex/env"; import type { NodeWithScore } from "../Node.js"; +import { + EventCaller, + getEventCaller, +} from "../internal/context/EventCaller.js"; + +export class LlamaIndexCustomEvent extends CustomEvent { + reason: EventCaller | null; + private constructor( + event: string, + options?: CustomEventInit & { + reason?: EventCaller | null; + }, + ) { + super(event, options); + this.reason = options?.reason ?? null; + } + + static fromEvent( + type: Type, + detail: LlamaIndexEventMaps[Type]["detail"], + ) { + return new LlamaIndexCustomEvent(type, { + detail: detail, + reason: getEventCaller(), + }); + } +} /** * This type is used to define the event maps for the Llamaindex package. @@ -21,26 +48,6 @@ declare module "llamaindex" { } //#region @deprecated remove in the next major version -/* - An event is a wrapper that groups related operations. - For example, during retrieve and synthesize, - a parent event wraps both operations, and each operation has it's own - event. In this case, both sub-events will share a parentId. -*/ - -export type EventTag = "intermediate" | "final"; -export type EventType = "retrieve" | "llmPredict" | "wrapper"; -export interface Event { - id: string; - type: EventType; - tags?: EventTag[]; - parentId?: string; -} - -interface BaseCallbackResponse { - event: Event; -} - //Specify StreamToken per mainstream LLM export interface DefaultStreamToken { id: string; @@ -68,13 +75,13 @@ export type AnthropicStreamToken = Anthropic.Completion; //StreamCallbackResponse should let practitioners implement callbacks out of the box... //When custom streaming LLMs are involved, people are expected to write their own StreamCallbackResponses -export interface StreamCallbackResponse extends BaseCallbackResponse { +export interface StreamCallbackResponse { index: number; isDone?: boolean; token?: DefaultStreamToken; } -export interface RetrievalCallbackResponse extends BaseCallbackResponse { +export interface RetrievalCallbackResponse { query: string; nodes: NodeWithScore[]; } @@ -98,7 +105,11 @@ interface CallbackManagerMethods { const noop: (...args: any[]) => any = () => void 0; -type EventHandler = (event: Event) => void; +type EventHandler = ( + event: Event & { + reason: EventCaller | null; + }, +) => void; export class CallbackManager implements CallbackManagerMethods { /** @@ -110,7 +121,7 @@ export class CallbackManager implements CallbackManagerMethods { this.#handlers .get("stream")! .map((handler) => - handler(new CustomEvent("stream", { detail: response })), + handler(LlamaIndexCustomEvent.fromEvent("stream", response)), ), ); }; @@ -125,7 +136,7 @@ export class CallbackManager implements CallbackManagerMethods { this.#handlers .get("retrieve")! .map((handler) => - handler(new CustomEvent("retrieve", { detail: response })), + handler(LlamaIndexCustomEvent.fromEvent("retrieve", response)), ), ); }; @@ -188,6 +199,8 @@ export class CallbackManager implements CallbackManagerMethods { if (!handlers) { return; } - handlers.forEach((handler) => handler(new CustomEvent(event, { detail }))); + handlers.forEach((handler) => + handler(LlamaIndexCustomEvent.fromEvent(event, detail)), + ); } } diff --git a/packages/core/src/cloud/LlamaCloudRetriever.ts b/packages/core/src/cloud/LlamaCloudRetriever.ts index 0b4a53dc13..6f3bb745d6 100644 --- a/packages/core/src/cloud/LlamaCloudRetriever.ts +++ b/packages/core/src/cloud/LlamaCloudRetriever.ts @@ -1,13 +1,12 @@ import type { PlatformApi, PlatformApiClient } from "@llamaindex/cloud"; -import { globalsHelper } from "../GlobalsHelper.js"; import type { NodeWithScore } from "../Node.js"; import { ObjectType, jsonToNode } from "../Node.js"; import type { BaseRetriever, RetrieveParams } from "../Retriever.js"; import { Settings } from "../Settings.js"; +import { wrapEventCaller } from "../internal/context/EventCaller.js"; import type { ClientParams, CloudConstructorParams } from "./types.js"; import { DEFAULT_PROJECT_NAME } from "./types.js"; import { getClient } from "./utils.js"; - export type CloudRetrieveParams = Omit< PlatformApi.RetrievalParams, "query" | "searchFilters" | "pipelineId" | "className" @@ -51,9 +50,9 @@ export class LlamaCloudRetriever implements BaseRetriever { return this.client; } + @wrapEventCaller async retrieve({ query, - parentEvent, preFilters, }: RetrieveParams): Promise { const pipelines = await ( @@ -77,13 +76,9 @@ export class LlamaCloudRetriever implements BaseRetriever { const nodes = this.resultNodesToNodeWithScore(results.retrievalNodes); - Settings.callbackManager.onRetrieve({ + Settings.callbackManager.dispatchEvent("retrieve", { query, nodes, - event: globalsHelper.createEvent({ - parentEvent, - type: "retrieve", - }), }); return nodes; diff --git a/packages/core/src/engines/chat/CondenseQuestionChatEngine.ts b/packages/core/src/engines/chat/CondenseQuestionChatEngine.ts index cd66a44fcd..6e629c1d42 100644 --- a/packages/core/src/engines/chat/CondenseQuestionChatEngine.ts +++ b/packages/core/src/engines/chat/CondenseQuestionChatEngine.ts @@ -8,6 +8,7 @@ import { import type { Response } from "../../Response.js"; import type { ServiceContext } from "../../ServiceContext.js"; import { llmFromSettingsOrContext } from "../../Settings.js"; +import { wrapEventCaller } from "../../internal/context/EventCaller.js"; import type { ChatMessage, LLM } from "../../llm/index.js"; import { extractText, streamReducer } from "../../llm/utils.js"; import { PromptMixin } from "../../prompts/index.js"; @@ -17,7 +18,6 @@ import type { ChatEngineParamsNonStreaming, ChatEngineParamsStreaming, } from "./types.js"; - /** * CondenseQuestionChatEngine is used in conjunction with a Index (for example VectorStoreIndex). * It does two steps on taking a user's chat message: first, it condenses the chat message @@ -82,6 +82,7 @@ export class CondenseQuestionChatEngine chat(params: ChatEngineParamsStreaming): Promise>; chat(params: ChatEngineParamsNonStreaming): Promise; + @wrapEventCaller async chat( params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, ): Promise> { diff --git a/packages/core/src/engines/chat/ContextChatEngine.ts b/packages/core/src/engines/chat/ContextChatEngine.ts index a7889ad494..9dc1400179 100644 --- a/packages/core/src/engines/chat/ContextChatEngine.ts +++ b/packages/core/src/engines/chat/ContextChatEngine.ts @@ -1,10 +1,9 @@ -import { randomUUID } from "@llamaindex/env"; import type { ChatHistory } from "../../ChatHistory.js"; import { getHistory } from "../../ChatHistory.js"; import type { ContextSystemPrompt } from "../../Prompt.js"; import { Response } from "../../Response.js"; import type { BaseRetriever } from "../../Retriever.js"; -import type { Event } from "../../callbacks/CallbackManager.js"; +import { wrapEventCaller } from "../../internal/context/EventCaller.js"; import type { ChatMessage, ChatResponseChunk, LLM } from "../../llm/index.js"; import { OpenAI } from "../../llm/index.js"; import type { MessageContent } from "../../llm/types.js"; @@ -60,6 +59,7 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { chat(params: ChatEngineParamsStreaming): Promise>; chat(params: ChatEngineParamsNonStreaming): Promise; + @wrapEventCaller async chat( params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, ): Promise> { @@ -67,21 +67,14 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { const chatHistory = params.chatHistory ? getHistory(params.chatHistory) : this.chatHistory; - const parentEvent: Event = { - id: randomUUID(), - type: "wrapper", - tags: ["final"], - }; const requestMessages = await this.prepareRequestMessages( message, chatHistory, - parentEvent, ); if (stream) { const stream = await this.chatModel.chat({ messages: requestMessages.messages, - parentEvent, stream: true, }); return streamConverter( @@ -98,7 +91,6 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { } const response = await this.chatModel.chat({ messages: requestMessages.messages, - parentEvent, }); chatHistory.addMessage(response.message); return new Response(response.message.content, requestMessages.nodes); @@ -111,14 +103,13 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { private async prepareRequestMessages( message: MessageContent, chatHistory: ChatHistory, - parentEvent?: Event, ) { chatHistory.addMessage({ content: message, role: "user", }); const textOnly = extractText(message); - const context = await this.contextGenerator.generate(textOnly, parentEvent); + const context = await this.contextGenerator.generate(textOnly); const nodes = context.nodes.map((r) => r.node); const messages = await chatHistory.requestMessages( context ? [context.message] : undefined, diff --git a/packages/core/src/engines/chat/DefaultContextGenerator.ts b/packages/core/src/engines/chat/DefaultContextGenerator.ts index dfcb9a8dda..7e2b8e1103 100644 --- a/packages/core/src/engines/chat/DefaultContextGenerator.ts +++ b/packages/core/src/engines/chat/DefaultContextGenerator.ts @@ -1,9 +1,7 @@ -import { randomUUID } from "@llamaindex/env"; import type { NodeWithScore, TextNode } from "../../Node.js"; import type { ContextSystemPrompt } from "../../Prompt.js"; import { defaultContextSystemPrompt } from "../../Prompt.js"; import type { BaseRetriever } from "../../Retriever.js"; -import type { Event } from "../../callbacks/CallbackManager.js"; import type { BaseNodePostprocessor } from "../../postprocessors/index.js"; import { PromptMixin } from "../../prompts/index.js"; import type { Context, ContextGenerator } from "./types.js"; @@ -56,17 +54,9 @@ export class DefaultContextGenerator return nodesWithScore; } - async generate(message: string, parentEvent?: Event): Promise { - if (!parentEvent) { - parentEvent = { - id: randomUUID(), - type: "wrapper", - tags: ["final"], - }; - } + async generate(message: string): Promise { const sourceNodesWithScore = await this.retriever.retrieve({ query: message, - parentEvent, }); const nodes = await this.applyNodePostprocessors( diff --git a/packages/core/src/engines/chat/SimpleChatEngine.ts b/packages/core/src/engines/chat/SimpleChatEngine.ts index bd1c82f649..3494186c5d 100644 --- a/packages/core/src/engines/chat/SimpleChatEngine.ts +++ b/packages/core/src/engines/chat/SimpleChatEngine.ts @@ -1,6 +1,7 @@ import type { ChatHistory } from "../../ChatHistory.js"; import { getHistory } from "../../ChatHistory.js"; import { Response } from "../../Response.js"; +import { wrapEventCaller } from "../../internal/context/EventCaller.js"; import type { ChatResponseChunk, LLM } from "../../llm/index.js"; import { OpenAI } from "../../llm/index.js"; import { streamConverter, streamReducer } from "../../llm/utils.js"; @@ -25,6 +26,7 @@ export class SimpleChatEngine implements ChatEngine { chat(params: ChatEngineParamsStreaming): Promise>; chat(params: ChatEngineParamsNonStreaming): Promise; + @wrapEventCaller async chat( params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, ): Promise> { diff --git a/packages/core/src/engines/chat/types.ts b/packages/core/src/engines/chat/types.ts index 397d2cf51f..e2e9954f48 100644 --- a/packages/core/src/engines/chat/types.ts +++ b/packages/core/src/engines/chat/types.ts @@ -1,7 +1,6 @@ import type { ChatHistory } from "../../ChatHistory.js"; import type { BaseNode, NodeWithScore } from "../../Node.js"; import type { Response } from "../../Response.js"; -import type { Event } from "../../callbacks/CallbackManager.js"; import type { ChatMessage } from "../../llm/index.js"; import type { MessageContent } from "../../llm/types.js"; import type { ToolOutput } from "../../tools/types.js"; @@ -56,7 +55,7 @@ export interface Context { * A ContextGenerator is used to generate a context based on a message's text content */ export interface ContextGenerator { - generate(message: string, parentEvent?: Event): Promise; + generate(message: string): Promise; } export enum ChatResponseMode { diff --git a/packages/core/src/engines/query/RetrieverQueryEngine.ts b/packages/core/src/engines/query/RetrieverQueryEngine.ts index d7cc96a54c..da8d617a96 100644 --- a/packages/core/src/engines/query/RetrieverQueryEngine.ts +++ b/packages/core/src/engines/query/RetrieverQueryEngine.ts @@ -1,8 +1,7 @@ -import { randomUUID } from "@llamaindex/env"; import type { NodeWithScore } from "../../Node.js"; import type { Response } from "../../Response.js"; import type { BaseRetriever } from "../../Retriever.js"; -import type { Event } from "../../callbacks/CallbackManager.js"; +import { wrapEventCaller } from "../../internal/context/EventCaller.js"; import type { BaseNodePostprocessor } from "../../postprocessors/index.js"; import { PromptMixin } from "../../prompts/Mixin.js"; import type { BaseSynthesizer } from "../../synthesizers/index.js"; @@ -62,10 +61,9 @@ export class RetrieverQueryEngine return nodesWithScore; } - private async retrieve(query: string, parentEvent: Event) { + private async retrieve(query: string) { const nodes = await this.retriever.retrieve({ query, - parentEvent, preFilters: this.preFilters, }); @@ -74,28 +72,22 @@ export class RetrieverQueryEngine query(params: QueryEngineParamsStreaming): Promise>; query(params: QueryEngineParamsNonStreaming): Promise; + @wrapEventCaller async query( params: QueryEngineParamsStreaming | QueryEngineParamsNonStreaming, ): Promise> { const { query, stream } = params; - const parentEvent: Event = params.parentEvent || { - id: randomUUID(), - type: "wrapper", - tags: ["final"], - }; - const nodesWithScore = await this.retrieve(query, parentEvent); + const nodesWithScore = await this.retrieve(query); if (stream) { return this.responseSynthesizer.synthesize({ query, nodesWithScore, - parentEvent, stream: true, }); } return this.responseSynthesizer.synthesize({ query, nodesWithScore, - parentEvent, }); } } diff --git a/packages/core/src/engines/query/SubQuestionQueryEngine.ts b/packages/core/src/engines/query/SubQuestionQueryEngine.ts index 00ff68b602..b06e40e967 100644 --- a/packages/core/src/engines/query/SubQuestionQueryEngine.ts +++ b/packages/core/src/engines/query/SubQuestionQueryEngine.ts @@ -1,10 +1,8 @@ -import { randomUUID } from "@llamaindex/env"; import type { NodeWithScore } from "../../Node.js"; import { TextNode } from "../../Node.js"; import { LLMQuestionGenerator } from "../../QuestionGenerator.js"; import type { Response } from "../../Response.js"; import type { ServiceContext } from "../../ServiceContext.js"; -import type { Event } from "../../callbacks/CallbackManager.js"; import { PromptMixin } from "../../prompts/Mixin.js"; import type { BaseSynthesizer } from "../../synthesizers/index.js"; import { @@ -20,6 +18,7 @@ import type { ToolMetadata, } from "../../types.js"; +import { wrapEventCaller } from "../../internal/context/EventCaller.js"; import type { BaseQuestionGenerator, SubQuestion } from "./types.js"; /** @@ -80,29 +79,15 @@ export class SubQuestionQueryEngine query(params: QueryEngineParamsStreaming): Promise>; query(params: QueryEngineParamsNonStreaming): Promise; + @wrapEventCaller async query( params: QueryEngineParamsStreaming | QueryEngineParamsNonStreaming, ): Promise> { const { query, stream } = params; const subQuestions = await this.questionGen.generate(this.metadatas, query); - // groups final retrieval+synthesis operation - const parentEvent: Event = params.parentEvent || { - id: randomUUID(), - type: "wrapper", - tags: ["final"], - }; - - // groups all sub-queries - const subQueryParentEvent: Event = { - id: randomUUID(), - parentId: parentEvent.id, - type: "wrapper", - tags: ["intermediate"], - }; - const subQNodes = await Promise.all( - subQuestions.map((subQ) => this.querySubQ(subQ, subQueryParentEvent)), + subQuestions.map((subQ) => this.querySubQ(subQ)), ); const nodesWithScore = subQNodes @@ -112,21 +97,16 @@ export class SubQuestionQueryEngine return this.responseSynthesizer.synthesize({ query, nodesWithScore, - parentEvent, stream: true, }); } return this.responseSynthesizer.synthesize({ query, nodesWithScore, - parentEvent, }); } - private async querySubQ( - subQ: SubQuestion, - parentEvent?: Event, - ): Promise { + private async querySubQ(subQ: SubQuestion): Promise { try { const question = subQ.subQuestion; @@ -140,7 +120,6 @@ export class SubQuestionQueryEngine const responseText = await queryEngine?.call?.({ query: question, - parentEvent, }); if (!responseText) { diff --git a/packages/core/src/indices/summary/index.ts b/packages/core/src/indices/summary/index.ts index 1eef0ebe5b..140347963f 100644 --- a/packages/core/src/indices/summary/index.ts +++ b/packages/core/src/indices/summary/index.ts @@ -1,5 +1,4 @@ import _ from "lodash"; -import { globalsHelper } from "../../GlobalsHelper.js"; import type { BaseNode, Document, NodeWithScore } from "../../Node.js"; import type { ChoiceSelectPrompt } from "../../Prompt.js"; import { defaultChoiceSelectPrompt } from "../../Prompt.js"; @@ -11,6 +10,7 @@ import { nodeParserFromSettingsOrContext, } from "../../Settings.js"; import { RetrieverQueryEngine } from "../../engines/query/index.js"; +import { wrapEventCaller } from "../../internal/context/EventCaller.js"; import type { BaseNodePostprocessor } from "../../postprocessors/index.js"; import type { StorageContext } from "../../storage/StorageContext.js"; import { storageContextFromDefaults } from "../../storage/StorageContext.js"; @@ -287,10 +287,8 @@ export class SummaryIndexRetriever implements BaseRetriever { this.index = index; } - async retrieve({ - query, - parentEvent, - }: RetrieveParams): Promise { + @wrapEventCaller + async retrieve({ query }: RetrieveParams): Promise { const nodeIds = this.index.indexStruct.nodes; const nodes = await this.index.docStore.getNodes(nodeIds); const result = nodes.map((node) => ({ @@ -298,13 +296,9 @@ export class SummaryIndexRetriever implements BaseRetriever { score: 1, })); - Settings.callbackManager.onRetrieve({ + Settings.callbackManager.dispatchEvent("retrieve", { query, nodes: result, - event: globalsHelper.createEvent({ - parentEvent, - type: "retrieve", - }), }); return result; @@ -340,10 +334,7 @@ export class SummaryIndexLLMRetriever implements BaseRetriever { this.serviceContext = serviceContext || index.serviceContext; } - async retrieve({ - query, - parentEvent, - }: RetrieveParams): Promise { + async retrieve({ query }: RetrieveParams): Promise { const nodeIds = this.index.indexStruct.nodes; const results: NodeWithScore[] = []; @@ -380,13 +371,9 @@ export class SummaryIndexLLMRetriever implements BaseRetriever { results.push(...nodeWithScores); } - Settings.callbackManager.onRetrieve({ + Settings.callbackManager.dispatchEvent("retrieve", { query, nodes: results, - event: globalsHelper.createEvent({ - parentEvent, - type: "retrieve", - }), }); return results; diff --git a/packages/core/src/indices/vectorStore/index.ts b/packages/core/src/indices/vectorStore/index.ts index 06937a0946..06f5c20c97 100644 --- a/packages/core/src/indices/vectorStore/index.ts +++ b/packages/core/src/indices/vectorStore/index.ts @@ -1,4 +1,3 @@ -import { globalsHelper } from "../../GlobalsHelper.js"; import type { BaseNode, Document, @@ -18,7 +17,6 @@ import { embedModelFromSettingsOrContext, nodeParserFromSettingsOrContext, } from "../../Settings.js"; -import { type Event } from "../../callbacks/CallbackManager.js"; import { DEFAULT_SIMILARITY_TOP_K } from "../../constants.js"; import type { BaseEmbedding, @@ -31,6 +29,7 @@ import { DocStoreStrategy, createDocStoreStrategy, } from "../../ingestion/strategies/index.js"; +import { wrapEventCaller } from "../../internal/context/EventCaller.js"; import type { BaseNodePostprocessor } from "../../postprocessors/types.js"; import type { StorageContext } from "../../storage/StorageContext.js"; import { storageContextFromDefaults } from "../../storage/StorageContext.js"; @@ -440,7 +439,6 @@ export class VectorIndexRetriever implements BaseRetriever { async retrieve({ query, - parentEvent, preFilters, }: RetrieveParams): Promise { let nodesWithScores = await this.textRetrieve( @@ -450,7 +448,7 @@ export class VectorIndexRetriever implements BaseRetriever { nodesWithScores = nodesWithScores.concat( await this.textToImageRetrieve(query, preFilters as MetadataFilters), ); - this.sendEvent(query, nodesWithScores, parentEvent); + this.sendEvent(query, nodesWithScores); return nodesWithScores; } @@ -487,18 +485,14 @@ export class VectorIndexRetriever implements BaseRetriever { return this.buildNodeListFromQueryResult(result); } + @wrapEventCaller protected sendEvent( query: string, nodesWithScores: NodeWithScore[], - parentEvent: Event | undefined, ) { - Settings.callbackManager.onRetrieve({ + Settings.callbackManager.dispatchEvent("retrieve", { query, nodes: nodesWithScores, - event: globalsHelper.createEvent({ - parentEvent, - type: "retrieve", - }), }); } diff --git a/packages/core/src/internal/context/EventCaller.ts b/packages/core/src/internal/context/EventCaller.ts new file mode 100644 index 0000000000..26c6a631f1 --- /dev/null +++ b/packages/core/src/internal/context/EventCaller.ts @@ -0,0 +1,99 @@ +import { AsyncLocalStorage, randomUUID } from "@llamaindex/env"; +import { isAsyncGenerator, isGenerator } from "../utils.js"; + +const eventReasonAsyncLocalStorage = new AsyncLocalStorage(); + +/** + * EventCaller is used to track the caller of an event. + */ +export class EventCaller { + public readonly id = randomUUID(); + + private constructor( + public readonly caller: unknown, + public readonly parent: EventCaller | null, + ) {} + + #computedCallers: unknown[] | null = null; + + public get computedCallers(): unknown[] { + if (this.#computedCallers != null) { + return this.#computedCallers; + } + const callers = [this.caller]; + let parent = this.parent; + while (parent != null) { + callers.push(parent.caller); + parent = parent.parent; + } + this.#computedCallers = callers; + return callers; + } + + public static create( + caller: unknown, + parent: EventCaller | null, + ): EventCaller { + return new EventCaller(caller, parent); + } +} + +export function getEventCaller(): EventCaller | null { + return eventReasonAsyncLocalStorage.getStore() ?? null; +} + +/** + * @param caller who is calling this function, pass in `this` if it's a class method + * @param fn + */ +function withEventCaller(caller: unknown, fn: () => T) { + // create a chain of event callers + const parentCaller = getEventCaller(); + return eventReasonAsyncLocalStorage.run( + EventCaller.create(caller, parentCaller), + fn, + ); +} + +export function wrapEventCaller( + originalMethod: (this: This, ...args: Args) => Result, + context: ClassMethodDecoratorContext, +) { + const name = context.name; + context.addInitializer(function () { + // @ts-expect-error + const fn = this[name].bind(this); + // @ts-expect-error + this[name] = (...args: unknown[]) => { + return withEventCaller(this, () => fn(...args)); + }; + }); + return function (this: This, ...args: Args): Result { + const result = originalMethod.call(this, ...args); + // patch for iterators because AsyncLocalStorage doesn't work with them + if (isAsyncGenerator(result)) { + const snapshot = AsyncLocalStorage.snapshot(); + return (async function* asyncGeneratorWrapper() { + while (true) { + const { value, done } = await snapshot(() => result.next()); + if (done) { + break; + } + yield value; + } + })() as Result; + } else if (isGenerator(result)) { + const snapshot = AsyncLocalStorage.snapshot(); + return (function* generatorWrapper() { + while (true) { + const { value, done } = snapshot(() => result.next()); + if (done) { + break; + } + yield value; + } + })() as Result; + } + return result; + }; +} diff --git a/packages/core/src/internal/utils.ts b/packages/core/src/internal/utils.ts new file mode 100644 index 0000000000..c04db0b393 --- /dev/null +++ b/packages/core/src/internal/utils.ts @@ -0,0 +1,7 @@ +export const isAsyncGenerator = (obj: unknown): obj is AsyncGenerator => { + return obj != null && typeof obj === "object" && Symbol.asyncIterator in obj; +}; + +export const isGenerator = (obj: unknown): obj is Generator => { + return obj != null && typeof obj === "object" && Symbol.iterator in obj; +}; diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index 2f18ecc624..eb17a75d77 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -1,8 +1,6 @@ import type OpenAILLM from "openai"; import type { ClientOptions as OpenAIClientOptions } from "openai"; import { - type Event, - type EventType, type OpenAIStreamToken, type StreamCallbackResponse, } from "../callbacks/CallbackManager.js"; @@ -10,6 +8,7 @@ import { import type { ChatCompletionMessageParam } from "openai/resources/index.js"; import type { LLMOptions } from "portkey-ai"; import { Tokenizers } from "../GlobalsHelper.js"; +import { wrapEventCaller } from "../internal/context/EventCaller.js"; import { getCallbackManager } from "../internal/settings/CallbackManager.js"; import type { AnthropicSession } from "./anthropic.js"; import { getAnthropicSession } from "./anthropic.js"; @@ -35,7 +34,7 @@ import type { LLMMetadata, MessageType, } from "./types.js"; -import { llmEvent } from "./utils.js"; +import { wrapLLMEvent } from "./utils.js"; export const GPT4_MODELS = { "gpt-4": { contextWindow: 8192 }, @@ -211,11 +210,12 @@ export class OpenAI extends BaseLLM { params: LLMChatParamsStreaming, ): Promise>; chat(params: LLMChatParamsNonStreaming): Promise; - @llmEvent + @wrapEventCaller + @wrapLLMEvent async chat( params: LLMChatParamsNonStreaming | LLMChatParamsStreaming, ): Promise> { - const { messages, parentEvent, stream, tools, toolChoice } = params; + const { messages, stream, tools, toolChoice } = params; const baseRequestParams: OpenAILLM.Chat.ChatCompletionCreateParams = { model: this.model, temperature: this.temperature, @@ -255,9 +255,9 @@ export class OpenAI extends BaseLLM { }; } + @wrapEventCaller protected async *streamChat({ messages, - parentEvent, }: LLMChatParamsStreaming): AsyncIterable { const baseRequestParams: OpenAILLM.Chat.ChatCompletionCreateParams = { model: this.model, @@ -274,22 +274,12 @@ export class OpenAI extends BaseLLM { ...this.additionalChatOptions, }; - //Now let's wrap our stream in a callback - const onLLMStream = getCallbackManager().onLLMStream; - const chunk_stream: AsyncIterable = await this.session.openai.chat.completions.create({ ...baseRequestParams, stream: true, }); - const event: Event = parentEvent - ? parentEvent - : { - id: "unspecified", - type: "llmPredict" as EventType, - }; - // TODO: add callback to streamConverter and use streamConverter here //Indices let idx_counter: number = 0; @@ -303,12 +293,11 @@ export class OpenAI extends BaseLLM { //onLLMStream Callback const stream_callback: StreamCallbackResponse = { - event: event, index: idx_counter, isDone: is_done, token: part, }; - onLLMStream(stream_callback); + getCallbackManager().dispatchEvent("stream", stream_callback); idx_counter++; @@ -541,11 +530,11 @@ If a question does not make any sense, or is not factually coherent, explain why params: LLMChatParamsStreaming, ): Promise>; chat(params: LLMChatParamsNonStreaming): Promise; - @llmEvent + @wrapLLMEvent async chat( params: LLMChatParamsNonStreaming | LLMChatParamsStreaming, ): Promise> { - const { messages, parentEvent, stream } = params; + const { messages, stream } = params; const api = ALL_AVAILABLE_LLAMADEUCE_MODELS[this.model] .replicateApi as `${string}/${string}:${string}`; @@ -683,13 +672,13 @@ export class Anthropic extends BaseLLM { params: LLMChatParamsStreaming, ): Promise>; chat(params: LLMChatParamsNonStreaming): Promise; - @llmEvent + @wrapLLMEvent async chat( params: LLMChatParamsNonStreaming | LLMChatParamsStreaming, ): Promise> { let { messages } = params; - const { parentEvent, stream } = params; + const { stream } = params; let systemPrompt: string | null = null; @@ -706,7 +695,7 @@ export class Anthropic extends BaseLLM { //Streaming if (stream) { - return this.streamChat(messages, parentEvent, systemPrompt); + return this.streamChat(messages, systemPrompt); } //Non-streaming @@ -726,7 +715,6 @@ export class Anthropic extends BaseLLM { protected async *streamChat( messages: ChatMessage[], - parentEvent?: Event | undefined, systemPrompt?: string | null, ): AsyncIterable { const stream = await this.session.anthropic.messages.create({ @@ -782,13 +770,13 @@ export class Portkey extends BaseLLM { params: LLMChatParamsStreaming, ): Promise>; chat(params: LLMChatParamsNonStreaming): Promise; - @llmEvent + @wrapLLMEvent async chat( params: LLMChatParamsNonStreaming | LLMChatParamsStreaming, ): Promise> { - const { messages, parentEvent, stream, extraParams } = params; + const { messages, stream, extraParams } = params; if (stream) { - return this.streamChat(messages, parentEvent, extraParams); + return this.streamChat(messages, extraParams); } else { const bodyParams = extraParams || {}; const response = await this.session.portkey.chatCompletions.create({ @@ -804,25 +792,14 @@ export class Portkey extends BaseLLM { async *streamChat( messages: ChatMessage[], - parentEvent?: Event, params?: Record, ): AsyncIterable { - // Wrapping the stream in a callback. - const onLLMStream = getCallbackManager().onLLMStream; - const chunkStream = await this.session.portkey.chatCompletions.create({ messages, ...params, stream: true, }); - const event: Event = parentEvent - ? parentEvent - : { - id: "unspecified", - type: "llmPredict" as EventType, - }; - //Indices let idx_counter: number = 0; for await (const part of chunkStream) { @@ -833,12 +810,11 @@ export class Portkey extends BaseLLM { //onLLMStream Callback const stream_callback: StreamCallbackResponse = { - event: event, index: idx_counter, isDone: is_done, // token: part, }; - onLLMStream(stream_callback); + getCallbackManager().dispatchEvent("stream", stream_callback); idx_counter++; diff --git a/packages/core/src/llm/base.ts b/packages/core/src/llm/base.ts index 6dce23cb8c..04854c2049 100644 --- a/packages/core/src/llm/base.ts +++ b/packages/core/src/llm/base.ts @@ -23,11 +23,10 @@ export abstract class BaseLLM implements LLM { async complete( params: LLMCompletionParamsStreaming | LLMCompletionParamsNonStreaming, ): Promise> { - const { prompt, parentEvent, stream } = params; + const { prompt, stream } = params; if (stream) { const stream = await this.chat({ messages: [{ content: prompt, role: "user" }], - parentEvent, stream: true, }); return streamConverter(stream, (chunk) => { @@ -38,7 +37,6 @@ export abstract class BaseLLM implements LLM { } const chatResponse = await this.chat({ messages: [{ content: prompt, role: "user" }], - parentEvent, }); return { text: chatResponse.message.content as string }; } diff --git a/packages/core/src/llm/mistral.ts b/packages/core/src/llm/mistral.ts index ce3fd5b0ee..ffbf192636 100644 --- a/packages/core/src/llm/mistral.ts +++ b/packages/core/src/llm/mistral.ts @@ -1,10 +1,6 @@ import { getEnv } from "@llamaindex/env"; import { Settings } from "../Settings.js"; -import { - type Event, - type EventType, - type StreamCallbackResponse, -} from "../callbacks/CallbackManager.js"; +import { type StreamCallbackResponse } from "../callbacks/CallbackManager.js"; import { BaseLLM } from "./base.js"; import type { ChatMessage, @@ -116,21 +112,10 @@ export class MistralAI extends BaseLLM { protected async *streamChat({ messages, - parentEvent, }: LLMChatParamsStreaming): AsyncIterable { - //Now let's wrap our stream in a callback - const onLLMStream = Settings.callbackManager.onLLMStream; - const client = await this.session.getClient(); const chunkStream = await client.chatStream(this.buildParams(messages)); - const event: Event = parentEvent - ? parentEvent - : { - id: "unspecified", - type: "llmPredict" as EventType, - }; - //Indices let idx_counter: number = 0; for await (const part of chunkStream) { @@ -141,12 +126,12 @@ export class MistralAI extends BaseLLM { part.choices[0].finish_reason === "stop" ? true : false; const stream_callback: StreamCallbackResponse = { - event: event, index: idx_counter, isDone: isDone, token: part, }; - onLLMStream(stream_callback); + + Settings.callbackManager.dispatchEvent("stream", stream_callback); idx_counter++; diff --git a/packages/core/src/llm/ollama.ts b/packages/core/src/llm/ollama.ts index 5b7f39e7c4..aad435628d 100644 --- a/packages/core/src/llm/ollama.ts +++ b/packages/core/src/llm/ollama.ts @@ -1,5 +1,4 @@ import { ok } from "@llamaindex/env"; -import type { Event } from "../callbacks/CallbackManager.js"; import { BaseEmbedding } from "../embeddings/types.js"; import type { ChatResponse, @@ -69,7 +68,7 @@ export class Ollama extends BaseEmbedding implements LLM { async chat( params: LLMChatParamsNonStreaming | LLMChatParamsStreaming, ): Promise> { - const { messages, parentEvent, stream } = params; + const { messages, stream } = params; const payload = { model: this.model, messages: messages.map((message) => ({ @@ -106,14 +105,13 @@ export class Ollama extends BaseEmbedding implements LLM { const stream = response.body; ok(stream, "stream is null"); ok(stream instanceof ReadableStream, "stream is not readable"); - return this.streamChat(stream, messageAccessor, parentEvent); + return this.streamChat(stream, messageAccessor); } } private async *streamChat( stream: ReadableStream, accessor: (data: any) => T, - parentEvent?: Event, ): AsyncIterable { const reader = stream.getReader(); while (true) { @@ -147,7 +145,7 @@ export class Ollama extends BaseEmbedding implements LLM { async complete( params: LLMCompletionParamsStreaming | LLMCompletionParamsNonStreaming, ): Promise> { - const { prompt, parentEvent, stream } = params; + const { prompt, stream } = params; const payload = { model: this.model, prompt: prompt, @@ -177,7 +175,7 @@ export class Ollama extends BaseEmbedding implements LLM { const stream = response.body; ok(stream, "stream is null"); ok(stream instanceof ReadableStream, "stream is not readable"); - return this.streamChat(stream, completionAccessor, parentEvent); + return this.streamChat(stream, completionAccessor); } } diff --git a/packages/core/src/llm/types.ts b/packages/core/src/llm/types.ts index 7bd2ebad55..8aa9548c44 100644 --- a/packages/core/src/llm/types.ts +++ b/packages/core/src/llm/types.ts @@ -1,5 +1,4 @@ import type { Tokenizers } from "../GlobalsHelper.js"; -import { type Event } from "../callbacks/CallbackManager.js"; type LLMBaseEvent< Type extends string, @@ -103,7 +102,6 @@ export interface LLMMetadata { export interface LLMChatParamsBase { messages: ChatMessage[]; - parentEvent?: Event; extraParams?: Record; tools?: any; toolChoice?: any; @@ -120,7 +118,6 @@ export interface LLMChatParamsNonStreaming extends LLMChatParamsBase { export interface LLMCompletionParamsBase { prompt: any; - parentEvent?: Event; } export interface LLMCompletionParamsStreaming extends LLMCompletionParamsBase { diff --git a/packages/core/src/llm/utils.ts b/packages/core/src/llm/utils.ts index ffc40e0cbe..03725ad5e8 100644 --- a/packages/core/src/llm/utils.ts +++ b/packages/core/src/llm/utils.ts @@ -1,3 +1,4 @@ +import { AsyncLocalStorage } from "@llamaindex/env"; import { getCallbackManager } from "../internal/settings/CallbackManager.js"; import type { ChatResponse, LLM, LLMChat, MessageContent } from "./types.js"; @@ -47,7 +48,7 @@ export function extractText(message: MessageContent): string { /** * @internal */ -export function llmEvent( +export function wrapLLMEvent( originalMethod: LLMChat["chat"], _context: ClassMethodDecoratorContext, ) { @@ -62,6 +63,8 @@ export function llmEvent( }); const response = await originalMethod.call(this, ...params); if (Symbol.asyncIterator in response) { + // save snapshot to restore it after the response is done + const snapshot = AsyncLocalStorage.snapshot(); const originalAsyncIterator = { [Symbol.asyncIterator]: response[Symbol.asyncIterator].bind(response), }; @@ -82,10 +85,12 @@ export function llmEvent( } yield chunk; } - getCallbackManager().dispatchEvent("llm-end", { - payload: { - response: finalResponse, - }, + snapshot(() => { + getCallbackManager().dispatchEvent("llm-end", { + payload: { + response: finalResponse, + }, + }); }); }; } else { diff --git a/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts b/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts index 857cbd9a72..5c93c60190 100644 --- a/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts +++ b/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts @@ -55,7 +55,6 @@ export class MultiModalResponseSynthesizer async synthesize({ query, nodesWithScore, - parentEvent, stream, }: SynthesizeParamsStreaming | SynthesizeParamsNonStreaming): Promise< AsyncIterable | Response @@ -90,7 +89,6 @@ export class MultiModalResponseSynthesizer const response = await llm.complete({ prompt, - parentEvent, }); return new Response(response.text, nodes); diff --git a/packages/core/src/synthesizers/ResponseSynthesizer.ts b/packages/core/src/synthesizers/ResponseSynthesizer.ts index b837a8867d..018ed8e1e8 100644 --- a/packages/core/src/synthesizers/ResponseSynthesizer.ts +++ b/packages/core/src/synthesizers/ResponseSynthesizer.ts @@ -62,7 +62,6 @@ export class ResponseSynthesizer async synthesize({ query, nodesWithScore, - parentEvent, stream, }: SynthesizeParamsStreaming | SynthesizeParamsNonStreaming): Promise< AsyncIterable | Response @@ -75,7 +74,6 @@ export class ResponseSynthesizer const response = await this.responseBuilder.getResponse({ query, textChunks, - parentEvent, stream, }); return streamConverter(response, (chunk) => new Response(chunk, nodes)); @@ -83,7 +81,6 @@ export class ResponseSynthesizer const response = await this.responseBuilder.getResponse({ query, textChunks, - parentEvent, }); return new Response(response, nodes); } diff --git a/packages/core/src/synthesizers/builders.ts b/packages/core/src/synthesizers/builders.ts index 25aa70ddcf..ee0cce07c8 100644 --- a/packages/core/src/synthesizers/builders.ts +++ b/packages/core/src/synthesizers/builders.ts @@ -1,4 +1,3 @@ -import type { Event } from "../callbacks/CallbackManager.js"; import type { LLM } from "../llm/index.js"; import { streamConverter } from "../llm/utils.js"; import type { @@ -55,7 +54,6 @@ export class SimpleResponseBuilder implements ResponseBuilder { async getResponse({ query, textChunks, - parentEvent, stream, }: | ResponseBuilderParamsStreaming @@ -69,10 +67,10 @@ export class SimpleResponseBuilder implements ResponseBuilder { const prompt = this.textQATemplate(input); if (stream) { - const response = await this.llm.complete({ prompt, parentEvent, stream }); + const response = await this.llm.complete({ prompt, stream }); return streamConverter(response, (chunk) => chunk.text); } else { - const response = await this.llm.complete({ prompt, parentEvent, stream }); + const response = await this.llm.complete({ prompt, stream }); return response.text; } } @@ -130,7 +128,6 @@ export class Refine extends PromptMixin implements ResponseBuilder { async getResponse({ query, textChunks, - parentEvent, prevResponse, stream, }: @@ -148,7 +145,6 @@ export class Refine extends PromptMixin implements ResponseBuilder { query, chunk, !!stream && lastChunk, - parentEvent, ); } else { response = await this.refineResponseSingle( @@ -156,7 +152,6 @@ export class Refine extends PromptMixin implements ResponseBuilder { query, chunk, !!stream && lastChunk, - parentEvent, ); } } @@ -168,7 +163,6 @@ export class Refine extends PromptMixin implements ResponseBuilder { queryStr: string, textChunk: string, stream: boolean, - parentEvent?: Event, ) { const textQATemplate: SimplePrompt = (input) => this.textQATemplate({ ...input, query: queryStr }); @@ -184,7 +178,6 @@ export class Refine extends PromptMixin implements ResponseBuilder { prompt: textQATemplate({ context: chunk, }), - parentEvent, stream: stream && lastChunk, }); } else { @@ -193,7 +186,6 @@ export class Refine extends PromptMixin implements ResponseBuilder { queryStr, chunk, stream && lastChunk, - parentEvent, ); } } @@ -207,7 +199,6 @@ export class Refine extends PromptMixin implements ResponseBuilder { queryStr: string, textChunk: string, stream: boolean, - parentEvent?: Event, ) { const refineTemplate: SimplePrompt = (input) => this.refineTemplate({ ...input, query: queryStr }); @@ -224,7 +215,6 @@ export class Refine extends PromptMixin implements ResponseBuilder { context: chunk, existingAnswer: response as string, }), - parentEvent, stream: stream && lastChunk, }); } @@ -234,7 +224,6 @@ export class Refine extends PromptMixin implements ResponseBuilder { async complete(params: { prompt: string; stream: boolean; - parentEvent?: Event; }): Promise | string> { if (params.stream) { const response = await this.llm.complete({ ...params, stream: true }); @@ -257,7 +246,6 @@ export class CompactAndRefine extends Refine { async getResponse({ query, textChunks, - parentEvent, prevResponse, stream, }: @@ -275,7 +263,6 @@ export class CompactAndRefine extends Refine { const params = { query, textChunks: newTexts, - parentEvent, prevResponse, }; if (stream) { @@ -328,7 +315,6 @@ export class TreeSummarize extends PromptMixin implements ResponseBuilder { async getResponse({ query, textChunks, - parentEvent, stream, }: | ResponseBuilderParamsStreaming @@ -351,7 +337,6 @@ export class TreeSummarize extends PromptMixin implements ResponseBuilder { context: packedTextChunks[0], query, }), - parentEvent, }; if (stream) { const response = await this.llm.complete({ ...params, stream }); @@ -366,7 +351,6 @@ export class TreeSummarize extends PromptMixin implements ResponseBuilder { context: chunk, query, }), - parentEvent, }), ), ); diff --git a/packages/core/src/synthesizers/types.ts b/packages/core/src/synthesizers/types.ts index 8b8a469729..abfcd9b16b 100644 --- a/packages/core/src/synthesizers/types.ts +++ b/packages/core/src/synthesizers/types.ts @@ -1,4 +1,3 @@ -import type { Event } from "../callbacks/CallbackManager.js"; import type { NodeWithScore } from "../Node.js"; import type { PromptMixin } from "../prompts/Mixin.js"; import type { Response } from "../Response.js"; @@ -6,7 +5,6 @@ import type { Response } from "../Response.js"; export interface SynthesizeParamsBase { query: string; nodesWithScore: NodeWithScore[]; - parentEvent?: Event; } export interface SynthesizeParamsStreaming extends SynthesizeParamsBase { @@ -30,7 +28,6 @@ export interface BaseSynthesizer { export interface ResponseBuilderParamsBase { query: string; textChunks: string[]; - parentEvent?: Event; prevResponse?: string; } diff --git a/packages/core/src/types.ts b/packages/core/src/types.ts index 827e51420a..6942556388 100644 --- a/packages/core/src/types.ts +++ b/packages/core/src/types.ts @@ -1,7 +1,6 @@ /** * Top level types to avoid circular dependencies */ -import type { Event } from "./callbacks/CallbackManager.js"; import type { Response } from "./Response.js"; /** @@ -9,7 +8,6 @@ import type { Response } from "./Response.js"; */ export interface QueryEngineParamsBase { query: string; - parentEvent?: Event; } export interface QueryEngineParamsStreaming extends QueryEngineParamsBase { diff --git a/packages/core/tests/CallbackManager.test.ts b/packages/core/tests/CallbackManager.test.ts index 461a18bc04..86f31b1836 100644 --- a/packages/core/tests/CallbackManager.test.ts +++ b/packages/core/tests/CallbackManager.test.ts @@ -88,12 +88,6 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2"); expect(streamCallbackData).toEqual([ { - event: { - id: expect.any(String), - parentId: expect.any(String), - type: "llmPredict", - tags: ["final"], - }, index: 0, token: { id: "id", @@ -104,12 +98,6 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { }, }, { - event: { - id: expect.any(String), - parentId: expect.any(String), - type: "llmPredict", - tags: ["final"], - }, index: 1, token: { id: "id", @@ -120,12 +108,6 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { }, }, { - event: { - id: expect.any(String), - parentId: expect.any(String), - type: "llmPredict", - tags: ["final"], - }, index: 2, isDone: true, }, @@ -134,19 +116,8 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { { query: query, nodes: expect.any(Array), - event: { - id: expect.any(String), - parentId: expect.any(String), - type: "retrieve", - tags: ["final"], - }, }, ]); - // both retrieval and streaming should have - // the same parent event - expect(streamCallbackData[0].event.parentId).toBe( - retrieveCallbackData[0].event.parentId, - ); }); test("For SummaryIndex w/ a SummaryIndexRetriever", async () => { @@ -169,12 +140,6 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2"); expect(streamCallbackData).toEqual([ { - event: { - id: expect.any(String), - parentId: expect.any(String), - type: "llmPredict", - tags: ["final"], - }, index: 0, token: { id: "id", @@ -185,12 +150,6 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { }, }, { - event: { - id: expect.any(String), - parentId: expect.any(String), - type: "llmPredict", - tags: ["final"], - }, index: 1, token: { id: "id", @@ -201,12 +160,6 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { }, }, { - event: { - id: expect.any(String), - parentId: expect.any(String), - type: "llmPredict", - tags: ["final"], - }, index: 2, isDone: true, }, @@ -215,18 +168,7 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { { query: query, nodes: expect.any(Array), - event: { - id: expect.any(String), - parentId: expect.any(String), - type: "retrieve", - tags: ["final"], - }, }, ]); - // both retrieval and streaming should have - // the same parent event - expect(streamCallbackData[0].event.parentId).toBe( - retrieveCallbackData[0].event.parentId, - ); }); }); diff --git a/packages/core/tests/utility/mockOpenAI.ts b/packages/core/tests/utility/mockOpenAI.ts index 0249be5631..8fac02b5bf 100644 --- a/packages/core/tests/utility/mockOpenAI.ts +++ b/packages/core/tests/utility/mockOpenAI.ts @@ -1,4 +1,3 @@ -import { globalsHelper } from "llamaindex/GlobalsHelper"; import type { CallbackManager } from "llamaindex/callbacks/CallbackManager"; import type { OpenAIEmbedding } from "llamaindex/embeddings/index"; import type { OpenAI } from "llamaindex/llm/LLM"; @@ -15,18 +14,13 @@ export function mockLlmGeneration({ callbackManager?: CallbackManager; }) { vi.spyOn(languageModel, "chat").mockImplementation( - async ({ messages, parentEvent }: LLMChatParamsBase) => { + async ({ messages }: LLMChatParamsBase) => { const text = DEFAULT_LLM_TEXT_OUTPUT; - const event = globalsHelper.createEvent({ - parentEvent, - type: "llmPredict", - }); if (callbackManager?.onLLMStream) { const chunks = text.split("-"); for (let i = 0; i < chunks.length; i++) { const chunk = chunks[i]; callbackManager?.onLLMStream({ - event, index: i, token: { id: "id", @@ -46,7 +40,6 @@ export function mockLlmGeneration({ }); } callbackManager?.onLLMStream({ - event, index: chunks.length, isDone: true, }); @@ -122,18 +115,13 @@ export function mocStructuredkLlmGeneration({ callbackManager?: CallbackManager; }) { vi.spyOn(languageModel, "chat").mockImplementation( - async ({ messages, parentEvent }: LLMChatParamsBase) => { + async ({ messages }: LLMChatParamsBase) => { const text = structuredOutput; - const event = globalsHelper.createEvent({ - parentEvent, - type: "llmPredict", - }); if (callbackManager?.onLLMStream) { const chunks = text.split("-"); for (let i = 0; i < chunks.length; i++) { const chunk = chunks[i]; callbackManager?.onLLMStream({ - event, index: i, token: { id: "id", @@ -153,7 +141,6 @@ export function mocStructuredkLlmGeneration({ }); } callbackManager?.onLLMStream({ - event, index: chunks.length, isDone: true, });