diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.test.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.test.tsx new file mode 100644 index 0000000000000..23e1e61e16dc7 --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.test.tsx @@ -0,0 +1,148 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { type Message } from '../../../common/types'; +import { reverseToLastUserMessage } from './chat_body'; + +describe('', () => { + describe('reverseToLastUserMessage', () => { + const firstUserMessage = { + message: { + role: 'user', + content: 'Give me a list of my APM services', + }, + }; + const firstUserMessageIndex = 1; + + const secondUserMessage = { + message: { + role: 'user', + content: 'Can you tell me about the synth-go-0 service', + }, + }; + const secondUserMessageIndex = 7; + + const messages = [ + { + message: { + role: 'system', + content: "You're the best around", + }, + }, + firstUserMessage, + { + message: { + role: 'assistant', + function_call: { + name: 'recall', + arguments: '{"queries":[],"contexts":[]}', + trigger: 'assistant', + }, + content: '', + }, + }, + { + message: { + role: 'user', + name: 'recall', + content: '[]', + }, + }, + { + message: { + role: 'assistant', + function_call: { + name: 'get_apm_services_list', + arguments: '{\n "start": "now-1h",\n "end": "now"\n}', + trigger: 'assistant', + }, + content: '', + }, + }, + { + message: { + role: 'user', + name: 'get_apm_services_list', + content: '[{"service.name":"synth-go-0"}]', + }, + }, + { + message: { + role: 'assistant', + function_call: { + name: '', + arguments: '', + trigger: 'assistant', + }, + content: 'Here is a list of your APM services:\n\n1. synth-go-0\n', + }, + }, + secondUserMessage, + { + message: { + role: 'assistant', + function_call: { + name: 'recall', + arguments: '{"queries":[],"contexts":[]}', + trigger: 'assistant', + }, + content: '', + }, + }, + { + message: { + role: 'user', + name: 'recall', + content: '[]', + }, + }, + { + message: { + role: 'assistant', + function_call: { + name: 'get_apm_service_summary', + arguments: + '{\n "service.name": "synth-go-0",\n "start": "now-1h",\n "end": "now"\n}', + trigger: 'assistant', + }, + content: '', + }, + }, + { + message: { + role: 'user', + name: 'get_apm_service_summary', + content: '{"service.name":"synth-go-0"}', + }, + }, + { + message: { + role: 'assistant', + function_call: { + name: '', + arguments: '', + trigger: 'assistant', + }, + content: 'The service named "synth-go-0" is really neat.', + }, + }, + ] as unknown as Message[]; + + it('goes back to the last written user message when regenerating from the end of the conversation', () => { + const nextMessages = reverseToLastUserMessage(messages, messages.at(-1)!); + expect(nextMessages).toEqual(messages.slice(0, secondUserMessageIndex + 1)); + }); + + it('goes back to the last written user message when regenerating from the middle of the conversation', () => { + const nextMessages = reverseToLastUserMessage( + messages, + messages.at(secondUserMessageIndex - 1)! + ); + expect(nextMessages).toEqual(messages.slice(0, firstUserMessageIndex + 1)); + }); + }); +}); diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx index 5b7e6b509e69f..15cd44119ce28 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx @@ -20,6 +20,7 @@ import { import type { AuthenticatedUser } from '@kbn/security-plugin/common'; import { euiThemeVars } from '@kbn/ui-theme'; import { i18n } from '@kbn/i18n'; +import { findLastIndex } from 'lodash'; import { ChatState } from '../../hooks/use_chat'; import { useConversation } from '../../hooks/use_conversation'; import { useLicense } from '../../hooks/use_license'; @@ -263,8 +264,7 @@ export function ChatBody({ }} onFeedback={handleFeedback} onRegenerate={(message) => { - const indexOf = messages.indexOf(message); - next(messages.slice(0, indexOf)); + next(reverseToLastUserMessage(messages, message)); }} onSendTelemetry={(eventWithPayload) => sendEvent(chatService.analytics, eventWithPayload) @@ -410,3 +410,19 @@ export function ChatBody({ ); } + +// Exported for testing only +export function reverseToLastUserMessage(messages: Message[], message: Message) { + // Drop messages after and including the one marked for regeneration + const indexOf = messages.indexOf(message); + const previousMessages = messages.slice(0, indexOf); + + // Go back to the last written user message to fully regenerate function calls + const lastUserMessageIndex = findLastIndex( + previousMessages, + (aMessage: Message) => aMessage.message.role === 'user' && !aMessage.message.name + ); + const nextMessages = previousMessages.slice(0, lastUserMessageIndex + 1); + + return nextMessages; +}