From e330da5870fae34339170b942095a2600fa7a95e Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 21 Jun 2024 17:20:59 -0700 Subject: [PATCH] Fix torch-numpy compatibility conflict in tests (#20004) --- requirements/fabric/base.txt | 2 +- requirements/pytorch/base.txt | 2 +- src/lightning/fabric/utilities/testing/_runif.py | 4 ++-- tests/tests_fabric/strategies/test_ddp_integration.py | 5 +++++ tests/tests_fabric/utilities/test_distributed.py | 5 +++++ 5 files changed, 14 insertions(+), 4 deletions(-) diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt index 7487dd9b754b3..aac884d9c6f43 100644 --- a/requirements/fabric/base.txt +++ b/requirements/fabric/base.txt @@ -6,4 +6,4 @@ torch >=2.0.0, <2.4.0 fsspec[http] >=2022.5.0, <2024.4.0 packaging >=20.0, <=23.1 typing-extensions >=4.4.0, <4.10.0 -lightning-utilities >=0.8.0, <0.12.0 +lightning-utilities >=0.10.0, <0.12.0 diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index 4993a918af099..6372357b6d290 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -9,4 +9,4 @@ fsspec[http] >=2022.5.0, <2024.4.0 torchmetrics >=0.7.0, <1.3.0 # needed for using fixed compare_version packaging >=20.0, <=23.1 typing-extensions >=4.4.0, <4.10.0 -lightning-utilities >=0.8.0, <0.12.0 +lightning-utilities >=0.10.0, <0.12.0 diff --git a/src/lightning/fabric/utilities/testing/_runif.py b/src/lightning/fabric/utilities/testing/_runif.py index 6ab2ff730eec9..9a6f5554baa19 100644 --- a/src/lightning/fabric/utilities/testing/_runif.py +++ b/src/lightning/fabric/utilities/testing/_runif.py @@ -17,7 +17,7 @@ from typing import Dict, List, Optional, Tuple import torch -from lightning_utilities.core.imports import compare_version +from lightning_utilities.core.imports import RequirementCache, compare_version from packaging.version import Version from lightning.fabric.accelerators import XLAAccelerator @@ -112,7 +112,7 @@ def _runif_reasons( reasons.append("Standalone execution") kwargs["standalone"] = True - if deepspeed and not _DEEPSPEED_AVAILABLE: + if deepspeed and not (_DEEPSPEED_AVAILABLE and RequirementCache(module="deepspeed.utils")): reasons.append("Deepspeed") if dynamo: diff --git a/tests/tests_fabric/strategies/test_ddp_integration.py b/tests/tests_fabric/strategies/test_ddp_integration.py index 6f003748b9cce..281f0d47bae0c 100644 --- a/tests/tests_fabric/strategies/test_ddp_integration.py +++ b/tests/tests_fabric/strategies/test_ddp_integration.py @@ -19,6 +19,7 @@ import pytest import torch from lightning.fabric import Fabric +from lightning_utilities.core.imports import RequirementCache from torch._dynamo import OptimizedModule from torch.nn.parallel.distributed import DistributedDataParallel @@ -27,6 +28,10 @@ from tests_fabric.test_fabric import BoringModel +@pytest.mark.skipif( + RequirementCache("torch<2.4") and RequirementCache("numpy>=2.0"), + reason="torch.distributed not compatible with numpy>=2.0", +) @pytest.mark.parametrize( "accelerator", [ diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index 5331a6f9be611..2c30b3aa62ddf 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -20,6 +20,7 @@ _sync_ddp, is_shared_filesystem, ) +from lightning_utilities.core.imports import RequirementCache from tests_fabric.helpers.runif import RunIf @@ -121,6 +122,10 @@ def test_collective_operations(devices, process): spawn_launch(process, devices) +@pytest.mark.skipif( + RequirementCache("torch<2.4") and RequirementCache("numpy>=2.0"), + reason="torch.distributed not compatible with numpy>=2.0", +) @pytest.mark.flaky(reruns=3) # flaky with "process 0 terminated with signal SIGABRT" (GLOO) def test_is_shared_filesystem(tmp_path, monkeypatch): # In the non-distributed case, every location is interpreted as 'shared'