From 8e3c3c06bb3a4d47f90b637c2c6e6c6a73232ef0 Mon Sep 17 00:00:00 2001 From: xinhe Date: Fri, 5 Aug 2022 10:57:01 +0800 Subject: [PATCH] fix bug in pt/translation example (#1128) --- .../quantization/ptq_dynamic/eager/run_benchmark.sh | 13 +++++++------ .../quantization/ptq_dynamic/eager/run_tuning.sh | 12 +++++++----- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/examples/pytorch/nlp/huggingface_models/translation/quantization/ptq_dynamic/eager/run_benchmark.sh b/examples/pytorch/nlp/huggingface_models/translation/quantization/ptq_dynamic/eager/run_benchmark.sh index f75c0e7fb23..5817f967664 100644 --- a/examples/pytorch/nlp/huggingface_models/translation/quantization/ptq_dynamic/eager/run_benchmark.sh +++ b/examples/pytorch/nlp/huggingface_models/translation/quantization/ptq_dynamic/eager/run_benchmark.sh @@ -52,8 +52,7 @@ function init_params { # run_benchmark function run_benchmark { - extra_cmd='' - + extra_cmd='None' if [[ ${mode} == "accuracy" ]]; then mode_cmd=" --accuracy_only" elif [[ ${mode} == "benchmark" ]]; then @@ -65,10 +64,9 @@ function run_benchmark { if [ "${topology}" = "t5_WMT_en_ro" ];then model_name_or_path='t5-small' - extra_cmd="--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config_name ro-en" + extra_cmd='translate English to Romanian: ' elif [ "${topology}" = "marianmt_WMT_en_ro" ]; then model_name_or_path='Helsinki-NLP/opus-mt-en-ro' - extra_cmd="--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config_name ro-en" fi if [[ ${int8} == "true" ]]; then @@ -82,9 +80,12 @@ function run_benchmark { --predict_with_generate \ --per_device_eval_batch_size ${batch_size} \ --output_dir ${tuned_checkpoint} \ - --source_prefix "translate English to Romanian: " \ + --source_lang en \ + --target_lang ro \ + --dataset_name wmt16 \ + --dataset_config_name ro-en\ ${mode_cmd} \ - ${extra_cmd} + --source_prefix "$extra_cmd" } main "$@" diff --git a/examples/pytorch/nlp/huggingface_models/translation/quantization/ptq_dynamic/eager/run_tuning.sh b/examples/pytorch/nlp/huggingface_models/translation/quantization/ptq_dynamic/eager/run_tuning.sh index bdc1d0c4dae..f7238918d5f 100644 --- a/examples/pytorch/nlp/huggingface_models/translation/quantization/ptq_dynamic/eager/run_tuning.sh +++ b/examples/pytorch/nlp/huggingface_models/translation/quantization/ptq_dynamic/eager/run_tuning.sh @@ -37,18 +37,17 @@ function init_params { # run_tuning function run_tuning { - extra_cmd='' + extra_cmd='None' batch_size=16 model_type='bert' if [ "${topology}" = "t5_WMT_en_ro" ];then model_name_or_path='t5-small' model_type='t5' - extra_cmd="--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config_name ro-en" + extra_cmd='translate English to Romanian: ' elif [ "${topology}" = "marianmt_WMT_en_ro" ]; then model_name_or_path='Helsinki-NLP/opus-mt-en-ro' model_type='marianmt' - extra_cmd="--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config_name ro-en" fi sed -i "/: bert/s|name:.*|name: $model_type|g" conf.yaml @@ -61,10 +60,13 @@ function run_tuning { --predict_with_generate \ --per_device_eval_batch_size ${batch_size} \ --output_dir ${tuned_checkpoint} \ - --source_prefix "translate English to Romanian: " \ + --source_lang en \ + --target_lang ro \ + --dataset_name wmt16 \ + --dataset_config_name ro-en\ --tune \ --overwrite_output_dir \ - $extra_cmd + --source_prefix "$extra_cmd" } main "$@"