Skip to content

Commit

Permalink
[inference] Add cancelation support for chatComplete and output (elas…
Browse files Browse the repository at this point in the history
…tic#203108)

## Summary

Fix elastic#200757

Add cancelation support for `chatComplete` and `output`, based on an
abort signal.


### Examples

#### response mode

```ts
import { isInferenceRequestAbortedError } from '@kbn/inference-common';

try {
  const abortController = new AbortController();
  const chatResponse = await inferenceClient.chatComplete({
    connectorId: 'some-gen-ai-connector',
    abortSignal: abortController.signal,
    messages: [{ role: MessageRole.User, content: 'Do something' }],
  });
} catch(e) {
  if(isInferenceRequestAbortedError(e)) {
    // request was aborted, do something
  } else {
    // was another error, do something else
  }
}

// elsewhere
abortController.abort()
```

#### stream mode

```ts
import { isInferenceRequestAbortedError } from '@kbn/inference-common';

const abortController = new AbortController();
const events$ = inferenceClient.chatComplete({
  stream: true,
  connectorId: 'some-gen-ai-connector',
  abortSignal: abortController.signal,
  messages: [{ role: MessageRole.User, content: 'Do something' }],
});

events$.subscribe({
  next: (event) => {
    // do something
  },
  error: (err) => {
    if(isInferenceRequestAbortedError(e)) {
      // request was aborted, do something
    } else {
      // was another error, do something else
    }
  }
});

abortController.abort();
```
  • Loading branch information
pgayvallet authored Dec 17, 2024
1 parent 78f1d17 commit 0b74f62
Show file tree
Hide file tree
Showing 27 changed files with 688 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -1832,7 +1832,7 @@ packages/kbn-monaco/src/esql @elastic/kibana-esql
#CC# /x-pack/plugins/global_search_providers/ @elastic/kibana-core

# AppEx AI Infra
/x-pack/plugins/inference @elastic/appex-ai-infra @elastic/obs-ai-assistant @elastic/security-generative-ai
/x-pack/platform/plugins/shared/inference @elastic/appex-ai-infra @elastic/obs-ai-assistant @elastic/security-generative-ai
/x-pack/test/functional_gen_ai/inference @elastic/appex-ai-infra

# AppEx Platform Services Security
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,14 @@ export {
type InferenceTaskErrorEvent,
type InferenceTaskInternalError,
type InferenceTaskRequestError,
type InferenceTaskAbortedError,
createInferenceInternalError,
createInferenceRequestError,
createInferenceRequestAbortedError,
isInferenceError,
isInferenceInternalError,
isInferenceRequestError,
isInferenceRequestAbortedError,
} from './src/errors';

export { truncateList } from './src/truncate_list';
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ export type ChatCompleteOptions<
* Function calling mode, defaults to "native".
*/
functionCalling?: FunctionCallingMode;
/**
* Optional signal that can be used to forcefully abort the request.
*/
abortSignal?: AbortSignal;
} & TToolOptions;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { InferenceTaskEventBase, InferenceTaskEventType } from './inference_task
export enum InferenceTaskErrorCode {
internalError = 'internalError',
requestError = 'requestError',
abortedError = 'requestAborted',
}

/**
Expand Down Expand Up @@ -46,16 +47,37 @@ export type InferenceTaskErrorEvent = InferenceTaskEventBase<InferenceTaskEventT
};
};

/**
* Inference error thrown when an unexpected internal error occurs while handling the request.
*/
export type InferenceTaskInternalError = InferenceTaskError<
InferenceTaskErrorCode.internalError,
Record<string, any>
>;

/**
* Inference error thrown when the request was considered invalid.
*
* Some example of reasons for invalid requests would be:
* - no connector matching the provided connectorId
* - invalid connector type for the provided connectorId
*/
export type InferenceTaskRequestError = InferenceTaskError<
InferenceTaskErrorCode.requestError,
{ status: number }
>;

/**
* Inference error thrown when the request was aborted.
*
* Request abortion occurs when providing an abort signal and firing it
* before the call to the LLM completes.
*/
export type InferenceTaskAbortedError = InferenceTaskError<
InferenceTaskErrorCode.abortedError,
{ status: number }
>;

