Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Security solution] Fix streaming on cloud #171578

Merged
merged 18 commits into from
Nov 23, 2023
Merged
1 change: 1 addition & 0 deletions x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export const getGenAiTokenTracking = async ({
try {
const { total, prompt, completion } = await getTokenCountFromInvokeStream({
responseStream: result.data.pipe(new PassThrough()),
actionTypeId,
body: (validatedParams as { subActionParams: InvokeBody }).subActionParams,
logger,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,15 @@
import { Transform } from 'stream';
import { getTokenCountFromInvokeStream } from './get_token_count_from_invoke_stream';
import { loggerMock } from '@kbn/logging-mocks';
import { EventStreamCodec } from '@smithy/eventstream-codec';
import { fromUtf8, toUtf8 } from '@smithy/util-utf8';

interface StreamMock {
write: (data: string) => void;
fail: () => void;
complete: () => void;
transform: Transform;
}

function createStreamMock(): StreamMock {
function createStreamMock() {
const transform: Transform = new Transform({});

return {
write: (data: string) => {
transform.push(`${data}\n`);
write: (data: unknown) => {
transform.push(data);
},
fail: () => {
transform.emit('error', new Error('Stream failed'));
Expand All @@ -34,7 +29,10 @@ function createStreamMock(): StreamMock {
}
const logger = loggerMock.create();
describe('getTokenCountFromInvokeStream', () => {
let stream: StreamMock;
beforeEach(() => {
jest.resetAllMocks();
});
let stream: ReturnType<typeof createStreamMock>;
const body = {
messages: [
{
Expand All @@ -48,36 +46,79 @@ describe('getTokenCountFromInvokeStream', () => {
],
};

const chunk = {
object: 'chat.completion.chunk',
choices: [
{
delta: {
content: 'Single.',
},
},
],
};

const PROMPT_TOKEN_COUNT = 34;
const COMPLETION_TOKEN_COUNT = 2;
describe('OpenAI stream', () => {
beforeEach(() => {
stream = createStreamMock();
stream.write(`data: ${JSON.stringify(chunk)}`);
});

beforeEach(() => {
stream = createStreamMock();
stream.write('Single');
});

describe('when a stream completes', () => {
beforeEach(async () => {
it('counts the prompt + completion tokens for OpenAI response', async () => {
stream.complete();
});
it('counts the prompt tokens', async () => {
const tokens = await getTokenCountFromInvokeStream({
responseStream: stream.transform,
body,
logger,
actionTypeId: '.gen-ai',
});
expect(tokens.prompt).toBe(PROMPT_TOKEN_COUNT);
expect(tokens.completion).toBe(COMPLETION_TOKEN_COUNT);
expect(tokens.total).toBe(PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT);
});
it('resolves the promise with the correct prompt tokens', async () => {
const tokenPromise = getTokenCountFromInvokeStream({
responseStream: stream.transform,
body,
logger,
actionTypeId: '.gen-ai',
});

stream.fail();

await expect(tokenPromise).resolves.toEqual({
prompt: PROMPT_TOKEN_COUNT,
total: PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT,
completion: COMPLETION_TOKEN_COUNT,
});
expect(logger.error).toHaveBeenCalled();
});
});
describe('Bedrock stream', () => {
beforeEach(() => {
stream = createStreamMock();
stream.write(encodeBedrockResponse('Simple.'));
});

describe('when a stream fails', () => {
it('counts the prompt + completion tokens for OpenAI response', async () => {
stream.complete();
const tokens = await getTokenCountFromInvokeStream({
responseStream: stream.transform,
body,
logger,
actionTypeId: '.bedrock',
});
expect(tokens.prompt).toBe(PROMPT_TOKEN_COUNT);
expect(tokens.completion).toBe(COMPLETION_TOKEN_COUNT);
expect(tokens.total).toBe(PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT);
});
it('resolves the promise with the correct prompt tokens', async () => {
const tokenPromise = getTokenCountFromInvokeStream({
responseStream: stream.transform,
body,
logger,
actionTypeId: '.bedrock',
});

stream.fail();
Expand All @@ -91,3 +132,16 @@ describe('getTokenCountFromInvokeStream', () => {
});
});
});

function encodeBedrockResponse(completion: string) {
return new EventStreamCodec(toUtf8, fromUtf8).encode({
headers: {},
body: Uint8Array.from(
Buffer.from(
JSON.stringify({
bytes: Buffer.from(JSON.stringify({ completion })).toString('base64'),
})
)
),
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import { Logger } from '@kbn/logging';
import { encode } from 'gpt-tokenizer';
import { Readable } from 'stream';
import { finished } from 'stream/promises';
import { EventStreamCodec } from '@smithy/eventstream-codec';
import { fromUtf8, toUtf8 } from '@smithy/util-utf8';

export interface InvokeBody {
messages: Array<{
Expand All @@ -26,10 +28,12 @@ export interface InvokeBody {
* @param logger the logger
*/
export async function getTokenCountFromInvokeStream({
actionTypeId,
responseStream,
body,
logger,
}: {
actionTypeId: string;
responseStream: Readable;
body: InvokeBody;
logger: Logger;
Expand All @@ -47,22 +51,147 @@ export async function getTokenCountFromInvokeStream({
.join('\n')
).length;

let responseBody: string = '';
const parser = actionTypeId === '.bedrock' ? parseBedrockStream : parseOpenAIStream;
const parsedResponse = await parser(responseStream, logger);

const completionTokens = encode(parsedResponse).length;
return {
prompt: promptTokens,
completion: completionTokens,
total: promptTokens + completionTokens,
};
}

type StreamParser = (responseStream: Readable, logger: Logger) => Promise<string>;

responseStream.on('data', (chunk: string) => {
const parseBedrockStream: StreamParser = async (responseStream, logger) => {
const responseBuffer: Uint8Array[] = [];
responseStream.on('data', (chunk) => {
// special encoding for bedrock, do not attempt to convert to string
responseBuffer.push(chunk);
});
try {
await finished(responseStream);
} catch (e) {
logger.error('An error occurred while calculating streaming response tokens');
}
return parseBedrockBuffer(responseBuffer);
};

const parseOpenAIStream: StreamParser = async (responseStream, logger) => {
let responseBody: string = '';
responseStream.on('data', (chunk) => {
// no special encoding, can safely use toString and append to responseBody
responseBody += chunk.toString();
});
try {
await finished(responseStream);
} catch (e) {
logger.error('An error occurred while calculating streaming response tokens');
}
return parseOpenAIResponse(responseBody);
};

const completionTokens = encode(responseBody).length;
/**
* Parses a Bedrock buffer from an array of chunks.
*
* @param {Uint8Array[]} chunks - Array of Uint8Array chunks to be parsed.
* @returns {string} - Parsed string from the Bedrock buffer.
*/
const parseBedrockBuffer = (chunks: Uint8Array[]): string => {
// Initialize an empty Uint8Array to store the concatenated buffer.
let bedrockBuffer: Uint8Array = new Uint8Array(0);

return {
prompt: promptTokens,
completion: completionTokens,
total: promptTokens + completionTokens,
};
// Map through each chunk to process the Bedrock buffer.
return chunks
.map((chunk) => {
// Concatenate the current chunk to the existing buffer.
bedrockBuffer = concatChunks(bedrockBuffer, chunk);
// Get the length of the next message in the buffer.
let messageLength = getMessageLength(bedrockBuffer);
// Initialize an array to store fully formed message chunks.
const buildChunks = [];
// Process the buffer until no complete messages are left.
while (bedrockBuffer.byteLength > 0 && bedrockBuffer.byteLength >= messageLength) {
// Extract a chunk of the specified length from the buffer.
const extractedChunk = bedrockBuffer.slice(0, messageLength);
// Add the extracted chunk to the array of fully formed message chunks.
buildChunks.push(extractedChunk);
// Remove the processed chunk from the buffer.
bedrockBuffer = bedrockBuffer.slice(messageLength);
// Get the length of the next message in the updated buffer.
messageLength = getMessageLength(bedrockBuffer);
}

const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8);

// Decode and parse each message chunk, extracting the 'completion' property.
return buildChunks
.map((bChunk) => {
const event = awsDecoder.decode(bChunk);
const body = JSON.parse(
Buffer.from(JSON.parse(new TextDecoder().decode(event.body)).bytes, 'base64').toString()
);
return body.completion;
})
.join('');
})
.join('');
};

/**
* Concatenates two Uint8Array buffers.
*
* @param {Uint8Array} a - First buffer.
* @param {Uint8Array} b - Second buffer.
* @returns {Uint8Array} - Concatenated buffer.
*/
function concatChunks(a: Uint8Array, b: Uint8Array): Uint8Array {
const newBuffer = new Uint8Array(a.length + b.length);
// Copy the contents of the first buffer to the new buffer.
newBuffer.set(a);
// Copy the contents of the second buffer to the new buffer starting from the end of the first buffer.
newBuffer.set(b, a.length);
return newBuffer;
}

/**
* Gets the length of the next message from the buffer.
*
* @param {Uint8Array} buffer - Buffer containing the message.
* @returns {number} - Length of the next message.
*/
function getMessageLength(buffer: Uint8Array): number {
// If the buffer is empty, return 0.
if (buffer.byteLength === 0) return 0;
// Create a DataView to read the Uint32 value at the beginning of the buffer.
const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
// Read and return the Uint32 value (message length).
return view.getUint32(0, false);
}

const parseOpenAIResponse = (responseBody: string) =>
responseBody
.split('\n')
.filter((line) => {
return line.startsWith('data: ') && !line.endsWith('[DONE]');
})
.map((line) => {
return JSON.parse(line.replace('data: ', ''));
})
.filter(
(
line
): line is {
choices: Array<{
delta: { content?: string; function_call?: { name?: string; arguments: string } };
}>;
} => {
return 'object' in line && line.object === 'chat.completion.chunk';
}
)
.reduce((prev, line) => {
const msg = line.choices[0].delta!;
prev += msg.content || '';
return prev;
}, '');
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ export const getComments = ({
regenerateMessage(currentConversation.id);
};

const connectorTypeTitle = currentConversation.apiConfig.connectorTypeTitle ?? '';

const extraLoadingComment = isFetchingResponse
? [
{
Expand All @@ -75,6 +77,7 @@ export const getComments = ({
children: (
<StreamComment
amendMessage={amendMessageOfConversation}
connectorTypeTitle={connectorTypeTitle}
content=""
regenerateMessage={regenerateMessageOfConversation}
isLastComment
Expand Down Expand Up @@ -122,6 +125,7 @@ export const getComments = ({
children: (
<StreamComment
amendMessage={amendMessageOfConversation}
connectorTypeTitle={connectorTypeTitle}
index={index}
isLastComment={isLastComment}
isError={message.isError}
Expand All @@ -142,6 +146,7 @@ export const getComments = ({
children: (
<StreamComment
amendMessage={amendMessageOfConversation}
connectorTypeTitle={connectorTypeTitle}
content={transformedMessage.content}
index={index}
isLastComment={isLastComment}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const testProps = {
content,
index: 1,
isLastComment: true,
connectorTypeTitle: 'OpenAI',
regenerateMessage: jest.fn(),
transformMessage: jest.fn(),
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ interface Props {
isFetching?: boolean;
isLastComment: boolean;
index: number;
connectorTypeTitle: string;
reader?: ReadableStreamDefaultReader<Uint8Array>;
regenerateMessage: () => void;
transformMessage: (message: string) => ContentMessage;
Expand All @@ -29,6 +30,7 @@ interface Props {
export const StreamComment = ({
amendMessage,
content,
connectorTypeTitle,
index,
isError = false,
isFetching = false,
Expand All @@ -40,6 +42,7 @@ export const StreamComment = ({
const { error, isLoading, isStreaming, pendingMessage, setComplete } = useStream({
amendMessage,
content,
connectorTypeTitle,
reader,
isError,
});
Expand Down
Loading
Loading