Skip to content

Commit

Permalink
Pipeline parallel training engine. (microsoft#392)
Browse files Browse the repository at this point in the history
Co-authored-by: Jeff Rasley <[email protected]>
  • Loading branch information
Shaden Smith and jeffra authored Sep 10, 2020
1 parent 41db1c2 commit 65c2f97
Show file tree
Hide file tree
Showing 37 changed files with 4,753 additions and 159 deletions.
51 changes: 33 additions & 18 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@

from .runtime.engine import DeepSpeedEngine
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_ADAM
from .runtime.pipe.engine import PipelineEngine
from .runtime.lr_schedules import add_tuning_arguments
from .runtime.config import DeepSpeedConfig
from .runtime.activation_checkpointing import checkpointing
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .utils import logger
from .utils import log_dist

from .pipe import PipelineModule

try:
from .git_version_info import version, git_hash, git_branch
Expand Down Expand Up @@ -99,23 +102,35 @@ def initialize(args,
* ``lr_scheduler``: Wrapped lr scheduler if user ``lr_scheduler`` is passed, or
if ``lr_scheduler`` specified in JSON configuration. Otherwise ``None``.
"""
logger.info(
"DeepSpeed info: version={}, git-hash={}, git-branch={}".format(
__version__,
__git_hash__,
__git_branch__),
)

engine = DeepSpeedEngine(args=args,
model=model,
optimizer=optimizer,
model_parameters=model_parameters,
training_data=training_data,
lr_scheduler=lr_scheduler,
mpu=mpu,
dist_init_required=dist_init_required,
collate_fn=collate_fn,
config_params=config_params)
log_dist("DeepSpeed info: version={}, git-hash={}, git-branch={}".format(
__version__,
__git_hash__,
__git_branch__),
ranks=[0])

if not isinstance(model, PipelineModule):
engine = DeepSpeedEngine(args=args,
model=model,
optimizer=optimizer,
model_parameters=model_parameters,
training_data=training_data,
lr_scheduler=lr_scheduler,
mpu=mpu,
dist_init_required=dist_init_required,
collate_fn=collate_fn,
config_params=config_params)
else:
assert mpu is None, "mpu must be None with pipeline parallelism"
engine = PipelineEngine(args=args,
model=model,
optimizer=optimizer,
model_parameters=model_parameters,
training_data=training_data,
lr_scheduler=lr_scheduler,
mpu=model.mpu(),
dist_init_required=dist_init_required,
collate_fn=collate_fn,
config_params=config_params)

return_items = [
engine,
Expand Down
1 change: 1 addition & 0 deletions deepspeed/pipe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ..runtime.pipe import PipelineModule, LayerSpec, TiedLayerSpec
19 changes: 18 additions & 1 deletion deepspeed/runtime/activation_checkpointing/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,10 @@ def forward(ctx, run_function, *args):
timers.log(['forward'])
if SYNCHRONIZE:
torch.cuda.synchronize()

# Tensors returned from forward() may not be differentiable, e.g., attention mask
non_grad_outputs = [o for o in outputs if not o.is_floating_point()]
ctx.mark_non_differentiable(*non_grad_outputs)
return outputs

@staticmethod
Expand Down Expand Up @@ -548,7 +552,20 @@ def backward(ctx, *args):

if isinstance(outputs, torch.Tensor):
outputs = (outputs, )
torch.autograd.backward(outputs, args)

# Go over args and build the list of gradient tensors. This is usually just args,
# but if the forward pass returns tensors that do not require_grad then we should
# adjust the arguments to autograd.backward() too. This happens when forward()
# returns indices or a mask (such as an attention mask).
# We skip the first needs_input_grad because it corresponds to run_function.
output_tensors = []
grad_tensors = []
for idx, need_grad in enumerate(ctx.needs_input_grad[1:]):
if need_grad:
output_tensors.append(outputs[idx])
grad_tensors.append(args[idx])

torch.autograd.backward(output_tensors, grad_tensors)

if PROFILE_TIME:
timers('backward').stop()
Expand Down
35 changes: 24 additions & 11 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,20 @@ def get_sparse_attention_type(param_dict):
return SPARSE_ATTENTION_TYPE_DEFAULT


def get_pipeline_config(param_dict):
'''Parses pipeline engine configuration. '''
default_pipeline = {
'stages': 'auto',
'partition': 'best',
'seed_layers': False,
'activation_checkpoint_interval': 0
}
config = default_pipeline
for key, val in param_dict.get('pipeline', {}).items():
config[key] = val
return config


def get_optimizer_name(param_dict):
if OPTIMIZER in param_dict.keys() and \
TYPE in param_dict[OPTIMIZER].keys():
Expand Down Expand Up @@ -523,6 +537,7 @@ def _initialize_params(self, param_dict):
self.tensorboard_job_name = get_tensorboard_job_name(param_dict)

self.sparse_attention = get_sparse_attention(param_dict)
self.pipeline = get_pipeline_config(param_dict)

def _batch_assertion(self):

Expand Down Expand Up @@ -592,10 +607,6 @@ def _set_batch_related_parameters(self):
assert False, \
'Either train_batch_size or micro_batch_per_gpu needs to be provided'

logger.info(
f' After Train batch {self.train_batch_size} micro_batch {self.train_micro_batch_size_per_gpu} and grad_acc {self.gradient_accumulation_steps}'
)

def _configure_train_batch_size(self):
self._set_batch_related_parameters()
self._batch_assertion()
Expand Down Expand Up @@ -646,12 +657,14 @@ def _do_warning_check(self):
MAX_GRAD_NORM in self.optimizer_params.keys() and \
self.optimizer_params[MAX_GRAD_NORM] > 0:
if fp16_enabled:
logger.warning(
'DeepSpeedConfig: In FP16 mode, DeepSpeed will pass {}:{} to FP16 wrapper'
.format(MAX_GRAD_NORM,
self.optimizer_params[MAX_GRAD_NORM]))
if self.global_rank == 0:
logger.warning(
'DeepSpeedConfig: In FP16 mode, DeepSpeed will pass {}:{} to FP16 wrapper'
.format(MAX_GRAD_NORM,
self.optimizer_params[MAX_GRAD_NORM]))
else:
logger.warning(
'DeepSpeedConfig: In FP32 mode, DeepSpeed does not permit MAX_GRAD_NORM ({}) > 0, setting to zero'
.format(self.optimizer_params[MAX_GRAD_NORM]))
if self.global_rank == 0:
logger.warning(
'DeepSpeedConfig: In FP32 mode, DeepSpeed does not permit MAX_GRAD_NORM ({}) > 0, setting to zero'
.format(self.optimizer_params[MAX_GRAD_NORM]))
self.optimizer_params[MAX_GRAD_NORM] = 0.0
23 changes: 23 additions & 0 deletions deepspeed/runtime/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,29 @@
from torch.utils.data.distributed import DistributedSampler


class RepeatingLoader:
def __init__(self, loader):
"""Wraps an iterator to allow for infinite iteration. This is especially useful
for DataLoader types that we wish to automatically restart upon completion.
Args:
loader (iterator): The data loader to repeat.
"""
self.loader = loader
self.data_iter = iter(self.loader)

def __iter__(self):
return self

def __next__(self):
try:
batch = next(self.data_iter)
except StopIteration:
self.data_iter = iter(self.loader)
batch = next(self.data_iter)
return batch


class DeepSpeedDataLoader(object):
def __init__(self,
dataset,
Expand Down
Loading

0 comments on commit 65c2f97

Please sign in to comment.