You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For this code I'd expect a crash related to not having preprocessed the data, but instead we get a pickling error.
Actual results
0%| | 0/1000 [00:00<?, ?it/s]Traceback (most recent call last):
File "/Users/dlwh/src/mistral/src/stream_fork_crash.py", line 7, in <module>
Trainer(model, train_dataset=ds, args=TrainingArguments("out", max_steps=1000, dataloader_num_workers=4)).train()
File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/site-packages/transformers/trainer.py", line 1339, in train
for step, inputs in enumerate(epoch_iterator):
File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 359, in __iter__
return self._get_iterator()
File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 305, in _get_iterator
return _MultiProcessingDataLoaderIter(self)
File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 918, in __init__
w.start()
File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/multiprocessing/process.py", line 121, in start
self._popen = self._Popen(self)
File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/multiprocessing/context.py", line 224, in _Popen
return _default_context.get_context().Process._Popen(process_obj)
File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/multiprocessing/context.py", line 284, in _Popen
return Popen(process_obj)
File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 32, in __init__
super().__init__(process_obj)
File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/multiprocessing/popen_fork.py", line 19, in __init__
self._launch(process_obj)
File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 47, in _launch
reduction.dump(process_obj, fp)
File "/Users/dlwh/.conda/envs/mistral/lib/python3.8/multiprocessing/reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'iterable_dataset.<locals>.TorchIterableDataset'
0%| | 0/1000 [00:00<?, ?it/s]
This immediate crash can be fixed by not using a local class to make the TorchIterableDataset (Note that you have to do with_format("torch") or you get an exception because the dataset has no len) However, any lambdas etc used as maps will also trigger this crash. A more permanent fix would be to move away from multiprocessing and instead use something like pathos or multiprocessing_on_dill (https://stackoverflow.com/questions/19984152/what-can-multiprocessing-and-dill-do-together)
Note that if you bypass this crash you get another crash. (I'll file a separate bug).
Environment info
datasets version: 2.0.0
Platform: macOS-12.2-arm64-arm-64bit
Python version: 3.8.12
PyArrow version: 7.0.0
Pandas version: 1.4.1
The text was updated successfully, but these errors were encountered:
Hi, thanks for reporting. This could be related to #3148 too
We should definitely make TorchIterableDataset picklable by moving it in the main code instead of inside a function. If you'd like to contribute, feel free to open a Pull Request :)
I'm also taking a look at your second issue, which is more technical
Describe the bug
Streaming Datasets can't be pickled, so any interaction between them and multiprocessing results in a crash.
Steps to reproduce the bug
Expected results
For this code I'd expect a crash related to not having preprocessed the data, but instead we get a pickling error.
Actual results
This immediate crash can be fixed by not using a local class to make the
TorchIterableDataset
(Note that you have to do with_format("torch") or you get an exception because the dataset has no len) However, any lambdas etc used as maps will also trigger this crash. A more permanent fix would be to move away from multiprocessing and instead use something like pathos or multiprocessing_on_dill (https://stackoverflow.com/questions/19984152/what-can-multiprocessing-and-dill-do-together)Note that if you bypass this crash you get another crash. (I'll file a separate bug).
Environment info
datasets
version: 2.0.0The text was updated successfully, but these errors were encountered: