diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index e63c41620..309296edb 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import multiprocessing as py_mp +import pickle import warnings from abc import ABC, abstractmethod @@ -621,17 +622,20 @@ def checkpoint(self) -> bytes: else: warnings.warn(f"{rs} doesn't support `checkpoint`, skipping...") states.append(b"") - return b"\n".join(states) + return pickle.dumps(states) # Sequential Order, to align with initialize def restore(self, datapipe, serialized_state: bytes) -> DataPipe: - states = serialized_state.split(b"\n") + states = pickle.loads(serialized_state) assert len(states) == len(self.reading_services) for rs, state in zip(self.reading_services, states): if hasattr(rs, "restore") and callable(rs.restore): datapipe = rs.restore(datapipe, state) else: - warnings.warn(f"{rs} doesn't support `restore` from state, skipping...") + warnings.warn( + f"{rs} doesn't support `restore` from state, initialize from scratch" + ) + datapipe = rs.initialize(datapipe) return datapipe def _pause(