Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bf16+pipeline parallelism #1801

Merged
merged 52 commits into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
fb0dc00
bf16 updates
jeffra Feb 18, 2022
6eb4f1f
Got bf16 working
tjruwase Feb 21, 2022
a3d3576
fp32 reduction; flattened tensors
tjruwase Feb 23, 2022
6f5ebc3
bf16+zero_stage_1 first cut
tjruwase Feb 24, 2022
819abe2
finish zero_stage 1 sharding
tjruwase Feb 24, 2022
e48035b
Matching fp16 with debugging codes
tjruwase Feb 27, 2022
8245053
Matching loss with fp16
tjruwase Feb 27, 2022
1529313
Fix gradient clipping
tjruwase Feb 28, 2022
27e5b95
bf16 gradient clipping fix
tjruwase Mar 1, 2022
f497702
Unscale grad norm
tjruwase Mar 1, 2022
0ad7c7d
Fix grad norm scaling
tjruwase Mar 2, 2022
b81d862
Enable loading fp16_zero_1 into bf16_zero_1 engine and vice versa
tjruwase Mar 4, 2022
35ea380
Fix clip_grad key error
tjruwase Mar 4, 2022
37011a9
Reduce tied weight gradients
tjruwase Mar 5, 2022
8fbd4bf
Rebase with master
tjruwase Mar 5, 2022
61d51fd
Fix grad norm for moe
tjruwase Mar 5, 2022
3ee61cd
Merge branch 'master' into olruwase/bf16-updates
jeffra Mar 7, 2022
46cc2ce
Merge branch 'master' into olruwase/bf16-updates
jeffra Mar 8, 2022
de3616c
Reduce specified gradients
tjruwase Mar 8, 2022
89e054d
Merge branch 'olruwase/reduce_specified_gradients' of github.com:micr…
tjruwase Mar 8, 2022
ab61edb
Use O(n) instead of O(n^2)
tjruwase Mar 9, 2022
b7d64fd
Remove optimizer restriction for bf16
tjruwase Mar 9, 2022
1919868
Link bf16 & fp32 params
tjruwase Mar 10, 2022
77b649d
Clip gradients of last stage tied weights
tjruwase Mar 11, 2022
4a505ec
Merge branch 'master' into olruwase/bf16-updates
tjruwase Mar 11, 2022
ff99cb2
Merge branch 'master' into olruwase/bf16-updates
tjruwase Mar 11, 2022
20fdba3
Merge branch 'master' into olruwase/bf16-updates
jeffra Mar 15, 2022
86fa437
Merge branch 'master' into olruwase/bf16-updates
tjruwase Mar 16, 2022
71499a8
Merge branch 'master' into olruwase/bf16-updates
jeffra Mar 16, 2022
7e7fa60
Merge branch 'master' into olruwase/bf16-updates
jeffra Mar 17, 2022
2aa612a
Simplify tied weights reduction logic
tjruwase Mar 17, 2022
2cd21f1
Merge branch 'master' into olruwase/bf16-updates
tjruwase Mar 18, 2022
a4cbf0c
Merge branch 'olruwase/bf16-updates' of github.com:microsoft/DeepSpee…
tjruwase Mar 18, 2022
67ea260
Merge branch 'master' into olruwase/bf16-updates
tjruwase Mar 18, 2022
6a4d6e6
Merge branch 'master' into olruwase/bf16-updates
tjruwase Mar 19, 2022
4e1dcfd
Merge branch 'master' into olruwase/bf16-updates
tjruwase Mar 22, 2022
e24814a
Also clip all tp rank parameters
tjruwase Mar 24, 2022
88cdf61
Merge branch 'olruwase/bf16-updates' of github.com:microsoft/DeepSpee…
tjruwase Mar 24, 2022
20697bc
lp to hp mapping
tjruwase Mar 25, 2022
4e8f7ff
Link lp/hp/optim state; Refresh links after checkpoint load
tjruwase Mar 28, 2022
52a2f10
Merge branch 'master' into olruwase/bf16-updates
tjruwase Mar 28, 2022
3ed5703
Merge branch 'olruwase/bf16-updates' of github.com:microsoft/DeepSpee…
tjruwase Mar 28, 2022
5481b86
Remove debug print
tjruwase Mar 28, 2022
d911e67
Remove debug print
tjruwase Mar 28, 2022
144f652
Simplify zero_grad logic
tjruwase Mar 30, 2022
bb70816
fp32 accessors
tjruwase Mar 30, 2022
89b4b3f
Merge branch 'master' into olruwase/bf16-updates
tjruwase Mar 30, 2022
a9bfaee
Merge branch 'master' into olruwase/bf16-updates
tjruwase Mar 30, 2022
fa4ff11
Fix update bug
tjruwase Mar 31, 2022
cfd5638
Merge branch 'olruwase/bf16-updates' of github.com:microsoft/DeepSpee…
tjruwase Mar 31, 2022
5ea1c60
Merge branch 'master' into olruwase/bf16-updates
tjruwase Mar 31, 2022
0e2a1c5
Merge branch 'master' into olruwase/bf16-updates
tjruwase Apr 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*.swp
*.log
deepspeed/git_version_info_installed.py
__pycache__

