Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Trained Models: Optimize trained models Kibana API #200977

Merged
merged 26 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
66022ba
fix isNaN issue
darnautov Nov 18, 2024
2d803f6
refactor types and guards
darnautov Nov 20, 2024
5c3d9bf
rename types
darnautov Nov 20, 2024
6509cba
move types to common
darnautov Nov 25, 2024
17250c5
Merge remote-tracking branch 'origin/main' into ml-191939-fix-trained…
darnautov Nov 28, 2024
356c804
wip fetch server side
darnautov Nov 29, 2024
c21c453
create new endpoint
darnautov Dec 1, 2024
6f3d663
assign indices, refactor
darnautov Dec 2, 2024
fa4d474
check for dfa jobs
darnautov Dec 2, 2024
8ce7687
cleanup
darnautov Dec 2, 2024
a654bfe
cleanup model actions
darnautov Dec 2, 2024
9f5c6f1
fix ts issues
darnautov Dec 2, 2024
fd307ac
update jest tests
darnautov Dec 2, 2024
725825c
api integration tests
darnautov Dec 2, 2024
a2753de
fix i18n
darnautov Dec 2, 2024
44b9587
Merge branch 'main' into ml-191939-fix-trained-models-init
elasticmachine Dec 2, 2024
82b7fc3
fix spaces sync
darnautov Dec 3, 2024
39f3462
update test for spaces check
darnautov Dec 3, 2024
280defd
fix imports
darnautov Dec 3, 2024
9920c58
update tests, fix adding a built-in type
darnautov Dec 3, 2024
7b4d219
Merge remote-tracking branch 'origin/ml-191939-fix-trained-models-ini…
darnautov Dec 3, 2024
3f9de1b
Merge remote-tracking branch 'origin/main' into ml-191939-fix-trained…
darnautov Dec 3, 2024
1c41a84
fix typo, clean up tests
darnautov Dec 4, 2024
d869cca
typo
darnautov Dec 4, 2024
307bb40
replace transport call, update tests and types
darnautov Dec 4, 2024
dacbab1
fix ts issue in tests
darnautov Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 136 additions & 26 deletions x-pack/plugins/ml/common/types/trained_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,23 @@
* 2.0.
*/
import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import type { TrainedModelType } from '@kbn/ml-trained-models-utils';
import type { MlInferenceConfigCreateContainer } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import type {
InferenceAPIConfigResponse,
ModelDefinitionResponse,
ModelState,
TrainedModelType,
} from '@kbn/ml-trained-models-utils';
import {
BUILT_IN_MODEL_TAG,
ELASTIC_MODEL_TAG,
TRAINED_MODEL_TYPE,
} from '@kbn/ml-trained-models-utils';
import type {
DataFrameAnalyticsConfig,
FeatureImportanceBaseline,
TotalFeatureImportance,
} from '@kbn/ml-data-frame-analytics-utils';
import type { IndexName, IndicesIndexState } from '@elastic/elasticsearch/lib/api/types';
import type { InferenceAPIConfigResponse } from '@kbn/ml-trained-models-utils';
import type { XOR } from './common';
import type { MlSavedObjectType } from './saved_objects';

