diff --git a/x-pack/plugins/ml/common/types/trained_models.ts b/x-pack/plugins/ml/common/types/trained_models.ts index ad59b7a917c49..182cf277d93cc 100644 --- a/x-pack/plugins/ml/common/types/trained_models.ts +++ b/x-pack/plugins/ml/common/types/trained_models.ts @@ -4,7 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ - +import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; import type { DataFrameAnalyticsConfig } from './data_frame_analytics'; import type { FeatureImportanceBaseline, TotalFeatureImportance } from './feature_importance'; import type { XOR } from './common'; @@ -87,14 +87,12 @@ export type PutTrainedModelConfig = { } >; // compressed_definition and definition are mutually exclusive -export interface TrainedModelConfigResponse { - description?: string; - created_by: string; - create_time: string; - default_field_map: Record; - estimated_heap_memory_usage_bytes: number; - estimated_operations: number; - license_level: string; +export type TrainedModelConfigResponse = estypes.MlTrainedModelConfig & { + /** + * Associated pipelines. Extends response from the ES endpoint. + */ + pipelines?: Record | null; + metadata?: { analytics_config: DataFrameAnalyticsConfig; input: unknown; @@ -107,11 +105,7 @@ export interface TrainedModelConfigResponse { tags: string[]; version: string; inference_config?: Record; - /** - * Associated pipelines. Extends response from the ES endpoint. - */ - pipelines?: Record | null; -} +}; export interface PipelineDefinition { processors?: Array>; diff --git a/x-pack/plugins/ml/public/application/services/ml_api_service/index.ts b/x-pack/plugins/ml/public/application/services/ml_api_service/index.ts index 40187f70f1680..87f1a8eec2478 100644 --- a/x-pack/plugins/ml/public/application/services/ml_api_service/index.ts +++ b/x-pack/plugins/ml/public/application/services/ml_api_service/index.ts @@ -17,6 +17,7 @@ import { resultsApiProvider } from './results'; import { jobsApiProvider } from './jobs'; import { fileDatavisualizer } from './datavisualizer'; import { savedObjectsApiProvider } from './saved_objects'; +import { trainedModelsApiProvider } from './trained_models'; import type { MlServerDefaults, MlServerLimits, @@ -719,5 +720,6 @@ export function mlApiServicesProvider(httpService: HttpService) { jobs: jobsApiProvider(httpService), fileDatavisualizer, savedObjects: savedObjectsApiProvider(httpService), + trainedModels: trainedModelsApiProvider(httpService), }; } diff --git a/x-pack/plugins/ml/public/application/services/ml_api_service/trained_models.ts b/x-pack/plugins/ml/public/application/services/ml_api_service/trained_models.ts index 97027b86a88e1..738f5e1ace74a 100644 --- a/x-pack/plugins/ml/public/application/services/ml_api_service/trained_models.ts +++ b/x-pack/plugins/ml/public/application/services/ml_api_service/trained_models.ts @@ -5,6 +5,8 @@ * 2.0. */ +import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; + import { useMemo } from 'react'; import { HttpFetchQuery } from 'kibana/public'; import { HttpService } from '../http_service'; @@ -138,6 +140,25 @@ export function trainedModelsApiProvider(httpService: HttpService) { query: { force }, }); }, + + inferTrainedModel(modelId: string, payload: any, timeout?: string) { + const body = JSON.stringify(payload); + return httpService.http({ + path: `${apiBasePath}/trained_models/infer/${modelId}`, + method: 'POST', + body, + ...(timeout ? { query: { timeout } as HttpFetchQuery } : {}), + }); + }, + + ingestPipelineSimulate(payload: estypes.IngestSimulateRequest['body']) { + const body = JSON.stringify(payload); + return httpService.http({ + path: `${apiBasePath}/trained_models/ingest_pipeline_simulate`, + method: 'POST', + body, + }); + }, }; } diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/models_list.tsx b/x-pack/plugins/ml/public/application/trained_models/models_management/models_list.tsx index bd3e3638e8310..1604e265b1617 100644 --- a/x-pack/plugins/ml/public/application/trained_models/models_management/models_list.tsx +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/models_list.tsx @@ -53,6 +53,7 @@ import { DEPLOYMENT_STATE, TRAINED_MODEL_TYPE } from '../../../../common/constan import { getUserConfirmationProvider } from './force_stop_dialog'; import { MLSavedObjectsSpacesList } from '../../components/ml_saved_objects_spaces_list'; import { SavedObjectsWarning } from '../../components/saved_objects_warning'; +import { TestTrainedModelFlyout, isTestable } from './test_models'; type Stats = Omit; @@ -134,6 +135,7 @@ export const ModelsList: FC = ({ const [itemIdToExpandedRowMap, setItemIdToExpandedRowMap] = useState>( {} ); + const [showTestFlyout, setShowTestFlyout] = useState(null); const getUserConfirmation = useMemo(() => getUserConfirmationProvider(overlays, theme), []); const navigateToPath = useNavigateToPath(); @@ -470,6 +472,19 @@ export const ModelsList: FC = ({ return !isPopulatedObject(item.pipelines); }, }, + { + name: i18n.translate('xpack.ml.inference.modelsList.testModelActionLabel', { + defaultMessage: 'Test model', + }), + description: i18n.translate('xpack.ml.inference.modelsList.testModelActionLabel', { + defaultMessage: 'Test model', + }), + icon: 'inputOutput', + type: 'icon', + isPrimary: true, + available: isTestable, + onClick: setShowTestFlyout, + }, ] as Array>) ); } @@ -769,6 +784,12 @@ export const ModelsList: FC = ({ modelIds={modelIdsToDelete} /> )} + {showTestFlyout === null ? null : ( + + )} ); }; diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/index.ts b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/index.ts new file mode 100644 index 0000000000000..da7c12c1c0c58 --- /dev/null +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/index.ts @@ -0,0 +1,9 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export { TestTrainedModelFlyout } from './test_flyout'; +export { isTestable } from './utils'; diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/inference_error.tsx b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/inference_error.tsx new file mode 100644 index 0000000000000..dc7ae508ab270 --- /dev/null +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/inference_error.tsx @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import React, { FC } from 'react'; +import { i18n } from '@kbn/i18n'; +import { EuiCallOut } from '@elastic/eui'; + +interface Props { + errorText: string | null; +} + +export const ErrorMessage: FC = ({ errorText }) => { + return errorText === null ? null : ( + <> + +

