From e6767f75d53a1aa837279aaab1d203b5fdd1968a Mon Sep 17 00:00:00 2001 From: Rodney Norris Date: Tue, 20 Dec 2022 10:38:21 -0600 Subject: [PATCH] [Enterprise Search][ML Inference] support setting inference config labels for zero shot models (#147653) ## Summary Some zero shot models do not have default labels, so using them requires setting labels with the inference config on the pipeline. If they are not set ingestion will error when trying to run the inference pipeline. This change adds the ability to set labels on the pipeline inference config for zero shot models. ### Screenshots image With labels: image With non-zero shot model image Co-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com> --- .../common/ml_inference_pipeline/index.ts | 4 + .../common/types/pipelines.ts | 9 +- .../pipelines/create_ml_inference_pipeline.ts | 7 +- .../ml_inference/configure_pipeline.tsx | 3 + .../ml_inference/inference_config.tsx | 58 ++++++++++++ .../ml_inference/ml_inference_logic.test.ts | 1 + .../ml_inference/ml_inference_logic.ts | 16 ++++ .../pipelines/ml_inference/types.ts | 3 + .../zero_shot_inference_configuration.tsx | 89 +++++++++++++++++++ .../create_ml_inference_pipeline.test.ts | 39 ++++++++ .../create_ml_inference_pipeline.ts | 11 ++- .../create_pipeline_definitions.test.ts | 3 + .../pipelines/create_pipeline_definitions.ts | 8 +- .../routes/enterprise_search/indices.test.ts | 1 + .../routes/enterprise_search/indices.ts | 11 +++ 15 files changed, 259 insertions(+), 4 deletions(-) create mode 100644 x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/inference_config.tsx create mode 100644 x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/zero_shot_inference_configuration.tsx diff --git a/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts b/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts index 206e03893f4bc..4917ac29397e6 100644 --- a/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts +++ b/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts @@ -17,6 +17,7 @@ import { MlInferencePipeline, CreateMlInferencePipelineParameters, TrainedModelState, + InferencePipelineInferenceConfig, } from '../types/pipelines'; // Getting an error importing this from @kbn/ml-plugin/common/constants/data_frame_analytics' @@ -37,6 +38,7 @@ export const SUPPORTED_PYTORCH_TASKS = { export interface MlInferencePipelineParams { description?: string; destinationField: string; + inferenceConfig?: InferencePipelineInferenceConfig; model: MlTrainedModelConfig; pipelineName: string; sourceField: string; @@ -50,6 +52,7 @@ export interface MlInferencePipelineParams { export const generateMlInferencePipelineBody = ({ description, destinationField, + inferenceConfig, model, pipelineName, sourceField, @@ -77,6 +80,7 @@ export const generateMlInferencePipelineBody = ({ field_map: { [sourceField]: modelInputField, }, + inference_config: inferenceConfig, model_id: model.model_id, on_failure: [ { diff --git a/x-pack/plugins/enterprise_search/common/types/pipelines.ts b/x-pack/plugins/enterprise_search/common/types/pipelines.ts index 38314f6d162de..53bc80687f136 100644 --- a/x-pack/plugins/enterprise_search/common/types/pipelines.ts +++ b/x-pack/plugins/enterprise_search/common/types/pipelines.ts @@ -5,7 +5,7 @@ * 2.0. */ -import { IngestPipeline } from '@elastic/elasticsearch/lib/api/types'; +import { IngestInferenceConfig, IngestPipeline } from '@elastic/elasticsearch/lib/api/types'; export interface InferencePipeline { modelId: string | undefined; @@ -67,7 +67,14 @@ export interface DeleteMlInferencePipelineResponse { export interface CreateMlInferencePipelineParameters { destination_field?: string; + inference_config?: InferencePipelineInferenceConfig; model_id: string; pipeline_name: string; source_field: string; } + +export type InferencePipelineInferenceConfig = IngestInferenceConfig & { + zero_shot_classification?: { + labels: string[]; + }; +}; diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/create_ml_inference_pipeline.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/create_ml_inference_pipeline.ts index 78f08c4bc0ee8..afc4d789a0223 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/create_ml_inference_pipeline.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/create_ml_inference_pipeline.ts @@ -4,13 +4,17 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import { CreateMlInferencePipelineParameters } from '../../../../../common/types/pipelines'; +import { + CreateMlInferencePipelineParameters, + InferencePipelineInferenceConfig, +} from '../../../../../common/types/pipelines'; import { createApiLogic } from '../../../shared/api_logic/create_api_logic'; import { HttpLogic } from '../../../shared/http'; export interface CreateMlInferencePipelineApiLogicArgs { destinationField?: string; indexName: string; + inferenceConfig?: InferencePipelineInferenceConfig; modelId: string; pipelineName: string; sourceField: string; @@ -26,6 +30,7 @@ export const createMlInferencePipeline = async ( const route = `/internal/enterprise_search/indices/${args.indexName}/ml_inference/pipeline_processors`; const params: CreateMlInferencePipelineParameters = { destination_field: args.destinationField, + inference_config: args.inferenceConfig, model_id: args.modelId, pipeline_name: args.pipelineName, source_field: args.sourceField, diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/configure_pipeline.tsx b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/configure_pipeline.tsx index feb4ca8c87a4e..104562439d325 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/configure_pipeline.tsx +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/configure_pipeline.tsx @@ -30,6 +30,7 @@ import { docLinks } from '../../../../../shared/doc_links'; import { IndexViewLogic } from '../../index_view_logic'; +import { InferenceConfiguration } from './inference_config'; import { EMPTY_PIPELINE_CONFIGURATION, MLInferenceLogic } from './ml_inference_logic'; import { MlModelSelectOption } from './model_select_option'; import { PipelineSelectOption } from './pipeline_select_option'; @@ -275,6 +276,7 @@ export const ConfigurePipeline: React.FC = () => { setInferencePipelineConfiguration({ ...configuration, modelID: value, + inferenceConfig: undefined, }) } options={modelOptions} @@ -357,6 +359,7 @@ export const ConfigurePipeline: React.FC = () => { + ); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/inference_config.tsx b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/inference_config.tsx new file mode 100644 index 0000000000000..244d708224940 --- /dev/null +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/inference_config.tsx @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import React from 'react'; + +import { useValues } from 'kea'; + +import { EuiSpacer, EuiText } from '@elastic/eui'; +import { i18n } from '@kbn/i18n'; + +import { + getMlModelTypesForModelConfig, + SUPPORTED_PYTORCH_TASKS, +} from '../../../../../../../common/ml_inference_pipeline'; +import { getMLType } from '../../../shared/ml_inference/utils'; + +import { MLInferenceLogic } from './ml_inference_logic'; +import { ZeroShotClassificationInferenceConfiguration } from './zero_shot_inference_configuration'; + +export const InferenceConfiguration: React.FC = () => { + const { + addInferencePipelineModal: { configuration }, + selectedMLModel, + } = useValues(MLInferenceLogic); + if (!selectedMLModel || configuration.existingPipeline) return null; + const modelType = getMLType(getMlModelTypesForModelConfig(selectedMLModel)); + switch (modelType) { + case SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION: + return ( + + + + ); + default: + return null; + } +}; + +const InferenceConfigurationWrapper: React.FC = ({ children }) => { + return ( + <> + + +

+ {i18n.translate( + 'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.inference.title', + { defaultMessage: 'Inference Configuration' } + )} +

+
+ {children} + + ); +}; diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.test.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.test.ts index 171fe0e3a53fc..9083914812263 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.test.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.test.ts @@ -55,6 +55,7 @@ const DEFAULT_VALUES: MLInferenceProcessorsValues = { mlInferencePipelinesData: undefined, mlModelsData: null, mlModelsStatus: 0, + selectedMLModel: null, sourceFields: undefined, supportedMLModels: [], }; diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.ts index 9a6737d90da4b..5bddfbc545b53 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.ts @@ -162,6 +162,7 @@ export interface MLInferenceProcessorsValues { mlInferencePipelinesData: FetchMlInferencePipelinesResponse | undefined; mlModelsData: TrainedModel[] | null; mlModelsStatus: Status; + selectedMLModel: TrainedModel | null; sourceFields: string[] | undefined; supportedMLModels: TrainedModel[]; } @@ -241,6 +242,7 @@ export const MLInferenceLogic = kea< ? configuration.destinationField : undefined, indexName, + inferenceConfig: configuration.inferenceConfig, modelId: configuration.modelID, pipelineName: configuration.pipelineName, sourceField: configuration.sourceField, @@ -343,6 +345,7 @@ export const MLInferenceLogic = kea< model, pipelineName: configuration.pipelineName, sourceField: configuration.sourceField, + inferenceConfig: configuration.inferenceConfig, }); }, ], @@ -435,5 +438,18 @@ export const MLInferenceLogic = kea< return existingPipelines; }, ], + selectedMLModel: [ + () => [selectors.supportedMLModels, selectors.addInferencePipelineModal], + ( + supportedMLModels: MLInferenceProcessorsValues['supportedMLModels'], + addInferencePipelineModal: MLInferenceProcessorsValues['addInferencePipelineModal'] + ) => { + return ( + supportedMLModels.find( + (model) => model.model_id === addInferencePipelineModal.configuration.modelID + ) ?? null + ); + }, + ], }), }); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/types.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/types.ts index db8970fa62759..0e7070e2385e9 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/types.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/types.ts @@ -5,9 +5,12 @@ * 2.0. */ +import { InferencePipelineInferenceConfig } from '../../../../../../../common/types/pipelines'; + export interface InferencePipelineConfiguration { destinationField: string; existingPipeline?: boolean; + inferenceConfig?: InferencePipelineInferenceConfig; modelID: string; pipelineName: string; sourceField: string; diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/zero_shot_inference_configuration.tsx b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/zero_shot_inference_configuration.tsx new file mode 100644 index 0000000000000..2df6e4b8efa62 --- /dev/null +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/zero_shot_inference_configuration.tsx @@ -0,0 +1,89 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import React from 'react'; + +import { useValues, useActions } from 'kea'; + +import { EuiComboBox, EuiFormRow, EuiSpacer } from '@elastic/eui'; + +import { i18n } from '@kbn/i18n'; + +import { IndexViewLogic } from '../../index_view_logic'; + +import { MLInferenceLogic } from './ml_inference_logic'; + +type LabelOptions = Array<{ label: string }>; + +export const ZeroShotClassificationInferenceConfiguration: React.FC = () => { + const { ingestionMethod } = useValues(IndexViewLogic); + const { + addInferencePipelineModal: { configuration }, + } = useValues(MLInferenceLogic); + const { setInferencePipelineConfiguration } = useActions(MLInferenceLogic); + + const zeroShotLabels = configuration?.inferenceConfig?.zero_shot_classification?.labels ?? []; + const labelOptions = zeroShotLabels.map((label) => ({ label })); + + const onLabelChange = (selectedLabels: LabelOptions) => { + const inferenceConfig = + selectedLabels.length === 0 + ? undefined + : { + zero_shot_classification: { + ...(configuration?.inferenceConfig?.zero_shot_classification ?? {}), + labels: selectedLabels.map(({ label }) => label), + }, + }; + setInferencePipelineConfiguration({ + ...configuration, + inferenceConfig, + }); + }; + const onCreateLabel = (labelValue: string, labels: LabelOptions = []) => { + const normalizedLabelValue = labelValue.trim(); + if (!normalizedLabelValue) return; + + const existingLabel = labels.find((label) => label.label === normalizedLabelValue); + if (existingLabel) return; + setInferencePipelineConfiguration({ + ...configuration, + inferenceConfig: { + zero_shot_classification: { + ...(configuration?.inferenceConfig?.zero_shot_classification ?? {}), + labels: [...zeroShotLabels, normalizedLabelValue], + }, + }, + }); + }; + return ( + <> + + + + + + ); +}; diff --git a/x-pack/plugins/enterprise_search/server/lib/indices/pipelines/ml_inference/pipeline_processors/create_ml_inference_pipeline.test.ts b/x-pack/plugins/enterprise_search/server/lib/indices/pipelines/ml_inference/pipeline_processors/create_ml_inference_pipeline.test.ts index 220db6cca9eb6..2f7de483cee2b 100644 --- a/x-pack/plugins/enterprise_search/server/lib/indices/pipelines/ml_inference/pipeline_processors/create_ml_inference_pipeline.test.ts +++ b/x-pack/plugins/enterprise_search/server/lib/indices/pipelines/ml_inference/pipeline_processors/create_ml_inference_pipeline.test.ts @@ -61,6 +61,7 @@ describe('createMlInferencePipeline lib function', () => { modelId, sourceField, destinationField, + undefined, // Omitted inference config mockClient as unknown as ElasticsearchClient ); @@ -74,6 +75,7 @@ describe('createMlInferencePipeline lib function', () => { modelId, sourceField, destinationField, + undefined, // Omitted inference config mockClient as unknown as ElasticsearchClient ); @@ -93,6 +95,7 @@ describe('createMlInferencePipeline lib function', () => { modelId, sourceField, undefined, // Omitted destination field + undefined, // Omitted inference config mockClient as unknown as ElasticsearchClient ); @@ -110,6 +113,41 @@ describe('createMlInferencePipeline lib function', () => { ); }); + it('should set inference config when provided', async () => { + mockClient.ingest.getPipeline.mockImplementation(() => Promise.reject({ statusCode: 404 })); // Pipeline does not exist + mockClient.ingest.putPipeline.mockImplementation(() => Promise.resolve({ acknowledged: true })); + + await createMlInferencePipeline( + pipelineName, + modelId, + sourceField, + destinationField, + { + zero_shot_classification: { + labels: ['foo', 'bar'], + }, + }, + mockClient as unknown as ElasticsearchClient + ); + + // Verify the object passed to pipeline creation contains the default target field name + expect(mockClient.ingest.putPipeline).toHaveBeenCalledWith( + expect.objectContaining({ + processors: expect.arrayContaining([ + expect.objectContaining({ + inference: expect.objectContaining({ + inference_config: { + zero_shot_classification: { + labels: ['foo', 'bar'], + }, + }, + }), + }), + ]), + }) + ); + }); + it('should throw an error without creating the pipeline if it already exists', () => { mockClient.ingest.getPipeline.mockImplementation(() => Promise.resolve({ @@ -122,6 +160,7 @@ describe('createMlInferencePipeline lib function', () => { modelId, sourceField, destinationField, + undefined, // Omitted inference config mockClient as unknown as ElasticsearchClient ); diff --git a/x-pack/plugins/enterprise_search/server/lib/indices/pipelines/ml_inference/pipeline_processors/create_ml_inference_pipeline.ts b/x-pack/plugins/enterprise_search/server/lib/indices/pipelines/ml_inference/pipeline_processors/create_ml_inference_pipeline.ts index b454f8b495cb2..d75e844cb912e 100644 --- a/x-pack/plugins/enterprise_search/server/lib/indices/pipelines/ml_inference/pipeline_processors/create_ml_inference_pipeline.ts +++ b/x-pack/plugins/enterprise_search/server/lib/indices/pipelines/ml_inference/pipeline_processors/create_ml_inference_pipeline.ts @@ -10,7 +10,10 @@ import { ElasticsearchClient } from '@kbn/core/server'; import { formatPipelineName } from '../../../../../../common/ml_inference_pipeline'; import { ErrorCode } from '../../../../../../common/types/error_codes'; -import type { CreateMlInferencePipelineResponse } from '../../../../../../common/types/pipelines'; +import type { + CreateMlInferencePipelineResponse, + InferencePipelineInferenceConfig, +} from '../../../../../../common/types/pipelines'; import { addSubPipelineToIndexSpecificMlPipeline } from '../../../../../utils/create_ml_inference_pipeline'; import { getPrefixedInferencePipelineProcessorName } from '../../../../../utils/ml_inference_pipeline_utils'; import { formatMlPipelineBody } from '../../../../pipelines/create_pipeline_definitions'; @@ -23,6 +26,7 @@ import { formatMlPipelineBody } from '../../../../pipelines/create_pipeline_defi * @param modelId model ID selected by the user. * @param sourceField The document field that model will read. * @param destinationField The document field that the model will write to. + * @param inferenceConfig The configuration for the model. * @param esClient the Elasticsearch Client to use when retrieving pipeline and model details. */ export const createAndReferenceMlInferencePipeline = async ( @@ -31,6 +35,7 @@ export const createAndReferenceMlInferencePipeline = async ( modelId: string, sourceField: string, destinationField: string | null | undefined, + inferenceConfig: InferencePipelineInferenceConfig | undefined, esClient: ElasticsearchClient ): Promise => { const createPipelineResult = await createMlInferencePipeline( @@ -38,6 +43,7 @@ export const createAndReferenceMlInferencePipeline = async ( modelId, sourceField, destinationField, + inferenceConfig, esClient ); @@ -59,6 +65,7 @@ export const createAndReferenceMlInferencePipeline = async ( * @param modelId model ID selected by the user. * @param sourceField The document field that model will read. * @param destinationField The document field that the model will write to. + * @param inferenceConfig The configuration for the model. * @param esClient the Elasticsearch Client to use when retrieving pipeline and model details. */ export const createMlInferencePipeline = async ( @@ -66,6 +73,7 @@ export const createMlInferencePipeline = async ( modelId: string, sourceField: string, destinationField: string | null | undefined, + inferenceConfig: InferencePipelineInferenceConfig | undefined, esClient: ElasticsearchClient ): Promise => { const inferencePipelineGeneratedName = getPrefixedInferencePipelineProcessorName(pipelineName); @@ -89,6 +97,7 @@ export const createMlInferencePipeline = async ( modelId, sourceField, destinationField || formatPipelineName(pipelineName), + inferenceConfig, esClient ); diff --git a/x-pack/plugins/enterprise_search/server/lib/pipelines/create_pipeline_definitions.test.ts b/x-pack/plugins/enterprise_search/server/lib/pipelines/create_pipeline_definitions.test.ts index 3d54396a7d742..23232fda4199d 100644 --- a/x-pack/plugins/enterprise_search/server/lib/pipelines/create_pipeline_definitions.test.ts +++ b/x-pack/plugins/enterprise_search/server/lib/pipelines/create_pipeline_definitions.test.ts @@ -130,6 +130,7 @@ describe('formatMlPipelineBody util function', () => { modelId, sourceField, destField, + undefined, mockClient as unknown as ElasticsearchClient ); expect(actualResult).toEqual(expectedResult); @@ -144,6 +145,7 @@ describe('formatMlPipelineBody util function', () => { modelId, sourceField, destField, + undefined, mockClient as unknown as ElasticsearchClient ); await expect(asyncCall).rejects.toThrow(Error); @@ -184,6 +186,7 @@ describe('formatMlPipelineBody util function', () => { modelId, sourceField, destField, + undefined, mockClient as unknown as ElasticsearchClient ); expect(actualResult).toEqual(expectedResultWithNoInputField); diff --git a/x-pack/plugins/enterprise_search/server/lib/pipelines/create_pipeline_definitions.ts b/x-pack/plugins/enterprise_search/server/lib/pipelines/create_pipeline_definitions.ts index 4eba6dc5b0c8c..8b511ab22c3e7 100644 --- a/x-pack/plugins/enterprise_search/server/lib/pipelines/create_pipeline_definitions.ts +++ b/x-pack/plugins/enterprise_search/server/lib/pipelines/create_pipeline_definitions.ts @@ -8,7 +8,10 @@ import { ElasticsearchClient } from '@kbn/core/server'; import { generateMlInferencePipelineBody } from '../../../common/ml_inference_pipeline'; -import { MlInferencePipeline } from '../../../common/types/pipelines'; +import { + InferencePipelineInferenceConfig, + MlInferencePipeline, +} from '../../../common/types/pipelines'; import { getInferencePipelineNameFromIndexName } from '../../utils/ml_inference_pipeline_utils'; export interface CreatedPipelines { @@ -221,6 +224,7 @@ export const createIndexPipelineDefinitions = async ( * @param modelId modelId selected by user. * @param sourceField The document field that model will read. * @param destinationField The document field that the model will write to. + * @param inferenceConfig The configuration for the model. * @param esClient the Elasticsearch Client to use when retrieving model details. */ export const formatMlPipelineBody = async ( @@ -228,6 +232,7 @@ export const formatMlPipelineBody = async ( modelId: string, sourceField: string, destinationField: string, + inferenceConfig: InferencePipelineInferenceConfig | undefined, esClient: ElasticsearchClient ): Promise => { // this will raise a 404 if model doesn't exist @@ -235,6 +240,7 @@ export const formatMlPipelineBody = async ( const model = models.trained_model_configs[0]; return generateMlInferencePipelineBody({ destinationField, + inferenceConfig, model, pipelineName, sourceField, diff --git a/x-pack/plugins/enterprise_search/server/routes/enterprise_search/indices.test.ts b/x-pack/plugins/enterprise_search/server/routes/enterprise_search/indices.test.ts index b1c49f9c5cf9a..b81a532862cba 100644 --- a/x-pack/plugins/enterprise_search/server/routes/enterprise_search/indices.test.ts +++ b/x-pack/plugins/enterprise_search/server/routes/enterprise_search/indices.test.ts @@ -289,6 +289,7 @@ describe('Enterprise Search Managed Indices', () => { mockRequestBody.model_id, mockRequestBody.source_field, mockRequestBody.destination_field, + undefined, mockClient.asCurrentUser ); diff --git a/x-pack/plugins/enterprise_search/server/routes/enterprise_search/indices.ts b/x-pack/plugins/enterprise_search/server/routes/enterprise_search/indices.ts index 239310731583e..bc4c644e8666b 100644 --- a/x-pack/plugins/enterprise_search/server/routes/enterprise_search/indices.ts +++ b/x-pack/plugins/enterprise_search/server/routes/enterprise_search/indices.ts @@ -366,6 +366,15 @@ export function registerIndexRoutes({ }), body: schema.object({ destination_field: schema.maybe(schema.nullable(schema.string())), + inference_config: schema.maybe( + schema.object({ + zero_shot_classification: schema.maybe( + schema.object({ + labels: schema.arrayOf(schema.string()), + }) + ), + }) + ), model_id: schema.string(), pipeline_name: schema.string(), source_field: schema.string(), @@ -381,6 +390,7 @@ export function registerIndexRoutes({ pipeline_name: pipelineName, source_field: sourceField, destination_field: destinationField, + inference_config: inferenceConfig, } = request.body; let createPipelineResult: CreateMlInferencePipelineResponse | undefined; @@ -392,6 +402,7 @@ export function registerIndexRoutes({ modelId, sourceField, destinationField, + inferenceConfig, client.asCurrentUser ); } catch (error) {