diff --git a/tensorizer/utils.py b/tensorizer/utils.py index eb9ec8c..877a085 100644 --- a/tensorizer/utils.py +++ b/tensorizer/utils.py @@ -75,7 +75,11 @@ def convert_bytes(num, decimal=True) -> str: def get_device() -> torch.device: - return torch.device("cuda" if torch.cuda.is_available() else "cpu") + return torch.device( + "cuda" + if torch.cuda.is_available() + else ("mps" if torch.backends.mps.is_available() else "cpu") + ) class GlobalGPUMemoryUsage(NamedTuple):