Skip to content

Commit

Permalink
[ML] Testing trained models in UI (#128359)
Browse files Browse the repository at this point in the history
* [ML] Testing trained models in UI

* folder rename

* code clean up

* translations

* adding comments

* endpoint comments

* small changes based on review

* removing testing text

* refactoring to remove duplicate code

* changing misc entities

* probably is now 3 sig figs

* class refactor

* another refactor

* fixing enitiy highlighting

* adding infer timeout

* show class name for known types

* refactoring highlighting

* moving unknown entity type

* removing default badge tooltips

* fixing linting error

* small import changes

Co-authored-by: Kibana Machine <[email protected]>
  • Loading branch information
jgowdyelastic and kibanamachine authored Mar 29, 2022
1 parent ac4e96c commit f4ed8e1
Show file tree
Hide file tree
Showing 23 changed files with 1,007 additions and 14 deletions.
22 changes: 8 additions & 14 deletions x-pack/plugins/ml/common/types/trained_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import type { DataFrameAnalyticsConfig } from './data_frame_analytics';
import type { FeatureImportanceBaseline, TotalFeatureImportance } from './feature_importance';
import type { XOR } from './common';
Expand Down Expand Up @@ -87,14 +87,12 @@ export type PutTrainedModelConfig = {
}
>; // compressed_definition and definition are mutually exclusive

export interface TrainedModelConfigResponse {
description?: string;
created_by: string;
create_time: string;
default_field_map: Record<string, string>;
estimated_heap_memory_usage_bytes: number;
estimated_operations: number;
license_level: string;
export type TrainedModelConfigResponse = estypes.MlTrainedModelConfig & {
/**
* Associated pipelines. Extends response from the ES endpoint.
*/
pipelines?: Record<string, PipelineDefinition> | null;

metadata?: {
analytics_config: DataFrameAnalyticsConfig;
input: unknown;
Expand All @@ -107,11 +105,7 @@ export interface TrainedModelConfigResponse {
tags: string[];
version: string;
inference_config?: Record<string, any>;
/**
* Associated pipelines. Extends response from the ES endpoint.
*/
pipelines?: Record<string, PipelineDefinition> | null;
}
};

export interface PipelineDefinition {
processors?: Array<Record<string, any>>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import { resultsApiProvider } from './results';
import { jobsApiProvider } from './jobs';
import { fileDatavisualizer } from './datavisualizer';
import { savedObjectsApiProvider } from './saved_objects';
import { trainedModelsApiProvider } from './trained_models';
import type {
MlServerDefaults,
MlServerLimits,
Expand Down Expand Up @@ -719,5 +720,6 @@ export function mlApiServicesProvider(httpService: HttpService) {
jobs: jobsApiProvider(httpService),
fileDatavisualizer,
savedObjects: savedObjectsApiProvider(httpService),
trainedModels: trainedModelsApiProvider(httpService),
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
* 2.0.
*/

import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';

import { useMemo } from 'react';
import { HttpFetchQuery } from 'kibana/public';
import { HttpService } from '../http_service';
Expand Down Expand Up @@ -138,6 +140,25 @@ export function trainedModelsApiProvider(httpService: HttpService) {
query: { force },
});
},

inferTrainedModel(modelId: string, payload: any, timeout?: string) {
const body = JSON.stringify(payload);
return httpService.http<estypes.MlInferTrainedModelDeploymentResponse>({
path: `${apiBasePath}/trained_models/infer/${modelId}`,
method: 'POST',
body,
...(timeout ? { query: { timeout } as HttpFetchQuery } : {}),
});
},

ingestPipelineSimulate(payload: estypes.IngestSimulateRequest['body']) {
const body = JSON.stringify(payload);
return httpService.http<estypes.IngestSimulateResponse>({
path: `${apiBasePath}/trained_models/ingest_pipeline_simulate`,
method: 'POST',
body,
});
},
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import { DEPLOYMENT_STATE, TRAINED_MODEL_TYPE } from '../../../../common/constan
import { getUserConfirmationProvider } from './force_stop_dialog';
import { MLSavedObjectsSpacesList } from '../../components/ml_saved_objects_spaces_list';
import { SavedObjectsWarning } from '../../components/saved_objects_warning';
import { TestTrainedModelFlyout, isTestable } from './test_models';

type Stats = Omit<TrainedModelStat, 'model_id'>;

Expand Down Expand Up @@ -134,6 +135,7 @@ export const ModelsList: FC<Props> = ({
const [itemIdToExpandedRowMap, setItemIdToExpandedRowMap] = useState<Record<string, JSX.Element>>(
{}
);
const [showTestFlyout, setShowTestFlyout] = useState<ModelItem | null>(null);
const getUserConfirmation = useMemo(() => getUserConfirmationProvider(overlays, theme), []);

const navigateToPath = useNavigateToPath();
Expand Down Expand Up @@ -470,6 +472,19 @@ export const ModelsList: FC<Props> = ({
return !isPopulatedObject(item.pipelines);
},
},
{
name: i18n.translate('xpack.ml.inference.modelsList.testModelActionLabel', {
defaultMessage: 'Test model',
}),
description: i18n.translate('xpack.ml.inference.modelsList.testModelActionLabel', {
defaultMessage: 'Test model',
}),
icon: 'inputOutput',
type: 'icon',
isPrimary: true,
available: isTestable,
onClick: setShowTestFlyout,
},
] as Array<Action<ModelItem>>)
);
}
Expand Down Expand Up @@ -769,6 +784,12 @@ export const ModelsList: FC<Props> = ({
modelIds={modelIdsToDelete}
/>
)}
{showTestFlyout === null ? null : (
<TestTrainedModelFlyout
model={showTestFlyout}
onClose={setShowTestFlyout.bind(null, null)}
/>
)}
</>
);
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

export { TestTrainedModelFlyout } from './test_flyout';
export { isTestable } from './utils';
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import React, { FC } from 'react';
import { i18n } from '@kbn/i18n';
import { EuiCallOut } from '@elastic/eui';

interface Props {
errorText: string | null;
}

export const ErrorMessage: FC<Props> = ({ errorText }) => {
return errorText === null ? null : (
<>
<EuiCallOut
title={i18n.translate('xpack.ml.trainedModels.testModelsFlyout.inferenceError', {
defaultMessage: 'An error occurred',
})}
color="danger"
iconType="cross"
>
<p>{errorText}</p>
</EuiCallOut>
</>
);
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';

import { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';

const DEFAULT_INPUT_FIELD = 'text_field';

export type FormattedNerResp = Array<{
value: string;
entity: estypes.MlTrainedModelEntities | null;
}>;

export abstract class InferenceBase<TInferResponse> {
protected readonly inputField: string;

constructor(
protected trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
protected model: estypes.MlTrainedModelConfig
) {
this.inputField = model.input?.field_names[0] ?? DEFAULT_INPUT_FIELD;
}

protected abstract infer(inputText: string): Promise<TInferResponse>;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import React, { FC, useState } from 'react';

import { i18n } from '@kbn/i18n';
import { FormattedMessage } from '@kbn/i18n-react';
import { EuiSpacer, EuiTextArea, EuiButton, EuiTabs, EuiTab } from '@elastic/eui';

import { LangIdentInference } from './lang_ident/lang_ident_inference';
import { NerInference } from './ner/ner_inference';
import type { FormattedLangIdentResp } from './lang_ident/lang_ident_inference';
import type { FormattedNerResp } from './ner/ner_inference';

import { MLJobEditor } from '../../../../jobs/jobs_list/components/ml_job_editor';
import { extractErrorMessage } from '../../../../../../common/util/errors';
import { ErrorMessage } from '../inference_error';
import { OutputLoadingContent } from '../output_loading';

interface Props {
inferrer: LangIdentInference | NerInference;
getOutputComponent(output: any): JSX.Element;
}

enum TAB {
TEXT,
RAW,
}

export const InferenceInputForm: FC<Props> = ({ inferrer, getOutputComponent }) => {
const [inputText, setInputText] = useState('');
const [isRunning, setIsRunning] = useState(false);
const [output, setOutput] = useState<FormattedLangIdentResp | FormattedNerResp | null>(null);
const [rawOutput, setRawOutput] = useState<string | null>(null);
const [selectedTab, setSelectedTab] = useState(TAB.TEXT);
const [showOutput, setShowOutput] = useState(false);
const [errorText, setErrorText] = useState<string | null>(null);

async function run() {
setShowOutput(true);
setOutput(null);
setRawOutput(null);
setIsRunning(true);
setErrorText(null);
try {
const { response, rawResponse } = await inferrer.infer(inputText);
setOutput(response);
setRawOutput(JSON.stringify(rawResponse, null, 2));
} catch (e) {
setIsRunning(false);
setOutput(null);
setErrorText(extractErrorMessage(e));
setRawOutput(JSON.stringify(e.body ?? e, null, 2));
}
setIsRunning(false);
}

return (
<>
<EuiTextArea
placeholder={i18n.translate('xpack.ml.trainedModels.testModelsFlyout.langIdent.inputText', {
defaultMessage: 'Input text',
})}
value={inputText}
disabled={isRunning === true}
fullWidth
onChange={(e) => {
setInputText(e.target.value);
}}
/>
<EuiSpacer size="m" />
<div>
<EuiButton
onClick={run}
disabled={isRunning === true || inputText === ''}
fullWidth={false}
>
<FormattedMessage
id="xpack.ml.trainedModels.testModelsFlyout.langIdent.runButton"
defaultMessage="Test"
/>
</EuiButton>
</div>
{showOutput === true ? (
<>
<EuiSpacer size="m" />
<EuiTabs size={'s'}>
<EuiTab
isSelected={selectedTab === TAB.TEXT}
onClick={setSelectedTab.bind(null, TAB.TEXT)}
>
<FormattedMessage
id="xpack.ml.trainedModels.testModelsFlyout.langIdent.markupTab"
defaultMessage="Output"
/>
</EuiTab>
<EuiTab
isSelected={selectedTab === TAB.RAW}
onClick={setSelectedTab.bind(null, TAB.RAW)}
>
<FormattedMessage
id="xpack.ml.trainedModels.testModelsFlyout.langIdent.rawOutput"
defaultMessage="Raw output"
/>
</EuiTab>
</EuiTabs>

<EuiSpacer size="m" />

{selectedTab === TAB.TEXT ? (
<>
{errorText !== null ? (
<ErrorMessage errorText={errorText} />
) : output === null ? (
<OutputLoadingContent text={inputText} />
) : (
<>{getOutputComponent(output)}</>
)}
</>
) : (
<MLJobEditor value={rawOutput ?? ''} readOnly={true} />
)}
</>
) : null}
</>
);
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

export type { FormattedLangIdentResp } from './lang_ident_inference';
export { LangIdentInference } from './lang_ident_inference';
export { LangIdentOutput } from './lang_ident_output';
Loading

0 comments on commit f4ed8e1

Please sign in to comment.