diff --git a/setup.py b/setup.py index 67575a0e04bf0..47cac5996f816 100644 --- a/setup.py +++ b/setup.py @@ -168,7 +168,7 @@ def build_extensions(self) -> None: def _is_cuda() -> bool: - return torch.version.cuda is not None + return torch.version.cuda is not None and not _is_neuron() def _is_hip() -> bool: