Skip to content

Commit

Permalink
Update tests for PyTorch 2.2.1 (#19521)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Mar 1, 2024
1 parent 9a1ca82 commit 956a895
Show file tree
Hide file tree
Showing 12 changed files with 3 additions and 96 deletions.
7 changes: 0 additions & 7 deletions tests/tests_fabric/plugins/precision/test_amp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration tests for Automatic Mixed Precision (AMP) training."""
import sys

import pytest
import torch
import torch.nn as nn
from lightning.fabric import Fabric, seed_everything
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2

from tests_fabric.helpers.runif import RunIf

Expand All @@ -40,11 +38,6 @@ def forward(self, x):
return output


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)
@pytest.mark.parametrize(
("accelerator", "precision", "expected_dtype"),
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys

import pytest
import torch
import torch.nn as nn
from lightning.fabric import Fabric
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2

from tests_fabric.helpers.runif import RunIf

Expand All @@ -31,11 +29,6 @@ def __init__(self):
self.register_buffer("buffer", torch.ones(3))


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)
@pytest.mark.parametrize("strategy", ["ddp_spawn", pytest.param("ddp_fork", marks=RunIf(skip_windows=True))])
def test_memory_sharing_disabled(strategy):
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race
Expand Down
8 changes: 1 addition & 7 deletions tests/tests_fabric/strategies/test_ddp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from copy import deepcopy
from unittest import mock
from unittest.mock import Mock

import pytest
import torch
from lightning.fabric import Fabric
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from torch.nn.parallel.distributed import DistributedDataParallel

from tests_fabric.helpers.runif import RunIf
from tests_fabric.strategies.test_single_device import _run_test_clip_gradients
from tests_fabric.test_fabric import BoringModel


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)
@pytest.mark.parametrize(
"accelerator",
[
Expand Down
7 changes: 0 additions & 7 deletions tests/tests_fabric/utilities/test_distributed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import functools
import os
import sys
from functools import partial
from pathlib import Path
from unittest import mock
Expand All @@ -18,7 +17,6 @@
_sync_ddp,
is_shared_filesystem,
)
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2

from tests_fabric.helpers.runif import RunIf

Expand Down Expand Up @@ -120,11 +118,6 @@ def test_collective_operations(devices, process):
spawn_launch(process, devices)


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)
@pytest.mark.flaky(reruns=3) # flaky with "process 0 terminated with signal SIGABRT" (GLOO)
def test_is_shared_filesystem(tmp_path, monkeypatch):
# In the non-distributed case, every location is interpreted as 'shared'
Expand Down
6 changes: 0 additions & 6 deletions tests/tests_fabric/utilities/test_spike.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pytest
import torch
from lightning.fabric import Fabric
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, SpikeDetection, TrainingSpikeException


Expand All @@ -29,11 +28,6 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
)


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)
@pytest.mark.flaky(max_runs=3)
@pytest.mark.parametrize(
("global_rank_spike", "num_devices", "spike_value", "finite_only"),
Expand Down
6 changes: 0 additions & 6 deletions tests/tests_pytorch/callbacks/test_spike.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, TrainingSpikeException
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks.spike import SpikeDetection
Expand Down Expand Up @@ -47,11 +46,6 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)
@pytest.mark.flaky(max_runs=3)
@pytest.mark.parametrize(
("global_rank_spike", "num_devices", "spike_value", "finite_only"),
Expand Down
7 changes: 0 additions & 7 deletions tests/tests_pytorch/loops/test_prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import sys

import pytest
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
Expand Down Expand Up @@ -52,11 +50,6 @@ def predict_step(self, batch, batch_idx):
assert trainer.predict_loop.predictions == []


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)
@pytest.mark.parametrize("use_distributed_sampler", [False, True])
def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path, use_distributed_sampler):
"""Tests that set_epoch is called on the dataloader's batch sampler (if any) during prediction."""
Expand Down
13 changes: 1 addition & 12 deletions tests/tests_pytorch/models/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from unittest import mock

import pytest
import torch
from lightning.fabric.plugins.environments import SLURMEnvironment
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -55,16 +53,7 @@ def _assert_autocast_enabled(self):
[
("single_device", "16-mixed", 1),
("single_device", "bf16-mixed", 1),
pytest.param(
"ddp_spawn",
"16-mixed",
2,
marks=pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
),
),
("ddp_spawn", "16-mixed", 2),
pytest.param("ddp_spawn", "bf16-mixed", 2, marks=RunIf(skip_windows=True)),
],
)
Expand Down
7 changes: 0 additions & 7 deletions tests/tests_pytorch/serve/test_servable_module_validator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import sys
from typing import Dict

import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.serve.servable_module_validator import ServableModule, ServableModuleValidator
Expand Down Expand Up @@ -38,11 +36,6 @@ def test_servable_module_validator():
callback.on_train_start(Trainer(accelerator="cpu"), model)


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)
@pytest.mark.flaky(reruns=3)
def test_servable_module_validator_with_trainer(tmpdir):
callback = ServableModuleValidator()
Expand Down
12 changes: 0 additions & 12 deletions tests/tests_pytorch/strategies/launchers/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from multiprocessing import Process
from unittest import mock
from unittest.mock import ANY, Mock, call, patch

import pytest
import torch
from lightning.fabric.plugins import ClusterEnvironment
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.strategies import DDPStrategy
Expand Down Expand Up @@ -196,11 +194,6 @@ def on_fit_start(self) -> None:
assert torch.equal(self.layer.weight.data, self.tied_layer.weight.data)


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)
def test_memory_sharing_disabled():
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race
conditions on model updates."""
Expand All @@ -221,11 +214,6 @@ def test_check_for_missing_main_guard():
launcher.launch(function=Mock())


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)
def test_fit_twice_raises():
model = BoringModel()
trainer = Trainer(
Expand Down
7 changes: 0 additions & 7 deletions tests/tests_pytorch/trainer/connectors/test_data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from re import escape
from typing import Sized
from unittest import mock
Expand All @@ -20,7 +19,6 @@
import lightning.fabric
import pytest
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
Expand Down Expand Up @@ -125,11 +123,6 @@ def on_train_end(self):
self.ctx.__exit__(None, None, None)


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader_persistent_workers_performance_warning(num_workers, tmp_path):
"""Test that when the multiprocessing start-method is 'spawn', we recommend setting `persistent_workers=True`."""
Expand Down
12 changes: 1 addition & 11 deletions tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@

import collections
import itertools
import sys
from re import escape
from unittest import mock
from unittest.mock import call

import numpy as np
import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.pytorch import Trainer, callbacks
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar
from lightning.pytorch.core.module import LightningModule
Expand Down Expand Up @@ -348,15 +346,7 @@ def validation_step(self, batch, batch_idx):
("devices", "accelerator"),
[
(1, "cpu"),
pytest.param(
2,
"cpu",
marks=pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
),
),
(2, "cpu"),
pytest.param(2, "gpu", marks=RunIf(min_cuda_gpus=2)),
],
)
Expand Down

0 comments on commit 956a895

Please sign in to comment.