From 444d183d8f253572173f786ddf978bc1f07bea6c Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Mon, 7 Aug 2023 07:27:30 +0200 Subject: [PATCH] [Connectors] Make defaultModel a property of the connector (#162754) Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com> --- .../plugins/actions/server/actions_client.ts | 1 + .../server/service/client/index.ts | 6 -- .../observability_ai_assistant/tsconfig.json | 1 - .../common/gen_ai/constants.ts | 2 + .../stack_connectors/common/gen_ai/schema.ts | 20 ++++--- .../connector_types/gen_ai/connector.test.tsx | 3 +- .../connector_types/gen_ai/constants.tsx | 25 ++++++++- .../connector_types/gen_ai/translations.ts | 15 +++++ .../connector_types/gen_ai/gen_ai.test.ts | 47 ++++++++++++++-- .../server/connector_types/gen_ai/gen_ai.ts | 15 ++++- .../connector_types/gen_ai/index.test.ts | 11 +++- .../server/connector_types/gen_ai/index.ts | 11 ++-- .../gen_ai/lib/openai_utils.test.ts | 56 +++++++++++++++---- .../gen_ai/lib/openai_utils.ts | 12 +++- .../connector_types/gen_ai/lib/utils.test.ts | 32 +++++++++-- .../connector_types/gen_ai/lib/utils.ts | 43 +++++++++++--- .../tests/actions/connector_types/gen_ai.ts | 41 ++++++++++++-- 17 files changed, 276 insertions(+), 65 deletions(-) diff --git a/x-pack/plugins/actions/server/actions_client.ts b/x-pack/plugins/actions/server/actions_client.ts index b409c737c8cf..9b47bd38af3c 100644 --- a/x-pack/plugins/actions/server/actions_client.ts +++ b/x-pack/plugins/actions/server/actions_client.ts @@ -247,6 +247,7 @@ export class ActionsClient { const actionType = this.actionTypeRegistry.get(actionTypeId); const configurationUtilities = this.actionTypeRegistry.getUtils(); + const validatedActionTypeConfig = validateConfig(actionType, config, { configurationUtilities, }); diff --git a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts index 920bddee2a17..30e47cc34672 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts @@ -15,7 +15,6 @@ import type { PublicMethodsOf } from '@kbn/utility-types'; import { internal, notFound } from '@hapi/boom'; import { compact, isEmpty, merge, omit } from 'lodash'; import type { SearchHit } from '@elastic/elasticsearch/lib/api/types'; -import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/common/gen_ai/constants'; import { type Conversation, type ConversationCreateRequest, @@ -134,12 +133,7 @@ export class ObservabilityAIAssistantClient implements IObservabilityAIAssistant }) ); - const connector = await this.dependencies.actionsClient.get({ - id: connectorId, - }); - const request: Omit & { model?: string } = { - ...(connector.config?.apiProvider === OpenAiProviderType.OpenAi ? { model: 'gpt-4' } : {}), messages: messagesForOpenAI, stream: true, }; diff --git a/x-pack/plugins/observability_ai_assistant/tsconfig.json b/x-pack/plugins/observability_ai_assistant/tsconfig.json index aa923b7a8094..fbe0b8b24537 100644 --- a/x-pack/plugins/observability_ai_assistant/tsconfig.json +++ b/x-pack/plugins/observability_ai_assistant/tsconfig.json @@ -24,7 +24,6 @@ "@kbn/spaces-plugin", "@kbn/kibana-react-plugin", "@kbn/shared-ux-utility", - "@kbn/stack-connectors-plugin", "@kbn/alerting-plugin" ], "exclude": [ diff --git a/x-pack/plugins/stack_connectors/common/gen_ai/constants.ts b/x-pack/plugins/stack_connectors/common/gen_ai/constants.ts index 0fbf91e10258..6e3d2924ca6c 100644 --- a/x-pack/plugins/stack_connectors/common/gen_ai/constants.ts +++ b/x-pack/plugins/stack_connectors/common/gen_ai/constants.ts @@ -26,6 +26,8 @@ export enum OpenAiProviderType { AzureAi = 'Azure OpenAI', } +export const DEFAULT_OPENAI_MODEL = 'gpt-4'; + export const OPENAI_CHAT_URL = 'https://api.openai.com/v1/chat/completions' as const; export const OPENAI_LEGACY_COMPLETION_URL = 'https://api.openai.com/v1/completions' as const; export const AZURE_OPENAI_CHAT_URL = diff --git a/x-pack/plugins/stack_connectors/common/gen_ai/schema.ts b/x-pack/plugins/stack_connectors/common/gen_ai/schema.ts index 6dc3413fd144..0ec454112044 100644 --- a/x-pack/plugins/stack_connectors/common/gen_ai/schema.ts +++ b/x-pack/plugins/stack_connectors/common/gen_ai/schema.ts @@ -6,16 +6,20 @@ */ import { schema } from '@kbn/config-schema'; -import { OpenAiProviderType } from './constants'; +import { DEFAULT_OPENAI_MODEL, OpenAiProviderType } from './constants'; // Connector schema -export const GenAiConfigSchema = schema.object({ - apiProvider: schema.oneOf([ - schema.literal(OpenAiProviderType.OpenAi as string), - schema.literal(OpenAiProviderType.AzureAi as string), - ]), - apiUrl: schema.string(), -}); +export const GenAiConfigSchema = schema.oneOf([ + schema.object({ + apiProvider: schema.oneOf([schema.literal(OpenAiProviderType.AzureAi)]), + apiUrl: schema.string(), + }), + schema.object({ + apiProvider: schema.oneOf([schema.literal(OpenAiProviderType.OpenAi)]), + apiUrl: schema.string(), + defaultModel: schema.string({ defaultValue: DEFAULT_OPENAI_MODEL }), + }), +]); export const GenAiSecretsSchema = schema.object({ apiKey: schema.string() }); diff --git a/x-pack/plugins/stack_connectors/public/connector_types/gen_ai/connector.test.tsx b/x-pack/plugins/stack_connectors/public/connector_types/gen_ai/connector.test.tsx index 59d47bdb1b4a..5e561615f0bc 100644 --- a/x-pack/plugins/stack_connectors/public/connector_types/gen_ai/connector.test.tsx +++ b/x-pack/plugins/stack_connectors/public/connector_types/gen_ai/connector.test.tsx @@ -10,7 +10,7 @@ import GenerativeAiConnectorFields from './connector'; import { ConnectorFormTestProvider } from '../lib/test_utils'; import { act, fireEvent, render, waitFor } from '@testing-library/react'; import userEvent from '@testing-library/user-event'; -import { OpenAiProviderType } from '../../../common/gen_ai/constants'; +import { DEFAULT_OPENAI_MODEL, OpenAiProviderType } from '../../../common/gen_ai/constants'; import { useKibana } from '@kbn/triggers-actions-ui-plugin/public'; import { useGetDashboard } from './use_get_dashboard'; @@ -26,6 +26,7 @@ const openAiConnector = { config: { apiUrl: 'https://openaiurl.com', apiProvider: OpenAiProviderType.OpenAi, + defaultModel: DEFAULT_OPENAI_MODEL, }, secrets: { apiKey: 'thats-a-nice-looking-key', diff --git a/x-pack/plugins/stack_connectors/public/connector_types/gen_ai/constants.tsx b/x-pack/plugins/stack_connectors/public/connector_types/gen_ai/constants.tsx index 706abf215295..fc06a88fc91f 100644 --- a/x-pack/plugins/stack_connectors/public/connector_types/gen_ai/constants.tsx +++ b/x-pack/plugins/stack_connectors/public/connector_types/gen_ai/constants.tsx @@ -9,7 +9,7 @@ import React from 'react'; import { ConfigFieldSchema, SecretsFieldSchema } from '@kbn/triggers-actions-ui-plugin/public'; import { FormattedMessage } from '@kbn/i18n-react'; import { EuiLink } from '@elastic/eui'; -import { OpenAiProviderType } from '../../../common/gen_ai/constants'; +import { DEFAULT_OPENAI_MODEL, OpenAiProviderType } from '../../../common/gen_ai/constants'; import * as i18n from './translations'; export const DEFAULT_URL = 'https://api.openai.com/v1/chat/completions' as const; @@ -17,7 +17,6 @@ export const DEFAULT_URL_AZURE = 'https://{your-resource-name}.openai.azure.com/openai/deployments/{deployment-id}/chat/completions?api-version={api-version}' as const; export const DEFAULT_BODY = `{ - "model":"gpt-3.5-turbo", "messages": [{ "role":"user", "content":"Hello world" @@ -54,6 +53,28 @@ export const openAiConfig: ConfigFieldSchema[] = [ /> ), }, + { + id: 'defaultModel', + label: i18n.DEFAULT_MODEL_LABEL, + helpText: ( + + {`${i18n.OPEN_AI} ${i18n.DOCUMENTATION}`} + + ), + }} + /> + ), + defaultValue: DEFAULT_OPENAI_MODEL, + }, ]; export const azureAiConfig: ConfigFieldSchema[] = [ diff --git a/x-pack/plugins/stack_connectors/public/connector_types/gen_ai/translations.ts b/x-pack/plugins/stack_connectors/public/connector_types/gen_ai/translations.ts index 23c6205b0db2..a407413faa79 100644 --- a/x-pack/plugins/stack_connectors/public/connector_types/gen_ai/translations.ts +++ b/x-pack/plugins/stack_connectors/public/connector_types/gen_ai/translations.ts @@ -18,6 +18,21 @@ export const API_KEY_LABEL = i18n.translate('xpack.stackConnectors.components.ge defaultMessage: 'API Key', }); +export const DEFAULT_MODEL_LABEL = i18n.translate( + 'xpack.stackConnectors.components.genAi.defaultModelTextFieldLabel', + { + defaultMessage: 'Default model', + } +); + +export const DEFAULT_MODEL_TOOLTIP_CONTENT = i18n.translate( + 'xpack.stackConnectors.components.genAi.defaultModelTooltipContent', + { + defaultMessage: + 'The model can be set on a per request basis by including a "model" parameter in the request body. If no model is provided, the fallback will be the default model.', + } +); + export const API_PROVIDER_LABEL = i18n.translate( 'xpack.stackConnectors.components.genAi.apiProviderLabel', { diff --git a/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/gen_ai.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/gen_ai.test.ts index f999a01fc274..31f88dd0edb0 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/gen_ai.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/gen_ai.test.ts @@ -7,7 +7,11 @@ import { GenAiConnector } from './gen_ai'; import { actionsConfigMock } from '@kbn/actions-plugin/server/actions_config.mock'; -import { GEN_AI_CONNECTOR_ID, OpenAiProviderType } from '../../../common/gen_ai/constants'; +import { + DEFAULT_OPENAI_MODEL, + GEN_AI_CONNECTOR_ID, + OpenAiProviderType, +} from '../../../common/gen_ai/constants'; import { loggingSystemMock } from '@kbn/core-logging-server-mocks'; import { actionsMock } from '@kbn/actions-plugin/server/mocks'; import { @@ -35,6 +39,7 @@ describe('GenAiConnector', () => { config: { apiUrl: 'https://api.openai.com/v1/chat/completions', apiProvider: OpenAiProviderType.OpenAi, + defaultModel: DEFAULT_OPENAI_MODEL, }, secrets: { apiKey: '123' }, logger: loggingSystemMock.createLogger(), @@ -42,7 +47,6 @@ describe('GenAiConnector', () => { }); const sampleOpenAiBody = { - model: 'gpt-3.5-turbo', messages: [ { role: 'user', @@ -58,6 +62,39 @@ describe('GenAiConnector', () => { }); describe('runApi', () => { + it('uses the default model if none is supplied', async () => { + const response = await connector.runApi({ body: JSON.stringify(sampleOpenAiBody) }); + expect(mockRequest).toBeCalledTimes(1); + expect(mockRequest).toHaveBeenCalledWith({ + url: 'https://api.openai.com/v1/chat/completions', + method: 'post', + responseSchema: GenAiRunActionResponseSchema, + data: JSON.stringify({ ...sampleOpenAiBody, stream: false, model: DEFAULT_OPENAI_MODEL }), + headers: { + Authorization: 'Bearer 123', + 'content-type': 'application/json', + }, + }); + expect(response).toEqual({ result: 'success' }); + }); + + it('overrides the default model with the default model specified in the body', async () => { + const requestBody = { model: 'gpt-3.5-turbo', ...sampleOpenAiBody }; + const response = await connector.runApi({ body: JSON.stringify(requestBody) }); + expect(mockRequest).toBeCalledTimes(1); + expect(mockRequest).toHaveBeenCalledWith({ + url: 'https://api.openai.com/v1/chat/completions', + method: 'post', + responseSchema: GenAiRunActionResponseSchema, + data: JSON.stringify({ ...requestBody, stream: false }), + headers: { + Authorization: 'Bearer 123', + 'content-type': 'application/json', + }, + }); + expect(response).toEqual({ result: 'success' }); + }); + it('the OpenAI API call is successful with correct parameters', async () => { const response = await connector.runApi({ body: JSON.stringify(sampleOpenAiBody) }); expect(mockRequest).toBeCalledTimes(1); @@ -65,7 +102,7 @@ describe('GenAiConnector', () => { url: 'https://api.openai.com/v1/chat/completions', method: 'post', responseSchema: GenAiRunActionResponseSchema, - data: JSON.stringify({ ...sampleOpenAiBody, stream: false }), + data: JSON.stringify({ ...sampleOpenAiBody, stream: false, model: DEFAULT_OPENAI_MODEL }), headers: { Authorization: 'Bearer 123', 'content-type': 'application/json', @@ -128,7 +165,7 @@ describe('GenAiConnector', () => { url: 'https://api.openai.com/v1/chat/completions', method: 'post', responseSchema: GenAiRunActionResponseSchema, - data: JSON.stringify({ ...sampleOpenAiBody, stream: false }), + data: JSON.stringify({ ...sampleOpenAiBody, stream: false, model: DEFAULT_OPENAI_MODEL }), headers: { Authorization: 'Bearer 123', 'content-type': 'application/json', @@ -148,7 +185,7 @@ describe('GenAiConnector', () => { url: 'https://api.openai.com/v1/chat/completions', method: 'post', responseSchema: GenAiStreamingResponseSchema, - data: JSON.stringify({ ...sampleOpenAiBody, stream: true }), + data: JSON.stringify({ ...sampleOpenAiBody, stream: true, model: DEFAULT_OPENAI_MODEL }), headers: { Authorization: 'Bearer 123', 'content-type': 'application/json', diff --git a/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/gen_ai.ts b/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/gen_ai.ts index 29214d18709b..488ec8971141 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/gen_ai.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/gen_ai.ts @@ -88,7 +88,12 @@ export class GenAiConnector extends SubActionConnector { - const sanitizedBody = sanitizeRequest(this.provider, this.url, body); + const sanitizedBody = sanitizeRequest( + this.provider, + this.url, + body, + ...('defaultModel' in this.config ? [this.config.defaultModel] : []) + ); const axiosOptions = getAxiosOptions(this.provider, this.key, false); const response = await this.request({ url: this.url, @@ -104,7 +109,13 @@ export class GenAiConnector extends SubActionConnector { - const executeBody = getRequestWithStreamOption(this.provider, this.url, body, stream); + const executeBody = getRequestWithStreamOption( + this.provider, + this.url, + body, + stream, + ...('defaultModel' in this.config ? [this.config.defaultModel] : []) + ); const axiosOptions = getAxiosOptions(this.provider, this.key, stream); const response = await this.request({ url: this.url, diff --git a/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/index.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/index.test.ts index bf279a1739f8..75611b610dbe 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/index.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/index.test.ts @@ -11,7 +11,7 @@ import axios from 'axios'; import { configValidator, getConnectorType } from '.'; import { GenAiConfig, GenAiSecrets } from '../../../common/gen_ai/types'; import { SubActionConnectorType } from '@kbn/actions-plugin/server/sub_action_framework/types'; -import { OpenAiProviderType } from '../../../common/gen_ai/constants'; +import { DEFAULT_OPENAI_MODEL, OpenAiProviderType } from '../../../common/gen_ai/constants'; jest.mock('axios'); jest.mock('@kbn/actions-plugin/server/lib/axios_utils', () => { @@ -44,6 +44,7 @@ describe('Generative AI Connector', () => { const config: GenAiConfig = { apiUrl: 'https://api.openai.com/v1/chat/completions', apiProvider: OpenAiProviderType.OpenAi, + defaultModel: DEFAULT_OPENAI_MODEL, }; expect(configValidator(config, { configurationUtilities })).toEqual(config); @@ -53,6 +54,7 @@ describe('Generative AI Connector', () => { const config: GenAiConfig = { apiUrl: 'example.com/do-something', apiProvider: OpenAiProviderType.OpenAi, + defaultModel: DEFAULT_OPENAI_MODEL, }; expect(() => { configValidator(config, { configurationUtilities }); @@ -64,7 +66,8 @@ describe('Generative AI Connector', () => { test('config validation failed when the OpenAI API provider is empty', () => { const config: GenAiConfig = { apiUrl: 'https://api.openai.com/v1/chat/completions', - apiProvider: '', + apiProvider: '' as OpenAiProviderType, + defaultModel: DEFAULT_OPENAI_MODEL, }; expect(() => { configValidator(config, { configurationUtilities }); @@ -76,7 +79,8 @@ describe('Generative AI Connector', () => { test('config validation failed when the OpenAI API provider is invalid', () => { const config: GenAiConfig = { apiUrl: 'https://api.openai.com/v1/chat/completions', - apiProvider: 'bad-one', + apiProvider: 'bad-one' as OpenAiProviderType, + defaultModel: DEFAULT_OPENAI_MODEL, }; expect(() => { configValidator(config, { configurationUtilities }); @@ -96,6 +100,7 @@ describe('Generative AI Connector', () => { const config: GenAiConfig = { apiUrl: 'http://mylisteningserver.com:9200/endpoint', apiProvider: OpenAiProviderType.OpenAi, + defaultModel: DEFAULT_OPENAI_MODEL, }; expect(() => { diff --git a/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/index.ts b/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/index.ts index 3d1841381b61..f845215ddba4 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/index.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/index.ts @@ -46,15 +46,12 @@ export const configValidator = ( assertURL(configObject.apiUrl); urlAllowListValidator('apiUrl')(configObject, validatorServices); - if ( - configObject.apiProvider !== OpenAiProviderType.OpenAi && - configObject.apiProvider !== OpenAiProviderType.AzureAi - ) { + const { apiProvider } = configObject; + + if (apiProvider !== OpenAiProviderType.OpenAi && apiProvider !== OpenAiProviderType.AzureAi) { throw new Error( `API Provider is not supported${ - configObject.apiProvider && configObject.apiProvider.length - ? `: ${configObject.apiProvider}` - : `` + apiProvider && (apiProvider as OpenAiProviderType).length ? `: ${apiProvider}` : `` }` ); } diff --git a/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/lib/openai_utils.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/lib/openai_utils.test.ts index d88d739d08bb..17e9b2365ae9 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/lib/openai_utils.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/lib/openai_utils.test.ts @@ -6,7 +6,11 @@ */ import { sanitizeRequest, getRequestWithStreamOption } from './openai_utils'; -import { OPENAI_CHAT_URL, OPENAI_LEGACY_COMPLETION_URL } from '../../../../common/gen_ai/constants'; +import { + DEFAULT_OPENAI_MODEL, + OPENAI_CHAT_URL, + OPENAI_LEGACY_COMPLETION_URL, +} from '../../../../common/gen_ai/constants'; describe('Open AI Utils', () => { describe('sanitizeRequest', () => { @@ -23,7 +27,11 @@ describe('Open AI Utils', () => { }; [OPENAI_CHAT_URL, OPENAI_LEGACY_COMPLETION_URL].forEach((url: string) => { - const sanitizedBodyString = sanitizeRequest(url, JSON.stringify(body)); + const sanitizedBodyString = sanitizeRequest( + url, + JSON.stringify(body), + DEFAULT_OPENAI_MODEL + ); expect(sanitizedBodyString).toEqual( `{\"model\":\"gpt-4\",\"stream\":false,\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}]}` ); @@ -42,7 +50,11 @@ describe('Open AI Utils', () => { }; [OPENAI_CHAT_URL, OPENAI_LEGACY_COMPLETION_URL].forEach((url: string) => { - const sanitizedBodyString = sanitizeRequest(url, JSON.stringify(body)); + const sanitizedBodyString = sanitizeRequest( + url, + JSON.stringify(body), + DEFAULT_OPENAI_MODEL + ); expect(sanitizedBodyString).toEqual( `{\"model\":\"gpt-4\",\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}],\"stream\":false}` ); @@ -62,7 +74,11 @@ describe('Open AI Utils', () => { }; [OPENAI_CHAT_URL, OPENAI_LEGACY_COMPLETION_URL].forEach((url: string) => { - const sanitizedBodyString = sanitizeRequest(url, JSON.stringify(body)); + const sanitizedBodyString = sanitizeRequest( + url, + JSON.stringify(body), + DEFAULT_OPENAI_MODEL + ); expect(sanitizedBodyString).toEqual( `{\"model\":\"gpt-4\",\"stream\":false,\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}]}` ); @@ -73,7 +89,7 @@ describe('Open AI Utils', () => { const bodyString = `{\"model\":\"gpt-4\",\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}],,}`; [OPENAI_CHAT_URL, OPENAI_LEGACY_COMPLETION_URL].forEach((url: string) => { - const sanitizedBodyString = sanitizeRequest(url, bodyString); + const sanitizedBodyString = sanitizeRequest(url, bodyString, DEFAULT_OPENAI_MODEL); expect(sanitizedBodyString).toEqual(bodyString); }); }); @@ -81,7 +97,11 @@ describe('Open AI Utils', () => { it('does nothing when url does not accept stream parameter', () => { const bodyString = `{\"model\":\"gpt-4\",\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}]}`; - const sanitizedBodyString = sanitizeRequest('https://randostring.ai', bodyString); + const sanitizedBodyString = sanitizeRequest( + 'https://randostring.ai', + bodyString, + DEFAULT_OPENAI_MODEL + ); expect(sanitizedBodyString).toEqual(bodyString); }); }); @@ -99,7 +119,12 @@ describe('Open AI Utils', () => { }; [OPENAI_CHAT_URL, OPENAI_LEGACY_COMPLETION_URL].forEach((url: string) => { - const sanitizedBodyString = getRequestWithStreamOption(url, JSON.stringify(body), true); + const sanitizedBodyString = getRequestWithStreamOption( + url, + JSON.stringify(body), + true, + DEFAULT_OPENAI_MODEL + ); expect(sanitizedBodyString).toEqual( `{\"model\":\"gpt-4\",\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}],\"stream\":true}` ); @@ -119,7 +144,12 @@ describe('Open AI Utils', () => { }; [OPENAI_CHAT_URL, OPENAI_LEGACY_COMPLETION_URL].forEach((url: string) => { - const sanitizedBodyString = getRequestWithStreamOption(url, JSON.stringify(body), false); + const sanitizedBodyString = getRequestWithStreamOption( + url, + JSON.stringify(body), + false, + DEFAULT_OPENAI_MODEL + ); expect(sanitizedBodyString).toEqual( `{\"model\":\"gpt-4\",\"stream\":false,\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}]}` ); @@ -130,7 +160,12 @@ describe('Open AI Utils', () => { const bodyString = `{\"model\":\"gpt-4\",\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}],,}`; [OPENAI_CHAT_URL, OPENAI_LEGACY_COMPLETION_URL].forEach((url: string) => { - const sanitizedBodyString = getRequestWithStreamOption(url, bodyString, false); + const sanitizedBodyString = getRequestWithStreamOption( + url, + bodyString, + false, + DEFAULT_OPENAI_MODEL + ); expect(sanitizedBodyString).toEqual(bodyString); }); }); @@ -141,7 +176,8 @@ describe('Open AI Utils', () => { const sanitizedBodyString = getRequestWithStreamOption( 'https://randostring.ai', bodyString, - true + true, + DEFAULT_OPENAI_MODEL ); expect(sanitizedBodyString).toEqual(bodyString); }); diff --git a/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/lib/openai_utils.ts b/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/lib/openai_utils.ts index 85a9779098a2..aacae0d7bc0c 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/lib/openai_utils.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/lib/openai_utils.ts @@ -17,8 +17,8 @@ const APIS_ALLOWING_STREAMING = new Set([OPENAI_CHAT_URL, OPENAI_LEGACY_ * The stream parameter is accepted in the ChatCompletion * API and the Completion API only */ -export const sanitizeRequest = (url: string, body: string): string => { - return getRequestWithStreamOption(url, body, false); +export const sanitizeRequest = (url: string, body: string, defaultModel: string): string => { + return getRequestWithStreamOption(url, body, false, defaultModel); }; /** @@ -27,7 +27,12 @@ export const sanitizeRequest = (url: string, body: string): string => { * The stream parameter is accepted in the ChatCompletion * API and the Completion API only */ -export const getRequestWithStreamOption = (url: string, body: string, stream: boolean): string => { +export const getRequestWithStreamOption = ( + url: string, + body: string, + stream: boolean, + defaultModel: string +): string => { if (!APIS_ALLOWING_STREAMING.has(url)) { return body; } @@ -36,6 +41,7 @@ export const getRequestWithStreamOption = (url: string, body: string, stream: bo const jsonBody = JSON.parse(body); if (jsonBody) { jsonBody.stream = stream; + jsonBody.model = jsonBody.model || defaultModel; } return JSON.stringify(jsonBody); diff --git a/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/lib/utils.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/lib/utils.test.ts index 0d7fa5606673..c50b513661ba 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/lib/utils.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/lib/utils.test.ts @@ -6,7 +6,11 @@ */ import { sanitizeRequest, getRequestWithStreamOption, getAxiosOptions } from './utils'; -import { OpenAiProviderType, OPENAI_CHAT_URL } from '../../../../common/gen_ai/constants'; +import { + DEFAULT_OPENAI_MODEL, + OpenAiProviderType, + OPENAI_CHAT_URL, +} from '../../../../common/gen_ai/constants'; import { sanitizeRequest as openAiSanitizeRequest, getRequestWithStreamOption as openAiGetRequestWithStreamOption, @@ -39,8 +43,12 @@ describe('Utils', () => { }); it('calls openai_utils sanitizeRequest when provider is OpenAi', () => { - sanitizeRequest(OpenAiProviderType.OpenAi, OPENAI_CHAT_URL, bodyString); - expect(mockOpenAiSanitizeRequest).toHaveBeenCalledWith(OPENAI_CHAT_URL, bodyString); + sanitizeRequest(OpenAiProviderType.OpenAi, OPENAI_CHAT_URL, bodyString, DEFAULT_OPENAI_MODEL); + expect(mockOpenAiSanitizeRequest).toHaveBeenCalledWith( + OPENAI_CHAT_URL, + bodyString, + DEFAULT_OPENAI_MODEL + ); expect(mockAzureAiSanitizeRequest).not.toHaveBeenCalled(); }); @@ -65,12 +73,19 @@ describe('Utils', () => { }); it('calls openai_utils getRequestWithStreamOption when provider is OpenAi', () => { - getRequestWithStreamOption(OpenAiProviderType.OpenAi, OPENAI_CHAT_URL, bodyString, true); + getRequestWithStreamOption( + OpenAiProviderType.OpenAi, + OPENAI_CHAT_URL, + bodyString, + true, + DEFAULT_OPENAI_MODEL + ); expect(mockOpenAiGetRequestWithStreamOption).toHaveBeenCalledWith( OPENAI_CHAT_URL, bodyString, - true + true, + DEFAULT_OPENAI_MODEL ); expect(mockAzureAiGetRequestWithStreamOption).not.toHaveBeenCalled(); }); @@ -87,7 +102,12 @@ describe('Utils', () => { }); it('does not call any helper fns when provider is unrecognized', () => { - getRequestWithStreamOption('foo', OPENAI_CHAT_URL, bodyString, true); + getRequestWithStreamOption( + 'foo' as unknown as OpenAiProviderType, + OPENAI_CHAT_URL, + bodyString, + true + ); expect(mockOpenAiGetRequestWithStreamOption).not.toHaveBeenCalled(); expect(mockAzureAiGetRequestWithStreamOption).not.toHaveBeenCalled(); }); diff --git a/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/lib/utils.ts b/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/lib/utils.ts index 2e76961a4e17..b0c953eaa3ae 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/lib/utils.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/lib/utils.ts @@ -17,10 +17,15 @@ import { getRequestWithStreamOption as azureAiGetRequestWithStreamOption, } from './azure_openai_utils'; -export const sanitizeRequest = (provider: string, url: string, body: string): string => { +export const sanitizeRequest = ( + provider: string, + url: string, + body: string, + defaultModel?: string +): string => { switch (provider) { case OpenAiProviderType.OpenAi: - return openAiSanitizeRequest(url, body); + return openAiSanitizeRequest(url, body, defaultModel!); case OpenAiProviderType.AzureAi: return azureAiSanitizeRequest(url, body); default: @@ -28,21 +33,45 @@ export const sanitizeRequest = (provider: string, url: string, body: string): st } }; -export const getRequestWithStreamOption = ( - provider: string, +export function getRequestWithStreamOption( + provider: OpenAiProviderType.OpenAi, + url: string, + body: string, + stream: boolean, + defaultModel: string +): string; + +export function getRequestWithStreamOption( + provider: OpenAiProviderType.AzureAi, url: string, body: string, stream: boolean -): string => { +): string; + +export function getRequestWithStreamOption( + provider: OpenAiProviderType, + url: string, + body: string, + stream: boolean, + defaultModel?: string +): string; + +export function getRequestWithStreamOption( + provider: string, + url: string, + body: string, + stream: boolean, + defaultModel?: string +): string { switch (provider) { case OpenAiProviderType.OpenAi: - return openAiGetRequestWithStreamOption(url, body, stream); + return openAiGetRequestWithStreamOption(url, body, stream, defaultModel!); case OpenAiProviderType.AzureAi: return azureAiGetRequestWithStreamOption(url, body, stream); default: return body; } -}; +} export const getAxiosOptions = ( provider: string, diff --git a/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/gen_ai.ts b/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/gen_ai.ts index 63b550b1d7b3..b52d9c14cc3a 100644 --- a/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/gen_ai.ts +++ b/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/gen_ai.ts @@ -67,7 +67,7 @@ export default function genAiTest({ getService }: FtrProviderContext) { simulator.close(); }); - it('should return 200 when creating the connector', async () => { + it('should return 200 when creating the connector without a default model', async () => { const { body: createdAction } = await supertest .post('/api/actions/connector') .set('kbn-xsrf', 'foo') @@ -87,7 +87,40 @@ export default function genAiTest({ getService }: FtrProviderContext) { name, connector_type_id: connectorTypeId, is_missing_secrets: false, - config, + config: { + ...config, + defaultModel: 'gpt-4', + }, + }); + }); + + it('should return 200 when creating the connector with a default model', async () => { + const { body: createdAction } = await supertest + .post('/api/actions/connector') + .set('kbn-xsrf', 'foo') + .send({ + name, + connector_type_id: connectorTypeId, + config: { + ...config, + defaultModel: 'gpt-3.5-turbo', + }, + secrets, + }) + .expect(200); + + expect(createdAction).to.eql({ + id: createdAction.id, + is_preconfigured: false, + is_system_action: false, + is_deprecated: false, + name, + connector_type_id: connectorTypeId, + is_missing_secrets: false, + config: { + ...config, + defaultModel: 'gpt-3.5-turbo', + }, }); }); @@ -111,7 +144,7 @@ export default function genAiTest({ getService }: FtrProviderContext) { statusCode: 400, error: 'Bad Request', message: - 'error validating action type config: [apiProvider]: expected at least one defined value but got [undefined]', + 'error validating action type config: types that failed validation:\n- [0.apiProvider]: expected at least one defined value but got [undefined]\n- [1.apiProvider]: expected at least one defined value but got [undefined]', }); }); }); @@ -132,7 +165,7 @@ export default function genAiTest({ getService }: FtrProviderContext) { statusCode: 400, error: 'Bad Request', message: - 'error validating action type config: [apiUrl]: expected value of type [string] but got [undefined]', + 'error validating action type config: types that failed validation:\n- [0.apiProvider]: expected value to equal [Azure OpenAI]\n- [1.apiUrl]: expected value of type [string] but got [undefined]', }); }); });