Skip to content

Commit

Permalink
[Obs AI Assistant] Instructions & Claude improvements (elastic#181058)
Browse files Browse the repository at this point in the history
When we send over a conversation to the LLM for completion, we include a
system message. System messages are a way for the consumer (in this
case, us as developers) to control the LLM's behavior.

This system message was previously constructed by using a concept called
`ContextDefinition` - originally this was a way to define a set of
functions and behavior for a specific context, e.g. core functionality,
APM-specific functionality, platform-specific functionality etc. However
we never actually did anything with this, and much of its intended
functionality is now captured with the screen context API.

In elastic#179736, we added user
instructions, which are ways for the user to control the Assistant's
behaviour, by appending to the system message we construct with the
registered context definitions.

With this PR, we are making several changes:

- Remove the concept of concept definitions entirely
- Replace it with `registerInstruction`, which allows the consumer to
register pieces of text that will be included in the system message.
- `registerInstruction` _also_ takes a callback. That callback receives
the available function names for that specific chat request. For
instance, when we reach the function call limit, the LLM will have no
functions to call. This allows consumers to cater their instructions to
this specific scenario, which somewhat limits the possibility of the LLM
calling a function that it is not allowed to - Claude is especially
prone to this (likely related to the fact we use simulated function
calling).

This leads to the following functional changes:
- A system message is now constructed by combining the registered
instructions (system-specific) with the knowledge base and request
instructions (user-specific)
- `GET /internal/observability_ai_assistant/functions` no longer returns
the contexts. Instead it returns the system message
- `GET /internal/observability_ai_assistant/chat/complete` now creates a
system message at the start, and overrides the system message from the
request.
- For each invocation of `chat`, it re-calculates the system message by
"materializing" the registered instructions with the available function
names for that chat invocation

Additionally, I've made some attempted improvements to simulated
function calling:
- simplified the system message
- more emphasis on generating valid JSON (e.g. I saw multiline
delimiters `"""` which are not supported)
- more emphasis on not providing any input if the function does not
accept any parameters. e.g. Claude was trying to provide entire search
requests or SPL-like query strings as input, which led to
hallucinations)

There are also some other changes, which I've commented on in the file
changes.

