From 95e97f132496a5188bc9ef9216b8196226b5e019 Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Wed, 1 May 2024 14:36:40 +0200 Subject: [PATCH] [8.14] [Obs AI Assistant] Refactor ObservabilityAIAssistantClient (#181255) (#182237) # Backport This will backport the following commits from `main` to `8.14`: - [[Obs AI Assistant] Refactor ObservabilityAIAssistantClient (#181255)](https://github.com/elastic/kibana/pull/181255) ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sqren/backport) --- .../common/conversation_complete.ts | 6 +- .../utils/create_function_request_message.ts | 3 +- .../utils/create_function_response_error.ts | 4 +- .../utils/emit_with_concatenated_message.ts | 26 +- .../utils/without_token_count_events.ts | 23 + .../public/mock.tsx | 2 +- .../public/service/complete.test.ts | 1 + .../public/service/create_chat_service.ts | 9 +- .../public/storybook_mock.tsx | 2 +- .../server/functions/context.ts | 24 +- .../get_relevant_field_names.ts | 19 +- .../functions/get_dataset_info/index.ts | 16 +- .../server/routes/chat/route.ts | 74 +- .../chat_function_client/index.test.ts | 2 - .../service/chat_function_client/index.ts | 7 +- .../bedrock/process_bedrock_stream.test.ts | 12 +- .../fail_on_non_existing_function_call.ts | 47 +- .../server/service/client/adapters/types.ts | 9 +- .../get_context_function_request_if_needed.ts | 35 + .../server/service/client/index.test.ts | 112 +- .../server/service/client/index.ts | 995 ++++++------------ .../client/operators/continue_conversation.ts | 294 ++++++ .../server/service/client/operators/debug.ts | 21 + .../client/operators/extract_messages.ts | 24 + .../client/operators/extract_token_count.ts | 36 + .../client/operators/get_generated_title.ts | 105 ++ .../operators/hide_token_count_events.ts | 38 + .../operators/instrument_and_count_tokens.ts | 71 ++ .../server/service/types.ts | 28 +- .../catch_function_limit_exceeded_error.ts | 7 +- .../service/util/observable_into_stream.ts | 3 +- .../service/util/reject_token_count_events.ts | 26 - .../service/util/with_assistant_span.ts | 25 + .../observability_ai_assistant/tsconfig.json | 1 + .../server/functions/query/index.ts | 4 +- .../server/functions/visualize_esql.ts | 35 +- .../tests/chat/chat.spec.ts | 5 +- 37 files changed, 1271 insertions(+), 880 deletions(-) create mode 100644 x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/without_token_count_events.ts create mode 100644 x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/get_context_function_request_if_needed.ts create mode 100644 x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/continue_conversation.ts create mode 100644 x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/debug.ts create mode 100644 x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/extract_messages.ts create mode 100644 x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/extract_token_count.ts create mode 100644 x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/get_generated_title.ts create mode 100644 x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/hide_token_count_events.ts create mode 100644 x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/instrument_and_count_tokens.ts delete mode 100644 x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/reject_token_count_events.ts create mode 100644 x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/with_assistant_span.ts diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/common/conversation_complete.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/common/conversation_complete.ts index eed16e9c8ddb4..cc0a487331e61 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/common/conversation_complete.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/common/conversation_complete.ts @@ -103,13 +103,17 @@ export type StreamingChatResponseEvent = | ConversationCreateEvent | ConversationUpdateEvent | MessageAddEvent - | ChatCompletionErrorEvent; + | ChatCompletionErrorEvent + | TokenCountEvent; export type StreamingChatResponseEventWithoutError = Exclude< StreamingChatResponseEvent, ChatCompletionErrorEvent >; +export type ChatEvent = ChatCompletionChunkEvent | TokenCountEvent; +export type MessageOrChatEvent = ChatEvent | MessageAddEvent; + export enum ChatCompletionErrorCode { InternalError = 'internalError', NotFoundError = 'notFoundError', diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/create_function_request_message.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/create_function_request_message.ts index 45399ea651bb3..01a4a5c12537b 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/create_function_request_message.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/create_function_request_message.ts @@ -14,7 +14,7 @@ export function createFunctionRequestMessage({ args, }: { name: string; - args: unknown; + args?: Record; }): MessageAddEvent { return { id: v4(), @@ -28,6 +28,7 @@ export function createFunctionRequestMessage({ trigger: MessageRole.Assistant as const, }, role: MessageRole.Assistant, + content: '', }, }, }; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/create_function_response_error.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/create_function_response_error.ts index 79f6e5d4ff6df..bfb4021894273 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/create_function_response_error.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/create_function_response_error.ts @@ -24,9 +24,11 @@ export function createFunctionResponseError({ name: error.name, message: error.message, cause: error.cause, - stack: error.stack, }, message: message || error.message, }, + data: { + stack: error.stack, + }, }); } diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/emit_with_concatenated_message.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/emit_with_concatenated_message.ts index af283b78698f1..47370cc48cf00 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/emit_with_concatenated_message.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/emit_with_concatenated_message.ts @@ -5,9 +5,20 @@ * 2.0. */ -import { concat, from, last, mergeMap, Observable, shareReplay, withLatestFrom } from 'rxjs'; +import { + concat, + from, + last, + mergeMap, + Observable, + OperatorFunction, + shareReplay, + withLatestFrom, +} from 'rxjs'; +import { withoutTokenCountEvents } from './without_token_count_events'; import { ChatCompletionChunkEvent, + ChatEvent, MessageAddEvent, StreamingChatResponseEventType, } from '../conversation_complete'; @@ -40,20 +51,21 @@ function mergeWithEditedMessage( ); } -export function emitWithConcatenatedMessage( +export function emitWithConcatenatedMessage( callback?: ConcatenateMessageCallback -): ( - source$: Observable -) => Observable { - return (source$: Observable) => { +): OperatorFunction { + return (source$) => { const shared = source$.pipe(shareReplay()); + const withoutTokenCount$ = shared.pipe(withoutTokenCountEvents()); + const response$ = concat( shared, shared.pipe( + withoutTokenCountEvents(), concatenateChatCompletionChunks(), last(), - withLatestFrom(source$), + withLatestFrom(withoutTokenCount$), mergeMap(([message, chunkEvent]) => { return mergeWithEditedMessage(message, chunkEvent, callback); }) diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/without_token_count_events.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/without_token_count_events.ts new file mode 100644 index 0000000000000..137b1140fbdcd --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/without_token_count_events.ts @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { filter, OperatorFunction } from 'rxjs'; +import { + StreamingChatResponseEvent, + StreamingChatResponseEventType, + TokenCountEvent, +} from '../conversation_complete'; + +export function withoutTokenCountEvents(): OperatorFunction< + T, + Exclude +> { + return filter( + (event): event is Exclude => + event.type !== StreamingChatResponseEventType.TokenCount + ); +} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/mock.tsx b/x-pack/plugins/observability_solution/observability_ai_assistant/public/mock.tsx index 28b05433b2e1e..4775ad1b551b1 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/mock.tsx +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/mock.tsx @@ -38,7 +38,7 @@ export const mockChatService: ObservabilityAIAssistantChatService = { '@timestamp': new Date().toISOString(), message: { role: MessageRole.System, - content: '', + content: 'System', }, }), }; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.test.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.test.ts index a0b7b8fe1447e..f02471e8090d6 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.test.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.test.ts @@ -284,6 +284,7 @@ describe('complete', () => { '@timestamp': expect.any(String), message: { content: expect.any(String), + data: expect.any(String), name: 'my_action', role: MessageRole.User, }, diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.ts index 4995aa1b584ba..c0b897a134dd5 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.ts @@ -28,7 +28,6 @@ import { StreamingChatResponseEventType, type StreamingChatResponseEventWithoutError, type StreamingChatResponseEvent, - TokenCountEvent, } from '../../common/conversation_complete'; import { FunctionRegistry, @@ -163,13 +162,7 @@ export async function createChatService({ const subscription = toObservable(response) .pipe( - map( - (line) => - JSON.parse(line) as - | StreamingChatResponseEvent - | BufferFlushEvent - | TokenCountEvent - ), + map((line) => JSON.parse(line) as StreamingChatResponseEvent | BufferFlushEvent), filter( (line): line is StreamingChatResponseEvent => line.type !== StreamingChatResponseEventType.BufferFlush && diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/storybook_mock.tsx b/x-pack/plugins/observability_solution/observability_ai_assistant/public/storybook_mock.tsx index 1d9d79838bd3a..6cad5a52ed2f8 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/storybook_mock.tsx +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/storybook_mock.tsx @@ -33,7 +33,7 @@ export const createStorybookChatService = (): ObservabilityAIAssistantChatServic '@timestamp': new Date().toISOString(), message: { role: MessageRole.System, - content: '', + content: 'System', }, }), }); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts index 7c785392dfaf4..a64e63ad49c4c 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts @@ -21,7 +21,7 @@ import { concatenateChatCompletionChunks } from '../../common/utils/concatenate_ import { createFunctionResponseMessage } from '../../common/utils/create_function_response_message'; import { RecallRanking, RecallRankingEventType } from '../analytics/recall_ranking'; import type { ObservabilityAIAssistantClient } from '../service/client'; -import { ChatFn } from '../service/types'; +import { FunctionCallChatFunction } from '../service/types'; import { parseSuggestionScores } from './parse_suggestion_scores'; const MAX_TOKEN_COUNT_FOR_DATA_ON_SCREEN = 1000; @@ -61,7 +61,7 @@ export function registerContextFunction({ required: ['queries', 'categories'], } as const, }, - async ({ arguments: args, messages, connectorId, screenContexts, chat }, signal) => { + async ({ arguments: args, messages, screenContexts, chat }, signal) => { const { analytics } = (await resources.context.core).coreStart; const { queries, categories } = args; @@ -118,7 +118,6 @@ export function registerContextFunction({ queries: queriesOrUserPrompt, messages, chat, - connectorId, signal, logger: resources.logger, }); @@ -209,15 +208,13 @@ async function scoreSuggestions({ messages, queries, chat, - connectorId, signal, logger, }: { suggestions: Awaited>; messages: Message[]; queries: string[]; - chat: ChatFn; - connectorId: string; + chat: FunctionCallChatFunction; signal: AbortSignal; logger: Logger; }) { @@ -274,15 +271,12 @@ async function scoreSuggestions({ }; const response = await lastValueFrom( - ( - await chat('score_suggestions', { - connectorId, - messages: [...messages.slice(0, -2), newUserMessage], - functions: [scoreFunction], - functionCall: 'score', - signal, - }) - ).pipe(concatenateChatCompletionChunks()) + chat('score_suggestions', { + messages: [...messages.slice(0, -2), newUserMessage], + functions: [scoreFunction], + functionCall: 'score', + signal, + }).pipe(concatenateChatCompletionChunks()) ); const scoreFunctionRequest = decodeOrThrow(scoreFunctionRequestRt)(response); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/get_dataset_info/get_relevant_field_names.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/get_dataset_info/get_relevant_field_names.ts index 9fc0ad4056870..543641098836f 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/get_dataset_info/get_relevant_field_names.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/get_dataset_info/get_relevant_field_names.ts @@ -5,13 +5,13 @@ * 2.0. */ import datemath from '@elastic/datemath'; -import type { DataViewsServerPluginStart } from '@kbn/data-views-plugin/server'; import type { ElasticsearchClient, SavedObjectsClientContract } from '@kbn/core/server'; +import type { DataViewsServerPluginStart } from '@kbn/data-views-plugin/server'; import { castArray, chunk, groupBy, uniq } from 'lodash'; -import { lastValueFrom, Observable } from 'rxjs'; -import type { ObservabilityAIAssistantClient } from '../../service/client'; -import { type ChatCompletionChunkEvent, type Message, MessageRole } from '../../../common'; +import { lastValueFrom } from 'rxjs'; +import { MessageRole, type Message } from '../../../common'; import { concatenateChatCompletionChunks } from '../../../common/utils/concatenate_chat_completion_chunks'; +import { FunctionCallChatFunction } from '../../service/types'; export async function getRelevantFieldNames({ index, @@ -22,6 +22,7 @@ export async function getRelevantFieldNames({ savedObjectsClient, chat, messages, + signal, }: { index: string | string[]; start?: string; @@ -30,13 +31,8 @@ export async function getRelevantFieldNames({ esClient: ElasticsearchClient; savedObjectsClient: SavedObjectsClientContract; messages: Message[]; - chat: ( - name: string, - {}: Pick< - Parameters[1], - 'functionCall' | 'functions' | 'messages' - > - ) => Promise>; + chat: FunctionCallChatFunction; + signal: AbortSignal; }): Promise<{ fields: string[] }> { const dataViewsService = await dataViews.dataViewsServiceFactory(savedObjectsClient, esClient); @@ -79,6 +75,7 @@ export async function getRelevantFieldNames({ chunk(fieldNames, 500).map(async (fieldsInChunk) => { const chunkResponse$ = ( await chat('get_relevent_dataset_names', { + signal, messages: [ { '@timestamp': new Date().toISOString(), diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/get_dataset_info/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/get_dataset_info/index.ts index 1554df10175a2..e5b4e21195003 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/get_dataset_info/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/get_dataset_info/index.ts @@ -37,7 +37,7 @@ export function registerGetDatasetInfoFunction({ required: ['index'], } as const, }, - async ({ arguments: { index }, messages, connectorId, chat }, signal) => { + async ({ arguments: { index }, messages, chat }, signal) => { const coreContext = await resources.context.core; const esClient = coreContext.elasticsearch.client.asCurrentUser; @@ -83,18 +83,8 @@ export function registerGetDatasetInfoFunction({ esClient, dataViews: await resources.plugins.dataViews.start(), savedObjectsClient, - chat: ( - operationName, - { messages: nextMessages, functionCall, functions: nextFunctions } - ) => { - return chat(operationName, { - messages: nextMessages, - functionCall, - functions: nextFunctions, - connectorId, - signal, - }); - }, + signal, + chat, }); return { diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/chat/route.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/chat/route.ts index ae96e633b3278..f5e9ca339e9e8 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/chat/route.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/chat/route.ts @@ -8,12 +8,15 @@ import { notImplemented } from '@hapi/boom'; import { toBooleanRt } from '@kbn/io-ts-utils'; import * as t from 'io-ts'; import { Readable } from 'stream'; +import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; +import { KibanaRequest } from '@kbn/core/server'; import { aiAssistantSimulatedFunctionCalling } from '../..'; import { flushBuffer } from '../../service/util/flush_buffer'; import { observableIntoStream } from '../../service/util/observable_into_stream'; import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route'; import { screenContextRt, messageRt, functionRt } from '../runtime_types'; import { ObservabilityAIAssistantRouteHandlerResources } from '../types'; +import { withAssistantSpan } from '../../service/util/with_assistant_span'; const chatCompleteBaseRt = t.type({ body: t.intersection([ @@ -57,6 +60,27 @@ const chatCompletePublicRt = t.intersection([ }), ]); +async function guardAgainstInvalidConnector({ + actions, + request, + connectorId, +}: { + actions: ActionsPluginStart; + request: KibanaRequest; + connectorId: string; +}) { + return withAssistantSpan('guard_against_invalid_connector', async () => { + const actionsClient = await actions.getActionsClientWithRequest(request); + + const connector = await actionsClient.get({ + id: connectorId, + throwIfSystemAction: true, + }); + + return connector; + }); +} + const chatRoute = createObservabilityAIAssistantServerRoute({ endpoint: 'POST /internal/observability_ai_assistant/chat', options: { @@ -76,7 +100,17 @@ const chatRoute = createObservabilityAIAssistantServerRoute({ ]), }), handler: async (resources): Promise => { - const { request, params, service, context } = resources; + const { request, params, service, context, plugins } = resources; + + const { + body: { name, messages, connectorId, functions, functionCall }, + } = params; + + await guardAgainstInvalidConnector({ + actions: await plugins.actions.start(), + request, + connectorId, + }); const [client, cloudStart, simulateFunctionCalling] = await Promise.all([ service.getClient({ request }), @@ -88,17 +122,13 @@ const chatRoute = createObservabilityAIAssistantServerRoute({ throw notImplemented(); } - const { - body: { name, messages, connectorId, functions, functionCall }, - } = params; - const controller = new AbortController(); request.events.aborted$.subscribe(() => { controller.abort(); }); - const response$ = await client.chat(name, { + const response$ = client.chat(name, { messages, connectorId, signal: controller.signal, @@ -120,19 +150,7 @@ async function chatComplete( params: t.TypeOf; } ) { - const { request, params, service } = resources; - - const [client, cloudStart, simulateFunctionCalling] = await Promise.all([ - service.getClient({ request }), - resources.plugins.cloud?.start() || Promise.resolve(undefined), - ( - await resources.context.core - ).uiSettings.client.get(aiAssistantSimulatedFunctionCalling), - ]); - - if (!client) { - throw notImplemented(); - } + const { request, params, service, plugins } = resources; const { body: { @@ -147,6 +165,24 @@ async function chatComplete( }, } = params; + await guardAgainstInvalidConnector({ + actions: await plugins.actions.start(), + request, + connectorId, + }); + + const [client, cloudStart, simulateFunctionCalling] = await Promise.all([ + service.getClient({ request }), + resources.plugins.cloud?.start() || Promise.resolve(undefined), + ( + await resources.context.core + ).uiSettings.client.get(aiAssistantSimulatedFunctionCalling), + ]); + + if (!client) { + throw notImplemented(); + } + const controller = new AbortController(); request.events.aborted$.subscribe(() => { diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/chat_function_client/index.test.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/chat_function_client/index.test.ts index 9ecbd450cba30..bdce24d65d2c7 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/chat_function_client/index.test.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/chat_function_client/index.test.ts @@ -48,7 +48,6 @@ describe('chatFunctionClient', () => { }), messages: [], signal: new AbortController().signal, - connectorId: '', }); }).rejects.toThrowError(`Function arguments are invalid`); @@ -107,7 +106,6 @@ describe('chatFunctionClient', () => { name: 'get_data_on_screen', args: JSON.stringify({ data: ['my_dummy_data'] }), messages: [], - connectorId: '', signal: new AbortController().signal, }); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/chat_function_client/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/chat_function_client/index.ts index d0b019d635c12..e882616e202cc 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/chat_function_client/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/chat_function_client/index.ts @@ -13,7 +13,7 @@ import { FunctionVisibility, type FunctionResponse } from '../../../common/funct import type { Message, ObservabilityAIAssistantScreenContextRequest } from '../../../common/types'; import { filterFunctionDefinitions } from '../../../common/utils/filter_function_definitions'; import type { - ChatFn, + FunctionCallChatFunction, FunctionHandler, FunctionHandlerRegistry, RegisteredInstruction, @@ -144,14 +144,12 @@ export class ChatFunctionClient { args, messages, signal, - connectorId, }: { - chat: ChatFn; + chat: FunctionCallChatFunction; name: string; args: string | undefined; messages: Message[]; signal: AbortSignal; - connectorId: string; }): Promise { const fn = this.functionRegistry.get(name); @@ -167,7 +165,6 @@ export class ChatFunctionClient { { arguments: parsedArguments, messages, - connectorId, screenContexts: this.screenContexts, chat, }, diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/bedrock/process_bedrock_stream.test.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/bedrock/process_bedrock_stream.test.ts index 90f7d6f5ee69c..6aef5fb091185 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/bedrock/process_bedrock_stream.test.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/bedrock/process_bedrock_stream.test.ts @@ -11,9 +11,9 @@ import { Logger } from '@kbn/logging'; import { concatenateChatCompletionChunks } from '../../../../../common/utils/concatenate_chat_completion_chunks'; import { processBedrockStream } from './process_bedrock_stream'; import { MessageRole } from '../../../../../common'; -import { rejectTokenCountEvents } from '../../../util/reject_token_count_events'; import { TOOL_USE_END, TOOL_USE_START } from '../simulate_function_calling/constants'; import { parseInlineFunctionCalls } from '../simulate_function_calling/parse_inline_function_calls'; +import { withoutTokenCountEvents } from '../../../../../common/utils/without_token_count_events'; describe('processBedrockStream', () => { const encodeChunk = (body: unknown) => { @@ -69,7 +69,7 @@ describe('processBedrockStream', () => { parseInlineFunctionCalls({ logger: getLoggerMock(), }), - rejectTokenCountEvents(), + withoutTokenCountEvents(), concatenateChatCompletionChunks() ) ) @@ -101,7 +101,7 @@ describe('processBedrockStream', () => { parseInlineFunctionCalls({ logger: getLoggerMock(), }), - rejectTokenCountEvents(), + withoutTokenCountEvents(), concatenateChatCompletionChunks() ) ) @@ -135,7 +135,7 @@ describe('processBedrockStream', () => { parseInlineFunctionCalls({ logger: getLoggerMock(), }), - rejectTokenCountEvents(), + withoutTokenCountEvents(), concatenateChatCompletionChunks() ) ) @@ -167,7 +167,7 @@ describe('processBedrockStream', () => { parseInlineFunctionCalls({ logger: getLoggerMock(), }), - rejectTokenCountEvents(), + withoutTokenCountEvents(), concatenateChatCompletionChunks() ) ); @@ -193,7 +193,7 @@ describe('processBedrockStream', () => { parseInlineFunctionCalls({ logger: getLoggerMock(), }), - rejectTokenCountEvents(), + withoutTokenCountEvents(), concatenateChatCompletionChunks() ) ) diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/fail_on_non_existing_function_call.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/fail_on_non_existing_function_call.ts index 1e99fd623052b..d4c7c40e440ce 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/fail_on_non_existing_function_call.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/fail_on_non_existing_function_call.ts @@ -5,28 +5,24 @@ * 2.0. */ -import { noop } from 'lodash'; -import { forkJoin, last, Observable, shareReplay, tap } from 'rxjs'; -import { - ChatCompletionChunkEvent, - createFunctionNotFoundError, - FunctionDefinition, -} from '../../../../common'; -import { TokenCountEvent } from '../../../../common/conversation_complete'; +import { ignoreElements, last, merge, Observable, shareReplay, tap } from 'rxjs'; +import { createFunctionNotFoundError, FunctionDefinition } from '../../../../common'; +import { ChatEvent } from '../../../../common/conversation_complete'; import { concatenateChatCompletionChunks } from '../../../../common/utils/concatenate_chat_completion_chunks'; -import { rejectTokenCountEvents } from '../../util/reject_token_count_events'; +import { withoutTokenCountEvents } from '../../../../common/utils/without_token_count_events'; export function failOnNonExistingFunctionCall({ functions, }: { functions?: Array>; }) { - return (source$: Observable) => { - return new Observable((subscriber) => { - const shared = source$.pipe(shareReplay()); + return (source$: Observable) => { + const shared$ = source$.pipe(shareReplay()); - const checkFunctionCallResponse$ = shared.pipe( - rejectTokenCountEvents(), + return merge( + shared$, + shared$.pipe( + withoutTokenCountEvents(), concatenateChatCompletionChunks(), last(), tap((event) => { @@ -36,24 +32,9 @@ export function failOnNonExistingFunctionCall({ ) { throw createFunctionNotFoundError(event.message.function_call.name); } - }) - ); - - source$.subscribe({ - next: (val) => { - subscriber.next(val); - }, - error: noop, - }); - - forkJoin([source$, checkFunctionCallResponse$]).subscribe({ - complete: () => { - subscriber.complete(); - }, - error: (error) => { - subscriber.error(error); - }, - }); - }); + }), + ignoreElements() + ) + ); }; } diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/types.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/types.ts index 3dd2c4bbed5f3..2a292035acdb2 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/types.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/types.ts @@ -9,10 +9,7 @@ import type { Readable } from 'node:stream'; import type { Observable } from 'rxjs'; import type { Logger } from '@kbn/logging'; import type { Message } from '../../../../common'; -import type { - ChatCompletionChunkEvent, - TokenCountEvent, -} from '../../../../common/conversation_complete'; +import type { ChatEvent } from '../../../../common/conversation_complete'; import { CompatibleJSONSchema } from '../../../../common/functions/types'; export interface LlmFunction { @@ -31,7 +28,5 @@ export type LlmApiAdapterFactory = (options: { export interface LlmApiAdapter { getSubAction: () => { subAction: string; subActionParams: Record }; - streamIntoObservable: ( - readable: Readable - ) => Observable; + streamIntoObservable: (readable: Readable) => Observable; } diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/get_context_function_request_if_needed.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/get_context_function_request_if_needed.ts new file mode 100644 index 0000000000000..8f05cf144a33b --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/get_context_function_request_if_needed.ts @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { findLastIndex } from 'lodash'; +import { Message, MessageAddEvent, MessageRole } from '../../../common'; +import { createFunctionRequestMessage } from '../../../common/utils/create_function_request_message'; + +export function getContextFunctionRequestIfNeeded( + messages: Message[] +): MessageAddEvent | undefined { + const indexOfLastUserMessage = findLastIndex( + messages, + (message) => message.message.role === MessageRole.User && !message.message.name + ); + + const hasContextSinceLastUserMessage = messages + .slice(indexOfLastUserMessage) + .some((message) => message.message.name === 'context'); + + if (hasContextSinceLastUserMessage) { + return undefined; + } + + return createFunctionRequestMessage({ + name: 'context', + args: { + queries: [], + categories: [], + }, + }); +} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts index a35e50d538bcb..e22f63cf92eb2 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts @@ -39,12 +39,14 @@ const nextTick = () => { return new Promise(process.nextTick); }; -const waitForNextWrite = async (stream: Readable): Promise => { +const waitForNextWrite = async (stream: Readable): Promise => { // this will fire before the client's internal write() promise is // resolved - await new Promise((resolve) => stream.once('data', resolve)); + const response = await new Promise((resolve) => stream.once('data', resolve)); // so we wait another tick to let the client move to the next step await nextTick(); + + return response; }; function createLlmSimulator() { @@ -108,12 +110,7 @@ describe('Observability AI Assistant client', () => { getInstructions: jest.fn(), } as any; - const loggerMock: DeeplyMockedKeys = { - log: jest.fn(), - error: jest.fn(), - debug: jest.fn(), - trace: jest.fn(), - } as any; + let loggerMock: DeeplyMockedKeys = {} as any; const functionClientMock: DeeplyMockedKeys = { executeFunction: jest.fn(), @@ -130,6 +127,18 @@ describe('Observability AI Assistant client', () => { function createClient() { jest.resetAllMocks(); + // uncomment this line for debugging + // const consoleOrPassThrough = console.log.bind(console); + const consoleOrPassThrough = () => {}; + + loggerMock = { + log: jest.fn().mockImplementation(consoleOrPassThrough), + error: jest.fn().mockImplementation(consoleOrPassThrough), + debug: jest.fn().mockImplementation(consoleOrPassThrough), + trace: jest.fn().mockImplementation(consoleOrPassThrough), + isLevelEnabled: jest.fn().mockReturnValue(true), + } as any; + functionClientMock.getFunctions.mockReturnValue([]); functionClientMock.hasFunction.mockImplementation((name) => { return name !== 'context'; @@ -214,24 +223,27 @@ describe('Observability AI Assistant client', () => { beforeEach(async () => { client = createClient(); actionsClientMock.execute - .mockImplementationOnce(() => { + .mockImplementationOnce((body) => { return new Promise((resolve, reject) => { titleLlmPromiseResolve = (title: string) => { const titleLlmSimulator = createLlmSimulator(); - titleLlmSimulator.next({ content: title }); - titleLlmSimulator.complete(); - resolve({ - actionId: '', - status: 'ok', - data: titleLlmSimulator.stream, - }); + titleLlmSimulator + .next({ content: title }) + .then(() => titleLlmSimulator.complete()) + .then(() => { + resolve({ + actionId: '', + status: 'ok', + data: titleLlmSimulator.stream, + }); + }); }; - titleLlmPromiseReject = () => { - reject(); + titleLlmPromiseReject = (error: Error) => { + reject(error); }; }); }) - .mockImplementationOnce(async () => { + .mockImplementationOnce(async (body) => { llmSimulator = createLlmSimulator(); return { actionId: '', @@ -260,6 +272,8 @@ describe('Observability AI Assistant client', () => { stream.on('data', dataHandler); await llmSimulator.next({ content: 'Hello' }); + + await nextTick(); }); it('calls the actions client with the messages', () => { @@ -346,9 +360,9 @@ describe('Observability AI Assistant client', () => { id: expect.any(String), last_updated: expect.any(String), token_count: { - completion: 2, - prompt: 156, - total: 158, + completion: 1, + prompt: 78, + total: 79, }, }, type: StreamingChatResponseEventType.ConversationCreate, @@ -364,8 +378,6 @@ describe('Observability AI Assistant client', () => { titleLlmPromiseResolve('An auto-generated title'); - await nextTick(); - await llmSimulator.complete(); await finished(stream); @@ -405,9 +417,9 @@ describe('Observability AI Assistant client', () => { id: expect.any(String), last_updated: expect.any(String), token_count: { - completion: 8, - prompt: 340, - total: 348, + completion: 6, + prompt: 262, + total: 268, }, }, type: StreamingChatResponseEventType.ConversationCreate, @@ -423,9 +435,9 @@ describe('Observability AI Assistant client', () => { last_updated: expect.any(String), title: 'An auto-generated title', token_count: { - completion: 8, - prompt: 340, - total: 348, + completion: 6, + prompt: 262, + total: 268, }, }, labels: {}, @@ -477,7 +489,7 @@ describe('Observability AI Assistant client', () => { beforeEach(async () => { client = createClient(); - actionsClientMock.execute.mockImplementationOnce(async () => { + actionsClientMock.execute.mockImplementationOnce(async (body) => { llmSimulator = createLlmSimulator(); return { actionId: '', @@ -499,6 +511,11 @@ describe('Observability AI Assistant client', () => { id: 'my-conversation-id', title: 'My stored conversation', last_updated: new Date().toISOString(), + token_count: { + completion: 1, + prompt: 78, + total: 79, + }, }, labels: {}, numeric_labels: {}, @@ -694,7 +711,7 @@ describe('Observability AI Assistant client', () => { beforeEach(async () => { client = createClient(); - actionsClientMock.execute.mockImplementationOnce(async () => { + actionsClientMock.execute.mockImplementationOnce(async (body) => { llmSimulator = createLlmSimulator(); return { actionId: '', @@ -794,7 +811,6 @@ describe('Observability AI Assistant client', () => { it('executes the function', () => { expect(functionClientMock.executeFunction).toHaveBeenCalledWith({ - connectorId: 'foo', name: 'myFunction', chat: expect.any(Function), args: JSON.stringify({ foo: 'bar' }), @@ -832,6 +848,7 @@ describe('Observability AI Assistant client', () => { afterEach(async () => { fnResponseResolve({ content: { my: 'content' } }); + await waitForNextWrite(stream); await llmSimulator.complete(); @@ -993,7 +1010,12 @@ describe('Observability AI Assistant client', () => { }); it('appends the function response', () => { - expect(JSON.parse(dataHandler.mock.lastCall!)).toEqual({ + const parsed = JSON.parse(dataHandler.mock.lastCall!); + + parsed.message.message.content = JSON.parse(parsed.message.message.content); + parsed.message.message.data = JSON.parse(parsed.message.message.data); + + expect(parsed).toEqual({ type: StreamingChatResponseEventType.MessageAdd, id: expect.any(String), message: { @@ -1001,10 +1023,16 @@ describe('Observability AI Assistant client', () => { message: { role: MessageRole.User, name: 'myFunction', - content: JSON.stringify({ - message: 'Error: Function failed', - error: {}, - }), + content: { + message: 'Function failed', + error: { + name: 'Error', + message: 'Function failed', + }, + }, + data: { + stack: expect.any(String), + }, }, }, }); @@ -1138,7 +1166,7 @@ describe('Observability AI Assistant client', () => { let dataHandler: jest.Mock; beforeEach(async () => { client = createClient(); - actionsClientMock.execute.mockImplementationOnce(async () => { + actionsClientMock.execute.mockImplementationOnce(async (body) => { llmSimulator = createLlmSimulator(); return { actionId: '', @@ -1149,7 +1177,7 @@ describe('Observability AI Assistant client', () => { functionClientMock.hasFunction.mockReturnValue(true); - functionClientMock.executeFunction.mockImplementationOnce(async () => { + functionClientMock.executeFunction.mockImplementationOnce(async (body) => { return { content: [ { @@ -1327,14 +1355,14 @@ describe('Observability AI Assistant client', () => { await nextTick(); - for (let i = 0; i <= maxFunctionCalls + 1; i++) { + for (let i = 0; i <= maxFunctionCalls; i++) { await requestAlertsFunctionCall(); } await finished(stream); }); - it('executed the function no more than three times', () => { + it(`executed the function no more than ${maxFunctionCalls} times`, () => { expect(functionClientMock.executeFunction).toHaveBeenCalledTimes(maxFunctionCalls); }); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts index e4cb53be99754..c00c060bcf138 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts @@ -5,25 +5,27 @@ * 2.0. */ import type { SearchHit } from '@elastic/elasticsearch/lib/api/types'; -import { internal, notFound } from '@hapi/boom'; +import { notFound } from '@hapi/boom'; import type { ActionsClient } from '@kbn/actions-plugin/server'; import type { ElasticsearchClient } from '@kbn/core/server'; import type { Logger } from '@kbn/logging'; import type { PublicMethodsOf } from '@kbn/utility-types'; -import apm from 'elastic-apm-node'; -import { decode, encode } from 'gpt-tokenizer'; -import { findLastIndex, last, merge, noop, omit, pick, take } from 'lodash'; +import { merge, omit } from 'lodash'; import { filter, - identity, - isObservable, - last as lastOperator, - lastValueFrom, + forkJoin, + from, + merge as mergeOperator, map, Observable, + of, shareReplay, + switchMap, + throwError, + combineLatest, tap, - toArray, + catchError, + defer, } from 'rxjs'; import { Readable } from 'stream'; import { v4 } from 'uuid'; @@ -31,20 +33,17 @@ import { ObservabilityAIAssistantConnectorType } from '../../../common/connector import { ChatCompletionChunkEvent, ChatCompletionErrorEvent, + ConversationCreateEvent, + ConversationUpdateEvent, createConversationNotFoundError, + createInternalServerError, createTokenLimitReachedError, - MessageAddEvent, StreamingChatResponseEventType, TokenCountEvent, type StreamingChatResponseEvent, } from '../../../common/conversation_complete'; +import { CompatibleJSONSchema } from '../../../common/functions/types'; import { - CompatibleJSONSchema, - FunctionResponse, - FunctionVisibility, -} from '../../../common/functions/types'; -import { - MessageRole, UserInstruction, type Conversation, type ConversationCreateRequest, @@ -52,25 +51,30 @@ import { type KnowledgeBaseEntry, type Message, } from '../../../common/types'; -import { concatenateChatCompletionChunks } from '../../../common/utils/concatenate_chat_completion_chunks'; -import { createFunctionResponseError } from '../../../common/utils/create_function_response_error'; -import { emitWithConcatenatedMessage } from '../../../common/utils/emit_with_concatenated_message'; +import { withoutTokenCountEvents } from '../../../common/utils/without_token_count_events'; import type { ChatFunctionClient } from '../chat_function_client'; import { KnowledgeBaseEntryOperationType, KnowledgeBaseService, RecalledEntry, } from '../knowledge_base_service'; -import type { ChatFn, ObservabilityAIAssistantResourceNames } from '../types'; -import { catchFunctionLimitExceededError } from '../util/catch_function_limit_exceeded_error'; +import type { ObservabilityAIAssistantResourceNames } from '../types'; import { getAccessQuery } from '../util/get_access_query'; import { getSystemMessageFromInstructions } from '../util/get_system_message_from_instructions'; -import { rejectTokenCountEvents } from '../util/reject_token_count_events'; import { replaceSystemMessage } from '../util/replace_system_message'; +import { withAssistantSpan } from '../util/with_assistant_span'; import { createBedrockClaudeAdapter } from './adapters/bedrock/bedrock_claude_adapter'; import { failOnNonExistingFunctionCall } from './adapters/fail_on_non_existing_function_call'; import { createOpenAiAdapter } from './adapters/openai_adapter'; import { LlmApiAdapter } from './adapters/types'; +import { getContextFunctionRequestIfNeeded } from './get_context_function_request_if_needed'; +import { extractMessages } from './operators/extract_messages'; +import { extractTokenCount } from './operators/extract_token_count'; +import { instrumentAndCountTokens } from './operators/instrument_and_count_tokens'; +import { continueConversation } from './operators/continue_conversation'; +import { getGeneratedTitle } from './operators/get_generated_title'; + +const MAX_FUNCTION_CALLS = 8; export class ObservabilityAIAssistantClient { constructor( @@ -161,472 +165,273 @@ export class ObservabilityAIAssistantClient { instructions?: Array; simulateFunctionCalling?: boolean; }): Observable> => { - return new Observable>( - (subscriber) => { - const { - messages, - connectorId, - signal, - functionClient, - persist, - kibanaPublicUrl, - simulateFunctionCalling, - isPublic = false, - instructions: requestInstructions = [], - } = params; - - const isConversationUpdate = persist && !!params.conversationId; - const conversationId = persist ? params.conversationId || v4() : ''; - const title = params.title || ''; - const responseLanguage = params.responseLanguage || 'English'; - - const registeredInstructions = functionClient.getInstructions(); - - const knowledgeBaseInstructions: UserInstruction[] = []; - - if (responseLanguage) { - requestInstructions.push( - `You MUST respond in the users preferred language which is: ${responseLanguage}.` - ); - } - - let storedSystemMessage: string = ''; // will be set as soon as kb instructions are loaded - - if (persist && !isConversationUpdate && kibanaPublicUrl) { - registeredInstructions.push( - `This conversation will be persisted in Kibana and available at this url: ${ - kibanaPublicUrl + `/app/observabilityAIAssistant/conversations/${conversationId}` - }.` - ); - } - - const tokenCountResult = { - prompt: 0, - completion: 0, - total: 0, - }; - - const chatWithTokenCountIncrement: ChatFn = async (name, options) => { - const response$ = await this.chat(name, { - ...options, - simulateFunctionCalling, - }); - - const incrementTokenCount = () => { - return ( - source: Observable - ): Observable => { - return source.pipe( - tap((event) => { - if (event.type === StreamingChatResponseEventType.TokenCount) { - tokenCountResult.prompt += event.tokens.prompt; - tokenCountResult.completion += event.tokens.completion; - tokenCountResult.total += event.tokens.total; - } - }) - ); - }; - }; - - return response$.pipe(incrementTokenCount(), rejectTokenCountEvents()); - }; - - let numFunctionsCalled: number = 0; - - const MAX_FUNCTION_CALLS = 8; - const MAX_FUNCTION_RESPONSE_TOKEN_COUNT = 4000; - - const allFunctions = functionClient - .getFunctions() - .filter((fn) => { - const visibility = fn.definition.visibility ?? FunctionVisibility.All; - return ( - visibility === FunctionVisibility.All || - visibility === FunctionVisibility.AssistantOnly - ); - }) - .map((fn) => pick(fn.definition, 'name', 'description', 'parameters')); - - const allActions = functionClient.getActions(); + const { + functionClient, + connectorId, + simulateFunctionCalling, + instructions: requestInstructions = [], + messages: initialMessages, + signal, + responseLanguage = 'English', + persist, + kibanaPublicUrl, + isPublic, + title: predefinedTitle, + conversationId: predefinedConversationId, + } = params; + + if (responseLanguage) { + requestInstructions.push( + `You MUST respond in the users preferred language which is: ${responseLanguage}.` + ); + } - const next = async (nextMessages: Message[]): Promise => { - const lastMessage = last(nextMessages); - const isUserMessage = lastMessage?.message.role === MessageRole.User; + const isConversationUpdate = persist && !!predefinedConversationId; - const indexOfLastUserMessage = findLastIndex( - nextMessages, - ({ message }) => message.role === MessageRole.User && !message.name - ); + const conversationId = persist ? predefinedConversationId || v4() : ''; - const hasNoContextRequestAfterLastUserMessage = - indexOfLastUserMessage !== -1 && - nextMessages - .slice(indexOfLastUserMessage) - .every(({ message }) => message.function_call?.name !== 'context'); + if (persist && !isConversationUpdate && kibanaPublicUrl) { + requestInstructions.push( + `This conversation will be persisted in Kibana and available at this url: ${ + kibanaPublicUrl + `/app/observabilityAIAssistant/conversations/${conversationId}` + }.` + ); + } - const shouldInjectContext = - functionClient.hasFunction('context') && hasNoContextRequestAfterLastUserMessage; + const kbInstructions$ = from(this.fetchKnowledgeBaseInstructions()).pipe(shareReplay()); + + // from the initial messages, override any system message with + // the one that is based on the instructions (registered, request, kb) + const messagesWithUpdatedSystemMessage$ = kbInstructions$.pipe( + map((knowledgeBaseInstructions) => { + // this is what we eventually store in the conversation + const messagesWithUpdatedSystemMessage = replaceSystemMessage( + getSystemMessageFromInstructions({ + registeredInstructions: functionClient.getInstructions(), + knowledgeBaseInstructions, + requestInstructions, + availableFunctionNames: functionClient.getFunctions().map((fn) => fn.definition.name), + }), + initialMessages + ); + + return messagesWithUpdatedSystemMessage; + }), + shareReplay() + ); - if (shouldInjectContext) { - const contextFunctionRequest = { - '@timestamp': new Date().toISOString(), - message: { - role: MessageRole.Assistant, - content: '', - function_call: { - name: 'context', - arguments: JSON.stringify({ - queries: [], - categories: [], - }), - trigger: MessageRole.Assistant as const, + // if it is: + // - a new conversation + // - no predefined title is given + // - we need to store the conversation + // we generate a title + // if not, we complete with an empty string + const title$ = + predefinedTitle || isConversationUpdate || !persist + ? of(predefinedTitle || '').pipe(shareReplay()) + : messagesWithUpdatedSystemMessage$.pipe( + switchMap((messages) => + getGeneratedTitle({ + messages, + responseLanguage, + logger: this.dependencies.logger, + chat: (name, chatParams) => { + return this.chat(name, { + ...chatParams, + simulateFunctionCalling, + connectorId, + signal, + }); }, - }, - }; - - subscriber.next({ - type: StreamingChatResponseEventType.MessageAdd, - id: v4(), - message: contextFunctionRequest, - }); - - return await next(nextMessages.concat(contextFunctionRequest)); - } else if (isUserMessage) { - const functionCallsExceeded = numFunctionsCalled > MAX_FUNCTION_CALLS; - const functions = functionCallsExceeded ? [] : allFunctions.concat(allActions); - - const spanName = - lastMessage.message.name && lastMessage.message.name !== 'context' - ? 'function_response' - : 'user_message'; - - const systemMessageForChatRequest = getSystemMessageFromInstructions({ - registeredInstructions, - requestInstructions, - knowledgeBaseInstructions, - availableFunctionNames: functions.map((fn) => fn.name) || [], - }); - - const response$ = ( - await chatWithTokenCountIncrement(spanName, { - messages: replaceSystemMessage(systemMessageForChatRequest, nextMessages), - connectorId, - signal, - functions, }) - ).pipe( - emitWithConcatenatedMessage(), - shareReplay(), - Boolean(functions.length) ? identity : catchFunctionLimitExceededError() - ); - - response$.subscribe({ - next: (val) => subscriber.next(val), - // we handle the error below - error: noop, - }); - - const emittedMessageEvents = await lastValueFrom( - response$.pipe( - filter( - (event): event is MessageAddEvent => - event.type === StreamingChatResponseEventType.MessageAdd - ), - // LLMs like to hallucinate parameters if the function does not define - // them, and it can lead to other hallicunations down the line - map((messageEvent) => { - const fnName = messageEvent.message.message.function_call?.name; - - if (fnName && !functions.find((fn) => fn.name === fnName)?.parameters) { - const clone = { ...messageEvent }; - clone.message.message.function_call!.arguments = ''; - return clone; - } - return messageEvent; - }), - toArray() - ) - ); - - return await next( - nextMessages.concat(emittedMessageEvents.map((event) => event.message)) - ); - } - - const functionCallName = lastMessage?.message.function_call?.name; - const isAssistantMessage = lastMessage?.message.role === MessageRole.Assistant; - - if (isAssistantMessage && functionCallName) { - if (functionClient.hasAction(functionCallName)) { - this.dependencies.logger.debug(`Executing client-side action: ${functionCallName}`); - - // if validation fails, return the error to the LLM. - // otherwise, close the stream. - - try { - functionClient.validate( - functionCallName, - JSON.parse(lastMessage.message.function_call!.arguments || '{}') - ); - } catch (error) { - const functionResponseMessage = createFunctionResponseError({ - name: functionCallName, - error, - }); - nextMessages = nextMessages.concat(functionResponseMessage.message); - - subscriber.next(functionResponseMessage); - - return await next(nextMessages); - } - - subscriber.complete(); - - return; - } - - const span = apm.startSpan(`execute_function ${functionCallName}`); - - span?.addLabels({ - ai_assistant_args: JSON.stringify(lastMessage.message.function_call!.arguments ?? {}), - }); - - const functionResponse = - numFunctionsCalled >= MAX_FUNCTION_CALLS - ? { - content: { - error: {}, - message: 'Function limit exceeded, ask the user what to do next', - }, - } - : await functionClient - .executeFunction({ - chat: chatWithTokenCountIncrement, - connectorId, - name: functionCallName, - messages: replaceSystemMessage(storedSystemMessage, nextMessages), - args: lastMessage.message.function_call!.arguments, - signal, - }) - .then((response) => { - if (isObservable(response)) { - return response; - } - - span?.setOutcome('success'); - - const encoded = encode(JSON.stringify(response.content || {})); - - if (encoded.length <= MAX_FUNCTION_RESPONSE_TOKEN_COUNT) { - return response; - } - - return { - data: response.data, - content: { - message: - 'Function response exceeded the maximum length allowed and was truncated', - truncated: decode(take(encoded, MAX_FUNCTION_RESPONSE_TOKEN_COUNT)), - }, - }; - }) - .catch((error): FunctionResponse => { - span?.setOutcome('failure'); - return { - content: { - message: error.toString(), - error, - }, - }; - }); - - numFunctionsCalled++; - - if (signal.aborted) { - span?.end(); - return; - } - - if (isObservable(functionResponse)) { - const shared = functionResponse.pipe(shareReplay()); + ), + shareReplay() + ); - shared.subscribe({ - next: (val) => subscriber.next(val), - // we handle the error below - error: noop, + // we continue the conversation here, after resolving both the materialized + // messages and the knowledge base instructions + const nextEvents$ = combineLatest([messagesWithUpdatedSystemMessage$, kbInstructions$]).pipe( + switchMap(([messagesWithUpdatedSystemMessage, knowledgeBaseInstructions]) => { + // if needed, inject a context function request here + const contextRequest = functionClient.hasFunction('context') + ? getContextFunctionRequestIfNeeded(messagesWithUpdatedSystemMessage) + : undefined; + + return mergeOperator( + // if we have added a context function request, also emit + // the messageAdd event for it, so we can notify the consumer + // and add it to the conversation + ...(contextRequest ? [of(contextRequest)] : []), + continueConversation({ + messages: [ + ...messagesWithUpdatedSystemMessage, + ...(contextRequest ? [contextRequest.message] : []), + ], + chat: (name, chatParams) => { + // inject a chat function with predefined parameters + return this.chat(name, { + ...chatParams, + signal, + simulateFunctionCalling, + connectorId, }); + }, + // start out with the max number of function calls + functionCallsLeft: MAX_FUNCTION_CALLS, + functionClient, + knowledgeBaseInstructions, + requestInstructions, + signal, + }) + ); + }), + shareReplay() + ); - const messageEvents = await lastValueFrom( - shared.pipe( - filter( - (event): event is MessageAddEvent => - event.type === StreamingChatResponseEventType.MessageAdd - ), - toArray() - ) - ); - - span?.end(); - - return await next(nextMessages.concat(messageEvents.map((event) => event.message))); - } - - const functionResponseMessage = { - '@timestamp': new Date().toISOString(), - message: { - name: lastMessage.message.function_call!.name, - - content: JSON.stringify(functionResponse.content || {}), - data: functionResponse.data ? JSON.stringify(functionResponse.data) : undefined, - role: MessageRole.User, - }, - }; - - this.dependencies.logger.debug( - `Function response: ${JSON.stringify(functionResponseMessage, null, 2)}` - ); - nextMessages = nextMessages.concat(functionResponseMessage); - - subscriber.next({ - type: StreamingChatResponseEventType.MessageAdd, - message: functionResponseMessage, - id: v4(), - }); - - span?.end(); - - return await next(nextMessages); - } - - this.dependencies.logger.debug(`Conversation: ${JSON.stringify(nextMessages, null, 2)}`); - - if (!persist) { - subscriber.complete(); - return; + const output$ = mergeOperator( + // get all the events from continuing the conversation + nextEvents$, + // wait until all dependencies have completed + forkJoin([ + messagesWithUpdatedSystemMessage$, + // get just the new messages + nextEvents$.pipe(withoutTokenCountEvents(), extractMessages()), + // count all the token count events emitted during completion + mergeOperator( + nextEvents$, + title$.pipe(filter((value): value is TokenCountEvent => typeof value !== 'string')) + ).pipe(extractTokenCount()), + // get just the title, and drop the token count events + title$.pipe(filter((value): value is string => typeof value === 'string')), + ]).pipe( + switchMap(([messagesWithUpdatedSystemMessage, addedMessages, tokenCountResult, title]) => { + const initialMessagesWithAddedMessages = + messagesWithUpdatedSystemMessage.concat(addedMessages); + + const lastMessage = + initialMessagesWithAddedMessages[initialMessagesWithAddedMessages.length - 1]; + + // if a function request is at the very end, close the stream to consumer + // without persisting or updating the conversation. we need to wait + // on the function response to have a valid conversation + const isFunctionRequest = lastMessage.message.function_call?.name; + + if (!persist || isFunctionRequest) { + return of(); } - this.dependencies.logger.debug( - `Token count for conversation: ${JSON.stringify(tokenCountResult)}` - ); - - apm.currentTransaction?.addLabels({ - tokenCountPrompt: tokenCountResult.prompt, - tokenCountCompletion: tokenCountResult.completion, - tokenCountTotal: tokenCountResult.total, - }); - - // store the updated conversation and close the stream if (isConversationUpdate) { - const conversation = await this.getConversationWithMetaFields(conversationId); - if (!conversation) { - throw createConversationNotFoundError(); - } - - if (signal.aborted) { - return; - } - - const persistedTokenCount = conversation._source?.conversation.token_count; - - const updatedConversation = await this.update( - conversationId, - - merge( - {}, - - // base conversation without messages - omit(conversation._source, 'messages'), - - // update messages - { messages: replaceSystemMessage(storedSystemMessage, nextMessages) }, + return from(this.getConversationWithMetaFields(conversationId)) + .pipe( + switchMap((conversation) => { + if (!conversation) { + return throwError(() => createConversationNotFoundError()); + } - // update token count - { - conversation: { - token_count: { - prompt: (persistedTokenCount?.prompt || 0) + tokenCountResult.prompt, - completion: - (persistedTokenCount?.completion || 0) + tokenCountResult.completion, - total: (persistedTokenCount?.total || 0) + tokenCountResult.total, - }, - }, - } + const persistedTokenCount = conversation._source?.conversation.token_count ?? { + prompt: 0, + completion: 0, + total: 0, + }; + + return from( + this.update( + conversationId, + + merge( + {}, + + // base conversation without messages + omit(conversation._source, 'messages'), + + // update messages + { messages: initialMessagesWithAddedMessages }, + + // update token count + { + conversation: { + title: title || conversation._source?.conversation.title, + token_count: { + prompt: persistedTokenCount.prompt + tokenCountResult.prompt, + completion: + persistedTokenCount.completion + tokenCountResult.completion, + total: persistedTokenCount.total + tokenCountResult.total, + }, + }, + } + ) + ) + ); + }) ) - ); - - subscriber.next({ - type: StreamingChatResponseEventType.ConversationUpdate, - conversation: updatedConversation.conversation, - }); - } else { - const generatedTitle = await titlePromise; - if (signal.aborted) { - return; - } + .pipe( + map((conversation): ConversationUpdateEvent => { + return { + conversation: conversation.conversation, + type: StreamingChatResponseEventType.ConversationUpdate, + }; + }) + ); + } - const conversation = await this.create({ + return from( + this.create({ '@timestamp': new Date().toISOString(), conversation: { - title: generatedTitle || title || 'New conversation', - token_count: tokenCountResult, + title, id: conversationId, + token_count: tokenCountResult, }, - messages: replaceSystemMessage(storedSystemMessage, nextMessages), + public: !!isPublic, labels: {}, numeric_labels: {}, - public: isPublic, - }); - - subscriber.next({ - type: StreamingChatResponseEventType.ConversationCreate, - conversation: conversation.conversation, - }); - } - - subscriber.complete(); - }; - - this.fetchKnowledgeBaseInstructions() - .then((loadedKnowledgeBaseInstructions) => { - knowledgeBaseInstructions.push(...loadedKnowledgeBaseInstructions); - - storedSystemMessage = getSystemMessageFromInstructions({ - registeredInstructions, - requestInstructions, - knowledgeBaseInstructions, - availableFunctionNames: allFunctions.map((fn) => fn.name), - }); + messages: initialMessagesWithAddedMessages, + }) + ).pipe( + map((conversation): ConversationCreateEvent => { + return { + conversation: conversation.conversation, + type: StreamingChatResponseEventType.ConversationCreate, + }; + }) + ); + }) + ) + ); - return next(messages); - }) - .catch((error) => { - if (!signal.aborted) { - this.dependencies.logger.error(error); - } - subscriber.error(error); - }); + return output$.pipe( + instrumentAndCountTokens('complete'), + withoutTokenCountEvents(), + catchError((error) => { + this.dependencies.logger.error(error); + return throwError(() => error); + }), + tap((event) => { + if (this.dependencies.logger.isLevelEnabled('debug')) { + switch (event.type) { + case StreamingChatResponseEventType.MessageAdd: + this.dependencies.logger.debug(`Added message: ${JSON.stringify(event.message)}`); + break; + + case StreamingChatResponseEventType.ConversationCreate: + this.dependencies.logger.debug( + `Created conversation: ${JSON.stringify(event.conversation)}` + ); + break; - const titlePromise = - !isConversationUpdate && !title && persist - ? this.getGeneratedTitle({ - chat: chatWithTokenCountIncrement, - messages, - connectorId, - signal, - responseLanguage, - }).catch((error) => { - this.dependencies.logger.error( - 'Could not generate title, falling back to default title' - ); - this.dependencies.logger.error(error); - return Promise.resolve(undefined); - }) - : Promise.resolve(undefined); - } - ).pipe(shareReplay()); + case StreamingChatResponseEventType.ConversationUpdate: + this.dependencies.logger.debug( + `Updated conversation: ${JSON.stringify(event.conversation)}` + ); + break; + } + } + }), + shareReplay() + ); }; - chat = async ( + chat = ( name: string, { messages, @@ -643,123 +448,96 @@ export class ObservabilityAIAssistantClient { signal: AbortSignal; simulateFunctionCalling?: boolean; } - ): Promise> => { - const span = apm.startSpan(`chat ${name}`); - - const spanId = (span?.ids['span.id'] || '').substring(0, 6); - - const loggerPrefix = `${name}${spanId ? ` (${spanId})` : ''}`; - - try { - const connector = await this.dependencies.actionsClient.get({ - id: connectorId, - }); - - let adapter: LlmApiAdapter; - - this.dependencies.logger.debug(`Creating "${connector.actionTypeId}" adapter`); - - switch (connector.actionTypeId) { - case ObservabilityAIAssistantConnectorType.OpenAI: - adapter = createOpenAiAdapter({ - messages, - functions, - functionCall, - logger: this.dependencies.logger, - simulateFunctionCalling, - }); - break; - - case ObservabilityAIAssistantConnectorType.Bedrock: - adapter = createBedrockClaudeAdapter({ - messages, - functions, - functionCall, - logger: this.dependencies.logger, - }); - break; - - default: - throw new Error(`Connector type is not supported: ${connector.actionTypeId}`); - } - - const subAction = adapter.getSubAction(); - - this.dependencies.logger.debug(`${loggerPrefix}: Sending conversation to connector`); - this.dependencies.logger.trace( - `${loggerPrefix}:\n${JSON.stringify(subAction.subActionParams, null, 2)}` - ); - - const now = performance.now(); - - const executeResult = await this.dependencies.actionsClient.execute({ - actionId: connectorId, - params: subAction, - }); - - this.dependencies.logger.debug( - `${loggerPrefix}: Received action client response: ${ - executeResult.status - } (took: ${Math.round(performance.now() - now)}ms)` - ); - - if (executeResult.status === 'error' && executeResult?.serviceMessage) { - const tokenLimitRegex = - /This model's maximum context length is (\d+) tokens\. However, your messages resulted in (\d+) tokens/g; - const tokenLimitRegexResult = tokenLimitRegex.exec(executeResult.serviceMessage); + ): Observable => { + return defer(() => + from( + withAssistantSpan('get_connector', () => + this.dependencies.actionsClient.get({ id: connectorId, throwIfSystemAction: true }) + ) + ) + ).pipe( + switchMap((connector) => { + this.dependencies.logger.debug(`Creating "${connector.actionTypeId}" adapter`); + + let adapter: LlmApiAdapter; + + switch (connector.actionTypeId) { + case ObservabilityAIAssistantConnectorType.OpenAI: + adapter = createOpenAiAdapter({ + messages, + functions, + functionCall, + logger: this.dependencies.logger, + simulateFunctionCalling, + }); + break; + + case ObservabilityAIAssistantConnectorType.Bedrock: + adapter = createBedrockClaudeAdapter({ + messages, + functions, + functionCall, + logger: this.dependencies.logger, + }); + break; - if (tokenLimitRegexResult) { - const [, tokenLimit, tokenCount] = tokenLimitRegexResult; - throw createTokenLimitReachedError(parseInt(tokenLimit, 10), parseInt(tokenCount, 10)); + default: + throw new Error(`Connector type is not supported: ${connector.actionTypeId}`); } - } - if (executeResult.status === 'error') { - throw internal(`${executeResult?.message} - ${executeResult?.serviceMessage}`); - } + const subAction = adapter.getSubAction(); + + this.dependencies.logger.trace(JSON.stringify(subAction.subActionParams, null, 2)); + + return from( + withAssistantSpan('get_execute_result', () => + this.dependencies.actionsClient.execute({ + actionId: connectorId, + params: subAction, + }) + ) + ).pipe( + switchMap((executeResult) => { + if (executeResult.status === 'error' && executeResult?.serviceMessage) { + const tokenLimitRegex = + /This model's maximum context length is (\d+) tokens\. However, your messages resulted in (\d+) tokens/g; + const tokenLimitRegexResult = tokenLimitRegex.exec(executeResult.serviceMessage); + + if (tokenLimitRegexResult) { + const [, tokenLimit, tokenCount] = tokenLimitRegexResult; + throw createTokenLimitReachedError( + parseInt(tokenLimit, 10), + parseInt(tokenCount, 10) + ); + } + } - const response = executeResult.data as Readable; + if (executeResult.status === 'error') { + throw createInternalServerError( + `${executeResult?.message} - ${executeResult?.serviceMessage}` + ); + } - signal.addEventListener('abort', () => response.destroy()); + const response = executeResult.data as Readable; - const response$ = adapter.streamIntoObservable(response).pipe( - shareReplay(), - failOnNonExistingFunctionCall({ functions }), - tap((event) => { - if (event.type === StreamingChatResponseEventType.TokenCount) { - span?.addLabels({ - tokenCountPrompt: event.tokens.prompt, - tokenCountCompletion: event.tokens.completion, - tokenCountTotal: event.tokens.total, - }); - } - }) - ); + signal.addEventListener('abort', () => response.destroy()); - response$ - .pipe(rejectTokenCountEvents(), concatenateChatCompletionChunks(), lastOperator()) - .subscribe({ - error: (error) => { - this.dependencies.logger.debug('Error in chat response'); - this.dependencies.logger.debug(error); - span?.setOutcome('failure'); - span?.end(); - }, - next: (message) => { - this.dependencies.logger.debug(`Received message:\n${JSON.stringify(message)}`); - }, - complete: () => { - span?.setOutcome('success'); - span?.end(); - }, - }); - - return response$; - } catch (error) { - span?.setOutcome('failure'); - span?.end(); - throw error; - } + return adapter.streamIntoObservable(response); + }) + ); + }), + instrumentAndCountTokens(name), + failOnNonExistingFunctionCall({ functions }), + tap((event) => { + if ( + event.type === StreamingChatResponseEventType.ChatCompletionChunk && + this.dependencies.logger.isLevelEnabled('trace') + ) { + this.dependencies.logger.trace(`Received chunk: ${JSON.stringify(event.message)}`); + } + }), + shareReplay() + ); }; find = async (options?: { query?: string }): Promise<{ conversations: Conversation[] }> => { @@ -813,79 +591,6 @@ export class ObservabilityAIAssistantClient { return updatedConversation; }; - getGeneratedTitle = async ({ - chat, - messages, - connectorId, - signal, - responseLanguage, - }: { - chat: ( - ...chatParams: Parameters['chat']> - ) => Promise>; - messages: Message[]; - connectorId: string; - signal: AbortSignal; - responseLanguage: string; - }) => { - const response$ = await chat('generate_title', { - messages: [ - { - '@timestamp': new Date().toString(), - message: { - role: MessageRole.System, - content: `You are a helpful assistant for Elastic Observability. Assume the following message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you. Please create the title in ${responseLanguage}.`, - }, - }, - { - '@timestamp': new Date().toISOString(), - message: { - role: MessageRole.User, - content: messages.slice(1).reduce((acc, curr) => { - return `${acc} ${curr.message.role}: ${curr.message.content}`; - }, 'Generate a title, using the title_conversation_function, based on the following conversation:\n\n'), - }, - }, - ], - functions: [ - { - name: 'title_conversation', - description: - 'Use this function to title the conversation. Do not wrap the title in quotes', - parameters: { - type: 'object', - properties: { - title: { - type: 'string', - }, - }, - required: ['title'], - }, - }, - ], - functionCall: 'title_conversation', - connectorId, - signal, - }); - - const response = await lastValueFrom(response$.pipe(concatenateChatCompletionChunks())); - - const input = - (response.message.function_call.name - ? JSON.parse(response.message.function_call.arguments).title - : response.message?.content) || ''; - - // This regular expression captures a string enclosed in single or double quotes. - // It extracts the string content without the quotes. - // Example matches: - // - "Hello, World!" => Captures: Hello, World! - // - 'Another Example' => Captures: Another Example - // - JustTextWithoutQuotes => Captures: JustTextWithoutQuotes - const match = input.match(/^["']?([^"']+)["']?$/); - const title = match ? match[1] : input; - return title; - }; - setTitle = async ({ conversationId, title }: { conversationId: string; title: string }) => { const document = await this.getConversationWithMetaFields(conversationId); if (!document) { diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/continue_conversation.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/continue_conversation.ts new file mode 100644 index 0000000000000..a4a87c4dbe545 --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/continue_conversation.ts @@ -0,0 +1,294 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { decode, encode } from 'gpt-tokenizer'; +import { pick, take } from 'lodash'; +import { + catchError, + concat, + EMPTY, + from, + isObservable, + Observable, + of, + OperatorFunction, + shareReplay, + switchMap, + throwError, +} from 'rxjs'; +import { createFunctionNotFoundError, Message, MessageRole } from '../../../../common'; +import { + createFunctionLimitExceededError, + MessageOrChatEvent, +} from '../../../../common/conversation_complete'; +import { FunctionVisibility } from '../../../../common/functions/types'; +import { UserInstruction } from '../../../../common/types'; +import { createFunctionResponseError } from '../../../../common/utils/create_function_response_error'; +import { createFunctionResponseMessage } from '../../../../common/utils/create_function_response_message'; +import { emitWithConcatenatedMessage } from '../../../../common/utils/emit_with_concatenated_message'; +import { withoutTokenCountEvents } from '../../../../common/utils/without_token_count_events'; +import type { ChatFunctionClient } from '../../chat_function_client'; +import type { ChatFunctionWithoutConnector } from '../../types'; +import { getSystemMessageFromInstructions } from '../../util/get_system_message_from_instructions'; +import { replaceSystemMessage } from '../../util/replace_system_message'; +import { extractMessages } from './extract_messages'; +import { hideTokenCountEvents } from './hide_token_count_events'; + +const MAX_FUNCTION_RESPONSE_TOKEN_COUNT = 4000; + +function executeFunctionAndCatchError({ + name, + args, + functionClient, + messages, + chat, + signal, +}: { + name: string; + args: string | undefined; + functionClient: ChatFunctionClient; + messages: Message[]; + chat: ChatFunctionWithoutConnector; + signal: AbortSignal; +}): Observable { + // hide token count events from functions to prevent them from + // having to deal with it as well + return hideTokenCountEvents((hide) => { + const executeFunctionResponse$ = from( + functionClient.executeFunction({ + name, + chat: (operationName, params) => { + return chat(operationName, params).pipe(hide()); + }, + args, + signal, + messages, + }) + ); + + return executeFunctionResponse$.pipe( + catchError((error) => { + // We want to catch the error only when a promise occurs + // if it occurs in the Observable, we cannot easily recover + // from it because the function may have already emitted + // values which could lead to an invalid conversation state, + // so in that case we let the stream fail. + return of(createFunctionResponseError({ name, error })); + }), + switchMap((response) => { + if (isObservable(response)) { + return response; + } + + // is messageAdd event + if ('type' in response) { + return of(response); + } + + const encoded = encode(JSON.stringify(response.content || {})); + + const exceededTokenLimit = encoded.length >= MAX_FUNCTION_RESPONSE_TOKEN_COUNT; + + return of( + createFunctionResponseMessage({ + name, + content: exceededTokenLimit + ? { + message: + 'Function response exceeded the maximum length allowed and was truncated', + truncated: decode(take(encoded, MAX_FUNCTION_RESPONSE_TOKEN_COUNT)), + } + : response.content, + data: response.data, + }) + ); + }) + ); + }); +} + +function getFunctionDefinitions({ + functionClient, + functionLimitExceeded, +}: { + functionClient: ChatFunctionClient; + functionLimitExceeded: boolean; +}) { + const systemFunctions = functionLimitExceeded + ? [] + : functionClient + .getFunctions() + .map((fn) => fn.definition) + .filter( + (def) => + !def.visibility || + [FunctionVisibility.AssistantOnly, FunctionVisibility.All].includes(def.visibility) + ); + + const actions = functionLimitExceeded ? [] : functionClient.getActions(); + + const allDefinitions = systemFunctions + .concat(actions) + .map((definition) => pick(definition, 'name', 'description', 'parameters')); + + return allDefinitions; +} + +export function continueConversation({ + messages: initialMessages, + functionClient, + chat, + signal, + functionCallsLeft, + requestInstructions, + knowledgeBaseInstructions, +}: { + messages: Message[]; + functionClient: ChatFunctionClient; + chat: ChatFunctionWithoutConnector; + signal: AbortSignal; + functionCallsLeft: number; + requestInstructions: Array; + knowledgeBaseInstructions: UserInstruction[]; +}): Observable { + let nextFunctionCallsLeft = functionCallsLeft; + + const definitions = getFunctionDefinitions({ + functionLimitExceeded: functionCallsLeft <= 0, + functionClient, + }); + + const messagesWithUpdatedSystemMessage = replaceSystemMessage( + getSystemMessageFromInstructions({ + registeredInstructions: functionClient.getInstructions(), + knowledgeBaseInstructions, + requestInstructions, + availableFunctionNames: definitions.map((def) => def.name), + }), + initialMessages + ); + + const lastMessage = + messagesWithUpdatedSystemMessage[messagesWithUpdatedSystemMessage.length - 1].message; + + const isUserMessage = lastMessage.role === MessageRole.User; + + return executeNextStep().pipe(handleEvents()); + + function executeNextStep() { + if (isUserMessage) { + const operationName = + lastMessage.name && lastMessage.name !== 'context' + ? `function_response ${lastMessage.name}` + : 'user_message'; + + return chat(operationName, { + messages: messagesWithUpdatedSystemMessage, + functions: definitions, + }).pipe(emitWithConcatenatedMessage()); + } + + const functionCallName = lastMessage.function_call?.name; + + if (!functionCallName) { + // reply from the LLM without a function request, + // so we can close the stream and wait for input from the user + return EMPTY; + } + + // we know we are executing a function here, so we can already + // subtract one, and reference the old count for if clauses + const currentFunctionCallsLeft = nextFunctionCallsLeft; + + nextFunctionCallsLeft--; + + const isAction = functionCallName && functionClient.hasAction(functionCallName); + + if (currentFunctionCallsLeft === 0) { + // create a function call response error so the LLM knows it needs to stop calling functions + return of( + createFunctionResponseError({ + name: functionCallName, + error: createFunctionLimitExceededError(), + }) + ); + } + + if (currentFunctionCallsLeft < 0) { + // LLM tried calling it anyway, throw an error + return throwError(() => createFunctionLimitExceededError()); + } + + // if it's an action, we close the stream and wait for the action response + // from the client/browser + if (isAction) { + try { + functionClient.validate( + functionCallName, + JSON.parse(lastMessage.function_call!.arguments || '{}') + ); + } catch (error) { + // return a function response error for the LLM to handle + return of( + createFunctionResponseError({ + name: functionCallName, + error, + }) + ); + } + + return EMPTY; + } + + if (!functionClient.hasFunction(functionCallName)) { + // tell the LLM the function was not found + return of( + createFunctionResponseError({ + name: functionCallName, + error: createFunctionNotFoundError(functionCallName), + }) + ); + } + + return executeFunctionAndCatchError({ + name: functionCallName, + args: lastMessage.function_call!.arguments, + chat, + functionClient, + messages: messagesWithUpdatedSystemMessage, + signal, + }); + } + + function handleEvents(): OperatorFunction { + return (events$) => { + const shared$ = events$.pipe(shareReplay()); + + return concat( + shared$, + shared$.pipe( + withoutTokenCountEvents(), + extractMessages(), + switchMap((extractedMessages) => { + if (!extractedMessages.length) { + return EMPTY; + } + return continueConversation({ + messages: messagesWithUpdatedSystemMessage.concat(extractedMessages), + chat, + functionCallsLeft: nextFunctionCallsLeft, + functionClient, + signal, + knowledgeBaseInstructions, + requestInstructions, + }); + }) + ) + ); + }; + } +} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/debug.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/debug.ts new file mode 100644 index 0000000000000..4c097bcc28c1f --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/debug.ts @@ -0,0 +1,21 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import { inspect } from 'util'; +import { dematerialize, materialize, OperatorFunction, tap } from 'rxjs'; + +export function debug(prefix: string): OperatorFunction { + return (source$) => { + return source$.pipe( + materialize(), + tap((event) => { + // eslint-disable-next-line no-console + console.log(prefix + ':\n' + inspect(event, { depth: 10 })); + }), + dematerialize() + ); + }; +} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/extract_messages.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/extract_messages.ts new file mode 100644 index 0000000000000..97c12ae9f7cbb --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/extract_messages.ts @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { filter, last, map, OperatorFunction, toArray } from 'rxjs'; +import { Message, MessageAddEvent, StreamingChatResponseEventType } from '../../../../common'; +import type { MessageOrChatEvent } from '../../../../common/conversation_complete'; + +export function extractMessages(): OperatorFunction { + return (source$) => { + return source$.pipe( + filter( + (event): event is MessageAddEvent => + event.type === StreamingChatResponseEventType.MessageAdd + ), + map((event) => event.message), + toArray(), + last() + ); + }; +} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/extract_token_count.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/extract_token_count.ts new file mode 100644 index 0000000000000..0d11db24732f3 --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/extract_token_count.ts @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { filter, OperatorFunction, scan } from 'rxjs'; +import { + StreamingChatResponseEvent, + StreamingChatResponseEventType, + TokenCountEvent, +} from '../../../../common/conversation_complete'; + +export function extractTokenCount(): OperatorFunction< + StreamingChatResponseEvent, + TokenCountEvent['tokens'] +> { + return (events$) => { + return events$.pipe( + filter( + (event): event is TokenCountEvent => + event.type === StreamingChatResponseEventType.TokenCount + ), + scan( + (acc, event) => { + acc.completion += event.tokens.completion; + acc.prompt += event.tokens.prompt; + acc.total += event.tokens.total; + return acc; + }, + { completion: 0, prompt: 0, total: 0 } + ) + ); + }; +} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/get_generated_title.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/get_generated_title.ts new file mode 100644 index 0000000000000..f35e0716f1051 --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/get_generated_title.ts @@ -0,0 +1,105 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { catchError, map, Observable, of, tap } from 'rxjs'; +import { Logger } from '@kbn/logging'; +import type { ObservabilityAIAssistantClient } from '..'; +import { Message, MessageRole } from '../../../../common'; +import { concatenateChatCompletionChunks } from '../../../../common/utils/concatenate_chat_completion_chunks'; +import { hideTokenCountEvents } from './hide_token_count_events'; +import { ChatEvent, TokenCountEvent } from '../../../../common/conversation_complete'; + +type ChatFunctionWithoutConnectorAndTokenCount = ( + name: string, + params: Omit< + Parameters[1], + 'connectorId' | 'signal' | 'simulateFunctionCalling' + > +) => Observable; + +export function getGeneratedTitle({ + responseLanguage, + messages, + chat, + logger, +}: { + responseLanguage?: string; + messages: Message[]; + chat: ChatFunctionWithoutConnectorAndTokenCount; + logger: Logger; +}): Observable { + return hideTokenCountEvents((hide) => + chat('generate_title', { + messages: [ + { + '@timestamp': new Date().toString(), + message: { + role: MessageRole.System, + content: `You are a helpful assistant for Elastic Observability. Assume the following message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you. Please create the title in ${responseLanguage}.`, + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: messages.slice(1).reduce((acc, curr) => { + return `${acc} ${curr.message.role}: ${curr.message.content}`; + }, 'Generate a title, using the title_conversation_function, based on the following conversation:\n\n'), + }, + }, + ], + functions: [ + { + name: 'title_conversation', + description: + 'Use this function to title the conversation. Do not wrap the title in quotes', + parameters: { + type: 'object', + properties: { + title: { + type: 'string', + }, + }, + required: ['title'], + }, + }, + ], + functionCall: 'title_conversation', + }).pipe( + hide(), + concatenateChatCompletionChunks(), + map((concatenatedMessage) => { + const input = + (concatenatedMessage.message.function_call.name + ? JSON.parse(concatenatedMessage.message.function_call.arguments).title + : concatenatedMessage.message?.content) || ''; + + // This regular expression captures a string enclosed in single or double quotes. + // It extracts the string content without the quotes. + // Example matches: + // - "Hello, World!" => Captures: Hello, World! + // - 'Another Example' => Captures: Another Example + // - JustTextWithoutQuotes => Captures: JustTextWithoutQuotes + const match = input.match(/^["']?([^"']+)["']?$/); + const title = match ? match[1] : input; + return title; + }), + tap((event) => { + if (typeof event === 'string') { + logger.debug(`Generated title: ${event}`); + } + }) + ) + ).pipe( + catchError((error) => { + logger.error(`Error generating title`); + logger.error(error); + // TODO: i18n + return of('New conversation'); + }) + ); +} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/hide_token_count_events.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/hide_token_count_events.ts new file mode 100644 index 0000000000000..7aabb1448382f --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/hide_token_count_events.ts @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { merge, Observable, partition } from 'rxjs'; +import type { StreamingChatResponseEvent } from '../../../../common'; +import { + StreamingChatResponseEventType, + TokenCountEvent, +} from '../../../../common/conversation_complete'; + +type Hide = () => ( + source$: Observable +) => Observable>; + +export function hideTokenCountEvents( + cb: (hide: Hide) => Observable> +): Observable { + // `hide` can be called multiple times, so we keep track of each invocation + const allInterceptors: Array> = []; + + const hide: Hide = () => (source$) => { + const [tokenCountEvents$, otherEvents$] = partition( + source$, + (value): value is TokenCountEvent => value.type === StreamingChatResponseEventType.TokenCount + ); + + allInterceptors.push(tokenCountEvents$); + + return otherEvents$; + }; + + // combine the two observables again + return merge(cb(hide), ...allInterceptors); +} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/instrument_and_count_tokens.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/instrument_and_count_tokens.ts new file mode 100644 index 0000000000000..094b2606ae533 --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/instrument_and_count_tokens.ts @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import apm from 'elastic-apm-node'; +import { + catchError, + ignoreElements, + merge, + OperatorFunction, + shareReplay, + tap, + last, + throwError, + finalize, +} from 'rxjs'; +import type { StreamingChatResponseEvent } from '../../../../common/conversation_complete'; +import { extractTokenCount } from './extract_token_count'; + +export function instrumentAndCountTokens( + name: string +): OperatorFunction { + return (source$) => { + const span = apm.startSpan(name); + + if (!span) { + return source$; + } + span?.addLabels({ + plugin: 'observability_ai_assistant', + }); + + const shared$ = source$.pipe(shareReplay()); + + let tokenCount = { + prompt: 0, + completion: 0, + total: 0, + }; + + return merge( + shared$, + shared$.pipe( + extractTokenCount(), + tap((nextTokenCount) => { + tokenCount = nextTokenCount; + }), + last(), + tap(() => { + span?.setOutcome('success'); + }), + catchError((error) => { + span?.setOutcome('failure'); + return throwError(() => error); + }), + finalize(() => { + span?.addLabels({ + tokenCountPrompt: tokenCount.prompt, + tokenCountCompletion: tokenCount.completion, + tokenCountTotal: tokenCount.total, + }); + span?.end(); + }), + ignoreElements() + ) + ); + }; +} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/types.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/types.ts index 241ecd1350c68..00ca82521875c 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/types.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/types.ts @@ -7,7 +7,7 @@ import type { FromSchema } from 'json-schema-to-ts'; import { Observable } from 'rxjs'; -import { ChatCompletionChunkEvent } from '../../common/conversation_complete'; +import { ChatCompletionChunkEvent, ChatEvent } from '../../common/conversation_complete'; import type { CompatibleJSONSchema, FunctionDefinition, @@ -27,17 +27,33 @@ export type RespondFunctionResources = Pick< 'context' | 'logger' | 'plugins' | 'request' >; -export type ChatFn = ( - ...args: Parameters -) => Promise>; +export type ChatFunction = ( + name: string, + params: Parameters[1] +) => Observable; + +export type ChatFunctionWithoutConnector = ( + name: string, + params: Omit< + Parameters[1], + 'connectorId' | 'simulateFunctionCalling' | 'signal' + > +) => Observable; + +export type FunctionCallChatFunction = ( + name: string, + params: Omit< + Parameters[1], + 'connectorId' | 'simulateFunctionCalling' + > +) => Observable; type RespondFunction = ( options: { arguments: TArguments; messages: Message[]; - connectorId: string; screenContexts: ObservabilityAIAssistantScreenContextRequest[]; - chat: ChatFn; + chat: FunctionCallChatFunction; }, signal: AbortSignal ) => Promise; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/catch_function_limit_exceeded_error.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/catch_function_limit_exceeded_error.ts index 25eecc7e7723e..01c9713f0f8e6 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/catch_function_limit_exceeded_error.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/catch_function_limit_exceeded_error.ts @@ -9,16 +9,15 @@ import { i18n } from '@kbn/i18n'; import { catchError, filter, of, OperatorFunction, shareReplay, throwError } from 'rxjs'; import { ChatCompletionChunkEvent, - MessageAddEvent, MessageRole, StreamingChatResponseEventType, } from '../../../common'; -import { isFunctionNotFoundError } from '../../../common/conversation_complete'; +import { isFunctionNotFoundError, MessageOrChatEvent } from '../../../common/conversation_complete'; import { emitWithConcatenatedMessage } from '../../../common/utils/emit_with_concatenated_message'; export function catchFunctionLimitExceededError(): OperatorFunction< - ChatCompletionChunkEvent | MessageAddEvent, - ChatCompletionChunkEvent | MessageAddEvent + MessageOrChatEvent, + MessageOrChatEvent > { return (source$) => { const shared$ = source$.pipe(shareReplay()); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/observable_into_stream.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/observable_into_stream.ts index f6fe506367f2f..3ca09acde2b6f 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/observable_into_stream.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/observable_into_stream.ts @@ -13,11 +13,10 @@ import { isChatCompletionError, StreamingChatResponseEventType, StreamingChatResponseEventWithoutError, - TokenCountEvent, } from '../../../common/conversation_complete'; export function observableIntoStream( - source: Observable + source: Observable ) { const stream = new PassThrough(); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/reject_token_count_events.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/reject_token_count_events.ts deleted file mode 100644 index b8e563495d1d7..0000000000000 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/reject_token_count_events.ts +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import { filter, Observable } from 'rxjs'; -import { - ChatCompletionChunkEvent, - StreamingChatResponseEventType, - TokenCountEvent, -} from '../../../common/conversation_complete'; - -export function rejectTokenCountEvents() { - return ( - source: Observable - ): Observable> => { - return source.pipe( - filter( - (event): event is Exclude => - event.type !== StreamingChatResponseEventType.TokenCount - ) - ); - }; -} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/with_assistant_span.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/with_assistant_span.ts new file mode 100644 index 0000000000000..44494e978d804 --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/with_assistant_span.ts @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import { withSpan, SpanOptions, parseSpanOptions } from '@kbn/apm-utils'; + +export function withAssistantSpan( + optionsOrName: SpanOptions | string, + cb: () => Promise +): Promise { + const options = parseSpanOptions(optionsOrName); + + const optionsWithDefaults = { + ...(options.intercept ? {} : { type: 'plugin:observability_ai_assistant' }), + ...options, + labels: { + plugin: 'observability_ai_assistant', + ...options.labels, + }, + }; + + return withSpan(optionsWithDefaults, cb); +} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/tsconfig.json b/x-pack/plugins/observability_solution/observability_ai_assistant/tsconfig.json index eee8ea0d56911..cc66498fa5d4e 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/tsconfig.json +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/tsconfig.json @@ -48,6 +48,7 @@ "@kbn/cloud-plugin", "@kbn/serverless", "@kbn/triggers-actions-ui-plugin", + "@kbn/apm-utils" ], "exclude": ["target/**/*"] } diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant_app/server/functions/query/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant_app/server/functions/query/index.ts index 2cf8600d9db2f..ab0964fdc6216 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant_app/server/functions/query/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant_app/server/functions/query/index.ts @@ -142,7 +142,7 @@ export function registerQueryFunction({ functions, resources }: FunctionRegistra description: `This function generates, executes and/or visualizes a query based on the user's request. It also explains how ES|QL works and how to convert queries from one language to another. Make sure you call one of the get_dataset functions first if you need index or field names. This function takes no input.`, visibility: FunctionVisibility.AssistantOnly, }, - async ({ messages, connectorId, chat }, signal) => { + async ({ messages, chat }, signal) => { const [systemMessage, esqlDocs] = await Promise.all([loadSystemMessage(), loadEsqlDocs()]); const withEsqlSystemMessage = (message?: string) => [ @@ -155,7 +155,6 @@ export function registerQueryFunction({ functions, resources }: FunctionRegistra const source$ = ( await chat('classify_esql', { - connectorId, messages: withEsqlSystemMessage().concat({ '@timestamp': new Date().toISOString(), message: { @@ -382,7 +381,6 @@ export function registerQueryFunction({ functions, resources }: FunctionRegistra }, }, ], - connectorId, signal, functions: functions.getActions(), }); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant_app/server/functions/visualize_esql.ts b/x-pack/plugins/observability_solution/observability_ai_assistant_app/server/functions/visualize_esql.ts index 1a7d64c0d324f..1523ca510238a 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant_app/server/functions/visualize_esql.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant_app/server/functions/visualize_esql.ts @@ -27,25 +27,22 @@ export function registerVisualizeESQLFunction({ functions, resources, }: FunctionRegistrationParameters) { - functions.registerFunction( - visualizeESQLFunction, - async ({ arguments: { query, intention }, connectorId, messages }, signal) => { - const { columns, errorMessages } = await validateEsqlQuery({ - query, - client: (await resources.context.core).elasticsearch.client.asCurrentUser, - }); + functions.registerFunction(visualizeESQLFunction, async ({ arguments: { query, intention } }) => { + const { columns, errorMessages } = await validateEsqlQuery({ + query, + client: (await resources.context.core).elasticsearch.client.asCurrentUser, + }); - const message = getMessageForLLM(intention, query, Boolean(errorMessages?.length)); + const message = getMessageForLLM(intention, query, Boolean(errorMessages?.length)); - return { - data: { - columns, - }, - content: { - message, - errorMessages, - }, - }; - } - ); + return { + data: { + columns, + }, + content: { + message, + errorMessages, + }, + }; + }); } diff --git a/x-pack/test/observability_ai_assistant_api_integration/tests/chat/chat.spec.ts b/x-pack/test/observability_ai_assistant_api_integration/tests/chat/chat.spec.ts index b69ef1c512fa1..a9cf749b3d761 100644 --- a/x-pack/test/observability_ai_assistant_api_integration/tests/chat/chat.spec.ts +++ b/x-pack/test/observability_ai_assistant_api_integration/tests/chat/chat.spec.ts @@ -160,6 +160,7 @@ export default function ApiTest({ getService }: FtrProviderContext) { connectorId, functions: [], }) + .expect(200) .pipe(passThrough); let data: string = ''; @@ -188,9 +189,9 @@ export default function ApiTest({ getService }: FtrProviderContext) { await new Promise((resolve) => passThrough.on('end', () => resolve())); - const response = JSON.parse(data); + const response = JSON.parse(data.trim()); - expect(response.message).to.be( + expect(response.error.message).to.be( `Token limit reached. Token limit is 8192, but the current conversation has 11036 tokens.` ); });