diff --git a/test/smoke_test/smoke_test.py b/test/smoke_test/smoke_test.py index a20a94b76..a271a5a3e 100644 --- a/test/smoke_test/smoke_test.py +++ b/test/smoke_test/smoke_test.py @@ -55,18 +55,23 @@ def check_nightly_binaries_date(package: str) -> None: f"Expected {module['name']} to be less then {NIGHTLY_ALLOWED_DELTA} days. But its {date_m_delta}" ) +def test_cuda_runtime_errors_captured() -> None: + cuda_exception_missed=True + try: + torch._assert_async(torch.tensor(0, device="cuda")) + torch._assert_async(torch.tensor(0 + 0j, device="cuda")) + except RuntimeError as e: + if re.search("CUDA", f"{e}"): + print(f"Caught CUDA exception with success: {e}") + cuda_exception_missed = False + else: + raise e + if(cuda_exception_missed): + raise RuntimeError( f"Expected CUDA RuntimeError but have not received!") + def smoke_test_cuda(package: str) -> None: if not torch.cuda.is_available() and is_cuda_system: raise RuntimeError(f"Expected CUDA {gpu_arch_ver}. However CUDA is not loaded.") - if torch.cuda.is_available(): - if torch.version.cuda != gpu_arch_ver: - raise RuntimeError( - f"Wrong CUDA version. Loaded: {torch.version.cuda} Expected: {gpu_arch_ver}" - ) - print(f"torch cuda: {torch.version.cuda}") - # todo add cudnn version validation - print(f"torch cudnn: {torch.backends.cudnn.version()}") - print(f"cuDNN enabled? {torch.backends.cudnn.enabled}") if(package == 'all' and is_cuda_system): for module in MODULES: @@ -80,6 +85,19 @@ def smoke_test_cuda(package: str) -> None: version = imported_module._extension._check_cuda_version() print(f"{module['name']} CUDA: {version}") + if torch.cuda.is_available(): + if torch.version.cuda != gpu_arch_ver: + raise RuntimeError( + f"Wrong CUDA version. Loaded: {torch.version.cuda} Expected: {gpu_arch_ver}" + ) + print(f"torch cuda: {torch.version.cuda}") + # todo add cudnn version validation + print(f"torch cudnn: {torch.backends.cudnn.version()}") + print(f"cuDNN enabled? {torch.backends.cudnn.enabled}") + + # This check has to be run last, since its messing up CUDA runtime + test_cuda_runtime_errors_captured() + def smoke_test_conv2d() -> None: import torch.nn as nn @@ -128,7 +146,6 @@ def main() -> None: ) options = parser.parse_args() print(f"torch: {torch.__version__}") - smoke_test_cuda(options.package) smoke_test_conv2d() if options.package == "all": @@ -138,6 +155,8 @@ def main() -> None: if installation_str.find("nightly") != -1: check_nightly_binaries_date(options.package) + smoke_test_cuda(options.package) + if __name__ == "__main__": main()