Skip to content

Commit

Permalink
[Connectors] Make defaultModel a property of the connector (#162754)
Browse files Browse the repository at this point in the history
Co-authored-by: kibanamachine <[email protected]>
  • Loading branch information
dgieselaar and kibanamachine authored Aug 7, 2023
1 parent 6cd3b6a commit 444d183
Show file tree
Hide file tree
Showing 17 changed files with 276 additions and 65 deletions.
1 change: 1 addition & 0 deletions x-pack/plugins/actions/server/actions_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ export class ActionsClient {

const actionType = this.actionTypeRegistry.get(actionTypeId);
const configurationUtilities = this.actionTypeRegistry.getUtils();

const validatedActionTypeConfig = validateConfig(actionType, config, {
configurationUtilities,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import type { PublicMethodsOf } from '@kbn/utility-types';
import { internal, notFound } from '@hapi/boom';
import { compact, isEmpty, merge, omit } from 'lodash';
import type { SearchHit } from '@elastic/elasticsearch/lib/api/types';
import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/common/gen_ai/constants';
import {
type Conversation,
type ConversationCreateRequest,
Expand Down Expand Up @@ -134,12 +133,7 @@ export class ObservabilityAIAssistantClient implements IObservabilityAIAssistant
})
);

const connector = await this.dependencies.actionsClient.get({
id: connectorId,
});

const request: Omit<CreateChatCompletionRequest, 'model'> & { model?: string } = {
...(connector.config?.apiProvider === OpenAiProviderType.OpenAi ? { model: 'gpt-4' } : {}),
messages: messagesForOpenAI,
stream: true,
};
Expand Down
1 change: 0 additions & 1 deletion x-pack/plugins/observability_ai_assistant/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
"@kbn/spaces-plugin",
"@kbn/kibana-react-plugin",
"@kbn/shared-ux-utility",
"@kbn/stack-connectors-plugin",
"@kbn/alerting-plugin"
],
"exclude": [
Expand Down
2 changes: 2 additions & 0 deletions x-pack/plugins/stack_connectors/common/gen_ai/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ export enum OpenAiProviderType {
AzureAi = 'Azure OpenAI',
}

export const DEFAULT_OPENAI_MODEL = 'gpt-4';

export const OPENAI_CHAT_URL = 'https://api.openai.com/v1/chat/completions' as const;
export const OPENAI_LEGACY_COMPLETION_URL = 'https://api.openai.com/v1/completions' as const;
export const AZURE_OPENAI_CHAT_URL =
Expand Down
20 changes: 12 additions & 8 deletions x-pack/plugins/stack_connectors/common/gen_ai/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,20 @@
*/

import { schema } from '@kbn/config-schema';
import { OpenAiProviderType } from './constants';
import { DEFAULT_OPENAI_MODEL, OpenAiProviderType } from './constants';

// Connector schema
export const GenAiConfigSchema = schema.object({
apiProvider: schema.oneOf([
schema.literal(OpenAiProviderType.OpenAi as string),
schema.literal(OpenAiProviderType.AzureAi as string),
]),
apiUrl: schema.string(),
});
export const GenAiConfigSchema = schema.oneOf([
schema.object({
apiProvider: schema.oneOf([schema.literal(OpenAiProviderType.AzureAi)]),
apiUrl: schema.string(),
}),
schema.object({
apiProvider: schema.oneOf([schema.literal(OpenAiProviderType.OpenAi)]),
apiUrl: schema.string(),
defaultModel: schema.string({ defaultValue: DEFAULT_OPENAI_MODEL }),
}),
]);

export const GenAiSecretsSchema = schema.object({ apiKey: schema.string() });

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import GenerativeAiConnectorFields from './connector';
import { ConnectorFormTestProvider } from '../lib/test_utils';
import { act, fireEvent, render, waitFor } from '@testing-library/react';
import userEvent from '@testing-library/user-event';
import { OpenAiProviderType } from '../../../common/gen_ai/constants';
import { DEFAULT_OPENAI_MODEL, OpenAiProviderType } from '../../../common/gen_ai/constants';
import { useKibana } from '@kbn/triggers-actions-ui-plugin/public';
import { useGetDashboard } from './use_get_dashboard';

Expand All @@ -26,6 +26,7 @@ const openAiConnector = {
config: {
apiUrl: 'https://openaiurl.com',
apiProvider: OpenAiProviderType.OpenAi,
defaultModel: DEFAULT_OPENAI_MODEL,
},
secrets: {
apiKey: 'thats-a-nice-looking-key',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@ import React from 'react';
import { ConfigFieldSchema, SecretsFieldSchema } from '@kbn/triggers-actions-ui-plugin/public';
import { FormattedMessage } from '@kbn/i18n-react';
import { EuiLink } from '@elastic/eui';
import { OpenAiProviderType } from '../../../common/gen_ai/constants';
import { DEFAULT_OPENAI_MODEL, OpenAiProviderType } from '../../../common/gen_ai/constants';
import * as i18n from './translations';

export const DEFAULT_URL = 'https://api.openai.com/v1/chat/completions' as const;
export const DEFAULT_URL_AZURE =
'https://{your-resource-name}.openai.azure.com/openai/deployments/{deployment-id}/chat/completions?api-version={api-version}' as const;

export const DEFAULT_BODY = `{
"model":"gpt-3.5-turbo",
"messages": [{
"role":"user",
"content":"Hello world"
Expand Down Expand Up @@ -54,6 +53,28 @@ export const openAiConfig: ConfigFieldSchema[] = [
/>
),
},
{
id: 'defaultModel',
label: i18n.DEFAULT_MODEL_LABEL,
helpText: (
<FormattedMessage
defaultMessage='The model can be set on a per request basis by including a "model" parameter in the request body. If no model is provided, the fallback will be the default model. For more information, refer to the {genAiAPIModelDocs}.'
id="xpack.stackConnectors.components.genAi.openAiDocumentationModel"
values={{
genAiAPIModelDocs: (
<EuiLink
data-test-subj="open-ai-api-doc"
href="https://platform.openai.com/docs/api-reference/models"
target="_blank"
>
{`${i18n.OPEN_AI} ${i18n.DOCUMENTATION}`}
</EuiLink>
),
}}
/>
),
defaultValue: DEFAULT_OPENAI_MODEL,
},
];

export const azureAiConfig: ConfigFieldSchema[] = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,21 @@ export const API_KEY_LABEL = i18n.translate('xpack.stackConnectors.components.ge
defaultMessage: 'API Key',
});

export const DEFAULT_MODEL_LABEL = i18n.translate(
'xpack.stackConnectors.components.genAi.defaultModelTextFieldLabel',
{
defaultMessage: 'Default model',
}
);

export const DEFAULT_MODEL_TOOLTIP_CONTENT = i18n.translate(
'xpack.stackConnectors.components.genAi.defaultModelTooltipContent',
{
defaultMessage:
'The model can be set on a per request basis by including a "model" parameter in the request body. If no model is provided, the fallback will be the default model.',
}
);

export const API_PROVIDER_LABEL = i18n.translate(
'xpack.stackConnectors.components.genAi.apiProviderLabel',
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

import { GenAiConnector } from './gen_ai';
import { actionsConfigMock } from '@kbn/actions-plugin/server/actions_config.mock';
import { GEN_AI_CONNECTOR_ID, OpenAiProviderType } from '../../../common/gen_ai/constants';
import {
DEFAULT_OPENAI_MODEL,
GEN_AI_CONNECTOR_ID,
OpenAiProviderType,
} from '../../../common/gen_ai/constants';
import { loggingSystemMock } from '@kbn/core-logging-server-mocks';
import { actionsMock } from '@kbn/actions-plugin/server/mocks';
import {
Expand Down Expand Up @@ -35,14 +39,14 @@ describe('GenAiConnector', () => {
config: {
apiUrl: 'https://api.openai.com/v1/chat/completions',
apiProvider: OpenAiProviderType.OpenAi,
defaultModel: DEFAULT_OPENAI_MODEL,
},
secrets: { apiKey: '123' },
logger: loggingSystemMock.createLogger(),
services: actionsMock.createServices(),
});

const sampleOpenAiBody = {
model: 'gpt-3.5-turbo',
messages: [
{
role: 'user',
Expand All @@ -58,14 +62,47 @@ describe('GenAiConnector', () => {
});

describe('runApi', () => {
it('uses the default model if none is supplied', async () => {
const response = await connector.runApi({ body: JSON.stringify(sampleOpenAiBody) });
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith({
url: 'https://api.openai.com/v1/chat/completions',
method: 'post',
responseSchema: GenAiRunActionResponseSchema,
data: JSON.stringify({ ...sampleOpenAiBody, stream: false, model: DEFAULT_OPENAI_MODEL }),
headers: {
Authorization: 'Bearer 123',
'content-type': 'application/json',
},
});
expect(response).toEqual({ result: 'success' });
});

it('overrides the default model with the default model specified in the body', async () => {
const requestBody = { model: 'gpt-3.5-turbo', ...sampleOpenAiBody };
const response = await connector.runApi({ body: JSON.stringify(requestBody) });
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith({
url: 'https://api.openai.com/v1/chat/completions',
method: 'post',
responseSchema: GenAiRunActionResponseSchema,
data: JSON.stringify({ ...requestBody, stream: false }),
headers: {
Authorization: 'Bearer 123',
'content-type': 'application/json',
},
});
expect(response).toEqual({ result: 'success' });
});

it('the OpenAI API call is successful with correct parameters', async () => {
const response = await connector.runApi({ body: JSON.stringify(sampleOpenAiBody) });
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith({
url: 'https://api.openai.com/v1/chat/completions',
method: 'post',
responseSchema: GenAiRunActionResponseSchema,
data: JSON.stringify({ ...sampleOpenAiBody, stream: false }),
data: JSON.stringify({ ...sampleOpenAiBody, stream: false, model: DEFAULT_OPENAI_MODEL }),
headers: {
Authorization: 'Bearer 123',
'content-type': 'application/json',
Expand Down Expand Up @@ -128,7 +165,7 @@ describe('GenAiConnector', () => {
url: 'https://api.openai.com/v1/chat/completions',
method: 'post',
responseSchema: GenAiRunActionResponseSchema,
data: JSON.stringify({ ...sampleOpenAiBody, stream: false }),
data: JSON.stringify({ ...sampleOpenAiBody, stream: false, model: DEFAULT_OPENAI_MODEL }),
headers: {
Authorization: 'Bearer 123',
'content-type': 'application/json',
Expand All @@ -148,7 +185,7 @@ describe('GenAiConnector', () => {
url: 'https://api.openai.com/v1/chat/completions',
method: 'post',
responseSchema: GenAiStreamingResponseSchema,
data: JSON.stringify({ ...sampleOpenAiBody, stream: true }),
data: JSON.stringify({ ...sampleOpenAiBody, stream: true, model: DEFAULT_OPENAI_MODEL }),
headers: {
Authorization: 'Bearer 123',
'content-type': 'application/json',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,12 @@ export class GenAiConnector extends SubActionConnector<GenAiConfig, GenAiSecrets
}

public async runApi({ body }: GenAiRunActionParams): Promise<GenAiRunActionResponse> {
const sanitizedBody = sanitizeRequest(this.provider, this.url, body);
const sanitizedBody = sanitizeRequest(
this.provider,
this.url,
body,
...('defaultModel' in this.config ? [this.config.defaultModel] : [])
);
const axiosOptions = getAxiosOptions(this.provider, this.key, false);
const response = await this.request({
url: this.url,
Expand All @@ -104,7 +109,13 @@ export class GenAiConnector extends SubActionConnector<GenAiConfig, GenAiSecrets
body,
stream,
}: GenAiStreamActionParams): Promise<GenAiRunActionResponse> {
const executeBody = getRequestWithStreamOption(this.provider, this.url, body, stream);
const executeBody = getRequestWithStreamOption(
this.provider,
this.url,
body,
stream,
...('defaultModel' in this.config ? [this.config.defaultModel] : [])
);
const axiosOptions = getAxiosOptions(this.provider, this.key, stream);
const response = await this.request({
url: this.url,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import axios from 'axios';
import { configValidator, getConnectorType } from '.';
import { GenAiConfig, GenAiSecrets } from '../../../common/gen_ai/types';
import { SubActionConnectorType } from '@kbn/actions-plugin/server/sub_action_framework/types';
import { OpenAiProviderType } from '../../../common/gen_ai/constants';
import { DEFAULT_OPENAI_MODEL, OpenAiProviderType } from '../../../common/gen_ai/constants';

jest.mock('axios');
jest.mock('@kbn/actions-plugin/server/lib/axios_utils', () => {
Expand Down Expand Up @@ -44,6 +44,7 @@ describe('Generative AI Connector', () => {
const config: GenAiConfig = {
apiUrl: 'https://api.openai.com/v1/chat/completions',
apiProvider: OpenAiProviderType.OpenAi,
defaultModel: DEFAULT_OPENAI_MODEL,
};

expect(configValidator(config, { configurationUtilities })).toEqual(config);
Expand All @@ -53,6 +54,7 @@ describe('Generative AI Connector', () => {
const config: GenAiConfig = {
apiUrl: 'example.com/do-something',
apiProvider: OpenAiProviderType.OpenAi,
defaultModel: DEFAULT_OPENAI_MODEL,
};
expect(() => {
configValidator(config, { configurationUtilities });
Expand All @@ -64,7 +66,8 @@ describe('Generative AI Connector', () => {
test('config validation failed when the OpenAI API provider is empty', () => {
const config: GenAiConfig = {
apiUrl: 'https://api.openai.com/v1/chat/completions',
apiProvider: '',
apiProvider: '' as OpenAiProviderType,
defaultModel: DEFAULT_OPENAI_MODEL,
};
expect(() => {
configValidator(config, { configurationUtilities });
Expand All @@ -76,7 +79,8 @@ describe('Generative AI Connector', () => {
test('config validation failed when the OpenAI API provider is invalid', () => {
const config: GenAiConfig = {
apiUrl: 'https://api.openai.com/v1/chat/completions',
apiProvider: 'bad-one',
apiProvider: 'bad-one' as OpenAiProviderType,
defaultModel: DEFAULT_OPENAI_MODEL,
};
expect(() => {
configValidator(config, { configurationUtilities });
Expand All @@ -96,6 +100,7 @@ describe('Generative AI Connector', () => {
const config: GenAiConfig = {
apiUrl: 'http://mylisteningserver.com:9200/endpoint',
apiProvider: OpenAiProviderType.OpenAi,
defaultModel: DEFAULT_OPENAI_MODEL,
};

expect(() => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,12 @@ export const configValidator = (
assertURL(configObject.apiUrl);
urlAllowListValidator('apiUrl')(configObject, validatorServices);

if (
configObject.apiProvider !== OpenAiProviderType.OpenAi &&
configObject.apiProvider !== OpenAiProviderType.AzureAi
) {
const { apiProvider } = configObject;

if (apiProvider !== OpenAiProviderType.OpenAi && apiProvider !== OpenAiProviderType.AzureAi) {
throw new Error(
`API Provider is not supported${
configObject.apiProvider && configObject.apiProvider.length
? `: ${configObject.apiProvider}`
: ``
apiProvider && (apiProvider as OpenAiProviderType).length ? `: ${apiProvider}` : ``
}`
);
}
Expand Down
Loading

0 comments on commit 444d183

Please sign in to comment.