diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index c8bde6390b3c..9ff5a7232a73 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -271,6 +271,7 @@ def __init__(self, remote_device=None, pin_memory=False, deepspeed_config=None, + param_dict=None, enabled=True): """A context to enable massive model construction for training with ZeRO-3. Models are automatically partitioned (or, sharded) across the @@ -293,6 +294,8 @@ def __init__(self, ``"cpu"``. Defaults to ``False``. deepspeed_config (``json file``, optional): If provided, provides configuration for swapping fp16 params to NVMe. + param_dict (dict, optional): Instead of requiring a deepspeed_config you can pass your deepspeed config + as a dictionary instead for swapping fp16 params to NVMe. enabled (bool, optional): If ``False``, this context has no effect. Defaults to ``True``. @@ -382,7 +385,7 @@ def get_model(): #It is the device where parameters are fully instantiated using allgather self.local_device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) - self._validate_remote_device(remote_device, deepspeed_config) + self._validate_remote_device(remote_device, deepspeed_config, param_dict) #Remote device is the device where parameter partiitons are stored #It can be same as local_device or it could be CPU or NVMe. @@ -392,7 +395,7 @@ def get_model(): # Enable fp16 param swapping to NVMe if self.remote_device == OFFLOAD_NVME_DEVICE: - _ds_config = DeepSpeedConfig(deepspeed_config) + _ds_config = DeepSpeedConfig(deepspeed_config, param_dict=param_dict) self.param_swapper = AsyncPartitionedParameterSwapper(_ds_config) else: self.param_swapper = None @@ -406,9 +409,9 @@ def get_model(): self._convert_to_deepspeed_param(param) param.partition() - def _validate_remote_device(self, remote_device, ds_config): + def _validate_remote_device(self, remote_device, ds_config, param_dict): if ds_config is not None: - _ds_config = DeepSpeedConfig(ds_config) + _ds_config = DeepSpeedConfig(ds_config, param_dict=param_dict) if remote_device in [None, OFFLOAD_CPU_DEVICE]: if _ds_config.zero_config.offload_param is not None: offload_param_device = _ds_config.zero_config.offload_param[