Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Laion5b dataset example #1017

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/_build_test_upload.yml
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ jobs:
- 3.8
- 3.9
- "3.10"
- 3.11
- "3.11"
steps:
- name: Checkout Source Repository
uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ The following is the corresponding `torchdata` versions and supported Python ver

| `torch` | `torchdata` | `python` |
| -------------------- | ------------------ | ----------------- |
| `master` / `nightly` | `main` / `nightly` | `>=3.8`, `<=3.10` |
| `master` / `nightly` | `main` / `nightly` | `>=3.8`, `<=3.11` |
| `1.13.1` | `0.5.1` | `>=3.7`, `<=3.10` |
| `1.12.1` | `0.4.1` | `>=3.7`, `<=3.10` |
| `1.12.0` | `0.4.0` | `>=3.7`, `<=3.10` |
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Features described in this documentation are classified by release status:
dataloader2.rst
reading_service.rst


.. toctree::
:maxdepth: 2
:caption: Tutorial and Examples:
Expand Down
11 changes: 11 additions & 0 deletions docs/source/torchdata.datapipes.utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@ DataPipe Graph Visualization

to_graph

Commond Utility Functions
--------------------------------------
.. currentmodule:: torchdata.datapipes.utils

.. autosummary::
:nosignatures:
:toctree: generated/
:template: function.rst

pin_memory_fn


File Object and Stream Utility
-------------------------------------
Expand Down
84 changes: 84 additions & 0 deletions examples/vision/laion5b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from io import BytesIO

import requests

from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torchdata.datapipes.iter import HuggingFaceHubReader

try:
import PIL
from PIL import Image
except ImportError:
PIL = None
Image = None


def has_no_watermark(x):
return x["pwatermark"] is not None and x["pwatermark"] < 0.8


def is_sfw(x):
return x["punsafe"] is not None and x["punsafe"] < 0.5


def load_image(url):
try:
r = requests.get(url, timeout=5)
return Image.open(BytesIO(r.content))
except Exception:
return None


def image_was_loaded(x):
return x is not None


# For more information about the dataset see: https://laion.ai/blog/laion-5b/
# name of the dataset to be used
NAME = "laion/laion2B-en-joined"


# As the dataset is too large to store locally we use a streaming approach
def laion2b_en(name=NAME):
dp = HuggingFaceHubReader(name)
SvenDS9 marked this conversation as resolved.
Show resolved Hide resolved
dp = dp.filter(has_no_watermark)
dp = dp.filter(is_sfw)
dp = dp.shuffle().sharding_filter()
dp = dp.slice(index=["TEXT", "URL"])
dp = dp.map(fn=load_image, input_col="URL", output_col="IMAGE") # this needs multithreading
dp = dp.filter(filter_fn=image_was_loaded, input_col="IMAGE")
dp = dp.drop("URL")
dp = dp.batch(20)
return dp


def print_label_and_copyright(label, image):
try:
try:
exif = image.getexif()
# 0x8298 is the EXIF-tag for copyright
copyright_info = exif.get(0x8298, "no info")
except Exception:
copyright_info = "EXIF data is corrupted"
if copyright_info != "no info" and copyright_info != "EXIF data is corrupted":
print(f"image {i}: {label=}, {copyright_info=} ")
else:
print(f"image {i}: {label=}")
except PIL.UnidentifiedImageError:
print(f"image {i}: corrupted")


if __name__ == "__main__":
i = 0
dp = laion2b_en()
rs = MultiProcessingReadingService(num_workers=4)
dl = DataLoader2(dp, reading_service=rs)
for batch in dl:
for entry in batch:
print_label_and_copyright(entry["TEXT"], entry["IMAGE"])
i += 1
12 changes: 11 additions & 1 deletion packaging/build_conda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,19 @@ export CU_VERSION=cpu
export NO_CUDA_PACKAGE=1
export BUILD_TYPE="conda"

if [[ "$PYTHON_VERSION" == "3.11" ]]; then
export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c malfet"
fi

export SOURCE_ROOT_DIR="$PWD"
setup_env
setup_conda_pytorch_constraint

mkdir -p conda-bld
conda build $CONDA_CHANNEL_FLAGS --no-anaconda-upload --output-folder conda-bld --python "$PYTHON_VERSION" packaging/torchdata
conda build \
-c defaults \
$CONDA_CHANNEL_FLAGS \
--no-anaconda-upload \
--output-folder conda-bld \
--python "$PYTHON_VERSION" \
packaging/torchdata
3 changes: 2 additions & 1 deletion packaging/torchdata/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ test:
- cpuonly
- pytest
- expecttest
- fsspec
# fsspec doesn't support Python 3.11
# - fsspec
# The following packages are not on the default conda channel
# - iopath
# - rarfile
Expand Down
29 changes: 27 additions & 2 deletions test/dataloader2/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@

