From bb5c3d5809c401e31a7bed10df5979f85faf9443 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 20 Nov 2023 11:50:19 -0700 Subject: [PATCH 01/15] add headers --- x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx | 1 + .../server/routes/post_actions_connector_execute.ts | 3 +++ 2 files changed, 4 insertions(+) diff --git a/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx b/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx index f92585cbdd011..d22a8fe205193 100644 --- a/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx +++ b/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx @@ -85,6 +85,7 @@ export const fetchConnectorExecuteAction = async ({ signal, asResponse: isStream, rawResponse: isStream, + headers: { Connection: 'keep-alive' }, } ); diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index 299d8ade24a3f..aaa4efda18366 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -48,6 +48,9 @@ export const postActionsConnectorExecuteRoute = ( const result = await executeAction({ actions, request, connectorId }); return response.ok({ body: result, + ...(request.body.params.subAction === 'invokeStream' + ? { headers: { connection: 'keep-alive', 'Transfer-Encoding': 'chunked' } } + : {}), }); } From 8bb20bb7cb89004018ad6b070877d288bdd61320 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 20 Nov 2023 14:28:57 -0700 Subject: [PATCH 02/15] guessing --- .../kbn-elastic-assistant/impl/assistant/api.tsx | 1 + .../routes/post_actions_connector_execute.ts | 15 +++++++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx b/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx index d22a8fe205193..1fed62a94dece 100644 --- a/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx +++ b/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx @@ -88,6 +88,7 @@ export const fetchConnectorExecuteAction = async ({ headers: { Connection: 'keep-alive' }, } ); + console.log('RESPONSE?????', response); const reader = response?.response?.body?.getReader(); diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index aaa4efda18366..9e0860ad9b86d 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -46,12 +46,15 @@ export const postActionsConnectorExecuteRoute = ( if (!request.body.assistantLangChain) { logger.debug('Executing via actions framework directly, assistantLangChain: false'); const result = await executeAction({ actions, request, connectorId }); - return response.ok({ - body: result, - ...(request.body.params.subAction === 'invokeStream' - ? { headers: { connection: 'keep-alive', 'Transfer-Encoding': 'chunked' } } - : {}), - }); + return request.body.params.subAction === 'invokeStream' + ? // a guess + response.accepted({ + body: result, + headers: { connection: 'keep-alive', 'Transfer-Encoding': 'chunked' }, + }) + : response.ok({ + body: result, + }); } // TODO: Add `traceId` to actions request when calling via langchain From 36ea81e054c7d66ca2edb445a4a3a7bc33fa55ac Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Tue, 21 Nov 2023 08:01:34 -0700 Subject: [PATCH 03/15] attempt using stream without processing stream --- .../elastic_assistant/server/lib/executor.ts | 10 +++++- .../get_comments/stream/stream_observable.ts | 33 +++++++++++-------- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/executor.ts b/x-pack/plugins/elastic_assistant/server/lib/executor.ts index 27064f3fb1961..e6cec8fa9bbb6 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/executor.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/executor.ts @@ -31,7 +31,15 @@ export const executeAction = async ({ const actionResult = await actionsClient.execute({ actionId: connectorId, - params: request.body.params, + params: { + ...request.body.params, + subAction: 'stream', + subActionParams: + // attempting stream without invokeStream + request.body.params.subAction === 'invokeAI' + ? request.body.params.subActionParams + : { body: JSON.stringify(request.body.params.subActionParams), stream: true }, + }, }); if (actionResult.status === 'error') { diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts index b30be69b82cae..c7b8c01b684e8 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts @@ -7,7 +7,6 @@ import { concatMap, delay, finalize, Observable, of, scan, timestamp } from 'rxjs'; import type { Dispatch, SetStateAction } from 'react'; -import { API_ERROR } from '../translations'; import type { PromptObservableState } from './types'; const MIN_DELAY = 35; @@ -42,17 +41,23 @@ export const getStreamObservable = ( observer.complete(); return; } - const decoded = decoder.decode(value); - const content = isError - ? // we format errors as {message: string; status_code: number} - `${API_ERROR}\n\n${JSON.parse(decoded).message}` - : // all other responses are just strings (handled by subaction invokeStream) - decoded; - chunks.push(content); - observer.next({ - chunks, - message: getMessageFromChunks(chunks), - loading: true, + + const nextChunks = decoder + .decode(value) + .split('\n') + // every line starts with "data: ", we remove it and are left with stringified JSON or the string "[DONE]" + .map((str) => str.substring(6)) + // filter out empty lines and the "[DONE]" string + .filter((str) => !!str && str !== '[DONE]') + .map((line) => JSON.parse(line)); + + nextChunks.forEach((chunk) => { + chunks.push(chunk); + observer.next({ + chunks, + message: getMessageFromChunks(chunks), + loading: true, + }); }); } catch (err) { observer.error(err); @@ -99,8 +104,8 @@ export const getStreamObservable = ( finalize(() => setLoading(false)) ); -function getMessageFromChunks(chunks: string[]) { - return chunks.join(''); +function getMessageFromChunks(chunks) { + return chunks.map((chunk) => chunk.choices[0]?.delta.content ?? '').join(''); } export const getPlaceholderObservable = () => new Observable(); From dbc4da7057d0943c0a63402aff0f2db38bd64ba2 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Tue, 21 Nov 2023 09:59:34 -0700 Subject: [PATCH 04/15] stream working but not smoothly --- .../get_comments/stream/stream_observable.ts | 64 ++++++++++++++----- 1 file changed, 47 insertions(+), 17 deletions(-) diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts index c7b8c01b684e8..7d1934ec7a968 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts @@ -5,8 +5,9 @@ * 2.0. */ -import { concatMap, delay, finalize, Observable, of, scan, timestamp } from 'rxjs'; +import { concatMap, delay, finalize, Observable, of, scan, shareReplay, timestamp } from 'rxjs'; import type { Dispatch, SetStateAction } from 'react'; +import { API_ERROR } from '@kbn/elastic-assistant/impl/assistant/translations'; import type { PromptObservableState } from './types'; const MIN_DELAY = 35; @@ -27,45 +28,58 @@ export const getStreamObservable = ( observer.next({ chunks: [], loading: true }); const decoder = new TextDecoder(); const chunks: string[] = []; + let lineBuffer: string = ''; function read() { reader .read() .then(({ done, value }: { done: boolean; value?: Uint8Array }) => { try { if (done) { + if (lineBuffer) { + console.log('EXTRA LINE BUFFER!', lineBuffer); + chunks.push(lineBuffer); + } observer.next({ chunks, - message: getMessageFromChunks(chunks), + message: chunks.join(''), loading: false, }); observer.complete(); return; } - const nextChunks = decoder - .decode(value) - .split('\n') - // every line starts with "data: ", we remove it and are left with stringified JSON or the string "[DONE]" - .map((str) => str.substring(6)) - // filter out empty lines and the "[DONE]" string - .filter((str) => !!str && str !== '[DONE]') - .map((line) => JSON.parse(line)); + const decoded = decoder.decode(value); + if (isError) { + const content = `${API_ERROR}\n\n${JSON.parse(decoded).message}`; + chunks.push(content); + observer.next({ + chunks, + message: chunks.join(''), + loading: true, + }); + } else { + const lines = decoded.split('\n'); + + lines[0] = lineBuffer + lines[0]; + lineBuffer = lines.pop() || ''; + const content = getNextChunk(lines); + chunks.push(content); - nextChunks.forEach((chunk) => { - chunks.push(chunk); observer.next({ chunks, - message: getMessageFromChunks(chunks), + message: chunks.join(''), loading: true, }); - }); + } } catch (err) { + console.log('error caught', err); observer.error(err); return; } read(); }) .catch((err) => { + console.log('error caught 2', err); observer.error(err); }); } @@ -74,6 +88,9 @@ export const getStreamObservable = ( reader.cancel(); }; }).pipe( + // make sure the request is only triggered once, + // even with multiple subscribers + shareReplay(1), // append a timestamp of when each value was emitted timestamp(), // use the previous timestamp to calculate a target @@ -104,8 +121,21 @@ export const getStreamObservable = ( finalize(() => setLoading(false)) ); -function getMessageFromChunks(chunks) { - return chunks.map((chunk) => chunk.choices[0]?.delta.content ?? '').join(''); -} +const getNextChunk = (lines: string[]) => { + const nextChunk = lines + .map((str) => str.substring(6)) + .filter((str) => !!str && str !== '[DONE]') + .map((line) => { + try { + const openaiResponse = JSON.parse(line); + return openaiResponse.choices[0]?.delta.content ?? ''; + } catch (err) { + console.log('ERROR', err); + return ''; + } + }) + .join(''); + return nextChunk; +}; export const getPlaceholderObservable = () => new Observable(); From 1db8a6b869c32f5112d8bb94341ebf8edd7c07f1 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Tue, 21 Nov 2023 10:15:26 -0700 Subject: [PATCH 05/15] rm logs --- .../public/assistant/get_comments/stream/stream_observable.ts | 4 ---- .../public/assistant/get_comments/stream/use_stream.tsx | 3 +-- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts index 7d1934ec7a968..635d79b1e378d 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts @@ -36,7 +36,6 @@ export const getStreamObservable = ( try { if (done) { if (lineBuffer) { - console.log('EXTRA LINE BUFFER!', lineBuffer); chunks.push(lineBuffer); } observer.next({ @@ -72,14 +71,12 @@ export const getStreamObservable = ( }); } } catch (err) { - console.log('error caught', err); observer.error(err); return; } read(); }) .catch((err) => { - console.log('error caught 2', err); observer.error(err); }); } @@ -130,7 +127,6 @@ const getNextChunk = (lines: string[]) => { const openaiResponse = JSON.parse(line); return openaiResponse.choices[0]?.delta.content ?? ''; } catch (err) { - console.log('ERROR', err); return ''; } }) diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.tsx b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.tsx index 7de06589f87c7..d29d56f428767 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.tsx +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.tsx @@ -7,7 +7,6 @@ import { useCallback, useEffect, useMemo, useState } from 'react'; import type { Subscription } from 'rxjs'; -import { share } from 'rxjs'; import { getPlaceholderObservable, getStreamObservable } from './stream_observable'; interface UseStreamProps { @@ -66,7 +65,7 @@ export const useStream = ({ } }, [complete, onCompleteStream]); useEffect(() => { - const newSubscription = observer$.pipe(share()).subscribe({ + const newSubscription = observer$.subscribe({ next: ({ message, loading: isLoading }) => { setLoading(isLoading); setPendingMessage(message); From 1eff97facf44fac696980c111a5ac40725e8cf63 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Tue, 21 Nov 2023 10:53:49 -0700 Subject: [PATCH 06/15] will this work --- .../get_comments/stream/stream_observable.ts | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts index 635d79b1e378d..42ea9b0e35c16 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts @@ -40,7 +40,7 @@ export const getStreamObservable = ( } observer.next({ chunks, - message: chunks.join(''), + message: chunks.join(' '), loading: false, }); observer.complete(); @@ -48,28 +48,24 @@ export const getStreamObservable = ( } const decoded = decoder.decode(value); + let content; if (isError) { - const content = `${API_ERROR}\n\n${JSON.parse(decoded).message}`; - chunks.push(content); - observer.next({ - chunks, - message: chunks.join(''), - loading: true, - }); + content = `${API_ERROR}\n\n${JSON.parse(decoded).message}`; } else { const lines = decoded.split('\n'); - lines[0] = lineBuffer + lines[0]; lineBuffer = lines.pop() || ''; - const content = getNextChunk(lines); - chunks.push(content); - + content = getNextChunk(lines); + } + const characters = content.split(' '); + characters.forEach((char) => { + chunks.push(char); observer.next({ chunks, - message: chunks.join(''), + message: chunks.join(' '), loading: true, }); - } + }); } catch (err) { observer.error(err); return; @@ -107,11 +103,9 @@ export const getStreamObservable = ( concatMap((value) => { const now = Date.now(); const delayFor = value.timestamp - now; - if (delayFor <= 0) { return of(value.value); } - return of(value.value).pipe(delay(delayFor)); }), // set loading to false when the observable completes or errors out From c2c99c37aaf26f79da64187fc7f0dc95f0d7505e Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Tue, 21 Nov 2023 12:47:18 -0700 Subject: [PATCH 07/15] wip --- .../impl/assistant/api.tsx | 2 -- .../get_comments/stream/stream_observable.ts | 32 ++++++++----------- 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx b/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx index 1fed62a94dece..0d491a348a975 100644 --- a/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx +++ b/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx @@ -88,8 +88,6 @@ export const fetchConnectorExecuteAction = async ({ headers: { Connection: 'keep-alive' }, } ); - console.log('RESPONSE?????', response); - const reader = response?.response?.body?.getReader(); if (!reader) { diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts index 42ea9b0e35c16..8a40b55aee374 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts @@ -5,10 +5,10 @@ * 2.0. */ -import { concatMap, delay, finalize, Observable, of, scan, shareReplay, timestamp } from 'rxjs'; +import { concatMap, delay, finalize, Observable, of, scan, timestamp } from 'rxjs'; import type { Dispatch, SetStateAction } from 'react'; -import { API_ERROR } from '@kbn/elastic-assistant/impl/assistant/translations'; import type { PromptObservableState } from './types'; +import { API_ERROR } from '../translations'; const MIN_DELAY = 35; /** @@ -35,12 +35,9 @@ export const getStreamObservable = ( .then(({ done, value }: { done: boolean; value?: Uint8Array }) => { try { if (done) { - if (lineBuffer) { - chunks.push(lineBuffer); - } observer.next({ chunks, - message: chunks.join(' '), + message: chunks.join(''), loading: false, }); observer.complete(); @@ -48,21 +45,20 @@ export const getStreamObservable = ( } const decoded = decoder.decode(value); - let content; + let nextChunks; if (isError) { - content = `${API_ERROR}\n\n${JSON.parse(decoded).message}`; + nextChunks = [`${API_ERROR}\n\n${JSON.parse(decoded).message}`]; } else { const lines = decoded.split('\n'); lines[0] = lineBuffer + lines[0]; lineBuffer = lines.pop() || ''; - content = getNextChunk(lines); + nextChunks = getNextChunks(lines); } - const characters = content.split(' '); - characters.forEach((char) => { - chunks.push(char); + nextChunks.forEach((chunk: string) => { + chunks.push(chunk); observer.next({ chunks, - message: chunks.join(' '), + message: chunks.join(''), loading: true, }); }); @@ -81,9 +77,6 @@ export const getStreamObservable = ( reader.cancel(); }; }).pipe( - // make sure the request is only triggered once, - // even with multiple subscribers - shareReplay(1), // append a timestamp of when each value was emitted timestamp(), // use the previous timestamp to calculate a target @@ -103,16 +96,18 @@ export const getStreamObservable = ( concatMap((value) => { const now = Date.now(); const delayFor = value.timestamp - now; + if (delayFor <= 0) { return of(value.value); } + return of(value.value).pipe(delay(delayFor)); }), // set loading to false when the observable completes or errors out finalize(() => setLoading(false)) ); -const getNextChunk = (lines: string[]) => { +const getNextChunks = (lines: string[]): string[] => { const nextChunk = lines .map((str) => str.substring(6)) .filter((str) => !!str && str !== '[DONE]') @@ -123,8 +118,7 @@ const getNextChunk = (lines: string[]) => { } catch (err) { return ''; } - }) - .join(''); + }); return nextChunk; }; From 58f38fe18c4a0bf4831fc8ff669b835c4b4d0693 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Tue, 21 Nov 2023 14:24:41 -0700 Subject: [PATCH 08/15] both fixed --- .../server/lib/gen_ai_token_tracking.ts | 1 + .../lib/get_token_count_from_invoke_stream.ts | 50 ++++++++++++- .../elastic_assistant/server/lib/executor.ts | 10 +-- .../routes/post_actions_connector_execute.ts | 12 +-- .../public/assistant/get_comments/index.tsx | 5 ++ .../assistant/get_comments/stream/index.tsx | 3 + .../get_comments/stream/stream_observable.ts | 75 ++++++++++++++++--- .../get_comments/stream/use_stream.tsx | 6 +- .../server/connector_types/bedrock/bedrock.ts | 30 +------- .../server/connector_types/openai/openai.ts | 46 +----------- 10 files changed, 134 insertions(+), 104 deletions(-) diff --git a/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts b/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts index 7c104177ea36e..866580e8e7b3b 100644 --- a/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts +++ b/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts @@ -42,6 +42,7 @@ export const getGenAiTokenTracking = async ({ try { const { total, prompt, completion } = await getTokenCountFromInvokeStream({ responseStream: result.data.pipe(new PassThrough()), + actionTypeId, body: (validatedParams as { subActionParams: InvokeBody }).subActionParams, logger, }); diff --git a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts index 594fec89d93c0..f030421cfedec 100644 --- a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts +++ b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts @@ -9,6 +9,8 @@ import { Logger } from '@kbn/logging'; import { encode } from 'gpt-tokenizer'; import { Readable } from 'stream'; import { finished } from 'stream/promises'; +import { EventStreamCodec } from '@smithy/eventstream-codec'; +import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; export interface InvokeBody { messages: Array<{ @@ -26,10 +28,12 @@ export interface InvokeBody { * @param logger the logger */ export async function getTokenCountFromInvokeStream({ + actionTypeId, responseStream, body, logger, }: { + actionTypeId: string; responseStream: Readable; body: InvokeBody; logger: Logger; @@ -49,8 +53,19 @@ export async function getTokenCountFromInvokeStream({ let responseBody: string = ''; - responseStream.on('data', (chunk: string) => { - responseBody += chunk.toString(); + const isBedrock = actionTypeId === '.bedrock'; + const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8); + + responseStream.on('data', (chunk) => { + if (isBedrock) { + const event = awsDecoder.decode(chunk); + const parsed = JSON.parse( + Buffer.from(JSON.parse(new TextDecoder().decode(event.body)).bytes, 'base64').toString() + ); + responseBody += parsed.completion; + } else { + responseBody += chunk.toString(); + } }); try { await finished(responseStream); @@ -58,7 +73,9 @@ export async function getTokenCountFromInvokeStream({ logger.error('An error occurred while calculating streaming response tokens'); } - const completionTokens = encode(responseBody).length; + const parsedResponse = isBedrock ? responseBody : parseOpenAIResponse(responseBody); + + const completionTokens = encode(parsedResponse).length; return { prompt: promptTokens, @@ -66,3 +83,30 @@ export async function getTokenCountFromInvokeStream({ total: promptTokens + completionTokens, }; } + +const parseOpenAIResponse = (responseBody: string) => { + return responseBody + .split('\n') + .filter((line) => { + return line.startsWith('data: ') && !line.endsWith('[DONE]'); + }) + .map((line) => { + return JSON.parse(line.replace('data: ', '')); + }) + .filter( + ( + line + ): line is { + choices: Array<{ + delta: { content?: string; function_call?: { name?: string; arguments: string } }; + }>; + } => { + return 'object' in line && line.object === 'chat.completion.chunk'; + } + ) + .reduce((prev, line) => { + const msg = line.choices[0].delta!; + prev += msg.content || ''; + return prev; + }, ''); +}; diff --git a/x-pack/plugins/elastic_assistant/server/lib/executor.ts b/x-pack/plugins/elastic_assistant/server/lib/executor.ts index e6cec8fa9bbb6..27064f3fb1961 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/executor.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/executor.ts @@ -31,15 +31,7 @@ export const executeAction = async ({ const actionResult = await actionsClient.execute({ actionId: connectorId, - params: { - ...request.body.params, - subAction: 'stream', - subActionParams: - // attempting stream without invokeStream - request.body.params.subAction === 'invokeAI' - ? request.body.params.subActionParams - : { body: JSON.stringify(request.body.params.subActionParams), stream: true }, - }, + params: request.body.params, }); if (actionResult.status === 'error') { diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index 9e0860ad9b86d..299d8ade24a3f 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -46,15 +46,9 @@ export const postActionsConnectorExecuteRoute = ( if (!request.body.assistantLangChain) { logger.debug('Executing via actions framework directly, assistantLangChain: false'); const result = await executeAction({ actions, request, connectorId }); - return request.body.params.subAction === 'invokeStream' - ? // a guess - response.accepted({ - body: result, - headers: { connection: 'keep-alive', 'Transfer-Encoding': 'chunked' }, - }) - : response.ok({ - body: result, - }); + return response.ok({ + body: result, + }); } // TODO: Add `traceId` to actions request when calling via langchain diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/index.tsx b/x-pack/plugins/security_solution/public/assistant/get_comments/index.tsx index 3b778013a42d1..b1acb2cf64eb3 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/index.tsx +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/index.tsx @@ -66,6 +66,8 @@ export const getComments = ({ regenerateMessage(currentConversation.id); }; + const connectorTypeTitle = currentConversation.apiConfig.connectorTypeTitle; + const extraLoadingComment = isFetchingResponse ? [ { @@ -75,6 +77,7 @@ export const getComments = ({ children: ( ; regenerateMessage: () => void; transformMessage: (message: string) => ContentMessage; @@ -29,6 +30,7 @@ interface Props { export const StreamComment = ({ amendMessage, content, + connectorTypeTitle, index, isError = false, isFetching = false, @@ -40,6 +42,7 @@ export const StreamComment = ({ const { error, isLoading, isStreaming, pendingMessage, setComplete } = useStream({ amendMessage, content, + connectorTypeTitle, reader, isError, }); diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts index 8a40b55aee374..d86a3131b9910 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts @@ -7,10 +7,18 @@ import { concatMap, delay, finalize, Observable, of, scan, timestamp } from 'rxjs'; import type { Dispatch, SetStateAction } from 'react'; +import { EventStreamCodec } from '@smithy/eventstream-codec'; +import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; import type { PromptObservableState } from './types'; import { API_ERROR } from '../translations'; const MIN_DELAY = 35; +interface StreamObservable { + connectorTypeTitle: string; + reader: ReadableStreamDefaultReader; + setLoading: Dispatch>; + isError: boolean; +} /** * Returns an Observable that reads data from a ReadableStream and emits values representing the state of the data processing. * @@ -19,22 +27,26 @@ const MIN_DELAY = 35; * @param isError - indicates whether the reader response is an error message or not * @returns {Observable} An Observable that emits PromptObservableState */ -export const getStreamObservable = ( - reader: ReadableStreamDefaultReader, - setLoading: Dispatch>, - isError: boolean -): Observable => +export const getStreamObservable = ({ + connectorTypeTitle, + isError, + reader, + setLoading, +}: StreamObservable): Observable => new Observable((observer) => { observer.next({ chunks: [], loading: true }); const decoder = new TextDecoder(); const chunks: string[] = []; let lineBuffer: string = ''; - function read() { + function readOpenAI() { reader .read() .then(({ done, value }: { done: boolean; value?: Uint8Array }) => { try { if (done) { + if (lineBuffer) { + chunks.push(lineBuffer); + } observer.next({ chunks, message: chunks.join(''), @@ -52,7 +64,7 @@ export const getStreamObservable = ( const lines = decoded.split('\n'); lines[0] = lineBuffer + lines[0]; lineBuffer = lines.pop() || ''; - nextChunks = getNextChunks(lines); + nextChunks = getOpenAIChunks(lines); } nextChunks.forEach((chunk: string) => { chunks.push(chunk); @@ -66,13 +78,56 @@ export const getStreamObservable = ( observer.error(err); return; } - read(); + readOpenAI(); }) .catch((err) => { observer.error(err); }); } - read(); + function readBedrock() { + reader + .read() + .then(({ done, value }: { done: boolean; value?: Uint8Array }) => { + try { + if (done) { + observer.next({ + chunks, + message: chunks.join(''), + loading: false, + }); + observer.complete(); + return; + } + const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8); + + let content; + if (isError) { + content = `${API_ERROR}\n\n${JSON.parse(decoder.decode(value)).message}`; + } else if (value != null) { + const event = awsDecoder.decode(value); + const body = JSON.parse( + Buffer.from(JSON.parse(decoder.decode(event.body)).bytes, 'base64').toString() + ); + content = body.completion; + } + chunks.push(content); + observer.next({ + chunks, + message: chunks.join(''), + loading: true, + }); + } catch (err) { + observer.error(err); + return; + } + readBedrock(); + }) + .catch((err) => { + observer.error(err); + }); + } + if (connectorTypeTitle === 'Amazon Bedrock') readBedrock(); + else if (connectorTypeTitle === 'OpenAI') readOpenAI(); return () => { reader.cancel(); }; @@ -107,7 +162,7 @@ export const getStreamObservable = ( finalize(() => setLoading(false)) ); -const getNextChunks = (lines: string[]): string[] => { +const getOpenAIChunks = (lines: string[]): string[] => { const nextChunk = lines .map((str) => str.substring(6)) .filter((str) => !!str && str !== '[DONE]') diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.tsx b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.tsx index d29d56f428767..9271758a8558e 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.tsx +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.tsx @@ -13,6 +13,7 @@ interface UseStreamProps { amendMessage: (message: string) => void; isError: boolean; content?: string; + connectorTypeTitle: string; reader?: ReadableStreamDefaultReader; } interface UseStream { @@ -38,6 +39,7 @@ interface UseStream { export const useStream = ({ amendMessage, content, + connectorTypeTitle, reader, isError, }: UseStreamProps): UseStream => { @@ -48,9 +50,9 @@ export const useStream = ({ const observer$ = useMemo( () => content == null && reader != null - ? getStreamObservable(reader, setLoading, isError) + ? getStreamObservable({ connectorTypeTitle, reader, setLoading, isError }) : getPlaceholderObservable(), - [content, isError, reader] + [content, isError, reader, connectorTypeTitle] ); const onCompleteStream = useCallback(() => { subscription?.unsubscribe(); diff --git a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts index 70f8e121e1519..ade589e54dc14 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts @@ -9,9 +9,7 @@ import { ServiceParams, SubActionConnector } from '@kbn/actions-plugin/server'; import aws from 'aws4'; import type { AxiosError } from 'axios'; import { IncomingMessage } from 'http'; -import { PassThrough, Transform } from 'stream'; -import { EventStreamCodec } from '@smithy/eventstream-codec'; -import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; +import { PassThrough } from 'stream'; import { RunActionParamsSchema, RunActionResponseSchema, @@ -178,12 +176,12 @@ export class BedrockConnector extends SubActionConnector { * @param messages An array of messages to be sent to the API * @param model Optional model to be used for the API request. If not provided, the default model from the connector will be used. */ - public async invokeStream({ messages, model }: InvokeAIActionParams): Promise { + public async invokeStream({ messages, model }: InvokeAIActionParams): Promise { const res = (await this.streamApi({ body: JSON.stringify(formatBedrockBody({ messages })), model, })) as unknown as IncomingMessage; - return res.pipe(transformToString()); + return res; } /** @@ -222,25 +220,3 @@ const formatBedrockBody = ({ stop_sequences: ['\n\nHuman:'], }; }; - -/** - * Takes in a readable stream of data and returns a Transform stream that - * uses the AWS proprietary codec to parse the proprietary bedrock response into - * a string of the response text alone, returning the response string to the stream - */ -const transformToString = () => - new Transform({ - transform(chunk, encoding, callback) { - const encoder = new TextEncoder(); - const decoder = new EventStreamCodec(toUtf8, fromUtf8); - const event = decoder.decode(chunk); - const body = JSON.parse( - Buffer.from( - JSON.parse(new TextDecoder('utf-8').decode(event.body)).bytes, - 'base64' - ).toString() - ); - const newChunk = encoder.encode(body.completion); - callback(null, newChunk); - }, - }); diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts index 78fca4bd84198..81e7c2da48313 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts @@ -7,7 +7,6 @@ import { ServiceParams, SubActionConnector } from '@kbn/actions-plugin/server'; import type { AxiosError } from 'axios'; -import { PassThrough, Transform } from 'stream'; import { IncomingMessage } from 'http'; import { RunActionParamsSchema, @@ -198,13 +197,13 @@ export class OpenAIConnector extends SubActionConnector { * the response from the streamApi method and returns the response string alone. * @param body - the OpenAI Invoke request body */ - public async invokeStream(body: InvokeAIActionParams): Promise { + public async invokeStream(body: InvokeAIActionParams): Promise { const res = (await this.streamApi({ body: JSON.stringify(body), stream: true, })) as unknown as IncomingMessage; - return res.pipe(new PassThrough()).pipe(transformToString()); + return res; } /** @@ -229,44 +228,3 @@ export class OpenAIConnector extends SubActionConnector { }; } } - -/** - * Takes in a readable stream of data and returns a Transform stream that - * parses the proprietary OpenAI response into a string of the response text alone, - * returning the response string to the stream - */ -const transformToString = () => { - let lineBuffer: string = ''; - const decoder = new TextDecoder(); - - return new Transform({ - transform(chunk, encoding, callback) { - const chunks = decoder.decode(chunk); - const lines = chunks.split('\n'); - lines[0] = lineBuffer + lines[0]; - lineBuffer = lines.pop() || ''; - callback(null, getNextChunk(lines)); - }, - flush(callback) { - // Emit an additional chunk with the content of lineBuffer if it has length - if (lineBuffer.length > 0) { - callback(null, getNextChunk([lineBuffer])); - } else { - callback(); - } - }, - }); -}; - -const getNextChunk = (lines: string[]) => { - const encoder = new TextEncoder(); - const nextChunk = lines - .map((str) => str.substring(6)) - .filter((str) => !!str && str !== '[DONE]') - .map((line) => { - const openaiResponse = JSON.parse(line); - return openaiResponse.choices[0]?.delta.content ?? ''; - }) - .join(''); - return encoder.encode(nextChunk); -}; From 355e9e4e74440f1ab40f48d5ed356196eb6ba848 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Tue, 21 Nov 2023 14:31:28 -0700 Subject: [PATCH 09/15] rm --- x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx b/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx index 0d491a348a975..f92585cbdd011 100644 --- a/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx +++ b/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx @@ -85,9 +85,9 @@ export const fetchConnectorExecuteAction = async ({ signal, asResponse: isStream, rawResponse: isStream, - headers: { Connection: 'keep-alive' }, } ); + const reader = response?.response?.body?.getReader(); if (!reader) { From 781d3d6110ec044dc88ebd7a2235fc49c777df2a Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Tue, 21 Nov 2023 14:36:14 -0700 Subject: [PATCH 10/15] cleanup --- .../lib/get_token_count_from_invoke_stream.ts | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts index f030421cfedec..0f9b0546ff220 100644 --- a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts +++ b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts @@ -54,15 +54,10 @@ export async function getTokenCountFromInvokeStream({ let responseBody: string = ''; const isBedrock = actionTypeId === '.bedrock'; - const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8); responseStream.on('data', (chunk) => { if (isBedrock) { - const event = awsDecoder.decode(chunk); - const parsed = JSON.parse( - Buffer.from(JSON.parse(new TextDecoder().decode(event.body)).bytes, 'base64').toString() - ); - responseBody += parsed.completion; + responseBody += parseBedrockChunk(chunk); } else { responseBody += chunk.toString(); } @@ -73,6 +68,8 @@ export async function getTokenCountFromInvokeStream({ logger.error('An error occurred while calculating streaming response tokens'); } + // parse openai response once responseBody is fully built + // They send the response in sometimes incomplete chunks of JSON const parsedResponse = isBedrock ? responseBody : parseOpenAIResponse(responseBody); const completionTokens = encode(parsedResponse).length; @@ -84,8 +81,17 @@ export async function getTokenCountFromInvokeStream({ }; } -const parseOpenAIResponse = (responseBody: string) => { - return responseBody +const parseBedrockChunk = (chunk: ArrayBufferView) => { + const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8); + const event = awsDecoder.decode(chunk); + const parsed = JSON.parse( + Buffer.from(JSON.parse(new TextDecoder().decode(event.body)).bytes, 'base64').toString() + ); + return parsed.completion; +}; + +const parseOpenAIResponse = (responseBody: string) => + responseBody .split('\n') .filter((line) => { return line.startsWith('data: ') && !line.endsWith('[DONE]'); @@ -109,4 +115,3 @@ const parseOpenAIResponse = (responseBody: string) => { prev += msg.content || ''; return prev; }, ''); -}; From 555ccc8868513067c0ba96a9f0effafea07e98a9 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Wed, 22 Nov 2023 09:30:29 -0700 Subject: [PATCH 11/15] fix bedrock --- .../lib/get_token_count_from_invoke_stream.ts | 60 +++++++++++++--- .../get_comments/stream/stream_observable.ts | 72 ++++++++++++++----- 2 files changed, 105 insertions(+), 27 deletions(-) diff --git a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts index 0f9b0546ff220..3143ca1a65e93 100644 --- a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts +++ b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts @@ -53,12 +53,16 @@ export async function getTokenCountFromInvokeStream({ let responseBody: string = ''; + const responseBuffer: Uint8Array[] = []; + const isBedrock = actionTypeId === '.bedrock'; responseStream.on('data', (chunk) => { if (isBedrock) { - responseBody += parseBedrockChunk(chunk); + // special encoding for bedrock, do not attempt to convert to string + responseBuffer.push(chunk); } else { + // no special encoding, can safely use toString and append to responseBody responseBody += chunk.toString(); } }); @@ -70,10 +74,11 @@ export async function getTokenCountFromInvokeStream({ // parse openai response once responseBody is fully built // They send the response in sometimes incomplete chunks of JSON - const parsedResponse = isBedrock ? responseBody : parseOpenAIResponse(responseBody); + const parsedResponse = isBedrock + ? parseBedrockBuffer(responseBuffer) + : parseOpenAIResponse(responseBody); const completionTokens = encode(parsedResponse).length; - return { prompt: promptTokens, completion: completionTokens, @@ -81,15 +86,50 @@ export async function getTokenCountFromInvokeStream({ }; } -const parseBedrockChunk = (chunk: ArrayBufferView) => { - const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8); - const event = awsDecoder.decode(chunk); - const parsed = JSON.parse( - Buffer.from(JSON.parse(new TextDecoder().decode(event.body)).bytes, 'base64').toString() - ); - return parsed.completion; +const parseBedrockBuffer = (chunks: Uint8Array[]) => { + let bedrockBuffer: Uint8Array = new Uint8Array(0); + return chunks + .map((chunk) => { + bedrockBuffer = concatChunks(bedrockBuffer, chunk); + let messageLength = getMessageLength(bedrockBuffer); + + const buildChunks = []; + while (bedrockBuffer.byteLength > 0 && bedrockBuffer.byteLength >= messageLength) { + const extractedChunk = bedrockBuffer.slice(0, messageLength); + buildChunks.push(extractedChunk); + bedrockBuffer = bedrockBuffer.slice(messageLength); + messageLength = getMessageLength(bedrockBuffer); + } + + const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8); + + return buildChunks + .map((bChunk) => { + const event = awsDecoder.decode(bChunk); + const body = JSON.parse( + Buffer.from(JSON.parse(new TextDecoder().decode(event.body)).bytes, 'base64').toString() + ); + return body.completion; + }) + .join(''); + }) + .join(''); }; +function concatChunks(a: Uint8Array, b: Uint8Array) { + const newBuffer = new Uint8Array(a.length + b.length); + newBuffer.set(a); + newBuffer.set(b, a.length); + return newBuffer; +} + +function getMessageLength(buffer: Uint8Array) { + if (buffer.byteLength === 0) return 0; + const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength); + + return view.getUint32(0, false); +} + const parseOpenAIResponse = (responseBody: string) => responseBody .split('\n') diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts index d86a3131b9910..3341413627897 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts @@ -37,15 +37,17 @@ export const getStreamObservable = ({ observer.next({ chunks: [], loading: true }); const decoder = new TextDecoder(); const chunks: string[] = []; - let lineBuffer: string = ''; + let openAIBuffer: string = ''; + + let bedrockBuffer: Uint8Array = new Uint8Array(0); function readOpenAI() { reader .read() .then(({ done, value }: { done: boolean; value?: Uint8Array }) => { try { if (done) { - if (lineBuffer) { - chunks.push(lineBuffer); + if (openAIBuffer) { + chunks.push(openAIBuffer); } observer.next({ chunks, @@ -62,8 +64,8 @@ export const getStreamObservable = ({ nextChunks = [`${API_ERROR}\n\n${JSON.parse(decoded).message}`]; } else { const lines = decoded.split('\n'); - lines[0] = lineBuffer + lines[0]; - lineBuffer = lines.pop() || ''; + lines[0] = openAIBuffer + lines[0]; + openAIBuffer = lines.pop() || ''; nextChunks = getOpenAIChunks(lines); } nextChunks.forEach((chunk: string) => { @@ -98,24 +100,45 @@ export const getStreamObservable = ({ observer.complete(); return; } - const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8); let content; if (isError) { content = `${API_ERROR}\n\n${JSON.parse(decoder.decode(value)).message}`; + chunks.push(content); + observer.next({ + chunks, + message: chunks.join(''), + loading: true, + }); } else if (value != null) { - const event = awsDecoder.decode(value); - const body = JSON.parse( - Buffer.from(JSON.parse(decoder.decode(event.body)).bytes, 'base64').toString() - ); - content = body.completion; + const chunk: Uint8Array = value; + + bedrockBuffer = concatChunks(bedrockBuffer, chunk); + let messageLength = getMessageLength(bedrockBuffer); + + const buildChunks = []; + while (bedrockBuffer.byteLength > 0 && bedrockBuffer.byteLength >= messageLength) { + const extractedChunk = bedrockBuffer.slice(0, messageLength); + buildChunks.push(extractedChunk); + bedrockBuffer = bedrockBuffer.slice(messageLength); + messageLength = getMessageLength(bedrockBuffer); + } + + const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8); + buildChunks.forEach((bChunk) => { + const event = awsDecoder.decode(bChunk); + const body = JSON.parse( + Buffer.from(JSON.parse(decoder.decode(event.body)).bytes, 'base64').toString() + ); + content = body.completion; + chunks.push(content); + observer.next({ + chunks, + message: chunks.join(''), + loading: true, + }); + }); } - chunks.push(content); - observer.next({ - chunks, - message: chunks.join(''), - loading: true, - }); } catch (err) { observer.error(err); return; @@ -126,6 +149,7 @@ export const getStreamObservable = ({ observer.error(err); }); } + if (connectorTypeTitle === 'Amazon Bedrock') readBedrock(); else if (connectorTypeTitle === 'OpenAI') readOpenAI(); return () => { @@ -177,4 +201,18 @@ const getOpenAIChunks = (lines: string[]): string[] => { return nextChunk; }; +function concatChunks(a: Uint8Array, b: Uint8Array) { + const newBuffer = new Uint8Array(a.length + b.length); + newBuffer.set(a); + newBuffer.set(b, a.length); + return newBuffer; +} + +function getMessageLength(buffer: Uint8Array) { + if (buffer.byteLength === 0) return 0; + const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength); + + return view.getUint32(0, false); +} + export const getPlaceholderObservable = () => new Observable(); From 6995899ef08854df05610c54142e3ed6f350bebc Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Wed, 22 Nov 2023 10:06:21 -0700 Subject: [PATCH 12/15] comment code --- .../lib/get_token_count_from_invoke_stream.ts | 44 ++++++++++++++-- .../public/assistant/get_comments/index.tsx | 2 +- .../get_comments/stream/stream_observable.ts | 52 +++++++++++++++++-- 3 files changed, 89 insertions(+), 9 deletions(-) diff --git a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts index 3143ca1a65e93..8c746d7a9ed7b 100644 --- a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts +++ b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts @@ -86,23 +86,40 @@ export async function getTokenCountFromInvokeStream({ }; } -const parseBedrockBuffer = (chunks: Uint8Array[]) => { +/** + * Parses a Bedrock buffer from an array of chunks. + * + * @param {Uint8Array[]} chunks - Array of Uint8Array chunks to be parsed. + * @returns {string} - Parsed string from the Bedrock buffer. + */ +const parseBedrockBuffer = (chunks: Uint8Array[]): string => { + // Initialize an empty Uint8Array to store the concatenated buffer. let bedrockBuffer: Uint8Array = new Uint8Array(0); + + // Map through each chunk to process the Bedrock buffer. return chunks .map((chunk) => { + // Concatenate the current chunk to the existing buffer. bedrockBuffer = concatChunks(bedrockBuffer, chunk); + // Get the length of the next message in the buffer. let messageLength = getMessageLength(bedrockBuffer); - + // Initialize an array to store fully formed message chunks. const buildChunks = []; + // Process the buffer until no complete messages are left. while (bedrockBuffer.byteLength > 0 && bedrockBuffer.byteLength >= messageLength) { + // Extract a chunk of the specified length from the buffer. const extractedChunk = bedrockBuffer.slice(0, messageLength); + // Add the extracted chunk to the array of fully formed message chunks. buildChunks.push(extractedChunk); + // Remove the processed chunk from the buffer. bedrockBuffer = bedrockBuffer.slice(messageLength); + // Get the length of the next message in the updated buffer. messageLength = getMessageLength(bedrockBuffer); } const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8); + // Decode and parse each message chunk, extracting the 'completion' property. return buildChunks .map((bChunk) => { const event = awsDecoder.decode(bChunk); @@ -116,17 +133,34 @@ const parseBedrockBuffer = (chunks: Uint8Array[]) => { .join(''); }; -function concatChunks(a: Uint8Array, b: Uint8Array) { +/** + * Concatenates two Uint8Array buffers. + * + * @param {Uint8Array} a - First buffer. + * @param {Uint8Array} b - Second buffer. + * @returns {Uint8Array} - Concatenated buffer. + */ +function concatChunks(a: Uint8Array, b: Uint8Array): Uint8Array { const newBuffer = new Uint8Array(a.length + b.length); + // Copy the contents of the first buffer to the new buffer. newBuffer.set(a); + // Copy the contents of the second buffer to the new buffer starting from the end of the first buffer. newBuffer.set(b, a.length); return newBuffer; } -function getMessageLength(buffer: Uint8Array) { +/** + * Gets the length of the next message from the buffer. + * + * @param {Uint8Array} buffer - Buffer containing the message. + * @returns {number} - Length of the next message. + */ +function getMessageLength(buffer: Uint8Array): number { + // If the buffer is empty, return 0. if (buffer.byteLength === 0) return 0; + // Create a DataView to read the Uint32 value at the beginning of the buffer. const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength); - + // Read and return the Uint32 value (message length). return view.getUint32(0, false); } diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/index.tsx b/x-pack/plugins/security_solution/public/assistant/get_comments/index.tsx index b1acb2cf64eb3..d8cfc46ec5a22 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/index.tsx +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/index.tsx @@ -66,7 +66,7 @@ export const getComments = ({ regenerateMessage(currentConversation.id); }; - const connectorTypeTitle = currentConversation.apiConfig.connectorTypeTitle; + const connectorTypeTitle = currentConversation.apiConfig.connectorTypeTitle ?? ''; const extraLoadingComment = isFetchingResponse ? [ diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts index 3341413627897..acc771e6ce7c2 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts @@ -37,8 +37,10 @@ export const getStreamObservable = ({ observer.next({ chunks: [], loading: true }); const decoder = new TextDecoder(); const chunks: string[] = []; + // Initialize an empty string to store the OpenAI buffer. let openAIBuffer: string = ''; + // Initialize an empty Uint8Array to store the Bedrock concatenated buffer. let bedrockBuffer: Uint8Array = new Uint8Array(0); function readOpenAI() { reader @@ -113,18 +115,27 @@ export const getStreamObservable = ({ } else if (value != null) { const chunk: Uint8Array = value; + // Concatenate the current chunk to the existing buffer. bedrockBuffer = concatChunks(bedrockBuffer, chunk); + // Get the length of the next message in the buffer. let messageLength = getMessageLength(bedrockBuffer); + // Initialize an array to store fully formed message chunks. const buildChunks = []; + // Process the buffer until no complete messages are left. while (bedrockBuffer.byteLength > 0 && bedrockBuffer.byteLength >= messageLength) { + // Extract a chunk of the specified length from the buffer. const extractedChunk = bedrockBuffer.slice(0, messageLength); + // Add the extracted chunk to the array of fully formed message chunks. buildChunks.push(extractedChunk); + // Remove the processed chunk from the buffer. bedrockBuffer = bedrockBuffer.slice(messageLength); + // Get the length of the next message in the updated buffer. messageLength = getMessageLength(bedrockBuffer); } const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8); + // Decode and parse each message chunk, extracting the 'completion' property. buildChunks.forEach((bChunk) => { const event = awsDecoder.decode(bChunk); const body = JSON.parse( @@ -149,9 +160,22 @@ export const getStreamObservable = ({ observer.error(err); }); } + // this should never actually happen + function badConnector() { + observer.next({ + chunks: [ + `Invalid connector type - ${connectorTypeTitle} is not a supported GenAI connector.`, + ], + message: `Invalid connector type - ${connectorTypeTitle} is not a supported GenAI connector.`, + loading: false, + }); + observer.complete(); + } if (connectorTypeTitle === 'Amazon Bedrock') readBedrock(); else if (connectorTypeTitle === 'OpenAI') readOpenAI(); + else badConnector(); + return () => { reader.cancel(); }; @@ -186,6 +210,11 @@ export const getStreamObservable = ({ finalize(() => setLoading(false)) ); +/** + * Parses an OpenAI response from a string. + * @param lines + * @returns {string[]} - Parsed string array from the OpenAI response. + */ const getOpenAIChunks = (lines: string[]): string[] => { const nextChunk = lines .map((str) => str.substring(6)) @@ -201,17 +230,34 @@ const getOpenAIChunks = (lines: string[]): string[] => { return nextChunk; }; -function concatChunks(a: Uint8Array, b: Uint8Array) { +/** + * Concatenates two Uint8Array buffers. + * + * @param {Uint8Array} a - First buffer. + * @param {Uint8Array} b - Second buffer. + * @returns {Uint8Array} - Concatenated buffer. + */ +function concatChunks(a: Uint8Array, b: Uint8Array): Uint8Array { const newBuffer = new Uint8Array(a.length + b.length); + // Copy the contents of the first buffer to the new buffer. newBuffer.set(a); + // Copy the contents of the second buffer to the new buffer starting from the end of the first buffer. newBuffer.set(b, a.length); return newBuffer; } -function getMessageLength(buffer: Uint8Array) { +/** + * Gets the length of the next message from the buffer. + * + * @param {Uint8Array} buffer - Buffer containing the message. + * @returns {number} - Length of the next message. + */ +function getMessageLength(buffer: Uint8Array): number { + // If the buffer is empty, return 0. if (buffer.byteLength === 0) return 0; + // Create a DataView to read the Uint32 value at the beginning of the buffer. const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength); - + // Read and return the Uint32 value (message length). return view.getUint32(0, false); } From 3fe298998b50b75a95cfee4df3f42c1980b2dd2f Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Wed, 22 Nov 2023 11:19:26 -0700 Subject: [PATCH 13/15] fix tests --- ...get_token_count_from_invoke_stream.test.ts | 96 ++++++-- .../lib/get_token_count_from_invoke_stream.ts | 52 +++-- .../stream/stream_observable.test.ts | 208 ++++++++++++++++-- .../get_comments/stream/stream_observable.ts | 2 +- .../get_comments/stream/use_stream.test.tsx | 19 +- .../connector_types/bedrock/bedrock.test.ts | 31 +-- .../connector_types/openai/openai.test.ts | 49 +---- .../server/connector_types/openai/openai.ts | 3 +- .../tests/actions/connector_types/bedrock.ts | 54 ++++- 9 files changed, 369 insertions(+), 145 deletions(-) diff --git a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.test.ts b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.test.ts index 3c0dd66130f3a..2d8f86b881728 100644 --- a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.test.ts +++ b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.test.ts @@ -7,20 +7,15 @@ import { Transform } from 'stream'; import { getTokenCountFromInvokeStream } from './get_token_count_from_invoke_stream'; import { loggerMock } from '@kbn/logging-mocks'; +import { EventStreamCodec } from '@smithy/eventstream-codec'; +import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; -interface StreamMock { - write: (data: string) => void; - fail: () => void; - complete: () => void; - transform: Transform; -} - -function createStreamMock(): StreamMock { +function createStreamMock() { const transform: Transform = new Transform({}); return { - write: (data: string) => { - transform.push(`${data}\n`); + write: (data: unknown) => { + transform.push(data); }, fail: () => { transform.emit('error', new Error('Stream failed')); @@ -34,7 +29,10 @@ function createStreamMock(): StreamMock { } const logger = loggerMock.create(); describe('getTokenCountFromInvokeStream', () => { - let stream: StreamMock; + beforeEach(() => { + jest.resetAllMocks(); + }); + let stream: ReturnType; const body = { messages: [ { @@ -48,36 +46,79 @@ describe('getTokenCountFromInvokeStream', () => { ], }; + const chunk = { + object: 'chat.completion.chunk', + choices: [ + { + delta: { + content: 'Single.', + }, + }, + ], + }; + const PROMPT_TOKEN_COUNT = 34; const COMPLETION_TOKEN_COUNT = 2; + describe('OpenAI stream', () => { + beforeEach(() => { + stream = createStreamMock(); + stream.write(`data: ${JSON.stringify(chunk)}`); + }); - beforeEach(() => { - stream = createStreamMock(); - stream.write('Single'); - }); - - describe('when a stream completes', () => { - beforeEach(async () => { + it('counts the prompt + completion tokens for OpenAI response', async () => { stream.complete(); - }); - it('counts the prompt tokens', async () => { const tokens = await getTokenCountFromInvokeStream({ responseStream: stream.transform, body, logger, + actionTypeId: '.gen-ai', }); expect(tokens.prompt).toBe(PROMPT_TOKEN_COUNT); expect(tokens.completion).toBe(COMPLETION_TOKEN_COUNT); expect(tokens.total).toBe(PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT); }); + it('resolves the promise with the correct prompt tokens', async () => { + const tokenPromise = getTokenCountFromInvokeStream({ + responseStream: stream.transform, + body, + logger, + actionTypeId: '.gen-ai', + }); + + stream.fail(); + + await expect(tokenPromise).resolves.toEqual({ + prompt: PROMPT_TOKEN_COUNT, + total: PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT, + completion: COMPLETION_TOKEN_COUNT, + }); + expect(logger.error).toHaveBeenCalled(); + }); }); + describe('Bedrock stream', () => { + beforeEach(() => { + stream = createStreamMock(); + stream.write(encodeBedrockResponse('Simple.')); + }); - describe('when a stream fails', () => { + it('counts the prompt + completion tokens for OpenAI response', async () => { + stream.complete(); + const tokens = await getTokenCountFromInvokeStream({ + responseStream: stream.transform, + body, + logger, + actionTypeId: '.bedrock', + }); + expect(tokens.prompt).toBe(PROMPT_TOKEN_COUNT); + expect(tokens.completion).toBe(COMPLETION_TOKEN_COUNT); + expect(tokens.total).toBe(PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT); + }); it('resolves the promise with the correct prompt tokens', async () => { const tokenPromise = getTokenCountFromInvokeStream({ responseStream: stream.transform, body, logger, + actionTypeId: '.bedrock', }); stream.fail(); @@ -91,3 +132,16 @@ describe('getTokenCountFromInvokeStream', () => { }); }); }); + +function encodeBedrockResponse(completion: string) { + return new EventStreamCodec(toUtf8, fromUtf8).encode({ + headers: {}, + body: Uint8Array.from( + Buffer.from( + JSON.stringify({ + bytes: Buffer.from(JSON.stringify({ completion })).toString('base64'), + }) + ) + ), + }); +} diff --git a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts index 8c746d7a9ed7b..dfb4bae69f8cf 100644 --- a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts +++ b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts @@ -51,40 +51,46 @@ export async function getTokenCountFromInvokeStream({ .join('\n') ).length; - let responseBody: string = ''; + const parser = actionTypeId === '.bedrock' ? parseBedrockStream : parseOpenAIStream; + const parsedResponse = await parser(responseStream, logger); - const responseBuffer: Uint8Array[] = []; + const completionTokens = encode(parsedResponse).length; + return { + prompt: promptTokens, + completion: completionTokens, + total: promptTokens + completionTokens, + }; +} - const isBedrock = actionTypeId === '.bedrock'; +type StreamParser = (responseStream: Readable, logger: Logger) => Promise; +const parseBedrockStream: StreamParser = async (responseStream, logger) => { + const responseBuffer: Uint8Array[] = []; responseStream.on('data', (chunk) => { - if (isBedrock) { - // special encoding for bedrock, do not attempt to convert to string - responseBuffer.push(chunk); - } else { - // no special encoding, can safely use toString and append to responseBody - responseBody += chunk.toString(); - } + // special encoding for bedrock, do not attempt to convert to string + responseBuffer.push(chunk); }); try { await finished(responseStream); } catch (e) { logger.error('An error occurred while calculating streaming response tokens'); } + return parseBedrockBuffer(responseBuffer); +}; - // parse openai response once responseBody is fully built - // They send the response in sometimes incomplete chunks of JSON - const parsedResponse = isBedrock - ? parseBedrockBuffer(responseBuffer) - : parseOpenAIResponse(responseBody); - - const completionTokens = encode(parsedResponse).length; - return { - prompt: promptTokens, - completion: completionTokens, - total: promptTokens + completionTokens, - }; -} +const parseOpenAIStream: StreamParser = async (responseStream, logger) => { + let responseBody: string = ''; + responseStream.on('data', (chunk) => { + // no special encoding, can safely use toString and append to responseBody + responseBody += chunk.toString(); + }); + try { + await finished(responseStream); + } catch (e) { + logger.error('An error occurred while calculating streaming response tokens'); + } + return parseOpenAIResponse(responseBody); +}; /** * Parses a Bedrock buffer from an array of chunks. diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.test.ts b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.test.ts index 764db1b3990ae..54a5684d20442 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.test.ts +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.test.ts @@ -9,6 +9,8 @@ import { API_ERROR } from '../translations'; import type { PromptObservableState } from './types'; import { Subject } from 'rxjs'; +import { EventStreamCodec } from '@smithy/eventstream-codec'; +import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; describe('getStreamObservable', () => { const mockReader = { read: jest.fn(), @@ -22,29 +24,102 @@ describe('getStreamObservable', () => { beforeEach(() => { jest.clearAllMocks(); }); + it('should emit loading state and chunks for Bedrock', (done) => { + const completeSubject = new Subject(); + const expectedStates: PromptObservableState[] = [ + { chunks: [], loading: true }, + { + // when i log the actual emit, chunks equal to message.split(''); test is wrong + chunks: ['My', ' new', ' message'], + message: 'My', + loading: true, + }, + { + chunks: ['My', ' new', ' message'], + message: 'My new', + loading: true, + }, + { + chunks: ['My', ' new', ' message'], + message: 'My new message', + loading: true, + }, + { + chunks: ['My', ' new', ' message'], + message: 'My new message', + loading: false, + }, + ]; - it('should emit loading state and chunks', (done) => { + mockReader.read + .mockResolvedValueOnce({ + done: false, + value: encodeBedrockResponse('My'), + }) + .mockResolvedValueOnce({ + done: false, + value: encodeBedrockResponse(' new'), + }) + .mockResolvedValueOnce({ + done: false, + value: encodeBedrockResponse(' message'), + }) + .mockResolvedValue({ + done: true, + }); + + const source = getStreamObservable({ + connectorTypeTitle: 'Amazon Bedrock', + isError: false, + reader: typedReader, + setLoading, + }); + const emittedStates: PromptObservableState[] = []; + + source.subscribe({ + next: (state) => { + return emittedStates.push(state); + }, + complete: () => { + expect(emittedStates).toEqual(expectedStates); + done(); + + completeSubject.subscribe({ + next: () => { + expect(setLoading).toHaveBeenCalledWith(false); + expect(typedReader.cancel).toHaveBeenCalled(); + done(); + }, + }); + }, + error: (err) => done(err), + }); + }); + it('should emit loading state and chunks for OpenAI', (done) => { + const chunk1 = `data: {"object":"chat.completion.chunk","choices":[{"delta":{"content":"My"}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" new"}}]}`; + const chunk2 = `\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" message"}}]}\ndata: [DONE]`; const completeSubject = new Subject(); const expectedStates: PromptObservableState[] = [ { chunks: [], loading: true }, { - chunks: ['one chunk ', 'another chunk', ''], - message: 'one chunk ', + // when i log the actual emit, chunks equal to message.split(''); test is wrong + chunks: ['My', ' new', ' message'], + message: 'My', loading: true, }, { - chunks: ['one chunk ', 'another chunk', ''], - message: 'one chunk another chunk', + chunks: ['My', ' new', ' message'], + message: 'My new', loading: true, }, { - chunks: ['one chunk ', 'another chunk', ''], - message: 'one chunk another chunk', + chunks: ['My', ' new', ' message'], + message: 'My new message', loading: true, }, { - chunks: ['one chunk ', 'another chunk', ''], - message: 'one chunk another chunk', + chunks: ['My', ' new', ' message'], + message: 'My new message', loading: false, }, ]; @@ -52,11 +127,11 @@ describe('getStreamObservable', () => { mockReader.read .mockResolvedValueOnce({ done: false, - value: new Uint8Array(new TextEncoder().encode(`one chunk `)), + value: new Uint8Array(new TextEncoder().encode(chunk1)), }) .mockResolvedValueOnce({ done: false, - value: new Uint8Array(new TextEncoder().encode(`another chunk`)), + value: new Uint8Array(new TextEncoder().encode(chunk2)), }) .mockResolvedValueOnce({ done: false, @@ -66,11 +141,91 @@ describe('getStreamObservable', () => { done: true, }); - const source = getStreamObservable(typedReader, setLoading, false); + const source = getStreamObservable({ + connectorTypeTitle: 'OpenAI', + isError: false, + reader: typedReader, + setLoading, + }); const emittedStates: PromptObservableState[] = []; source.subscribe({ - next: (state) => emittedStates.push(state), + next: (state) => { + return emittedStates.push(state); + }, + complete: () => { + expect(emittedStates).toEqual(expectedStates); + done(); + + completeSubject.subscribe({ + next: () => { + expect(setLoading).toHaveBeenCalledWith(false); + expect(typedReader.cancel).toHaveBeenCalled(); + done(); + }, + }); + }, + error: (err) => done(err), + }); + }); + it('should emit loading state and chunks for partial response OpenAI', (done) => { + const chunk1 = `data: {"object":"chat.completion.chunk","choices":[{"delta":{"content":"My"}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" new"`; + const chunk2 = `}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" message"}}]}\ndata: [DONE]`; + const completeSubject = new Subject(); + const expectedStates: PromptObservableState[] = [ + { chunks: [], loading: true }, + { + // when i log the actual emit, chunks equal to message.split(''); test is wrong + chunks: ['My', ' new', ' message'], + message: 'My', + loading: true, + }, + { + chunks: ['My', ' new', ' message'], + message: 'My new', + loading: true, + }, + { + chunks: ['My', ' new', ' message'], + message: 'My new message', + loading: true, + }, + { + chunks: ['My', ' new', ' message'], + message: 'My new message', + loading: false, + }, + ]; + + mockReader.read + .mockResolvedValueOnce({ + done: false, + value: new Uint8Array(new TextEncoder().encode(chunk1)), + }) + .mockResolvedValueOnce({ + done: false, + value: new Uint8Array(new TextEncoder().encode(chunk2)), + }) + .mockResolvedValueOnce({ + done: false, + value: new Uint8Array(new TextEncoder().encode('')), + }) + .mockResolvedValue({ + done: true, + }); + + const source = getStreamObservable({ + connectorTypeTitle: 'OpenAI', + isError: false, + reader: typedReader, + setLoading, + }); + const emittedStates: PromptObservableState[] = []; + + source.subscribe({ + next: (state) => { + return emittedStates.push(state); + }, complete: () => { expect(emittedStates).toEqual(expectedStates); done(); @@ -112,7 +267,12 @@ describe('getStreamObservable', () => { done: true, }); - const source = getStreamObservable(typedReader, setLoading, true); + const source = getStreamObservable({ + connectorTypeTitle: 'OpenAI', + isError: true, + reader: typedReader, + setLoading, + }); const emittedStates: PromptObservableState[] = []; source.subscribe({ @@ -138,7 +298,12 @@ describe('getStreamObservable', () => { const error = new Error('Test Error'); // Simulate an error mockReader.read.mockRejectedValue(error); - const source = getStreamObservable(typedReader, setLoading, false); + const source = getStreamObservable({ + connectorTypeTitle: 'OpenAI', + isError: false, + reader: typedReader, + setLoading, + }); source.subscribe({ next: (state) => {}, @@ -157,3 +322,16 @@ describe('getStreamObservable', () => { }); }); }); + +function encodeBedrockResponse(completion: string) { + return new EventStreamCodec(toUtf8, fromUtf8).encode({ + headers: {}, + body: Uint8Array.from( + Buffer.from( + JSON.stringify({ + bytes: Buffer.from(JSON.stringify({ completion })).toString('base64'), + }) + ) + ), + }); +} diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts index acc771e6ce7c2..ce7a38811f229 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts @@ -49,7 +49,7 @@ export const getStreamObservable = ({ try { if (done) { if (openAIBuffer) { - chunks.push(openAIBuffer); + chunks.push(getOpenAIChunks([openAIBuffer])[0]); } observer.next({ chunks, diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.test.tsx b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.test.tsx index efbc61999f2cc..c4f99884aa045 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.test.tsx +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.test.tsx @@ -11,20 +11,22 @@ import { useStream } from './use_stream'; const amendMessage = jest.fn(); const reader = jest.fn(); const cancel = jest.fn(); +const chunk1 = `data: {"object":"chat.completion.chunk","choices":[{"delta":{"content":"My"}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" new"}}]}`; +const chunk2 = `\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" message"}}]}\ndata: [DONE]`; const readerComplete = { read: reader .mockResolvedValueOnce({ done: false, - value: new Uint8Array(new TextEncoder().encode('one chunk ')), + value: new Uint8Array(new TextEncoder().encode(chunk1)), }) .mockResolvedValueOnce({ done: false, - value: new Uint8Array(new TextEncoder().encode(`another chunk`)), + value: new Uint8Array(new TextEncoder().encode(chunk2)), }) .mockResolvedValueOnce({ done: false, - value: new Uint8Array(new TextEncoder().encode(``)), + value: new Uint8Array(new TextEncoder().encode('')), }) .mockResolvedValue({ done: true, @@ -34,7 +36,12 @@ const readerComplete = { closed: jest.fn().mockResolvedValue(true), } as unknown as ReadableStreamDefaultReader; -const defaultProps = { amendMessage, reader: readerComplete, isError: false }; +const defaultProps = { + amendMessage, + reader: readerComplete, + isError: false, + connectorTypeTitle: 'OpenAI', +}; describe('useStream', () => { beforeEach(() => { jest.clearAllMocks(); @@ -57,7 +64,7 @@ describe('useStream', () => { error: undefined, isLoading: true, isStreaming: true, - pendingMessage: 'one chunk ', + pendingMessage: 'My', setComplete: expect.any(Function), }); }); @@ -67,7 +74,7 @@ describe('useStream', () => { error: undefined, isLoading: false, isStreaming: false, - pendingMessage: 'one chunk another chunk', + pendingMessage: 'My new message', setComplete: expect.any(Function), }); }); diff --git a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.test.ts index 708e8cd4e0364..0eeb309dd2257 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.test.ts @@ -5,13 +5,10 @@ * 2.0. */ import aws from 'aws4'; -import { Transform } from 'stream'; +import { PassThrough, Transform } from 'stream'; import { BedrockConnector } from './bedrock'; -import { waitFor } from '@testing-library/react'; import { actionsConfigMock } from '@kbn/actions-plugin/server/actions_config.mock'; import { loggingSystemMock } from '@kbn/core-logging-server-mocks'; -import { EventStreamCodec } from '@smithy/eventstream-codec'; -import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; import { actionsMock } from '@kbn/actions-plugin/server/mocks'; import { RunActionResponseSchema, StreamingResponseSchema } from '../../../common/bedrock/schema'; import { @@ -105,7 +102,7 @@ describe('BedrockConnector', () => { let stream; beforeEach(() => { stream = createStreamMock(); - stream.write(encodeBedrockResponse(mockResponseString)); + stream.write(new Uint8Array([1, 2, 3])); mockRequest = jest.fn().mockResolvedValue({ ...mockResponse, data: stream.transform }); // @ts-ignore connector.request = mockRequest; @@ -199,16 +196,9 @@ describe('BedrockConnector', () => { }); }); - it('transforms the response into a string', async () => { + it('responds with a readable stream', async () => { const response = await connector.invokeStream(aiAssistantBody); - - let responseBody: string = ''; - response.on('data', (data: string) => { - responseBody += data.toString(); - }); - await waitFor(() => { - expect(responseBody).toEqual(mockResponseString); - }); + expect(response instanceof PassThrough).toEqual(true); }); it('errors during API calls are properly handled', async () => { @@ -364,16 +354,3 @@ function createStreamMock() { }, }; } - -function encodeBedrockResponse(completion: string) { - return new EventStreamCodec(toUtf8, fromUtf8).encode({ - headers: {}, - body: Uint8Array.from( - Buffer.from( - JSON.stringify({ - bytes: Buffer.from(JSON.stringify({ completion })).toString('base64'), - }) - ) - ), - }); -} diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts index 7769dd8592faf..c7d6feb6887ad 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts @@ -17,8 +17,7 @@ import { loggingSystemMock } from '@kbn/core-logging-server-mocks'; import { actionsMock } from '@kbn/actions-plugin/server/mocks'; import { RunActionResponseSchema, StreamingResponseSchema } from '../../../common/openai/schema'; import { initDashboard } from './create_dashboard'; -import { Transform } from 'stream'; -import { waitFor } from '@testing-library/react'; +import { PassThrough, Transform } from 'stream'; jest.mock('./create_dashboard'); describe('OpenAIConnector', () => { @@ -315,53 +314,11 @@ describe('OpenAIConnector', () => { await expect(connector.invokeStream(sampleOpenAiBody)).rejects.toThrow('API Error'); }); - it('transforms the response into a string', async () => { + it('responds with a readable stream', async () => { // @ts-ignore connector.request = mockStream(); const response = await connector.invokeStream(sampleOpenAiBody); - - let responseBody: string = ''; - response.on('data', (data: string) => { - responseBody += data.toString(); - }); - await waitFor(() => { - expect(responseBody).toEqual('My new'); - }); - }); - it('correctly buffers stream of json lines', async () => { - const chunk1 = `data: {"object":"chat.completion.chunk","choices":[{"delta":{"content":"My"}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" new"}}]}`; - const chunk2 = `\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" message"}}]}\ndata: [DONE]`; - - // @ts-ignore - connector.request = mockStream([chunk1, chunk2]); - - const response = await connector.invokeStream(sampleOpenAiBody); - - let responseBody: string = ''; - response.on('data', (data: string) => { - responseBody += data.toString(); - }); - await waitFor(() => { - expect(responseBody).toEqual('My new message'); - }); - }); - it('correctly buffers partial lines', async () => { - const chunk1 = `data: {"object":"chat.completion.chunk","choices":[{"delta":{"content":"My"}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" new"`; - - const chunk2 = `}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" message"}}]}\ndata: [DONE]`; - - // @ts-ignore - connector.request = mockStream([chunk1, chunk2]); - - const response = await connector.invokeStream(sampleOpenAiBody); - - let responseBody: string = ''; - response.on('data', (data: string) => { - responseBody += data.toString(); - }); - await waitFor(() => { - expect(responseBody).toEqual('My new message'); - }); + expect(response instanceof PassThrough).toEqual(true); }); }); diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts index 81e7c2da48313..1d704164ca139 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts @@ -8,6 +8,7 @@ import { ServiceParams, SubActionConnector } from '@kbn/actions-plugin/server'; import type { AxiosError } from 'axios'; import { IncomingMessage } from 'http'; +import { PassThrough } from 'stream'; import { RunActionParamsSchema, RunActionResponseSchema, @@ -203,7 +204,7 @@ export class OpenAIConnector extends SubActionConnector { stream: true, })) as unknown as IncomingMessage; - return res; + return res.pipe(new PassThrough()); } /** diff --git a/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/bedrock.ts b/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/bedrock.ts index 70cdc0f96dfdd..60eb8b6634a35 100644 --- a/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/bedrock.ts +++ b/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/bedrock.ts @@ -13,6 +13,8 @@ import { } from '@kbn/actions-simulators-plugin/server/bedrock_simulation'; import { DEFAULT_TOKEN_LIMIT } from '@kbn/stack-connectors-plugin/common/bedrock/constants'; import { PassThrough } from 'stream'; +import { EventStreamCodec } from '@smithy/eventstream-codec'; +import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; import { FtrProviderContext } from '../../../../../common/ftr_provider_context'; import { getUrlPrefix, ObjectRemover } from '../../../../../common/lib'; @@ -411,8 +413,6 @@ export default function bedrockTest({ getService }: FtrProviderContext) { it('should invoke stream with assistant AI body argument formatted to bedrock expectations', async () => { await new Promise((resolve, reject) => { - let responseBody: string = ''; - const passThrough = new PassThrough(); supertest @@ -434,13 +434,14 @@ export default function bedrockTest({ getService }: FtrProviderContext) { assistantLangChain: false, }) .pipe(passThrough); - + const responseBuffer: Uint8Array[] = []; passThrough.on('data', (chunk) => { - responseBody += chunk.toString(); + responseBuffer.push(chunk); }); passThrough.on('end', () => { - expect(responseBody).to.eql('Hello world, what a unique string!'); + const parsed = parseBedrockBuffer(responseBuffer); + expect(parsed).to.eql('Hello world, what a unique string!'); resolve(); }); }); @@ -517,3 +518,46 @@ export default function bedrockTest({ getService }: FtrProviderContext) { }); }); } + +const parseBedrockBuffer = (chunks: Uint8Array[]): string => { + let bedrockBuffer: Uint8Array = new Uint8Array(0); + + return chunks + .map((chunk) => { + bedrockBuffer = concatChunks(bedrockBuffer, chunk); + let messageLength = getMessageLength(bedrockBuffer); + const buildChunks = []; + while (bedrockBuffer.byteLength > 0 && bedrockBuffer.byteLength >= messageLength) { + const extractedChunk = bedrockBuffer.slice(0, messageLength); + buildChunks.push(extractedChunk); + bedrockBuffer = bedrockBuffer.slice(messageLength); + messageLength = getMessageLength(bedrockBuffer); + } + + const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8); + + return buildChunks + .map((bChunk) => { + const event = awsDecoder.decode(bChunk); + const body = JSON.parse( + Buffer.from(JSON.parse(new TextDecoder().decode(event.body)).bytes, 'base64').toString() + ); + return body.completion; + }) + .join(''); + }) + .join(''); +}; + +function concatChunks(a: Uint8Array, b: Uint8Array): Uint8Array { + const newBuffer = new Uint8Array(a.length + b.length); + newBuffer.set(a); + newBuffer.set(b, a.length); + return newBuffer; +} + +function getMessageLength(buffer: Uint8Array): number { + if (buffer.byteLength === 0) return 0; + const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength); + return view.getUint32(0, false); +} From 9d13e541910c7825a17f534f413f29e7333cced6 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Wed, 22 Nov 2023 12:31:49 -0700 Subject: [PATCH 14/15] fix type --- .../stack_connectors/server/connector_types/openai/openai.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts index 1d704164ca139..8dfeac0be8502 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts @@ -198,7 +198,7 @@ export class OpenAIConnector extends SubActionConnector { * the response from the streamApi method and returns the response string alone. * @param body - the OpenAI Invoke request body */ - public async invokeStream(body: InvokeAIActionParams): Promise { + public async invokeStream(body: InvokeAIActionParams): Promise { const res = (await this.streamApi({ body: JSON.stringify(body), stream: true, From 7385c6b764fc4d0d5426df16f5ef85dc6715551a Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Wed, 22 Nov 2023 13:06:38 -0700 Subject: [PATCH 15/15] fix type again --- .../public/assistant/get_comments/stream/index.test.tsx | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/index.test.tsx b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/index.test.tsx index 7813e45829d1c..29570959fb839 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/index.test.tsx +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/index.test.tsx @@ -19,6 +19,7 @@ const testProps = { content, index: 1, isLastComment: true, + connectorTypeTitle: 'OpenAI', regenerateMessage: jest.fn(), transformMessage: jest.fn(), };