diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/utils.test.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/utils.test.ts index 3ea6890c41932..4a9c11faa7f73 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/utils.test.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/utils.test.ts @@ -31,11 +31,6 @@ describe('ml inference utils', () => { ner: {}, }, }), - makeFakeModel({ - inference_config: { - classification: {}, - }, - }), makeFakeModel({ inference_config: { text_classification: {}, @@ -53,6 +48,16 @@ describe('ml inference utils', () => { }, }, }), + makeFakeModel({ + inference_config: { + question_answering: {}, + }, + }), + makeFakeModel({ + inference_config: { + fill_mask: {}, + }, + }), ]; for (const model of models) { @@ -61,7 +66,14 @@ describe('ml inference utils', () => { }); it('returns false for expected models', () => { - const models: TrainedModelConfigResponse[] = [makeFakeModel({})]; + const models: TrainedModelConfigResponse[] = [ + makeFakeModel({}), + makeFakeModel({ + inference_config: { + classification: {}, + }, + }), + ]; for (const model of models) { expect(isSupportedMLModel(model)).toBe(false); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/utils.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/utils.ts index 83cf04585b5ff..b788a522d395f 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/utils.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/utils.ts @@ -11,10 +11,11 @@ import { TrainedModelConfigResponse } from '@kbn/ml-plugin/common/types/trained_ import { AddInferencePipelineFormErrors, InferencePipelineConfiguration } from './types'; const NLP_CONFIG_KEYS = [ + 'fill_mask', 'ner', - 'classification', 'text_classification', 'text_embedding', + 'question_answering', 'zero_shot_classification', ]; export const isSupportedMLModel = (model: TrainedModelConfigResponse): boolean => {