Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[trainer] deepspeed bug fixes and tests #10039

Merged
merged 2 commits into from
Feb 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions examples/seq2seq/test_deepspeed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest

from transformers.integrations import is_deepspeed_available
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, require_torch_multi_gpu
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
from utils import load_json


set_seed(42)
MBART_TINY = "sshleifer/tiny-mbart"


# a candidate for testing_utils
def require_deepspeed(test_case):
"""
Decorator marking a test that requires deepspeed
"""
if not is_deepspeed_available():
return unittest.skip("test requires deepspeed")(test_case)
else:
return test_case


@require_deepspeed
class TestDeepSpeed(TestCasePlus):

# XXX: need to do better validation beyond just that the run was successful
def run_quick(self, distributed=None, extra_args_str=None, remove_args_str=None):
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str, remove_args_str)
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0]
assert "eval_bleu" in first_step_stats

def run_quick_no_train(self, distributed=None, extra_args_str=None):
remove_args_str = "--do_train"
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str, remove_args_str)
val_metrics = load_json(os.path.join(output_dir, "val_results.json"))
assert "val_bleu" in val_metrics
test_metrics = load_json(os.path.join(output_dir, "test_results.json"))
assert "test_bleu" in test_metrics

@require_torch_multi_gpu
def test_basic(self):
self.run_quick()

@require_torch_multi_gpu
def test_grad_acum(self):
self.run_quick(extra_args_str="--gradient_accumulation_steps 2")

@require_torch_multi_gpu
def test_no_train(self):
# we should not fail if train is skipped
self.run_quick_no_train()

def run_trainer(
self,
eval_steps: int,
max_len: str,
model_name: str,
num_train_epochs: int,
distributed: bool = False,
extra_args_str: str = None,
remove_args_str: str = None,
):
data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro"
output_dir = self.get_auto_remove_tmp_dir()
args = f"""
--model_name_or_path {model_name}
--data_dir {data_dir}
--output_dir {output_dir}
--overwrite_output_dir
--n_train 8
--n_val 8
--max_source_length {max_len}
--max_target_length {max_len}
--val_max_target_length {max_len}
--do_train
--do_eval
--do_predict
--num_train_epochs {str(num_train_epochs)}
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--learning_rate 3e-3
--warmup_steps 8
--evaluation_strategy steps
--predict_with_generate
--logging_steps 0
--save_steps {str(eval_steps)}
--eval_steps {str(eval_steps)}
--group_by_length
--label_smoothing_factor 0.1
--adafactor
--task translation
--tgt_lang ro_RO
--src_lang en_XX
""".split()
# --eval_beams 2

if extra_args_str is not None:
args.extend(extra_args_str.split())

if remove_args_str is not None:
remove_args = remove_args_str.split()
args = [x for x in args if x not in remove_args]

ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config.json".split()
distributed_args = f"""
{self.test_file_dir}/finetune_trainer.py
""".split()
cmd = ["deepspeed"] + distributed_args + args + ds_args
# keep for quick debug
# print(" ".join(cmd)); die
execute_subprocess_async(cmd, env=self.get_env())

return output_dir
39 changes: 4 additions & 35 deletions examples/seq2seq/test_finetune_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from unittest.mock import patch

from transformers.file_utils import is_apex_available
from transformers.integrations import is_deepspeed_available, is_fairscale_available
from transformers.integrations import is_fairscale_available
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
Expand Down Expand Up @@ -49,17 +49,6 @@ def require_fairscale(test_case):
return test_case


# a candidate for testing_utils
def require_deepspeed(test_case):
"""
Decorator marking a test that requires deepspeed
"""
if not is_deepspeed_available():
return unittest.skip("test requires deepspeed")(test_case)
else:
return test_case


