Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CI: enable testing with coming PT 2.2 #19289

Merged
merged 32 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5c95b0c
ci: build dockers for PT 2.2
Borda Jan 15, 2024
4f04497
py3.12
Borda Jan 15, 2024
83bf784
--pre
Borda Jan 15, 2024
f416bad
--extra-index-url
Borda Jan 15, 2024
9752202
test
Borda Jan 15, 2024
309fcc4
typing-extensions
Borda Jan 15, 2024
93a70df
push
Borda Jan 15, 2024
fd3d7eb
ci
Borda Jan 15, 2024
99d2e5d
ci
Borda Jan 15, 2024
aeb7d91
ci
Borda Jan 16, 2024
32d3f2f
ci
Borda Jan 16, 2024
e3d3fd5
push
Borda Jan 16, 2024
c122ef1
ci
Borda Jan 16, 2024
971bb92
Merge branch 'master' into ci/pt-2.2
Borda Jan 18, 2024
7a89d8c
bump jsonargparse
awaelchli Jan 18, 2024
0ed793f
bump jsonargparse
Borda Jan 18, 2024
ab6ad5a
Merge branch 'master' into ci/pt-2.2
Borda Jan 19, 2024
d5ac631
Merge branch 'master' into ci/pt-2.2
Borda Jan 19, 2024
5f03ac2
test jsonargparse
awaelchli Jan 25, 2024
91e16ec
try again
awaelchli Jan 25, 2024
bd597fa
Apply suggestions from code review
Borda Jan 25, 2024
27b6910
debug
awaelchli Jan 25, 2024
84965f0
debug
awaelchli Jan 25, 2024
4a51997
install latest jsonargparse
awaelchli Jan 26, 2024
4f73593
Add windows skips for Fabric
awaelchli Jan 26, 2024
0d0e3db
Merge branch 'master' into ci/pt-2.2
awaelchli Jan 26, 2024
2ce6ee1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 26, 2024
b1e74db
convert to xfail
awaelchli Jan 26, 2024
06a95d0
add pytorch skips
awaelchli Jan 26, 2024
b0f6868
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 26, 2024
482f056
skip checkpoint consolidation test
awaelchli Jan 26, 2024
de66d77
set max torch
awaelchli Jan 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .azure/gpu-tests-fabric.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ jobs:
"Fabric | latest":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.10-torch2.1-cuda12.1.0"
PACKAGE_NAME: "fabric"
"Fabric | future":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.2-cuda12.1.0"
PACKAGE_NAME: "fabric"
"Lightning | latest":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.10-torch2.1-cuda12.1.0"
PACKAGE_NAME: "lightning"
Expand All @@ -73,6 +76,10 @@ jobs:
scope=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(fabric="lightning_fabric").get(n, n))')
echo "##vso[task.setvariable variable=COVERAGE_SOURCE]$scope"
displayName: "set env. vars"
- bash: |
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}/torch_test.html"
condition: endsWith(variables['Agent.JobName'], 'future')
displayName: "set env. vars 4 future"

- bash: |
echo $(DEVICES)
Expand All @@ -99,7 +106,7 @@ jobs:

- bash: |
extra=$(python -c "print({'lightning': 'fabric-'}.get('$(PACKAGE_NAME)', ''))")
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links ${TORCH_URL}
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}"
displayName: "Install package & dependencies"

- bash: |
Expand Down
10 changes: 9 additions & 1 deletion .azure/gpu-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ jobs:
"PyTorch | latest":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.10-torch2.1-cuda12.1.0"
PACKAGE_NAME: "pytorch"
"PyTorch | future":
# todo: failed to install `pygame` with py3.11
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.10-torch2.2-cuda12.1.0"
PACKAGE_NAME: "pytorch"
"Lightning | latest":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.10-torch2.1-cuda12.1.0"
PACKAGE_NAME: "lightning"
Expand All @@ -76,6 +80,10 @@ jobs:
scope=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(pytorch="pytorch_lightning").get(n, n))')
echo "##vso[task.setvariable variable=COVERAGE_SOURCE]$scope"
displayName: "set env. vars"
- bash: |
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}/torch_test.html"
condition: endsWith(variables['Agent.JobName'], 'future')
displayName: "set env. vars 4 future"

