From d438374e8b3831b3832a40389aa598e9948bfd3a Mon Sep 17 00:00:00 2001 From: David Duong Date: Wed, 22 Nov 2023 16:30:48 +0100 Subject: [PATCH] Fix streaming of Bedrock on Cloudflare Workers (#3364) * Fix streaming of Bedrock on Cloudflare Workers * Handle buffer 0 * Remove unnecessary break --- langchain/src/chat_models/bedrock/web.ts | 32 +++++++++++++++++++++++- langchain/src/llms/bedrock/web.ts | 32 +++++++++++++++++++++++- 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/langchain/src/chat_models/bedrock/web.ts b/langchain/src/chat_models/bedrock/web.ts index 6ea2bf45fa5b..0431f5baa359 100644 --- a/langchain/src/chat_models/bedrock/web.ts +++ b/langchain/src/chat_models/bedrock/web.ts @@ -379,11 +379,41 @@ export class BedrockChat extends SimpleChatModel implements BaseBedrockInput { // eslint-disable-next-line @typescript-eslint/no-explicit-any _readChunks(reader: any) { + function _concatChunks(a: Uint8Array, b: Uint8Array) { + const newBuffer = new Uint8Array(a.length + b.length); + newBuffer.set(a); + newBuffer.set(b, a.length); + return newBuffer; + } + + function getMessageLength(buffer: Uint8Array) { + if (buffer.byteLength === 0) return 0; + const view = new DataView( + buffer.buffer, + buffer.byteOffset, + buffer.byteLength + ); + + return view.getUint32(0, false); + } + return { async *[Symbol.asyncIterator]() { let readResult = await reader.read(); + + let buffer: Uint8Array = new Uint8Array(0); while (!readResult.done) { - yield readResult.value; + const chunk: Uint8Array = readResult.value; + + buffer = _concatChunks(buffer, chunk); + let messageLength = getMessageLength(buffer); + + while (buffer.byteLength > 0 && buffer.byteLength >= messageLength) { + yield buffer.slice(0, messageLength); + buffer = buffer.slice(messageLength); + messageLength = getMessageLength(buffer); + } + readResult = await reader.read(); } }, diff --git a/langchain/src/llms/bedrock/web.ts b/langchain/src/llms/bedrock/web.ts index d2aff9cfa97f..f10660c9feab 100644 --- a/langchain/src/llms/bedrock/web.ts +++ b/langchain/src/llms/bedrock/web.ts @@ -313,11 +313,41 @@ export class Bedrock extends LLM implements BaseBedrockInput { // eslint-disable-next-line @typescript-eslint/no-explicit-any _readChunks(reader: any) { + function _concatChunks(a: Uint8Array, b: Uint8Array) { + const newBuffer = new Uint8Array(a.length + b.length); + newBuffer.set(a); + newBuffer.set(b, a.length); + return newBuffer; + } + + function getMessageLength(buffer: Uint8Array) { + if (buffer.byteLength === 0) return 0; + const view = new DataView( + buffer.buffer, + buffer.byteOffset, + buffer.byteLength + ); + + return view.getUint32(0, false); + } + return { async *[Symbol.asyncIterator]() { let readResult = await reader.read(); + + let buffer: Uint8Array = new Uint8Array(0); while (!readResult.done) { - yield readResult.value; + const chunk: Uint8Array = readResult.value; + + buffer = _concatChunks(buffer, chunk); + let messageLength = getMessageLength(buffer); + + while (buffer.byteLength > 0 && buffer.byteLength >= messageLength) { + yield buffer.slice(0, messageLength); + buffer = buffer.slice(messageLength); + messageLength = getMessageLength(buffer); + } + readResult = await reader.read(); } },