Skip to content

Commit

Permalink
[Inference Connector] Modified getProvider to use _inference/_service…
Browse files Browse the repository at this point in the history
…s ES API instead of hardcoded values. (elastic#199047)

@ymao1
[introduced](elastic/elasticsearch#114862) new
ES API which allows to fetch available services providers list with the
settings and task types:
`GET _inference/_services` 
This PR is changing hardcoded providers list
`x-pack/plugins/stack_connectors/public/connector_types/inference/providers/get_providers.ts`
to use new ES API.
  • Loading branch information
YulNaumenko authored Nov 12, 2024
1 parent 4e65ae9 commit abf6a1d
Show file tree
Hide file tree
Showing 20 changed files with 301 additions and 1,761 deletions.
15 changes: 15 additions & 0 deletions x-pack/plugins/stack_connectors/common/inference/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
TextEmbeddingParamsSchema,
TextEmbeddingResponseSchema,
} from './schema';
import { ConfigProperties } from '../dynamic_config/types';

export type Config = TypeOf<typeof ConfigSchema>;
export type Secrets = TypeOf<typeof SecretsSchema>;
Expand All @@ -36,3 +37,17 @@ export type TextEmbeddingParams = TypeOf<typeof TextEmbeddingParamsSchema>;
export type TextEmbeddingResponse = TypeOf<typeof TextEmbeddingResponseSchema>;

export type StreamingResponse = TypeOf<typeof StreamingResponseSchema>;

export type FieldsConfiguration = Record<string, ConfigProperties>;

export interface InferenceTaskType {
task_type: string;
configuration: FieldsConfiguration;
}

export interface InferenceProvider {
provider: string;
task_types: InferenceTaskType[];
logo?: string;
configuration: FieldsConfiguration;
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ import {
import { FormattedMessage } from '@kbn/i18n-react';

import { fieldValidators } from '@kbn/es-ui-shared-plugin/static/forms/helpers';
import { ConfigEntryView } from '../../../common/dynamic_config/types';
import { ConnectorConfigurationFormItems } from '../lib/dynamic_config/connector_configuration_form_items';
import * as i18n from './translations';
import { DEFAULT_TASK_TYPE } from './constants';
import { ConfigEntryView } from '../lib/dynamic_config/types';
import { Config } from './types';
import { TaskTypeOption } from './helpers';

Expand All @@ -52,7 +52,7 @@ interface AdditionalOptionsConnectorFieldsProps {
isEdit: boolean;
optionalProviderFormFields: ConfigEntryView[];
onSetProviderConfigEntry: (key: string, value: unknown) => Promise<void>;
onTaskTypeOptionsSelect: (taskType: string, provider?: string) => Promise<void>;
onTaskTypeOptionsSelect: (taskType: string, provider?: string) => void;
selectedTaskType?: string;
taskTypeFormFields: ConfigEntryView[];
taskTypeSchema: ConfigEntryView[];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@ import { ConnectorFormTestProvider } from '../lib/test_utils';
import { render, waitFor } from '@testing-library/react';
import userEvent from '@testing-library/user-event';
import { createStartServicesMock } from '@kbn/triggers-actions-ui-plugin/public/common/lib/kibana/kibana_react.mock';
import { DisplayType, FieldType } from '../lib/dynamic_config/types';
import { useProviders } from './providers/get_providers';
import { getTaskTypes } from './get_task_types';
import { HttpSetup } from '@kbn/core-http-browser';
import { DisplayType, FieldType } from '../../../common/dynamic_config/types';

jest.mock('./providers/get_providers');
jest.mock('./get_task_types');

const mockUseKibanaReturnValue = createStartServicesMock();
jest.mock('@kbn/triggers-actions-ui-plugin/public/common/lib/kibana', () => ({
Expand All @@ -37,13 +34,32 @@ jest.mock('@faker-js/faker', () => ({
}));

const mockProviders = useProviders as jest.Mock;
const mockTaskTypes = getTaskTypes as jest.Mock;

const providersSchemas = [
{
provider: 'openai',
logo: '', // should be openai logo here, the hardcoded uses assets/images
taskTypes: ['completion', 'text_embedding'],
task_types: [
{
task_type: 'completion',
configuration: {
user: {
display: DisplayType.TEXTBOX,
label: 'User',
order: 1,
required: false,
sensitive: false,
tooltip: 'Specifies the user issuing the request.',
type: FieldType.STRING,
validations: [],
value: '',
ui_restrictions: [],
default_value: null,
depends_on: [],
},
},
},
],
configuration: {
api_key: {
display: DisplayType.TEXTBOX,
Expand Down Expand Up @@ -106,7 +122,16 @@ const providersSchemas = [
{
provider: 'googleaistudio',
logo: '', // should be googleaistudio logo here, the hardcoded uses assets/images
taskTypes: ['completion', 'text_embedding'],
task_types: [
{
task_type: 'completion',
configuration: {},
},
{
task_type: 'text_embedding',
configuration: {},
},
],
configuration: {
api_key: {
display: DisplayType.TEXTBOX,
Expand Down Expand Up @@ -139,39 +164,6 @@ const providersSchemas = [
},
},
];
const taskTypesSchemas: Record<string, any> = {
googleaistudio: [
{
task_type: 'completion',
configuration: {},
},
{
task_type: 'text_embedding',
configuration: {},
},
],
openai: [
{
task_type: 'completion',
configuration: {
user: {
display: DisplayType.TEXTBOX,
label: 'User',
order: 1,
required: false,
sensitive: false,
tooltip: 'Specifies the user issuing the request.',
type: FieldType.STRING,
validations: [],
value: '',
ui_restrictions: [],
default_value: null,
depends_on: [],
},
},
},
],
};

const openAiConnector = {
actionTypeId: '.inference',
Expand Down Expand Up @@ -222,9 +214,6 @@ describe('ConnectorFields renders', () => {
isLoading: false,
data: providersSchemas,
});
mockTaskTypes.mockImplementation(
(http: HttpSetup, provider: string) => taskTypesSchemas[provider]
);
});
test('openai provider fields are rendered', async () => {
const { getAllByTestId } = render(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

import React, { useState, useEffect, useCallback } from 'react';
import React, { useState, useEffect, useCallback, useMemo } from 'react';
import {
EuiFormRow,
EuiSpacer,
Expand All @@ -31,12 +31,12 @@ import {
import { useKibana } from '@kbn/triggers-actions-ui-plugin/public';

import { fieldValidators } from '@kbn/es-ui-shared-plugin/static/forms/helpers';
import { ConfigEntryView } from '../../../common/dynamic_config/types';
import { InferenceTaskType } from '../../../common/inference/types';
import { ServiceProviderKeys } from '../../../common/inference/constants';
import { ConnectorConfigurationFormItems } from '../lib/dynamic_config/connector_configuration_form_items';
import { getTaskTypes } from './get_task_types';
import * as i18n from './translations';
import { DEFAULT_TASK_TYPE } from './constants';
import { ConfigEntryView } from '../lib/dynamic_config/types';
import { SelectableProvider } from './providers/selectable';
import { Config, Secrets } from './types';
import { generateInferenceEndpointId, getTaskTypeOptions, TaskTypeOption } from './helpers';
Expand Down Expand Up @@ -116,13 +116,13 @@ const InferenceAPIConnectorFields: React.FunctionComponent<ActionConnectorFields
}, [isSubmitting, config, validateFields]);

const onTaskTypeOptionsSelect = useCallback(
async (taskType: string, provider?: string) => {
(taskType: string, provider?: string) => {
// Get task type settings
const currentTaskTypes = await getTaskTypes(http, provider ?? config?.provider);
const currentProvider = providers?.find((p) => p.provider === (provider ?? config?.provider));
const currentTaskTypes = currentProvider?.task_types;
const newTaskType = currentTaskTypes?.find((p) => p.task_type === taskType);

setSelectedTaskType(taskType);
generateInferenceEndpointId(config, setFieldValue);

// transform the schema
const newTaskTypeSchema = Object.keys(newTaskType?.configuration ?? {}).map((k) => ({
Expand Down Expand Up @@ -150,19 +150,23 @@ const InferenceAPIConnectorFields: React.FunctionComponent<ActionConnectorFields
taskTypeConfig: configDefaults,
},
});
generateInferenceEndpointId(
{ ...config, taskType, taskTypeConfig: configDefaults },
setFieldValue
);
},
[config, http, setFieldValue, updateFieldValues]
[config, providers, setFieldValue, updateFieldValues]
);

const onProviderChange = useCallback(
async (provider?: string) => {
(provider?: string) => {
const newProvider = providers?.find((p) => p.provider === provider);

// Update task types list available for the selected provider
const providerTaskTypes = newProvider?.taskTypes ?? [];
const providerTaskTypes = (newProvider?.task_types ?? []).map((t) => t.task_type);
setTaskTypeOptions(getTaskTypeOptions(providerTaskTypes));
if (providerTaskTypes.length > 0) {
await onTaskTypeOptionsSelect(providerTaskTypes[0], provider);
onTaskTypeOptionsSelect(providerTaskTypes[0], provider);
}

// Update connector providerSchema
Expand Down Expand Up @@ -203,9 +207,8 @@ const InferenceAPIConnectorFields: React.FunctionComponent<ActionConnectorFields
);

useEffect(() => {
const getTaskTypeSchema = async () => {
const currentTaskTypes = await getTaskTypes(http, config?.provider ?? '');
const newTaskType = currentTaskTypes?.find((p) => p.task_type === config?.taskType);
const getTaskTypeSchema = (taskTypes: InferenceTaskType[]) => {
const newTaskType = taskTypes.find((p) => p.task_type === config?.taskType);

// transform the schema
const newTaskTypeSchema = Object.keys(newTaskType?.configuration ?? {}).map((k) => ({
Expand All @@ -228,7 +231,7 @@ const InferenceAPIConnectorFields: React.FunctionComponent<ActionConnectorFields

setProviderSchema(newProviderSchema);

getTaskTypeSchema();
getTaskTypeSchema(newProvider?.task_types ?? []);
}
}, [config?.provider, config?.taskType, http, isEdit, providers]);

Expand Down Expand Up @@ -309,6 +312,22 @@ const InferenceAPIConnectorFields: React.FunctionComponent<ActionConnectorFields
setFieldValue('config.provider', '');
}, [onProviderChange, setFieldValue]);

const providerIcon = useMemo(
() =>
Object.keys(SERVICE_PROVIDERS).includes(config?.provider)
? SERVICE_PROVIDERS[config?.provider as ServiceProviderKeys].icon
: undefined,
[config?.provider]
);

const providerName = useMemo(
() =>
Object.keys(SERVICE_PROVIDERS).includes(config?.provider)
? SERVICE_PROVIDERS[config?.provider as ServiceProviderKeys].name
: config?.provider,
[config?.provider]
);

const providerSuperSelect = useCallback(
(isInvalid: boolean) => (
<EuiFormControlLayout
Expand All @@ -317,21 +336,15 @@ const InferenceAPIConnectorFields: React.FunctionComponent<ActionConnectorFields
isDisabled={isEdit || readOnly}
isInvalid={isInvalid}
fullWidth
icon={
!config?.provider
? { type: 'sparkles', side: 'left' }
: SERVICE_PROVIDERS[config?.provider as ServiceProviderKeys].icon
}
icon={!config?.provider ? { type: 'sparkles', side: 'left' } : providerIcon}
>
<EuiFieldText
onClick={handleProviderPopover}
data-test-subj="provider-select"
isInvalid={isInvalid}
disabled={isEdit || readOnly}
onKeyDown={handleProviderKeyboardOpen}
value={
config?.provider ? SERVICE_PROVIDERS[config?.provider as ServiceProviderKeys].name : ''
}
value={config?.provider ? providerName : ''}
fullWidth
placeholder={i18n.SELECT_PROVIDER}
icon={{ type: 'arrowDown', side: 'right' }}
Expand All @@ -345,8 +358,10 @@ const InferenceAPIConnectorFields: React.FunctionComponent<ActionConnectorFields
readOnly,
onClearProvider,
config?.provider,
providerIcon,
handleProviderPopover,
handleProviderKeyboardOpen,
providerName,
isProviderPopoverOpen,
]
);
Expand Down

This file was deleted.

Loading

0 comments on commit abf6a1d

Please sign in to comment.