Skip to content

Commit

Permalink
[Enterprise Search] Ensure All Model Types are Set for Inference Proc…
Browse files Browse the repository at this point in the history
…essor Ingest History (#141493)

* add fxn to get model types

* update definition for `types`; add addtl tests
  • Loading branch information
markjhoy authored Sep 22, 2022
1 parent 14b8dc5 commit ca75f7a
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
* 2.0.
*/

import { MlTrainedModelConfig } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { ElasticsearchClient } from '@kbn/core/server';
import { BUILT_IN_MODEL_TAG } from '@kbn/ml-plugin/common/constants/data_frame_analytics';

import { InferencePipeline } from '../../../common/types/pipelines';

import {
fetchAndAddTrainedModelData,
getMlModelTypesForModelConfig,
getMlModelConfigsForModelIds,
fetchMlInferencePipelineProcessorNames,
fetchMlInferencePipelineProcessors,
fetchPipelineProcessorInferenceData,
Expand Down Expand Up @@ -200,6 +204,95 @@ describe('fetchPipelineProcessorInferenceData lib function', () => {
});
});

describe('getMlModelTypesForModelConfig lib function', () => {
const mockModel: MlTrainedModelConfig = {
inference_config: {
ner: {},
},
input: {
field_names: [],
},
model_id: 'test_id',
model_type: 'pytorch',
tags: ['test_tag'],
};
const builtInMockModel: MlTrainedModelConfig = {
inference_config: {
text_classification: {},
},
input: {
field_names: [],
},
model_id: 'test_id',
model_type: 'lang_ident',
tags: [BUILT_IN_MODEL_TAG],
};

it('should return the model type and inference config type', () => {
const expected = ['pytorch', 'ner'];
const response = getMlModelTypesForModelConfig(mockModel);
expect(response.sort()).toEqual(expected.sort());
});

it('should include the built in type', () => {
const expected = ['lang_ident', 'text_classification', BUILT_IN_MODEL_TAG];
const response = getMlModelTypesForModelConfig(builtInMockModel);
expect(response.sort()).toEqual(expected.sort());
});
});

describe('getMlModelConfigsForModelIds lib function', () => {
const mockClient = {
ml: {
getTrainedModels: jest.fn(),
getTrainedModelsStats: jest.fn(),
},
};

beforeEach(() => {
jest.clearAllMocks();
});

it('should fetch the models that we ask for', async () => {
mockClient.ml.getTrainedModels.mockImplementation(() =>
Promise.resolve(mockGetTrainedModelsData)
);
mockClient.ml.getTrainedModelsStats.mockImplementation(() =>
Promise.resolve(mockGetTrainedModelStats)
);

const input = {
'trained-model-id-1': {
isDeployed: true,
pipelineName: '',
trainedModelName: 'trained-model-id-1',
types: ['pytorch', 'ner'],
},
'trained-model-id-2': {
isDeployed: true,
pipelineName: '',
trainedModelName: 'trained-model-id-2',
types: ['pytorch', 'ner'],
},
} as Record<string, InferencePipeline>;

const expected = {
'trained-model-id-2': input['trained-model-id-2'],
};
const response = await getMlModelConfigsForModelIds(
mockClient as unknown as ElasticsearchClient,
['trained-model-id-2']
);
expect(mockClient.ml.getTrainedModels).toHaveBeenCalledWith({
model_id: 'trained-model-id-2',
});
expect(mockClient.ml.getTrainedModelsStats).toHaveBeenCalledWith({
model_id: 'trained-model-id-2',
});
expect(response).toEqual(expected);
});
});

