Skip to content

Commit

Permalink
[Inference] Implement NL-to-ESQL task (#190433)
Browse files Browse the repository at this point in the history
Implements the NL-to-ESQL task and migrates the Observability AI
Assistant to use the new task. Most of the files are simply generated
documentation. I've also included two scripts: one to generate the
documentation, and another to evaluate the task against a real LLM.

TBD: run evaluation framework in Observability AI Assistant to ensure
there are no performance regressions.

---------

Co-authored-by: Elastic Machine <[email protected]>
Co-authored-by: pgayvallet <[email protected]>
Co-authored-by: kibanamachine <[email protected]>
  • Loading branch information
4 people authored Sep 4, 2024
1 parent d7e5559 commit 5c298a1
Show file tree
Hide file tree
Showing 326 changed files with 6,257 additions and 4,632 deletions.
6 changes: 3 additions & 3 deletions x-pack/plugins/inference/common/chat_complete/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
*/

import type { Observable } from 'rxjs';
import type { InferenceTaskEventBase } from '../tasks';
import type { InferenceTaskEventBase } from '../inference_task';
import type { ToolCall, ToolCallsOf, ToolOptions } from './tools';

export enum MessageRole {
Expand Down Expand Up @@ -34,7 +34,7 @@ export type ToolMessage<TToolResponse extends Record<string, any> | unknown> =

export type Message = UserMessage | AssistantMessage | ToolMessage<unknown>;

export type ChatCompletionMessageEvent<TToolOptions extends ToolOptions> =
export type ChatCompletionMessageEvent<TToolOptions extends ToolOptions = ToolOptions> =
InferenceTaskEventBase<ChatCompletionEventType.ChatCompletionMessage> & {
content: string;
} & { toolCalls: ToolCallsOf<TToolOptions>['toolCalls'] };
Expand Down Expand Up @@ -87,7 +87,7 @@ export type ChatCompletionEvent<TToolOptions extends ToolOptions = ToolOptions>
* @param {ToolChoice} [options.toolChoice] Force the LLM to call a (specific) tool, or no tool
* @param {Record<string, ToolDefinition>} [options.tools] A map of tools that can be called by the LLM
*/
export type ChatCompleteAPI<TToolOptions extends ToolOptions = ToolOptions> = (
export type ChatCompleteAPI = <TToolOptions extends ToolOptions = ToolOptions>(
options: {
connectorId: string;
system?: string;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { ChatCompletionChunkEvent, ChatCompletionEvent, ChatCompletionEventType } from '.';

export function isChatCompletionChunkEvent(
event: ChatCompletionEvent
): event is ChatCompletionChunkEvent {
return event.type === ChatCompletionEventType.ChatCompletionChunk;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { ChatCompletionEvent, ChatCompletionEventType } from '.';
import { InferenceTaskEvent } from '../inference_task';

export function isChatCompletionEvent(event: InferenceTaskEvent): event is ChatCompletionEvent {
return (
event.type === ChatCompletionEventType.ChatCompletionChunk ||
event.type === ChatCompletionEventType.ChatCompletionMessage ||
event.type === ChatCompletionEventType.ChatCompletionTokenCount
);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { ChatCompletionEvent, ChatCompletionEventType, ChatCompletionMessageEvent } from '.';
import type { ToolOptions } from './tools';

export function isChatCompletionMessageEvent<T extends ToolOptions<string>>(
event: ChatCompletionEvent<T>
): event is ChatCompletionMessageEvent<T> {
return event.type === ChatCompletionEventType.ChatCompletionMessage;
}
25 changes: 14 additions & 11 deletions x-pack/plugins/inference/common/chat_complete/tool_schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ interface ToolSchemaFragmentBase {

interface ToolSchemaTypeObject extends ToolSchemaFragmentBase {
type: 'object';
properties: Record<string, ToolSchemaType>;
properties?: Record<string, ToolSchemaType>;
required?: string[] | readonly string[];
}

Expand Down Expand Up @@ -47,16 +47,19 @@ export type ToolSchemaType =
| ToolSchemaTypeNumber
| ToolSchemaTypeArray;

type FromToolSchemaObject<TToolSchemaObject extends ToolSchemaTypeObject> = Required<
{
[key in keyof TToolSchemaObject['properties']]?: FromToolSchema<
TToolSchemaObject['properties'][key]
>;
},
TToolSchemaObject['required'] extends string[] | readonly string[]
? ValuesType<TToolSchemaObject['required']>
: never
>;
type FromToolSchemaObject<TToolSchemaObject extends ToolSchemaTypeObject> =
TToolSchemaObject extends { properties: Record<string, any> }
? Required<
{
[key in keyof TToolSchemaObject['properties']]?: FromToolSchema<
TToolSchemaObject['properties'][key]
>;
},
TToolSchemaObject['required'] extends string[] | readonly string[]
? ValuesType<TToolSchemaObject['required']>
: never
>
: never;

type FromToolSchemaArray<TToolSchemaObject extends ToolSchemaTypeArray> = Array<
FromToolSchema<TToolSchemaObject['items']>
Expand Down
8 changes: 3 additions & 5 deletions x-pack/plugins/inference/common/chat_complete/tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,9 @@ export type ToolCallsOf<TToolOptions extends ToolOptions> = TToolOptions extends
? TToolOptions extends { toolChoice: ToolChoiceType.none }
? { toolCalls: [] }
: {
toolCalls: ToolResponsesOf<
Assert<ToolsOfChoice<TToolOptions>, Record<string, ToolDefinition> | undefined>
>;
toolCalls: ToolResponsesOf<ToolsOfChoice<TToolOptions>>;
}
: { toolCalls: never[] };
: { toolCalls: never };

export enum ToolChoiceType {
none = 'none',
Expand All @@ -70,7 +68,7 @@ export interface UnvalidatedToolCall {

export interface ToolCall<
TName extends string = string,
TArguments extends Record<string, any> | undefined = undefined
TArguments extends Record<string, any> | undefined = Record<string, any> | undefined
> {
toolCallId: string;
function: {
Expand Down
31 changes: 31 additions & 0 deletions x-pack/plugins/inference/common/ensure_multi_turn.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { Message, MessageRole } from './chat_complete';

function isUserMessage(message: Message): boolean {
return message.role !== MessageRole.Assistant;
}

export function ensureMultiTurn(messages: Message[]): Message[] {
const next: Message[] = [];

messages.forEach((message) => {
const prevMessage = next[next.length - 1];

if (prevMessage && isUserMessage(prevMessage) === isUserMessage(message)) {
next.push({
content: '-',
role: isUserMessage(message) ? MessageRole.Assistant : MessageRole.User,
});
}

next.push(message);
});

return next;
}
9 changes: 5 additions & 4 deletions x-pack/plugins/inference/common/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/
import { i18n } from '@kbn/i18n';
import { InferenceTaskEventBase, InferenceTaskEventType } from './tasks';
import { InferenceTaskEventBase, InferenceTaskEventType } from './inference_task';

export enum InferenceTaskErrorCode {
internalError = 'internalError',
Expand Down Expand Up @@ -42,7 +42,7 @@ export type InferenceTaskErrorEvent = InferenceTaskEventBase<InferenceTaskEventT

export type InferenceTaskInternalError = InferenceTaskError<
InferenceTaskErrorCode.internalError,
{}
Record<string, any>
>;

export type InferenceTaskRequestError = InferenceTaskError<
Expand All @@ -53,9 +53,10 @@ export type InferenceTaskRequestError = InferenceTaskError<
export function createInferenceInternalError(
message: string = i18n.translate('xpack.inference.internalError', {
defaultMessage: 'An internal error occurred',
})
}),
meta?: Record<string, any>
): InferenceTaskInternalError {
return new InferenceTaskError(InferenceTaskErrorCode.internalError, message, {});
return new InferenceTaskError(InferenceTaskErrorCode.internalError, message, meta ?? {});
}

export function createInferenceRequestError(
Expand Down
31 changes: 31 additions & 0 deletions x-pack/plugins/inference/common/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

export {
correctCommonEsqlMistakes,
splitIntoCommands,
} from './tasks/nl_to_esql/correct_common_esql_mistakes';

export { isChatCompletionChunkEvent } from './chat_complete/is_chat_completion_chunk_event';
export { isChatCompletionMessageEvent } from './chat_complete/is_chat_completion_message_event';
export { isChatCompletionEvent } from './chat_complete/is_chat_completion_event';

export { isOutputUpdateEvent } from './output/is_output_update_event';
export { isOutputCompleteEvent } from './output/is_output_complete_event';
export { isOutputEvent } from './output/is_output_event';

export type { ToolSchema } from './chat_complete/tool_schema';

export {
type Message,
MessageRole,
type ToolMessage,
type AssistantMessage,
type UserMessage,
} from './chat_complete';

export { generateFakeToolCallId } from './chat_complete/generate_fake_tool_call_id';
File renamed without changes.
24 changes: 18 additions & 6 deletions x-pack/plugins/inference/common/output/create_output_api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,29 @@ import { map } from 'rxjs';
import { ChatCompleteAPI, ChatCompletionEventType, MessageRole } from '../chat_complete';
import { withoutTokenCountEvents } from '../chat_complete/without_token_count_events';
import { OutputAPI, OutputEvent, OutputEventType } from '.';
import { ensureMultiTurn } from '../ensure_multi_turn';

export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI {
return (id, { connectorId, input, schema, system }) => {
return (id, { connectorId, input, schema, system, previousMessages }) => {
return chatCompleteApi({
connectorId,
system,
messages: [
messages: ensureMultiTurn([
...(previousMessages || []),
{
role: MessageRole.User,
content: input,
},
],
]),
...(schema
? {
tools: { output: { description: `Output your response in the this format`, schema } },
toolChoice: { function: 'output' },
tools: {
output: {
description: `Use the following schema to respond to the user's request in structured data, so it can be parsed and handled.`,
schema,
},
},
toolChoice: { function: 'output' as const },
}
: {}),
}).pipe(
Expand All @@ -37,10 +44,15 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI {
content: event.content,
};
}

return {
id,
type: OutputEventType.OutputComplete,
output: event.toolCalls[0].function.arguments,
output:
event.toolCalls.length && 'arguments' in event.toolCalls[0].function
? event.toolCalls[0].function.arguments
: undefined,
content: event.content,
};
})
);
Expand Down
14 changes: 10 additions & 4 deletions x-pack/plugins/inference/common/output/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

import { Observable } from 'rxjs';
import { FromToolSchema, ToolSchema } from '../chat_complete/tool_schema';
import { InferenceTaskEventBase } from '../tasks';
import { InferenceTaskEventBase } from '../inference_task';
import { Message } from '../chat_complete';

export enum OutputEventType {
OutputUpdate = 'output',
OutputComplete = 'complete',
}

type Output = Record<string, any> | undefined;
type Output = Record<string, any> | undefined | unknown;

export type OutputUpdateEvent<TId extends string = string> =
InferenceTaskEventBase<OutputEventType.OutputUpdate> & {
Expand All @@ -28,6 +29,7 @@ export type OutputCompleteEvent<
> = InferenceTaskEventBase<OutputEventType.OutputComplete> & {
id: TId;
output: TOutput;
content?: string;
};

export type OutputEvent<TId extends string = string, TOutput extends Output = Output> =
Expand All @@ -39,7 +41,8 @@ export type OutputEvent<TId extends string = string, TOutput extends Output = Ou
*
* @param {string} id The id of the operation
* @param {string} options.connectorId The ID of the connector that is to be used.
* @param {string} options.input The prompt for the LLM
* @param {string} options.input The prompt for the LLM.
* @param {string} options.messages Previous messages in a conversation.
* @param {ToolSchema} [options.schema] The schema the response from the LLM should adhere to.
*/
export type OutputAPI = <
Expand All @@ -52,18 +55,21 @@ export type OutputAPI = <
system?: string;
input: string;
schema?: TOutputSchema;
previousMessages?: Message[];
}
) => Observable<
OutputEvent<TId, TOutputSchema extends ToolSchema ? FromToolSchema<TOutputSchema> : undefined>
>;

export function createOutputCompleteEvent<TId extends string, TOutput extends Output>(
id: TId,
output: TOutput
output: TOutput,
content?: string
): OutputCompleteEvent<TId, TOutput> {
return {
id,
type: OutputEventType.OutputComplete,
output,
content,
};
}
14 changes: 14 additions & 0 deletions x-pack/plugins/inference/common/output/is_output_complete_event.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { OutputEvent, OutputEventType, OutputUpdateEvent } from '.';

export function isOutputCompleteEvent<TOutputEvent extends OutputEvent>(
event: TOutputEvent
): event is Exclude<TOutputEvent, OutputUpdateEvent> {
return event.type === OutputEventType.OutputComplete;
}
15 changes: 15 additions & 0 deletions x-pack/plugins/inference/common/output/is_output_event.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { OutputEvent, OutputEventType } from '.';
import type { InferenceTaskEvent } from '../inference_task';

export function isOutputEvent(event: InferenceTaskEvent): event is OutputEvent {
return (
event.type === OutputEventType.OutputComplete || event.type === OutputEventType.OutputUpdate
);
}
14 changes: 14 additions & 0 deletions x-pack/plugins/inference/common/output/is_output_update_event.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { OutputEvent, OutputEventType, OutputUpdateEvent } from '.';

export function isOutputUpdateEvent<TId extends string>(
event: OutputEvent
): event is OutputUpdateEvent<TId> {
return event.type === OutputEventType.OutputComplete;
}
Loading

0 comments on commit 5c298a1

Please sign in to comment.