Skip to content

Commit

Permalink
Seed NumPy using np.random.SeedSequence() in `pl_worker_init_functi…
Browse files Browse the repository at this point in the history
…on()` to robustly seed NumPy-dependent dataloader workers (#20369)

* Update seed.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update seed.py

* Update seed.py

* Update seed.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Luca Antiga <[email protected]>
  • Loading branch information
4 people authored Nov 25, 2024
1 parent 1f4a77c commit 29c0396
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/lightning/fabric/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:
if _NUMPY_AVAILABLE:
import numpy as np

np.random.seed(seed_sequence[3] & 0xFFFFFFFF) # numpy takes 32-bit seed only
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
np_rng_seed = ss.generate_state(4)

np.random.seed(np_rng_seed)


def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> list[int]:
Expand Down

0 comments on commit 29c0396

Please sign in to comment.