Skip to content

Commit

Permalink
Use device_name instead of device index to support other device
Browse files Browse the repository at this point in the history
Create tensor with device=Integer will always choose cuda as its
deivce for current pytorch version(2.1), other device should use
device={device}:{index}.

Change get_accelerator().current_device() to
get_accelerator().current_device_name() to support other devices.
  • Loading branch information
hipudding committed Jul 13, 2023
1 parent aef6c65 commit ce41575
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,30 +1043,30 @@ def all_gather_coalesced(params: Iterable[Parameter],
param_buffer = torch.empty(
buffer_size,
dtype=param.dtype if not quant else torch.int8,
device=get_accelerator().current_device(),
device=get_accelerator().current_device_name(),
requires_grad=False,
)
param_ds_tensor = param.ds_secondary_tensor if not forward and param.ds_secondary_tensor is not None else param.ds_tensor
if not quant:
handles = _dist_allgather_fn(
param_ds_tensor.to(get_accelerator().current_device()),
param_ds_tensor.to(get_accelerator().current_device_name()),
param_buffer,
ds_process_group,
)
param.data = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device)
return AllGatherHandle(handles, param)
else:
quantized_param, scales = self.quantizer_module.quantize(param_ds_tensor)
handle = _dist_allgather_fn(quantized_param.to(get_accelerator().current_device()), param_buffer,
ds_process_group)
handle = _dist_allgather_fn(quantized_param.to(get_accelerator().current_device_name()),
param_buffer, ds_process_group)

quant_scale_buffer = torch.empty(
scales.numel() * world_size,
dtype=torch.float32,
device=get_accelerator().current_device(),
device=get_accelerator().current_device_name(),
requires_grad=False,
)
quant_handle = _dist_allgather_fn(scales.to(get_accelerator().current_device()),
quant_handle = _dist_allgather_fn(scales.to(get_accelerator().current_device_name()),
quant_scale_buffer, ds_process_group)
quant_info = QuantizationInfo()

Expand All @@ -1086,7 +1086,7 @@ def all_gather_coalesced(params: Iterable[Parameter],
flat_tensor = torch.empty(partition_sz * world_size,
dtype=get_only_unique_item(p.dtype
for p in params) if not quant else torch.int8,
device=get_accelerator().current_device(),
device=get_accelerator().current_device_name(),
requires_grad=False)
if not quant:
partitions: List[Parameter] = []
Expand Down Expand Up @@ -1118,17 +1118,17 @@ def all_gather_coalesced(params: Iterable[Parameter],
use_secondary_tensor = True
quantized_param, scales = self.quantizer_module.quantize(
instrument_w_nvtx(torch.cat)(
[p.ds_secondary_tensor.to(get_accelerator().current_device()) for p in params]))
[p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params]))
else:
quantized_param, scales = self.quantizer_module.quantize(
instrument_w_nvtx(
torch.cat)([p.ds_tensor.to(get_accelerator().current_device()) for p in params]))
torch.cat)([p.ds_tensor.to(get_accelerator().current_device_name()) for p in params]))
handle = _dist_allgather_fn(quantized_param, flat_tensor, ds_process_group)
quant_info = QuantizationInfo()
quant_scale_buffer = torch.empty(
scales.numel() * world_size,
dtype=torch.float32,
device=get_accelerator().current_device(),
device=get_accelerator().current_device_name(),
requires_grad=False,
)
quant_handle = _dist_allgather_fn(scales, quant_scale_buffer, ds_process_group)
Expand Down

0 comments on commit ce41575

Please sign in to comment.