Skip to content

Commit

Permalink
fix bedrock
Browse files Browse the repository at this point in the history
  • Loading branch information
stephmilovic committed Nov 22, 2023
1 parent 781d3d6 commit 555ccc8
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,16 @@ export async function getTokenCountFromInvokeStream({

let responseBody: string = '';

const responseBuffer: Uint8Array[] = [];

const isBedrock = actionTypeId === '.bedrock';

responseStream.on('data', (chunk) => {
if (isBedrock) {
responseBody += parseBedrockChunk(chunk);
// special encoding for bedrock, do not attempt to convert to string
responseBuffer.push(chunk);
} else {
// no special encoding, can safely use toString and append to responseBody
responseBody += chunk.toString();
}
});
Expand All @@ -70,26 +74,62 @@ export async function getTokenCountFromInvokeStream({

// parse openai response once responseBody is fully built
// They send the response in sometimes incomplete chunks of JSON
const parsedResponse = isBedrock ? responseBody : parseOpenAIResponse(responseBody);
const parsedResponse = isBedrock
? parseBedrockBuffer(responseBuffer)
: parseOpenAIResponse(responseBody);

const completionTokens = encode(parsedResponse).length;

return {
prompt: promptTokens,
completion: completionTokens,
total: promptTokens + completionTokens,
};
}

const parseBedrockChunk = (chunk: ArrayBufferView) => {
const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8);
const event = awsDecoder.decode(chunk);
const parsed = JSON.parse(
Buffer.from(JSON.parse(new TextDecoder().decode(event.body)).bytes, 'base64').toString()
);
return parsed.completion;
const parseBedrockBuffer = (chunks: Uint8Array[]) => {
let bedrockBuffer: Uint8Array = new Uint8Array(0);
return chunks
.map((chunk) => {
bedrockBuffer = concatChunks(bedrockBuffer, chunk);
let messageLength = getMessageLength(bedrockBuffer);

const buildChunks = [];
while (bedrockBuffer.byteLength > 0 && bedrockBuffer.byteLength >= messageLength) {
const extractedChunk = bedrockBuffer.slice(0, messageLength);
buildChunks.push(extractedChunk);
bedrockBuffer = bedrockBuffer.slice(messageLength);
messageLength = getMessageLength(bedrockBuffer);
}

const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8);

return buildChunks
.map((bChunk) => {
const event = awsDecoder.decode(bChunk);
const body = JSON.parse(
Buffer.from(JSON.parse(new TextDecoder().decode(event.body)).bytes, 'base64').toString()
);
return body.completion;
})
.join('');
})
.join('');
};

function concatChunks(a: Uint8Array, b: Uint8Array) {
const newBuffer = new Uint8Array(a.length + b.length);
newBuffer.set(a);
newBuffer.set(b, a.length);
return newBuffer;
}

function getMessageLength(buffer: Uint8Array) {
if (buffer.byteLength === 0) return 0;
const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);

return view.getUint32(0, false);
}

const parseOpenAIResponse = (responseBody: string) =>
responseBody
.split('\n')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,17 @@ export const getStreamObservable = ({
observer.next({ chunks: [], loading: true });
const decoder = new TextDecoder();
const chunks: string[] = [];
let lineBuffer: string = '';
let openAIBuffer: string = '';

let bedrockBuffer: Uint8Array = new Uint8Array(0);
function readOpenAI() {
reader
.read()
.then(({ done, value }: { done: boolean; value?: Uint8Array }) => {
try {
if (done) {
if (lineBuffer) {
chunks.push(lineBuffer);
if (openAIBuffer) {
chunks.push(openAIBuffer);
}
observer.next({
chunks,
Expand All @@ -62,8 +64,8 @@ export const getStreamObservable = ({
nextChunks = [`${API_ERROR}\n\n${JSON.parse(decoded).message}`];
} else {
const lines = decoded.split('\n');
lines[0] = lineBuffer + lines[0];
lineBuffer = lines.pop() || '';
lines[0] = openAIBuffer + lines[0];
openAIBuffer = lines.pop() || '';
nextChunks = getOpenAIChunks(lines);
}
nextChunks.forEach((chunk: string) => {
Expand Down Expand Up @@ -98,24 +100,45 @@ export const getStreamObservable = ({
observer.complete();
return;
}
const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8);

let content;
if (isError) {
content = `${API_ERROR}\n\n${JSON.parse(decoder.decode(value)).message}`;
chunks.push(content);
observer.next({
chunks,
message: chunks.join(''),
loading: true,
});
} else if (value != null) {
const event = awsDecoder.decode(value);
const body = JSON.parse(
Buffer.from(JSON.parse(decoder.decode(event.body)).bytes, 'base64').toString()
);
content = body.completion;
const chunk: Uint8Array = value;

bedrockBuffer = concatChunks(bedrockBuffer, chunk);
let messageLength = getMessageLength(bedrockBuffer);

const buildChunks = [];
while (bedrockBuffer.byteLength > 0 && bedrockBuffer.byteLength >= messageLength) {
const extractedChunk = bedrockBuffer.slice(0, messageLength);
buildChunks.push(extractedChunk);
bedrockBuffer = bedrockBuffer.slice(messageLength);
messageLength = getMessageLength(bedrockBuffer);
}

const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8);
buildChunks.forEach((bChunk) => {
const event = awsDecoder.decode(bChunk);
const body = JSON.parse(
Buffer.from(JSON.parse(decoder.decode(event.body)).bytes, 'base64').toString()
);
content = body.completion;
chunks.push(content);
observer.next({
chunks,
message: chunks.join(''),
loading: true,
});
});
}
chunks.push(content);
observer.next({
chunks,
message: chunks.join(''),
loading: true,
});
} catch (err) {
observer.error(err);
return;
Expand All @@ -126,6 +149,7 @@ export const getStreamObservable = ({
observer.error(err);
});
}

if (connectorTypeTitle === 'Amazon Bedrock') readBedrock();
else if (connectorTypeTitle === 'OpenAI') readOpenAI();
return () => {
Expand Down Expand Up @@ -177,4 +201,18 @@ const getOpenAIChunks = (lines: string[]): string[] => {
return nextChunk;
};

function concatChunks(a: Uint8Array, b: Uint8Array) {
const newBuffer = new Uint8Array(a.length + b.length);
newBuffer.set(a);
newBuffer.set(b, a.length);
return newBuffer;
}

function getMessageLength(buffer: Uint8Array) {
if (buffer.byteLength === 0) return 0;
const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);

return view.getUint32(0, false);
}

export const getPlaceholderObservable = () => new Observable<PromptObservableState>();

0 comments on commit 555ccc8

Please sign in to comment.