Expand Down Expand Up @@ -95,33 +104,12 @@ export type PutTrainedModelConfig = {
>; // compressed_definition and definition are mutually exclusive

export type TrainedModelConfigResponse = estypes.MlTrainedModelConfig & {
/**
* Associated pipelines. Extends response from the ES endpoint.
*/
pipelines?: Record<string, PipelineDefinition> | null;
origin_job_exists?: boolean;

metadata?: {
analytics_config: DataFrameAnalyticsConfig;
metadata?: estypes.MlTrainedModelConfig['metadata'] & {
analytics_config?: DataFrameAnalyticsConfig;
input: unknown;
total_feature_importance?: TotalFeatureImportance[];
feature_importance_baseline?: FeatureImportanceBaseline;
model_aliases?: string[];
} & Record<string, unknown>;
model_id: string;
model_type: TrainedModelType;
tags: string[];
version: string;
inference_config?: Record<string, any>;
indices?: Array<Record<IndexName, IndicesIndexState | null>>;
/**
* Whether the model has inference services
*/
hasInferenceServices?: boolean;
/**
* Inference services associated with the model
*/
inference_apis?: InferenceAPIConfigResponse[];
};

export interface PipelineDefinition {
Expand Down Expand Up @@ -309,3 +297,125 @@ export interface ModelDownloadState {
total_parts: number;
downloaded_parts: number;
}

export type Stats = Omit<TrainedModelStat, 'model_id' | 'deployment_stats'>;

/**
* Additional properties for all items in the Trained models table
* */
interface BaseModelItem {
type?: string[];
tags: string[];
/**
* Whether the model has inference services
*/
hasInferenceServices?: boolean;
/**
* Inference services associated with the model
*/
inference_apis?: InferenceAPIConfigResponse[];
/**
* Associated pipelines. Extends response from the ES endpoint.
*/
pipelines?: Record<string, PipelineDefinition>;
/**
* Indices with associated pipelines that have inference processors utilizing the model deployments.
*/
indices?: string[];
}

/** Common properties for existing NLP models and NLP model download configs */
interface BaseNLPModelItem extends BaseModelItem {
disclaimer?: string;
recommended?: boolean;
supported?: boolean;
state: ModelState | undefined;
downloadState?: ModelDownloadState;
}

/** Model available for download */
export type ModelDownloadItem = BaseNLPModelItem &
Omit<ModelDefinitionResponse, 'version' | 'config'> & {
putModelConfig?: object;
softwareLicense?: string;
};
/** Trained NLP model, i.e. pytorch model returned by the trained_models API */
export type NLPModelItem = BaseNLPModelItem &
TrainedModelItem & {
stats: Stats & { deployment_stats: TrainedModelDeploymentStatsResponse[] };
/**
* Description of the current model state
*/
stateDescription?: string;
/**
* Deployment ids extracted from the deployment stats
*/
deployment_ids: string[];
};

export function isBaseNLPModelItem(item: unknown): item is BaseNLPModelItem {
return (
typeof item === 'object' &&
item !== null &&
'type' in item &&
Array.isArray(item.type) &&
item.type.includes(TRAINED_MODEL_TYPE.PYTORCH)
);
}

export function isNLPModelItem(item: unknown): item is NLPModelItem {
return isExistingModel(item) && item.model_type === TRAINED_MODEL_TYPE.PYTORCH;
}

export const isElasticModel = (item: TrainedModelConfigResponse) =>
item.tags.includes(ELASTIC_MODEL_TAG);

export type ExistingModelBase = TrainedModelConfigResponse & BaseModelItem;

/** Any model returned by the trained_models API, e.g. lang_ident, elser, dfa model */
export type TrainedModelItem = ExistingModelBase & { stats: Stats };

/** Trained DFA model */
export type DFAModelItem = Omit<TrainedModelItem, 'inference_config'> & {
origin_job_exists?: boolean;
inference_config?: Pick<MlInferenceConfigCreateContainer, 'classification' | 'regression'>;
metadata?: estypes.MlTrainedModelConfig['metadata'] & {
analytics_config: DataFrameAnalyticsConfig;
input: unknown;
total_feature_importance?: TotalFeatureImportance[];
feature_importance_baseline?: FeatureImportanceBaseline;
} & Record<string, unknown>;
};

export type TrainedModelWithPipelines = TrainedModelItem & {
pipelines: Record<string, PipelineDefinition>;
};

export function isExistingModel(item: unknown): item is TrainedModelItem {
return (
typeof item === 'object' &&
item !== null &&
'model_type' in item &&
'create_time' in item &&
!!item.create_time
);
}

export function isDFAModelItem(item: unknown): item is DFAModelItem {
return isExistingModel(item) && item.model_type === TRAINED_MODEL_TYPE.TREE_ENSEMBLE;
}

export function isModelDownloadItem(item: TrainedModelUIItem): item is ModelDownloadItem {
return 'putModelConfig' in item && !!item.type?.includes(TRAINED_MODEL_TYPE.PYTORCH);
}

export const isBuiltInModel = (item: TrainedModelConfigResponse | TrainedModelUIItem) =>
item.tags.includes(BUILT_IN_MODEL_TAG);
/**
* This type represents a union of different model entities:
* - Any existing trained model returned by the API, e.g., lang_ident_model_1, DFA models, etc.
* - Hosted model configurations available for download, e.g., ELSER or E5
* - NLP models already downloaded into Elasticsearch
* - DFA models
*/
export type TrainedModelUIItem = TrainedModelItem | ModelDownloadItem | NLPModelItem | DFAModelItem;
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import {
import { i18n } from '@kbn/i18n';
import { extractErrorProperties } from '@kbn/ml-error-utils';

import type { ModelItem } from '../../model_management/models_list';
import type { DFAModelItem } from '../../../../common/types/trained_models';
import type { AddInferencePipelineSteps } from './types';
import { ADD_INFERENCE_PIPELINE_STEPS } from './constants';
import { AddInferencePipelineFooter } from '../shared';
Expand All @@ -39,7 +39,7 @@ import { useFetchPipelines } from './hooks/use_fetch_pipelines';

export interface AddInferencePipelineFlyoutProps {
onClose: () => void;
model: ModelItem;
model: DFAModelItem;
}

export const AddInferencePipelineFlyout: FC<AddInferencePipelineFlyoutProps> = ({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import {
import { i18n } from '@kbn/i18n';
import { FormattedMessage } from '@kbn/i18n-react';
import { CodeEditor } from '@kbn/code-editor';
import type { ModelItem } from '../../../model_management/models_list';
import type { DFAModelItem } from '../../../../../common/types/trained_models';
import {
EDIT_MESSAGE,
CANCEL_EDIT_MESSAGE,
Expand Down Expand Up @@ -56,9 +56,9 @@ interface Props {
condition?: string;
fieldMap: MlInferenceState['fieldMap'];
handleAdvancedConfigUpdate: (configUpdate: Partial<MlInferenceState>) => void;
inferenceConfig: ModelItem['inference_config'];
modelInferenceConfig: ModelItem['inference_config'];
modelInputFields: ModelItem['input'];
inferenceConfig: DFAModelItem['inference_config'];
modelInferenceConfig: DFAModelItem['inference_config'];
modelInputFields: DFAModelItem['input'];
modelType?: InferenceModelTypes;
setHasUnsavedChanges: React.Dispatch<React.SetStateAction<boolean>>;
tag?: string;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
*/

import { getAnalysisType } from '@kbn/ml-data-frame-analytics-utils';
import type { DFAModelItem } from '../../../../common/types/trained_models';
import type { MlInferenceState } from './types';
import type { ModelItem } from '../../model_management/models_list';

export const getModelType = (model: ModelItem): string | undefined => {
export const getModelType = (model: DFAModelItem): string | undefined => {
const analysisConfig = model.metadata?.analytics_config?.analysis;
return analysisConfig !== undefined ? getAnalysisType(analysisConfig) : undefined;
};
Expand Down Expand Up @@ -54,13 +54,17 @@ export const getDefaultOnFailureConfiguration = (): MlInferenceState['onFailure'
},
];

export const getInitialState = (model: ModelItem): MlInferenceState => {
export const getInitialState = (model: DFAModelItem): MlInferenceState => {
const modelType = getModelType(model);
let targetField;

if (modelType !== undefined) {
targetField = model.inference_config
? `ml.inference.${model.inference_config[modelType].results_field}`
? `ml.inference.${
model.inference_config[
modelType as keyof Exclude<DFAModelItem['inference_config'], undefined>
]!.results_field
}`
: undefined;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ export function AnalyticsIdSelector({
async function fetchAnalyticsModels() {
setIsLoading(true);
try {
// FIXME should if fetch all trained models?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alvarezmelissa87 could you please check this one?

Copy link
Contributor

@alvarezmelissa87 alvarezmelissa87 Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is correct for when we're using the analytics selection flyout from the maps page - we want to allow users to choose jobs or models, which requires us to fetch the models. Though, looks like in x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/page.tsx
the AnalyticsIdSelector should have the jobsOnly prop set to 'true' so that we don't fetch models for the exploration page - which only pertains to jobs.
Looks like I missed setting that correctly. Would you be up for adding it? Otherwise, I can create a tiny PR for it.
So in that file it should be:

<AnalyticsIdSelector
    setAnalyticsId={setAnalyticsId}
    setIsIdSelectorFlyoutVisible={setIsIdSelectorFlyoutVisible}
    jobsOnly
  />
        

Copy link
Contributor Author

@darnautov darnautov Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, it also fetches all NLP models, I believe it doesn't make sense.
It should only list DFA models.

image

const response = await trainedModelsApiService.getTrainedModels();
setTrainedModels(response);
} catch (e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ import { FormattedMessage } from '@kbn/i18n-react';
import React, { type FC, useMemo, useState } from 'react';
import { groupBy } from 'lodash';
import { ElandPythonClient } from '@kbn/inference_integration_flyout';
import type { ModelDownloadItem } from '../../../common/types/trained_models';
import { usePermissionCheck } from '../capabilities/check_capabilities';
import { useMlKibana } from '../contexts/kibana';
import type { ModelItem } from './models_list';

export interface AddModelFlyoutProps {
modelDownloads: ModelItem[];
modelDownloads: ModelDownloadItem[];
onClose: () => void;
onSubmit: (modelId: string) => void;
}
Expand Down Expand Up @@ -138,7 +138,7 @@ export const AddModelFlyout: FC<AddModelFlyoutProps> = ({ onClose, onSubmit, mod
};

interface ClickToDownloadTabContentProps {
modelDownloads: ModelItem[];
modelDownloads: ModelDownloadItem[];
onModelDownload: (modelId: string) => void;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import { i18n } from '@kbn/i18n';
import { extractErrorProperties } from '@kbn/ml-error-utils';
import type { SupportedPytorchTasksType } from '@kbn/ml-trained-models-utils';

import type { ModelItem } from '../models_list';
import type { TrainedModelItem } from '../../../../common/types/trained_models';
import type { AddInferencePipelineSteps } from '../../components/ml_inference/types';
import { ADD_INFERENCE_PIPELINE_STEPS } from '../../components/ml_inference/constants';
import { AddInferencePipelineFooter } from '../../components/shared';
Expand All @@ -40,7 +40,7 @@ import { useTestTrainedModelsContext } from '../test_models/test_trained_models_

export interface CreatePipelineForModelFlyoutProps {
onClose: (refreshList?: boolean) => void;
model: ModelItem;
model: TrainedModelItem;
}

export const CreatePipelineForModelFlyout: FC<CreatePipelineForModelFlyoutProps> = ({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import type { IngestInferenceProcessor } from '@elastic/elasticsearch/lib/api/types';
import type { TrainedModelItem } from '../../../../common/types/trained_models';
import { getDefaultOnFailureConfiguration } from '../../components/ml_inference/state';
import type { ModelItem } from '../models_list';

export interface InferecePipelineCreationState {
creatingPipeline: boolean;
Expand All @@ -26,7 +26,7 @@ export interface InferecePipelineCreationState {
}

export const getInitialState = (
model: ModelItem,
model: TrainedModelItem,
initialPipelineConfig: estypes.IngestPipeline | undefined
): InferecePipelineCreationState => ({
creatingPipeline: false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ import { i18n } from '@kbn/i18n';
import { FormattedMessage } from '@kbn/i18n-react';
import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';

import type { ModelItem } from '../models_list';
import type { TrainedModelItem } from '../../../../common/types/trained_models';
import { TestTrainedModelContent } from '../test_models/test_trained_model_content';
import { useMlKibana } from '../../contexts/kibana';
import { type InferecePipelineCreationState } from './state';

interface ContentProps {
model: ModelItem;
model: TrainedModelItem;
handlePipelineConfigUpdate: (configUpdate: Partial<InferecePipelineCreationState>) => void;
externalPipelineConfig?: estypes.IngestPipeline;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ import {
EuiSpacer,
} from '@elastic/eui';
import { isPopulatedObject } from '@kbn/ml-is-populated-object';
import type { TrainedModelItem, TrainedModelUIItem } from '../../../common/types/trained_models';
import { isExistingModel } from '../../../common/types/trained_models';
import { type WithRequired } from '../../../common/types/common';
import { useTrainedModelsApiService } from '../services/ml_api_service/trained_models';
import { useToastNotificationService } from '../services/toast_notification_service';
import { DeleteSpaceAwareItemCheckModal } from '../components/delete_space_aware_item_check_modal';
import { type ModelItem } from './models_list';

interface DeleteModelsModalProps {
models: ModelItem[];
models: TrainedModelUIItem[];
onClose: (refreshList?: boolean) => void;
}

Expand All @@ -42,11 +43,14 @@ export const DeleteModelsModal: FC<DeleteModelsModalProps> = ({ models, onClose

const modelIds = models.map((m) => m.model_id);

const modelsWithPipelines = models.filter((m) => isPopulatedObject(m.pipelines)) as Array<
WithRequired<ModelItem, 'pipelines'>
>;
const modelsWithPipelines = models.filter(
(m): m is WithRequired<TrainedModelItem, 'pipelines'> =>
isExistingModel(m) && isPopulatedObject(m.pipelines)
);

const modelsWithInferenceAPIs = models.filter((m) => m.hasInferenceServices);
const modelsWithInferenceAPIs = models.filter(
(m): m is TrainedModelItem => isExistingModel(m) && !!m.hasInferenceServices
);

const inferenceAPIsIDs: string[] = modelsWithInferenceAPIs.flatMap((model) => {
return (model.inference_apis ?? []).map((inference) => inference.inference_id);
Expand Down
Loading