Skip to content

Commit

Permalink
refactor: use event.reason, remove parentEvent (#681)
Browse files Browse the repository at this point in the history
  • Loading branch information
himself65 authored Apr 2, 2024
1 parent a6dfa30 commit d256cbe
Show file tree
Hide file tree
Showing 30 changed files with 226 additions and 349 deletions.
13 changes: 12 additions & 1 deletion examples/chatHistory.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import { stdin as input, stdout as output } from "node:process";
import readline from "node:readline/promises";

import { OpenAI, SimpleChatEngine, SummaryChatHistory } from "llamaindex";
import {
OpenAI,
Settings,
SimpleChatEngine,
SummaryChatHistory,
} from "llamaindex";

if (process.env.NODE_ENV === "development") {
Settings.callbackManager.on("llm-end", (event) => {
console.log("callers chain", event.reason?.computedCallers);
});
}

async function main() {
// Set maxTokens to 75% of the context window size of 4096
Expand Down
40 changes: 0 additions & 40 deletions packages/core/src/GlobalsHelper.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
import { encodingForModel } from "js-tiktoken";

import { randomUUID } from "@llamaindex/env";
import type {
Event,
EventTag,
EventType,
} from "./callbacks/CallbackManager.js";

export enum Tokenizers {
CL100K_BASE = "cl100k_base",
}
Expand Down Expand Up @@ -51,39 +44,6 @@ class GlobalsHelper {

return this.defaultTokenizer!.decode.bind(this.defaultTokenizer);
}

/**
* @deprecated createEvent will be removed in the future,
* please use `new CustomEvent(eventType, { detail: payload })` instead.
*
* Also, `parentEvent` will not be used in the future,
* use `AsyncLocalStorage` to track parent events instead.
* @example - Usage of `AsyncLocalStorage`:
* let id = 0;
* const asyncLocalStorage = new AsyncLocalStorage<number>();
* asyncLocalStorage.run(++id, async () => {
* setTimeout(() => {
* console.log('parent event id:', asyncLocalStorage.getStore()); // 1
* }, 1000)
* });
*/
createEvent({
parentEvent,
type,
tags,
}: {
parentEvent?: Event;
type: EventType;
tags?: EventTag[];
}): Event {
return {
id: randomUUID(),
type,
// inherit parent tags if tags not set
tags: tags || parentEvent?.tags,
parentId: parentEvent?.id,
};
}
}

export const globalsHelper = new GlobalsHelper();
5 changes: 0 additions & 5 deletions packages/core/src/Retriever.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import type { Event } from "./callbacks/CallbackManager.js";
import type { NodeWithScore } from "./Node.js";
import type { ServiceContext } from "./ServiceContext.js";

export type RetrieveParams = {
query: string;
/**
* @deprecated will be removed in the next major version
*/
parentEvent?: Event;
preFilters?: unknown;
};

Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/agent/openai/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ export class OpenAIAgentWorker implements AgentWorker {
...llmChatKwargs,
});

