Skip to content

Commit

Permalink
Auto-detect accelerator
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 committed Nov 16, 2024
1 parent 9648f31 commit 93ba283
Showing 1 changed file with 2 additions and 12 deletions.
14 changes: 2 additions & 12 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4115,18 +4115,8 @@ def from_pretrained(
if not torch.distributed.is_initialized():
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")

# Get device type (e.g. "cuda")
try:
# torch 2.6 API
device_type = torch.distributed.distributed_c10d._device_capability()[0]
except AttributeError:
if torch.cuda.is_available():
device_type = "cuda"
else:
raise RuntimeError(
"Device type unknown. Please run model.tensor_parallel with an explicit DeviceMesh."
)
# Get torch device module (e.g. torch.cuda) based on device type
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
device_module = torch.get_device_module(device_type)
# Get device with index assuming equal number of devices per host
tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count())
Expand Down

0 comments on commit 93ba283

Please sign in to comment.