From b5131dbc2143e5a6b36ec3154e22fee60d128952 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 5 Feb 2021 17:30:17 -0800 Subject: [PATCH 1/2] deepspeed bug fixes and tests --- examples/seq2seq/test_deepspeed.py | 132 ++++++++++++++++++++++ examples/seq2seq/test_finetune_trainer.py | 39 +------ src/transformers/trainer.py | 13 ++- 3 files changed, 148 insertions(+), 36 deletions(-) create mode 100644 examples/seq2seq/test_deepspeed.py diff --git a/examples/seq2seq/test_deepspeed.py b/examples/seq2seq/test_deepspeed.py new file mode 100644 index 00000000000..1955568ff35 --- /dev/null +++ b/examples/seq2seq/test_deepspeed.py @@ -0,0 +1,132 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +from transformers.integrations import is_deepspeed_available +from transformers.testing_utils import TestCasePlus, execute_subprocess_async, require_torch_multi_gpu +from transformers.trainer_callback import TrainerState +from transformers.trainer_utils import set_seed +from utils import load_json + + +set_seed(42) +MBART_TINY = "sshleifer/tiny-mbart" + + +# a candidate for testing_utils +def require_deepspeed(test_case): + """ + Decorator marking a test that requires deepspeed + """ + if not is_deepspeed_available(): + return unittest.skip("test requires deepspeed")(test_case) + else: + return test_case + + +@require_deepspeed +class TestDeepSpeed(TestCasePlus): + + # XXX: need to do better validation beyond just that the run was successful + def run_quick(self, distributed=None, extra_args_str=None, remove_args_str=None): + output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str, remove_args_str) + logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history + eval_metrics = [log for log in logs if "eval_loss" in log.keys()] + first_step_stats = eval_metrics[0] + assert "eval_bleu" in first_step_stats + + def run_quick_no_train(self, distributed=None, extra_args_str=None): + remove_args_str = "--do_train" + output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str, remove_args_str) + val_metrics = load_json(os.path.join(output_dir, "val_results.json")) + assert "val_bleu" in val_metrics + test_metrics = load_json(os.path.join(output_dir, "test_results.json")) + assert "test_bleu" in test_metrics + + @require_torch_multi_gpu + def test_basic(self): + self.run_quick() + + @require_torch_multi_gpu + def test_grad_acum(self): + self.run_quick(extra_args_str="--gradient_accumulation_steps 2") + + @require_torch_multi_gpu + def test_no_train(self): + # we should not fail if train is skipped + self.run_quick_no_train() + + def run_trainer( + self, + eval_steps: int, + max_len: str, + model_name: str, + num_train_epochs: int, + distributed: bool = False, + extra_args_str: str = None, + remove_args_str: str = None, + ): + data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro" + output_dir = self.get_auto_remove_tmp_dir() + args = f""" + --model_name_or_path {model_name} + --data_dir {data_dir} + --output_dir {output_dir} + --overwrite_output_dir + --n_train 8 + --n_val 8 + --max_source_length {max_len} + --max_target_length {max_len} + --val_max_target_length {max_len} + --do_train + --do_eval + --do_predict + --num_train_epochs {str(num_train_epochs)} + --per_device_train_batch_size 4 + --per_device_eval_batch_size 4 + --learning_rate 3e-3 + --warmup_steps 8 + --evaluation_strategy steps + --predict_with_generate + --logging_steps 0 + --save_steps {str(eval_steps)} + --eval_steps {str(eval_steps)} + --group_by_length + --label_smoothing_factor 0.1 + --adafactor + --task translation + --tgt_lang ro_RO + --src_lang en_XX + """.split() + # --eval_beams 2 + + if extra_args_str is not None: + args.extend(extra_args_str.split()) + + if remove_args_str is not None: + remove_args = remove_args_str.split() + args = [x for x in args if x not in remove_args] + + ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config.json".split() + distributed_args = f""" + {self.test_file_dir}/finetune_trainer.py + """.split() + cmd = ["deepspeed"] + distributed_args + args + ds_args + # keep for quick debug + # print(" ".join(cmd)); die + execute_subprocess_async(cmd, env=self.get_env()) + + return output_dir diff --git a/examples/seq2seq/test_finetune_trainer.py b/examples/seq2seq/test_finetune_trainer.py index 4a925a8e425..77a1c902629 100644 --- a/examples/seq2seq/test_finetune_trainer.py +++ b/examples/seq2seq/test_finetune_trainer.py @@ -18,7 +18,7 @@ from unittest.mock import patch from transformers.file_utils import is_apex_available -from transformers.integrations import is_deepspeed_available, is_fairscale_available +from transformers.integrations import is_fairscale_available from transformers.testing_utils import ( TestCasePlus, execute_subprocess_async, @@ -49,17 +49,6 @@ def require_fairscale(test_case): return test_case -# a candidate for testing_utils -def require_deepspeed(test_case): - """ - Decorator marking a test that requires deepspeed - """ - if not is_deepspeed_available(): - return unittest.skip("test requires deepspeed")(test_case) - else: - return test_case - - # a candidate for testing_utils def require_apex(test_case): """ @@ -72,8 +61,8 @@ def require_apex(test_case): class TestFinetuneTrainer(TestCasePlus): - def finetune_trainer_quick(self, distributed=None, deepspeed=False, extra_args_str=None): - output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, deepspeed, extra_args_str) + def finetune_trainer_quick(self, distributed=None, extra_args_str=None): + output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str) logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history eval_metrics = [log for log in logs if "eval_loss" in log.keys()] first_step_stats = eval_metrics[0] @@ -107,16 +96,6 @@ def test_finetune_trainer_ddp_sharded_ddp_fp16(self): def test_finetune_trainer_apex(self): self.finetune_trainer_quick(extra_args_str="--fp16 --fp16_backend=apex") - @require_torch_multi_gpu - @require_deepspeed - def test_finetune_trainer_deepspeed(self): - self.finetune_trainer_quick(deepspeed=True) - - @require_torch_multi_gpu - @require_deepspeed - def test_finetune_trainer_deepspeed_grad_acum(self): - self.finetune_trainer_quick(deepspeed=True, extra_args_str="--gradient_accumulation_steps 2") - @slow def test_finetune_trainer_slow(self): # There is a missing call to __init__process_group somewhere @@ -146,7 +125,6 @@ def run_trainer( model_name: str, num_train_epochs: int, distributed: bool = False, - deepspeed: bool = False, extra_args_str: str = None, ): data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro" @@ -186,15 +164,7 @@ def run_trainer( if extra_args_str is not None: args.extend(extra_args_str.split()) - if deepspeed: - ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config.json".split() - distributed_args = f""" - {self.test_file_dir}/finetune_trainer.py - """.split() - cmd = ["deepspeed"] + distributed_args + args + ds_args - execute_subprocess_async(cmd, env=self.get_env()) - - elif distributed: + if distributed: n_gpu = get_gpu_count() distributed_args = f""" -m torch.distributed.launch @@ -203,7 +173,6 @@ def run_trainer( """.split() cmd = [sys.executable] + distributed_args + args execute_subprocess_async(cmd, env=self.get_env()) - else: testargs = ["finetune_trainer.py"] + args with patch.object(sys, "argv", testargs): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d604599a423..dd4aa2905b7 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -932,7 +932,11 @@ def train( if (step + 1) % self.args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control) - if ((step + 1) % self.args.gradient_accumulation_steps != 0) and self.args.local_rank != -1: + if ( + ((step + 1) % self.args.gradient_accumulation_steps != 0) + and self.args.local_rank != -1 + and not self.args.deepspeed + ): # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. with model.no_sync(): tr_loss += self.training_step(model, inputs) @@ -1588,7 +1592,14 @@ def prediction_loop( prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only ) + if self.args.deepspeed and not self.args.do_train: + # In the future we probably can run deepspeed for inference too, but this will require some thinking about how to best run it - since while it works DeepSpeed wasn't designed for inference + + # since we have to postpone model.to() till training for DeepSpeed, if there was no training, we must put the model on the right device + self.model = self.model.to(self.args.device) + model = self.model + # multi-gpu eval if self.args.n_gpu > 1: model = torch.nn.DataParallel(model) From f64222cb273b5e63db450474879410f9b3934f0f Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 8 Feb 2021 08:36:52 -0800 Subject: [PATCH 2/2] manual wrap? --- src/transformers/trainer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index dd4aa2905b7..c7cb28afce8 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1593,9 +1593,12 @@ def prediction_loop( ) if self.args.deepspeed and not self.args.do_train: - # In the future we probably can run deepspeed for inference too, but this will require some thinking about how to best run it - since while it works DeepSpeed wasn't designed for inference + # In the future we probably can run deepspeed for inference too, but this will require + # some thinking about how to best run it - since while it works DeepSpeed wasn't + # designed for inference - # since we have to postpone model.to() till training for DeepSpeed, if there was no training, we must put the model on the right device + # since we have to postpone model.to() till training for DeepSpeed, if there was no + # training, we must put the model on the right device self.model = self.model.to(self.args.device) model = self.model