mp_ctx_parametrize = parametrize("ctx", mp.get_all_start_methods())

EXCEPTION_ITERATION_NUM = 7


class _ReadingServiceWrapper:
def __init__(self, dp):
Expand All @@ -79,6 +81,18 @@ def return_one():
return 1


class MakeMistakeDataPipe(IterDataPipe):
def __init__(self, source_datapipe, exc_iteration=EXCEPTION_ITERATION_NUM):
self.source_datapipe = source_datapipe
self.exc_iteration = exc_iteration

def __iter__(self):
for i, x in enumerate(self.source_datapipe):
if i == self.exc_iteration:
raise Exception("oops")
yield x


class TestReadingService(ReadingServiceInterface):
def initialize(self, dp: DataPipe) -> DataPipe:
return _ReadingServiceWrapper(dp) # type: ignore[return-value]
Expand All @@ -99,6 +113,19 @@ def test_dataloader2_shutdown(self) -> None:
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe)
data_loader.shutdown()

def test_worker_exception_raised(self):
dp = IterableWrapper(range(100)).sharding_filter()
dp = MakeMistakeDataPipe(dp)
for worker_prefetch_cnt in [0, 5, 10]:
for num_workers in [1, 4]:
rs = MultiProcessingReadingService(num_workers=num_workers, worker_prefetch_cnt=worker_prefetch_cnt)
dl = DataLoader2(dp, reading_service=rs)
it = iter(dl)
for i in range(EXCEPTION_ITERATION_NUM * num_workers):
next(it)
with self.assertRaises(communication.iter.WorkerException):
next(it)

def test_dataloader2_state_dict(self) -> None:
test_data_pipe = IterableWrapper(range(3))
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe)
Expand Down Expand Up @@ -171,7 +198,6 @@ def test_dataloader2_iterates_correctly(self) -> None:
self.assertEqual(list(range(10)), actual)

def test_dataloader2_reset(self) -> None:

test_data_pipe = IterableWrapper(range(10))
reading_services = [None, TestReadingService(), MultiProcessingReadingService(num_workers=1)]

Expand Down Expand Up @@ -264,7 +290,6 @@ def test_dataloader2_shuffle(self) -> None:
"fork is not supported. Dying (set die_after_fork=0 to override)",
)
class TestDataLoader2EventLoop(TestCase):

# TODO: This needs fixing, see issue 624
# @skipIfNoDill
# def test_basic_threading(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from unittest import TestCase

from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize
from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, PrototypeMultiProcessingReadingService
from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, MultiProcessingReadingService
from torchdata.datapipes.iter import IterableWrapper


Expand All @@ -29,9 +29,9 @@ def _add_one(x: int) -> int:
dp_parametrize = parametrize("dp", test_dps)


class TestPrototypeMultiProcessingReadingService(TestCase):
class TestMultiProcessingReadingService(TestCase):
r"""
This tests specific functionalities of PrototypeMultiProcessingReadingService, notably
This tests specific functionalities of MultiProcessingReadingService, notably
`pause`, `resume`, `snapshot`.
"""

Expand All @@ -40,7 +40,7 @@ def test_reading_service_pause_resume_0_worker(self, ctx) -> None:

