diff --git a/packages/aila/src/core/chat/AilaStreamHandler.ts b/packages/aila/src/core/chat/AilaStreamHandler.ts index 4ec0c014f..656ee0370 100644 --- a/packages/aila/src/core/chat/AilaStreamHandler.ts +++ b/packages/aila/src/core/chat/AilaStreamHandler.ts @@ -12,9 +12,7 @@ export class AilaStreamHandler { private _chat: AilaChat; private _controller?: ReadableStreamDefaultController; private _patchEnqueuer: PatchEnqueuer; - private _isStreaming: boolean = false; private _streamReader?: ReadableStreamDefaultReader; - private _abortController?: AbortController; constructor(chat: AilaChat) { this._chat = chat; @@ -23,31 +21,49 @@ export class AilaStreamHandler { public startStreaming(abortController?: AbortController): ReadableStream { return new ReadableStream({ - start: async (controller) => { - await this.stream(controller, abortController); + start: (controller) => { + this.stream(controller, abortController).catch((error) => { + log.error("Error in stream:", error); + controller.error(error); + }); }, }); } + private logStreamingStep(step: string) { + log.info(`Streaming step: ${step}`); + } + private async stream( controller: ReadableStreamDefaultController, abortController?: AbortController, ) { this.setupController(controller); - this.listenForAbort(abortController); try { await this._chat.setupGeneration(); + this.logStreamingStep("Setup generation complete"); + await this._chat.handleSettingInitialState(); + this.logStreamingStep("Handle initial state complete"); + await this._chat.handleSubjectWarning(); + this.logStreamingStep("Handle subject warning complete"); + await this.startLLMStream(); - while (this._isStreaming) { - await this.readFromStream(); - } + this.logStreamingStep("Start LLM stream complete"); + + await this.readFromStream(abortController); + this.logStreamingStep("Read from stream complete"); + + log.info( + "Finished reading from stream", + this._chat.iteration, + this._chat.id, + ); } catch (e) { - await this.handleStreamError(e); + this.handleStreamError(e); log.info("Stream error", e, this._chat.iteration, this._chat.id); } finally { - this._isStreaming = false; try { await this._chat.complete(); log.info("Chat completed", this._chat.iteration, this._chat.id); @@ -68,23 +84,6 @@ export class AilaStreamHandler { this._patchEnqueuer.setController(controller); } - private listenForAbort(abortController?: AbortController) { - if (!abortController) { - return; - } - if (this._abortController) { - this._abortController.signal.removeEventListener( - "abort", - this.stopStreamingOnAbort, - ); - } - this._abortController = abortController; - this._abortController.signal.addEventListener( - "abort", - this.stopStreamingOnAbort, - ); - } - private async startLLMStream() { await this._chat.enqueue({ type: "comment", @@ -93,15 +92,29 @@ export class AilaStreamHandler { const messages = this._chat.completionMessages(); this._streamReader = await this._chat.createChatCompletionObjectStream(messages); - this._isStreaming = true; } - private async readFromStream() { + private async readFromStream(abortController?: AbortController) { + if (!this._streamReader) { + throw new Error("Stream reader is not defined"); + } try { - await this.fetchChunkFromStream(); - } catch (error) { - await this._chat.generationFailed(error); - throw error; + while (true) { + const { done, value } = await this._streamReader.read(); + if (done) { + break; + } + if (value) { + this._chat.appendChunk(value); + this._controller?.enqueue(value); + } + } + } catch (e) { + if (abortController?.signal.aborted) { + log.info("Stream aborted", this._chat.iteration, this._chat.id); + } else { + throw e; + } } } @@ -109,7 +122,7 @@ export class AilaStreamHandler { for (const plugin of this._chat.aila.plugins ?? []) { await plugin.onStreamError?.(error, { aila: this._chat.aila, - enqueue: (patch) => this._chat.enqueue(patch), + enqueue: this._chat.enqueue, }); } @@ -135,21 +148,4 @@ export class AilaStreamHandler { this._controller.close(); } } - - private async fetchChunkFromStream() { - if (this._streamReader) { - const { done, value } = await this._streamReader.read(); - if (value) { - this._chat.appendChunk(value); - this._controller?.enqueue(value); - } - if (done) { - this._isStreaming = false; - } - } - } - - private stopStreamingOnAbort = () => { - this._isStreaming = false; - }; }