Skip to content

Commit

Permalink
Add testing for PyTorch 2.4 (Fabric) (#20028)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Jul 2, 2024
1 parent 37e04d0 commit 693c21a
Show file tree
Hide file tree
Showing 20 changed files with 84 additions and 104 deletions.
10 changes: 8 additions & 2 deletions .azure/gpu-tests-fabric.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ jobs:
"Lightning | latest":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0"
PACKAGE_NAME: "lightning"
"Lightning | future":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.4-cuda12.1.0"
PACKAGE_NAME: "lightning"
workspace:
clean: all
steps:
Expand All @@ -72,9 +75,12 @@ jobs:
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/cu${cuda_ver}/torch_stable.html"
scope=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(fabric="lightning_fabric").get(n, n))')
echo "##vso[task.setvariable variable=COVERAGE_SOURCE]$scope"
python_ver=$(python -c "import sys; print(f'{sys.version_info.major}{sys.version_info.minor}')")
echo "##vso[task.setvariable variable=PYTHON_VERSION_MM]$python_ver"
displayName: "set env. vars"
- bash: |
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}/torch_test.html"
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}"
echo "##vso[task.setvariable variable=TORCHVISION_URL]https://download.pytorch.org/whl/test/cu124/torchvision-0.19.0%2Bcu124-cp${PYTHON_VERSION_MM}-cp${PYTHON_VERSION_MM}-linux_x86_64.whl"
condition: endsWith(variables['Agent.JobName'], 'future')
displayName: "set env. vars 4 future"
Expand Down Expand Up @@ -103,7 +109,7 @@ jobs:
- bash: |
extra=$(python -c "print({'lightning': 'fabric-'}.get('$(PACKAGE_NAME)', ''))")
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}"
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}" --find-links="${TORCHVISION_URL}"
displayName: "Install package & dependencies"
- bash: |
Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/ci-tests-fabric.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ jobs:
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
# TODO: PyTorch 2.4 on Windows not yet working with `torch.distributed` (not compiled with libuv support)
# - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
# only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues
- { os: "macOS-12", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.1" }
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.1" }
Expand Down Expand Up @@ -79,7 +83,7 @@ jobs:
FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }}
PYPI_CACHE_DIR: "_pip-wheels"
TORCH_URL_STABLE: "https://download.pytorch.org/whl/cpu/torch_stable.html"
TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch_test.html"
TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch"
# TODO: Remove this - Enable running MPS tests on this platform
DISABLE_MPS: ${{ matrix.os == 'macOS-14' && '1' || '0' }}
steps:
Expand Down Expand Up @@ -118,7 +122,7 @@ jobs:
- name: Env. variables
run: |
# Switch PyTorch URL
python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.3' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.4' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
# Switch coverage scope
python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.pkg-name}}' == 'lightning' else 'lightning_fabric'))" >> $GITHUB_ENV
# if you install mono-package set dependency only for this subpackage
Expand Down
2 changes: 1 addition & 1 deletion requirements/fabric/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

numpy >=1.21.0, <1.27.0
torch >=2.1.0, <2.4.0
torch >=2.1.0, <2.5.0
fsspec[http] >=2022.5.0, <2024.4.0
packaging >=20.0, <=23.1
typing-extensions >=4.4.0, <4.10.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/fabric/examples.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

torchvision >=0.16.0, <0.19.0
torchvision >=0.16.0, <0.20.0
torchmetrics >=0.10.0, <1.3.0
lightning-utilities >=0.8.0, <0.12.0
5 changes: 3 additions & 2 deletions src/lightning/fabric/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.types import Optimizable