# Functional Test: Verifies that this ReadingService will raise error when `pause/resume` is used
# with `num_workers = 0`
rs0 = PrototypeMultiProcessingReadingService(
rs0 = MultiProcessingReadingService(
num_workers=0, worker_prefetch_cnt=0, main_prefetch_cnt=0, multiprocessing_context=ctx
)
dl0: DataLoader2 = DataLoader2(dp1, reading_service=rs0)
Expand All @@ -64,7 +64,7 @@ def test_reading_service_pause_resume(self, ctx, dp, n_workers, worker_prefetch_

# Functional Test: Testing various configuration of DataPipe/ReadingService to ensure the pipeline
# properly pauses and resumes
rs = PrototypeMultiProcessingReadingService(
rs = MultiProcessingReadingService(
num_workers=n_workers,
worker_prefetch_cnt=worker_prefetch_cnt,
main_prefetch_cnt=main_prefetch_cnt,
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_reading_service_pause_resume(self, ctx, dp, n_workers, worker_prefetch_
def test_reading_service_pause_stop_yield(self, ctx, dp, n_workers, worker_prefetch_cnt, main_prefetch_cnt) -> None:

# Functional Test: Confirms that `dl` will stop yielding elements after `_pause` is called
rs = PrototypeMultiProcessingReadingService(
rs = MultiProcessingReadingService(
num_workers=n_workers,
worker_prefetch_cnt=worker_prefetch_cnt,
main_prefetch_cnt=main_prefetch_cnt,
Expand All @@ -117,7 +117,7 @@ def test_reading_service_pause_stop_yield(self, ctx, dp, n_workers, worker_prefe
@parametrize("n_workers,worker_prefetch_cnt,main_prefetch_cnt", [(1, 0, 0), (1, 0, 2), (2, 0, 0), (2, 2, 2)])
def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_prefetch_cnt) -> None:

rs = PrototypeMultiProcessingReadingService(
rs = MultiProcessingReadingService(
num_workers=n_workers, worker_prefetch_cnt=worker_prefetch_cnt, main_prefetch_cnt=main_prefetch_cnt
)

Expand Down Expand Up @@ -209,10 +209,10 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr
# those DPs belong to a dispatching process and only do pause if worker_id == 0
# There might still be a race condition, need to look into the messages

# rs1 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=0)
# rs2 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=2)
# rs3 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=0)
# rs4 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=2)
# rs1 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=0)
# rs2 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=2)
# rs3 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=0)
# rs4 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=2)
# rss = [rs1, rs2, rs3, rs4]

# for n, rs in enumerate(rss):
Expand Down Expand Up @@ -284,7 +284,7 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr
# pass


instantiate_parametrized_tests(TestPrototypeMultiProcessingReadingService)
instantiate_parametrized_tests(TestMultiProcessingReadingService)


if __name__ == "__main__":
Expand Down
44 changes: 43 additions & 1 deletion test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Dict

import expecttest
import torch.utils.data.datapipes.iter
import torch

import torchdata

Expand Down Expand Up @@ -42,6 +42,8 @@
)
from torchdata.datapipes.map import MapDataPipe, SequenceWrapper

skipIfNoCUDA = unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available")


def test_torchdata_pytorch_consistency() -> None:
def extract_datapipe_names(module):
Expand All @@ -68,6 +70,14 @@ def extract_datapipe_names(module):
raise AssertionError(msg + "\n".join(sorted(missing_datapipes)))


def _convert_to_tensor(data):
if isinstance(data, dict):
return {k: _convert_to_tensor(v) for k, v in data.items()}
elif isinstance(data, list):
return [_convert_to_tensor(v) for v in data]
return torch.tensor(data)


class TestIterDataPipe(expecttest.TestCase):
def test_in_memory_cache_holder_iterdatapipe(self) -> None:
source_dp = IterableWrapper(range(10))
Expand Down Expand Up @@ -1475,6 +1485,38 @@ def test_random_splitter_iterdatapipe(self):
next(it_train)
next(it_valid) # No error, can keep going

@skipIfNoCUDA
def test_pin_memory(self):
# Tensor
dp = IterableWrapper([(i, i + 1) for i in range(10)]).map(_convert_to_tensor).pin_memory()
self.assertTrue(all(d.is_pinned() for d in dp))

# List of Tensors
dp = IterableWrapper([[(i - 1, i), (i, i + 1)] for i in range(10)]).map(_convert_to_tensor).pin_memory()
self.assertTrue(all(d0.is_pinned() and d1.is_pinned() for d0, d1 in dp))

# Dict of Tensors
dp = IterableWrapper([{str(i): (i, i + 1)} for i in range(10)]).map(_convert_to_tensor).pin_memory()
self.assertTrue(all(v.is_pinned() for d in dp for v in d.values()))

# Dict of List of Tensors
dp = (
IterableWrapper([{str(i): [(i - 1, i), (i, i + 1)]} for i in range(10)])
.map(_convert_to_tensor)
.pin_memory()
)
self.assertTrue(all(v.is_pinned() for d in dp for batch in d.values() for v in batch))

# List of Dict of Tensors
dp = IterableWrapper([{str(i): (i, i + 1)} for i in range(10)]).map(_convert_to_tensor).batch(2).pin_memory()
self.assertTrue(all(v.is_pinned() for batch in dp for d in batch for v in d.values()))

# List of List of Tensors
dp = (
IterableWrapper([[(i - 1, i), (i, i + 1)] for i in range(10)]).map(_convert_to_tensor).batch(2).pin_memory()
)
self.assertTrue(all(d0.is_pinned() and d1.is_pinned() for batch in dp for d0, d1 in batch))


if __name__ == "__main__":
unittest.main()
5 changes: 5 additions & 0 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def _filter_by_module_availability(datapipes):
return [dp for dp in datapipes if dp[0] not in filter_set]


def _convert_to_tensor(data):
return torch.tensor(data)


class TestIterDataPipeSerialization(expecttest.TestCase):
def setUp(self):
self.temp_dir = create_temp_dir()
Expand Down Expand Up @@ -272,6 +276,7 @@ def test_serializable(self):
(),
{},
),
(iterdp.Prefetcher, None, (), {}),
(iterdp.ParquetDataFrameLoader, None, (), {"dtype": DTYPE}),
(iterdp.RarArchiveLoader, None, (), {}),
(
Expand Down
Loading