Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Adds ELSER config to the Trained Models UI #155867

Merged
merged 10 commits into from
Apr 27, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ export const SUPPORTED_PYTORCH_TASKS = {
TEXT_CLASSIFICATION: 'text_classification',
TEXT_EMBEDDING: 'text_embedding',
FILL_MASK: 'fill_mask',
// Not supported yet by the Trained Models UI
TEXT_EXPANSION: 'text_expansion',
} as const;
export type SupportedPytorchTasksType =
Expand All @@ -39,4 +40,34 @@ export const BUILT_IN_MODEL_TYPE = i18n.translate(
{ defaultMessage: 'built-in' }
);

export const CURATED_MODEL_TYPE = i18n.translate(
'xpack.ml.trainedModels.modelsList.curatedModelLabel',
{ defaultMessage: 'curated' }
);

export const BUILT_IN_MODEL_TAG = 'prepackaged';

export const CURATED_MODEL_TAG = 'curated';

export const CURATED_MODEL_DEFINITIONS = {
'.elser_model_1_SNAPSHOT': {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _SNAPSHOT will need to be removed in good time for release. Please set up a calendar reminder to do this. I will try to remember too, but the more people who are thinking about the need to update these places the better.

config: {
input: {
field_names: ['text_field'],
},
},
description: i18n.translate('xpack.ml.trainedModels.modelsList.elserDescription', {
defaultMessage: 'Elastic Learned Sparse EncodeR',
}),
},
} as const;

