Skip to content

Commit

Permalink
Fix streaming of Bedrock on Cloudflare Workers (#3364)
Browse files Browse the repository at this point in the history
* Fix streaming of Bedrock on Cloudflare Workers

* Handle buffer 0

* Remove unnecessary break
  • Loading branch information
dqbd authored Nov 22, 2023
1 parent e14539a commit d438374
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 2 deletions.
32 changes: 31 additions & 1 deletion langchain/src/chat_models/bedrock/web.ts
Original file line number Diff line number Diff line change
Expand Up @@ -379,11 +379,41 @@ export class BedrockChat extends SimpleChatModel implements BaseBedrockInput {

// eslint-disable-next-line @typescript-eslint/no-explicit-any
_readChunks(reader: any) {
function _concatChunks(a: Uint8Array, b: Uint8Array) {
const newBuffer = new Uint8Array(a.length + b.length);
newBuffer.set(a);
newBuffer.set(b, a.length);
return newBuffer;
}

function getMessageLength(buffer: Uint8Array) {
if (buffer.byteLength === 0) return 0;
const view = new DataView(
buffer.buffer,
buffer.byteOffset,
buffer.byteLength
);

return view.getUint32(0, false);
}

return {
async *[Symbol.asyncIterator]() {
let readResult = await reader.read();

let buffer: Uint8Array = new Uint8Array(0);
while (!readResult.done) {
yield readResult.value;
const chunk: Uint8Array = readResult.value;

buffer = _concatChunks(buffer, chunk);
let messageLength = getMessageLength(buffer);

while (buffer.byteLength > 0 && buffer.byteLength >= messageLength) {
yield buffer.slice(0, messageLength);
buffer = buffer.slice(messageLength);
messageLength = getMessageLength(buffer);
}

readResult = await reader.read();
}
},
Expand Down
32 changes: 31 additions & 1 deletion langchain/src/llms/bedrock/web.ts
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,41 @@ export class Bedrock extends LLM implements BaseBedrockInput {

// eslint-disable-next-line @typescript-eslint/no-explicit-any
_readChunks(reader: any) {
function _concatChunks(a: Uint8Array, b: Uint8Array) {
const newBuffer = new Uint8Array(a.length + b.length);
newBuffer.set(a);
newBuffer.set(b, a.length);
return newBuffer;
}

function getMessageLength(buffer: Uint8Array) {
if (buffer.byteLength === 0) return 0;
const view = new DataView(
buffer.buffer,
buffer.byteOffset,
buffer.byteLength
);

return view.getUint32(0, false);
}

return {
async *[Symbol.asyncIterator]() {
let readResult = await reader.read();

let buffer: Uint8Array = new Uint8Array(0);
while (!readResult.done) {
yield readResult.value;
const chunk: Uint8Array = readResult.value;

buffer = _concatChunks(buffer, chunk);
let messageLength = getMessageLength(buffer);

while (buffer.byteLength > 0 && buffer.byteLength >= messageLength) {
yield buffer.slice(0, messageLength);
buffer = buffer.slice(messageLength);
messageLength = getMessageLength(buffer);
}

readResult = await reader.read();
}
},
Expand Down

0 comments on commit d438374

Please sign in to comment.