diff --git a/torchdata/__init__.py b/torchdata/__init__.py index 2f0f3c381..d304f9a26 100644 --- a/torchdata/__init__.py +++ b/torchdata/__init__.py @@ -8,9 +8,14 @@ from . import datapipes +janitor = datapipes.utils.janitor + try: from .version import __version__ # noqa: F401 except ImportError: pass -__all__ = ["datapipes"] +__all__ = [ + "datapipes", + "janitor", +] diff --git a/torchdata/datapipes/iter/util/bz2fileloader.py b/torchdata/datapipes/iter/util/bz2fileloader.py index 93f4739f7..442df0392 100644 --- a/torchdata/datapipes/iter/util/bz2fileloader.py +++ b/torchdata/datapipes/iter/util/bz2fileloader.py @@ -54,10 +54,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: try: extracted_fobj = bz2.open(data_stream, mode="rb") # type: ignore[call-overload] new_pathname = pathname.rstrip(".bz2") - yield new_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc] + yield new_pathname, StreamWrapper(extracted_fobj, data_stream, name=new_pathname) # type: ignore[misc] except Exception as e: warnings.warn(f"Unable to extract files from corrupted bzip2 stream {pathname} due to: {e}, abort!") raise e + finally: + if isinstance(data_stream, StreamWrapper): + data_stream.autoclose() def __len__(self) -> int: if self.length == -1: diff --git a/torchdata/datapipes/iter/util/combining.py b/torchdata/datapipes/iter/util/combining.py index a7148d26f..ea07248a7 100644 --- a/torchdata/datapipes/iter/util/combining.py +++ b/torchdata/datapipes/iter/util/combining.py @@ -11,6 +11,8 @@ from torch.utils.data import functional_datapipe, IterDataPipe, MapDataPipe from torch.utils.data.datapipes.utils.common import _check_unpickable_fn +from torchdata.datapipes.utils.janitor import janitor + T_co = TypeVar("T_co", covariant=True) @@ -109,6 +111,14 @@ def __iter__(self) -> Iterator: else: yield res + for remaining in ref_it: + janitor(remaining) + + # TODO(VItalyFedyunin): This should be Exception or warn when debug mode is enabled + if len(self.buffer) > 0: + for k, v in self.buffer.items(): + janitor(v) + def __len__(self) -> int: return len(self.source_datapipe) diff --git a/torchdata/datapipes/iter/util/decompressor.py b/torchdata/datapipes/iter/util/decompressor.py index 3cc949bde..30e153b42 100644 --- a/torchdata/datapipes/iter/util/decompressor.py +++ b/torchdata/datapipes/iter/util/decompressor.py @@ -97,7 +97,9 @@ def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]: for path, file in self.source_datapipe: file_type = self._detect_compression_type(path) decompressor = self._DECOMPRESSORS[file_type] - yield path, StreamWrapper(decompressor(file)) + yield path, StreamWrapper(decompressor(file), file, name=path) + if isinstance(file, StreamWrapper): + file.autoclose() @functional_datapipe("extract") diff --git a/torchdata/datapipes/iter/util/rararchiveloader.py b/torchdata/datapipes/iter/util/rararchiveloader.py index 5b5546316..d7c805fc4 100644 --- a/torchdata/datapipes/iter/util/rararchiveloader.py +++ b/torchdata/datapipes/iter/util/rararchiveloader.py @@ -107,7 +107,9 @@ def __iter__(self) -> Iterator[Tuple[str, io.BufferedIOBase]]: inner_path = os.path.join(path, info.filename) file_obj = rar.open(info) - yield inner_path, StreamWrapper(file_obj) # type: ignore[misc] + yield inner_path, StreamWrapper(file_obj, stream) # type: ignore[misc] + if isinstance(stream, StreamWrapper): + stream.autoclose() def __len__(self) -> int: if self.length == -1: diff --git a/torchdata/datapipes/iter/util/tararchiveloader.py b/torchdata/datapipes/iter/util/tararchiveloader.py index ccb9e0c0f..4c8adef19 100644 --- a/torchdata/datapipes/iter/util/tararchiveloader.py +++ b/torchdata/datapipes/iter/util/tararchiveloader.py @@ -67,10 +67,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: warnings.warn(f"failed to extract file {tarinfo.name} from source tarfile {pathname}") raise tarfile.ExtractError inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name)) - yield inner_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc] + yield inner_pathname, StreamWrapper(extracted_fobj, data_stream, name=inner_pathname) # type: ignore[misc] except Exception as e: warnings.warn(f"Unable to extract files from corrupted tarfile stream {pathname} due to: {e}, abort!") raise e + finally: + if isinstance(data_stream, StreamWrapper): + data_stream.autoclose() def __len__(self) -> int: if self.length == -1: diff --git a/torchdata/datapipes/iter/util/xzfileloader.py b/torchdata/datapipes/iter/util/xzfileloader.py index ac5c46beb..637560cce 100644 --- a/torchdata/datapipes/iter/util/xzfileloader.py +++ b/torchdata/datapipes/iter/util/xzfileloader.py @@ -56,10 +56,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: try: extracted_fobj = lzma.open(data_stream, mode="rb") # type: ignore[call-overload] new_pathname = pathname.rstrip(".xz") - yield new_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc] + yield new_pathname, StreamWrapper(extracted_fobj, data_stream) # type: ignore[misc] except Exception as e: warnings.warn(f"Unable to extract files from corrupted xz/lzma stream {pathname} due to: {e}, abort!") raise e + finally: + if isinstance(data_stream, StreamWrapper): + data_stream.autoclose() def __len__(self) -> int: if self.length == -1: diff --git a/torchdata/datapipes/iter/util/ziparchiveloader.py b/torchdata/datapipes/iter/util/ziparchiveloader.py index ae0c8f836..d61e061d9 100644 --- a/torchdata/datapipes/iter/util/ziparchiveloader.py +++ b/torchdata/datapipes/iter/util/ziparchiveloader.py @@ -67,10 +67,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: continue extracted_fobj = zips.open(zipinfo) inner_pathname = os.path.normpath(os.path.join(pathname, zipinfo.filename)) - yield inner_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc] + yield inner_pathname, StreamWrapper(extracted_fobj, data_stream, name=inner_pathname) # type: ignore[misc] except Exception as e: warnings.warn(f"Unable to extract files from corrupted zipfile stream {pathname} due to: {e}, abort!") raise e + finally: + if isinstance(data_stream, StreamWrapper): + data_stream.autoclose() # We are unable to close 'data_stream' here, because it needs to be available to use later def __len__(self) -> int: diff --git a/torchdata/datapipes/utils/__init__.py b/torchdata/datapipes/utils/__init__.py index 889fd4d08..c74e8f702 100644 --- a/torchdata/datapipes/utils/__init__.py +++ b/torchdata/datapipes/utils/__init__.py @@ -7,5 +7,6 @@ from torch.utils.data.datapipes.utils.common import StreamWrapper from ._visualization import to_graph +from .janitor import janitor -__all__ = ["StreamWrapper", "to_graph"] +__all__ = ["StreamWrapper", "janitor", "to_graph"] diff --git a/torchdata/datapipes/utils/janitor.py b/torchdata/datapipes/utils/janitor.py new file mode 100644 index 000000000..bc649123f --- /dev/null +++ b/torchdata/datapipes/utils/janitor.py @@ -0,0 +1,10 @@ +from torchdata.datapipes.utils import StreamWrapper + + +def janitor(obj): + """ + Invokes various `obj` cleanup procedures such as: + - Closing streams + """ + # TODO(VitalyFedyunin): We can also release caching locks here to allow filtering + StreamWrapper.close_streams(obj)