Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: Inconsistent Behavior with StreamingDataloader loading states (specific to StreamingDataset) #316

Closed
bhimrazy opened this issue Aug 8, 2024 · 0 comments
Assignees
Labels
bug Something isn't working help wanted Extra attention is needed priority 0

Comments

@bhimrazy
Copy link
Collaborator

bhimrazy commented Aug 8, 2024

🐛 Bug

Bug: Inconsistent Behavior in StreamingDataloader after loading states

Description
The StreamingDataloader exhibits inconsistent behaviour when handling loading states across different scenarios. Specifically, issues arise when iterating over the dataloader after loading states with complete and partial first epoch.

To Reproduce

Create Optimized Dataset
from litdata import optimize


def random_data(index):
    return index

if __name__ == "__main__":
    optimize(
        fn=random_data,
        inputs=list(range(100)),
        output_dir="my_optimized_dataset",
        num_workers=4,
        chunk_bytes="64MB",
    )

Bugs

  1. Iterating over the dataloader after loading state with complete one epoch iteration throws error.
    a. Without loading state -> [OK]

    from litdata import StreamingDataLoader, StreamingDataset
    
    dataset = StreamingDataset("my_optimized_dataset")
    dataloader = StreamingDataLoader(dataset, num_workers=4, batch_size=4)
    
    print("Epoch", dataloader.current_epoch)
    for batch_idx, batch in enumerate(dataloader):
        print(batch_idx, end=" ")
    
    print("\nEpoch", dataloader.current_epoch)
    for batch_idx, batch in enumerate(dataloader):
        print(batch_idx, end=" ")

    Output

    Epoch 0
    0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 
    Epoch 1
    0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27                                            

    b. With loading state in intermediate step throws error -> [IndexError]

    from litdata import StreamingDataLoader, StreamingDataset
    
    dataset = StreamingDataset("my_optimized_dataset")
    dataloader = StreamingDataLoader(dataset, num_workers=4, batch_size=4)
    
    print("Epoch", dataloader.current_epoch)
    for batch_idx, batch in enumerate(dataloader):
        print(batch_idx, end=" ")
    
    # load dataloader state
    dataloader.load_state_dict(dataloader.state_dict())
    
    print("\nEpoch", dataloader.current_epoch)
    for batch_idx, batch in enumerate(dataloader):
        print(batch_idx, end=" ")

    Output

    Epoch 0
    0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 
    Epoch 1
    Traceback (most recent call last):
      File "/Users/bhimrajyadav/litdata/uncover_dataloader_bug.py", line 24, in <module>
        main()
      File "/Users/bhimrajyadav/litdata/uncover_dataloader_bug.py", line 19, in main
        for batch_idx, batch in enumerate(dataloader):
      File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/litdata/streaming/dataloader.py", line 623, in __iter__
        for batch in super().__iter__():
      File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
        data = self._next_data()
               ^^^^^^^^^^^^^^^^^
      File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1344, in _next_data
        return self._process_data(data)
               ^^^^^^^^^^^^^^^^^^^^^^^^
      File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1370, in _process_data
        data.reraise()
      File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/torch/_utils.py", line 706, in reraise
        raise exception
    IndexError: Caught IndexError in DataLoader worker process 0.
    Original Traceback (most recent call last):
      File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 253, in _worker_loop
        fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 80, in create_fetcher
        return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 22, in __init__
        self.dataset_iter = iter(dataset)
                            ^^^^^^^^^^^^^
      File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/litdata/streaming/dataset.py", line 236, in __iter__
        self._resume(workers_chunks, workers_intervals)
      File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/litdata/streaming/dataset.py", line 308, in _resume
        interval = self.worker_intervals[self.chunk_index]
                   ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
    IndexError: list index out of range                              
  2. Iterating over the dataloader after loading state with partial first epoch iteration do not reset after completing the epoch.

    from litdata import StreamingDataLoader, StreamingDataset
    
    dataset = StreamingDataset("my_optimized_dataset")
    dataloader = StreamingDataLoader(dataset, num_workers=4, batch_size=4)
    print("len of dataloader", len(dataloader))
    for batch_idx, batch in enumerate(dataloader):
        if batch_idx == 0:
            print("Epoch", dataloader.current_epoch)
        print(batch.numpy(), end=" ")
    
        if batch_idx == 20:
            break
    
    # load dataloader state
    dataloader.load_state_dict(dataloader.state_dict())
    
    for _ in range(3):
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx == 0:
                print("\nEpoch", dataloader.current_epoch)
            print(batch.numpy(), end=" ")

    Output:

    len of dataloader 25
    Epoch 1
    [0 1 2 3] [25 26 27 28] [50 51 52 53] [75 76 77 78] [4 5 6 7] [29 30 31 32] [54 55 56 57] [79 80 81 82] [ 8  9 10 11] [33 34 35 36] [58 59 60 61] [83 84 85 86] [12 13 14 15] [37 38 39 40] [62 63 64 65] [87 88 89 90] [16 17 18 19] [41 42 43 44] [66 67 68 69] [91 92 93 94] [20 21 22 23] 
    Epoch 1
    [45 46 47 48] [70 71 72 73] [95 96 97 98] [24] [49] [74] [99] 
    Epoch 2 # Expected: reset of epoch and display all the batches onwards from epoch 2
    [24] [45 46 47 48] [70 71 72 73] [95 96 97 98] [49] [74] [99] 
    Epoch 3
    [24] [45 46 47 48] [70 71 72 73] [95 96 97 98] [49] [74] [99] 
    
  3. Throws num workers error when loading state with num_worksers=0

    from litdata import StreamingDataLoader, StreamingDataset
    
    dataset = StreamingDataset("my_optimized_dataset")
    dataloader = StreamingDataLoader(dataset, batch_size=4)
    
    # load dataloader state
    dataloader.load_state_dict(dataloader.state_dict())
    
    batch = next(iter(dataloader))

    Output

    ValueError: The provided `num_workers` state doesn't match the current one. Found `1` instead of `0`.
  4. current_epoch is not synchronized with dataloader and dataset in dataloader state, when num_workers is not defined

    from litdata import StreamingDataLoader, StreamingDataset
    
    dataset = StreamingDataset("my_optimized_dataset")
    dataloader = StreamingDataLoader(dataset, num_workers=0, batch_size=4)
    for _ in dataloader:
    pass
    
    print("state dict", dataloader.state_dict())

    Output
    {'dataset': {'num_samples_yielded': 100,
    'num_workers': 0,
    'num_workers': 0,
    'batch_size': 4,
    'current_epoch': 2,
    'input_dir_path': '/Users/bhimrajyadav/litdata/my_optimized_dataset',
    'input_dir_url': None,
    'item_loader': None,
    'drop_last': False,
    'seed': 42,
    'world_size': 1,
    'shuffle': False,
    'subsampled_files': ['chunk-0-0.bin', 'chunk-1-0.bin'],
    'region_of_interest': [(0, 50), (0, 50)]},
    'current_epoch': 1,
    'num_samples_yielded': 100,
    'latest_worker_idx': 0}

Environment

  • PyTorch Version (e.g., 1.0): 2.4.0
  • OS (e.g., Linux): Mac OS
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.12.4
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

@bhimrazy bhimrazy added bug Something isn't working help wanted Extra attention is needed labels Aug 8, 2024
@bhimrazy bhimrazy self-assigned this Aug 9, 2024
@bhimrazy bhimrazy changed the title Bug: Inconsistent Behavior with StreamingDataloader loading states Bug: Inconsistent Behavior with StreamingDataloader loading states (specific for StreamingDataset) Aug 14, 2024
@bhimrazy bhimrazy changed the title Bug: Inconsistent Behavior with StreamingDataloader loading states (specific for StreamingDataset) Bug: Inconsistent Behavior with StreamingDataloader loading states (specific to StreamingDataset) Aug 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed priority 0
Projects
None yet
Development

No branches or pull requests

2 participants