From 291c4fac604115118c5417871a3a73637fd67b9a Mon Sep 17 00:00:00 2001 From: n1ck-guo <110074967+n1ck-guo@users.noreply.github.com> Date: Fri, 28 Jul 2023 13:37:05 +0800 Subject: [PATCH] fix hpo exmple (#1122) * fix hpo exmple Signed-off-by: Guo, Heng --- .../text-classification/pruning/hpo/run_glue_no_trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/nlp/huggingface_models/text-classification/pruning/hpo/run_glue_no_trainer.py b/examples/pytorch/nlp/huggingface_models/text-classification/pruning/hpo/run_glue_no_trainer.py index 41fd826ae16..dbbf4ba0329 100644 --- a/examples/pytorch/nlp/huggingface_models/text-classification/pruning/hpo/run_glue_no_trainer.py +++ b/examples/pytorch/nlp/huggingface_models/text-classification/pruning/hpo/run_glue_no_trainer.py @@ -628,6 +628,7 @@ def preprocess_function(examples): eval_metric = metric.compute() logger.info(f"mnli-mm: {eval_metric}") + return eval_metric if __name__ == "__main__": @@ -652,7 +653,7 @@ def preprocess_function(examples): higher_is_better=True, min_train_samples=3) searcher = prepare_hpo(config) - for iter in range(10): + for iter in range(5): print(f'search iter {iter}') st = time.time() params = searcher.suggest() @@ -666,5 +667,5 @@ def preprocess_function(examples): acc = metric['accuracy'] f1 = metric['f1'] rt = time.time() - st - tmp_str = f'{iter + 10}\t{params}\t{acc}\t{f1}\t{rt}\n' + tmp_str = f'{iter}\t{params}\t{acc}\t{f1}\t{rt}\n' print(tmp_str)