You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I’m currently working with the Grain Dataloader and encountered an issue with stopping and resuming training at each epoch boundary when initializing the sampler with num_epochs > 0. I’d like some guidance on how to handle this scenario properly and understand the shuffling mechanism across epochs.
The Issue:
I’m using the IndexSampler with num_epochs > 0, and my expectation is to stop the dataloader at the end of each epoch and resume training at the next epoch. However, the sampler seems to provide a total number of items equal to num_samples * num_epochs, making it behave as if all epochs are combined into a single stream.
Here’s how I understand the current behavior:
The IndexSampler creates a MapDataset with the number of records (num_records). This acts like an infinite list of record keys stitched together.
When calling __getitem__ with an index greater than num_records (but less than num_epochs * num_records), it acts like a modulus operation. This mechanism helps maintain access to one record per epoch.
Shuffling appears to happen per slice of num_records, allowing random access across epochs while maintaining the structure.
My Goal:
I want to stop at the end of the first epoch and then resume at the next epoch. My current implementation tracks the iterator state using get_state() and set_state(), but this feels somewhat clunky.
What I’ve Tried:
Tracking the iterator state:
After each batch is processed, I use get_state() to check the last_seen_indices.
Checking for the epoch boundary:
I attempt to detect when the indices reach the end of the current epoch by comparing the last_seen_indices to a calculated boundary value based on (current_epoch + 1) * num_records.
State update upon crossing the boundary:
Once the boundary is crossed, I update the state by setting the indices 'last_seen_indices' just before the start of the next epoch for each worker and resetting last_worker_index to ensure workers resume correctly.
Training loop logic:
After each batch, I check the epoch boundary. If it is crossed, I break the loop, save the state, and later resume from this state.
However, this approach feels a bit manual, and I’m wondering if there’s a more optimal or intended way to handle this scenario, especially in terms of the shuffling behavior across epochs and stopping/resuming the dataloader.
Questions:
Is there a cleaner way to stop the dataloader at the end of each epoch and resume from the next?
Is my approach of tracking the iterator state with get_state() and set_state() appropriate for this use case, or is there a better approach?
Thank you for your assistance!
The text was updated successfully, but these errors were encountered:
Hi there,
I’m currently working with the Grain Dataloader and encountered an issue with stopping and resuming training at each epoch boundary when initializing the sampler with
num_epochs > 0
. I’d like some guidance on how to handle this scenario properly and understand the shuffling mechanism across epochs.The Issue:
I’m using the
IndexSampler
withnum_epochs > 0
, and my expectation is to stop the dataloader at the end of each epoch and resume training at the next epoch. However, the sampler seems to provide a total number of items equal tonum_samples * num_epochs
, making it behave as if all epochs are combined into a single stream.Here’s how I understand the current behavior:
IndexSampler
creates aMapDataset
with the number of records (num_records
). This acts like an infinite list of record keys stitched together.__getitem__
with an index greater thannum_records
(but less thannum_epochs * num_records
), it acts like a modulus operation. This mechanism helps maintain access to one record per epoch.num_records
, allowing random access across epochs while maintaining the structure.My Goal:
I want to stop at the end of the first epoch and then resume at the next epoch. My current implementation tracks the iterator state using
get_state()
andset_state()
, but this feels somewhat clunky.What I’ve Tried:
Tracking the iterator state:
get_state()
to check thelast_seen_indices
.Checking for the epoch boundary:
last_seen_indices
to a calculated boundary value based on(current_epoch + 1) * num_records
.State update upon crossing the boundary:
last_worker_index
to ensure workers resume correctly.Training loop logic:
However, this approach feels a bit manual, and I’m wondering if there’s a more optimal or intended way to handle this scenario, especially in terms of the shuffling behavior across epochs and stopping/resuming the dataloader.
Questions:
get_state()
andset_state()
appropriate for this use case, or is there a better approach?Thank you for your assistance!
The text was updated successfully, but these errors were encountered: