Skip to content

Commit

Permalink
Add Intel Gaudi device/HPU to auto load in instantiate_device_type_te…
Browse files Browse the repository at this point in the history
…sts (pytorch#126970)

### Motivation
Intel Gaudi accelerator (device name hpu) is seen to have good pass rate with the pytorch framework UTs , however being an out-of-tree device, we face challenges in adapting the device to natively run the existing pytorch UTs under pytorch/test. The UTs however is a good indicator of the device stack health and as such we run them regularly with adaptations.
Although we can add Gaudi/HPU device to generate the device specific tests using the TORCH_TEST_DEVICES environment variable, we miss out on lot of features such as executing for specific dtypes, skipping and overriding opInfo. With significant changes introduced every Pytorch release maintaining these adaptations become difficult and time consuming.
Hence with this PR  we introduce Gaudi device in common_device_type framework, so that the tests are instantiated for Gaudi when the library is loaded.
The eventual goal is to introduce Gaudi out-of-tree support as equivalent to in-tree devices

### Changes
Add HPUTestBase of type DeviceTypeTestBase specifying appropriate attributes for Gaudi/HPU.
Include code to check if  intel Gaudi Software library is loaded and if so, add the device to the list of devices considered for instantiation of device type tests

### Additional Context
please refer the following RFC : pytorch/rfcs#63

Pull Request resolved: pytorch#126970
Approved by: https://github.com/albanD
  • Loading branch information
ankurneog authored and TharinduRusira committed Jun 14, 2024
1 parent 613524b commit 045a9f8
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
29 changes: 28 additions & 1 deletion torch/testing/_internal/common_device_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \
skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, IS_WINDOWS, TEST_MPS, TEST_XPU, \
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, IS_WINDOWS, TEST_MPS, TEST_XPU, TEST_HPU, \
_TestParametrizer, compose_parametrize_fns, dtype_name, \
TEST_WITH_MIOPEN_SUGGEST_NHWC, NATIVE_DEVICES, skipIfTorchDynamo, \
get_tracked_input, clear_tracked_input, PRINT_REPRO_ON_FAILURE, \
Expand Down Expand Up @@ -590,6 +590,18 @@ def setUpClass(cls):
def _should_stop_test_suite(self):
return False

class HPUTestBase(DeviceTypeTestBase):
device_type = 'hpu'
primary_device: ClassVar[str]

@classmethod
def get_primary_device(cls):
return cls.primary_device

@classmethod
def setUpClass(cls):
cls.primary_device = 'hpu:0'

class PrivateUse1TestBase(DeviceTypeTestBase):
primary_device: ClassVar[str]
device_mod = None
Expand Down Expand Up @@ -701,6 +713,8 @@ def get_desired_device_type_test_bases(except_for=None, only_for=None, include_l
test_bases.append(MPSTestBase)
if only_for == 'xpu' and TEST_XPU and XPUTestBase not in test_bases:
test_bases.append(XPUTestBase)
if TEST_HPU and HPUTestBase not in test_bases:
test_bases.append(HPUTestBase)
# Filter out the device types based on user inputs
desired_device_type_test_bases = filter_desired_device_types(test_bases, except_for, only_for)
if include_lazy:
Expand Down Expand Up @@ -1060,6 +1074,10 @@ class skipMPSIf(skipIf):
def __init__(self, dep, reason):
super().__init__(dep, reason, device_type='mps')

class skipHPUIf(skipIf):
def __init__(self, dep, reason):
super().__init__(dep, reason, device_type='hpu')

# Skips a test on XLA if the condition is true.
class skipXLAIf(skipIf):

Expand Down Expand Up @@ -1343,6 +1361,9 @@ def onlyMPS(fn):
def onlyXPU(fn):
return onlyOn('xpu')(fn)

def onlyHPU(fn):
return onlyOn('hpu')(fn)

def onlyPRIVATEUSE1(fn):
device_type = torch._C._get_privateuse1_backend_name()
device_mod = getattr(torch, device_type, None)
Expand Down Expand Up @@ -1401,6 +1422,9 @@ def expectedFailureMeta(fn):
def expectedFailureXLA(fn):
return expectedFailure('xla')(fn)

def expectedFailureHPU(fn):
return expectedFailure('hpu')(fn)

# Skips a test on CPU if LAPACK is not available.
def skipCPUIfNoLapack(fn):
return skipCPUIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")(fn)
Expand Down Expand Up @@ -1578,6 +1602,9 @@ def skipXLA(fn):
def skipMPS(fn):
return skipMPSIf(True, "test doesn't work on MPS backend")(fn)

def skipHPU(fn):
return skipHPUIf(True, "test doesn't work on HPU backend")(fn)

def skipPRIVATEUSE1(fn):
return skipPRIVATEUSE1If(True, "test doesn't work on privateuse1 backend")(fn)

Expand Down
10 changes: 10 additions & 0 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,7 @@ def TemporaryDirectoryName(suffix=None):
TEST_MKL = torch.backends.mkl.is_available()
TEST_MPS = torch.backends.mps.is_available()
TEST_XPU = torch.xpu.is_available()
TEST_HPU = True if (hasattr(torch, "hpu") and torch.hpu.is_available()) else False
TEST_CUDA = torch.cuda.is_available()
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None)
custom_device_is_available = hasattr(custom_device_mod, "is_available") and custom_device_mod.is_available()
Expand Down Expand Up @@ -1622,6 +1623,15 @@ def wrapper(*args, **kwargs):
fn(*args, **kwargs)
return wrapper

def skipIfHpu(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if TEST_HPU:
raise unittest.SkipTest("test doesn't currently work with HPU")
else:
fn(*args, **kwargs)
return wrapper

# Skips a test on CUDA if ROCm is available and its version is lower than requested.
def skipIfRocmVersionLessThan(version=None):
def dec_fn(fn):
Expand Down

0 comments on commit 045a9f8

Please sign in to comment.