Skip to content

Commit

Permalink
[ML] ELSER config in the Trained Models UI (elastic#155867)
Browse files Browse the repository at this point in the history
  • Loading branch information
darnautov authored Apr 27, 2023
1 parent 2250020 commit bf64874
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ export const SUPPORTED_PYTORCH_TASKS = {
TEXT_CLASSIFICATION: 'text_classification',
TEXT_EMBEDDING: 'text_embedding',
FILL_MASK: 'fill_mask',
// Not supported yet by the Trained Models UI
TEXT_EXPANSION: 'text_expansion',
} as const;
export type SupportedPytorchTasksType =
Expand All @@ -39,4 +40,34 @@ export const BUILT_IN_MODEL_TYPE = i18n.translate(
{ defaultMessage: 'built-in' }
);

export const CURATED_MODEL_TYPE = i18n.translate(
'xpack.ml.trainedModels.modelsList.curatedModelLabel',
{ defaultMessage: 'curated' }
);

export const BUILT_IN_MODEL_TAG = 'prepackaged';

export const CURATED_MODEL_TAG = 'curated';

export const CURATED_MODEL_DEFINITIONS = {
'.elser_model_1_SNAPSHOT': {
config: {
input: {
field_names: ['text_field'],
},
},
description: i18n.translate('xpack.ml.trainedModels.modelsList.elserDescription', {
defaultMessage: 'Elastic Learned Sparse EncodeR',
}),
},
} as const;

export const MODEL_STATE = {
...DEPLOYMENT_STATE,
DOWNLOADING: i18n.translate('xpack.ml.trainedModels.modelsList.downloadingStateLabel', {
defaultMessage: 'downloading',
}),
DOWNLOADED: i18n.translate('xpack.ml.trainedModels.modelsList.downloadedStateLabel', {
defaultMessage: 'downloaded',
}),
} as const;
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ import {
DEPLOYMENT_STATE,
TRAINED_MODEL_TYPE,
} from '@kbn/ml-trained-models-utils';
import {
CURATED_MODEL_TAG,
MODEL_STATE,
} from '@kbn/ml-trained-models-utils/src/constants/trained_models';
import { useTrainedModelsApiService } from '../services/ml_api_service/trained_models';
import { getUserConfirmationProvider } from './force_stop_dialog';
import { useToastNotificationService } from '../services/toast_notification_service';
Expand Down Expand Up @@ -154,7 +158,7 @@ export function useModelActions({
type: 'icon',
isPrimary: true,
enabled: (item) => {
return canStartStopTrainedModels && !isLoading;
return canStartStopTrainedModels && !isLoading && item.state !== MODEL_STATE.DOWNLOADING;
},
available: (item) => item.model_type === TRAINED_MODEL_TYPE.PYTORCH,
onClick: async (item) => {
Expand Down Expand Up @@ -317,6 +321,50 @@ export function useModelActions({
}
},
},
{
name: i18n.translate('xpack.ml.inference.modelsList.downloadModelActionLabel', {
defaultMessage: 'Download model',
}),
description: i18n.translate('xpack.ml.inference.modelsList.downloadModelActionLabel', {
defaultMessage: 'Download model',
}),
'data-test-subj': 'mlModelsTableRowDownloadModelAction',
icon: 'download',
type: 'icon',
isPrimary: true,
available: (item) => item.tags.includes(CURATED_MODEL_TAG),
enabled: (item) => !item.state && !isLoading,
onClick: async (item) => {
try {
onLoading(true);
await trainedModelsApiService.putTrainedModelConfig(
item.model_id,
item.putModelConfig!
);
displaySuccessToast(
i18n.translate('xpack.ml.trainedModels.modelsList.downloadSuccess', {
defaultMessage: '"{modelId}" model download has been started successfully.',
values: {
modelId: item.model_id,
},
})
);
// Need to fetch model state updates
await fetchModels();
} catch (e) {
displayErrorToast(
e,
i18n.translate('xpack.ml.trainedModels.modelsList.downloadFailed', {
defaultMessage: 'Failed to download "{modelId}"',
values: {
modelId: item.model_id,
},
})
);
onLoading(false);
}
},
},
{
name: (model) => {
const enabled = !isPopulatedObject(model.pipelines);
Expand Down Expand Up @@ -350,7 +398,8 @@ export function useModelActions({
onClick: (model) => {
onModelsDeleteRequest([model.model_id]);
},
available: (item) => canDeleteTrainedModels && !isBuiltInModel(item),
available: (item) =>
canDeleteTrainedModels && !isBuiltInModel(item) && !item.putModelConfig,
enabled: (item) => {
// TODO check for permissions to delete ingest pipelines.
// ATM undefined means pipelines fetch failed server-side.
Expand Down
123 changes: 100 additions & 23 deletions x-pack/plugins/ml/public/application/model_management/models_list.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ import {
DEPLOYMENT_STATE,
} from '@kbn/ml-trained-models-utils';
import { isDefined } from '@kbn/ml-is-defined';
import {
CURATED_MODEL_DEFINITIONS,
CURATED_MODEL_TAG,
CURATED_MODEL_TYPE,
MODEL_STATE,
} from '@kbn/ml-trained-models-utils/src/constants/trained_models';
import { useModelActions } from './model_actions';
import { ModelsTableToConfigMapping } from '.';
import { ModelsBarStats, StatsBar } from '../components/stats_bar';
Expand Down Expand Up @@ -62,6 +68,8 @@ export type ModelItem = TrainedModelConfigResponse & {
stats?: Stats & { deployment_stats: TrainedModelDeploymentStatsResponse[] };
pipelines?: ModelPipelines['pipelines'] | null;
deployment_ids: string[];
putModelConfig?: object;
state: string;
};

export type ModelItemFull = Required<ModelItem>;
Expand Down Expand Up @@ -134,6 +142,35 @@ export const ModelsList: FC<Props> = ({
[]
);

const isCuratedModel = useCallback(
(item: ModelItem) => item.tags.includes(CURATED_MODEL_TAG),
[]
);

/**
* Checks if the model download complete.
*/
const isDownloadComplete = useCallback(
async (modelId: string): Promise<boolean> => {
try {
const response = await trainedModelsApiService.getTrainedModels(modelId, {
include: 'definition_status',
});
// @ts-ignore
return !!response[0]?.fully_defined;
} catch (error) {
displayErrorToast(
error,
i18n.translate('xpack.ml.trainedModels.modelsList.downloadStatusCheckErrorMessage', {
defaultMessage: 'Failed to check download status',
})
);
}
return false;
},
[trainedModelsApiService, displayErrorToast]
);

/**
* Fetches trained models.
*/
Expand All @@ -158,6 +195,7 @@ export const ModelsList: FC<Props> = ({
model.model_type,
...Object.keys(model.inference_config),
...(isBuiltInModel(model as ModelItem) ? [BUILT_IN_MODEL_TYPE] : []),
...(isCuratedModel(model as ModelItem) ? [CURATED_MODEL_TYPE] : []),
],
}
: {}),
Expand Down Expand Up @@ -239,7 +277,26 @@ export const ModelsList: FC<Props> = ({
model.deployment_ids = modelStats
.map((v) => v.deployment_stats?.deployment_id)
.filter(isDefined);
model.state = model.stats.deployment_stats?.some(
(v) => v.state === DEPLOYMENT_STATE.STARTED
)
? DEPLOYMENT_STATE.STARTED
: '';
});

const curatedModels = models.filter((model) =>
CURATED_MODEL_DEFINITIONS.hasOwnProperty(model.model_id)
);
if (curatedModels.length > 0) {
for (const model of curatedModels) {
if (model.state === MODEL_STATE.STARTED) {
// no need to check for the download status if the model has been deployed
continue;
}
const isDownloaded = await isDownloadComplete(model.model_id);
model.state = isDownloaded ? MODEL_STATE.DOWNLOADED : MODEL_STATE.DOWNLOADING;
}
}
}

return true;
Expand Down Expand Up @@ -310,21 +367,26 @@ export const ModelsList: FC<Props> = ({
align: 'left',
width: '40px',
isExpander: true,
render: (item: ModelItem) => (
<EuiButtonIcon
onClick={toggleDetails.bind(null, item)}
aria-label={
itemIdToExpandedRowMap[item.model_id]
? i18n.translate('xpack.ml.trainedModels.modelsList.collapseRow', {
defaultMessage: 'Collapse',
})
: i18n.translate('xpack.ml.trainedModels.modelsList.expandRow', {
defaultMessage: 'Expand',
})
}
iconType={itemIdToExpandedRowMap[item.model_id] ? 'arrowDown' : 'arrowRight'}
/>
),
render: (item: ModelItem) => {
if (!item.stats) {
return null;
}
return (
<EuiButtonIcon
onClick={toggleDetails.bind(null, item)}
aria-label={
itemIdToExpandedRowMap[item.model_id]
? i18n.translate('xpack.ml.trainedModels.modelsList.collapseRow', {
defaultMessage: 'Collapse',
})
: i18n.translate('xpack.ml.trainedModels.modelsList.expandRow', {
defaultMessage: 'Expand',
})
}
iconType={itemIdToExpandedRowMap[item.model_id] ? 'arrowDown' : 'arrowRight'}
/>
);
},
'data-test-subj': 'mlModelsTableRowDetailsToggle',
},
{
Expand Down Expand Up @@ -368,17 +430,13 @@ export const ModelsList: FC<Props> = ({
'data-test-subj': 'mlModelsTableColumnType',
},
{
field: 'state',
name: i18n.translate('xpack.ml.trainedModels.modelsList.stateHeader', {
defaultMessage: 'State',
}),
align: 'left',
truncateText: false,
render: (model: ModelItem) => {
const state = model.stats?.deployment_stats?.some(
(v) => v.state === DEPLOYMENT_STATE.STARTED
)
? DEPLOYMENT_STATE.STARTED
: '';
render: (state: string) => {
return state ? <EuiBadge color="hollow">{state}</EuiBadge> : null;
},
'data-test-subj': 'mlModelsTableColumnDeploymentState',
Expand Down Expand Up @@ -473,7 +531,10 @@ export const ModelsList: FC<Props> = ({

return '';
},
selectable: (item) => !isPopulatedObject(item.pipelines) && !isBuiltInModel(item),
selectable: (item) =>
!isPopulatedObject(item.pipelines) &&
!isBuiltInModel(item) &&
!(isCuratedModel(item) && !item.state),
onSelectionChange: (selectedItems) => {
setSelectedModels(selectedItems);
},
Expand Down Expand Up @@ -510,6 +571,22 @@ export const ModelsList: FC<Props> = ({
: {}),
};

const resultItems = useMemo<ModelItem[]>(() => {
const idSet = new Set(items.map((i) => i.model_id));
const notDownloaded: ModelItem[] = Object.entries(CURATED_MODEL_DEFINITIONS)
.filter(([modelId]) => !idSet.has(modelId))
.map(([modelId, modelDefinition]) => {
return {
model_id: modelId,
type: [CURATED_MODEL_TYPE],
tags: [CURATED_MODEL_TAG],
putModelConfig: modelDefinition.config,
description: modelDefinition.description,
} as ModelItem;
});
return [...items, ...notDownloaded];
}, [items]);

return (
<>
<SavedObjectsWarning onCloseFlyout={fetchModelsData} forceRefresh={isLoading} />
Expand All @@ -531,7 +608,7 @@ export const ModelsList: FC<Props> = ({
isExpandable={true}
itemIdToExpandedRowMap={itemIdToExpandedRowMap}
isSelectable={false}
items={items}
items={resultItems}
itemId={ModelsTableToConfigMapping.id}
loading={isLoading}
search={search}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ import {
} from '@kbn/ml-trained-models-utils';
import type { ModelItem } from '../models_list';

const PYTORCH_TYPES = Object.values(SUPPORTED_PYTORCH_TASKS);
const PYTORCH_TYPES = Object.values(SUPPORTED_PYTORCH_TASKS).filter(
(taskType) => taskType !== SUPPORTED_PYTORCH_TASKS.TEXT_EXPANSION
);

export function isTestable(modelItem: ModelItem, checkForState = false) {
if (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,14 @@ export function trainedModelsApiProvider(httpService: HttpService) {
query: { type, node, showClosedJobs },
});
},

putTrainedModelConfig(modelId: string, config: object) {
return httpService.http<estypes.MlPutTrainedModelResponse>({
path: `${apiBasePath}/trained_models/${modelId}`,
method: 'PUT',
body: JSON.stringify(config),
});
},
};
}

Expand Down

0 comments on commit bf64874

Please sign in to comment.