- bash: |
echo $(DEVICES)
Expand Down Expand Up @@ -109,7 +117,7 @@ jobs:

- bash: |
extra=$(python -c "print({'lightning': 'pytorch-'}.get('$(PACKAGE_NAME)', ''))")
pip install -e ".[${extra}dev]" -r requirements/_integrations/strategies.txt pytest-timeout -U --find-links ${TORCH_URL}
pip install -e ".[${extra}dev]" -r requirements/_integrations/strategies.txt pytest-timeout -U --find-links="${TORCH_URL}"
displayName: "Install package & dependencies"

- bash: pip uninstall -y lightning
Expand Down
10 changes: 7 additions & 3 deletions .github/workflows/ci-tests-fabric.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,17 @@ jobs:
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "1.13" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.9", pytorch-version: "1.12" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "1.13" }
# only run PyTorch latest with Python recent
# only run PyTorch latest
- { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
- { os: "macOS-11", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.1" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.1" }
# only run PyTorch future
- { os: "macOS-12", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
# only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues
- { os: "macOS-12", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.0" }
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.0" }
Expand Down Expand Up @@ -128,7 +132,7 @@ jobs:
- name: Env. variables
run: |
# Switch PyTorch URL
#python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.release }}' == 'pre' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.2' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
# Switch coverage scope
python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.pkg-name}}' == 'lightning' else 'lightning_fabric'))" >> $GITHUB_ENV
# if you install mono-package set dependency only for this subpackage
Expand Down Expand Up @@ -157,7 +161,7 @@ jobs:

- name: Testing Warnings
working-directory: tests/tests_fabric
# needs to run outside of `pytest`
# needs to run outside `pytest`
run: python utilities/test_warnings.py

- name: Testing Fabric
Expand Down
10 changes: 7 additions & 3 deletions .github/workflows/ci-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,17 @@ jobs:
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "1.13" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.9", pytorch-version: "1.12" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "1.13" }
# only run PyTorch latest with Python recent
# only run PyTorch latest
- { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
- { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
# only run PyTorch future
- { os: "macOS-12", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.2" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.2" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.2" }
Borda marked this conversation as resolved.
Show resolved Hide resolved
# only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues
- { os: "macOS-12", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.0" }
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.0" }
Expand Down Expand Up @@ -134,7 +138,7 @@ jobs:
- name: Env. variables
run: |
# Switch PyTorch URL
#python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.release }}' == 'pre' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.2' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
# Switch coverage scope
python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.pkg-name}}' == 'lightning' else 'pytorch_lightning'))" >> $GITHUB_ENV
# if you install mono-package set dependency only for this subpackage
Expand Down Expand Up @@ -194,7 +198,7 @@ jobs:

- name: Testing Warnings
working-directory: tests/tests_pytorch
# needs to run outside of `pytest`
# needs to run outside `pytest`
run: python utilities/test_warnings.py

- name: Testing PyTorch
Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/docker-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ jobs:
tags = [f"latest-py{py_ver}-torch{pt_ver}-cuda{cuda_ver}"]
if ver:
tags += [f"{ver}-py{py_ver}-torch{pt_ver}-cuda{cuda_ver}"]
if py_ver == '3.10' and pt_ver == '2.0' and cuda_ver == '12.0.1':
if py_ver == '3.10' and pt_ver == '2.1' and cuda_ver == '12.1.0':
tags += ["latest"]

tags = [f"{repo}:{tag}" for tag in tags]
Expand Down Expand Up @@ -103,12 +103,16 @@ jobs:
matrix:
include:
# These are the base images for PL release docker images,
# so include at least all of the combinations in release-dockers.yml.
# so include at least all the combinations in release-dockers.yml.
- { python_version: "3.9", pytorch_version: "1.12", cuda_version: "11.7.1" }
- { python_version: "3.9", pytorch_version: "1.13", cuda_version: "11.8.0" }
- { python_version: "3.9", pytorch_version: "1.13", cuda_version: "12.0.1" }
- { python_version: "3.10", pytorch_version: "2.0", cuda_version: "11.8.0" }
- { python_version: "3.10", pytorch_version: "2.1", cuda_version: "12.1.0" }
- { python_version: "3.10", pytorch_version: "2.2", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" }
# - { python_version: "3.12", pytorch_version: "2.2", cuda_version: "12.1.0" } # todo: pending on `onnxruntime`
steps:
- uses: actions/checkout@v4
- uses: docker/setup-buildx-action@v3
Expand Down
11 changes: 5 additions & 6 deletions dockers/base-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

ARG UBUNTU_VERSION=20.04
ARG UBUNTU_VERSION=22.04
ARG CUDA_VERSION=11.7.1


Expand All @@ -38,7 +38,7 @@ RUN \
# https://github.com/NVIDIA/nvidia-docker/issues/1631
# https://github.com/NVIDIA/nvidia-docker/issues/1631#issuecomment-1264715214
apt-get update && apt-get install -y wget && \
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub && \
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub && \
mkdir -p /etc/apt/keyrings/ && mv 3bf863cc.pub /etc/apt/keyrings/ && \
echo "deb [signed-by=/etc/apt/keyrings/3bf863cc.pub] https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ /" /etc/apt/sources.list.d/cuda.list && \
apt-get update -qq --fix-missing && \
Expand Down Expand Up @@ -82,9 +82,7 @@ COPY requirements/_integrations/ requirements/_integrations/
ENV PYTHONPATH="/usr/lib/python${PYTHON_VERSION}/site-packages"

RUN \
wget https://bootstrap.pypa.io/get-pip.py --progress=bar:force:noscroll --no-check-certificate && \
python${PYTHON_VERSION} get-pip.py && \
rm get-pip.py && \
curl https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} && \
# Disable cache \
pip config set global.cache-dir false && \
# set particular PyTorch version \
Expand All @@ -99,7 +97,8 @@ RUN \
-r requirements/pytorch/extra.txt \
-r requirements/pytorch/test.txt \
-r requirements/pytorch/strategies.txt \
--find-links "https://download.pytorch.org/whl/cu${CUDA_VERSION_MM//'.'/''}/torch_stable.html"
--find-links="https://download.pytorch.org/whl/cu${CUDA_VERSION_MM//'.'/''}/torch_stable.html" \
--find-links="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/torch_test.html"

RUN \
# Show what we have
Expand Down
2 changes: 1 addition & 1 deletion requirements/app/app.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
lightning-cloud == 0.5.61 # Must be pinned to ensure compatibility
packaging
typing-extensions >=4.4.0, <4.8.0
typing-extensions >=4.4.0, <4.10.0
deepdiff >=5.7.0, <6.6.0
fsspec[http] >=2022.5.0, <2023.11.0
croniter >=1.3.0, <1.5.0 # strict; TODO: for now until we find something more robust.
Expand Down
2 changes: 1 addition & 1 deletion requirements/fabric/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ numpy >=1.17.2, <1.27.0
torch >=1.12.0, <2.2.0
fsspec[http] >=2022.5.0, <2023.11.0
packaging >=20.0, <=23.1
typing-extensions >=4.4.0, <4.8.0
typing-extensions >=4.4.0, <4.10.0
lightning-utilities >=0.8.0, <0.10.0
2 changes: 1 addition & 1 deletion requirements/pytorch/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ PyYAML >=5.4, <6.1.0
fsspec[http] >=2022.5.0, <2023.11.0
torchmetrics >=0.7.0, <1.3.0 # needed for using fixed compare_version
packaging >=20.0, <=23.1
typing-extensions >=4.4.0, <4.8.0
typing-extensions >=4.4.0, <4.10.0
lightning-utilities >=0.8.0, <0.10.0
2 changes: 1 addition & 1 deletion requirements/pytorch/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
matplotlib>3.1, <3.9.0
omegaconf >=2.0.5, <2.4.0
hydra-core >=1.0.5, <1.4.0
jsonargparse[signatures] >=4.26.1, <4.27.0
jsonargparse[signatures] >=4.26.1, <4.28.0
rich >=12.3.0, <13.6.0
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
bitsandbytes ==0.41.0 # strict
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration tests for Automatic Mixed Precision (AMP) training."""
import sys

import pytest
import torch
import torch.nn as nn
from lightning.fabric import Fabric, seed_everything
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2

from tests_fabric.helpers.runif import RunIf

Expand All @@ -37,6 +40,11 @@ def forward(self, x):
return output


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
Borda marked this conversation as resolved.
Show resolved Hide resolved
reason="Windows + DDP issue in PyTorch 2.2",
)
@pytest.mark.parametrize(
("accelerator", "precision", "expected_dtype"),
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys

import pytest
import torch
import torch.nn as nn
from lightning.fabric import Fabric
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2

from tests_fabric.helpers.runif import RunIf

Expand All @@ -28,6 +31,11 @@ def __init__(self):
self.register_buffer("buffer", torch.ones(3))


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)
@pytest.mark.parametrize("strategy", ["ddp_spawn", pytest.param("ddp_fork", marks=RunIf(skip_windows=True))])
def test_memory_sharing_disabled(strategy):
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race
Expand Down
8 changes: 7 additions & 1 deletion tests/tests_fabric/strategies/test_ddp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from copy import deepcopy
from unittest import mock
from unittest.mock import Mock

import pytest
import torch
from lightning.fabric import Fabric
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_2
from torch.nn.parallel.distributed import DistributedDataParallel

from tests_fabric.helpers.runif import RunIf
from tests_fabric.strategies.test_single_device import _run_test_clip_gradients
from tests_fabric.test_fabric import BoringModel


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)
@pytest.mark.parametrize(
"accelerator",
[
Expand Down
8 changes: 6 additions & 2 deletions tests/tests_fabric/strategies/test_fsdp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from lightning.fabric import Fabric
from lightning.fabric.plugins import FSDPPrecision
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1
from lightning.fabric.utilities.imports import (
_TORCH_GREATER_EQUAL_2_0,
_TORCH_GREATER_EQUAL_2_1,
)
from lightning.fabric.utilities.load import _load_distributed_checkpoint
from lightning.fabric.wrappers import _FabricOptimizer
from torch.distributed.fsdp import FlatParameter, FullyShardedDataParallel, OptimStateKeyType
Expand Down Expand Up @@ -560,7 +563,8 @@ def test_clip_gradients(clip_type, precision):
optimizer.zero_grad()


@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0")
# TODO: Support checkpoint consolidation with PyTorch >= 2.2
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", max_torch="2.2.0")
def test_save_sharded_and_consolidate_and_load(tmp_path):
"""Test the consolidation of a FSDP-sharded checkpoint into a single file."""

Expand Down
7 changes: 7 additions & 0 deletions tests/tests_fabric/utilities/test_distributed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import os
import sys
from functools import partial
from pathlib import Path
from unittest import mock
Expand All @@ -17,6 +18,7 @@
_sync_ddp,
is_shared_filesystem,
)
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2

from tests_fabric.helpers.runif import RunIf

Expand Down Expand Up @@ -118,6 +120,11 @@ def test_collective_operations(devices, process):
spawn_launch(process, devices)


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)
@pytest.mark.flaky(reruns=3) # flaky with "process 0 terminated with signal SIGABRT" (GLOO)
def test_is_shared_filesystem(tmp_path, monkeypatch):
# In the non-distributed case, every location is interpreted as 'shared'
Expand Down
6 changes: 6 additions & 0 deletions tests/tests_fabric/utilities/test_spike.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import torch
from lightning.fabric import Fabric
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, SpikeDetection, TrainingSpikeException


Expand All @@ -28,6 +29,11 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
)


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)
@pytest.mark.flaky(max_runs=3)
@pytest.mark.parametrize(
("global_rank_spike", "num_devices", "spike_value", "finite_only"),
Expand Down
Loading
Loading