export function createInferenceInternalError(
message = 'An internal error occurred',
meta?: Record<string, any>
Expand All @@ -72,16 +94,38 @@ export function createInferenceRequestError(
});
}

export function createInferenceRequestAbortedError(): InferenceTaskAbortedError {
return new InferenceTaskError(InferenceTaskErrorCode.abortedError, 'Request was aborted', {
status: 499,
});
}

/**
* Check if the given error is an {@link InferenceTaskError}
*/
export function isInferenceError(
error: unknown
): error is InferenceTaskError<string, Record<string, any> | undefined> {
return error instanceof InferenceTaskError;
}

/**
* Check if the given error is an {@link InferenceTaskInternalError}
*/
export function isInferenceInternalError(error: unknown): error is InferenceTaskInternalError {
return isInferenceError(error) && error.code === InferenceTaskErrorCode.internalError;
}

/**
* Check if the given error is an {@link InferenceTaskRequestError}
*/
export function isInferenceRequestError(error: unknown): error is InferenceTaskRequestError {
return isInferenceError(error) && error.code === InferenceTaskErrorCode.requestError;
}

/**
* Check if the given error is an {@link InferenceTaskAbortedError}
*/
export function isInferenceRequestAbortedError(error: unknown): error is InferenceTaskAbortedError {
return isInferenceError(error) && error.code === InferenceTaskErrorCode.abortedError;
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ export interface OutputOptions<
* Defaults to false.
*/
stream?: TStream;

/**
* Optional signal that can be used to forcefully abort the request.
*/
abortSignal?: AbortSignal;
/**
* Optional configuration for retrying the call if an error occurs.
*/
Expand Down
69 changes: 69 additions & 0 deletions x-pack/platform/plugins/shared/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,75 @@ const toolCall = toolCalls[0];
// process the tool call and eventually continue the conversation with the LLM
```

#### Request cancellation

Request cancellation can be done by passing an abort signal when calling the API. Firing the signal
before the request completes will cause the abortion, and the API call will throw an error.

```ts
const abortController = new AbortController();

const chatResponse = await inferenceClient.chatComplete({
connectorId: 'some-gen-ai-connector',
abortSignal: abortController.signal,
messages: [{ role: MessageRole.User, content: 'Do something' }],
});

// from elsewhere / before the request completes and the promise resolves:

abortController.abort();
```

The `isInferenceRequestAbortedError` helper function, exposed from `@kbn/inference-common`, can be used easily identify those errors:

```ts
import { isInferenceRequestAbortedError } from '@kbn/inference-common';

try {
const abortController = new AbortController();
const chatResponse = await inferenceClient.chatComplete({
connectorId: 'some-gen-ai-connector',
abortSignal: abortController.signal,
messages: [{ role: MessageRole.User, content: 'Do something' }],
});
} catch(e) {
if(isInferenceRequestAbortedError(e)) {
// request was aborted, do something
} else {
// was another error, do something else
}
}
```

The approach is very similar for stream mode:

```ts
import { isInferenceRequestAbortedError } from '@kbn/inference-common';

const abortController = new AbortController();
const events$ = inferenceClient.chatComplete({
stream: true,
connectorId: 'some-gen-ai-connector',
abortSignal: abortController.signal,
messages: [{ role: MessageRole.User, content: 'Do something' }],
});

events$.subscribe({
next: (event) => {
// do something
},
error: (err) => {
if(isInferenceRequestAbortedError(e)) {
// request was aborted, do something
} else {
// was another error, do something else
}
}
});

abortController.abort();
```

### `output` API

`output` is a wrapper around the `chatComplete` API that is catered towards a specific use case: having the LLM output a structured response, based on a schema.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,4 +196,26 @@ describe('createOutputApi', () => {
).toThrowError('Retry options are not supported in streaming mode');
});
});

it('propagates the abort signal when provided', async () => {
chatComplete.mockResolvedValue(Promise.resolve({ content: 'content', toolCalls: [] }));

const output = createOutputApi(chatComplete);

const abortController = new AbortController();

await output({
id: 'id',
connectorId: '.my-connector',
input: 'input message',
abortSignal: abortController.signal,
});

expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith(
expect.objectContaining({
abortSignal: abortController.signal,
})
);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI) {
previousMessages,
functionCalling,
stream,
abortSignal,
retry,
}: DefaultOutputOptions): OutputCompositeResponse<string, ToolSchema | undefined, boolean> {
if (stream && retry !== undefined) {
Expand All @@ -52,6 +53,7 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI) {
connectorId,
stream,
functionCalling,
abortSignal,
system,
messages,
...(schema
Expand Down Expand Up @@ -113,6 +115,7 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI) {
input,
schema,
system,
abortSignal,
previousMessages: messages.concat(
{
role: MessageRole.Assistant as const,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,5 +325,24 @@ describe('bedrockClaudeAdapter', () => {
expect(tools).toEqual([]);
expect(system).toEqual(addNoToolUsageDirective('some system instruction'));
});

it('propagates the abort signal when provided', () => {
const abortController = new AbortController();

bedrockClaudeAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
abortSignal: abortController.signal,
});

expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({
subAction: 'invokeStream',
subActionParams: expect.objectContaining({
signal: abortController.signal,
}),
});
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import { processCompletionChunks } from './process_completion_chunks';
import { addNoToolUsageDirective } from './prompts';

export const bedrockClaudeAdapter: InferenceConnectorAdapter = {
chatComplete: ({ executor, system, messages, toolChoice, tools }) => {
chatComplete: ({ executor, system, messages, toolChoice, tools, abortSignal }) => {
const noToolUsage = toolChoice === ToolChoiceType.none;

const subActionParams = {
Expand All @@ -36,6 +36,7 @@ export const bedrockClaudeAdapter: InferenceConnectorAdapter = {
toolChoice: toolChoiceToBedrock(toolChoice),
temperature: 0,
stopSequences: ['\n\nHuman:'],
signal: abortSignal,
};

return from(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,5 +402,24 @@ describe('geminiAdapter', () => {
expect(tapFn).toHaveBeenCalledWith({ chunk: 1 });
expect(tapFn).toHaveBeenCalledWith({ chunk: 2 });
});

it('propagates the abort signal when provided', () => {
const abortController = new AbortController();

geminiAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
abortSignal: abortController.signal,
});

expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({
subAction: 'invokeStream',
subActionParams: expect.objectContaining({
signal: abortController.signal,
}),
});
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import { processVertexStream } from './process_vertex_stream';
import type { GenerateContentResponseChunk, GeminiMessage, GeminiToolConfig } from './types';

export const geminiAdapter: InferenceConnectorAdapter = {
chatComplete: ({ executor, system, messages, toolChoice, tools }) => {
chatComplete: ({ executor, system, messages, toolChoice, tools, abortSignal }) => {
return from(
executor.invoke({
subAction: 'invokeStream',
Expand All @@ -32,6 +32,7 @@ export const geminiAdapter: InferenceConnectorAdapter = {
tools: toolsToGemini(tools),
toolConfig: toolChoiceToConfig(toolChoice),
temperature: 0,
signal: abortSignal,
stopSequences: ['\n\nHuman:'],
},
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ describe('openAIAdapter', () => {
};
});
});

it('correctly formats messages ', () => {
openAIAdapter.chatComplete({
...defaultArgs,
Expand Down Expand Up @@ -254,6 +255,25 @@ describe('openAIAdapter', () => {
expect(getRequest().stream).toBe(true);
expect(getRequest().body.stream).toBe(true);
});

it('propagates the abort signal when provided', () => {
const abortController = new AbortController();

openAIAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
abortSignal: abortController.signal,
});

expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({
subAction: 'stream',
subActionParams: expect.objectContaining({
signal: abortController.signal,
}),
});
});
});

describe('when handling the response', () => {
Expand Down
Loading

0 comments on commit 0b74f62

Please sign in to comment.