Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming Datasets don't work with Transformers Trainer when dataloader_num_workers>1 #3950

Closed
dlwh opened this issue Mar 16, 2022 · 1 comment · Fixed by #4375
Closed

Streaming Datasets don't work with Transformers Trainer when dataloader_num_workers>1 #3950

dlwh opened this issue Mar 16, 2022 · 1 comment · Fixed by #4375
Assignees
Labels
bug Something isn't working good first issue Good for newcomers

Comments

@dlwh
Copy link

dlwh commented Mar 16, 2022

Describe the bug

Streaming Datasets can't be pickled, so any interaction between them and multiprocessing results in a crash.

Steps to reproduce the bug

import transformers
from transformers import Trainer, AutoModelForCausalLM, TrainingArguments
import datasets

ds = datasets.load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True).with_format("torch")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
Trainer(model, train_dataset=ds, args=TrainingArguments("out", max_steps=1000, dataloader_num_workers=4)).train()

Expected results

For this code I'd expect a crash related to not having preprocessed the data, but instead we get a pickling error.

Actual results

  0%|          | 0/1000 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/Users/dlwh/src/mistral/src/stream_fork_crash.py", line 7, in <module>
    Trainer(model, train_dataset=ds, args=TrainingArguments("out", max_steps=1000, dataloader_num_workers=4)).train()
  File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/site-packages/transformers/trainer.py", line 1339, in train
    for step, inputs in enumerate(epoch_iterator):
  File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 359, in __iter__
    return self._get_iterator()
  File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 305, in _get_iterator
    return _MultiProcessingDataLoaderIter(self)
  File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 918, in __init__
    w.start()
  File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/multiprocessing/context.py", line 224, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/multiprocessing/context.py", line 284, in _Popen
    return Popen(process_obj)
  File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'iterable_dataset.<locals>.TorchIterableDataset'
  0%|          | 0/1000 [00:00<?, ?it/s]

This immediate crash can be fixed by not using a local class to make the TorchIterableDataset (Note that you have to do with_format("torch") or you get an exception because the dataset has no len) However, any lambdas etc used as maps will also trigger this crash. A more permanent fix would be to move away from multiprocessing and instead use something like pathos or multiprocessing_on_dill (https://stackoverflow.com/questions/19984152/what-can-multiprocessing-and-dill-do-together)

Note that if you bypass this crash you get another crash. (I'll file a separate bug).

Environment info

  • datasets version: 2.0.0
  • Platform: macOS-12.2-arm64-arm-64bit
  • Python version: 3.8.12
  • PyArrow version: 7.0.0
  • Pandas version: 1.4.1
@lhoestq
Copy link
Member

lhoestq commented Mar 28, 2022

Hi, thanks for reporting. This could be related to #3148 too

We should definitely make TorchIterableDataset picklable by moving it in the main code instead of inside a function. If you'd like to contribute, feel free to open a Pull Request :)

I'm also taking a look at your second issue, which is more technical

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants