Skip to content

Commit

Permalink
[DataPipe] Automatically close parent streams and discarded streams
Browse files Browse the repository at this point in the history
ghstack-source-id: 88237d7a9aa2f2e473bbfbe7311fdec6e067f0a9
Pull Request resolved: #560
  • Loading branch information
VitalyFedyunin committed Jun 29, 2022
1 parent 242ec0d commit c0c1554
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 8 deletions.
7 changes: 6 additions & 1 deletion torchdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,14 @@

from . import datapipes

janitor = datapipes.utils.janitor

try:
from .version import __version__ # noqa: F401
except ImportError:
pass

__all__ = ["datapipes"]
__all__ = [
"datapipes",
"janitor",
]
5 changes: 4 additions & 1 deletion torchdata/datapipes/iter/util/bz2fileloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
try:
extracted_fobj = bz2.open(data_stream, mode="rb") # type: ignore[call-overload]
new_pathname = pathname.rstrip(".bz2")
yield new_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc]
yield new_pathname, StreamWrapper(extracted_fobj, data_stream, name=new_pathname) # type: ignore[misc]
except Exception as e:
warnings.warn(f"Unable to extract files from corrupted bzip2 stream {pathname} due to: {e}, abort!")
raise e
finally:
if isinstance(data_stream, StreamWrapper):
data_stream.autoclose()

def __len__(self) -> int:
if self.length == -1:
Expand Down
10 changes: 10 additions & 0 deletions torchdata/datapipes/iter/util/combining.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from torch.utils.data import functional_datapipe, IterDataPipe, MapDataPipe
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn

from torchdata.datapipes.utils.janitor import janitor

T_co = TypeVar("T_co", covariant=True)


Expand Down Expand Up @@ -109,6 +111,14 @@ def __iter__(self) -> Iterator:
else:
yield res

for remaining in ref_it:
janitor(remaining)

# TODO(VItalyFedyunin): This should be Exception or warn when debug mode is enabled
if len(self.buffer) > 0:
for k, v in self.buffer.items():
janitor(v)

def __len__(self) -> int:
return len(self.source_datapipe)

Expand Down
4 changes: 3 additions & 1 deletion torchdata/datapipes/iter/util/decompressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
for path, file in self.source_datapipe:
file_type = self._detect_compression_type(path)
decompressor = self._DECOMPRESSORS[file_type]
yield path, StreamWrapper(decompressor(file))
yield path, StreamWrapper(decompressor(file), file, name=path)
if isinstance(file, StreamWrapper):
file.autoclose()


@functional_datapipe("extract")
Expand Down
4 changes: 3 additions & 1 deletion torchdata/datapipes/iter/util/rararchiveloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def __iter__(self) -> Iterator[Tuple[str, io.BufferedIOBase]]:
inner_path = os.path.join(path, info.filename)
file_obj = rar.open(info)

yield inner_path, StreamWrapper(file_obj) # type: ignore[misc]
yield inner_path, StreamWrapper(file_obj, stream) # type: ignore[misc]
if isinstance(stream, StreamWrapper):
stream.autoclose()

def __len__(self) -> int:
if self.length == -1:
Expand Down
5 changes: 4 additions & 1 deletion torchdata/datapipes/iter/util/tararchiveloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
warnings.warn(f"failed to extract file {tarinfo.name} from source tarfile {pathname}")
raise tarfile.ExtractError
inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name))
yield inner_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc]
yield inner_pathname, StreamWrapper(extracted_fobj, data_stream, name=inner_pathname) # type: ignore[misc]
except Exception as e:
warnings.warn(f"Unable to extract files from corrupted tarfile stream {pathname} due to: {e}, abort!")
raise e
finally:
if isinstance(data_stream, StreamWrapper):
data_stream.autoclose()

def __len__(self) -> int:
if self.length == -1:
Expand Down
5 changes: 4 additions & 1 deletion torchdata/datapipes/iter/util/xzfileloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
try:
extracted_fobj = lzma.open(data_stream, mode="rb") # type: ignore[call-overload]
new_pathname = pathname.rstrip(".xz")
yield new_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc]
yield new_pathname, StreamWrapper(extracted_fobj, data_stream) # type: ignore[misc]
except Exception as e:
warnings.warn(f"Unable to extract files from corrupted xz/lzma stream {pathname} due to: {e}, abort!")
raise e
finally:
if isinstance(data_stream, StreamWrapper):
data_stream.autoclose()

def __len__(self) -> int:
if self.length == -1:
Expand Down
5 changes: 4 additions & 1 deletion torchdata/datapipes/iter/util/ziparchiveloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
continue
extracted_fobj = zips.open(zipinfo)
inner_pathname = os.path.normpath(os.path.join(pathname, zipinfo.filename))
yield inner_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc]
yield inner_pathname, StreamWrapper(extracted_fobj, data_stream, name=inner_pathname) # type: ignore[misc]
except Exception as e:
warnings.warn(f"Unable to extract files from corrupted zipfile stream {pathname} due to: {e}, abort!")
raise e
finally:
if isinstance(data_stream, StreamWrapper):
data_stream.autoclose()
# We are unable to close 'data_stream' here, because it needs to be available to use later

def __len__(self) -> int:
Expand Down
3 changes: 2 additions & 1 deletion torchdata/datapipes/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
from torch.utils.data.datapipes.utils.common import StreamWrapper

from ._visualization import to_graph
from .janitor import janitor

__all__ = ["StreamWrapper", "to_graph"]
__all__ = ["StreamWrapper", "janitor", "to_graph"]
10 changes: 10 additions & 0 deletions torchdata/datapipes/utils/janitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from torchdata.datapipes.utils import StreamWrapper


def janitor(obj):
"""
Invokes various `obj` cleanup procedures such as:
- Closing streams
"""
# TODO(VitalyFedyunin): We can also release caching locks here to allow filtering
StreamWrapper.close_streams(obj)

0 comments on commit c0c1554

Please sign in to comment.