From c6756d155e007694ebab63271afb06799ff22fdb Mon Sep 17 00:00:00 2001 From: Sebastian Thomas Date: Fri, 21 Oct 2022 16:49:34 +0200 Subject: [PATCH] Add a masks option to filter files in s3 datapipe --- test/test_local_io.py | 2 +- test/test_s3io.py | 34 +++++++++++++++++++++++++++ torchdata/datapipes/iter/load/s3io.py | 18 +++++++++++--- 3 files changed, 50 insertions(+), 4 deletions(-) create mode 100644 test/test_s3io.py diff --git a/test/test_local_io.py b/test/test_local_io.py index c6013967b..c175d2448 100644 --- a/test/test_local_io.py +++ b/test/test_local_io.py @@ -674,7 +674,7 @@ def test_disk_cache_locks(self): # TODO(120): this test currently only covers reading from local # filesystem. It needs to be modified once test data can be stored on - # gdrive/s3/onedrive + # gdrive/onedrive @skipIfNoIoPath def test_io_path_file_lister_iterdatapipe(self): datapipe = IoPathFileLister(root=self.temp_sub_dir.name) diff --git a/test/test_s3io.py b/test/test_s3io.py new file mode 100644 index 000000000..2ab5a9f01 --- /dev/null +++ b/test/test_s3io.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from unittest.mock import MagicMock, patch + +import expecttest + +from torchdata.datapipes.iter import IterableWrapper, S3FileLister + + +@patch("torchdata._torchdata") +class TestS3FileListerIterDataPipe(expecttest.TestCase): + def test_list_files(self, mock_torchdata): + s3handler_mock = MagicMock() + mock_torchdata.S3Handler.return_value = s3handler_mock + s3handler_mock.list_files = MagicMock( + side_effect=[["s3://bucket-name/folder/a.txt", "s3://bucket-name/folder/b.csv"], []] + ) + s3_prefixes = IterableWrapper(["s3://bucket-name/folder/"]) + dp_s3_urls = S3FileLister(s3_prefixes) + assert list(dp_s3_urls) == ["s3://bucket-name/folder/a.txt", "s3://bucket-name/folder/b.csv"] + + def test_list_files_with_filter_mask(self, mock_torchdata): + s3handler_mock = MagicMock() + mock_torchdata.S3Handler.return_value = s3handler_mock + s3handler_mock.list_files = MagicMock( + side_effect=[["s3://bucket-name/folder/a.txt", "s3://bucket-name/folder/b.csv"], []] + ) + s3_prefixes = IterableWrapper(["s3://bucket-name/folder/"]) + dp_s3_urls = S3FileLister(s3_prefixes, masks="*.csv") + assert list(dp_s3_urls) == ["s3://bucket-name/folder/b.csv"] diff --git a/torchdata/datapipes/iter/load/s3io.py b/torchdata/datapipes/iter/load/s3io.py index b5e0960e6..3a17d43ed 100644 --- a/torchdata/datapipes/iter/load/s3io.py +++ b/torchdata/datapipes/iter/load/s3io.py @@ -5,9 +5,11 @@ # LICENSE file in the root directory of this source tree. from io import BytesIO -from typing import Iterator, Tuple +from typing import Iterator, List, Tuple, Union import torchdata + +from torch.utils.data.datapipes.utils.common import match_masks from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterDataPipe from torchdata.datapipes.utils import StreamWrapper @@ -49,19 +51,29 @@ class S3FileListerIterDataPipe(IterDataPipe[str]): ... pass """ - def __init__(self, source_datapipe: IterDataPipe[str], length: int = -1, request_timeout_ms=-1, region="") -> None: + def __init__( + self, + source_datapipe: IterDataPipe[str], + length: int = -1, + request_timeout_ms=-1, + region="", + masks: Union[str, List[str]] = "", + ) -> None: if not hasattr(torchdata, "_torchdata") or not hasattr(torchdata._torchdata, "S3Handler"): raise ModuleNotFoundError("TorchData must be built with BUILD_S3=1 to use this datapipe.") self.source_datapipe: IterDataPipe[str] = source_datapipe self.length: int = length self.handler = torchdata._torchdata.S3Handler(request_timeout_ms, region) + self.masks = masks def __iter__(self) -> Iterator[str]: for prefix in self.source_datapipe: while True: urls = self.handler.list_files(prefix) - yield from urls + for url in urls: + if match_masks(url, self.masks): + yield url if not urls: break self.handler.clear_marker()