Skip to content

Commit

Permalink
Merge branch '8.14' into backport/8.14/pr-181616
Browse files Browse the repository at this point in the history
  • Loading branch information
kibanamachine authored May 2, 2024
2 parents 5f7f6e2 + 2719518 commit a5da838
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plu
import { KibanaRequest } from '@kbn/core/server';
import { aiAssistantSimulatedFunctionCalling } from '../..';
import { flushBuffer } from '../../service/util/flush_buffer';
import { observableIntoOpenAIStream } from '../../service/util/observable_into_openai_stream';
import { observableIntoStream } from '../../service/util/observable_into_stream';
import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route';
import { screenContextRt, messageRt, functionRt } from '../runtime_types';
Expand Down Expand Up @@ -53,10 +54,13 @@ const chatCompleteInternalRt = t.intersection([

const chatCompletePublicRt = t.intersection([
chatCompleteBaseRt,
t.type({
t.partial({
body: t.partial({
actions: t.array(functionRt),
}),
query: t.partial({
format: t.union([t.literal('default'), t.literal('openai')]),
}),
}),
]);

Expand Down Expand Up @@ -230,24 +234,32 @@ const publicChatCompleteRoute = createObservabilityAIAssistantServerRoute({
},
params: chatCompletePublicRt,
handler: async (resources): Promise<Readable> => {
const { params, logger } = resources;

const {
body: { actions, ...restOfBody },
} = resources.params;
return observableIntoStream(
await chatComplete({
...resources,
params: {
body: {
...restOfBody,
screenContexts: [
{
actions,
},
],
},
query = {},
} = params;

const { format = 'default' } = query;

const response$ = await chatComplete({
...resources,
params: {
body: {
...restOfBody,
screenContexts: [
{
actions,
},
],
},
})
);
},
});

return format === 'openai'
? observableIntoOpenAIStream(response$, logger)
: observableIntoStream(response$);
},
});

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { Logger } from '@kbn/logging';
import OpenAI from 'openai';
import {
catchError,
concatMap,
endWith,
filter,
from,
ignoreElements,
map,
Observable,
of,
} from 'rxjs';
import { PassThrough } from 'stream';
import {
BufferFlushEvent,
ChatCompletionChunkEvent,
StreamingChatResponseEventType,
StreamingChatResponseEventWithoutError,
TokenCountEvent,
} from '../../../common/conversation_complete';

export function observableIntoOpenAIStream(
source: Observable<StreamingChatResponseEventWithoutError | BufferFlushEvent | TokenCountEvent>,
logger: Logger
) {
const stream = new PassThrough();

source
.pipe(
filter(
(event): event is ChatCompletionChunkEvent =>
event.type === StreamingChatResponseEventType.ChatCompletionChunk
),
map((event) => {
const chunk: OpenAI.ChatCompletionChunk = {
model: 'unknown',
choices: [
{
delta: {
content: event.message.content,
function_call: event.message.function_call,
},
finish_reason: null,
index: 0,
},
],
created: new Date().getTime(),
id: event.id,
object: 'chat.completion.chunk',
};
return JSON.stringify(chunk);
}),
catchError((error) => {
return of(JSON.stringify({ error: { message: error.message } }));
}),
endWith('[DONE]'),
concatMap((line) => {
return from(
new Promise<void>((resolve, reject) => {
stream.write(`data: ${line}\n\n`, (err) => {
if (err) {
return reject(err);
}
resolve();
});
})
);
}),
ignoreElements()
)
.subscribe({
error: (error) => {
logger.error('Error writing stream');
logger.error(JSON.stringify(error));
stream.end(error);
},
complete: () => {
stream.end();
},
});

return stream;
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,19 @@ export default function ApiTest({ getService }: FtrProviderContext) {
let proxy: LlmProxy;
let connectorId: string;

async function getEvents(
params: {
actions?: Array<Pick<FunctionDefinition, 'name' | 'description' | 'parameters'>>;
instructions?: string[];
},
cb: (conversationSimulator: LlmResponseSimulator) => Promise<void>
interface RequestOptions {
actions?: Array<Pick<FunctionDefinition, 'name' | 'description' | 'parameters'>>;
instructions?: string[];
format?: 'openai';
}

type ConversationSimulatorCallback = (
conversationSimulator: LlmResponseSimulator
) => Promise<void>;

async function getResponseBody(
{ actions, instructions, format }: RequestOptions,
conversationSimulatorCallback: ConversationSimulatorCallback
) {
const titleInterceptor = proxy.intercept('title', (body) => isFunctionTitleRequest(body));

Expand All @@ -61,13 +68,16 @@ export default function ApiTest({ getService }: FtrProviderContext) {
const responsePromise = new Promise<Response>((resolve, reject) => {
supertest
.post(PUBLIC_COMPLETE_API_URL)
.query({
format,
})
.set('kbn-xsrf', 'foo')
.send({
messages,
connectorId,
persist: true,
actions: params.actions,
instructions: params.instructions,
actions,
instructions,
})
.end((err, response) => {
if (err) {
Expand All @@ -87,18 +97,40 @@ export default function ApiTest({ getService }: FtrProviderContext) {
await titleSimulator.complete();

await conversationSimulator.status(200);
await cb(conversationSimulator);
if (conversationSimulatorCallback) {
await conversationSimulatorCallback(conversationSimulator);
}

const response = await responsePromise;

return String(response.body)
return String(response.body);
}

async function getEvents(
options: RequestOptions,
conversationSimulatorCallback: ConversationSimulatorCallback
) {
const responseBody = await getResponseBody(options, conversationSimulatorCallback);

return responseBody
.split('\n')
.map((line) => line.trim())
.filter(Boolean)
.map((line) => JSON.parse(line) as StreamingChatResponseEvent)
.slice(2); // ignore context request/response, we're testing this elsewhere
}

async function getOpenAIResponse(conversationSimulatorCallback: ConversationSimulatorCallback) {
const responseBody = await getResponseBody(
{
format: 'openai',
},
conversationSimulatorCallback
);

return responseBody;
}

before(async () => {
proxy = await createLlmProxy(log);

Expand Down Expand Up @@ -209,6 +241,72 @@ export default function ApiTest({ getService }: FtrProviderContext) {
expect(request.messages[0].content).to.contain('This is a random instruction');
});
});

describe('with openai format', async () => {
let responseBody: string;

before(async () => {
responseBody = await getOpenAIResponse(async (conversationSimulator) => {
await conversationSimulator.next('Hello');
await conversationSimulator.complete();
});
});

function extractDataParts(lines: string[]) {
return lines.map((line) => {
// .replace is easier, but we want to verify here whether
// it matches the SSE syntax (`data: ...`)
const [, dataPart] = line.match(/^data: (.*)$/) || ['', ''];
return dataPart.trim();
});
}

function getLines() {
return responseBody.split('\n\n').filter(Boolean);
}

it('outputs each line an SSE-compatible format (data: ...)', () => {
const lines = getLines();

lines.forEach((line) => {
expect(line.match(/^data: /));
});
});

it('ouputs one chunk, and one [DONE] event', () => {
const dataParts = extractDataParts(getLines());

expect(dataParts[0]).not.to.be.empty();
expect(dataParts[1]).to.be('[DONE]');
});

it('outuputs an OpenAI-compatible chunk', () => {
const [dataLine] = extractDataParts(getLines());

expect(() => {
JSON.parse(dataLine);
}).not.to.throwException();

const parsedChunk = JSON.parse(dataLine);

expect(parsedChunk).to.eql({
model: 'unknown',
choices: [
{
delta: {
content: 'Hello',
},
finish_reason: null,
index: 0,
},
],
object: 'chat.completion.chunk',
// just test that these are a string and a number
id: String(parsedChunk.id),
created: Number(parsedChunk.created),
});
});
});
});
}

Expand Down

0 comments on commit a5da838

Please sign in to comment.