From 23f66c4f11c9d38dae441bb10ea50833a2699744 Mon Sep 17 00:00:00 2001 From: Jian Xiao <99709935+jianoaix@users.noreply.github.com> Date: Wed, 22 Feb 2023 20:28:19 -0800 Subject: [PATCH] [data] Streaming executor fixes #2 (#32759) Signed-off-by: Edward Oakes --- .../data/_internal/execution/legacy_compat.py | 4 +- python/ray/data/dataset.py | 2 +- python/ray/data/tests/test_dataset.py | 46 +++++++++++++++---- 3 files changed, 41 insertions(+), 11 deletions(-) diff --git a/python/ray/data/_internal/execution/legacy_compat.py b/python/ray/data/_internal/execution/legacy_compat.py index e5d59d0f3106..0579f38c2cea 100644 --- a/python/ray/data/_internal/execution/legacy_compat.py +++ b/python/ray/data/_internal/execution/legacy_compat.py @@ -266,11 +266,13 @@ def bulk_fn( def _bundles_to_block_list(bundles: Iterator[RefBundle]) -> BlockList: blocks, metadata = [], [] + owns_blocks = True for ref_bundle in bundles: + if not ref_bundle.owns_blocks: + owns_blocks = False for block, meta in ref_bundle.blocks: blocks.append(block) metadata.append(meta) - owns_blocks = all(b.owns_blocks for b in bundles) return BlockList(blocks, metadata, owned_by_consumer=owns_blocks) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index bd61c92f645e..065d3098d1a0 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2825,7 +2825,7 @@ def iter_rows(self, *, prefetch_blocks: int = 0) -> Iterator[Union[T, TableRow]] for batch in self.iter_batches( batch_size=None, prefetch_blocks=prefetch_blocks, batch_format=batch_format ): - batch = BlockAccessor.for_block(batch) + batch = BlockAccessor.for_block(BlockAccessor.batch_to_block(batch)) for row in batch.iter_rows(): yield row diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index 7322c9bc7df3..33624bd25d03 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -1369,17 +1369,26 @@ def test_count_lazy(ray_start_regular_shared): def test_lazy_loading_exponential_rampup(ray_start_regular_shared): ds = ray.data.range(100, parallelism=20) - assert ds._plan.execute()._num_computed() == 0 + + def check_num_computed(expected): + if ray.data.context.DatasetContext.get_current().use_streaming_executor: + # In streaing executor, ds.take() will not invoke partial execution + # in LazyBlocklist. + assert ds._plan.execute()._num_computed() == 0 + else: + assert ds._plan.execute()._num_computed() == expected + + check_num_computed(0) assert ds.take(10) == list(range(10)) - assert ds._plan.execute()._num_computed() == 2 + check_num_computed(2) assert ds.take(20) == list(range(20)) - assert ds._plan.execute()._num_computed() == 4 + check_num_computed(4) assert ds.take(30) == list(range(30)) - assert ds._plan.execute()._num_computed() == 8 + check_num_computed(8) assert ds.take(50) == list(range(50)) - assert ds._plan.execute()._num_computed() == 16 + check_num_computed(16) assert ds.take(100) == list(range(100)) - assert ds._plan.execute()._num_computed() == 20 + check_num_computed(20) def test_dataset_repr(ray_start_regular_shared): @@ -1696,7 +1705,14 @@ def to_pylist(table): # Default ArrowRows. for row, t_row in zip(ds.iter_rows(), to_pylist(t)): assert isinstance(row, TableRow) - assert isinstance(row, ArrowRow) + # In streaming, we set batch_format to "default" because calling + # ds.dataset_format() will still invoke bulk execution and we want + # to avoid that. As a result, it's receiving PandasRow (the defaut + # batch format). + if ray.data.context.DatasetContext.get_current().use_streaming_executor: + assert isinstance(row, PandasRow) + else: + assert isinstance(row, ArrowRow) assert row == t_row # PandasRows after conversion. @@ -1710,7 +1726,14 @@ def to_pylist(table): # Prefetch. for row, t_row in zip(ds.iter_rows(prefetch_blocks=1), to_pylist(t)): assert isinstance(row, TableRow) - assert isinstance(row, ArrowRow) + # In streaming, we set batch_format to "default" because calling + # ds.dataset_format() will still invoke bulk execution and we want + # to avoid that. As a result, it's receiving PandasRow (the defaut + # batch format). + if ray.data.context.DatasetContext.get_current().use_streaming_executor: + assert isinstance(row, PandasRow) + else: + assert isinstance(row, ArrowRow) assert row == t_row @@ -2181,7 +2204,12 @@ def test_lazy_loading_iter_batches_exponential_rampup(ray_start_regular_shared): ds = ray.data.range(32, parallelism=8) expected_num_blocks = [1, 2, 4, 4, 8, 8, 8, 8] for _, expected in zip(ds.iter_batches(batch_size=None), expected_num_blocks): - assert ds._plan.execute()._num_computed() == expected + if ray.data.context.DatasetContext.get_current().use_streaming_executor: + # In streaming execution of ds.iter_batches(), there is no partial + # execution so _num_computed() in LazyBlocklist is 0. + assert ds._plan.execute()._num_computed() == 0 + else: + assert ds._plan.execute()._num_computed() == expected def test_add_column(ray_start_regular_shared):