describe('fetchAndAddTrainedModelData lib function', () => {
const mockClient = {
ml: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* 2.0.
*/

import { MlTrainedModelConfig } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { ElasticsearchClient } from '@kbn/core/server';
import { BUILT_IN_MODEL_TAG } from '@kbn/ml-plugin/common/constants/data_frame_analytics';

Expand Down Expand Up @@ -65,39 +66,67 @@ export const fetchPipelineProcessorInferenceData = async (
);
};

export const fetchAndAddTrainedModelData = async (
export const getMlModelTypesForModelConfig = (trainedModel: MlTrainedModelConfig): string[] => {
if (!trainedModel) return [];

const isBuiltIn = trainedModel.tags?.includes(BUILT_IN_MODEL_TAG);

return [
trainedModel.model_type,
...Object.keys(trainedModel.inference_config || {}),
...(isBuiltIn ? [BUILT_IN_MODEL_TAG] : []),
].filter((type): type is string => type !== undefined);
};

export const getMlModelConfigsForModelIds = async (
client: ElasticsearchClient,
pipelineProcessorData: Record<string, InferencePipeline>
trainedModelNames: string[]
): Promise<Record<string, InferencePipeline>> => {
const trainedModelNames = Object.keys(pipelineProcessorData);

const [trainedModels, trainedModelsStats] = await Promise.all([
client.ml.getTrainedModels({ model_id: trainedModelNames.join() }),
client.ml.getTrainedModelsStats({ model_id: trainedModelNames.join() }),
]);

const modelConfigs: Record<string, InferencePipeline> = {};

trainedModels.trained_model_configs.forEach((trainedModelData) => {
const trainedModelName = trainedModelData.model_id;

if (pipelineProcessorData.hasOwnProperty(trainedModelName)) {
const isBuiltIn = trainedModelData.tags.includes(BUILT_IN_MODEL_TAG);

pipelineProcessorData[trainedModelName].types = [
trainedModelData.model_type,
...Object.keys(trainedModelData.inference_config || {}),
...(isBuiltIn ? [BUILT_IN_MODEL_TAG] : []),
].filter((type): type is string => type !== undefined);
if (trainedModelNames.includes(trainedModelName)) {
modelConfigs[trainedModelName] = {
isDeployed: false,
pipelineName: '',
trainedModelName,
types: getMlModelTypesForModelConfig(trainedModelData),
};
}
});

trainedModelsStats.trained_model_stats.forEach((trainedModelStats) => {
const trainedModelName = trainedModelStats.model_id;
if (pipelineProcessorData.hasOwnProperty(trainedModelName)) {
if (modelConfigs.hasOwnProperty(trainedModelName)) {
const isDeployed = trainedModelStats.deployment_stats?.state === 'started';
pipelineProcessorData[trainedModelName].isDeployed = isDeployed;
modelConfigs[trainedModelName].isDeployed = isDeployed;
}
});

return modelConfigs;
};

export const fetchAndAddTrainedModelData = async (
client: ElasticsearchClient,
pipelineProcessorData: Record<string, InferencePipeline>
): Promise<Record<string, InferencePipeline>> => {
const trainedModelNames = Object.keys(pipelineProcessorData);
const modelConfigs = await getMlModelConfigsForModelIds(client, trainedModelNames);

for (const [modelName, modelData] of Object.entries(modelConfigs)) {
if (pipelineProcessorData.hasOwnProperty(modelName)) {
pipelineProcessorData[modelName].types = modelData.types;
pipelineProcessorData[modelName].isDeployed = modelData.isDeployed;
}
}

return pipelineProcessorData;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ describe('createIndexPipelineDefinitions util function', () => {
describe('formatMlPipelineBody util function', () => {
const modelId = 'my-model-id';
let modelInputField = 'my-model-input-field';
const modelType = 'my-model-type';
const modelType = 'pytorch';
const inferenceConfigKey = 'my-model-type';
const modelTypes = ['pytorch', 'my-model-type'];
const modelVersion = 3;
const sourceField = 'my-source-field';
const destField = 'my-dest-field';
Expand All @@ -59,7 +61,6 @@ describe('formatMlPipelineBody util function', () => {
it('should return the pipeline body', async () => {
const expectedResult = {
description: '',
version: 1,
processors: [
{
remove: {
Expand All @@ -69,37 +70,41 @@ describe('formatMlPipelineBody util function', () => {
},
{
inference: {
model_id: modelId,
target_field: `ml.inference.${destField}`,
field_map: {
[sourceField]: modelInputField,
},
model_id: modelId,
target_field: `ml.inference.${destField}`,
},
},
{
append: {
field: '_source._ingest.processors',
value: [
{
type: modelType,
model_id: modelId,
model_version: modelVersion,
processed_timestamp: '{{{ _ingest.timestamp }}}',
types: modelTypes,
},
],
},
},
],
version: 1,
};

const mockResponse = {
count: 1,
trained_model_configs: [
{
inference_config: {
[inferenceConfigKey]: {},
},
input: { field_names: [modelInputField] },
model_id: modelId,
version: modelVersion,
model_type: modelType,
input: { field_names: [modelInputField] },
version: modelVersion,
},
],
};
Expand Down Expand Up @@ -131,7 +136,6 @@ describe('formatMlPipelineBody util function', () => {
modelInputField = 'MODEL_INPUT_FIELD';
const expectedResult = {
description: '',
version: 1,
processors: [
{
remove: {
Expand All @@ -141,36 +145,40 @@ describe('formatMlPipelineBody util function', () => {
},
{
inference: {
model_id: modelId,
target_field: `ml.inference.${destField}`,
field_map: {
[sourceField]: modelInputField,
},
model_id: modelId,
target_field: `ml.inference.${destField}`,
},
},
{
append: {
field: '_source._ingest.processors',
value: [
{
type: modelType,
model_id: modelId,
model_version: modelVersion,
processed_timestamp: '{{{ _ingest.timestamp }}}',
types: modelTypes,
},
],
},
},
],
version: 1,
};
const mockResponse = {
count: 1,
trained_model_configs: [
{
inference_config: {
[inferenceConfigKey]: {},
},
input: { field_names: [] },
model_id: modelId,
version: modelVersion,
model_type: modelType,
input: { field_names: [] },
version: modelVersion,
},
],
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { IngestPipeline } from '@elastic/elasticsearch/lib/api/types';
import { ElasticsearchClient } from '@kbn/core/server';

import { getInferencePipelineNameFromIndexName } from '../../utils/ml_inference_pipeline_utils';
import { getMlModelTypesForModelConfig } from '../indices/fetch_ml_inference_pipeline_processors';

export interface CreatedPipelines {
created: string[];
Expand Down Expand Up @@ -255,11 +256,10 @@ export const formatMlPipelineBody = async (
// if model returned no input field, insert a placeholder
const modelInputField =
model.input?.field_names?.length > 0 ? model.input.field_names[0] : 'MODEL_INPUT_FIELD';
const modelType = model.model_type;
const modelTypes = getMlModelTypesForModelConfig(model);
const modelVersion = model.version;
return {
description: '',
version: 1,
processors: [
{
remove: {
Expand All @@ -269,26 +269,27 @@ export const formatMlPipelineBody = async (
},
{
inference: {
model_id: modelId,
target_field: `ml.inference.${destinationField}`,
field_map: {
[sourceField]: modelInputField,
},
model_id: modelId,
target_field: `ml.inference.${destinationField}`,
},
},
{
append: {
field: '_source._ingest.processors',
value: [
{
type: modelType,
model_id: modelId,
model_version: modelVersion,
processed_timestamp: '{{{ _ingest.timestamp }}}',
types: modelTypes,
},
],
},
},
],
version: 1,
};
};

0 comments on commit ca75f7a

Please sign in to comment.