Skip to content

Commit

Permalink
Merge branch 'elastic:main' into deprecate_csp_rule
Browse files Browse the repository at this point in the history
  • Loading branch information
ofiriro3 authored Dec 20, 2022
2 parents d1aa7c3 + e6767f7 commit fbf501b
Show file tree
Hide file tree
Showing 15 changed files with 259 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import {
MlInferencePipeline,
CreateMlInferencePipelineParameters,
TrainedModelState,
InferencePipelineInferenceConfig,
} from '../types/pipelines';

// Getting an error importing this from @kbn/ml-plugin/common/constants/data_frame_analytics'
Expand All @@ -37,6 +38,7 @@ export const SUPPORTED_PYTORCH_TASKS = {
export interface MlInferencePipelineParams {
description?: string;
destinationField: string;
inferenceConfig?: InferencePipelineInferenceConfig;
model: MlTrainedModelConfig;
pipelineName: string;
sourceField: string;
Expand All @@ -50,6 +52,7 @@ export interface MlInferencePipelineParams {
export const generateMlInferencePipelineBody = ({
description,
destinationField,
inferenceConfig,
model,
pipelineName,
sourceField,
Expand Down Expand Up @@ -77,6 +80,7 @@ export const generateMlInferencePipelineBody = ({
field_map: {
[sourceField]: modelInputField,
},
inference_config: inferenceConfig,
model_id: model.model_id,
on_failure: [
{
Expand Down
9 changes: 8 additions & 1 deletion x-pack/plugins/enterprise_search/common/types/pipelines.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

import { IngestPipeline } from '@elastic/elasticsearch/lib/api/types';
import { IngestInferenceConfig, IngestPipeline } from '@elastic/elasticsearch/lib/api/types';

export interface InferencePipeline {
modelId: string | undefined;
Expand Down Expand Up @@ -67,7 +67,14 @@ export interface DeleteMlInferencePipelineResponse {

export interface CreateMlInferencePipelineParameters {
destination_field?: string;
inference_config?: InferencePipelineInferenceConfig;
model_id: string;
pipeline_name: string;
source_field: string;
}

export type InferencePipelineInferenceConfig = IngestInferenceConfig & {
zero_shot_classification?: {
labels: string[];
};
};
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { CreateMlInferencePipelineParameters } from '../../../../../common/types/pipelines';
import {
CreateMlInferencePipelineParameters,
InferencePipelineInferenceConfig,
} from '../../../../../common/types/pipelines';
import { createApiLogic } from '../../../shared/api_logic/create_api_logic';
import { HttpLogic } from '../../../shared/http';

export interface CreateMlInferencePipelineApiLogicArgs {
destinationField?: string;
indexName: string;
inferenceConfig?: InferencePipelineInferenceConfig;
modelId: string;
pipelineName: string;
sourceField: string;
Expand All @@ -26,6 +30,7 @@ export const createMlInferencePipeline = async (
const route = `/internal/enterprise_search/indices/${args.indexName}/ml_inference/pipeline_processors`;
const params: CreateMlInferencePipelineParameters = {
destination_field: args.destinationField,
inference_config: args.inferenceConfig,
model_id: args.modelId,
pipeline_name: args.pipelineName,
source_field: args.sourceField,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import { docLinks } from '../../../../../shared/doc_links';

import { IndexViewLogic } from '../../index_view_logic';

import { InferenceConfiguration } from './inference_config';
import { EMPTY_PIPELINE_CONFIGURATION, MLInferenceLogic } from './ml_inference_logic';
import { MlModelSelectOption } from './model_select_option';
import { PipelineSelectOption } from './pipeline_select_option';
Expand Down Expand Up @@ -275,6 +276,7 @@ export const ConfigurePipeline: React.FC = () => {
setInferencePipelineConfiguration({
...configuration,
modelID: value,
inferenceConfig: undefined,
})
}
options={modelOptions}
Expand Down Expand Up @@ -357,6 +359,7 @@ export const ConfigurePipeline: React.FC = () => {
</EuiFormRow>
</EuiFlexItem>
</EuiFlexGroup>
<InferenceConfiguration />
</EuiForm>
</>
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import React from 'react';

import { useValues } from 'kea';

import { EuiSpacer, EuiText } from '@elastic/eui';
import { i18n } from '@kbn/i18n';

import {
getMlModelTypesForModelConfig,
SUPPORTED_PYTORCH_TASKS,
} from '../../../../../../../common/ml_inference_pipeline';
import { getMLType } from '../../../shared/ml_inference/utils';

import { MLInferenceLogic } from './ml_inference_logic';
import { ZeroShotClassificationInferenceConfiguration } from './zero_shot_inference_configuration';

export const InferenceConfiguration: React.FC = () => {
const {
addInferencePipelineModal: { configuration },
selectedMLModel,
} = useValues(MLInferenceLogic);
if (!selectedMLModel || configuration.existingPipeline) return null;
const modelType = getMLType(getMlModelTypesForModelConfig(selectedMLModel));
switch (modelType) {
case SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION:
return (
<InferenceConfigurationWrapper>
<ZeroShotClassificationInferenceConfiguration />
</InferenceConfigurationWrapper>
);
default:
return null;
}
};

const InferenceConfigurationWrapper: React.FC = ({ children }) => {
return (
<>
<EuiSpacer />
<EuiText>
<h4>
{i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.inference.title',
{ defaultMessage: 'Inference Configuration' }
)}
</h4>
</EuiText>
{children}
</>
);
};
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ const DEFAULT_VALUES: MLInferenceProcessorsValues = {
mlInferencePipelinesData: undefined,
mlModelsData: null,
mlModelsStatus: 0,
selectedMLModel: null,
sourceFields: undefined,
supportedMLModels: [],
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ export interface MLInferenceProcessorsValues {
mlInferencePipelinesData: FetchMlInferencePipelinesResponse | undefined;
mlModelsData: TrainedModel[] | null;
mlModelsStatus: Status;
selectedMLModel: TrainedModel | null;
sourceFields: string[] | undefined;
supportedMLModels: TrainedModel[];
}
Expand Down Expand Up @@ -241,6 +242,7 @@ export const MLInferenceLogic = kea<
? configuration.destinationField
: undefined,
indexName,
inferenceConfig: configuration.inferenceConfig,
modelId: configuration.modelID,
pipelineName: configuration.pipelineName,
sourceField: configuration.sourceField,
Expand Down Expand Up @@ -343,6 +345,7 @@ export const MLInferenceLogic = kea<
model,
pipelineName: configuration.pipelineName,
sourceField: configuration.sourceField,
inferenceConfig: configuration.inferenceConfig,
});
},
],
Expand Down Expand Up @@ -435,5 +438,18 @@ export const MLInferenceLogic = kea<
return existingPipelines;
},
],
selectedMLModel: [
() => [selectors.supportedMLModels, selectors.addInferencePipelineModal],
(
supportedMLModels: MLInferenceProcessorsValues['supportedMLModels'],
addInferencePipelineModal: MLInferenceProcessorsValues['addInferencePipelineModal']
) => {
return (
supportedMLModels.find(
(model) => model.model_id === addInferencePipelineModal.configuration.modelID
) ?? null
);
},
],
}),
});
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
* 2.0.
*/

import { InferencePipelineInferenceConfig } from '../../../../../../../common/types/pipelines';

export interface InferencePipelineConfiguration {
destinationField: string;
existingPipeline?: boolean;
inferenceConfig?: InferencePipelineInferenceConfig;
modelID: string;
pipelineName: string;
sourceField: string;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import React from 'react';

import { useValues, useActions } from 'kea';

import { EuiComboBox, EuiFormRow, EuiSpacer } from '@elastic/eui';

import { i18n } from '@kbn/i18n';

import { IndexViewLogic } from '../../index_view_logic';

import { MLInferenceLogic } from './ml_inference_logic';

type LabelOptions = Array<{ label: string }>;

export const ZeroShotClassificationInferenceConfiguration: React.FC = () => {
const { ingestionMethod } = useValues(IndexViewLogic);
const {
addInferencePipelineModal: { configuration },
} = useValues(MLInferenceLogic);
const { setInferencePipelineConfiguration } = useActions(MLInferenceLogic);

const zeroShotLabels = configuration?.inferenceConfig?.zero_shot_classification?.labels ?? [];
const labelOptions = zeroShotLabels.map((label) => ({ label }));

const onLabelChange = (selectedLabels: LabelOptions) => {
const inferenceConfig =
selectedLabels.length === 0
? undefined
: {
zero_shot_classification: {
...(configuration?.inferenceConfig?.zero_shot_classification ?? {}),
labels: selectedLabels.map(({ label }) => label),
},
};
setInferencePipelineConfiguration({
...configuration,
inferenceConfig,
});
};
const onCreateLabel = (labelValue: string, labels: LabelOptions = []) => {
const normalizedLabelValue = labelValue.trim();
if (!normalizedLabelValue) return;

const existingLabel = labels.find((label) => label.label === normalizedLabelValue);
if (existingLabel) return;
setInferencePipelineConfiguration({
...configuration,
inferenceConfig: {
zero_shot_classification: {
...(configuration?.inferenceConfig?.zero_shot_classification ?? {}),
labels: [...zeroShotLabels, normalizedLabelValue],
},
},
});
};
return (
<>
<EuiSpacer size="s" />
<EuiFormRow
label={i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.inference.zeroShot.labels.label',
{ defaultMessage: 'Class labels' }
)}
fullWidth
>
<EuiComboBox
fullWidth
data-telemetry-id={`entSearchContent-${ingestionMethod}-pipelines-configureInferencePipeline-zeroShot-labels`}
placeholder={i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.inference.zeroShot.labels.placeholder',
{ defaultMessage: 'Create labels' }
)}
options={labelOptions}
selectedOptions={labelOptions}
onChange={onLabelChange}
onCreateOption={onCreateLabel}
noSuggestions
/>
</EuiFormRow>
</>
);
};
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ describe('createMlInferencePipeline lib function', () => {
modelId,
sourceField,
destinationField,
undefined, // Omitted inference config
mockClient as unknown as ElasticsearchClient
);

Expand All @@ -74,6 +75,7 @@ describe('createMlInferencePipeline lib function', () => {
modelId,
sourceField,
destinationField,
undefined, // Omitted inference config
mockClient as unknown as ElasticsearchClient
);

Expand All @@ -93,6 +95,7 @@ describe('createMlInferencePipeline lib function', () => {
modelId,
sourceField,
undefined, // Omitted destination field
undefined, // Omitted inference config
mockClient as unknown as ElasticsearchClient
);

Expand All @@ -110,6 +113,41 @@ describe('createMlInferencePipeline lib function', () => {
);
});

it('should set inference config when provided', async () => {
mockClient.ingest.getPipeline.mockImplementation(() => Promise.reject({ statusCode: 404 })); // Pipeline does not exist
mockClient.ingest.putPipeline.mockImplementation(() => Promise.resolve({ acknowledged: true }));

await createMlInferencePipeline(
pipelineName,
modelId,
sourceField,
destinationField,
{
zero_shot_classification: {
labels: ['foo', 'bar'],
},
},
mockClient as unknown as ElasticsearchClient
);

// Verify the object passed to pipeline creation contains the default target field name
expect(mockClient.ingest.putPipeline).toHaveBeenCalledWith(
expect.objectContaining({
processors: expect.arrayContaining([
expect.objectContaining({
inference: expect.objectContaining({
inference_config: {
zero_shot_classification: {
labels: ['foo', 'bar'],
},
},
}),
}),
]),
})
);
});

it('should throw an error without creating the pipeline if it already exists', () => {
mockClient.ingest.getPipeline.mockImplementation(() =>
Promise.resolve({
Expand All @@ -122,6 +160,7 @@ describe('createMlInferencePipeline lib function', () => {
modelId,
sourceField,
destinationField,
undefined, // Omitted inference config
mockClient as unknown as ElasticsearchClient
);

Expand Down
Loading

0 comments on commit fbf501b

Please sign in to comment.