Skip to content

Commit

Permalink
[Obs AI Assistant] Regenerate from last user message (elastic#174911)
Browse files Browse the repository at this point in the history
## Summary

Changes the Regenerate functionality to look back to the last written
message from the user and use that as the cut of point for where to
generate the next messages.
  • Loading branch information
miltonhultgren authored and CoenWarmer committed Feb 15, 2024
1 parent 1f8cce7 commit 602e8ad
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { type Message } from '../../../common/types';
import { reverseToLastUserMessage } from './chat_body';

describe('<ChatBody>', () => {
describe('reverseToLastUserMessage', () => {
const firstUserMessage = {
message: {
role: 'user',
content: 'Give me a list of my APM services',
},
};
const firstUserMessageIndex = 1;

const secondUserMessage = {
message: {
role: 'user',
content: 'Can you tell me about the synth-go-0 service',
},
};
const secondUserMessageIndex = 7;

const messages = [
{
message: {
role: 'system',
content: "You're the best around",
},
},
firstUserMessage,
{
message: {
role: 'assistant',
function_call: {
name: 'recall',
arguments: '{"queries":[],"contexts":[]}',
trigger: 'assistant',
},
content: '',
},
},
{
message: {
role: 'user',
name: 'recall',
content: '[]',
},
},
{
message: {
role: 'assistant',
function_call: {
name: 'get_apm_services_list',
arguments: '{\n "start": "now-1h",\n "end": "now"\n}',
trigger: 'assistant',
},
content: '',
},
},
{
message: {
role: 'user',
name: 'get_apm_services_list',
content: '[{"service.name":"synth-go-0"}]',
},
},
{
message: {
role: 'assistant',
function_call: {
name: '',
arguments: '',
trigger: 'assistant',
},
content: 'Here is a list of your APM services:\n\n1. synth-go-0\n',
},
},
secondUserMessage,
{
message: {
role: 'assistant',
function_call: {
name: 'recall',
arguments: '{"queries":[],"contexts":[]}',
trigger: 'assistant',
},
content: '',
},
},
{
message: {
role: 'user',
name: 'recall',
content: '[]',
},
},
{
message: {
role: 'assistant',
function_call: {
name: 'get_apm_service_summary',
arguments:
'{\n "service.name": "synth-go-0",\n "start": "now-1h",\n "end": "now"\n}',
trigger: 'assistant',
},
content: '',
},
},
{
message: {
role: 'user',
name: 'get_apm_service_summary',
content: '{"service.name":"synth-go-0"}',
},
},
{
message: {
role: 'assistant',
function_call: {
name: '',
arguments: '',
trigger: 'assistant',
},
content: 'The service named "synth-go-0" is really neat.',
},
},
] as unknown as Message[];

it('goes back to the last written user message when regenerating from the end of the conversation', () => {
const nextMessages = reverseToLastUserMessage(messages, messages.at(-1)!);
expect(nextMessages).toEqual(messages.slice(0, secondUserMessageIndex + 1));
});

it('goes back to the last written user message when regenerating from the middle of the conversation', () => {
const nextMessages = reverseToLastUserMessage(
messages,
messages.at(secondUserMessageIndex - 1)!
);
expect(nextMessages).toEqual(messages.slice(0, firstUserMessageIndex + 1));
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
import type { AuthenticatedUser } from '@kbn/security-plugin/common';
import { euiThemeVars } from '@kbn/ui-theme';
import { i18n } from '@kbn/i18n';
import { findLastIndex } from 'lodash';
import { ChatState } from '../../hooks/use_chat';
import { useConversation } from '../../hooks/use_conversation';
import { useLicense } from '../../hooks/use_license';
Expand Down Expand Up @@ -263,8 +264,7 @@ export function ChatBody({
}}
onFeedback={handleFeedback}
onRegenerate={(message) => {
const indexOf = messages.indexOf(message);
next(messages.slice(0, indexOf));
next(reverseToLastUserMessage(messages, message));
}}
onSendTelemetry={(eventWithPayload) =>
sendEvent(chatService.analytics, eventWithPayload)
Expand Down Expand Up @@ -410,3 +410,19 @@ export function ChatBody({
</EuiFlexGroup>
);
}

// Exported for testing only
export function reverseToLastUserMessage(messages: Message[], message: Message) {
// Drop messages after and including the one marked for regeneration
const indexOf = messages.indexOf(message);
const previousMessages = messages.slice(0, indexOf);

// Go back to the last written user message to fully regenerate function calls
const lastUserMessageIndex = findLastIndex(
previousMessages,
(aMessage: Message) => aMessage.message.role === 'user' && !aMessage.message.name
);
const nextMessages = previousMessages.slice(0, lastUserMessageIndex + 1);

return nextMessages;
}

0 comments on commit 602e8ad

Please sign in to comment.