**Addendum: I have pushed some more changes, related to the evaluation
framework (and running it with Claude). Will comment inline in
[`9ebd207`
(elastic#181058)](https://github.com/elastic/kibana/pull/181058/commits/9ebd207acd47c33077627356c464958240c9d446).**
  • Loading branch information
dgieselaar authored Apr 24, 2024
1 parent 800289c commit ba76b50
Show file tree
Hide file tree
Showing 62 changed files with 960 additions and 551 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ export function registerGetApmDatasetInfoFunction({
registerFunction(
{
name: 'get_apm_dataset_info',
contexts: ['core'],
visibility: FunctionVisibility.AssistantOnly,
description: `Use this function to get information about APM data.`,
parameters: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ export function registerGetApmDownstreamDependenciesFunction({
registerFunction(
{
name: 'get_apm_downstream_dependencies',
contexts: ['core'],
description: `Get the downstream dependencies (services or uninstrumented backends) for a
service. This allows you to map the downstream dependency name to a service, by
returning both span.destination.service.resource and service.name. Use this to
Expand All @@ -39,7 +38,8 @@ export function registerGetApmDownstreamDependenciesFunction({
},
'service.environment': {
type: 'string',
description: 'The environment that the service is running in',
description:
'The environment that the service is running in. Leave empty to query for all environments.',
},
start: {
type: 'string',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ export function registerGetApmServicesListFunction({
registerFunction(
{
name: 'get_apm_services_list',
contexts: ['apm'],
description: `Gets a list of services`,
descriptionForUser: i18n.translate(
'xpack.apm.observabilityAiAssistant.functions.registerGetApmServicesList.descriptionForUser',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ export function registerGetApmTimeseriesFunction({
}: FunctionRegistrationParameters) {
registerFunction(
{
contexts: ['core'],
name: 'get_apm_timeseries',
description: `Visualise and analyse different APM metrics, like throughput, failure rate, or latency, for any service or all services, or any or all of its dependencies, both as a timeseries and as a single statistic. A visualisation will be displayed above your reply - DO NOT attempt to display or generate an image yourself, or any other placeholder. Additionally, the function will return any changes, such as spikes, step and trend changes, or dips. You can also use it to compare data by requesting two different time ranges, or for instance two different service versions.`,
parameters,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ export function registerAssistantFunctions({
ruleDataClient: IRuleDataClient;
plugins: APMRouteHandlerResources['plugins'];
}): RegistrationCallback {
return async ({ resources, functions: { registerContext, registerFunction } }) => {
return async ({ resources, functions: { registerFunction } }) => {
const apmRouteHandlerResources: MinimalAPMRouteHandlerResources = {
context: resources.context,
request: resources.request,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ export type CompatibleJSONSchema = {
description?: string;
};

export interface ContextDefinition {
name: string;
description: string;
}

export type FunctionResponse =
| {
content?: any;
Expand All @@ -46,10 +41,6 @@ export interface FunctionDefinition<TParameters extends CompatibleJSONSchema = a
visibility?: FunctionVisibility;
descriptionForUser?: string;
parameters?: TParameters;
contexts: string[];
}

export type RegisterContextDefinition = (options: ContextDefinition) => void;

export type ContextRegistry = Map<string, ContextDefinition>;
export type FunctionRegistry = Map<string, FunctionDefinition>;
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ export interface UserInstruction {
text: string;
}

export type UserInstructionOrPlainText = string | UserInstruction;

export interface ObservabilityAIAssistantScreenContextRequest {
screenDescription?: string;
data?: Array<{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

import { concat, last, mergeMap, Observable, shareReplay, withLatestFrom } from 'rxjs';
import { concat, from, last, mergeMap, Observable, shareReplay, withLatestFrom } from 'rxjs';
import {
ChatCompletionChunkEvent,
MessageAddEvent,
Expand All @@ -16,8 +16,32 @@ import {
ConcatenatedMessage,
} from './concatenate_chat_completion_chunks';

type ConcatenateMessageCallback = (
concatenatedMessage: ConcatenatedMessage
) => Promise<ConcatenatedMessage>;

function mergeWithEditedMessage(
originalMessage: ConcatenatedMessage,
chunkEvent: ChatCompletionChunkEvent,
callback?: ConcatenateMessageCallback
): Observable<MessageAddEvent> {
return from(
(callback ? callback(originalMessage) : Promise.resolve(originalMessage)).then((message) => {
const next: MessageAddEvent = {
type: StreamingChatResponseEventType.MessageAdd as const,
id: chunkEvent.id,
message: {
'@timestamp': new Date().toISOString(),
...message,
},
};
return next;
})
);
}

export function emitWithConcatenatedMessage(
callback?: (concatenatedMessage: ConcatenatedMessage) => Promise<ConcatenatedMessage>
callback?: ConcatenateMessageCallback
): (
source$: Observable<ChatCompletionChunkEvent>
) => Observable<ChatCompletionChunkEvent | MessageAddEvent> {
Expand All @@ -30,17 +54,8 @@ export function emitWithConcatenatedMessage(
concatenateChatCompletionChunks(),
last(),
withLatestFrom(source$),
mergeMap(async ([message, chunkEvent]) => {
const next: MessageAddEvent = {
type: StreamingChatResponseEventType.MessageAdd as const,
id: chunkEvent.id,
message: {
'@timestamp': new Date().toISOString(),
...(callback ? await callback(message) : message),
},
};

return next;
mergeMap(([message, chunkEvent]) => {
return mergeWithEditedMessage(message, chunkEvent, callback);
})
)
);
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,18 @@
import type { FunctionDefinition } from '../functions/types';

export function filterFunctionDefinitions({
contexts,
filter,
definitions,
}: {
contexts?: string[];
filter?: string;
definitions: FunctionDefinition[];
}) {
return contexts || filter
return filter
? definitions.filter((fn) => {
const matchesContext =
!contexts || fn.contexts.some((context) => contexts.includes(context));
const matchesFilter =
!filter || fn.name.includes(filter) || fn.description.includes(filter);

return matchesContext && matchesFilter;
return matchesFilter;
})
: definitions;
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,21 @@
* 2.0.
*/

import { filter, Observable, tap } from 'rxjs';
import { filter, OperatorFunction, tap } from 'rxjs';
import {
ChatCompletionError,
ChatCompletionErrorCode,
type StreamingChatResponseEvent,
StreamingChatResponseEventType,
type ChatCompletionErrorEvent,
BufferFlushEvent,
} from '../conversation_complete';

export function throwSerializedChatCompletionErrors() {
return <T extends StreamingChatResponseEvent>(
source$: Observable<StreamingChatResponseEvent>
): Observable<Exclude<T, ChatCompletionErrorEvent>> => {
return source$.pipe(
export function throwSerializedChatCompletionErrors<
T extends StreamingChatResponseEvent | BufferFlushEvent
>(): OperatorFunction<T, Exclude<T, ChatCompletionErrorEvent>> {
return (source$) =>
source$.pipe(
tap((event) => {
// de-serialise error
if (event.type === StreamingChatResponseEventType.ChatCompletionError) {
Expand All @@ -33,5 +34,4 @@ export function throwSerializedChatCompletionErrors() {
event.type !== StreamingChatResponseEventType.ChatCompletionError
)
);
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,55 @@ import {
EuiFlexGroup,
EuiFlexItem,
EuiPanel,
UseEuiTheme,
useEuiTheme,
} from '@elastic/eui';
import { css } from '@emotion/css';
import { i18n } from '@kbn/i18n';
import React from 'react';
import { ChatActionClickHandler, ChatActionClickType } from '../chat/types';

const getCodeBlockClassName = (theme: UseEuiTheme) => css`
background-color: ${theme.euiTheme.colors.lightestShade};
.euiCodeBlock__pre {
margin-bottom: 0;
padding: ${theme.euiTheme.size.m};
min-block-size: 48px;
}
.euiCodeBlock__controls {
inset-block-start: ${theme.euiTheme.size.m};
inset-inline-end: ${theme.euiTheme.size.m};
}
`;

function CodeBlockWrapper({ children }: { children: React.ReactNode }) {
const theme = useEuiTheme();
return (
<EuiPanel
hasShadow={false}
hasBorder={false}
paddingSize="s"
className={getCodeBlockClassName(theme)}
>
{children}
</EuiPanel>
);
}

export function CodeBlock({ children }: { children: React.ReactNode }) {
return (
<CodeBlockWrapper>
<EuiFlexGroup direction="column" gutterSize="xs">
<EuiFlexItem grow={false}>
<EuiCodeBlock isCopyable fontSize="m">
{children}
</EuiCodeBlock>
</EuiFlexItem>
</EuiFlexGroup>
</CodeBlockWrapper>
);
}

export function EsqlCodeBlock({
value,
actionsDisabled,
Expand All @@ -26,26 +68,8 @@ export function EsqlCodeBlock({
actionsDisabled: boolean;
onActionClick: ChatActionClickHandler;
}) {
const theme = useEuiTheme();

return (
<EuiPanel
hasShadow={false}
hasBorder={false}
paddingSize="s"
className={css`
background-color: ${theme.euiTheme.colors.lightestShade};
.euiCodeBlock__pre {
margin-bottom: 0;
padding: ${theme.euiTheme.size.m};
min-block-size: 48px;
}
.euiCodeBlock__controls {
inset-block-start: ${theme.euiTheme.size.m};
inset-inline-end: ${theme.euiTheme.size.m};
}
`}
>
<CodeBlockWrapper>
<EuiFlexGroup direction="column" gutterSize="xs">
<EuiFlexItem grow={false}>
<EuiCodeBlock isCopyable fontSize="m">
Expand Down Expand Up @@ -87,6 +111,6 @@ export function EsqlCodeBlock({
</EuiFlexGroup>
</EuiFlexItem>
</EuiFlexGroup>
</EuiPanel>
</CodeBlockWrapper>
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import type { Code, InlineCode, Parent, Text } from 'mdast';
import React, { useMemo, useRef } from 'react';
import type { Node } from 'unist';
import { ChatActionClickHandler } from '../chat/types';
import { EsqlCodeBlock } from './esql_code_block';
import { CodeBlock, EsqlCodeBlock } from './esql_code_block';

interface Props {
content: string;
Expand Down Expand Up @@ -104,6 +104,9 @@ const esqlLanguagePlugin = () => {

if (node.type === 'code' && node.lang === 'esql') {
node.type = 'esql';
} else if (node.type === 'code') {
// switch to type that allows us to control rendering
node.type = 'codeBlock';
}
};

Expand Down Expand Up @@ -131,6 +134,14 @@ export function MessageText({ loading, content, onActionClick }: Props) {
processingPlugins[1][1].components = {
...components,
cursor: Cursor,
codeBlock: (props) => {
return (
<>
<CodeBlock>{props.value}</CodeBlock>
<EuiSpacer size="m" />
</>
);
},
esql: (props) => {
return (
<>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,17 @@ const mockChatService: MockedChatService = {
chat: jest.fn(),
complete: jest.fn(),
sendAnalyticsEvent: jest.fn(),
getContexts: jest.fn().mockReturnValue([{ name: 'core', description: '' }]),
getFunctions: jest.fn().mockReturnValue([]),
hasFunction: jest.fn().mockReturnValue(false),
hasRenderFunction: jest.fn().mockReturnValue(true),
renderFunction: jest.fn(),
getSystemMessage: jest.fn().mockReturnValue({
'@timestamp': new Date().toISOString(),
message: {
content: 'system',
role: MessageRole.System,
},
}),
};

const addErrorMock = jest.fn();
Expand Down
Loading

0 comments on commit ba76b50

Please sign in to comment.