From cbdb9940ad1f6efee1b7dd1ef54ec61c246a6000 Mon Sep 17 00:00:00 2001 From: animan42 Date: Sat, 10 Aug 2024 02:45:10 +0000 Subject: [PATCH 1/4] Unit test for max_pre_download param --- tests/streaming/test_dataset.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index f653efec..846d580f 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -89,6 +89,16 @@ def test_streaming_dataset(tmpdir, monkeypatch, compression): assert len(dataloader) == 60 dataloader = DataLoader(dataset, num_workers=2, batch_size=2) assert len(dataloader) == 30 + +@pytest.mark.timeout(30) +def test_streaming_dataset_max_pre_download(tmpdir, monkeypatch, compression): + seed_everything(42) + + dataset = StreamingDataset(input_dir=str(tmpdir)) + assert dataset.cache._reader._max_pre_download == 2 + + dataset = StreamingDataset(input_dir=str(tmpdir), max_pre_download=10) + assert dataset.cache._reader._max_pre_download == 2 @pytest.mark.parametrize("drop_last", [False, True]) From 22da7645626cf7517261e66353c1573c576d3524 Mon Sep 17 00:00:00 2001 From: animan42 Date: Sat, 10 Aug 2024 02:47:23 +0000 Subject: [PATCH 2/4] fix asserts --- tests/streaming/test_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 846d580f..f1c7a751 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -98,7 +98,7 @@ def test_streaming_dataset_max_pre_download(tmpdir, monkeypatch, compression): assert dataset.cache._reader._max_pre_download == 2 dataset = StreamingDataset(input_dir=str(tmpdir), max_pre_download=10) - assert dataset.cache._reader._max_pre_download == 2 + assert dataset.cache._reader._max_pre_download == 10 @pytest.mark.parametrize("drop_last", [False, True]) From 644d125e85ac8740ea456fcec8f494779a929bb7 Mon Sep 17 00:00:00 2001 From: animan42 Date: Sat, 10 Aug 2024 02:53:57 +0000 Subject: [PATCH 3/4] Pass the test --- tests/streaming/test_dataset.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index f1c7a751..1aed5b58 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -91,13 +91,25 @@ def test_streaming_dataset(tmpdir, monkeypatch, compression): assert len(dataloader) == 30 @pytest.mark.timeout(30) -def test_streaming_dataset_max_pre_download(tmpdir, monkeypatch, compression): +def test_streaming_dataset_max_pre_download(tmpdir): seed_everything(42) + + cache = Cache(str(tmpdir), chunk_size=10) + for i in range(60): + cache[i] = i + cache.done() + cache.merge() dataset = StreamingDataset(input_dir=str(tmpdir)) + assert len(dataset) == 60 + for i in range(60): + assert dataset[i] == i assert dataset.cache._reader._max_pre_download == 2 dataset = StreamingDataset(input_dir=str(tmpdir), max_pre_download=10) + assert len(dataset) == 60 + for i in range(60): + assert dataset[i] == i assert dataset.cache._reader._max_pre_download == 10 From 6413cb650cfd56cd4e9adb20340a0a1de8bab02f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 10 Aug 2024 02:54:45 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/streaming/test_dataset.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 1aed5b58..2bb6e834 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -89,11 +89,12 @@ def test_streaming_dataset(tmpdir, monkeypatch, compression): assert len(dataloader) == 60 dataloader = DataLoader(dataset, num_workers=2, batch_size=2) assert len(dataloader) == 30 - + + @pytest.mark.timeout(30) def test_streaming_dataset_max_pre_download(tmpdir): seed_everything(42) - + cache = Cache(str(tmpdir), chunk_size=10) for i in range(60): cache[i] = i @@ -105,7 +106,7 @@ def test_streaming_dataset_max_pre_download(tmpdir): for i in range(60): assert dataset[i] == i assert dataset.cache._reader._max_pre_download == 2 - + dataset = StreamingDataset(input_dir=str(tmpdir), max_pre_download=10) assert len(dataset) == 60 for i in range(60):