Skip to content

Commit

Permalink
Adding kwargs for fs.open() in fsspec DataPipes (#804)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #804

Fixes #803

I left `FSSpecFileLister` untouched since I don't think it will be useful for `fs.ls()` to accept kwargs.

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D40038331

Pulled By: NivekT

fbshipit-source-id: 45232b938693690bc0906fc6240a104e80ef51f9
  • Loading branch information
NivekT committed Oct 7, 2022
1 parent 9ad8efb commit c5df338
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 15 deletions.
20 changes: 17 additions & 3 deletions test/test_fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@

from _utils._common_utils_for_test import create_temp_dir, create_temp_files, reset_after_n_next_calls

from torchdata.datapipes.iter import FileLister, FSSpecFileLister, FSSpecFileOpener, FSSpecSaver, IterableWrapper
from torchdata.datapipes.iter import (
FileLister,
FSSpecFileLister,
FSSpecFileOpener,
FSSpecSaver,
IterableWrapper,
IterDataPipe,
)

try:
import fsspec
Expand Down Expand Up @@ -55,7 +62,7 @@ def filepath_fn(name: str) -> str:

@skipIfNoFSSpec
def test_fsspec_file_lister_iterdatapipe(self):
datapipe = FSSpecFileLister(root="file://" + self.temp_sub_dir.name)
datapipe: IterDataPipe = FSSpecFileLister(root="file://" + self.temp_sub_dir.name)

# check all file paths within sub_folder are listed
for path in datapipe:
Expand All @@ -75,7 +82,9 @@ def test_fsspec_file_lister_iterdatapipe(self):

@skipIfNoFSSpec
def test_fsspec_file_lister_iterdatapipe_with_list(self):
datapipe = FSSpecFileLister(root=["file://" + self.temp_sub_dir.name, "file://" + self.temp_sub_dir_2.name])
datapipe: IterDataPipe = FSSpecFileLister(
root=["file://" + self.temp_sub_dir.name, "file://" + self.temp_sub_dir_2.name]
)

# check all file paths within sub_folder are listed
file_lister = list(map(lambda path: path.split("://")[1], datapipe))
Expand Down Expand Up @@ -109,11 +118,16 @@ def test_fsspec_file_lister_iterdatapipe_with_list(self):
def test_fsspec_file_loader_iterdatapipe(self):
datapipe1 = FSSpecFileLister(root="file://" + self.temp_sub_dir.name)
datapipe2 = FSSpecFileOpener(datapipe1)
datapipe3 = FSSpecFileOpener(datapipe1, kwargs_for_open={"encoding": "cp037"})

# check contents of file match
for _, f in datapipe2:
self.assertEqual(f.read(), "0123456789abcdef")

# Opened with a different encoding, hence NotEqual
for _, f in datapipe3:
self.assertNotEqual(f.read(), "0123456789abcdef")

# Reset Test: Ensure the resulting streams are still readable after the DataPipe is reset/exhausted
self._write_text_files()
lister_dp = FileLister(self.temp_dir.name, "*.text")
Expand Down
34 changes: 22 additions & 12 deletions torchdata/datapipes/iter/load/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import posixpath

from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union

from torch.utils.data.datapipes.utils.common import match_masks

Expand Down Expand Up @@ -67,11 +67,11 @@ def __init__(
else:
self.datapipe = root
self.masks = masks
self.kwargs = kwargs
self.kwargs_for_connection = kwargs

def __iter__(self) -> Iterator[str]:
for root in self.datapipe:
fs, path = fsspec.core.url_to_fs(root, **self.kwargs)
fs, path = fsspec.core.url_to_fs(root, **self.kwargs_for_connection)

if isinstance(fs.protocol, str):
protocol_list = [fs.protocol]
Expand Down Expand Up @@ -117,7 +117,8 @@ class FSSpecFileOpenerIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
Args:
source_datapipe: Iterable DataPipe that provides the pathnames or URLs
mode: An optional string that specifies the mode in which the file is opened (``"r"`` by default)
kwargs: Extra options that make sense to a particular storage connection,
kwargs_for_open: Optional Dict to specify kwargs for opening files (``fs.open()``)
kwargs: Extra options that are used to establish a particular storage connection,
e.g. host, port, username, password, etc.
Example:
Expand All @@ -126,17 +127,20 @@ class FSSpecFileOpenerIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
>>> file_dp = datapipe.open_files_by_fsspec()
"""

def __init__(self, source_datapipe: IterDataPipe[str], mode: str = "r", **kwargs) -> None:
def __init__(
self, source_datapipe: IterDataPipe[str], mode: str = "r", *, kwargs_for_open: Optional[Dict] = None, **kwargs
) -> None:
_assert_fsspec()

self.source_datapipe: IterDataPipe[str] = source_datapipe
self.mode: str = mode
self.kwargs = kwargs
self.kwargs_for_open = kwargs_for_open if kwargs_for_open is not None else {}
self.kwargs_for_connection = kwargs

def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
for file_uri in self.source_datapipe:
fs, path = fsspec.core.url_to_fs(file_uri, **self.kwargs)
file = fs.open(path, self.mode)
fs, path = fsspec.core.url_to_fs(file_uri, **self.kwargs_for_connection)
file = fs.open(path, self.mode, **self.kwargs_for_open)
yield file_uri, StreamWrapper(file)

def __len__(self) -> int:
Expand All @@ -158,9 +162,11 @@ class FSSpecSaverIterDataPipe(IterDataPipe[str]):
source_datapipe: Iterable DataPipe with tuples of metadata and data
mode: Mode in which the file will be opened for write the data (``"w"`` by default)
filepath_fn: Function that takes in metadata and returns the target path of the new file
kwargs: Extra options that make sense to a particular storage connection,
kwargs_for_open: Optional Dict to specify kwargs for opening files (``fs.open()``)
kwargs: Extra options that are used to establish a particular storage connection,
e.g. host, port, username, password, etc.
Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> def filepath_fn(name: str) -> str:
Expand All @@ -176,20 +182,24 @@ def __init__(
source_datapipe: IterDataPipe[Tuple[Any, U]],
mode: str = "w",
filepath_fn: Optional[Callable] = None,
*,
kwargs_for_open: Optional[Dict] = None,
**kwargs,
):
_assert_fsspec()

self.source_datapipe: IterDataPipe[Tuple[Any, U]] = source_datapipe
self.mode: str = mode
self.filepath_fn: Optional[Callable] = filepath_fn
self.kwargs = kwargs
self.kwargs_for_open = kwargs_for_open if kwargs_for_open is not None else {}
self.kwargs_for_connection = kwargs

def __iter__(self) -> Iterator[str]:
for meta, data in self.source_datapipe:
filepath = meta if self.filepath_fn is None else self.filepath_fn(meta)
fs, path = fsspec.core.url_to_fs(filepath, **self.kwargs)
with fs.open(path, self.mode) as f:
fs, path = fsspec.core.url_to_fs(filepath, **self.kwargs_for_connection)
with fs.open(path, self.mode, **self.kwargs_for_open) as f:
print(f)
f.write(data)
yield filepath

Expand Down

0 comments on commit c5df338

Please sign in to comment.