Skip to content

Commit

Permalink
MP ZeRO++ (#3954)
Browse files Browse the repository at this point in the history
* zero++ tutorial PR (#3783)

* [Fix] _conv_flops_compute when padding is a str and stride=1 (#3169)

* fix conv_flops_compute when padding is a str when stride=1

* fix error

* change type of paddings to tuple

* fix padding calculation

* apply formatting check

---------

Co-authored-by: Cheng Li <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>

* fix interpolate flops compute (#3782)

* use `Flops Profiler` to test `model.generate()` (#2515)

* Update profiler.py

* pre-commit run --all-files

* Delete .DS_Store

* Delete .DS_Store

* Delete .DS_Store

---------

Co-authored-by: Jeff Rasley <[email protected]>
Co-authored-by: Cheng Li <[email protected]>

* revert PR #3611 (#3786)

* bump to 0.9.6

* ZeRO++ chinese blog (#3793)

* zeropp chinese blog

* try better quality images

* make title larger

* even larger...

* various fix

* center captions

* more fixes

* fix format

* remove staging trigger (#3792)

* DeepSpeed-Triton for Inference (#3748)

Co-authored-by: Stephen Youn <[email protected]>
Co-authored-by: Arash Bakhtiari <[email protected]>
Co-authored-by: Cheng Li <[email protected]>
Co-authored-by: Ethan Doe <[email protected]>
Co-authored-by: yidoe <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>

* ZeRO++ (#3784)

Co-authored-by: HeyangQin <[email protected]>
Co-authored-by: GuanhuaWang <[email protected]>
Co-authored-by: cmikeh2 <[email protected]>
Co-authored-by: Ammar Ahmad Awan <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
Co-authored-by: Michael Wyatt <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Reza Yazdani <[email protected]>

* adding zero++ to navigation panel of deepspeed.ai (#3796)

* Add ZeRO++ Japanese blog (#3797)

* zeropp chinese blog

* try better quality images

* make title larger

* even larger...

* various fix

* center captions

* more fixes

* fix format

* add ZeRO++ Japanese blog

* add links

---------

Co-authored-by: HeyangQin <[email protected]>
Co-authored-by: Conglong Li <[email protected]>

* Bug Fixes for autotuner and flops profiler (#1880)

* fix autotuner when backward is not called

* fix format

---------

Co-authored-by: Olatunji Ruwase <[email protected]>

* Missing strided copy for gated MLP (#3788)

Co-authored-by: Ammar Ahmad Awan <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
Co-authored-by: Logan Adams <[email protected]>

* Requires grad checking. (#3789)

Co-authored-by: Jeff Rasley <[email protected]>

* bump to 0.10.0

* Fix Bug in transform.cu (#3534)

* Bug fix

* Fixed formatting error

---------

Co-authored-by: Logan Adams <[email protected]>

* bug fix: triton importing error (#3799)

Co-authored-by: Stephen Youn <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>

* init commit for mixed precision lora

* fix format

* patch _allgather_params & minor fixes

* make sure initial quantization are finished

* make sure dequantization is finished

* skip quantization for small parameters

* fix format

* remove unused async_op

* lazy load of quantizer kernels

* add mixed precision lora tutorial

* cleanup mics

* cleanup mics

* replace get_accelerator().current_device()

* add kwargs to mics

* fix format

* seperate code and tutorial

* fix _all_gather in zero3

---------

Co-authored-by: Bill Luo <[email protected]>
Co-authored-by: Cheng Li <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Guorun <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
Co-authored-by: stephen youn <[email protected]>
Co-authored-by: Stephen Youn <[email protected]>
Co-authored-by: Arash Bakhtiari <[email protected]>
Co-authored-by: Ethan Doe <[email protected]>
Co-authored-by: yidoe <[email protected]>
Co-authored-by: GuanhuaWang <[email protected]>
Co-authored-by: cmikeh2 <[email protected]>
Co-authored-by: Ammar Ahmad Awan <[email protected]>
Co-authored-by: Michael Wyatt <[email protected]>
Co-authored-by: Reza Yazdani <[email protected]>
Co-authored-by: Masahiro Tanaka <[email protected]>
Co-authored-by: Conglong Li <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Joe Mayer <[email protected]>
Co-authored-by: Ramya Ramineni <[email protected]>
  • Loading branch information
21 people authored Aug 21, 2023
1 parent f036f00 commit 7711bdb
Show file tree
Hide file tree
Showing 7 changed files with 358 additions and 183 deletions.
42 changes: 25 additions & 17 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,9 @@ def zero_hpz_partition_size(self):
def zero_quantized_weights(self):
return self._config.zero_config.zero_quantized_weights

def zero_quantized_nontrainable_weights(self):
return self._config.zero_config.zero_quantized_nontrainable_weights

def zero_quantized_gradients(self):
return self._config.zero_config.zero_quantized_gradients

Expand Down Expand Up @@ -1470,23 +1473,26 @@ def _configure_zero_optimizer(self, optimizer):
assert not self.has_moe_layers, "MoE not supported with Stage 3"
if isinstance(optimizer, DummyOptim):
log_dist("Creating ZeRO Offload", ranks=[0])
zpg = groups._get_zero_param_intra_parallel_group()
if self.zero_hpz_partition_size() > 1 and zpg is None:
zero_param_parallel_group = groups._get_zero_param_intra_parallel_group()
if self.zero_hpz_partition_size() > 1 and zero_param_parallel_group is None:
self._set_zero_group_parallelism()
zpg = groups._get_zero_param_intra_parallel_group()
optimizer = DeepSpeedZeRoOffload(self.module,
timers=timers,
ds_config=self.config,
overlap_comm=self.zero_overlap_comm(),
prefetch_bucket_size=self.zero_prefetch_bucket_size(),
max_reuse_distance=self.zero_max_reuse_distance(),
max_live_parameters=self.zero_max_live_parameters(),
param_persistence_threshold=self.zero_param_persistence_threshold(),
model_persistence_threshold=self.zero_model_persistence_threshold(),
offload_param_config=self.zero_offload_param(),
mpu=self.mpu,
zero_param_parallel_group=zpg,
zero_quantized_weights=self.zero_quantized_weights())
zero_param_parallel_group = groups._get_zero_param_intra_parallel_group()
optimizer = DeepSpeedZeRoOffload(
self.module,
timers=timers,
ds_config=self.config,
overlap_comm=self.zero_overlap_comm(),
prefetch_bucket_size=self.zero_prefetch_bucket_size(),
max_reuse_distance=self.zero_max_reuse_distance(),
max_live_parameters=self.zero_max_live_parameters(),
param_persistence_threshold=self.zero_param_persistence_threshold(),
model_persistence_threshold=self.zero_model_persistence_threshold(),
offload_param_config=self.zero_offload_param(),
mpu=self.mpu,
zero_param_parallel_group=zero_param_parallel_group,
zero_quantized_weights=self.zero_quantized_weights(),
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
)
else:
log_dist(
f'Creating fp16 ZeRO stage {zero_stage} optimizer,'
Expand Down Expand Up @@ -1529,7 +1535,9 @@ def _configure_zero_optimizer(self, optimizer):
gradient_accumulation_dtype=gradient_accumulation_dtype,
communication_data_type=self.communication_data_type,
zero_hpz_partition_size=self.zero_hpz_partition_size(),
zero_quantized_weights=self.zero_quantized_weights())
zero_quantized_weights=self.zero_quantized_weights(),
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
)

else:
raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))
Expand Down
10 changes: 9 additions & 1 deletion deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"round_robin_gradients": [true|false],
"zero_hpz_partition_size": 1,
"zero_quantized_weights": [true|false],
"zero_quantized_nontrainable_weights": [true|false],
"zero_quantized_gradients": [true|false],
"memory_efficient_linear": [true|false],
"override_module_apply": [true|false],
Expand Down Expand Up @@ -258,9 +259,16 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
"""
zero_quantized_weights: bool = False
"""
Boolean indicating whether to quantized zero parameters (weights)
Boolean indicating whether to quantize zero parameters (weights)
for efficient all_gather comm
"""
zero_quantized_nontrainable_weights: bool = False
"""
Boolean indicating whether to quantize non-trainable zero parameters (weights)
for efficient memory usage and communication. Different from zero_quantized_weights
that stores the weights in original precision and only perform quantization during communication,
this flag will store the weights in quantized precision. This is useful for LoRA training.
"""
zero_quantized_gradients: bool = False
"""
Boolean indicating whether to use quantized zero gradients
Expand Down
40 changes: 5 additions & 35 deletions deepspeed/runtime/zero/mics.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _convert_to_deepspeed_param(self, param):
# so that we can fallback later
old_all_gather_coalesced = param.all_gather_coalesced

def _param_all_gather_coalesced(params, safe_mode=False, param_buffers=None):
def _param_all_gather_coalesced(params, param_buffers=None, **kwargs):
""""""
mics_comm_groups: MiCS_CommGroups = params[0].comm
hierarchical_all_gather = has_hierarchical_all_gather_groups(mics_comm_groups)
Expand All @@ -165,7 +165,7 @@ def _param_all_gather_coalesced(params, safe_mode=False, param_buffers=None):
elif dist.has_coalescing_manager():
return self._flat_all_gather_with_coalescing_manager(params, param_buffers)
else:
return old_all_gather_coalesced(params, safe_mode)
return old_all_gather_coalesced(params, **kwargs)

# change the all_gather_coalesced method
param.all_gather_coalesced = _param_all_gather_coalesced
Expand Down Expand Up @@ -308,22 +308,6 @@ class MiCS_Offload(DeepSpeedZeRoOffload):
""" Wrapper to change the behavior for parameter sharding
"""

def __init__(self,
module,
timers,
ds_config,
overlap_comm=True,
prefetch_bucket_size=50000000,
max_reuse_distance=1000000000,
max_live_parameters=1000000000,
param_persistence_threshold=100000,
model_persistence_threshold=sys.maxsize,
offload_param_config=None,
mpu=None):
super().__init__(module, timers, ds_config, overlap_comm, prefetch_bucket_size, max_reuse_distance,
max_live_parameters, param_persistence_threshold, model_persistence_threshold,
offload_param_config, mpu)

def _convert_to_zero_parameters(self, ds_config, module, mpu):
""" overload the parent class function for convert the parameters
Expand Down Expand Up @@ -405,24 +389,10 @@ def __init__(self,

def initialize_ds_offload(
self,
module,
timers,
ds_config,
overlap_comm,
prefetch_bucket_size,
max_reuse_distance,
max_live_parameters,
param_persistence_threshold,
model_persistence_threshold,
offload_param_config,
mpu,
zpg=None,
zero_quantized_weights=False,
*args,
**kwargs,
):
assert not zero_quantized_weights and zpg is None, "MiCS is mutually exclusive with ZeRO++"
return MiCS_Offload(module, timers, ds_config, overlap_comm, prefetch_bucket_size, max_reuse_distance,
max_live_parameters, param_persistence_threshold, model_persistence_threshold,
offload_param_config, mpu)
return MiCS_Offload(*args, **kwargs)

def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None:
grad_buffers = super().partition_grads(params_to_release, grad_partitions)
Expand Down
37 changes: 22 additions & 15 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,20 +200,23 @@ def backward(ctx, *args):

class DeepSpeedZeRoOffload(object):

def __init__(self,
module,
timers,
ds_config,
overlap_comm=True,
prefetch_bucket_size=50000000,
max_reuse_distance=1000000000,
max_live_parameters=1000000000,
param_persistence_threshold=100000,
model_persistence_threshold=sys.maxsize,
offload_param_config=None,
mpu=None,
zero_param_parallel_group=None,
zero_quantized_weights=False):
def __init__(
self,
module,
timers,
ds_config,
overlap_comm=True,
prefetch_bucket_size=50000000,
max_reuse_distance=1000000000,
max_live_parameters=1000000000,
param_persistence_threshold=100000,
model_persistence_threshold=sys.maxsize,
offload_param_config=None,
mpu=None,
zero_param_parallel_group=None,
zero_quantized_weights=False,
zero_quantized_nontrainable_weights=False,
):

see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True)

Expand All @@ -226,6 +229,7 @@ def __init__(self,
self.offload_param_pin_memory = False
self.zero_param_parallel_group = zero_param_parallel_group
self.zero_quantized_weights = zero_quantized_weights
self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights

if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none:
self.offload_device = offload_param_config.device
Expand Down Expand Up @@ -286,6 +290,8 @@ def get_param_coordinator(self, training):
inflight_param_registry=self.__inflight_param_registry[training],
prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme,
timers=self.timers,
zero_quantized_weights=self.zero_quantized_weights,
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights,
)

return self.param_coordinators[training]
Expand All @@ -312,7 +318,8 @@ def _convert_to_zero_parameters(self, ds_config, module, mpu):
pin_memory=self.offload_param_pin_memory,
mpu=mpu,
zero_param_parallel_group=self.zero_param_parallel_group,
zero_quantized_weights=self.zero_quantized_weights)
zero_quantized_weights=self.zero_quantized_weights,
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights)

def destroy(self):
self._remove_module_hooks()
Expand Down
Loading

0 comments on commit 7711bdb

Please sign in to comment.