diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 7fbc5d25de23d..d9f37f78b57fe 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Calling a method other than `forward` that invokes submodules is now an error when the model is wrapped (e.g., with DDP) ([#18819](https://github.com/Lightning-AI/lightning/pull/18819)) +- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846)) + ### Deprecated diff --git a/src/lightning/fabric/utilities/seed.py b/src/lightning/fabric/utilities/seed.py index bf6541b1ef408..ac82d022162bd 100644 --- a/src/lightning/fabric/utilities/seed.py +++ b/src/lightning/fabric/utilities/seed.py @@ -17,16 +17,16 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: - r"""Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition, - sets the following environment variables: + r"""Function that sets the seed for pseudo-random number generators in: torch, numpy, and Python's random module. + In addition, sets the following environment variables: - ``PL_GLOBAL_SEED``: will be passed to spawned subprocesses (e.g. ddp_spawn backend). - ``PL_SEED_WORKERS``: (optional) is set to 1 if ``workers=True``. Args: seed: the integer value seed for global random state in Lightning. - If ``None``, will read seed from ``PL_GLOBAL_SEED`` env variable - or select it randomly. + If ``None``, it will read the seed from ``PL_GLOBAL_SEED`` env variable. If ``None`` and the + ``PL_GLOBAL_SEED`` env variable is not set, then the seed defaults to 0. workers: if set to ``True``, will properly configure all dataloaders passed to the Trainer with a ``worker_init_fn``. If the user already provides such a function for their dataloaders, setting this argument will have no influence. See also: @@ -36,20 +36,20 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: if seed is None: env_seed = os.environ.get("PL_GLOBAL_SEED") if env_seed is None: - seed = _select_seed_randomly(min_seed_value, max_seed_value) + seed = 0 rank_zero_warn(f"No seed found, seed set to {seed}") else: try: seed = int(env_seed) except ValueError: - seed = _select_seed_randomly(min_seed_value, max_seed_value) + seed = 0 rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}") elif not isinstance(seed, int): seed = int(seed) if not (min_seed_value <= seed <= max_seed_value): rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") - seed = _select_seed_randomly(min_seed_value, max_seed_value) + seed = 0 log.info(rank_prefixed_message(f"Seed set to {seed}", _get_rank())) os.environ["PL_GLOBAL_SEED"] = str(seed) @@ -63,10 +63,6 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: return seed -def _select_seed_randomly(min_seed_value: int = min_seed_value, max_seed_value: int = max_seed_value) -> int: - return random.randint(min_seed_value, max_seed_value) # noqa: S311 - - def reset_seed() -> None: r"""Reset the seed to the value that :func:`~lightning.fabric.utilities.seed.seed_everything` previously set. diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 8e25bc3038bca..4b4858bc4fb39 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846)) - `LightningCLI` no longer allows setting a normal class instance as default. A `lazy_instance` can be used instead ([#18822](https://github.com/Lightning-AI/lightning/pull/18822)) diff --git a/tests/tests_fabric/utilities/test_seed.py b/tests/tests_fabric/utilities/test_seed.py index f8f6761ad2e24..f371afe96c3bb 100644 --- a/tests/tests_fabric/utilities/test_seed.py +++ b/tests/tests_fabric/utilities/test_seed.py @@ -4,10 +4,16 @@ import lightning.fabric.utilities import pytest import torch -from lightning.fabric.utilities import seed as seed_utils from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states +@mock.patch.dict(os.environ, clear=True) +def test_default_seed(): + """Test that the default seed is 0 when no seed provided and no environment variable set.""" + assert lightning.fabric.utilities.seed.seed_everything() == 0 + assert os.environ["PL_GLOBAL_SEED"] == "0" + + @mock.patch.dict(os.environ, {}, clear=True) def test_seed_stays_same_with_multiple_seed_everything_calls(): """Ensure that after the initial seed everything, the seed stays the same for the same run.""" @@ -30,22 +36,20 @@ def test_correct_seed_with_environment_variable(): @mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True) -@mock.patch.object(seed_utils, attribute="_select_seed_randomly", return_value=123) -def test_invalid_seed(_): +def test_invalid_seed(): """Ensure that we still fix the seed even if an invalid seed is given.""" with pytest.warns(UserWarning, match="Invalid seed found"): seed = lightning.fabric.utilities.seed.seed_everything() - assert seed == 123 + assert seed == 0 @mock.patch.dict(os.environ, {}, clear=True) -@mock.patch.object(seed_utils, attribute="_select_seed_randomly", return_value=123) @pytest.mark.parametrize("seed", [10e9, -10e9]) -def test_out_of_bounds_seed(_, seed): +def test_out_of_bounds_seed(seed): """Ensure that we still fix the seed even if an out-of-bounds seed is given.""" with pytest.warns(UserWarning, match="is not in bounds"): actual = lightning.fabric.utilities.seed.seed_everything(seed) - assert actual == 123 + assert actual == 0 def test_reset_seed_no_op():