{errorText}

+
+ + ); +}; diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/inference_base.ts b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/inference_base.ts new file mode 100644 index 0000000000000..777ca2d314c4d --- /dev/null +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/inference_base.ts @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; + +import { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models'; + +const DEFAULT_INPUT_FIELD = 'text_field'; + +export type FormattedNerResp = Array<{ + value: string; + entity: estypes.MlTrainedModelEntities | null; +}>; + +export abstract class InferenceBase { + protected readonly inputField: string; + + constructor( + protected trainedModelsApi: ReturnType, + protected model: estypes.MlTrainedModelConfig + ) { + this.inputField = model.input?.field_names[0] ?? DEFAULT_INPUT_FIELD; + } + + protected abstract infer(inputText: string): Promise; +} diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/inference_input_form.tsx b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/inference_input_form.tsx new file mode 100644 index 0000000000000..6503486d98211 --- /dev/null +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/inference_input_form.tsx @@ -0,0 +1,131 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import React, { FC, useState } from 'react'; + +import { i18n } from '@kbn/i18n'; +import { FormattedMessage } from '@kbn/i18n-react'; +import { EuiSpacer, EuiTextArea, EuiButton, EuiTabs, EuiTab } from '@elastic/eui'; + +import { LangIdentInference } from './lang_ident/lang_ident_inference'; +import { NerInference } from './ner/ner_inference'; +import type { FormattedLangIdentResp } from './lang_ident/lang_ident_inference'; +import type { FormattedNerResp } from './ner/ner_inference'; + +import { MLJobEditor } from '../../../../jobs/jobs_list/components/ml_job_editor'; +import { extractErrorMessage } from '../../../../../../common/util/errors'; +import { ErrorMessage } from '../inference_error'; +import { OutputLoadingContent } from '../output_loading'; + +interface Props { + inferrer: LangIdentInference | NerInference; + getOutputComponent(output: any): JSX.Element; +} + +enum TAB { + TEXT, + RAW, +} + +export const InferenceInputForm: FC = ({ inferrer, getOutputComponent }) => { + const [inputText, setInputText] = useState(''); + const [isRunning, setIsRunning] = useState(false); + const [output, setOutput] = useState(null); + const [rawOutput, setRawOutput] = useState(null); + const [selectedTab, setSelectedTab] = useState(TAB.TEXT); + const [showOutput, setShowOutput] = useState(false); + const [errorText, setErrorText] = useState(null); + + async function run() { + setShowOutput(true); + setOutput(null); + setRawOutput(null); + setIsRunning(true); + setErrorText(null); + try { + const { response, rawResponse } = await inferrer.infer(inputText); + setOutput(response); + setRawOutput(JSON.stringify(rawResponse, null, 2)); + } catch (e) { + setIsRunning(false); + setOutput(null); + setErrorText(extractErrorMessage(e)); + setRawOutput(JSON.stringify(e.body ?? e, null, 2)); + } + setIsRunning(false); + } + + return ( + <> + { + setInputText(e.target.value); + }} + /> + +
+ + + +
+ {showOutput === true ? ( + <> + + + + + + + + + + + + + {selectedTab === TAB.TEXT ? ( + <> + {errorText !== null ? ( + + ) : output === null ? ( + + ) : ( + <>{getOutputComponent(output)} + )} + + ) : ( + + )} + + ) : null} + + ); +}; diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/lang_ident/index.ts b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/lang_ident/index.ts new file mode 100644 index 0000000000000..b3439d90e8828 --- /dev/null +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/lang_ident/index.ts @@ -0,0 +1,10 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export type { FormattedLangIdentResp } from './lang_ident_inference'; +export { LangIdentInference } from './lang_ident_inference'; +export { LangIdentOutput } from './lang_ident_output'; diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/lang_ident/lang_codes.ts b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/lang_ident/lang_codes.ts new file mode 100644 index 0000000000000..eff2fdcdd94e7 --- /dev/null +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/lang_ident/lang_codes.ts @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +const langCodes: Record = { + af: 'Afrikaans', + hr: 'Croatian', + pa: 'Punjabi', + am: 'Amharic', + ht: 'Haitian', + pl: 'Polish', + ar: 'Arabic', + hu: 'Hungarian', + ps: 'Pashto', + az: 'Azerbaijani', + hy: 'Armenian', + pt: 'Portuguese', + be: 'Belarusian', + id: 'Indonesian', + ro: 'Romanian', + bg: 'Bulgarian', + ig: 'Igbo', + ru: 'Russian', + 'bg-Latn': 'Bulgarian', + is: 'Icelandic', + 'ru-Latn': 'Russian', + bn: 'Bengali', + it: 'Italian', + sd: 'Sindhi', + bs: 'Bosnian', + iw: 'Hebrew', + si: 'Sinhala', + ca: 'Catalan', + ja: 'Japanese', + sk: 'Slovak', + ceb: 'Cebuano', + 'ja-Latn': 'Japanese', + sl: 'Slovenian', + co: 'Corsican', + jv: 'Javanese', + sm: 'Samoan', + cs: 'Czech', + ka: 'Georgian', + sn: 'Shona', + cy: 'Welsh', + kk: 'Kazakh', + so: 'Somali', + da: 'Danish', + km: 'Central Khmer', + sq: 'Albanian', + de: 'German', + kn: 'Kannada', + sr: 'Serbian', + el: 'Greek,modern', + ko: 'Korean', + st: 'Southern Sotho', + 'el-Latn': 'Greek,modern', + ku: 'Kurdish', + su: 'Sundanese', + en: 'English', + ky: 'Kirghiz', + sv: 'Swedish', + eo: 'Esperanto', + la: 'Latin', + sw: 'Swahili', + es: 'Spanish,Castilian', + lb: 'Luxembourgish', + ta: 'Tamil', + et: 'Estonian', + lo: 'Lao', + te: 'Telugu', + eu: 'Basque', + lt: 'Lithuanian', + tg: 'Tajik', + fa: 'Persian', + lv: 'Latvian', + th: 'Thai', + fi: 'Finnish', + mg: 'Malagasy', + tr: 'Turkish', + fil: 'Filipino', + mi: 'Maori', + uk: 'Ukrainian', + fr: 'French', + mk: 'Macedonian', + ur: 'Urdu', + fy: 'Western Frisian', + ml: 'Malayalam', + uz: 'Uzbek', + ga: 'Irish', + mn: 'Mongolian', + vi: 'Vietnamese', + gd: 'Gaelic', + mr: 'Marathi', + xh: 'Xhosa', + gl: 'Galician', + ms: 'Malay', + yi: 'Yiddish', + gu: 'Gujarati', + mt: 'Maltese', + yo: 'Yoruba', + ha: 'Hausa', + my: 'Burmese', + zh: 'Chinese', + haw: 'Hawaiian', + ne: 'Nepali', + 'zh-Latn': 'Chinese', + hi: 'Hindi', + nl: 'Dutch,Flemish', + zu: 'Zulu', + 'hi-Latn': 'Hindi', + no: 'Norwegian', + hmn: 'Hmong', + ny: 'Chichewa', + + zxx: 'unknown', +}; + +export function getLanguage(code: string) { + return langCodes[code] ?? 'unknown'; +} diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/lang_ident/lang_ident_inference.ts b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/lang_ident/lang_ident_inference.ts new file mode 100644 index 0000000000000..9108a59197617 --- /dev/null +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/lang_ident/lang_ident_inference.ts @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; + +import { InferenceBase } from '../inference_base'; + +export type FormattedLangIdentResp = Array<{ + className: string; + classProbability: number; + classScore: number; +}>; + +interface InferResponse { + response: FormattedLangIdentResp; + rawResponse: estypes.IngestSimulateResponse; +} + +export class LangIdentInference extends InferenceBase { + public async infer(inputText: string) { + const payload: estypes.IngestSimulateRequest['body'] = { + pipeline: { + processors: [ + { + inference: { + model_id: this.model.model_id, + inference_config: { + // @ts-expect-error classification missing from type + classification: { + num_top_classes: 3, + }, + }, + field_mappings: { + contents: this.inputField, + }, + target_field: '_ml.lang_ident', + }, + }, + ], + }, + docs: [ + { + _source: { + contents: inputText, + }, + }, + ], + }; + const resp = await this.trainedModelsApi.ingestPipelineSimulate(payload); + if (resp.docs.length) { + const topClasses = resp.docs[0].doc?._source._ml?.lang_ident?.top_classes ?? []; + + return { + response: topClasses.map((t: any) => ({ + className: t.class_name, + classProbability: t.class_probability, + classScore: t.class_score, + })), + rawResponse: resp, + }; + } + return { response: [], rawResponse: resp }; + } +} diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/lang_ident/lang_ident_output.tsx b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/lang_ident/lang_ident_output.tsx new file mode 100644 index 0000000000000..e4968bc516f83 --- /dev/null +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/lang_ident/lang_ident_output.tsx @@ -0,0 +1,86 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import React, { FC } from 'react'; +import { i18n } from '@kbn/i18n'; +import { EuiSpacer, EuiBasicTable, EuiTitle } from '@elastic/eui'; + +import type { FormattedLangIdentResp } from './lang_ident_inference'; +import { getLanguage } from './lang_codes'; + +const PROBABILITY_SIG_FIGS = 3; + +export const LangIdentOutput: FC<{ result: FormattedLangIdentResp }> = ({ result }) => { + if (result.length === 0) { + return null; + } + + const lang = getLanguage(result[0].className); + + const items = result.map(({ className, classProbability }, i) => { + return { + noa: `${i + 1}`, + className: getLanguage(className), + classProbability: `${Number(classProbability).toPrecision(PROBABILITY_SIG_FIGS)}`, + }; + }); + + const columns = [ + { + field: 'noa', + name: '#', + width: '5%', + truncateText: false, + isExpander: false, + }, + { + field: 'className', + name: i18n.translate( + 'xpack.ml.trainedModels.testModelsFlyout.langIdent.output.language_title', + { + defaultMessage: 'Language', + } + ), + width: '30%', + truncateText: false, + isExpander: false, + }, + { + field: 'classProbability', + name: i18n.translate( + 'xpack.ml.trainedModels.testModelsFlyout.langIdent.output.probability_title', + { + defaultMessage: 'Probability', + } + ), + truncateText: false, + isExpander: false, + }, + ]; + + const title = + lang !== 'unknown' + ? i18n.translate('xpack.ml.trainedModels.testModelsFlyout.langIdent.output.title', { + defaultMessage: 'This looks like {lang}', + values: { lang }, + }) + : i18n.translate('xpack.ml.trainedModels.testModelsFlyout.langIdent.output.titleUnknown', { + defaultMessage: 'Language code unknown: {code}', + values: { code: result[0].className }, + }); + + return ( + <> + +

{title}

+
+ + + + + ); +}; diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/ner/index.ts b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/ner/index.ts new file mode 100644 index 0000000000000..38ddad8bdeb80 --- /dev/null +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/ner/index.ts @@ -0,0 +1,10 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export type { FormattedNerResp } from './ner_inference'; +export { NerInference } from './ner_inference'; +export { NerOutput } from './ner_output'; diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/ner/ner_inference.ts b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/ner/ner_inference.ts new file mode 100644 index 0000000000000..e4dcfcc2c6333 --- /dev/null +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/ner/ner_inference.ts @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; + +import { InferenceBase } from '../inference_base'; + +export type FormattedNerResp = Array<{ + value: string; + entity: estypes.MlTrainedModelEntities | null; +}>; + +interface InferResponse { + response: FormattedNerResp; + rawResponse: estypes.MlInferTrainedModelDeploymentResponse; +} + +export class NerInference extends InferenceBase { + public async infer(inputText: string) { + const payload = { docs: { [this.inputField]: inputText } }; + const resp = await this.trainedModelsApi.inferTrainedModel(this.model.model_id, payload, '30s'); + + return { response: parseResponse(resp), rawResponse: resp }; + } +} + +function parseResponse(resp: estypes.MlInferTrainedModelDeploymentResponse): FormattedNerResp { + const { predicted_value: predictedValue, entities } = resp; + const splitWordsAndEntitiesRegex = /(\[.*?\]\(.*?&.*?\))/; + const matchEntityRegex = /(\[.*?\])\((.*?)&(.*?)\)/; + if (predictedValue === undefined || entities === undefined) { + return []; + } + + const sentenceChunks = (predictedValue as unknown as string).split(splitWordsAndEntitiesRegex); + let count = 0; + return sentenceChunks.map((chunk) => { + const matchedEntity = chunk.match(matchEntityRegex); + if (matchedEntity) { + const entityValue = matchedEntity[3]; + const entity = entities[count]; + if (entityValue !== entity.entity && entityValue.replaceAll('+', ' ') !== entity.entity) { + // entityValue may not equal entity.entity if the entity is comprised of + // two words as they are joined with a plus symbol + // Replace any plus symbols and check again. If they still don't match, log an error + + // eslint-disable-next-line no-console + console.error('mismatch entity', entity); + } + count++; + return { value: entity.entity, entity }; + } + return { value: chunk, entity: null }; + }); +} diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/ner/ner_output.tsx b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/ner/ner_output.tsx new file mode 100644 index 0000000000000..e9db3fa8efd36 --- /dev/null +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/ner/ner_output.tsx @@ -0,0 +1,167 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; +import React, { FC, ReactNode } from 'react'; +import { FormattedMessage } from '@kbn/i18n-react'; +import { + EuiHorizontalRule, + EuiBadge, + EuiToolTip, + EuiFlexGroup, + EuiFlexItem, + EuiIcon, +} from '@elastic/eui'; + +import { + useCurrentEuiTheme, + EuiThemeType, +} from '../../../../../components/color_range_legend/use_color_range'; +import type { FormattedNerResp } from './ner_inference'; + +const ICON_PADDING = '2px'; +const PROBABILITY_SIG_FIGS = 3; + +const ENTITY_TYPES = { + PER: { + label: 'Person', + icon: 'user', + color: 'euiColorVis5_behindText', + borderColor: 'euiColorVis5', + }, + LOC: { + label: 'Location', + icon: 'visMapCoordinate', + color: 'euiColorVis1_behindText', + borderColor: 'euiColorVis1', + }, + ORG: { + label: 'Organization', + icon: 'home', + color: 'euiColorVis0_behindText', + borderColor: 'euiColorVis0', + }, + MISC: { + label: 'Miscellaneous', + icon: 'questionInCircle', + color: 'euiColorVis7_behindText', + borderColor: 'euiColorVis7', + }, +}; + +const UNKNOWN_ENTITY_TYPE = { + label: '', + icon: 'questionInCircle', + color: 'euiColorVis5_behindText', + borderColor: 'euiColorVis5', +}; + +export const NerOutput: FC<{ result: FormattedNerResp }> = ({ result }) => { + const { euiTheme } = useCurrentEuiTheme(); + const lineSplit: JSX.Element[] = []; + result.forEach(({ value, entity }) => { + if (entity === null) { + const lines = value + .split(/(\n)/) + .map((line) => (line === '\n' ?
: {line})); + + lineSplit.push(...lines); + } else { + lineSplit.push( + +
+ + {value} +
+ +
+
+ + : {getClassLabel(entity.class_name)} +
+
+ + : {Number(entity.class_probability).toPrecision(PROBABILITY_SIG_FIGS)} +
+
+ + } + > + {value} +
+ ); + } + }); + return
{lineSplit}
; +}; + +const EntityBadge = ({ + entity, + children, +}: { + entity: estypes.MlTrainedModelEntities; + children: ReactNode; +}) => { + const { euiTheme } = useCurrentEuiTheme(); + return ( + + + + + + {children} + + + ); +}; + +function getClassIcon(className: string) { + const entity = ENTITY_TYPES[className as keyof typeof ENTITY_TYPES]; + return entity?.icon ?? UNKNOWN_ENTITY_TYPE.icon; +} + +function getClassLabel(className: string) { + const entity = ENTITY_TYPES[className as keyof typeof ENTITY_TYPES]; + return entity?.label ?? className; +} + +function getClassColor(euiTheme: EuiThemeType, className: string, border: boolean = false) { + const entity = ENTITY_TYPES[className as keyof typeof ENTITY_TYPES]; + let color = entity?.color ?? UNKNOWN_ENTITY_TYPE.color; + if (border) { + color = entity?.borderColor ?? UNKNOWN_ENTITY_TYPE.borderColor; + } + return euiTheme[color as keyof typeof euiTheme]; +} diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/output_loading.tsx b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/output_loading.tsx new file mode 100644 index 0000000000000..4cceed23edd25 --- /dev/null +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/output_loading.tsx @@ -0,0 +1,17 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import React, { FC } from 'react'; +import { EuiLoadingContent } from '@elastic/eui'; +import { LineRange } from '@elastic/eui/src/components/loading/loading_content'; + +export const OutputLoadingContent: FC<{ text: string }> = ({ text }) => { + const actualLines = text.split(/\r\n|\r|\n/).length + 1; + const lines = actualLines > 4 && actualLines <= 10 ? actualLines : 4; + + return ; +}; diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/selected_model.tsx b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/selected_model.tsx new file mode 100644 index 0000000000000..cab0826d5584a --- /dev/null +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/selected_model.tsx @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; +import React, { FC } from 'react'; + +import { NerOutput, NerInference } from './models/ner'; +import type { FormattedNerResp } from './models/ner'; +import { LangIdentOutput, LangIdentInference } from './models/lang_ident'; +import type { FormattedLangIdentResp } from './models/lang_ident'; + +import { TRAINED_MODEL_TYPE } from '../../../../../common/constants/trained_models'; +import { useMlApiContext } from '../../../contexts/kibana'; +import { InferenceInputForm } from './models/inference_input_form'; + +interface Props { + model: estypes.MlTrainedModelConfig | null; +} + +export const SelectedModel: FC = ({ model }) => { + const { trainedModels } = useMlApiContext(); + + if (model === null) { + return null; + } + + if (model.model_type === TRAINED_MODEL_TYPE.PYTORCH) { + const inferrer = new NerInference(trainedModels, model); + return ( + } + /> + ); + } + if (model.model_type === TRAINED_MODEL_TYPE.LANG_IDENT) { + const inferrer = new LangIdentInference(trainedModels, model); + return ( + } + /> + ); + } + + return null; +}; diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/test_flyout.tsx b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/test_flyout.tsx new file mode 100644 index 0000000000000..343cd32addce7 --- /dev/null +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/test_flyout.tsx @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; +import React, { FC } from 'react'; + +import { FormattedMessage } from '@kbn/i18n-react'; +import { EuiFlyout, EuiFlyoutHeader, EuiTitle, EuiFlyoutBody, EuiSpacer } from '@elastic/eui'; + +import { SelectedModel } from './selected_model'; + +interface Props { + model: estypes.MlTrainedModelConfig; + onClose: () => void; +} +export const TestTrainedModelFlyout: FC = ({ model, onClose }) => { + return ( + <> + + + +

+ +

+
+
+ + +

{model.model_id}

+
+ + + + +
+
+ + ); +}; diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/utils.ts b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/utils.ts new file mode 100644 index 0000000000000..ccddd960349d2 --- /dev/null +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/utils.ts @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; +import { TRAINED_MODEL_TYPE } from '../../../../../common/constants/trained_models'; + +const TESTABLE_MODEL_TYPES: estypes.MlTrainedModelType[] = [ + TRAINED_MODEL_TYPE.PYTORCH, + TRAINED_MODEL_TYPE.LANG_IDENT, +]; + +export function isTestable(model: estypes.MlTrainedModelConfig) { + return model.model_type && TESTABLE_MODEL_TYPES.includes(model.model_type); +} diff --git a/x-pack/plugins/ml/server/lib/ml_client/ml_client.ts b/x-pack/plugins/ml/server/lib/ml_client/ml_client.ts index 342a3913a6cba..122162777d9a5 100644 --- a/x-pack/plugins/ml/server/lib/ml_client/ml_client.ts +++ b/x-pack/plugins/ml/server/lib/ml_client/ml_client.ts @@ -494,6 +494,10 @@ export function getMlClient( await modelIdsCheck(p); return mlClient.stopTrainedModelDeployment(...p); }, + async inferTrainedModelDeployment(...p: Parameters) { + await modelIdsCheck(p); + return mlClient.inferTrainedModelDeployment(...p); + }, async info(...p: Parameters) { return mlClient.info(...p); }, diff --git a/x-pack/plugins/ml/server/routes/apidoc.json b/x-pack/plugins/ml/server/routes/apidoc.json index 59ed08664da3b..ac09aee7fcbb9 100644 --- a/x-pack/plugins/ml/server/routes/apidoc.json +++ b/x-pack/plugins/ml/server/routes/apidoc.json @@ -171,6 +171,8 @@ "StopTrainedModelDeployment", "PutTrainedModel", "DeleteTrainedModel", + "InferTrainedModelDeployment", + "IngestPipelineSimulate", "Alerting", "PreviewAlert" diff --git a/x-pack/plugins/ml/server/routes/schemas/inference_schema.ts b/x-pack/plugins/ml/server/routes/schemas/inference_schema.ts index 941edb31c79fa..1b9a865dcfca9 100644 --- a/x-pack/plugins/ml/server/routes/schemas/inference_schema.ts +++ b/x-pack/plugins/ml/server/routes/schemas/inference_schema.ts @@ -30,3 +30,19 @@ export const getInferenceQuerySchema = schema.object({ export const putTrainedModelQuerySchema = schema.object({ defer_definition_decompression: schema.maybe(schema.boolean()), }); + +export const pipelineSchema = schema.object({ + pipeline: schema.object({ + description: schema.maybe(schema.string()), + processors: schema.arrayOf(schema.recordOf(schema.string(), schema.any())), + version: schema.maybe(schema.number()), + on_failure: schema.maybe(schema.arrayOf(schema.recordOf(schema.string(), schema.any()))), + }), + docs: schema.arrayOf(schema.recordOf(schema.string(), schema.any())), + verbose: schema.maybe(schema.boolean()), +}); + +export const inferTrainedModelQuery = schema.object({ timeout: schema.maybe(schema.string()) }); +export const inferTrainedModelBody = schema.object({ + docs: schema.any(), +}); diff --git a/x-pack/plugins/ml/server/routes/trained_models.ts b/x-pack/plugins/ml/server/routes/trained_models.ts index 887ad47f1ceb2..27a062b45767c 100644 --- a/x-pack/plugins/ml/server/routes/trained_models.ts +++ b/x-pack/plugins/ml/server/routes/trained_models.ts @@ -5,6 +5,7 @@ * 2.0. */ +import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; import { schema } from '@kbn/config-schema'; import { RouteInitialization } from '../types'; import { wrapError } from '../client/error_wrapper'; @@ -13,6 +14,9 @@ import { modelIdSchema, optionalModelIdSchema, putTrainedModelQuerySchema, + pipelineSchema, + inferTrainedModelQuery, + inferTrainedModelBody, } from './schemas/inference_schema'; import { modelsProvider } from '../models/data_frame_analytics'; import { TrainedModelConfigResponse } from '../../common/types/trained_models'; @@ -352,4 +356,77 @@ export function trainedModelsRoutes({ router, routeGuard }: RouteInitialization) } }) ); + + /** + * @apiGroup TrainedModels + * + * @api {post} /api/ml/trained_models/infer/:modelId Evaluates a trained model + * @apiName InferTrainedModelDeployment + * @apiDescription Evaluates a trained model. + */ + router.post( + { + path: '/api/ml/trained_models/infer/{modelId}', + validate: { + params: modelIdSchema, + query: inferTrainedModelQuery, + body: inferTrainedModelBody, + }, + options: { + tags: ['access:ml:canStartStopTrainedModels'], + }, + }, + routeGuard.fullLicenseAPIGuard(async ({ mlClient, request, response }) => { + try { + const { modelId } = request.params; + const body = await mlClient.inferTrainedModelDeployment({ + model_id: modelId, + docs: request.body.docs, + ...(request.query.timeout ? { timeout: request.query.timeout } : {}), + }); + return response.ok({ + body, + }); + } catch (e) { + return response.customError(wrapError(e)); + } + }) + ); + + /** + * @apiGroup TrainedModels + * + * @api {post} /api/ml/trained_models/ingest_pipeline_simulate Ingest pipeline simulate + * @apiName IngestPipelineSimulate + * @apiDescription Simulates an ingest pipeline call using supplied documents + */ + router.post( + { + path: '/api/ml/trained_models/ingest_pipeline_simulate', + validate: { + body: pipelineSchema, + }, + options: { + tags: ['access:ml:canStartStopTrainedModels'], + }, + }, + routeGuard.fullLicenseAPIGuard(async ({ client, request, response }) => { + try { + const { pipeline, docs, verbose } = request.body; + + const body = await client.asCurrentUser.ingest.simulate({ + verbose, + body: { + pipeline, + docs: docs as estypes.IngestSimulateDocument[], + }, + }); + return response.ok({ + body, + }); + } catch (e) { + return response.customError(wrapError(e)); + } + }) + ); }