Skip to content

Commit

Permalink
fix: handle aborts and add logging to the stream handler (#350)
Browse files Browse the repository at this point in the history
Co-authored-by: Adam Howard <[email protected]>
  • Loading branch information
stefl and codeincontext authored Nov 20, 2024
1 parent 75e64b9 commit 20f956e
Showing 1 changed file with 47 additions and 51 deletions.
98 changes: 47 additions & 51 deletions packages/aila/src/core/chat/AilaStreamHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ export class AilaStreamHandler {
private _chat: AilaChat;
private _controller?: ReadableStreamDefaultController;
private _patchEnqueuer: PatchEnqueuer;
private _isStreaming: boolean = false;
private _streamReader?: ReadableStreamDefaultReader<string>;
private _abortController?: AbortController;

constructor(chat: AilaChat) {
this._chat = chat;
Expand All @@ -23,31 +21,49 @@ export class AilaStreamHandler {

public startStreaming(abortController?: AbortController): ReadableStream {
return new ReadableStream({
start: async (controller) => {
await this.stream(controller, abortController);
start: (controller) => {
this.stream(controller, abortController).catch((error) => {
log.error("Error in stream:", error);
controller.error(error);
});
},
});
}

private logStreamingStep(step: string) {
log.info(`Streaming step: ${step}`);
}

private async stream(
controller: ReadableStreamDefaultController,
abortController?: AbortController,
) {
this.setupController(controller);
this.listenForAbort(abortController);
try {
await this._chat.setupGeneration();
this.logStreamingStep("Setup generation complete");

await this._chat.handleSettingInitialState();
this.logStreamingStep("Handle initial state complete");

await this._chat.handleSubjectWarning();
this.logStreamingStep("Handle subject warning complete");

await this.startLLMStream();
while (this._isStreaming) {
await this.readFromStream();
}
this.logStreamingStep("Start LLM stream complete");

await this.readFromStream(abortController);
this.logStreamingStep("Read from stream complete");

log.info(
"Finished reading from stream",
this._chat.iteration,
this._chat.id,
);
} catch (e) {
await this.handleStreamError(e);
this.handleStreamError(e);
log.info("Stream error", e, this._chat.iteration, this._chat.id);
} finally {
this._isStreaming = false;
try {
await this._chat.complete();
log.info("Chat completed", this._chat.iteration, this._chat.id);
Expand All @@ -68,23 +84,6 @@ export class AilaStreamHandler {
this._patchEnqueuer.setController(controller);
}

private listenForAbort(abortController?: AbortController) {
if (!abortController) {
return;
}
if (this._abortController) {
this._abortController.signal.removeEventListener(
"abort",
this.stopStreamingOnAbort,
);
}
this._abortController = abortController;
this._abortController.signal.addEventListener(
"abort",
this.stopStreamingOnAbort,
);
}

private async startLLMStream() {
await this._chat.enqueue({
type: "comment",
Expand All @@ -93,23 +92,37 @@ export class AilaStreamHandler {
const messages = this._chat.completionMessages();
this._streamReader =
await this._chat.createChatCompletionObjectStream(messages);
this._isStreaming = true;
}

private async readFromStream() {
private async readFromStream(abortController?: AbortController) {
if (!this._streamReader) {
throw new Error("Stream reader is not defined");
}
try {
await this.fetchChunkFromStream();
} catch (error) {
await this._chat.generationFailed(error);
throw error;
while (true) {
const { done, value } = await this._streamReader.read();
if (done) {
break;
}
if (value) {
this._chat.appendChunk(value);
this._controller?.enqueue(value);
}
}
} catch (e) {
if (abortController?.signal.aborted) {
log.info("Stream aborted", this._chat.iteration, this._chat.id);
} else {
throw e;
}
}
}

private async handleStreamError(error: unknown) {
for (const plugin of this._chat.aila.plugins ?? []) {
await plugin.onStreamError?.(error, {
aila: this._chat.aila,
enqueue: (patch) => this._chat.enqueue(patch),
enqueue: this._chat.enqueue,
});
}

Expand All @@ -135,21 +148,4 @@ export class AilaStreamHandler {
this._controller.close();
}
}

private async fetchChunkFromStream() {
if (this._streamReader) {
const { done, value } = await this._streamReader.read();
if (value) {
this._chat.appendChunk(value);
this._controller?.enqueue(value);
}
if (done) {
this._isStreaming = false;
}
}
}

private stopStreamingOnAbort = () => {
this._isStreaming = false;
};
}

0 comments on commit 20f956e

Please sign in to comment.