const iterator = streamConverter(
const iterator = streamConverter.bind(this)(
streamReducer({
stream,
initialValue: "",
Expand Down
65 changes: 39 additions & 26 deletions packages/core/src/callbacks/CallbackManager.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,33 @@
import type { Anthropic } from "@anthropic-ai/sdk";
import { CustomEvent } from "@llamaindex/env";
import type { NodeWithScore } from "../Node.js";
import {
EventCaller,
getEventCaller,
} from "../internal/context/EventCaller.js";

export class LlamaIndexCustomEvent<T = any> extends CustomEvent<T> {
reason: EventCaller | null;
private constructor(
event: string,
options?: CustomEventInit & {
reason?: EventCaller | null;
},
) {
super(event, options);
this.reason = options?.reason ?? null;
}

static fromEvent<Type extends keyof LlamaIndexEventMaps>(
type: Type,
detail: LlamaIndexEventMaps[Type]["detail"],
) {
return new LlamaIndexCustomEvent(type, {
detail: detail,
reason: getEventCaller(),
});
}
}

/**
* This type is used to define the event maps for the Llamaindex package.
Expand All @@ -21,26 +48,6 @@ declare module "llamaindex" {
}

//#region @deprecated remove in the next major version
/*
An event is a wrapper that groups related operations.
For example, during retrieve and synthesize,
a parent event wraps both operations, and each operation has it's own
event. In this case, both sub-events will share a parentId.
*/

export type EventTag = "intermediate" | "final";
export type EventType = "retrieve" | "llmPredict" | "wrapper";
export interface Event {
id: string;
type: EventType;
tags?: EventTag[];
parentId?: string;
}

interface BaseCallbackResponse {
event: Event;
}

//Specify StreamToken per mainstream LLM
export interface DefaultStreamToken {
id: string;
Expand Down Expand Up @@ -68,13 +75,13 @@ export type AnthropicStreamToken = Anthropic.Completion;

//StreamCallbackResponse should let practitioners implement callbacks out of the box...
//When custom streaming LLMs are involved, people are expected to write their own StreamCallbackResponses
export interface StreamCallbackResponse extends BaseCallbackResponse {
export interface StreamCallbackResponse {
index: number;
isDone?: boolean;
token?: DefaultStreamToken;
}

export interface RetrievalCallbackResponse extends BaseCallbackResponse {
export interface RetrievalCallbackResponse {
query: string;
nodes: NodeWithScore[];
}
Expand All @@ -98,7 +105,11 @@ interface CallbackManagerMethods {

const noop: (...args: any[]) => any = () => void 0;

type EventHandler<Event extends CustomEvent> = (event: Event) => void;
type EventHandler<Event extends CustomEvent> = (
event: Event & {
reason: EventCaller | null;
},
) => void;

export class CallbackManager implements CallbackManagerMethods {
/**
Expand All @@ -110,7 +121,7 @@ export class CallbackManager implements CallbackManagerMethods {
this.#handlers
.get("stream")!
.map((handler) =>
handler(new CustomEvent("stream", { detail: response })),
handler(LlamaIndexCustomEvent.fromEvent("stream", response)),
),
);
};
Expand All @@ -125,7 +136,7 @@ export class CallbackManager implements CallbackManagerMethods {
this.#handlers
.get("retrieve")!
.map((handler) =>
handler(new CustomEvent("retrieve", { detail: response })),
handler(LlamaIndexCustomEvent.fromEvent("retrieve", response)),
),
);
};
Expand Down Expand Up @@ -188,6 +199,8 @@ export class CallbackManager implements CallbackManagerMethods {
if (!handlers) {
return;
}
handlers.forEach((handler) => handler(new CustomEvent(event, { detail })));
handlers.forEach((handler) =>
handler(LlamaIndexCustomEvent.fromEvent(event, detail)),
);
}
}
11 changes: 3 additions & 8 deletions packages/core/src/cloud/LlamaCloudRetriever.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import type { PlatformApi, PlatformApiClient } from "@llamaindex/cloud";
import { globalsHelper } from "../GlobalsHelper.js";
import type { NodeWithScore } from "../Node.js";
import { ObjectType, jsonToNode } from "../Node.js";
import type { BaseRetriever, RetrieveParams } from "../Retriever.js";
import { Settings } from "../Settings.js";
import { wrapEventCaller } from "../internal/context/EventCaller.js";
import type { ClientParams, CloudConstructorParams } from "./types.js";
import { DEFAULT_PROJECT_NAME } from "./types.js";
import { getClient } from "./utils.js";

export type CloudRetrieveParams = Omit<
PlatformApi.RetrievalParams,
"query" | "searchFilters" | "pipelineId" | "className"
Expand Down Expand Up @@ -51,9 +50,9 @@ export class LlamaCloudRetriever implements BaseRetriever {
return this.client;
}

@wrapEventCaller
async retrieve({
query,
parentEvent,
preFilters,
}: RetrieveParams): Promise<NodeWithScore[]> {
const pipelines = await (
Expand All @@ -77,13 +76,9 @@ export class LlamaCloudRetriever implements BaseRetriever {

const nodes = this.resultNodesToNodeWithScore(results.retrievalNodes);

Settings.callbackManager.onRetrieve({
Settings.callbackManager.dispatchEvent("retrieve", {
query,
nodes,
event: globalsHelper.createEvent({
parentEvent,
type: "retrieve",
}),
});

return nodes;
Expand Down
3 changes: 2 additions & 1 deletion packages/core/src/engines/chat/CondenseQuestionChatEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
import type { Response } from "../../Response.js";
import type { ServiceContext } from "../../ServiceContext.js";
import { llmFromSettingsOrContext } from "../../Settings.js";
import { wrapEventCaller } from "../../internal/context/EventCaller.js";
import type { ChatMessage, LLM } from "../../llm/index.js";
import { extractText, streamReducer } from "../../llm/utils.js";
import { PromptMixin } from "../../prompts/index.js";
Expand All @@ -17,7 +18,6 @@ import type {
ChatEngineParamsNonStreaming,
ChatEngineParamsStreaming,
} from "./types.js";

/**
* CondenseQuestionChatEngine is used in conjunction with a Index (for example VectorStoreIndex).
* It does two steps on taking a user's chat message: first, it condenses the chat message
Expand Down Expand Up @@ -82,6 +82,7 @@ export class CondenseQuestionChatEngine

chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>;
chat(params: ChatEngineParamsNonStreaming): Promise<Response>;
@wrapEventCaller
async chat(
params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming,
): Promise<Response | AsyncIterable<Response>> {
Expand Down
15 changes: 3 additions & 12 deletions packages/core/src/engines/chat/ContextChatEngine.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import { randomUUID } from "@llamaindex/env";
import type { ChatHistory } from "../../ChatHistory.js";
import { getHistory } from "../../ChatHistory.js";
import type { ContextSystemPrompt } from "../../Prompt.js";
import { Response } from "../../Response.js";
import type { BaseRetriever } from "../../Retriever.js";
import type { Event } from "../../callbacks/CallbackManager.js";
import { wrapEventCaller } from "../../internal/context/EventCaller.js";
import type { ChatMessage, ChatResponseChunk, LLM } from "../../llm/index.js";
import { OpenAI } from "../../llm/index.js";
import type { MessageContent } from "../../llm/types.js";
Expand Down Expand Up @@ -60,28 +59,22 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine {

chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>;
chat(params: ChatEngineParamsNonStreaming): Promise<Response>;
@wrapEventCaller
async chat(
params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming,
): Promise<Response | AsyncIterable<Response>> {
const { message, stream } = params;
const chatHistory = params.chatHistory
? getHistory(params.chatHistory)
: this.chatHistory;
const parentEvent: Event = {
id: randomUUID(),
type: "wrapper",
tags: ["final"],
};
const requestMessages = await this.prepareRequestMessages(
message,
chatHistory,
parentEvent,
);

if (stream) {
const stream = await this.chatModel.chat({
messages: requestMessages.messages,
parentEvent,
stream: true,
});
return streamConverter(
Expand All @@ -98,7 +91,6 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine {
}
const response = await this.chatModel.chat({
messages: requestMessages.messages,
parentEvent,
});
chatHistory.addMessage(response.message);
return new Response(response.message.content, requestMessages.nodes);
Expand All @@ -111,14 +103,13 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine {
private async prepareRequestMessages(
message: MessageContent,
chatHistory: ChatHistory,
parentEvent?: Event,
) {
chatHistory.addMessage({
content: message,
role: "user",
});
const textOnly = extractText(message);
const context = await this.contextGenerator.generate(textOnly, parentEvent);
const context = await this.contextGenerator.generate(textOnly);
const nodes = context.nodes.map((r) => r.node);
const messages = await chatHistory.requestMessages(
context ? [context.message] : undefined,
Expand Down
12 changes: 1 addition & 11 deletions packages/core/src/engines/chat/DefaultContextGenerator.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import { randomUUID } from "@llamaindex/env";
import type { NodeWithScore, TextNode } from "../../Node.js";
import type { ContextSystemPrompt } from "../../Prompt.js";
import { defaultContextSystemPrompt } from "../../Prompt.js";
import type { BaseRetriever } from "../../Retriever.js";
import type { Event } from "../../callbacks/CallbackManager.js";
import type { BaseNodePostprocessor } from "../../postprocessors/index.js";
import { PromptMixin } from "../../prompts/index.js";
import type { Context, ContextGenerator } from "./types.js";
Expand Down Expand Up @@ -56,17 +54,9 @@ export class DefaultContextGenerator
return nodesWithScore;
}

async generate(message: string, parentEvent?: Event): Promise<Context> {
if (!parentEvent) {
parentEvent = {
id: randomUUID(),
type: "wrapper",
tags: ["final"],
};
}
async generate(message: string): Promise<Context> {
const sourceNodesWithScore = await this.retriever.retrieve({
query: message,
parentEvent,
});

const nodes = await this.applyNodePostprocessors(
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/engines/chat/SimpleChatEngine.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { ChatHistory } from "../../ChatHistory.js";
import { getHistory } from "../../ChatHistory.js";
import { Response } from "../../Response.js";
import { wrapEventCaller } from "../../internal/context/EventCaller.js";
import type { ChatResponseChunk, LLM } from "../../llm/index.js";
import { OpenAI } from "../../llm/index.js";
import { streamConverter, streamReducer } from "../../llm/utils.js";
Expand All @@ -25,6 +26,7 @@ export class SimpleChatEngine implements ChatEngine {

chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>;
chat(params: ChatEngineParamsNonStreaming): Promise<Response>;
@wrapEventCaller
async chat(
params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming,
): Promise<Response | AsyncIterable<Response>> {
Expand Down
Loading

0 comments on commit d256cbe

Please sign in to comment.