# a candidate for testing_utils
def require_apex(test_case):
"""
Expand All @@ -72,8 +61,8 @@ def require_apex(test_case):


class TestFinetuneTrainer(TestCasePlus):
def finetune_trainer_quick(self, distributed=None, deepspeed=False, extra_args_str=None):
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, deepspeed, extra_args_str)
def finetune_trainer_quick(self, distributed=None, extra_args_str=None):
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str)
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0]
Expand Down Expand Up @@ -107,16 +96,6 @@ def test_finetune_trainer_ddp_sharded_ddp_fp16(self):
def test_finetune_trainer_apex(self):
self.finetune_trainer_quick(extra_args_str="--fp16 --fp16_backend=apex")

@require_torch_multi_gpu
@require_deepspeed
def test_finetune_trainer_deepspeed(self):
self.finetune_trainer_quick(deepspeed=True)

@require_torch_multi_gpu
@require_deepspeed
def test_finetune_trainer_deepspeed_grad_acum(self):
self.finetune_trainer_quick(deepspeed=True, extra_args_str="--gradient_accumulation_steps 2")

@slow
def test_finetune_trainer_slow(self):
# There is a missing call to __init__process_group somewhere
Expand Down Expand Up @@ -146,7 +125,6 @@ def run_trainer(
model_name: str,
num_train_epochs: int,
distributed: bool = False,
deepspeed: bool = False,
extra_args_str: str = None,
):
data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro"
Expand Down Expand Up @@ -186,15 +164,7 @@ def run_trainer(
if extra_args_str is not None:
args.extend(extra_args_str.split())

if deepspeed:
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config.json".split()
distributed_args = f"""
{self.test_file_dir}/finetune_trainer.py
""".split()
cmd = ["deepspeed"] + distributed_args + args + ds_args
execute_subprocess_async(cmd, env=self.get_env())

elif distributed:
if distributed:
n_gpu = get_gpu_count()
distributed_args = f"""
-m torch.distributed.launch
Expand All @@ -203,7 +173,6 @@ def run_trainer(
""".split()
cmd = [sys.executable] + distributed_args + args
execute_subprocess_async(cmd, env=self.get_env())

else:
testargs = ["finetune_trainer.py"] + args
with patch.object(sys, "argv", testargs):
Expand Down
16 changes: 15 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,11 @@ def train(
if (step + 1) % self.args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)

if ((step + 1) % self.args.gradient_accumulation_steps != 0) and self.args.local_rank != -1:
if (
((step + 1) % self.args.gradient_accumulation_steps != 0)
and self.args.local_rank != -1
and not self.args.deepspeed
):
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
tr_loss += self.training_step(model, inputs)
Expand Down Expand Up @@ -1588,7 +1592,17 @@ def prediction_loop(
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
)

if self.args.deepspeed and not self.args.do_train:
# In the future we probably can run deepspeed for inference too, but this will require
# some thinking about how to best run it - since while it works DeepSpeed wasn't
# designed for inference

# since we have to postpone model.to() till training for DeepSpeed, if there was no
# training, we must put the model on the right device
self.model = self.model.to(self.args.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmm, this means a future training might have it on the device already, now? Maybe we should just put on the device the model used (so model = self.model.to(self.args.device) but not stored in self.model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the case where one bypasses the training stage. Remember last PR here had to make a special case for deepspeed not to preload on device so that it could load a model in fp16?

Next I'm experimenting with DeepSpeed for inference only, so this will change again. But for now this is a bug fix.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I understand that. But what if someone does:

trainer = Trainer(...)
trainer.evaluate()
trainer.train()

(agreed it would be weird but trying to have the bug fix be general)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is yet another combination I haven't thought of. Thank you for thinking of it, @sgugger

As I mentioned I'm already working on DeepSpeed for inference so this code will change again shortly. And if I manage to do it - this code will be replaced with deepspeed_init and no switching to device at all. So this area is a wip and this PR is a temporary patch.

So do let me know whether you prefer a more general fix or hopefully today/tomorrow I will have a new version if DeepSpeed supports that - I just started working on it and I think in the worst case if it doesn't let me init it for inference (i.e. w/o optimizer/scheduler) I'll just init DeepSpeed as I'd for training if it's not supporting that at the moment, so really it'd be the same as train. Down the road as DeepSpeed avails itself for inference it'll then improve again. That's the plan at the moment.

And yes, I need to test all these different variations you're pointing at.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for the quick hotfix then! Just want to make sure the proper fix down the road supports all kinds of combination of train/eval.

Copy link
Contributor Author

@stas00 stas00 Feb 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, let's merge it and I will work on the new tests location and then add new tests for all the different combinations.


model = self.model

# multi-gpu eval
if self.args.n_gpu > 1:
model = torch.nn.DataParallel(model)
Expand Down