diff --git a/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts b/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts
index 206e03893f4bc..4917ac29397e6 100644
--- a/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts
+++ b/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts
@@ -17,6 +17,7 @@ import {
MlInferencePipeline,
CreateMlInferencePipelineParameters,
TrainedModelState,
+ InferencePipelineInferenceConfig,
} from '../types/pipelines';
// Getting an error importing this from @kbn/ml-plugin/common/constants/data_frame_analytics'
@@ -37,6 +38,7 @@ export const SUPPORTED_PYTORCH_TASKS = {
export interface MlInferencePipelineParams {
description?: string;
destinationField: string;
+ inferenceConfig?: InferencePipelineInferenceConfig;
model: MlTrainedModelConfig;
pipelineName: string;
sourceField: string;
@@ -50,6 +52,7 @@ export interface MlInferencePipelineParams {
export const generateMlInferencePipelineBody = ({
description,
destinationField,
+ inferenceConfig,
model,
pipelineName,
sourceField,
@@ -77,6 +80,7 @@ export const generateMlInferencePipelineBody = ({
field_map: {
[sourceField]: modelInputField,
},
+ inference_config: inferenceConfig,
model_id: model.model_id,
on_failure: [
{
diff --git a/x-pack/plugins/enterprise_search/common/types/pipelines.ts b/x-pack/plugins/enterprise_search/common/types/pipelines.ts
index 38314f6d162de..53bc80687f136 100644
--- a/x-pack/plugins/enterprise_search/common/types/pipelines.ts
+++ b/x-pack/plugins/enterprise_search/common/types/pipelines.ts
@@ -5,7 +5,7 @@
* 2.0.
*/
-import { IngestPipeline } from '@elastic/elasticsearch/lib/api/types';
+import { IngestInferenceConfig, IngestPipeline } from '@elastic/elasticsearch/lib/api/types';
export interface InferencePipeline {
modelId: string | undefined;
@@ -67,7 +67,14 @@ export interface DeleteMlInferencePipelineResponse {
export interface CreateMlInferencePipelineParameters {
destination_field?: string;
+ inference_config?: InferencePipelineInferenceConfig;
model_id: string;
pipeline_name: string;
source_field: string;
}
+
+export type InferencePipelineInferenceConfig = IngestInferenceConfig & {
+ zero_shot_classification?: {
+ labels: string[];
+ };
+};
diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/create_ml_inference_pipeline.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/create_ml_inference_pipeline.ts
index 78f08c4bc0ee8..afc4d789a0223 100644
--- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/create_ml_inference_pipeline.ts
+++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/create_ml_inference_pipeline.ts
@@ -4,13 +4,17 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
-import { CreateMlInferencePipelineParameters } from '../../../../../common/types/pipelines';
+import {
+ CreateMlInferencePipelineParameters,
+ InferencePipelineInferenceConfig,
+} from '../../../../../common/types/pipelines';
import { createApiLogic } from '../../../shared/api_logic/create_api_logic';
import { HttpLogic } from '../../../shared/http';
export interface CreateMlInferencePipelineApiLogicArgs {
destinationField?: string;
indexName: string;
+ inferenceConfig?: InferencePipelineInferenceConfig;
modelId: string;
pipelineName: string;
sourceField: string;
@@ -26,6 +30,7 @@ export const createMlInferencePipeline = async (
const route = `/internal/enterprise_search/indices/${args.indexName}/ml_inference/pipeline_processors`;
const params: CreateMlInferencePipelineParameters = {
destination_field: args.destinationField,
+ inference_config: args.inferenceConfig,
model_id: args.modelId,
pipeline_name: args.pipelineName,
source_field: args.sourceField,
diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/configure_pipeline.tsx b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/configure_pipeline.tsx
index feb4ca8c87a4e..104562439d325 100644
--- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/configure_pipeline.tsx
+++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/configure_pipeline.tsx
@@ -30,6 +30,7 @@ import { docLinks } from '../../../../../shared/doc_links';
import { IndexViewLogic } from '../../index_view_logic';
+import { InferenceConfiguration } from './inference_config';
import { EMPTY_PIPELINE_CONFIGURATION, MLInferenceLogic } from './ml_inference_logic';
import { MlModelSelectOption } from './model_select_option';
import { PipelineSelectOption } from './pipeline_select_option';
@@ -275,6 +276,7 @@ export const ConfigurePipeline: React.FC = () => {
setInferencePipelineConfiguration({
...configuration,
modelID: value,
+ inferenceConfig: undefined,
})
}
options={modelOptions}
@@ -357,6 +359,7 @@ export const ConfigurePipeline: React.FC = () => {
+
>
);
diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/inference_config.tsx b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/inference_config.tsx
new file mode 100644
index 0000000000000..244d708224940
--- /dev/null
+++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/inference_config.tsx
@@ -0,0 +1,58 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+import React from 'react';
+
+import { useValues } from 'kea';
+
+import { EuiSpacer, EuiText } from '@elastic/eui';
+import { i18n } from '@kbn/i18n';
+
+import {
+ getMlModelTypesForModelConfig,
+ SUPPORTED_PYTORCH_TASKS,
+} from '../../../../../../../common/ml_inference_pipeline';
+import { getMLType } from '../../../shared/ml_inference/utils';
+
+import { MLInferenceLogic } from './ml_inference_logic';
+import { ZeroShotClassificationInferenceConfiguration } from './zero_shot_inference_configuration';
+
+export const InferenceConfiguration: React.FC = () => {
+ const {
+ addInferencePipelineModal: { configuration },
+ selectedMLModel,
+ } = useValues(MLInferenceLogic);
+ if (!selectedMLModel || configuration.existingPipeline) return null;
+ const modelType = getMLType(getMlModelTypesForModelConfig(selectedMLModel));
+ switch (modelType) {
+ case SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION:
+ return (
+
+
+
+ );
+ default:
+ return null;
+ }
+};
+
+const InferenceConfigurationWrapper: React.FC = ({ children }) => {
+ return (
+ <>
+
+
+
+ {i18n.translate(
+ 'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.inference.title',
+ { defaultMessage: 'Inference Configuration' }
+ )}
+
+
+ {children}
+ >
+ );
+};
diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.test.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.test.ts
index 171fe0e3a53fc..9083914812263 100644
--- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.test.ts
+++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.test.ts
@@ -55,6 +55,7 @@ const DEFAULT_VALUES: MLInferenceProcessorsValues = {
mlInferencePipelinesData: undefined,
mlModelsData: null,
mlModelsStatus: 0,
+ selectedMLModel: null,
sourceFields: undefined,
supportedMLModels: [],
};
diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.ts
index 9a6737d90da4b..5bddfbc545b53 100644
--- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.ts
+++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.ts
@@ -162,6 +162,7 @@ export interface MLInferenceProcessorsValues {
mlInferencePipelinesData: FetchMlInferencePipelinesResponse | undefined;
mlModelsData: TrainedModel[] | null;
mlModelsStatus: Status;
+ selectedMLModel: TrainedModel | null;
sourceFields: string[] | undefined;
supportedMLModels: TrainedModel[];
}
@@ -241,6 +242,7 @@ export const MLInferenceLogic = kea<
? configuration.destinationField
: undefined,
indexName,
+ inferenceConfig: configuration.inferenceConfig,
modelId: configuration.modelID,
pipelineName: configuration.pipelineName,
sourceField: configuration.sourceField,
@@ -343,6 +345,7 @@ export const MLInferenceLogic = kea<
model,
pipelineName: configuration.pipelineName,
sourceField: configuration.sourceField,
+ inferenceConfig: configuration.inferenceConfig,
});
},
],
@@ -435,5 +438,18 @@ export const MLInferenceLogic = kea<
return existingPipelines;
},
],
+ selectedMLModel: [
+ () => [selectors.supportedMLModels, selectors.addInferencePipelineModal],
+ (
+ supportedMLModels: MLInferenceProcessorsValues['supportedMLModels'],
+ addInferencePipelineModal: MLInferenceProcessorsValues['addInferencePipelineModal']
+ ) => {
+ return (
+ supportedMLModels.find(
+ (model) => model.model_id === addInferencePipelineModal.configuration.modelID
+ ) ?? null
+ );
+ },
+ ],
}),
});
diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/types.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/types.ts
index db8970fa62759..0e7070e2385e9 100644
--- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/types.ts
+++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/types.ts
@@ -5,9 +5,12 @@
* 2.0.
*/
+import { InferencePipelineInferenceConfig } from '../../../../../../../common/types/pipelines';
+
export interface InferencePipelineConfiguration {
destinationField: string;
existingPipeline?: boolean;
+ inferenceConfig?: InferencePipelineInferenceConfig;
modelID: string;
pipelineName: string;
sourceField: string;
diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/zero_shot_inference_configuration.tsx b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/zero_shot_inference_configuration.tsx
new file mode 100644
index 0000000000000..2df6e4b8efa62
--- /dev/null
+++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/zero_shot_inference_configuration.tsx
@@ -0,0 +1,89 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+import React from 'react';
+
+import { useValues, useActions } from 'kea';
+
+import { EuiComboBox, EuiFormRow, EuiSpacer } from '@elastic/eui';
+
+import { i18n } from '@kbn/i18n';
+
+import { IndexViewLogic } from '../../index_view_logic';
+
+import { MLInferenceLogic } from './ml_inference_logic';
+
+type LabelOptions = Array<{ label: string }>;
+
+export const ZeroShotClassificationInferenceConfiguration: React.FC = () => {
+ const { ingestionMethod } = useValues(IndexViewLogic);
+ const {
+ addInferencePipelineModal: { configuration },
+ } = useValues(MLInferenceLogic);
+ const { setInferencePipelineConfiguration } = useActions(MLInferenceLogic);
+
+ const zeroShotLabels = configuration?.inferenceConfig?.zero_shot_classification?.labels ?? [];
+ const labelOptions = zeroShotLabels.map((label) => ({ label }));
+
+ const onLabelChange = (selectedLabels: LabelOptions) => {
+ const inferenceConfig =
+ selectedLabels.length === 0
+ ? undefined
+ : {
+ zero_shot_classification: {
+ ...(configuration?.inferenceConfig?.zero_shot_classification ?? {}),
+ labels: selectedLabels.map(({ label }) => label),
+ },
+ };
+ setInferencePipelineConfiguration({
+ ...configuration,
+ inferenceConfig,
+ });
+ };
+ const onCreateLabel = (labelValue: string, labels: LabelOptions = []) => {
+ const normalizedLabelValue = labelValue.trim();
+ if (!normalizedLabelValue) return;
+
+ const existingLabel = labels.find((label) => label.label === normalizedLabelValue);
+ if (existingLabel) return;
+ setInferencePipelineConfiguration({
+ ...configuration,
+ inferenceConfig: {
+ zero_shot_classification: {
+ ...(configuration?.inferenceConfig?.zero_shot_classification ?? {}),
+ labels: [...zeroShotLabels, normalizedLabelValue],
+ },
+ },
+ });
+ };
+ return (
+ <>
+
+
+
+
+ >
+ );
+};
diff --git a/x-pack/plugins/enterprise_search/server/lib/indices/pipelines/ml_inference/pipeline_processors/create_ml_inference_pipeline.test.ts b/x-pack/plugins/enterprise_search/server/lib/indices/pipelines/ml_inference/pipeline_processors/create_ml_inference_pipeline.test.ts
index 220db6cca9eb6..2f7de483cee2b 100644
--- a/x-pack/plugins/enterprise_search/server/lib/indices/pipelines/ml_inference/pipeline_processors/create_ml_inference_pipeline.test.ts
+++ b/x-pack/plugins/enterprise_search/server/lib/indices/pipelines/ml_inference/pipeline_processors/create_ml_inference_pipeline.test.ts
@@ -61,6 +61,7 @@ describe('createMlInferencePipeline lib function', () => {
modelId,
sourceField,
destinationField,
+ undefined, // Omitted inference config
mockClient as unknown as ElasticsearchClient
);
@@ -74,6 +75,7 @@ describe('createMlInferencePipeline lib function', () => {
modelId,
sourceField,
destinationField,
+ undefined, // Omitted inference config
mockClient as unknown as ElasticsearchClient
);
@@ -93,6 +95,7 @@ describe('createMlInferencePipeline lib function', () => {
modelId,
sourceField,
undefined, // Omitted destination field
+ undefined, // Omitted inference config
mockClient as unknown as ElasticsearchClient
);
@@ -110,6 +113,41 @@ describe('createMlInferencePipeline lib function', () => {
);
});
+ it('should set inference config when provided', async () => {
+ mockClient.ingest.getPipeline.mockImplementation(() => Promise.reject({ statusCode: 404 })); // Pipeline does not exist
+ mockClient.ingest.putPipeline.mockImplementation(() => Promise.resolve({ acknowledged: true }));
+
+ await createMlInferencePipeline(
+ pipelineName,
+ modelId,
+ sourceField,
+ destinationField,
+ {
+ zero_shot_classification: {
+ labels: ['foo', 'bar'],
+ },
+ },
+ mockClient as unknown as ElasticsearchClient
+ );
+
+ // Verify the object passed to pipeline creation contains the default target field name
+ expect(mockClient.ingest.putPipeline).toHaveBeenCalledWith(
+ expect.objectContaining({
+ processors: expect.arrayContaining([
+ expect.objectContaining({
+ inference: expect.objectContaining({
+ inference_config: {
+ zero_shot_classification: {
+ labels: ['foo', 'bar'],
+ },
+ },
+ }),
+ }),
+ ]),
+ })
+ );
+ });
+
it('should throw an error without creating the pipeline if it already exists', () => {
mockClient.ingest.getPipeline.mockImplementation(() =>
Promise.resolve({
@@ -122,6 +160,7 @@ describe('createMlInferencePipeline lib function', () => {
modelId,
sourceField,
destinationField,
+ undefined, // Omitted inference config
mockClient as unknown as ElasticsearchClient
);
diff --git a/x-pack/plugins/enterprise_search/server/lib/indices/pipelines/ml_inference/pipeline_processors/create_ml_inference_pipeline.ts b/x-pack/plugins/enterprise_search/server/lib/indices/pipelines/ml_inference/pipeline_processors/create_ml_inference_pipeline.ts
index b454f8b495cb2..d75e844cb912e 100644
--- a/x-pack/plugins/enterprise_search/server/lib/indices/pipelines/ml_inference/pipeline_processors/create_ml_inference_pipeline.ts
+++ b/x-pack/plugins/enterprise_search/server/lib/indices/pipelines/ml_inference/pipeline_processors/create_ml_inference_pipeline.ts
@@ -10,7 +10,10 @@ import { ElasticsearchClient } from '@kbn/core/server';
import { formatPipelineName } from '../../../../../../common/ml_inference_pipeline';
import { ErrorCode } from '../../../../../../common/types/error_codes';
-import type { CreateMlInferencePipelineResponse } from '../../../../../../common/types/pipelines';
+import type {
+ CreateMlInferencePipelineResponse,
+ InferencePipelineInferenceConfig,
+} from '../../../../../../common/types/pipelines';
import { addSubPipelineToIndexSpecificMlPipeline } from '../../../../../utils/create_ml_inference_pipeline';
import { getPrefixedInferencePipelineProcessorName } from '../../../../../utils/ml_inference_pipeline_utils';
import { formatMlPipelineBody } from '../../../../pipelines/create_pipeline_definitions';
@@ -23,6 +26,7 @@ import { formatMlPipelineBody } from '../../../../pipelines/create_pipeline_defi
* @param modelId model ID selected by the user.
* @param sourceField The document field that model will read.
* @param destinationField The document field that the model will write to.
+ * @param inferenceConfig The configuration for the model.
* @param esClient the Elasticsearch Client to use when retrieving pipeline and model details.
*/
export const createAndReferenceMlInferencePipeline = async (
@@ -31,6 +35,7 @@ export const createAndReferenceMlInferencePipeline = async (
modelId: string,
sourceField: string,
destinationField: string | null | undefined,
+ inferenceConfig: InferencePipelineInferenceConfig | undefined,
esClient: ElasticsearchClient
): Promise => {
const createPipelineResult = await createMlInferencePipeline(
@@ -38,6 +43,7 @@ export const createAndReferenceMlInferencePipeline = async (
modelId,
sourceField,
destinationField,
+ inferenceConfig,
esClient
);
@@ -59,6 +65,7 @@ export const createAndReferenceMlInferencePipeline = async (
* @param modelId model ID selected by the user.
* @param sourceField The document field that model will read.
* @param destinationField The document field that the model will write to.
+ * @param inferenceConfig The configuration for the model.
* @param esClient the Elasticsearch Client to use when retrieving pipeline and model details.
*/
export const createMlInferencePipeline = async (
@@ -66,6 +73,7 @@ export const createMlInferencePipeline = async (
modelId: string,
sourceField: string,
destinationField: string | null | undefined,
+ inferenceConfig: InferencePipelineInferenceConfig | undefined,
esClient: ElasticsearchClient
): Promise => {
const inferencePipelineGeneratedName = getPrefixedInferencePipelineProcessorName(pipelineName);
@@ -89,6 +97,7 @@ export const createMlInferencePipeline = async (
modelId,
sourceField,
destinationField || formatPipelineName(pipelineName),
+ inferenceConfig,
esClient
);
diff --git a/x-pack/plugins/enterprise_search/server/lib/pipelines/create_pipeline_definitions.test.ts b/x-pack/plugins/enterprise_search/server/lib/pipelines/create_pipeline_definitions.test.ts
index 3d54396a7d742..23232fda4199d 100644
--- a/x-pack/plugins/enterprise_search/server/lib/pipelines/create_pipeline_definitions.test.ts
+++ b/x-pack/plugins/enterprise_search/server/lib/pipelines/create_pipeline_definitions.test.ts
@@ -130,6 +130,7 @@ describe('formatMlPipelineBody util function', () => {
modelId,
sourceField,
destField,
+ undefined,
mockClient as unknown as ElasticsearchClient
);
expect(actualResult).toEqual(expectedResult);
@@ -144,6 +145,7 @@ describe('formatMlPipelineBody util function', () => {
modelId,
sourceField,
destField,
+ undefined,
mockClient as unknown as ElasticsearchClient
);
await expect(asyncCall).rejects.toThrow(Error);
@@ -184,6 +186,7 @@ describe('formatMlPipelineBody util function', () => {
modelId,
sourceField,
destField,
+ undefined,
mockClient as unknown as ElasticsearchClient
);
expect(actualResult).toEqual(expectedResultWithNoInputField);
diff --git a/x-pack/plugins/enterprise_search/server/lib/pipelines/create_pipeline_definitions.ts b/x-pack/plugins/enterprise_search/server/lib/pipelines/create_pipeline_definitions.ts
index 4eba6dc5b0c8c..8b511ab22c3e7 100644
--- a/x-pack/plugins/enterprise_search/server/lib/pipelines/create_pipeline_definitions.ts
+++ b/x-pack/plugins/enterprise_search/server/lib/pipelines/create_pipeline_definitions.ts
@@ -8,7 +8,10 @@
import { ElasticsearchClient } from '@kbn/core/server';
import { generateMlInferencePipelineBody } from '../../../common/ml_inference_pipeline';
-import { MlInferencePipeline } from '../../../common/types/pipelines';
+import {
+ InferencePipelineInferenceConfig,
+ MlInferencePipeline,
+} from '../../../common/types/pipelines';
import { getInferencePipelineNameFromIndexName } from '../../utils/ml_inference_pipeline_utils';
export interface CreatedPipelines {
@@ -221,6 +224,7 @@ export const createIndexPipelineDefinitions = async (
* @param modelId modelId selected by user.
* @param sourceField The document field that model will read.
* @param destinationField The document field that the model will write to.
+ * @param inferenceConfig The configuration for the model.
* @param esClient the Elasticsearch Client to use when retrieving model details.
*/
export const formatMlPipelineBody = async (
@@ -228,6 +232,7 @@ export const formatMlPipelineBody = async (
modelId: string,
sourceField: string,
destinationField: string,
+ inferenceConfig: InferencePipelineInferenceConfig | undefined,
esClient: ElasticsearchClient
): Promise => {
// this will raise a 404 if model doesn't exist
@@ -235,6 +240,7 @@ export const formatMlPipelineBody = async (
const model = models.trained_model_configs[0];
return generateMlInferencePipelineBody({
destinationField,
+ inferenceConfig,
model,
pipelineName,
sourceField,
diff --git a/x-pack/plugins/enterprise_search/server/routes/enterprise_search/indices.test.ts b/x-pack/plugins/enterprise_search/server/routes/enterprise_search/indices.test.ts
index b1c49f9c5cf9a..b81a532862cba 100644
--- a/x-pack/plugins/enterprise_search/server/routes/enterprise_search/indices.test.ts
+++ b/x-pack/plugins/enterprise_search/server/routes/enterprise_search/indices.test.ts
@@ -289,6 +289,7 @@ describe('Enterprise Search Managed Indices', () => {
mockRequestBody.model_id,
mockRequestBody.source_field,
mockRequestBody.destination_field,
+ undefined,
mockClient.asCurrentUser
);
diff --git a/x-pack/plugins/enterprise_search/server/routes/enterprise_search/indices.ts b/x-pack/plugins/enterprise_search/server/routes/enterprise_search/indices.ts
index 239310731583e..bc4c644e8666b 100644
--- a/x-pack/plugins/enterprise_search/server/routes/enterprise_search/indices.ts
+++ b/x-pack/plugins/enterprise_search/server/routes/enterprise_search/indices.ts
@@ -366,6 +366,15 @@ export function registerIndexRoutes({
}),
body: schema.object({
destination_field: schema.maybe(schema.nullable(schema.string())),
+ inference_config: schema.maybe(
+ schema.object({
+ zero_shot_classification: schema.maybe(
+ schema.object({
+ labels: schema.arrayOf(schema.string()),
+ })
+ ),
+ })
+ ),
model_id: schema.string(),
pipeline_name: schema.string(),
source_field: schema.string(),
@@ -381,6 +390,7 @@ export function registerIndexRoutes({
pipeline_name: pipelineName,
source_field: sourceField,
destination_field: destinationField,
+ inference_config: inferenceConfig,
} = request.body;
let createPipelineResult: CreateMlInferencePipelineResponse | undefined;
@@ -392,6 +402,7 @@ export function registerIndexRoutes({
modelId,
sourceField,
destinationField,
+ inferenceConfig,
client.asCurrentUser
);
} catch (error) {