diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 5bd5bf070fa9..526276917aa1 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -1043,13 +1043,13 @@ def all_gather_coalesced(params: Iterable[Parameter], param_buffer = torch.empty( buffer_size, dtype=param.dtype if not quant else torch.int8, - device=get_accelerator().current_device(), + device=get_accelerator().current_device_name(), requires_grad=False, ) param_ds_tensor = param.ds_secondary_tensor if not forward and param.ds_secondary_tensor is not None else param.ds_tensor if not quant: handles = _dist_allgather_fn( - param_ds_tensor.to(get_accelerator().current_device()), + param_ds_tensor.to(get_accelerator().current_device_name()), param_buffer, ds_process_group, ) @@ -1057,16 +1057,16 @@ def all_gather_coalesced(params: Iterable[Parameter], return AllGatherHandle(handles, param) else: quantized_param, scales = self.quantizer_module.quantize(param_ds_tensor) - handle = _dist_allgather_fn(quantized_param.to(get_accelerator().current_device()), param_buffer, - ds_process_group) + handle = _dist_allgather_fn(quantized_param.to(get_accelerator().current_device_name()), + param_buffer, ds_process_group) quant_scale_buffer = torch.empty( scales.numel() * world_size, dtype=torch.float32, - device=get_accelerator().current_device(), + device=get_accelerator().current_device_name(), requires_grad=False, ) - quant_handle = _dist_allgather_fn(scales.to(get_accelerator().current_device()), + quant_handle = _dist_allgather_fn(scales.to(get_accelerator().current_device_name()), quant_scale_buffer, ds_process_group) quant_info = QuantizationInfo() @@ -1086,7 +1086,7 @@ def all_gather_coalesced(params: Iterable[Parameter], flat_tensor = torch.empty(partition_sz * world_size, dtype=get_only_unique_item(p.dtype for p in params) if not quant else torch.int8, - device=get_accelerator().current_device(), + device=get_accelerator().current_device_name(), requires_grad=False) if not quant: partitions: List[Parameter] = [] @@ -1118,17 +1118,17 @@ def all_gather_coalesced(params: Iterable[Parameter], use_secondary_tensor = True quantized_param, scales = self.quantizer_module.quantize( instrument_w_nvtx(torch.cat)( - [p.ds_secondary_tensor.to(get_accelerator().current_device()) for p in params])) + [p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params])) else: quantized_param, scales = self.quantizer_module.quantize( instrument_w_nvtx( - torch.cat)([p.ds_tensor.to(get_accelerator().current_device()) for p in params])) + torch.cat)([p.ds_tensor.to(get_accelerator().current_device_name()) for p in params])) handle = _dist_allgather_fn(quantized_param, flat_tensor, ds_process_group) quant_info = QuantizationInfo() quant_scale_buffer = torch.empty( scales.numel() * world_size, dtype=torch.float32, - device=get_accelerator().current_device(), + device=get_accelerator().current_device_name(), requires_grad=False, ) quant_handle = _dist_allgather_fn(scales, quant_scale_buffer, ds_process_group)