diff --git a/examples/pytorch/nlp/huggingface_models/translation/quantization/ptq_dynamic/eager/run_benchmark.sh b/examples/pytorch/nlp/huggingface_models/translation/quantization/ptq_dynamic/eager/run_benchmark.sh index 379ce7e981d..f75c0e7fb23 100644 --- a/examples/pytorch/nlp/huggingface_models/translation/quantization/ptq_dynamic/eager/run_benchmark.sh +++ b/examples/pytorch/nlp/huggingface_models/translation/quantization/ptq_dynamic/eager/run_benchmark.sh @@ -57,7 +57,7 @@ function run_benchmark { if [[ ${mode} == "accuracy" ]]; then mode_cmd=" --accuracy_only" elif [[ ${mode} == "benchmark" ]]; then - mode_cmd=" --benchmark " + mode_cmd=" --benchmark --max_eval_samples 200 " else echo "Error: No such mode: ${mode}" exit 1 diff --git a/examples/pytorch/nlp/huggingface_models/translation/quantization/ptq_dynamic/eager/run_translation.py b/examples/pytorch/nlp/huggingface_models/translation/quantization/ptq_dynamic/eager/run_translation.py index 66be454ac0f..f7dbe387ce8 100644 --- a/examples/pytorch/nlp/huggingface_models/translation/quantization/ptq_dynamic/eager/run_translation.py +++ b/examples/pytorch/nlp/huggingface_models/translation/quantization/ptq_dynamic/eager/run_translation.py @@ -443,8 +443,6 @@ def preprocess_function(examples): if "train" not in raw_datasets: raise ValueError("--do_train requires a train dataset") train_dataset = raw_datasets["train"] - if data_args.max_train_samples is not None: - train_dataset = train_dataset.select(range(data_args.max_train_samples)) with training_args.main_process_first(desc="train dataset map pre-processing"): train_dataset = train_dataset.map( preprocess_function, @@ -454,14 +452,14 @@ def preprocess_function(examples): load_from_cache_file=not data_args.overwrite_cache, desc="Running tokenizer on train dataset", ) + if data_args.max_train_samples is not None: + train_dataset = train_dataset.select(range(data_args.max_train_samples)) if training_args.do_eval: max_target_length = data_args.val_max_target_length if "validation" not in raw_datasets: raise ValueError("--do_eval requires a validation dataset") eval_dataset = raw_datasets["validation"] - if data_args.max_eval_samples is not None: - eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) with training_args.main_process_first(desc="validation dataset map pre-processing"): eval_dataset = eval_dataset.map( preprocess_function, @@ -471,14 +469,14 @@ def preprocess_function(examples): load_from_cache_file=not data_args.overwrite_cache, desc="Running tokenizer on validation dataset", ) + if data_args.max_eval_samples is not None: + eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) if training_args.do_predict: max_target_length = data_args.val_max_target_length if "test" not in raw_datasets: raise ValueError("--do_predict requires a test dataset") predict_dataset = raw_datasets["test"] - if data_args.max_predict_samples is not None: - predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) with training_args.main_process_first(desc="prediction dataset map pre-processing"): predict_dataset = predict_dataset.map( preprocess_function, @@ -488,6 +486,8 @@ def preprocess_function(examples): load_from_cache_file=not data_args.overwrite_cache, desc="Running tokenizer on prediction dataset", ) + if data_args.max_predict_samples is not None: + predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) # Data collator label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id