diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5100da5844f52b..cee499f37875f8 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4115,18 +4115,8 @@ def from_pretrained( if not torch.distributed.is_initialized(): raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.") - # Get device type (e.g. "cuda") - try: - # torch 2.6 API - device_type = torch.distributed.distributed_c10d._device_capability()[0] - except AttributeError: - if torch.cuda.is_available(): - device_type = "cuda" - else: - raise RuntimeError( - "Device type unknown. Please run model.tensor_parallel with an explicit DeviceMesh." - ) - # Get torch device module (e.g. torch.cuda) based on device type + # Detect the accelerator on the machine. If no accelerator is available, it returns CPU. + device_type = torch._C._get_accelerator().type device_module = torch.get_device_module(device_type) # Get device with index assuming equal number of devices per host tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count())