Skip to content

Commit

Permalink
limit eval samples for benchmark (intel#1060)
Browse files Browse the repository at this point in the history
  • Loading branch information
xin3he authored Jul 14, 2022
1 parent d80c364 commit 43df314
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ function run_benchmark {
if [[ ${mode} == "accuracy" ]]; then
mode_cmd=" --accuracy_only"
elif [[ ${mode} == "benchmark" ]]; then
mode_cmd=" --benchmark "
mode_cmd=" --benchmark --max_eval_samples 200 "
else
echo "Error: No such mode: ${mode}"
exit 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,6 @@ def preprocess_function(examples):
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = raw_datasets["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
with training_args.main_process_first(desc="train dataset map pre-processing"):
train_dataset = train_dataset.map(
preprocess_function,
Expand All @@ -454,14 +452,14 @@ def preprocess_function(examples):
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on train dataset",
)
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))

if training_args.do_eval:
max_target_length = data_args.val_max_target_length
if "validation" not in raw_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = raw_datasets["validation"]
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
with training_args.main_process_first(desc="validation dataset map pre-processing"):
eval_dataset = eval_dataset.map(
preprocess_function,
Expand All @@ -471,14 +469,14 @@ def preprocess_function(examples):
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on validation dataset",
)
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))

if training_args.do_predict:
max_target_length = data_args.val_max_target_length
if "test" not in raw_datasets:
raise ValueError("--do_predict requires a test dataset")
predict_dataset = raw_datasets["test"]
if data_args.max_predict_samples is not None:
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
with training_args.main_process_first(desc="prediction dataset map pre-processing"):
predict_dataset = predict_dataset.map(
preprocess_function,
Expand All @@ -488,6 +486,8 @@ def preprocess_function(examples):
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on prediction dataset",
)
if data_args.max_predict_samples is not None:
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))

# Data collator
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
Expand Down

0 comments on commit 43df314

Please sign in to comment.