From 2719518e8a719efe8f981ab88b34e726379bc3e5 Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Thu, 2 May 2024 20:21:59 +0200 Subject: [PATCH] [8.14] [Obs AI Assistant] Option for OpenAI compatible output (#182076) (#182421) # Backport This will backport the following commits from `main` to `8.14`: - [[Obs AI Assistant] Option for OpenAI compatible output (#182076)](https://github.com/elastic/kibana/pull/182076) ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sqren/backport) --- .../server/routes/chat/route.ts | 44 ++++--- .../util/observable_into_openai_stream.ts | 91 ++++++++++++++ .../public_complete/public_complete.spec.ts | 118 ++++++++++++++++-- 3 files changed, 227 insertions(+), 26 deletions(-) create mode 100644 x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/observable_into_openai_stream.ts diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/chat/route.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/chat/route.ts index f5e9ca339e9e8..84d604e54cc6e 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/chat/route.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/chat/route.ts @@ -12,6 +12,7 @@ import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plu import { KibanaRequest } from '@kbn/core/server'; import { aiAssistantSimulatedFunctionCalling } from '../..'; import { flushBuffer } from '../../service/util/flush_buffer'; +import { observableIntoOpenAIStream } from '../../service/util/observable_into_openai_stream'; import { observableIntoStream } from '../../service/util/observable_into_stream'; import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route'; import { screenContextRt, messageRt, functionRt } from '../runtime_types'; @@ -53,10 +54,13 @@ const chatCompleteInternalRt = t.intersection([ const chatCompletePublicRt = t.intersection([ chatCompleteBaseRt, - t.type({ + t.partial({ body: t.partial({ actions: t.array(functionRt), }), + query: t.partial({ + format: t.union([t.literal('default'), t.literal('openai')]), + }), }), ]); @@ -230,24 +234,32 @@ const publicChatCompleteRoute = createObservabilityAIAssistantServerRoute({ }, params: chatCompletePublicRt, handler: async (resources): Promise => { + const { params, logger } = resources; + const { body: { actions, ...restOfBody }, - } = resources.params; - return observableIntoStream( - await chatComplete({ - ...resources, - params: { - body: { - ...restOfBody, - screenContexts: [ - { - actions, - }, - ], - }, + query = {}, + } = params; + + const { format = 'default' } = query; + + const response$ = await chatComplete({ + ...resources, + params: { + body: { + ...restOfBody, + screenContexts: [ + { + actions, + }, + ], }, - }) - ); + }, + }); + + return format === 'openai' + ? observableIntoOpenAIStream(response$, logger) + : observableIntoStream(response$); }, }); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/observable_into_openai_stream.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/observable_into_openai_stream.ts new file mode 100644 index 0000000000000..a5e6ef2d17c91 --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/observable_into_openai_stream.ts @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { Logger } from '@kbn/logging'; +import OpenAI from 'openai'; +import { + catchError, + concatMap, + endWith, + filter, + from, + ignoreElements, + map, + Observable, + of, +} from 'rxjs'; +import { PassThrough } from 'stream'; +import { + BufferFlushEvent, + ChatCompletionChunkEvent, + StreamingChatResponseEventType, + StreamingChatResponseEventWithoutError, + TokenCountEvent, +} from '../../../common/conversation_complete'; + +export function observableIntoOpenAIStream( + source: Observable, + logger: Logger +) { + const stream = new PassThrough(); + + source + .pipe( + filter( + (event): event is ChatCompletionChunkEvent => + event.type === StreamingChatResponseEventType.ChatCompletionChunk + ), + map((event) => { + const chunk: OpenAI.ChatCompletionChunk = { + model: 'unknown', + choices: [ + { + delta: { + content: event.message.content, + function_call: event.message.function_call, + }, + finish_reason: null, + index: 0, + }, + ], + created: new Date().getTime(), + id: event.id, + object: 'chat.completion.chunk', + }; + return JSON.stringify(chunk); + }), + catchError((error) => { + return of(JSON.stringify({ error: { message: error.message } })); + }), + endWith('[DONE]'), + concatMap((line) => { + return from( + new Promise((resolve, reject) => { + stream.write(`data: ${line}\n\n`, (err) => { + if (err) { + return reject(err); + } + resolve(); + }); + }) + ); + }), + ignoreElements() + ) + .subscribe({ + error: (error) => { + logger.error('Error writing stream'); + logger.error(JSON.stringify(error)); + stream.end(error); + }, + complete: () => { + stream.end(); + }, + }); + + return stream; +} diff --git a/x-pack/test/observability_ai_assistant_api_integration/tests/public_complete/public_complete.spec.ts b/x-pack/test/observability_ai_assistant_api_integration/tests/public_complete/public_complete.spec.ts index 4430d0405764d..ac2fa36f6b0fd 100644 --- a/x-pack/test/observability_ai_assistant_api_integration/tests/public_complete/public_complete.spec.ts +++ b/x-pack/test/observability_ai_assistant_api_integration/tests/public_complete/public_complete.spec.ts @@ -44,12 +44,19 @@ export default function ApiTest({ getService }: FtrProviderContext) { let proxy: LlmProxy; let connectorId: string; - async function getEvents( - params: { - actions?: Array>; - instructions?: string[]; - }, - cb: (conversationSimulator: LlmResponseSimulator) => Promise + interface RequestOptions { + actions?: Array>; + instructions?: string[]; + format?: 'openai'; + } + + type ConversationSimulatorCallback = ( + conversationSimulator: LlmResponseSimulator + ) => Promise; + + async function getResponseBody( + { actions, instructions, format }: RequestOptions, + conversationSimulatorCallback: ConversationSimulatorCallback ) { const titleInterceptor = proxy.intercept('title', (body) => isFunctionTitleRequest(body)); @@ -61,13 +68,16 @@ export default function ApiTest({ getService }: FtrProviderContext) { const responsePromise = new Promise((resolve, reject) => { supertest .post(PUBLIC_COMPLETE_API_URL) + .query({ + format, + }) .set('kbn-xsrf', 'foo') .send({ messages, connectorId, persist: true, - actions: params.actions, - instructions: params.instructions, + actions, + instructions, }) .end((err, response) => { if (err) { @@ -87,11 +97,22 @@ export default function ApiTest({ getService }: FtrProviderContext) { await titleSimulator.complete(); await conversationSimulator.status(200); - await cb(conversationSimulator); + if (conversationSimulatorCallback) { + await conversationSimulatorCallback(conversationSimulator); + } const response = await responsePromise; - return String(response.body) + return String(response.body); + } + + async function getEvents( + options: RequestOptions, + conversationSimulatorCallback: ConversationSimulatorCallback + ) { + const responseBody = await getResponseBody(options, conversationSimulatorCallback); + + return responseBody .split('\n') .map((line) => line.trim()) .filter(Boolean) @@ -99,6 +120,17 @@ export default function ApiTest({ getService }: FtrProviderContext) { .slice(2); // ignore context request/response, we're testing this elsewhere } + async function getOpenAIResponse(conversationSimulatorCallback: ConversationSimulatorCallback) { + const responseBody = await getResponseBody( + { + format: 'openai', + }, + conversationSimulatorCallback + ); + + return responseBody; + } + before(async () => { proxy = await createLlmProxy(log); @@ -209,6 +241,72 @@ export default function ApiTest({ getService }: FtrProviderContext) { expect(request.messages[0].content).to.contain('This is a random instruction'); }); }); + + describe('with openai format', async () => { + let responseBody: string; + + before(async () => { + responseBody = await getOpenAIResponse(async (conversationSimulator) => { + await conversationSimulator.next('Hello'); + await conversationSimulator.complete(); + }); + }); + + function extractDataParts(lines: string[]) { + return lines.map((line) => { + // .replace is easier, but we want to verify here whether + // it matches the SSE syntax (`data: ...`) + const [, dataPart] = line.match(/^data: (.*)$/) || ['', '']; + return dataPart.trim(); + }); + } + + function getLines() { + return responseBody.split('\n\n').filter(Boolean); + } + + it('outputs each line an SSE-compatible format (data: ...)', () => { + const lines = getLines(); + + lines.forEach((line) => { + expect(line.match(/^data: /)); + }); + }); + + it('ouputs one chunk, and one [DONE] event', () => { + const dataParts = extractDataParts(getLines()); + + expect(dataParts[0]).not.to.be.empty(); + expect(dataParts[1]).to.be('[DONE]'); + }); + + it('outuputs an OpenAI-compatible chunk', () => { + const [dataLine] = extractDataParts(getLines()); + + expect(() => { + JSON.parse(dataLine); + }).not.to.throwException(); + + const parsedChunk = JSON.parse(dataLine); + + expect(parsedChunk).to.eql({ + model: 'unknown', + choices: [ + { + delta: { + content: 'Hello', + }, + finish_reason: null, + index: 0, + }, + ], + object: 'chat.completion.chunk', + // just test that these are a string and a number + id: String(parsedChunk.id), + created: Number(parsedChunk.created), + }); + }); + }); }); }