diff --git a/x-pack/plugins/ml/common/types/common.ts b/x-pack/plugins/ml/common/types/common.ts index a5253456252cd..fdeffe14ddaf7 100644 --- a/x-pack/plugins/ml/common/types/common.ts +++ b/x-pack/plugins/ml/common/types/common.ts @@ -46,3 +46,7 @@ export interface ListingPageUrlState { export type AppPageState = { [key in MlPages]?: Partial; }; + +type Without = { [P in Exclude]?: never }; + +export type XOR = T | U extends object ? (Without & U) | (Without & T) : T | U; diff --git a/x-pack/plugins/ml/common/types/trained_models.ts b/x-pack/plugins/ml/common/types/trained_models.ts index 6b320d503b4c0..3c4c3af748645 100644 --- a/x-pack/plugins/ml/common/types/trained_models.ts +++ b/x-pack/plugins/ml/common/types/trained_models.ts @@ -7,6 +7,7 @@ import { DataFrameAnalyticsConfig } from './data_frame_analytics'; import { FeatureImportanceBaseline, TotalFeatureImportance } from './feature_importance'; +import { XOR } from './common'; export interface IngestStats { count: number; @@ -45,22 +46,54 @@ export interface TrainedModelStat { }; } +type TreeNode = object; + +export type PutTrainedModelConfig = { + description?: string; + metadata?: { + analytics_config: DataFrameAnalyticsConfig; + input: unknown; + total_feature_importance?: TotalFeatureImportance[]; + feature_importance_baseline?: FeatureImportanceBaseline; + model_aliases?: string[]; + } & Record; + tags?: string[]; + inference_config?: Record; + input: { field_names: string[] }; +} & XOR< + { compressed_definition: string }, + { + definition: { + preprocessors: object[]; + trained_model: { + tree: { + classification_labels?: string; + feature_names: string; + target_type: string; + tree_structure: TreeNode[]; + }; + tree_node: TreeNode; + ensemble?: object; + }; + }; + } +>; // compressed_definition and definition are mutually exclusive + export interface TrainedModelConfigResponse { - description: string; + description?: string; created_by: string; create_time: string; default_field_map: Record; estimated_heap_memory_usage_bytes: number; estimated_operations: number; license_level: string; - metadata?: - | { - analytics_config: DataFrameAnalyticsConfig; - input: any; - total_feature_importance?: TotalFeatureImportance[]; - feature_importance_baseline?: FeatureImportanceBaseline; - } - | Record; + metadata?: { + analytics_config: DataFrameAnalyticsConfig; + input: unknown; + total_feature_importance?: TotalFeatureImportance[]; + feature_importance_baseline?: FeatureImportanceBaseline; + model_aliases?: string[]; + } & Record; model_id: string; tags: string[]; version: string; diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/analytics_navigation_bar/analytics_navigation_bar.tsx b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/analytics_navigation_bar/analytics_navigation_bar.tsx index 32aa14559da0e..d26b5d5cfc16f 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/analytics_navigation_bar/analytics_navigation_bar.tsx +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/analytics_navigation_bar/analytics_navigation_bar.tsx @@ -31,6 +31,7 @@ export const AnalyticsNavigationBar: FC<{ defaultMessage: 'Jobs', }), path: '/data_frame_analytics', + testSubj: 'mlAnalyticsJobsTab', }, { id: 'models', @@ -38,6 +39,7 @@ export const AnalyticsNavigationBar: FC<{ defaultMessage: 'Models', }), path: '/data_frame_analytics/models', + testSubj: 'mlTrainedModelsTab', }, ]; if (jobId !== undefined || modelId !== undefined) { @@ -47,6 +49,7 @@ export const AnalyticsNavigationBar: FC<{ defaultMessage: 'Map', }), path: '/data_frame_analytics/map', + testSubj: '', }); } return navTabs; @@ -67,6 +70,7 @@ export const AnalyticsNavigationBar: FC<{ key={`tab-${tab.id}`} isSelected={tab.id === selectedTabId} onClick={onTabClick.bind(null, tab)} + data-test-subj={tab.testSubj} > {tab.name} diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/models_management/expanded_row.tsx b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/models_management/expanded_row.tsx index 476931e4b8551..88ffaa0da7fdc 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/models_management/expanded_row.tsx +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/models_management/expanded_row.tsx @@ -29,6 +29,7 @@ import { ModelItemFull } from './models_list'; import { useMlKibana } from '../../../../../contexts/kibana'; import { timeFormatter } from '../../../../../../../common/util/date_utils'; import { isDefined } from '../../../../../../../common/types/guards'; +import { isPopulatedObject } from '../../../../../../../common'; interface ExpandedRowProps { item: ModelItemFull; @@ -70,6 +71,8 @@ export const ExpandedRow: FC = ({ item }) => { description, } = item; + const { analytics_config: analyticsConfig, ...restMetaData } = metadata ?? {}; + const details = { description, tags, @@ -148,6 +151,26 @@ export const ExpandedRow: FC = ({ item }) => { /> + {isPopulatedObject(restMetaData) ? ( + + + +
+ +
+
+ + +
+
+ ) : null} ), @@ -186,7 +209,7 @@ export const ExpandedRow: FC = ({ item }) => { /> - {metadata?.analytics_config && ( + {analyticsConfig && ( @@ -201,7 +224,7 @@ export const ExpandedRow: FC = ({ item }) => { diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/models_management/models_list.tsx b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/models_management/models_list.tsx index 79b8c7130e73c..b9803f1ea26e0 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/models_management/models_list.tsx +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/models_management/models_list.tsx @@ -292,7 +292,7 @@ export const ModelsList: FC = () => { }), icon: 'visTable', type: 'icon', - available: (item) => item.metadata?.analytics_config?.id, + available: (item) => !!item.metadata?.analytics_config?.id, onClick: async (item) => { if (item.metadata?.analytics_config === undefined) return; @@ -327,7 +327,7 @@ export const ModelsList: FC = () => { icon: 'graphApp', type: 'icon', isPrimary: true, - available: (item) => item.metadata?.analytics_config?.id, + available: (item) => !!item.metadata?.analytics_config?.id, onClick: async (item) => { const path = await mlUrlGenerator.createUrl({ page: ML_PAGES.DATA_FRAME_ANALYTICS_MAP, diff --git a/x-pack/plugins/ml/server/models/data_frame_analytics/models_provider.ts b/x-pack/plugins/ml/server/models/data_frame_analytics/models_provider.ts index bafa5c300e79f..84f0fbaea0579 100644 --- a/x-pack/plugins/ml/server/models/data_frame_analytics/models_provider.ts +++ b/x-pack/plugins/ml/server/models/data_frame_analytics/models_provider.ts @@ -11,8 +11,8 @@ import { PipelineDefinition } from '../../../common/types/trained_models'; export function modelsProvider(client: IScopedClusterClient) { return { /** - * Retrieves the map of model ids and associated pipelines. - * @param modelIds + * Retrieves the map of model ids and aliases with associated pipelines. + * @param modelIds - Array of models ids and model aliases. */ async getModelsPipelines(modelIds: string[]) { const modelIdsMap = new Map | null>( diff --git a/x-pack/plugins/ml/server/routes/trained_models.ts b/x-pack/plugins/ml/server/routes/trained_models.ts index dbfc2195a12e1..c4b2d63b05d13 100644 --- a/x-pack/plugins/ml/server/routes/trained_models.ts +++ b/x-pack/plugins/ml/server/routes/trained_models.ts @@ -13,6 +13,7 @@ import { optionalModelIdSchema, } from './schemas/inference_schema'; import { modelsProvider } from '../models/data_frame_analytics'; +import { TrainedModelConfigResponse } from '../../common/types/trained_models'; export function trainedModelsRoutes({ router, routeGuard }: RouteInitialization) { /** @@ -42,14 +43,32 @@ export function trainedModelsRoutes({ router, routeGuard }: RouteInitialization) ...query, ...(modelId ? { model_id: modelId } : {}), }); - const result = body.trained_model_configs; + const result = body.trained_model_configs as TrainedModelConfigResponse[]; try { if (withPipelines) { + const modelIdsAndAliases: string[] = Array.from( + new Set( + result + .map(({ model_id: id, metadata }) => { + return [id, ...(metadata?.model_aliases ?? [])]; + }) + .flat() + ) + ); + const pipelinesResponse = await modelsProvider(client).getModelsPipelines( - result.map(({ model_id: id }: { model_id: string }) => id) + modelIdsAndAliases ); for (const model of result) { - model.pipelines = pipelinesResponse.get(model.model_id)!; + model.pipelines = { + ...(pipelinesResponse.get(model.model_id) ?? {}), + ...(model.metadata?.model_aliases ?? []).reduce((acc, alias) => { + return { + ...acc, + ...(pipelinesResponse.get(alias) ?? {}), + }; + }, {}), + }; } } } catch (e) { diff --git a/x-pack/test/functional/apps/ml/data_frame_analytics/index.ts b/x-pack/test/functional/apps/ml/data_frame_analytics/index.ts index e7b5df70c99a0..4de95a5d82054 100644 --- a/x-pack/test/functional/apps/ml/data_frame_analytics/index.ts +++ b/x-pack/test/functional/apps/ml/data_frame_analytics/index.ts @@ -16,5 +16,6 @@ export default function ({ loadTestFile }: FtrProviderContext) { loadTestFile(require.resolve('./classification_creation')); loadTestFile(require.resolve('./cloning')); loadTestFile(require.resolve('./feature_importance')); + loadTestFile(require.resolve('./trained_models')); }); } diff --git a/x-pack/test/functional/apps/ml/data_frame_analytics/trained_models.ts b/x-pack/test/functional/apps/ml/data_frame_analytics/trained_models.ts new file mode 100644 index 0000000000000..43130651cb121 --- /dev/null +++ b/x-pack/test/functional/apps/ml/data_frame_analytics/trained_models.ts @@ -0,0 +1,31 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { FtrProviderContext } from '../../../ftr_provider_context'; + +export default function ({ getService }: FtrProviderContext) { + const ml = getService('ml'); + + describe('trained models', function () { + before(async () => { + await ml.trainedModels.createdTestTrainedModels('classification', 15); + await ml.trainedModels.createdTestTrainedModels('regression', 15); + await ml.securityUI.loginAsMlPowerUser(); + await ml.navigation.navigateToTrainedModels(); + }); + + after(async () => { + await ml.api.cleanMlIndices(); + }); + + it('renders trained models list', async () => { + await ml.trainedModels.assertRowsNumberPerPage(10); + // +1 because of the built-in model + await ml.trainedModels.assertStats(31); + }); + }); +} diff --git a/x-pack/test/functional/services/ml/api.ts b/x-pack/test/functional/services/ml/api.ts index 7d09deff6f6b7..c0e3dedd8e191 100644 --- a/x-pack/test/functional/services/ml/api.ts +++ b/x-pack/test/functional/services/ml/api.ts @@ -23,6 +23,7 @@ import { ML_ANNOTATIONS_INDEX_ALIAS_WRITE, } from '../../../../plugins/ml/common/constants/index_patterns'; import { COMMON_REQUEST_HEADERS } from '../../../functional/services/ml/common_api'; +import { PutTrainedModelConfig } from '../../../../plugins/ml/common/types/trained_models'; export function MachineLearningAPIProvider({ getService }: FtrProviderContext) { const es = getService('es'); @@ -935,5 +936,17 @@ export function MachineLearningAPIProvider({ getService }: FtrProviderContext) { } } }, + + async createTrainedModel(modelId: string, body: PutTrainedModelConfig) { + log.debug(`Creating trained model with id "${modelId}"`); + const model = await esSupertest + .put(`/_ml/trained_models/${modelId}`) + .send(body) + .expect(200) + .then((res: any) => res.body); + + log.debug('> Trained model crated'); + return model; + }, }; } diff --git a/x-pack/test/functional/services/ml/common_ui.ts b/x-pack/test/functional/services/ml/common_ui.ts index b7ed95ef76ece..f42f54116c926 100644 --- a/x-pack/test/functional/services/ml/common_ui.ts +++ b/x-pack/test/functional/services/ml/common_ui.ts @@ -245,5 +245,28 @@ export function MachineLearningCommonUIProvider({ getService }: FtrProviderConte ); }); }, + + async assertRowsNumberPerPage(testSubj: string, rowsNumber: 10 | 25 | 100) { + const textContent = await testSubjects.getVisibleText( + `${testSubj} > tablePaginationPopoverButton` + ); + expect(textContent).to.be(`Rows per page: ${rowsNumber}`); + }, + + async ensurePagePopupOpen(testSubj: string) { + await retry.tryForTime(5000, async () => { + const isOpen = await testSubjects.exists('tablePagination-10-rows'); + if (!isOpen) { + await testSubjects.click(`${testSubj} > tablePaginationPopoverButton`); + await testSubjects.existOrFail('tablePagination-10-rows'); + } + }); + }, + + async setRowsNumberPerPage(testSubj: string, rowsNumber: 10 | 25 | 100) { + await this.ensurePagePopupOpen(testSubj); + await testSubjects.click(`tablePagination-${rowsNumber}-rows`); + await this.assertRowsNumberPerPage(testSubj, rowsNumber); + }, }; } diff --git a/x-pack/test/functional/services/ml/index.ts b/x-pack/test/functional/services/ml/index.ts index 05d369d890289..6a2e1158e70a3 100644 --- a/x-pack/test/functional/services/ml/index.ts +++ b/x-pack/test/functional/services/ml/index.ts @@ -48,6 +48,7 @@ import { MachineLearningAlertingProvider } from './alerting'; import { SwimLaneProvider } from './swim_lane'; import { MachineLearningDashboardJobSelectionTableProvider } from './dashboard_job_selection_table'; import { MachineLearningDashboardEmbeddablesProvider } from './dashboard_embeddables'; +import { TrainedModelsProvider } from './trained_models'; export function MachineLearningProvider(context: FtrProviderContext) { const commonAPI = MachineLearningCommonAPIProvider(context); @@ -108,6 +109,7 @@ export function MachineLearningProvider(context: FtrProviderContext) { const testResources = MachineLearningTestResourcesProvider(context); const alerting = MachineLearningAlertingProvider(context, commonUI); const swimLane = SwimLaneProvider(context); + const trainedModels = TrainedModelsProvider(context, api, commonUI); return { anomaliesTable, @@ -151,5 +153,6 @@ export function MachineLearningProvider(context: FtrProviderContext) { swimLane, testExecution, testResources, + trainedModels, }; } diff --git a/x-pack/test/functional/services/ml/navigation.ts b/x-pack/test/functional/services/ml/navigation.ts index 075c788a86336..9bebc25f2de4c 100644 --- a/x-pack/test/functional/services/ml/navigation.ts +++ b/x-pack/test/functional/services/ml/navigation.ts @@ -115,6 +115,13 @@ export function MachineLearningNavigationProvider({ await this.navigateToArea('~mlMainTab & ~dataFrameAnalytics', 'mlPageDataFrameAnalytics'); }, + async navigateToTrainedModels() { + await this.navigateToMl(); + await this.navigateToDataFrameAnalytics(); + await testSubjects.click('mlTrainedModelsTab'); + await testSubjects.existOrFail('mlModelsTableContainer'); + }, + async navigateToDataVisualizer() { await this.navigateToArea('~mlMainTab & ~dataVisualizer', 'mlPageDataVisualizerSelector'); }, diff --git a/x-pack/test/functional/services/ml/resources/trained_model_definitions/minimum_valid_config_classification.json.gz.b64 b/x-pack/test/functional/services/ml/resources/trained_model_definitions/minimum_valid_config_classification.json.gz.b64 new file mode 100644 index 0000000000000..c3e3a50f52d91 --- /dev/null +++ b/x-pack/test/functional/services/ml/resources/trained_model_definitions/minimum_valid_config_classification.json.gz.b64 @@ -0,0 +1 @@ +H4sICOE6Ol8AA2NsZi5qc29uAD2MQQqAIBBF955CXHeCrhIhg44xYBo6LUK8e1rW4i/ef59fhJSKE1BAq/do0atZllY+NeJPjR0Cnwl1gB1zE8sQXcWoBq3Tt2dIG7Lm6+g3ynjImRwZYIrhnVfRU28zTg0thgAAAA== \ No newline at end of file diff --git a/x-pack/test/functional/services/ml/resources/trained_model_definitions/minimum_valid_config_regression.json.gz.b64 b/x-pack/test/functional/services/ml/resources/trained_model_definitions/minimum_valid_config_regression.json.gz.b64 new file mode 100644 index 0000000000000..05c1aec7f149a --- /dev/null +++ b/x-pack/test/functional/services/ml/resources/trained_model_definitions/minimum_valid_config_regression.json.gz.b64 @@ -0,0 +1 @@ +H4sICOc8Ol8AA3JnLmpzb24APYxBCoAgFET3nkJcd4KuEiGCkwip8f0tQrx7WtZiFm/eMEVIqZiMj7A6JItdzbK08qmBnxpvMHwSdDQBuYlliK5SUoPW6duzIQfWfB39RhEcIWef4jutoqfeWtCVIIIAAAA= \ No newline at end of file diff --git a/x-pack/test/functional/services/ml/trained_models.ts b/x-pack/test/functional/services/ml/trained_models.ts new file mode 100644 index 0000000000000..ae799efbbd30c --- /dev/null +++ b/x-pack/test/functional/services/ml/trained_models.ts @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import fs from 'fs'; +import path from 'path'; +import expect from '@kbn/expect'; +import { FtrProviderContext } from '../../ftr_provider_context'; +import { MlApi } from './api'; +import { PutTrainedModelConfig } from '../../../../plugins/ml/common/types/trained_models'; +import { MlCommonUI } from './common_ui'; + +type ModelType = 'regression' | 'classification'; + +export function TrainedModelsProvider( + { getService }: FtrProviderContext, + mlApi: MlApi, + mlCommonUI: MlCommonUI +) { + const testSubjects = getService('testSubjects'); + + return { + async createdTestTrainedModels(modelType: ModelType, count: number = 10) { + const compressedDefinition = this.getCompressedModelDefinition(modelType); + + const models = new Array(count).fill(null).map((v, i) => { + return { + model_id: `dfa_${modelType}_model_n_${i}`, + body: { + compressed_definition: compressedDefinition, + inference_config: { + [modelType]: {}, + }, + input: { + field_names: ['common_field'], + }, + } as PutTrainedModelConfig, + }; + }); + + for (const model of models) { + await mlApi.createTrainedModel(model.model_id, model.body); + } + }, + + getCompressedModelDefinition(modelType: ModelType) { + return fs.readFileSync( + path.resolve( + __dirname, + 'resources', + 'trained_model_definitions', + `minimum_valid_config_${modelType}.json.gz.b64` + ), + 'utf-8' + ); + }, + + async assertStats(expectedTotalCount: number) { + const actualStats = await testSubjects.getVisibleText('mlInferenceModelsStatsBar'); + expect(actualStats).to.eql(`Total trained models: ${expectedTotalCount}`); + }, + + async assertRowsNumberPerPage(rowsNumber: 10 | 25 | 100) { + await mlCommonUI.assertRowsNumberPerPage('mlModelsTableContainer', rowsNumber); + }, + }; +}