export const MODEL_STATE = {
...DEPLOYMENT_STATE,
DOWNLOADING: i18n.translate('xpack.ml.trainedModels.modelsList.downloadingStateLabel', {
defaultMessage: 'downloading',
}),
DOWNLOADED: i18n.translate('xpack.ml.trainedModels.modelsList.downloadedStateLabel', {
defaultMessage: 'downloaded',
}),
} as const;
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ import {
DEPLOYMENT_STATE,
TRAINED_MODEL_TYPE,
} from '@kbn/ml-trained-models-utils';
import {
CURATED_MODEL_TAG,
MODEL_STATE,
} from '@kbn/ml-trained-models-utils/src/constants/trained_models';
import { useTrainedModelsApiService } from '../services/ml_api_service/trained_models';
import { getUserConfirmationProvider } from './force_stop_dialog';
import { useToastNotificationService } from '../services/toast_notification_service';
Expand Down Expand Up @@ -154,7 +158,7 @@ export function useModelActions({
type: 'icon',
isPrimary: true,
enabled: (item) => {
return canStartStopTrainedModels && !isLoading;
return canStartStopTrainedModels && !isLoading && item.state !== MODEL_STATE.DOWNLOADING;
},
available: (item) => item.model_type === TRAINED_MODEL_TYPE.PYTORCH,
onClick: async (item) => {
Expand Down Expand Up @@ -317,6 +321,50 @@ export function useModelActions({
}
},
},
{
name: i18n.translate('xpack.ml.inference.modelsList.downloadModelActionLabel', {
defaultMessage: 'Download model',
}),
description: i18n.translate('xpack.ml.inference.modelsList.downloadModelActionLabel', {
defaultMessage: 'Download model',
}),
'data-test-subj': 'mlModelsTableRowDownloadModelAction',
icon: 'download',
type: 'icon',
isPrimary: true,
available: (item) => item.tags.includes(CURATED_MODEL_TAG),
enabled: (item) => !item.state && !isLoading,
onClick: async (item) => {
try {
onLoading(true);
await trainedModelsApiService.putTrainedModelConfig(
item.model_id,
item.putModelConfig!
);
displaySuccessToast(
i18n.translate('xpack.ml.trainedModels.modelsList.downloadSuccess', {
defaultMessage: '"{modelId}" model download has been started successfully.',
values: {
modelId: item.model_id,
},
})
);
// Need to fetch model state updates
await fetchModels();
} catch (e) {
displayErrorToast(
e,
i18n.translate('xpack.ml.trainedModels.modelsList.downloadFailed', {
defaultMessage: 'Failed to download "{modelId}"',
values: {
modelId: item.model_id,
},
})
);
onLoading(false);
}
},
},
{
name: (model) => {
const enabled = !isPopulatedObject(model.pipelines);
Expand Down Expand Up @@ -350,7 +398,8 @@ export function useModelActions({
onClick: (model) => {
onModelsDeleteRequest([model.model_id]);
},
available: (item) => canDeleteTrainedModels && !isBuiltInModel(item),
available: (item) =>
canDeleteTrainedModels && !isBuiltInModel(item) && !item.putModelConfig,
enabled: (item) => {
// TODO check for permissions to delete ingest pipelines.
// ATM undefined means pipelines fetch failed server-side.
Expand Down
123 changes: 100 additions & 23 deletions x-pack/plugins/ml/public/application/model_management/models_list.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ import {
DEPLOYMENT_STATE,
} from '@kbn/ml-trained-models-utils';
import { isDefined } from '@kbn/ml-is-defined';
import {
CURATED_MODEL_DEFINITIONS,
CURATED_MODEL_TAG,
CURATED_MODEL_TYPE,
MODEL_STATE,
} from '@kbn/ml-trained-models-utils/src/constants/trained_models';
import { useModelActions } from './model_actions';
import { ModelsTableToConfigMapping } from '.';
import { ModelsBarStats, StatsBar } from '../components/stats_bar';
Expand Down Expand Up @@ -62,6 +68,8 @@ export type ModelItem = TrainedModelConfigResponse & {
stats?: Stats & { deployment_stats: TrainedModelDeploymentStatsResponse[] };
pipelines?: ModelPipelines['pipelines'] | null;
deployment_ids: string[];
putModelConfig?: object;
state: string;
};

export type ModelItemFull = Required<ModelItem>;
Expand Down Expand Up @@ -134,6 +142,35 @@ export const ModelsList: FC<Props> = ({
[]
);

const isCuratedModel = useCallback(
(item: ModelItem) => item.tags.includes(CURATED_MODEL_TAG),
[]
);

/**
* Checks if the model download complete.
*/
const isDownloadComplete = useCallback(
async (modelId: string): Promise<boolean> => {
try {
const response = await trainedModelsApiService.getTrainedModels(modelId, {
include: 'definition_status',
});
// @ts-ignore
return !!response[0]?.fully_defined;
} catch (error) {
displayErrorToast(
error,
i18n.translate('xpack.ml.trainedModels.modelsList.downloadStatusCheckErrorMessage', {
defaultMessage: 'Failed to check download status',
})
);
}
return false;
},
[trainedModelsApiService, displayErrorToast]
);

/**
* Fetches trained models.
*/
Expand All @@ -158,6 +195,7 @@ export const ModelsList: FC<Props> = ({
model.model_type,
...Object.keys(model.inference_config),
...(isBuiltInModel(model as ModelItem) ? [BUILT_IN_MODEL_TYPE] : []),
...(isCuratedModel(model as ModelItem) ? [CURATED_MODEL_TYPE] : []),
],
}
: {}),
Expand Down Expand Up @@ -239,7 +277,26 @@ export const ModelsList: FC<Props> = ({
model.deployment_ids = modelStats
.map((v) => v.deployment_stats?.deployment_id)
.filter(isDefined);
model.state = model.stats.deployment_stats?.some(
(v) => v.state === DEPLOYMENT_STATE.STARTED
)
? DEPLOYMENT_STATE.STARTED
: '';
});

const curatedModels = models.filter((model) =>
CURATED_MODEL_DEFINITIONS.hasOwnProperty(model.model_id)
);
if (curatedModels.length > 0) {
for (const model of curatedModels) {
if (model.state === MODEL_STATE.STARTED) {
// no need to check for the download status if the model has been deployed
continue;
}
const isDownloaded = await isDownloadComplete(model.model_id);
model.state = isDownloaded ? MODEL_STATE.DOWNLOADED : MODEL_STATE.DOWNLOADING;
}
}
}

return true;
Expand Down Expand Up @@ -310,21 +367,26 @@ export const ModelsList: FC<Props> = ({
align: 'left',
width: '40px',
isExpander: true,
render: (item: ModelItem) => (
<EuiButtonIcon
onClick={toggleDetails.bind(null, item)}
aria-label={
itemIdToExpandedRowMap[item.model_id]
? i18n.translate('xpack.ml.trainedModels.modelsList.collapseRow', {
defaultMessage: 'Collapse',
})
: i18n.translate('xpack.ml.trainedModels.modelsList.expandRow', {
defaultMessage: 'Expand',
})
}
iconType={itemIdToExpandedRowMap[item.model_id] ? 'arrowDown' : 'arrowRight'}
/>
),
render: (item: ModelItem) => {
if (!item.stats) {
return null;
}
return (
<EuiButtonIcon
onClick={toggleDetails.bind(null, item)}
aria-label={
itemIdToExpandedRowMap[item.model_id]
? i18n.translate('xpack.ml.trainedModels.modelsList.collapseRow', {
defaultMessage: 'Collapse',
})
: i18n.translate('xpack.ml.trainedModels.modelsList.expandRow', {
defaultMessage: 'Expand',
})
}
iconType={itemIdToExpandedRowMap[item.model_id] ? 'arrowDown' : 'arrowRight'}
/>
);
},
'data-test-subj': 'mlModelsTableRowDetailsToggle',
},
{
Expand Down Expand Up @@ -368,17 +430,13 @@ export const ModelsList: FC<Props> = ({
'data-test-subj': 'mlModelsTableColumnType',
},
{
field: 'state',
name: i18n.translate('xpack.ml.trainedModels.modelsList.stateHeader', {
defaultMessage: 'State',
}),
align: 'left',
truncateText: false,
render: (model: ModelItem) => {
const state = model.stats?.deployment_stats?.some(
(v) => v.state === DEPLOYMENT_STATE.STARTED
)
? DEPLOYMENT_STATE.STARTED
: '';
render: (state: string) => {
return state ? <EuiBadge color="hollow">{state}</EuiBadge> : null;
},
'data-test-subj': 'mlModelsTableColumnDeploymentState',
Expand Down Expand Up @@ -473,7 +531,10 @@ export const ModelsList: FC<Props> = ({

return '';
},
selectable: (item) => !isPopulatedObject(item.pipelines) && !isBuiltInModel(item),
selectable: (item) =>
!isPopulatedObject(item.pipelines) &&
!isBuiltInModel(item) &&
!(isCuratedModel(item) && !item.state),
onSelectionChange: (selectedItems) => {
setSelectedModels(selectedItems);
},
Expand Down Expand Up @@ -510,6 +571,22 @@ export const ModelsList: FC<Props> = ({
: {}),
};

const resultItems = useMemo<ModelItem[]>(() => {
const idSet = new Set(items.map((i) => i.model_id));
const notDownloaded: ModelItem[] = Object.entries(CURATED_MODEL_DEFINITIONS)
.filter(([modelId]) => !idSet.has(modelId))
.map(([modelId, modelDefinition]) => {
return {
model_id: modelId,
type: [CURATED_MODEL_TYPE],
tags: [CURATED_MODEL_TAG],
putModelConfig: modelDefinition.config,
description: modelDefinition.description,
} as ModelItem;
});
return [...items, ...notDownloaded];
}, [items]);

return (
<>
<SavedObjectsWarning onCloseFlyout={fetchModelsData} forceRefresh={isLoading} />
Expand All @@ -531,7 +608,7 @@ export const ModelsList: FC<Props> = ({
isExpandable={true}
itemIdToExpandedRowMap={itemIdToExpandedRowMap}
isSelectable={false}
items={items}
items={resultItems}
itemId={ModelsTableToConfigMapping.id}
loading={isLoading}
search={search}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ import {
} from '@kbn/ml-trained-models-utils';
import type { ModelItem } from '../models_list';

const PYTORCH_TYPES = Object.values(SUPPORTED_PYTORCH_TASKS);
const PYTORCH_TYPES = Object.values(SUPPORTED_PYTORCH_TASKS).filter(
(taskType) => taskType !== SUPPORTED_PYTORCH_TASKS.TEXT_EXPANSION
);

export function isTestable(modelItem: ModelItem, checkForState = false) {
if (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,14 @@ export function trainedModelsApiProvider(httpService: HttpService) {
query: { type, node, showClosedJobs },
});
},

putTrainedModelConfig(modelId: string, config: object) {
return httpService.http<estypes.MlPutTrainedModelResponse>({
path: `${apiBasePath}/trained_models/${modelId}`,
method: 'PUT',
body: JSON.stringify(config),
});
},
};
}

Expand Down