Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling Per-Epoch Training with Grain Dataloader #572

Open
danbnyn opened this issue Sep 5, 2024 · 0 comments
Open

Handling Per-Epoch Training with Grain Dataloader #572

danbnyn opened this issue Sep 5, 2024 · 0 comments

Comments

@danbnyn
Copy link

danbnyn commented Sep 5, 2024

Hi there,

I’m currently working with the Grain Dataloader and encountered an issue with stopping and resuming training at each epoch boundary when initializing the sampler with num_epochs > 0. I’d like some guidance on how to handle this scenario properly and understand the shuffling mechanism across epochs.

The Issue:

I’m using the IndexSampler with num_epochs > 0, and my expectation is to stop the dataloader at the end of each epoch and resume training at the next epoch. However, the sampler seems to provide a total number of items equal to num_samples * num_epochs, making it behave as if all epochs are combined into a single stream.

Here’s how I understand the current behavior:

  • The IndexSampler creates a MapDataset with the number of records (num_records). This acts like an infinite list of record keys stitched together.
  • When calling __getitem__ with an index greater than num_records (but less than num_epochs * num_records), it acts like a modulus operation. This mechanism helps maintain access to one record per epoch.
  • Shuffling appears to happen per slice of num_records, allowing random access across epochs while maintaining the structure.

My Goal:

I want to stop at the end of the first epoch and then resume at the next epoch. My current implementation tracks the iterator state using get_state() and set_state(), but this feels somewhat clunky.

What I’ve Tried:

  1. Tracking the iterator state:

    • After each batch is processed, I use get_state() to check the last_seen_indices.
  2. Checking for the epoch boundary:

    • I attempt to detect when the indices reach the end of the current epoch by comparing the last_seen_indices to a calculated boundary value based on (current_epoch + 1) * num_records.
  3. State update upon crossing the boundary:

    • Once the boundary is crossed, I update the state by setting the indices 'last_seen_indices' just before the start of the next epoch for each worker and resetting last_worker_index to ensure workers resume correctly.
  4. Training loop logic:

    • After each batch, I check the epoch boundary. If it is crossed, I break the loop, save the state, and later resume from this state.

However, this approach feels a bit manual, and I’m wondering if there’s a more optimal or intended way to handle this scenario, especially in terms of the shuffling behavior across epochs and stopping/resuming the dataloader.

Questions:

  1. Is there a cleaner way to stop the dataloader at the end of each epoch and resume from the next?
  2. Is my approach of tracking the iterator state with get_state() and set_state() appropriate for this use case, or is there a better approach?

Thank you for your assistance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant