Skip to content

Commit

Permalink
[Inference] Minor cleanup and restructure (#191069)
Browse files Browse the repository at this point in the history
## Summary

Fixing and improving a few things I noticed while discovering / ramping
up on the existing code
- address some nits
- extract / reuse some low level functions
- move things around
- add unit tests
  • Loading branch information
pgayvallet authored Aug 26, 2024
1 parent 7eb7e97 commit 56730e8
Show file tree
Hide file tree
Showing 27 changed files with 571 additions and 211 deletions.
1 change: 1 addition & 0 deletions x-pack/plugins/inference/common/chat_complete/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import type { Observable } from 'rxjs';
import type { InferenceTaskEventBase } from '../tasks';
import type { ToolCall, ToolCallsOf, ToolOptions } from './tools';
Expand Down
12 changes: 7 additions & 5 deletions x-pack/plugins/inference/common/connectors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@ export enum InferenceConnectorType {
Gemini = '.gemini',
}

const allSupportedConnectorTypes = Object.values(InferenceConnectorType);

export interface InferenceConnector {
type: InferenceConnectorType;
name: string;
connectorId: string;
}

export function isSupportedConnectorType(id: string): id is InferenceConnectorType {
return (
id === InferenceConnectorType.OpenAI ||
id === InferenceConnectorType.Bedrock ||
id === InferenceConnectorType.Gemini
);
return allSupportedConnectorTypes.includes(id as InferenceConnectorType);
}

export interface GetConnectorsResponseBody {
connectors: InferenceConnector[];
}
4 changes: 2 additions & 2 deletions x-pack/plugins/inference/public/chat_complete/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
* 2.0.
*/

import type { HttpStart } from '@kbn/core/public';
import { from } from 'rxjs';
import { ChatCompleteAPI } from '../../common/chat_complete';
import type { HttpStart } from '@kbn/core/public';
import type { ChatCompleteAPI } from '../../common/chat_complete';
import type { ChatCompleteRequestBody } from '../../common/chat_complete/request';
import { httpResponseIntoObservable } from '../util/http_response_into_observable';

Expand Down
12 changes: 9 additions & 3 deletions x-pack/plugins/inference/public/plugin.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { CoreSetup, CoreStart, Plugin, PluginInitializerContext } from '@kbn/core/public';

import type { CoreSetup, CoreStart, Plugin, PluginInitializerContext } from '@kbn/core/public';
import type { Logger } from '@kbn/logging';
import { createOutputApi } from '../common/output/create_output_api';
import type { GetConnectorsResponseBody } from '../common/connectors';
import { createChatCompleteApi } from './chat_complete';
import type {
ConfigSchema,
Expand Down Expand Up @@ -39,11 +41,15 @@ export class InferencePlugin

start(coreStart: CoreStart, pluginsStart: InferenceStartDependencies): InferencePublicStart {
const chatComplete = createChatCompleteApi({ http: coreStart.http });

return {
chatComplete,
output: createOutputApi(chatComplete),
getConnectors: () => {
return coreStart.http.get('/internal/inference/connectors');
getConnectors: async () => {
const res = await coreStart.http.get<GetConnectorsResponseBody>(
'/internal/inference/connectors'
);
return res.connectors;
},
};
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { InferenceConnectorType } from '../../../common/connectors';
import { getInferenceAdapter } from './get_inference_adapter';
import { openAIAdapter } from './openai';

describe('getInferenceAdapter', () => {
it('returns the openAI adapter for OpenAI type', () => {
expect(getInferenceAdapter(InferenceConnectorType.OpenAI)).toBe(openAIAdapter);
});

it('returns undefined for Bedrock type', () => {
expect(getInferenceAdapter(InferenceConnectorType.Bedrock)).toBe(undefined);
});

it('returns undefined for Gemini type', () => {
expect(getInferenceAdapter(InferenceConnectorType.Gemini)).toBe(undefined);
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { InferenceConnectorType } from '../../../common/connectors';
import type { InferenceConnectorAdapter } from '../types';
import { openAIAdapter } from './openai';

export const getInferenceAdapter = (
connectorType: InferenceConnectorType
): InferenceConnectorAdapter | undefined => {
switch (connectorType) {
case InferenceConnectorType.OpenAI:
return openAIAdapter;

case InferenceConnectorType.Bedrock:
// not implemented yet
break;

case InferenceConnectorType.Gemini:
// not implemented yet
break;
}

return undefined;
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

export { getInferenceAdapter } from './get_inference_adapter';
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
*/

import OpenAI from 'openai';
import { openAIAdapter } from '.';
import type { ActionsClient } from '@kbn/actions-plugin/server/actions_client';
import { ChatCompletionEventType, MessageRole } from '../../../../common/chat_complete';
import { v4 } from 'uuid';
import { PassThrough } from 'stream';
import { pick } from 'lodash';
import { lastValueFrom, Subject, toArray } from 'rxjs';
import { ChatCompletionEventType, MessageRole } from '../../../../common/chat_complete';
import { observableIntoEventSourceStream } from '../../../util/observable_into_event_source_stream';
import { v4 } from 'uuid';
import { InferenceExecutor } from '../../utils/inference_executor';
import { openAIAdapter } from '.';

function createOpenAIChunk({
delta,
Expand All @@ -39,38 +39,27 @@ function createOpenAIChunk({
}

describe('openAIAdapter', () => {
const actionsClientMock = {
execute: jest.fn(),
} as ActionsClient & { execute: jest.MockedFn<ActionsClient['execute']> };
const executorMock = {
invoke: jest.fn(),
} as InferenceExecutor & { invoke: jest.MockedFn<InferenceExecutor['invoke']> };

beforeEach(() => {
actionsClientMock.execute.mockReset();
executorMock.invoke.mockReset();
});

const defaultArgs = {
connector: {
id: 'foo',
actionTypeId: '.gen-ai',
name: 'OpenAI',
isPreconfigured: false,
isDeprecated: false,
isSystemAction: false,
},
actionsClient: actionsClientMock,
executor: executorMock,
};

describe('when creating the request', () => {
function getRequest() {
const params = actionsClientMock.execute.mock.calls[0][0].params.subActionParams as Record<
string,
any
>;
const params = executorMock.invoke.mock.calls[0][0].subActionParams as Record<string, any>;

return { stream: params.stream, body: JSON.parse(params.body) };
}

beforeEach(() => {
actionsClientMock.execute.mockImplementation(async () => {
executorMock.invoke.mockImplementation(async () => {
return {
actionId: '',
status: 'ok',
Expand Down Expand Up @@ -262,7 +251,7 @@ describe('openAIAdapter', () => {
beforeEach(() => {
source$ = new Subject<Record<string, any>>();

actionsClientMock.execute.mockImplementation(async () => {
executorMock.invoke.mockImplementation(async () => {
return {
actionId: '',
status: 'ok',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,60 +21,30 @@ import {
Message,
MessageRole,
} from '../../../../common/chat_complete';
import type { ToolOptions } from '../../../../common/chat_complete/tools';
import { createTokenLimitReachedError } from '../../../../common/chat_complete/errors';
import { createInferenceInternalError } from '../../../../common/errors';
import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable';
import { InferenceConnectorAdapter } from '../../types';
import { eventSourceStreamIntoObservable } from '../event_source_stream_into_observable';

export const openAIAdapter: InferenceConnectorAdapter = {
chatComplete: ({ connector, actionsClient, system, messages, toolChoice, tools }) => {
const openAIMessages = messagesToOpenAI({ system, messages });

const toolChoiceForOpenAI =
typeof toolChoice === 'string'
? toolChoice
: toolChoice
? {
function: {
name: toolChoice.function,
},
type: 'function' as const,
}
: undefined;

chatComplete: ({ executor, system, messages, toolChoice, tools }) => {
const stream = true;

const request: Omit<OpenAI.ChatCompletionCreateParams, 'model'> & { model?: string } = {
stream,
messages: openAIMessages,
messages: messagesToOpenAI({ system, messages }),
tool_choice: toolChoiceToOpenAI(toolChoice),
tools: toolsToOpenAI(tools),
temperature: 0,
tool_choice: toolChoiceForOpenAI,
tools: tools
? Object.entries(tools).map(([toolName, { description, schema }]) => {
return {
type: 'function',
function: {
name: toolName,
description,
parameters: (schema ?? {
type: 'object' as const,
properties: {},
}) as unknown as Record<string, unknown>,
},
};
})
: undefined,
};

return from(
actionsClient.execute({
actionId: connector.id,
params: {
subAction: 'stream',
subActionParams: {
body: JSON.stringify(request),
stream,
},
executor.invoke({
subAction: 'stream',
subActionParams: {
body: JSON.stringify(request),
stream,
},
})
).pipe(
Expand Down Expand Up @@ -125,6 +95,39 @@ export const openAIAdapter: InferenceConnectorAdapter = {
},
};

function toolsToOpenAI(tools: ToolOptions['tools']): OpenAI.ChatCompletionCreateParams['tools'] {
return tools
? Object.entries(tools).map(([toolName, { description, schema }]) => {
return {
type: 'function',
function: {
name: toolName,
description,
parameters: (schema ?? {
type: 'object' as const,
properties: {},
}) as unknown as Record<string, unknown>,
},
};
})
: undefined;
}

function toolChoiceToOpenAI(
toolChoice: ToolOptions['toolChoice']
): OpenAI.ChatCompletionCreateParams['tool_choice'] {
return typeof toolChoice === 'string'
? toolChoice
: toolChoice
? {
function: {
name: toolChoice.function,
},
type: 'function' as const,
}
: undefined;
}

function messagesToOpenAI({
system,
messages,
Expand Down
63 changes: 63 additions & 0 deletions x-pack/plugins/inference/server/chat_complete/api.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import type { KibanaRequest } from '@kbn/core-http-server';
import { defer, switchMap, throwError } from 'rxjs';
import type { ChatCompleteAPI, ChatCompletionResponse } from '../../common/chat_complete';
import { createInferenceRequestError } from '../../common/errors';
import type { InferenceStartDependencies } from '../types';
import { getConnectorById } from '../util/get_connector_by_id';
import { getInferenceAdapter } from './adapters';
import { createInferenceExecutor, chunksIntoMessage } from './utils';

export function createChatCompleteApi({
request,
actions,
}: {
request: KibanaRequest;
actions: InferenceStartDependencies['actions'];
}) {
const chatCompleteAPI: ChatCompleteAPI = ({
connectorId,
messages,
toolChoice,
tools,
system,
}): ChatCompletionResponse => {
return defer(async () => {
const actionsClient = await actions.getActionsClientWithRequest(request);
const connector = await getConnectorById({ connectorId, actionsClient });
const executor = createInferenceExecutor({ actionsClient, connector });
return { executor, connector };
}).pipe(
switchMap(({ executor, connector }) => {
const connectorType = connector.type;
const inferenceAdapter = getInferenceAdapter(connectorType);

if (!inferenceAdapter) {
return throwError(() =>
createInferenceRequestError(`Adapter for type ${connectorType} not implemented`, 400)
);
}

return inferenceAdapter.chatComplete({
system,
executor,
messages,
toolChoice,
tools,
});
}),
chunksIntoMessage({
toolChoice,
tools,
})
);
};

return chatCompleteAPI;
}
Loading

0 comments on commit 56730e8

Please sign in to comment.