Skip to content

Commit

Permalink
Simplify random seed in epoch data for reproducibility
Browse files Browse the repository at this point in the history
  • Loading branch information
Lauler committed Nov 28, 2024
1 parent eab4770 commit f060414
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions src/nanotron/data/nanoset.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,13 @@ def build_nanoset_index(self) -> np.ndarray:
)

# Shuffle indices in each epoch with different random seeds and concatenate them
r = np.random.RandomState(self.random_seed)
epoch_random_seeds = r.randint(0, 2**32 - 1, num_epochs)
dataset_indices = []
dataset_sample_indices = []
for i in range(num_epochs):
for num_epoch in range(num_epochs):
# Shuffle the sample and dataset indices in epoch with a given seed
numpy_random_state = np.random.RandomState(epoch_random_seeds[i])
numpy_random_state = np.random.RandomState(self.random_seed + num_epoch)
numpy_random_state.shuffle(dataset_index)
numpy_random_state = np.random.RandomState(epoch_random_seeds[i])
numpy_random_state = np.random.RandomState(self.random_seed + num_epoch)
numpy_random_state.shuffle(dataset_sample_index)

dataset_indices.append(dataset_index)
Expand Down

0 comments on commit f060414

Please sign in to comment.