Skip to content

Commit

Permalink
[Index management] Fix to enable adding new fields with new inference…
Browse files Browse the repository at this point in the history
… endpoint created from inference Flyout (#183869)

## Summary
Fixes bug, enabling adding a new fields with inference endpoints created
from inference flyout.


**Background**

This [PR](#180330) introduced a
package `@kbn/inference_integration_flyout` to enable adding a new
inference endpoint with Elasticsearch models - ELSER & E5, Third party
model - HuggingFace, Cohere, OpenAI and uploading via ELAND instructions
in a Flyout.

This flyout component is used by Add field component, when it's a
`semantic_text` field and would like to create a new inference endpoint
id for the field.

**Bug description**

After a new inference endpoint is created, Cannot add field with this
inference endpoint .


https://github.com/elastic/kibana/assets/55930906/ad62fdbb-4ef6-40af-9812-5d91b7e63c26


**Expected** 

Should be able to add a new field with this new inference endpoint 


**Testing instructions**

**Elasticsearch changes (only to test save mappings)**

Since ES changes for the semantic_text has been merged to main, this can
be tested against running ES from source or from latest snapshot
1. Update local branch with latest Elasticsearch changes from main
2. Run the elasticsearch: `./gradlew :run -Drun.license_type=trial`


**Manual test in UI**

1. Set is
[isSemanticTextEnabled](https://github.com/elastic/kibana/blob/e89b991d7473caba2a3a5b6204080f50100c67b9/x-pack/plugins/index_management/public/application/sections/home/index_list/details_page/details_page_mappings_content.tsx#L72)
to true
4. Add a new field with type - `Semantic_text` 
5. Click on drop down menu below `Select an inference endpoint:` 
6. Click `Add inference Endpoint` 
7.  create a new inference endpoint in Add inference endpoint flyout 
8. the drop down menu list, from Step 3, should have the new inference
endpoint created on Step 5
9. Fill in the reference field and Field name
10.  click add field
11. A new field should be created using this inference endpoint

---------

Co-authored-by: Kibana Machine <[email protected]>
  • Loading branch information
saarikabhasi and kibanamachine authored May 27, 2024
1 parent 193668c commit a9c8e8f
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { SelectInferenceId } from './select_inference_id';

const onChangeMock = jest.fn();
const setValueMock = jest.fn();
const setNewInferenceEndpointMock = jest.fn();

jest.mock('../../../../../app_context', () => ({
useAppContext: jest.fn().mockReturnValue({
Expand All @@ -21,6 +22,7 @@ jest.mock('../../../../../app_context', () => ({
mlApi: {
trainedModels: {
getTrainedModels: jest.fn().mockResolvedValue([]),
getTrainedModelStats: jest.fn().mockResolvedValue([]),
},
},
},
Expand All @@ -38,6 +40,7 @@ describe('SelectInferenceId', () => {
onChange: onChangeMock,
'data-test-subj': 'data-inference-endpoint-list',
setValue: setValueMock,
setNewInferenceEndpoint: setNewInferenceEndpointMock,
},
memoryRouter: { wrapComponent: false },
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ import {
TRAINED_MODEL_TYPE,
} from '@kbn/ml-trained-models-utils';
import { InferenceTaskType } from '@elastic/elasticsearch/lib/api/types';
import { ModelConfig } from '@kbn/inference_integration_flyout/types';
import {
ElasticsearchModelDefaultOptions,
ModelConfig,
Service,
} from '@kbn/inference_integration_flyout/types';
import { FormattedMessage } from '@kbn/i18n-react';
import { InferenceFlyoutWrapper } from '@kbn/inference_integration_flyout/components/inference_flyout_wrapper';
import { TrainedModelConfigResponse } from '@kbn/ml-plugin/common/types/trained_models';
Expand All @@ -39,15 +43,26 @@ import { getFieldConfig } from '../../../lib';
import { useAppContext } from '../../../../../app_context';
import { Form, UseField, useForm } from '../../../shared_imports';
import { useLoadInferenceModels } from '../../../../../services/api';
import { getTrainedModelStats } from '../../../../../../hooks/use_details_page_mappings_model_management';
import { InferenceToModelIdMap } from '../fields';

const inferenceServiceTypeElasticsearchModelMap: Record<string, ElasticsearchModelDefaultOptions> =
{
elser: ElasticsearchModelDefaultOptions.elser,
elasticsearch: ElasticsearchModelDefaultOptions.e5,
};

interface Props {
onChange(value: string): void;
'data-test-subj'?: string;
setValue: (value: string) => void;
setNewInferenceEndpoint: (newInferenceEndpoint: InferenceToModelIdMap) => void;
}
export const SelectInferenceId = ({
onChange,
'data-test-subj': dataTestSubj,
setValue,
setNewInferenceEndpoint,
}: Props) => {
const {
core: { application },
Expand Down Expand Up @@ -135,14 +150,26 @@ export const SelectInferenceId = ({
setIsInferenceFlyoutVisible(!isInferenceFlyoutVisible);
setIsCreateInferenceApiLoading(false);
setInferenceAddError(undefined);
const trainedModelStats = await ml?.mlApi?.trainedModels.getTrainedModelStats();
const defaultEndpointId =
inferenceServiceTypeElasticsearchModelMap[modelConfig.service] || '';
const newModelId: InferenceToModelIdMap = {};
newModelId[inferenceId] = {
trainedModelId: defaultEndpointId,
isDeployable:
modelConfig.service === Service.elser || modelConfig.service === Service.elasticsearch,
isDeployed: getTrainedModelStats(trainedModelStats)[defaultEndpointId] === 'deployed',
defaultInferenceEndpoint: false,
};
resendRequest();
setNewInferenceEndpoint(newModelId);
} catch (error) {
const errorObj = extractErrorProperties(error);
setInferenceAddError(errorObj.message);
setIsCreateInferenceApiLoading(false);
}
},
[isInferenceFlyoutVisible, resendRequest, ml]
[isInferenceFlyoutVisible, resendRequest, ml, setNewInferenceEndpoint]
);
useEffect(() => {
const subscription = subscribe((updateData) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ import {
import { i18n } from '@kbn/i18n';
import { MlPluginStart } from '@kbn/ml-plugin/public';
import classNames from 'classnames';
import React, { useEffect } from 'react';
import React, { useCallback, useEffect } from 'react';
import { EUI_SIZE, TYPE_DEFINITION } from '../../../../constants';
import { fieldSerializer } from '../../../../lib';
import { useDispatch } from '../../../../mappings_state_context';
import { useDispatch, useMappingsState } from '../../../../mappings_state_context';
import { Form, FormDataProvider, UseField, useForm, useFormData } from '../../../../shared_imports';
import { Field, MainType, NormalizedFields } from '../../../../types';
import { NameParameter, SubTypeParameter, TypeParameter } from '../../field_parameters';
Expand Down Expand Up @@ -70,7 +70,6 @@ export const CreateField = React.memo(function CreateFieldComponent({
}: Props) {
const { isSemanticTextEnabled, indexName, ml, setErrorsInTrainedModelDeployment } =
semanticTextInfo ?? {};

const dispatch = useDispatch();

const { form } = useForm<Field>({
Expand Down Expand Up @@ -315,8 +314,26 @@ interface InferenceProps {
}

function InferenceIdCombo({ setValue }: InferenceProps) {
const { inferenceToModelIdMap } = useMappingsState();
const dispatch = useDispatch();
const [{ type }] = useFormData({ watch: 'type' });

// update new inferenceEndpoint
const setNewInferenceEndpoint = useCallback(
(newInferenceEndpoint: InferenceToModelIdMap) => {
dispatch({
type: 'inferenceToModelIdMap.update',
value: {
inferenceToModelIdMap: {
...inferenceToModelIdMap,
...newInferenceEndpoint,
},
},
});
},
[dispatch, inferenceToModelIdMap]
);

if (type === undefined || type[0]?.value !== 'semantic_text') {
return null;
}
Expand All @@ -325,7 +342,13 @@ function InferenceIdCombo({ setValue }: InferenceProps) {
<>
<EuiSpacer />
<UseField path="inferenceId">
{(field) => <SelectInferenceId onChange={field.setValue} setValue={setValue} />}
{(field) => (
<SelectInferenceId
onChange={field.setValue}
setValue={setValue}
setNewInferenceEndpoint={setNewInferenceEndpoint}
/>
)}
</UseField>
</>
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* 2.0.
*/

import { ElasticsearchModelDefaultOptions, Service } from '@kbn/inference_integration_flyout/types';
import { InferenceStatsResponse } from '@kbn/ml-plugin/public/application/services/ml_api_service/trained_models';
import type { InferenceAPIConfigResponse } from '@kbn/ml-trained-models-utils';
import { useCallback, useMemo } from 'react';
Expand All @@ -27,20 +28,24 @@ const getCustomInferenceIdMap = (
) => {
return models?.data.reduce<InferenceToModelIdMap>((inferenceMap, model) => {
const inferenceId = model.model_id;
const trainedModelId =
'model_id' in model.service_settings ? model.service_settings.model_id : '';

const trainedModelId =
'model_id' in model.service_settings &&
(model.service_settings.model_id === ElasticsearchModelDefaultOptions.elser ||
model.service_settings.model_id === ElasticsearchModelDefaultOptions.e5)
? model.service_settings.model_id
: '';
inferenceMap[inferenceId] = {
trainedModelId,
isDeployable: model.service === 'elser' || model.service === 'elasticsearch',
isDeployable: model.service === Service.elser || model.service === Service.elasticsearch,
isDeployed: deploymentStatsByModelId[trainedModelId] === 'deployed',
defaultInferenceEndpoint: false,
};
return inferenceMap;
}, {});
};

const getTrainedModelStats = (modelStats?: InferenceStatsResponse): DeploymentStatusType => {
export const getTrainedModelStats = (modelStats?: InferenceStatsResponse): DeploymentStatusType => {
return (
modelStats?.trained_model_stats.reduce<DeploymentStatusType>((acc, modelStat) => {
if (modelStat.model_id) {
Expand Down

0 comments on commit a9c8e8f

Please sign in to comment.