From 140d0348917350982041ed4cc787d16ae61b821a Mon Sep 17 00:00:00 2001 From: James Gowdy Date: Thu, 20 Jul 2023 11:34:21 +0100 Subject: [PATCH] [ML] Using model supplied mask token (#162168) Fixes https://github.com/elastic/kibana/issues/159577 Using the `mask_token` property from the model config for testing the model. This is shown in the input placeholder text, in the input validation and for displaying the results. image --- .../text_classification/fill_mask_inference.ts | 16 +++++++++++----- .../plugins/translations/translations/fr-FR.json | 1 - .../plugins/translations/translations/ja-JP.json | 1 - .../plugins/translations/translations/zh-CN.json | 1 - 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/x-pack/plugins/ml/public/application/model_management/test_models/models/text_classification/fill_mask_inference.ts b/x-pack/plugins/ml/public/application/model_management/test_models/models/text_classification/fill_mask_inference.ts index 0c109292f16f1..d2b9b2dc23304 100644 --- a/x-pack/plugins/ml/public/application/model_management/test_models/models/text_classification/fill_mask_inference.ts +++ b/x-pack/plugins/ml/public/application/model_management/test_models/models/text_classification/fill_mask_inference.ts @@ -16,7 +16,7 @@ import { getGeneralInputComponent } from '../text_input'; import { getFillMaskOutputComponent } from './fill_mask_output'; import { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models'; -const MASK = '[MASK]'; +const DEFAULT_MASK_TOKEN = '[MASK]'; export class FillMaskInference extends InferenceBase { protected inferenceType = SUPPORTED_PYTORCH_TASKS.FILL_MASK; @@ -30,6 +30,7 @@ export class FillMaskInference extends InferenceBase defaultMessage: 'Test how well the model predicts a missing word in a phrase.', }), ]; + private maskToken = DEFAULT_MASK_TOKEN; constructor( trainedModelsApi: ReturnType, @@ -38,9 +39,14 @@ export class FillMaskInference extends InferenceBase deploymentId: string ) { super(trainedModelsApi, model, inputType, deploymentId); + // @ts-expect-error mask_token is missing in type + const maskToken = model.inference_config?.[this.inferenceType]?.mask_token; + if (maskToken) { + this.maskToken = maskToken; + } this.initialize([ - this.inputText$.pipe(map((inputText) => inputText.every((t) => t.includes(MASK)))), + this.inputText$.pipe(map((inputText) => inputText.every((t) => t.includes(this.maskToken)))), ]); } @@ -71,7 +77,7 @@ export class FillMaskInference extends InferenceBase public predictedValue(resp: TextClassificationResponse) { const { response, inputText } = resp; - return response[0]?.value ? inputText.replace(MASK, response[0].value) : inputText; + return response[0]?.value ? inputText.replace(this.maskToken, response[0].value) : inputText; } public getInputComponent(): JSX.Element | null { @@ -79,8 +85,8 @@ export class FillMaskInference extends InferenceBase const placeholder = i18n.translate( 'xpack.ml.trainedModels.testModelsFlyout.fillMask.inputText', { - defaultMessage: - 'Enter a phrase to test. Use [MASK] as a placeholder for the missing words.', + defaultMessage: `Enter a phrase to test. Use {maskToken} as a placeholder for the missing words.`, + values: { maskToken: this.maskToken }, } ); diff --git a/x-pack/plugins/translations/translations/fr-FR.json b/x-pack/plugins/translations/translations/fr-FR.json index 5ed58fa443cc8..5446ba0496aae 100644 --- a/x-pack/plugins/translations/translations/fr-FR.json +++ b/x-pack/plugins/translations/translations/fr-FR.json @@ -25599,7 +25599,6 @@ "xpack.ml.trainedModels.nodesList.totalAmountLabel": "Nombre total de nœuds Machine Learning", "xpack.ml.trainedModels.testModelsFlyout.deploymentIdLabel": "ID de déploiement", "xpack.ml.trainedModels.testModelsFlyout.fillMask.info1": "Testez la capacité du modèle à prédire un mot manquant dans une phrase.", - "xpack.ml.trainedModels.testModelsFlyout.fillMask.inputText": "Entrez une expression à tester. Utilisez [MASK] comme espace réservé pour les mots manquants.", "xpack.ml.trainedModels.testModelsFlyout.fillMask.label": "Masque de remplissage", "xpack.ml.trainedModels.testModelsFlyout.generalTextInput.inputText": "Texte d'entrée", "xpack.ml.trainedModels.testModelsFlyout.generalTextInput.inputTitle": "Texte d'entrée", diff --git a/x-pack/plugins/translations/translations/ja-JP.json b/x-pack/plugins/translations/translations/ja-JP.json index 3706ed45645c9..dd39918345c57 100644 --- a/x-pack/plugins/translations/translations/ja-JP.json +++ b/x-pack/plugins/translations/translations/ja-JP.json @@ -25598,7 +25598,6 @@ "xpack.ml.trainedModels.nodesList.totalAmountLabel": "合計機械学習ノード", "xpack.ml.trainedModels.testModelsFlyout.deploymentIdLabel": "デプロイID", "xpack.ml.trainedModels.testModelsFlyout.fillMask.info1": "モデルがフレーズの不足している単語を予測する精度をテストします。", - "xpack.ml.trainedModels.testModelsFlyout.fillMask.inputText": "テストするフレーズを入力してください。足りない語句のプレースホルダーとして[MASK]を使用します。", "xpack.ml.trainedModels.testModelsFlyout.fillMask.label": "マスクを塗りつぶす", "xpack.ml.trainedModels.testModelsFlyout.generalTextInput.inputText": "入力テキスト", "xpack.ml.trainedModels.testModelsFlyout.generalTextInput.inputTitle": "入力テキスト", diff --git a/x-pack/plugins/translations/translations/zh-CN.json b/x-pack/plugins/translations/translations/zh-CN.json index f18e9550a95e5..fb5015ff8467e 100644 --- a/x-pack/plugins/translations/translations/zh-CN.json +++ b/x-pack/plugins/translations/translations/zh-CN.json @@ -25597,7 +25597,6 @@ "xpack.ml.trainedModels.nodesList.totalAmountLabel": "Machine Learning 节点总数", "xpack.ml.trainedModels.testModelsFlyout.deploymentIdLabel": "部署 ID", "xpack.ml.trainedModels.testModelsFlyout.fillMask.info1": "测试模型预测短语中缺失的词的表现。", - "xpack.ml.trainedModels.testModelsFlyout.fillMask.inputText": "输入短语以进行测试。将 [MASK] 用作缺失词的占位符。", "xpack.ml.trainedModels.testModelsFlyout.fillMask.label": "填充掩码", "xpack.ml.trainedModels.testModelsFlyout.generalTextInput.inputText": "输入文本", "xpack.ml.trainedModels.testModelsFlyout.generalTextInput.inputTitle": "输入文本",