Skip to content

Commit

Permalink
[ML] Additional testing of trained model in UI (#129209)
Browse files Browse the repository at this point in the history
* [ML] Fixing text classification testing in the UI

* disabling unsupported models

* code clean up

* small refactor

* adding zero shot classification

* translation id

* adding text embedding

* adding fill_mask

* translation id

* code clean up

* adding observable for inference

* refactoring for observables

* removing comment

* refactor

* removing num_top_classes override

* adding optional num_top_classes

* translations

* removing any

* updating error type

* removing any type

* correcting type

* combining checks

* fixing lang ident

* added start check

Co-authored-by: Kibana Machine <[email protected]>
  • Loading branch information
jgowdyelastic and kibanamachine authored Apr 29, 2022
1 parent ced7f11 commit c539923
Show file tree
Hide file tree
Showing 32 changed files with 951 additions and 171 deletions.
8 changes: 4 additions & 4 deletions x-pack/plugins/ml/common/constants/trained_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ export type TrainedModelType = typeof TRAINED_MODEL_TYPE[keyof typeof TRAINED_MO

export const SUPPORTED_PYTORCH_TASKS = {
NER: 'ner',
// ZERO_SHOT_CLASSIFICATION: 'zero_shot_classification',
// CLASSIFICATION_LABELS: 'classification_labels',
// TEXT_CLASSIFICATION: 'text_classification',
// TEXT_EMBEDDING: 'text_embedding',
ZERO_SHOT_CLASSIFICATION: 'zero_shot_classification',
TEXT_CLASSIFICATION: 'text_classification',
TEXT_EMBEDDING: 'text_embedding',
FILL_MASK: 'fill_mask',
} as const;
export type SupportedPytorchTasksType =
typeof SUPPORTED_PYTORCH_TASKS[keyof typeof SUPPORTED_PYTORCH_TASKS];
6 changes: 3 additions & 3 deletions x-pack/plugins/ml/common/util/errors/errors.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import Boom from '@hapi/boom';

import { extractErrorMessage, MLHttpFetchError, MLResponseError, EsErrorBody } from '.';
import { extractErrorMessage, MLHttpFetchError, EsErrorBody } from '.';

describe('ML - error message utils', () => {
describe('extractErrorMessage', () => {
Expand Down Expand Up @@ -39,7 +39,7 @@ describe('ML - error message utils', () => {
expect(extractErrorMessage(stringMessage)).toBe(testMsg);

// kibana error without attributes
const bodyWithoutAttributes: MLHttpFetchError<MLResponseError> = {
const bodyWithoutAttributes: MLHttpFetchError = {
name: 'name',
req: {} as Request,
request: {} as Request,
Expand All @@ -53,7 +53,7 @@ describe('ML - error message utils', () => {
expect(extractErrorMessage(bodyWithoutAttributes)).toBe(testMsg);

// kibana error with attributes
const bodyWithAttributes: MLHttpFetchError<MLResponseError> = {
const bodyWithAttributes: MLHttpFetchError = {
name: 'name',
req: {} as Request,
request: {} as Request,
Expand Down
11 changes: 4 additions & 7 deletions x-pack/plugins/ml/common/util/errors/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,13 @@ export interface MLErrorObject {
fullError?: EsErrorBody;
}

export interface MLHttpFetchError<T> extends HttpFetchError {
export interface MLHttpFetchErrorBase<T> extends HttpFetchError {
body: T;
}

export type ErrorType =
| MLHttpFetchError<MLResponseError>
| EsErrorBody
| Boom.Boom
| string
| undefined;
export type MLHttpFetchError = MLHttpFetchErrorBase<MLResponseError>;

export type ErrorType = MLHttpFetchError | EsErrorBody | Boom.Boom | string | undefined;

export function isEsErrorBody(error: any): error is EsErrorBody {
return error && error.error?.reason !== undefined;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@ import { useEffect, useMemo } from 'react';
import { DEFAULT_MODEL_MEMORY_LIMIT } from '../../../../../../../common/constants/new_job';
import { ml } from '../../../../../services/ml_api_service';
import { JobValidator, VALIDATION_DELAY_MS } from '../../job_validator/job_validator';
import {
MLHttpFetchError,
MLResponseError,
extractErrorMessage,
} from '../../../../../../../common/util/errors';
import { MLHttpFetchError, extractErrorMessage } from '../../../../../../../common/util/errors';
import { useMlKibana } from '../../../../../contexts/kibana';
import { JobCreator } from '../job_creator';

Expand All @@ -41,10 +37,10 @@ export const modelMemoryEstimatorProvider = (
jobValidator: JobValidator
) => {
const modelMemoryCheck$ = new Subject<CalculatePayload>();
const error$ = new Subject<MLHttpFetchError<MLResponseError>>();
const error$ = new Subject<MLHttpFetchError>();

return {
get error$(): Observable<MLHttpFetchError<MLResponseError>> {
get error$(): Observable<MLHttpFetchError> {
return error$.asObservable();
},
get updates$(): Observable<string> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,9 @@ export const ModelsList: FC<Props> = ({
isPrimary: true,
available: isTestable,
onClick: setShowTestFlyout,
enabled: (item) =>
isPopulatedObject(item.stats?.deployment_stats) &&
item.stats?.deployment_stats?.state === DEPLOYMENT_STATE.STARTED,
},
] as Array<Action<ModelItem>>)
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { NerInference } from './ner';
import {
TextClassificationInference,
ZeroShotClassificationInference,
FillMaskInference,
} from './text_classification';
import { TextEmbeddingInference } from './text_embedding';
import { LangIdentInference } from './lang_ident';

export type InferrerType =
| NerInference
| TextClassificationInference
| TextEmbeddingInference
| ZeroShotClassificationInference
| FillMaskInference
| LangIdentInference;
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,38 @@
* 2.0.
*/

import { BehaviorSubject } from 'rxjs';
import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';

import { MLHttpFetchError } from '../../../../../../common/util/errors';
import { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';

const DEFAULT_INPUT_FIELD = 'text_field';

export type FormattedNerResp = Array<{
export type FormattedNerResponse = Array<{
value: string;
entity: estypes.MlTrainedModelEntities | null;
}>;

export interface InferResponse<T, U> {
inputText: string;
response: T;
rawResponse: U;
}

export enum RUNNING_STATE {
STOPPED,
RUNNING,
FINISHED,
FINISHED_WITH_ERRORS,
}

export abstract class InferenceBase<TInferResponse> {
protected readonly inputField: string;
public inputText$ = new BehaviorSubject<string>('');
public inferenceResult$ = new BehaviorSubject<TInferResponse | null>(null);
public inferenceError$ = new BehaviorSubject<MLHttpFetchError | null>(null);
public runningState$ = new BehaviorSubject<RUNNING_STATE>(RUNNING_STATE.STOPPED);

constructor(
protected trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
Expand All @@ -26,5 +45,26 @@ export abstract class InferenceBase<TInferResponse> {
this.inputField = model.input?.field_names[0] ?? DEFAULT_INPUT_FIELD;
}

protected abstract infer(inputText: string): Promise<TInferResponse>;
public setStopped() {
this.inferenceError$.next(null);
this.runningState$.next(RUNNING_STATE.STOPPED);
}
public setRunning() {
this.inferenceError$.next(null);
this.runningState$.next(RUNNING_STATE.RUNNING);
}

public setFinished() {
this.runningState$.next(RUNNING_STATE.FINISHED);
}

public setFinishedWithErrors(error: MLHttpFetchError) {
this.inferenceError$.next(error);
this.runningState$.next(RUNNING_STATE.FINISHED_WITH_ERRORS);
}

protected abstract getInputComponent(): JSX.Element;
protected abstract getOutputComponent(): JSX.Element;

protected abstract infer(): Promise<TInferResponse>;
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,87 +5,62 @@
* 2.0.
*/

import React, { FC, useState } from 'react';
import React, { FC, useState, useMemo } from 'react';

import { i18n } from '@kbn/i18n';
import useObservable from 'react-use/lib/useObservable';
import { FormattedMessage } from '@kbn/i18n-react';
import { EuiSpacer, EuiTextArea, EuiButton, EuiTabs, EuiTab } from '@elastic/eui';

import { LangIdentInference } from './lang_ident/lang_ident_inference';
import { NerInference } from './ner/ner_inference';
import type { FormattedLangIdentResp } from './lang_ident/lang_ident_inference';
import type { FormattedNerResp } from './ner/ner_inference';

import { MLJobEditor } from '../../../../jobs/jobs_list/components/ml_job_editor';
import { EuiSpacer, EuiButton, EuiTabs, EuiTab } from '@elastic/eui';
import { extractErrorMessage } from '../../../../../../common/util/errors';
import { ErrorMessage } from '../inference_error';
import { OutputLoadingContent } from '../output_loading';
import { RUNNING_STATE } from './inference_base';
import { RawOutput } from './raw_output';
import type { InferrerType } from '.';

interface Props {
inferrer: LangIdentInference | NerInference;
getOutputComponent(output: any): JSX.Element;
inferrer: InferrerType;
}

enum TAB {
TEXT,
RAW,
}

export const InferenceInputForm: FC<Props> = ({ inferrer, getOutputComponent }) => {
const [inputText, setInputText] = useState('');
const [isRunning, setIsRunning] = useState(false);
const [output, setOutput] = useState<FormattedLangIdentResp | FormattedNerResp | null>(null);
const [rawOutput, setRawOutput] = useState<string | null>(null);
export const InferenceInputForm: FC<Props> = ({ inferrer }) => {
const [selectedTab, setSelectedTab] = useState(TAB.TEXT);
const [showOutput, setShowOutput] = useState(false);
const [errorText, setErrorText] = useState<string | null>(null);

const runningState = useObservable(inferrer.runningState$);
const inputText = useObservable(inferrer.inputText$);
const inputComponent = useMemo(() => inferrer.getInputComponent(), []);
const outputComponent = useMemo(() => inferrer.getOutputComponent(), []);

async function run() {
setShowOutput(true);
setOutput(null);
setRawOutput(null);
setIsRunning(true);
setErrorText(null);
try {
const { response, rawResponse } = await inferrer.infer(inputText);
setOutput(response);
setRawOutput(JSON.stringify(rawResponse, null, 2));
await inferrer.infer();
} catch (e) {
setIsRunning(false);
setOutput(null);
setErrorText(extractErrorMessage(e));
setRawOutput(JSON.stringify(e.body ?? e, null, 2));
}
setIsRunning(false);
}

return (
<>
<EuiTextArea
placeholder={i18n.translate('xpack.ml.trainedModels.testModelsFlyout.langIdent.inputText', {
defaultMessage: 'Input text',
})}
value={inputText}
disabled={isRunning === true}
fullWidth
onChange={(e) => {
setInputText(e.target.value);
}}
/>
<>{inputComponent}</>
<EuiSpacer size="m" />
<div>
<EuiButton
onClick={run}
disabled={isRunning === true || inputText === ''}
disabled={runningState === RUNNING_STATE.RUNNING || inputText === ''}
fullWidth={false}
>
<FormattedMessage
id="xpack.ml.trainedModels.testModelsFlyout.langIdent.runButton"
id="xpack.ml.trainedModels.testModelsFlyout.inferenceInputForm.runButton"
defaultMessage="Test"
/>
</EuiButton>
</div>
{showOutput === true ? (
{runningState !== RUNNING_STATE.STOPPED ? (
<>
<EuiSpacer size="m" />
<EuiTabs size={'s'}>
Expand All @@ -94,7 +69,7 @@ export const InferenceInputForm: FC<Props> = ({ inferrer, getOutputComponent })
onClick={setSelectedTab.bind(null, TAB.TEXT)}
>
<FormattedMessage
id="xpack.ml.trainedModels.testModelsFlyout.langIdent.markupTab"
id="xpack.ml.trainedModels.testModelsFlyout.inferenceInputForm.markupTab"
defaultMessage="Output"
/>
</EuiTab>
Expand All @@ -103,7 +78,7 @@ export const InferenceInputForm: FC<Props> = ({ inferrer, getOutputComponent })
onClick={setSelectedTab.bind(null, TAB.RAW)}
>
<FormattedMessage
id="xpack.ml.trainedModels.testModelsFlyout.langIdent.rawOutput"
id="xpack.ml.trainedModels.testModelsFlyout.inferenceInputForm.rawOutput"
defaultMessage="Raw output"
/>
</EuiTab>
Expand All @@ -113,16 +88,16 @@ export const InferenceInputForm: FC<Props> = ({ inferrer, getOutputComponent })

{selectedTab === TAB.TEXT ? (
<>
{errorText !== null ? (
{runningState === RUNNING_STATE.RUNNING ? <OutputLoadingContent text={''} /> : null}

{errorText !== null || runningState === RUNNING_STATE.FINISHED_WITH_ERRORS ? (
<ErrorMessage errorText={errorText} />
) : output === null ? (
<OutputLoadingContent text={inputText} />
) : (
<>{getOutputComponent(output)}</>
)}
) : null}

{runningState === RUNNING_STATE.FINISHED ? <>{outputComponent}</> : null}
</>
) : (
<MLJobEditor value={rawOutput ?? ''} readOnly={true} />
<RawOutput inferrer={inferrer} />
)}
</>
) : null}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
* 2.0.
*/

export type { FormattedLangIdentResp } from './lang_ident_inference';
export type { FormattedLangIdentResponse, LangIdentResponse } from './lang_ident_inference';
export { LangIdentInference } from './lang_ident_inference';
export { LangIdentOutput } from './lang_ident_output';
export { getLangIdentOutputComponent } from './lang_ident_output';
Loading

0 comments on commit c539923

Please sign in to comment.