Expand All @@ -39,7 +40,7 @@ def __init__(
self,
precision: Literal["16-mixed", "bf16-mixed"],
device: str,
scaler: Optional[torch.cuda.amp.GradScaler] = None,
scaler: Optional["torch.cuda.amp.GradScaler"] = None,
) -> None:
if precision not in ("16-mixed", "bf16-mixed"):
raise ValueError(
Expand All @@ -49,7 +50,7 @@ def __init__(

self.precision = precision
if scaler is None and self.precision == "16-mixed":
scaler = torch.cuda.amp.GradScaler()
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
if scaler is not None and self.precision == "bf16-mixed":
raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
self.device = device
Expand Down
4 changes: 4 additions & 0 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import warnings
from contextlib import ExitStack, nullcontext
from datetime import timedelta
from functools import partial
Expand Down Expand Up @@ -83,6 +84,9 @@

_FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload")

# TODO: Switch to new state-dict APIs
warnings.filterwarnings("ignore", category=FutureWarning, message=".*FSDP.state_dict_type.*") # from torch >= 2.4


class FSDPStrategy(ParallelStrategy, _Sharded):
r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed.
Expand Down
18 changes: 6 additions & 12 deletions src/lightning/fabric/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class ModelParallelStrategy(ParallelStrategy):
Currently supports up to 2D parallelism. Specifically, it supports the combination of
Fully Sharded Data-Parallel 2 (FSDP2) with Tensor Parallelism (DTensor). These PyTorch APIs are currently still
experimental in PyTorch. Requires PyTorch 2.3 or newer.
experimental in PyTorch. Requires PyTorch 2.4 or newer.
Arguments:
parallelize_fn: A function that applies parallelisms to a module. The strategy will provide the
Expand All @@ -95,8 +95,8 @@ def __init__(
timeout: Optional[timedelta] = default_pg_timeout,
) -> None:
super().__init__()
if not _TORCH_GREATER_EQUAL_2_3:
raise ImportError(f"{type(self).__name__} requires PyTorch 2.3 or higher.")
if not _TORCH_GREATER_EQUAL_2_4:
raise ImportError(f"{type(self).__name__} requires PyTorch 2.4 or higher.")
self._parallelize_fn = parallelize_fn
self._data_parallel_size = data_parallel_size
self._tensor_parallel_size = tensor_parallel_size
Expand Down Expand Up @@ -178,7 +178,7 @@ def setup_module(self, module: TModel) -> TModel:
if any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules()):
raise TypeError(
"Found modules that are wrapped with `torch.distributed.fsdp.FullyShardedDataParallel`."
f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.3."
f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.4."
)

module = self._parallelize_fn(module, self.device_mesh)
Expand Down Expand Up @@ -329,10 +329,10 @@ def __init__(self, module: Module, enabled: bool) -> None:
self._enabled = enabled

def _set_requires_grad_sync(self, requires_grad_sync: bool) -> None:
from torch.distributed._composable.fsdp import FSDP
from torch.distributed._composable.fsdp import FSDPModule

for mod in self._module.modules():
if isinstance(mod, FSDP):
if isinstance(mod, FSDPModule):
mod.set_requires_gradient_sync(requires_grad_sync, recurse=False)

def __enter__(self) -> None:
Expand Down Expand Up @@ -458,9 +458,6 @@ def _load_checkpoint(
return metadata

if _is_full_checkpoint(path):
if not _TORCH_GREATER_EQUAL_2_4:
raise ImportError("Loading a non-distributed checkpoint into a distributed model requires PyTorch >= 2.4.")

checkpoint = torch.load(path, mmap=True, map_location="cpu")
_load_raw_module_state(checkpoint.pop(module_key), module, strict=strict)

Expand Down Expand Up @@ -546,9 +543,6 @@ def _load_raw_module_state(
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

if _has_dtensor_modules(module):
if not _TORCH_GREATER_EQUAL_2_4:
raise ImportError("Loading a non-distributed checkpoint into a distributed model requires PyTorch >= 2.4.")

from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict

state_dict_options = StateDictOptions(
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

_TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0")
_TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0")
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0", use_base_version=True)
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0")

_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
Expand Down
5 changes: 4 additions & 1 deletion src/lightning/fabric/utilities/testing/_runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from lightning.fabric.accelerators.cuda import num_cuda_devices
from lightning.fabric.accelerators.mps import MPSAccelerator
from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4


def _runif_reasons(
Expand Down Expand Up @@ -111,7 +112,9 @@ def _runif_reasons(
reasons.append("Standalone execution")
kwargs["standalone"] = True

if deepspeed and not (_DEEPSPEED_AVAILABLE and RequirementCache(module="deepspeed.utils")):
if deepspeed and not (
_DEEPSPEED_AVAILABLE and not _TORCH_GREATER_EQUAL_2_4 and RequirementCache(module="deepspeed.utils")
):
reasons.append("Deepspeed")

if dynamo:
Expand Down
1 change: 1 addition & 0 deletions tests/tests_fabric/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def thread_police_duuu_daaa_duuu_daaa():
elif (
thread.name == "QueueFeederThread" # tensorboardX
or thread.name == "QueueManagerThread" # torch.compile
or "(_read_thread)" in thread.name # torch.compile
):
thread.join(timeout=20)
elif (
Expand Down
7 changes: 5 additions & 2 deletions tests/tests_fabric/plugins/precision/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
import pytest
import torch
from lightning.fabric.plugins.precision.amp import MixedPrecision
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4


def test_amp_precision_default_scaler():
precision = MixedPrecision(precision="16-mixed", device=Mock())
assert isinstance(precision.scaler, torch.cuda.amp.GradScaler)
scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler
assert isinstance(precision.scaler, scaler_cls)


def test_amp_precision_scaler_with_bf16():
Expand All @@ -36,7 +38,8 @@ def test_amp_precision_forward_context():
"""Test to ensure that the context manager correctly is set to bfloat16 on CPU and CUDA."""
precision = MixedPrecision(precision="16-mixed", device="cuda")
assert precision.device == "cuda"
assert isinstance(precision.scaler, torch.cuda.amp.GradScaler)
scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler
assert isinstance(precision.scaler, scaler_cls)
assert torch.get_default_dtype() == torch.float32
with precision.forward_context():
assert torch.get_autocast_gpu_dtype() == torch.float16
Expand Down
4 changes: 3 additions & 1 deletion tests/tests_fabric/plugins/precision/test_amp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
import torch.nn as nn
from lightning.fabric import Fabric, seed_everything
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4

from tests_fabric.helpers.runif import RunIf

Expand Down Expand Up @@ -82,7 +83,8 @@ def run(fused=False):
optimizer = torch.optim.Adam(model.parameters(), lr=1.0, fused=fused)

model, optimizer = fabric.setup(model, optimizer)
assert isinstance(fabric._precision.scaler, torch.cuda.amp.GradScaler)
scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler
assert isinstance(fabric._precision.scaler, scaler_cls)

data = torch.randn(10, 10, device="cuda")
target = torch.randn(10, 10, device="cuda")
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_fabric/plugins/precision/test_bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self):
precision.convert_module(model)


@RunIf(min_cuda_gpus=1)
@RunIf(min_cuda_gpus=1, max_torch="2.4")
@pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable")
@pytest.mark.parametrize(
("args", "expected"),
Expand Down Expand Up @@ -232,7 +232,7 @@ def __init__(self):
assert model.l.weight.dtype == expected


@RunIf(min_cuda_gpus=1)
@RunIf(min_cuda_gpus=1, max_torch="2.4")
@pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable")
def test_load_quantized_checkpoint(tmp_path):
"""Test that a checkpoint saved from a quantized model can be loaded back into a quantized model."""
Expand Down
1 change: 1 addition & 0 deletions tests/tests_fabric/strategies/test_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __instancecheck__(self, instance):
assert strategy.get_module_state_dict(wrapped_module).keys() == original_module.state_dict().keys()


@pytest.mark.filterwarnings("ignore::FutureWarning")
@pytest.mark.parametrize(
"precision",
[
Expand Down
6 changes: 5 additions & 1 deletion tests/tests_fabric/strategies/test_fsdp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_model(self):
return model


@RunIf(min_cuda_gpus=2, standalone=True)
@RunIf(min_cuda_gpus=2, standalone=True, max_torch="2.4")
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
@pytest.mark.parametrize("manual_wrapping", [True, False])
def test_train_save_load(tmp_path, manual_wrapping, precision):
Expand Down Expand Up @@ -173,6 +173,7 @@ def test_train_save_load(tmp_path, manual_wrapping, precision):
assert state["coconut"] == 11


@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, standalone=True)
def test_save_full_state_dict(tmp_path):
"""Test that FSDP saves the full state into a single file with `state_dict_type="full"`."""
Expand Down Expand Up @@ -287,6 +288,7 @@ def test_save_full_state_dict(tmp_path):
trainer.run()


@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, standalone=True)
def test_load_full_state_dict_into_sharded_model(tmp_path):
"""Test that the strategy can load a full-state checkpoint into a FSDP sharded model."""
Expand Down Expand Up @@ -469,6 +471,7 @@ def _run_setup_assertions(empty_init, expected_device):
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))


@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, standalone=True)
def test_save_filter(tmp_path):
fabric = Fabric(accelerator="cuda", strategy=FSDPStrategy(state_dict_type="full"), devices=2)
Expand Down Expand Up @@ -602,6 +605,7 @@ def test_clip_gradients(clip_type, precision):
optimizer.zero_grad()


@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.3.0")
def test_save_sharded_and_consolidate_and_load(tmp_path):
"""Test the consolidation of a FSDP-sharded checkpoint into a single file."""
Expand Down
Loading

0 comments on commit 693c21a

Please sign in to comment.