# Build + installation data
build/
Expand Down
2 changes: 2 additions & 0 deletions deepspeed/checkpoint/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@

BASE_OPTIMIZER_STATE = 'base_optimizer_state'
SINGLE_PARTITION_OF_FP32_GROUPS = "single_partition_of_fp32_groups"
GROUPS_PADDING = 'groups_padding'

PARTITION_COUNT = 'partition_count'
ZERO_STAGE = 'zero_stage'
CLIP_GRAD = 'clip_grad'

#########################################
# Module checkpoint keys
Expand Down
600 changes: 600 additions & 0 deletions deepspeed/runtime/bf16_optimizer.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,6 @@ def _initialize_params(self, param_dict):
self.fp16_enabled = get_fp16_enabled(param_dict)
self.bfloat16_enabled = get_bfloat16_enabled(param_dict)
assert not (self.fp16_enabled and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled'
assert not (self.bfloat16_enabled and (self.zero_optimization_stage not in {1, 2, 3})), f'bfloat16 mode is only enabled for Zero 1,2,3 currently. got {self.zero_optimization_stage}'
self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled(
param_dict)
self.amp_enabled = get_amp_enabled(param_dict)
Expand Down
5 changes: 5 additions & 0 deletions deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,8 @@ class ValidationMode:
'''
DATALOADER_DROP_LAST = "dataloader_drop_last"
DATALOADER_DROP_LAST_DEFAULT = False

#########################################
# PIPELINE PARALLELISM
#########################################
PIPE_REPLICATED = 'ds_pipe_replicated'
145 changes: 96 additions & 49 deletions deepspeed/runtime/engine.py

Large diffs are not rendered by default.

60 changes: 39 additions & 21 deletions deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.utils import groups, logger, log_dist
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, CLIP_GRAD
import torch.distributed as dist


Expand All @@ -32,11 +32,13 @@ def __init__(self,
mpu=None,
clip_grad=0.0,
fused_adam_legacy=False,
has_moe_layers=False,
timers=None):

self.fused_adam_legacy = fused_adam_legacy
self.timers = timers
self.deepspeed = deepspeed
self.has_moe_layers = has_moe_layers
self.using_pipeline = self.deepspeed.pipeline_parallelism
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
Expand Down Expand Up @@ -93,6 +95,7 @@ def __init__(self,

self.clip_grad = clip_grad
self.norm_type = 2
self.step_count = 0

TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
Expand Down Expand Up @@ -167,11 +170,15 @@ def step_fused_adam(self, closure=None):
self.cur_scale))
return self.overflow

self._global_grad_norm = get_global_norm(norm_list=norm_groups)
scaled_grad_norm = get_global_norm(norm_list=norm_groups)

combined_scale = self.unscale_and_clip_grads(grads_groups_flat,
self._global_grad_norm,
scaled_grad_norm,
apply_scale=False)

# Stash unscaled gradient norm
self._global_grad_norm = scaled_global_grad_norm / self.cur_scale

# norm is in fact norm*cur_scale
self.optimizer.step(grads=[[g] for g in grads_groups_flat],
output_params=[[p] for p in self.fp16_groups_flat],
Expand Down Expand Up @@ -258,26 +265,19 @@ def step(self, closure=None):
self.start_timers([COMPUTE_NORM])

all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu)
#all_groups_norm_old = all_groups_norm
# Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce
if self.using_pipeline:
pg = self.deepspeed.mpu.get_data_parallel_group()
else:
pg = groups._get_data_parallel_group()
scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=pg))
scaled_norm_tensor = torch.tensor(scaled_norm,
device=self.fp32_groups_flat[i].device,
dtype=torch.float)
dist.all_reduce(scaled_norm_tensor, group=pg)
all_groups_norm = scaled_norm_tensor.item()
#print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {torch.distributed.get_rank()}")

self.stop_timers([COMPUTE_NORM])

self._global_grad_norm = get_global_norm(norm_list=[all_groups_norm])
if self.has_moe_layers:
all_groups_norm = self._get_norm_with_moe_layers(all_groups_norm)

scaled_global_grad_norm = get_global_norm(norm_list=[all_groups_norm])

# Stash unscaled gradient norm
self._global_grad_norm = scaled_global_grad_norm / self.cur_scale

self.start_timers([UNSCALE_AND_CLIP])
self.unscale_and_clip_grads(grads_groups_flat, self._global_grad_norm)
self.unscale_and_clip_grads(grads_groups_flat, scaled_global_grad_norm)
self.stop_timers([UNSCALE_AND_CLIP])

self.start_timers([BASIC_STEP])
Expand All @@ -300,8 +300,26 @@ def step(self, closure=None):

self.log_timers(STEP_TIMERS)

self.step_count += 1

return self.overflow

def _get_norm_with_moe_layers(self, all_groups_norm):
#all_groups_norm_old = all_groups_norm
# Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce
if self.using_pipeline:
pg = self.deepspeed.mpu.get_data_parallel_group()
else:
pg = groups._get_data_parallel_group()
scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=pg))
scaled_norm_tensor = torch.tensor(scaled_norm,
device=self.fp32_groups_flat[0].device,
dtype=torch.float)
dist.all_reduce(scaled_norm_tensor, group=pg)
all_groups_norm = scaled_norm_tensor.item()
#print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {torch.distributed.get_rank()}")
return all_groups_norm

def unscale_and_clip_grads(self, grad_groups_flat, total_norm, apply_scale=True):
# compute combined scale factor for this group
combined_scale = self.cur_scale
Expand All @@ -325,8 +343,8 @@ def backward(self, loss, create_graph=False, retain_graph=False):
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
"""
scaled_loss = (loss.float()) * self.cur_scale

scaled_loss = (loss.float()) * self.cur_scale
scaled_loss.backward(create_graph=create_graph, retain_graph=retain_graph)

def _update_scale(self, skip):
Expand Down Expand Up @@ -399,7 +417,7 @@ def state_dict(self):
state_dict['scale_window'] = self.scale_window
state_dict[OPTIMIZER_STATE_DICT] = self.optimizer.state_dict()
state_dict['fp32_groups_flat'] = self.fp32_groups_flat
state_dict['clip_grad'] = self.clip_grad
state_dict[CLIP_GRAD] = self.clip_grad
return state_dict

# Refresh fp32 master params from fp16 copies
Expand Down Expand Up @@ -433,7 +451,7 @@ def load_state_dict(self, state_dict, load_optimizer_states=True):
self.scale_window = state_dict['scale_window']
if load_optimizer_states:
self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT])
self.clip_grad = state_dict['clip_grad']
self.clip_grad = state_dict[CLIP_GRAD]
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
Expand Down
29 changes: 27 additions & 2 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,14 +241,31 @@ def _exec_reduce_tied_grads(self):
# (see https://github.com/EleutherAI/gpt-neox/issues/62#issuecomment-761471944)
if self.zero_optimization_partition_gradients():
self.optimizer.overlapping_partition_gradients_reduce_epilogue()
self.module.allreduce_tied_weight_gradients()

weight_group_list = self.module.get_tied_weights_and_groups()
for weight, group in weight_group_list:
grad = weight._hp_grad if self.bfloat16_enabled() else weight.grad
dist.all_reduce(grad, group=group)

def _exec_reduce_grads(self):
self._force_grad_boundary = True
if self.pipeline_enable_backward_allreduce:
self.allreduce_gradients(bucket_size=MEMORY_OPT_ALLREDUCE_SIZE)
if self.bfloat16_enabled():
if self.zero_optimization_stage() == 0:
self._bf16_reduce_grads()
else:
assert self.zero_optimization_stage() == 1, "only bf16 + z1 are supported"
raise NotImplementedError()
else:
self.allreduce_gradients(bucket_size=MEMORY_OPT_ALLREDUCE_SIZE)
self._force_grad_boundary = False

def _bf16_reduce_grads(self):
# Make our own list of gradients from the optimizer's FP32 grads
grads = []
self.buffered_allreduce_fallback(grads=self.optimizer.get_grads_for_reduction(),
elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE)

def _reserve_pipe_buffers(self, num_buffers):
"""Ensure that each pipeline buffer has at least ``num_buffers`` slots.

Expand Down Expand Up @@ -726,6 +743,10 @@ def _exec_backward_pass(self, buffer_id):
part_grad = None
#print(f'RANK={self.global_rank} BEFORE-BWD restored grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')

if self.bfloat16_enabled() and not self.is_last_stage():
# manually call because we don't call optimizer.backward()
self.optimizer.clear_lp_grads()

# This handles either a single tensor or tuple of tensors.
if isinstance(outputs, tuple):
out_tensors = [t for t in outputs if t.is_floating_point()]
Expand All @@ -734,6 +755,10 @@ def _exec_backward_pass(self, buffer_id):
else:
torch.autograd.backward(tensors=(outputs, ), grad_tensors=(grad_tensors, ))

if self.bfloat16_enabled() and not self.is_last_stage():
# manually call because we don't call optimizer.backward()
self.optimizer.update_hp_grads(clear_lp_grads=False)

# Free up the memory from the output of forward()
self.pipe_buffers['output_tensors'][buffer_id] = None
self.pipe_buffers['outputs'][buffer_id] = None
Expand Down
7 changes: 7 additions & 0 deletions deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,13 @@ def allreduce_tied_weight_gradients(self):
weight = getattr(self.tied_modules[key], comm['weight_attr'])
dist.all_reduce(weight.grad, group=comm['group'])

def get_tied_weights_and_groups(self):
weight_group_list = []
for key, comm in self.tied_comms.items():
weight = getattr(self.tied_modules[key], comm['weight_attr'])
weight_group_list.append((weight, comm['group']))
return weight_group_list

def _synchronize_tied_weights(self):
for key, comm in self.tied_comms.items():
dist.broadcast(
Expand Down
Loading