diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index d4126108367e..f39e773c9129 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -57,6 +57,7 @@ ParquetDatasource, BlockWritePathProvider, DefaultBlockWritePathProvider, + ReadTask, WriteResult, ) from ray.data.datasource.file_based_datasource import ( @@ -988,26 +989,28 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]": start_time = time.perf_counter() context = DatasetContext.get_current() - calls: List[Callable[[], ObjectRef[BlockPartition]]] = [] - metadata: List[BlockPartitionMetadata] = [] - block_partitions: List[ObjectRef[BlockPartition]] = [] + tasks: List[ReadTask] = [] + block_partition_refs: List[ObjectRef[BlockPartition]] = [] + block_partition_meta_refs: List[ObjectRef[BlockPartitionMetadata]] = [] datasets = [self] + list(other) for ds in datasets: bl = ds._plan.execute() if isinstance(bl, LazyBlockList): - calls.extend(bl._calls) - metadata.extend(bl._metadata) - block_partitions.extend(bl._block_partitions) + tasks.extend(bl._tasks) + block_partition_refs.extend(bl._block_partition_refs) + block_partition_meta_refs.extend(bl._block_partition_meta_refs) else: - calls.extend([None] * bl.initial_num_blocks()) - metadata.extend(bl._metadata) + tasks.extend([ReadTask(lambda: None, meta) for meta in bl._metadata]) if context.block_splitting_enabled: - block_partitions.extend( + block_partition_refs.extend( [ray.put([(b, m)]) for b, m in bl.get_blocks_with_metadata()] ) else: - block_partitions.extend(bl.get_blocks()) + block_partition_refs.extend(bl.get_blocks()) + block_partition_meta_refs.extend( + [ray.put(meta) for meta in bl._metadata] + ) epochs = [ds._get_epoch() for ds in datasets] max_epoch = max(*epochs) @@ -1028,7 +1031,8 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]": dataset_stats.time_total_s = time.perf_counter() - start_time return Dataset( ExecutionPlan( - LazyBlockList(calls, metadata, block_partitions), dataset_stats + LazyBlockList(tasks, block_partition_refs, block_partition_meta_refs), + dataset_stats, ), max_epoch, self._lazy, @@ -2548,6 +2552,7 @@ def repeat(self, times: Optional[int] = None) -> "DatasetPipeline[T]": # to enable fusion with downstream map stages. ctx = DatasetContext.get_current() if self._plan._is_read_stage() and ctx.optimize_fuse_read_stages: + self._plan._in_blocks.clear() blocks, read_stage = self._plan._rewrite_read_stage() outer_stats = DatasetStats(stages={}, parent=None) else: @@ -2666,6 +2671,7 @@ def window( # to enable fusion with downstream map stages. ctx = DatasetContext.get_current() if self._plan._is_read_stage() and ctx.optimize_fuse_read_stages: + self._plan._in_blocks.clear() blocks, read_stage = self._plan._rewrite_read_stage() outer_stats = DatasetStats(stages={}, parent=None) else: @@ -2749,12 +2755,13 @@ def fully_executed(self) -> "Dataset[T]": Returns: A Dataset with all blocks fully materialized in memory. """ - blocks = self.get_internal_block_refs() - bar = ProgressBar("Force reads", len(blocks)) - bar.block_until_complete(blocks) + blocks, metadata = [], [] + for b, m in self._plan.execute().get_blocks_with_metadata(): + blocks.append(b) + metadata.append(m) ds = Dataset( ExecutionPlan( - BlockList(blocks, self._plan.execute().get_metadata()), + BlockList(blocks, metadata), self._plan.stats(), dataset_uuid=self._get_uuid(), ), diff --git a/python/ray/data/impl/block_list.py b/python/ray/data/impl/block_list.py index d13ad0f54e7b..158f492335f8 100644 --- a/python/ray/data/impl/block_list.py +++ b/python/ray/data/impl/block_list.py @@ -1,15 +1,11 @@ import math -from typing import List, Iterator, Tuple, Any, Union, Optional, TYPE_CHECKING - -if TYPE_CHECKING: - import pyarrow +from typing import List, Iterator, Tuple import numpy as np import ray from ray.types import ObjectRef -from ray.data.block import Block, BlockMetadata, BlockAccessor -from ray.data.impl.remote_fn import cached_remote_fn +from ray.data.block import Block, BlockMetadata class BlockList: @@ -26,11 +22,7 @@ def __init__(self, blocks: List[ObjectRef[Block]], metadata: List[BlockMetadata] self._num_blocks = len(self._blocks) self._metadata: List[BlockMetadata] = metadata - def set_metadata(self, i: int, metadata: BlockMetadata) -> None: - """Set the metadata for a given block.""" - self._metadata[i] = metadata - - def get_metadata(self) -> List[BlockMetadata]: + def get_metadata(self, fetch_if_missing: bool = False) -> List[BlockMetadata]: """Get the metadata for all blocks.""" return self._metadata.copy() @@ -182,23 +174,3 @@ def executed_num_blocks(self) -> int: doesn't know how many blocks will be produced until tasks finish. """ return len(self.get_blocks()) - - def ensure_schema_for_first_block(self) -> Optional[Union["pyarrow.Schema", type]]: - """Ensure that the schema is set for the first block. - - Returns None if the block list is empty. - """ - get_schema = cached_remote_fn(_get_schema) - try: - block = next(self.iter_blocks()) - except (StopIteration, ValueError): - # Dataset is empty (no blocks) or was manually cleared. - return None - schema = ray.get(get_schema.remote(block)) - # Set the schema. - self._metadata[0].schema = schema - return schema - - -def _get_schema(block: Block) -> Any: - return BlockAccessor.for_block(block).schema() diff --git a/python/ray/data/impl/lazy_block_list.py b/python/ray/data/impl/lazy_block_list.py index 247d17e7e09b..3ddcc09bddaa 100644 --- a/python/ray/data/impl/lazy_block_list.py +++ b/python/ray/data/impl/lazy_block_list.py @@ -1,5 +1,6 @@ import math -from typing import Callable, List, Iterator, Tuple +from typing import List, Iterator, Tuple, Optional, Dict, Any +import uuid import numpy as np @@ -7,12 +8,18 @@ from ray.types import ObjectRef from ray.data.block import ( Block, + BlockAccessor, + BlockExecStats, BlockMetadata, BlockPartitionMetadata, MaybeBlockPartition, ) from ray.data.context import DatasetContext +from ray.data.datasource import ReadTask from ray.data.impl.block_list import BlockList +from ray.data.impl.progress_bar import ProgressBar +from ray.data.impl.remote_fn import cached_remote_fn +from ray.data.impl.stats import DatasetStats, _get_or_create_stats_actor class LazyBlockList(BlockList): @@ -25,100 +32,301 @@ class LazyBlockList(BlockList): def __init__( self, - calls: Callable[[], ObjectRef[MaybeBlockPartition]], - metadata: List[BlockPartitionMetadata], - block_partitions: List[ObjectRef[MaybeBlockPartition]] = None, + tasks: List[ReadTask], + block_partition_refs: Optional[List[ObjectRef[MaybeBlockPartition]]] = None, + block_partition_meta_refs: Optional[ + List[ObjectRef[BlockPartitionMetadata]] + ] = None, + cached_metadata: Optional[List[BlockPartitionMetadata]] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, + stats_uuid: str = None, ): - self._calls = calls - self._num_blocks = len(self._calls) - self._metadata = metadata - if block_partitions: - self._block_partitions = block_partitions + """Create a LazyBlockList on the provided read tasks. + + Args: + tasks: The read tasks that will produce the blocks of this lazy block list. + block_partition_refs: An optional list of already submitted read task + futures (i.e. block partition refs). This should be the same length as + the tasks argument. + block_partition_meta_refs: An optional list of block partition metadata + refs. This should be the same length as the tasks argument. + cached_metadata: An optional list of already computed AND fetched metadata. + This serves as a cache of fetched block metadata. + ray_remote_args: Ray remote arguments for the read tasks. + stats_uuid: UUID for the dataset stats, used to group and fetch read task + stats. If not provided, a new UUID will be created. + """ + self._tasks = tasks + self._num_blocks = len(self._tasks) + if stats_uuid is None: + stats_uuid = uuid.uuid4() + self._stats_uuid = stats_uuid + self._execution_started = False + self._remote_args = ray_remote_args or {} + # Block partition metadata that have already been computed and fetched. + if cached_metadata is not None: + self._cached_metadata = cached_metadata + else: + self._cached_metadata = [None] * len(tasks) + # Block partition metadata that have already been computed. + if block_partition_meta_refs is not None: + self._block_partition_meta_refs = block_partition_meta_refs + else: + self._block_partition_meta_refs = [None] * len(tasks) + # Block partitions that have already been computed. + if block_partition_refs is not None: + self._block_partition_refs = block_partition_refs + else: + self._block_partition_refs = [None] * len(tasks) + assert len(tasks) == len(self._block_partition_refs), ( + tasks, + self._block_partition_refs, + ) + assert len(tasks) == len(self._block_partition_meta_refs), ( + tasks, + self._block_partition_meta_refs, + ) + assert len(tasks) == len(self._cached_metadata), ( + tasks, + self._cached_metadata, + ) + + def get_metadata(self, fetch_if_missing: bool = False) -> List[BlockMetadata]: + """Get the metadata for all blocks.""" + if all(meta is not None for meta in self._cached_metadata): + # Always return fetched metadata if we already have it. + metadata = self._cached_metadata + elif not fetch_if_missing: + metadata = [ + m if m is not None else t.get_metadata() + for m, t in zip(self._cached_metadata, self._tasks) + ] else: - self._block_partitions = [None] * len(calls) - # Immediately compute the first block at least. - if calls: - self._block_partitions[0] = calls[0]() - assert len(calls) == len(metadata), (calls, metadata) - assert len(calls) == len(self._block_partitions), ( - calls, - self._block_partitions, + _, metadata = self._get_blocks_with_metadata() + return metadata + + def stats(self) -> DatasetStats: + """Create DatasetStats for this LazyBlockList.""" + return DatasetStats( + stages={"read": self.get_metadata(fetch_if_missing=False)}, + parent=None, + needs_stats_actor=True, + stats_uuid=self._stats_uuid, ) def copy(self) -> "LazyBlockList": return LazyBlockList( - self._calls.copy(), self._metadata.copy(), self._block_partitions.copy() + self._tasks.copy(), + block_partition_refs=self._block_partition_refs.copy(), + block_partition_meta_refs=self._block_partition_meta_refs.copy(), + cached_metadata=self._cached_metadata, + ray_remote_args=self._remote_args.copy(), + stats_uuid=self._stats_uuid, ) - def clear(self) -> None: - self._block_partitions = [None for _ in self._block_partitions] + def clear(self): + """Clears all object references (block partitions and base block partitions) + from this lazy block list. + """ + self._block_partition_refs = [None for _ in self._block_partition_refs] + self._block_partition_meta_refs = [ + None for _ in self._block_partition_meta_refs + ] + self._cached_metadata = [None for _ in self._cached_metadata] - def _check_if_cleared(self) -> None: + def _check_if_cleared(self): pass # LazyBlockList can always be re-computed. # Note: does not force execution prior to splitting. def split(self, split_size: int) -> List["LazyBlockList"]: - num_splits = math.ceil(len(self._calls) / split_size) - calls = np.array_split(self._calls, num_splits) - meta = np.array_split(self._metadata, num_splits) - block_partitions = np.array_split(self._block_partitions, num_splits) + num_splits = math.ceil(len(self._tasks) / split_size) + tasks = np.array_split(self._tasks, num_splits) + block_partition_refs = np.array_split(self._block_partition_refs, num_splits) + block_partition_meta_refs = np.array_split( + self._block_partition_meta_refs, num_splits + ) output = [] - for c, m, b in zip(calls, meta, block_partitions): - output.append(LazyBlockList(c.tolist(), m.tolist(), b.tolist())) + for t, b, m in zip(tasks, block_partition_refs, block_partition_meta_refs): + output.append( + LazyBlockList( + t.tolist(), + b.tolist(), + m.tolist(), + ) + ) return output # Note: does not force execution prior to splitting. def split_by_bytes(self, bytes_per_split: int) -> List["BlockList"]: self._check_if_cleared() output = [] - cur_calls, cur_meta, cur_blocks = [], [], [] + cur_tasks, cur_blocks, cur_blocks_meta = [], [], [] cur_size = 0 - for c, m, b in zip(self._calls, self._metadata, self._block_partitions): + for t, b, bm in zip( + self._tasks, + self._block_partition_refs, + self._block_partition_meta_refs, + ): + m = t.get_metadata() if m.size_bytes is None: raise RuntimeError( "Block has unknown size, cannot use split_by_bytes()" ) size = m.size_bytes if cur_blocks and cur_size + size > bytes_per_split: - output.append(LazyBlockList(cur_calls, cur_meta, cur_blocks)) - cur_calls, cur_meta, cur_blocks = [], [], [] + output.append( + LazyBlockList(cur_tasks, cur_blocks, cur_blocks_meta), + ) + cur_tasks, cur_blocks, cur_blocks_meta = [], [], [] cur_size = 0 - cur_calls.append(c) - cur_meta.append(m) + cur_tasks.append(t) cur_blocks.append(b) + cur_blocks_meta.append(b) cur_size += size if cur_blocks: - output.append(LazyBlockList(cur_calls, cur_meta, cur_blocks)) + output.append(LazyBlockList(cur_tasks, cur_blocks, cur_blocks_meta)) return output # Note: does not force execution prior to division. def divide(self, part_idx: int) -> ("LazyBlockList", "LazyBlockList"): left = LazyBlockList( - self._calls[:part_idx], - self._metadata[:part_idx], - self._block_partitions[:part_idx], + self._tasks[:part_idx], + self._block_partition_refs[:part_idx], + self._block_partition_meta_refs[:part_idx], ) right = LazyBlockList( - self._calls[part_idx:], - self._metadata[part_idx:], - self._block_partitions[part_idx:], + self._tasks[part_idx:], + self._block_partition_refs[part_idx:], + self._block_partition_meta_refs[part_idx:], ) return left, right def get_blocks(self) -> List[ObjectRef[Block]]: - # Force bulk evaluation of all block partitions futures. - list(self._iter_block_partitions()) - return list(self.iter_blocks()) + """Bulk version of iter_blocks(). + + Prefer calling this instead of the iter form for performance if you + don't need lazy evaluation. + """ + blocks, _ = self._get_blocks_with_metadata() + return blocks + + def get_blocks_with_metadata(self) -> List[Tuple[ObjectRef[Block], BlockMetadata]]: + """Bulk version of iter_blocks_with_metadata(). + + Prefer calling this instead of the iter form for performance if you + don't need lazy evaluation. + """ + blocks, metadata = self._get_blocks_with_metadata() + return list(zip(blocks, metadata)) + + def _get_blocks_with_metadata( + self, + ) -> Tuple[List[ObjectRef[Block]], List[BlockMetadata]]: + """Get all underlying block futures and concrete metadata. + + This will block on the completion of the underlying read tasks and will fetch + all block metadata outputted by those tasks. + """ + context = DatasetContext.get_current() + block_refs, meta_refs = [], [] + for block_ref, meta_ref in self._iter_block_partition_refs(): + block_refs.append(block_ref) + meta_refs.append(meta_ref) + if context.block_splitting_enabled: + # If block splitting is enabled, fetch the partitions. + parts = ray.get(block_refs) + block_refs, metadata = [], [] + for part in parts: + for block_ref, meta in part: + block_refs.append(block_ref) + metadata.append(meta) + self._cached_metadata = metadata + return block_refs, metadata + if all(meta is not None for meta in self._cached_metadata): + # Short-circuit on cached metadata. + return block_refs, self._cached_metadata + if not meta_refs: + # Short-circuit on empty set of block partitions. + assert not block_refs, block_refs + return [], [] + read_progress_bar = ProgressBar("Read progress", total=len(meta_refs)) + # Fetch the metadata in bulk. + # Handle duplicates (e.g. due to unioning the same dataset). + unique_meta_refs = set(meta_refs) + metadata = read_progress_bar.fetch_until_complete(list(unique_meta_refs)) + ref_to_data = { + meta_ref: data for meta_ref, data in zip(unique_meta_refs, metadata) + } + metadata = [ref_to_data[meta_ref] for meta_ref in meta_refs] + self._cached_metadata = metadata + return block_refs, metadata + + def compute_first_block(self): + """Kick off computation for the first block in the list. + + This is useful if looking to support rapid lightweight interaction with a small + amount of the dataset. + """ + if self._tasks: + self._get_or_compute(0) + + def ensure_metadata_for_first_block(self) -> Optional[BlockMetadata]: + """Ensure that the metadata is fetched and set for the first block. + + This will only block execution in order to fetch the post-read metadata for the + first block if the pre-read metadata for the first block has no schema. + + Returns: + None if the block list is empty, the metadata for the first block otherwise. + """ + if not self._tasks: + return None + metadata = self._tasks[0].get_metadata() + if metadata.schema is not None: + # If pre-read schema is not null, we consider it to be "good enough" and use + # it. + return metadata + # Otherwise, we trigger computation (if needed), wait until the task completes, + # and fetch the block partition metadata. + try: + _, metadata_ref = next(self._iter_block_partition_refs()) + except (StopIteration, ValueError): + # Dataset is empty (no blocks) or was manually cleared. + pass + else: + # This blocks until the underlying read task is finished. + metadata = ray.get(metadata_ref) + self._cached_metadata[0] = metadata + return metadata def iter_blocks_with_metadata( self, + block_for_metadata: bool = False, ) -> Iterator[Tuple[ObjectRef[Block], BlockMetadata]]: + """Iterate over the blocks along with their metadata. + + Note that, if block_for_metadata is False (default), this iterator returns + pre-read metadata from the ReadTasks given to this LazyBlockList so it doesn't + have to block on the execution of the read tasks. Therefore, the metadata may be + under-specified, e.g. missing schema or the number of rows. If fully-specified + block metadata is required, pass block_for_metadata=True. + + The length of this iterator is not known until execution. + + Args: + block_for_metadata: Whether we should block on the execution of read tasks + in order to obtain fully-specified block metadata. + + Returns: + An iterator of block references and the corresponding block metadata. + """ context = DatasetContext.get_current() outer = self class Iter: def __init__(self): - self._base_iter = outer._iter_block_partitions() + self._base_iter = outer._iter_block_partition_refs() + self._pos = -1 self._buffer = [] def __iter__(self): @@ -126,21 +334,36 @@ def __iter__(self): def __next__(self): while not self._buffer: + self._pos += 1 if context.block_splitting_enabled: part_ref, _ = next(self._base_iter) partition = ray.get(part_ref) else: - block, metadata = next(self._base_iter) - partition = [(block, metadata)] - for ref, metadata in partition: - self._buffer.append((ref, metadata)) + block_ref, metadata_ref = next(self._base_iter) + if block_for_metadata: + # This blocks until the read task completes, returning + # fully-specified block metadata. + metadata = ray.get(metadata_ref) + else: + # This does not block, returning (possibly under-specified) + # pre-read block metadata. + metadata = outer._tasks[self._pos].get_metadata() + partition = [(block_ref, metadata)] + for block_ref, metadata in partition: + self._buffer.append((block_ref, metadata)) return self._buffer.pop(0) return Iter() - def _iter_block_partitions( + def _iter_block_partition_refs( self, - ) -> Iterator[Tuple[ObjectRef[MaybeBlockPartition], BlockPartitionMetadata]]: + ) -> Iterator[ + Tuple[ObjectRef[MaybeBlockPartition], ObjectRef[BlockPartitionMetadata]] + ]: + """Iterate over the block futures and their corresponding metadata futures. + + This does NOT block on the execution of each submitted task. + """ outer = self class Iter: @@ -152,31 +375,80 @@ def __iter__(self): def __next__(self): self._pos += 1 - if self._pos < len(outer._calls): - return ( - outer._get_or_compute(self._pos), - outer._metadata[self._pos], - ) + if self._pos < len(outer._tasks): + return outer._get_or_compute(self._pos) raise StopIteration return Iter() - def _get_or_compute(self, i: int) -> ObjectRef[MaybeBlockPartition]: - assert i < len(self._calls), i - # Check if we need to compute more block_partitions. - if not self._block_partitions[i]: + def _get_or_compute( + self, + i: int, + ) -> Tuple[ObjectRef[MaybeBlockPartition], ObjectRef[BlockPartitionMetadata]]: + assert i < len(self._tasks), i + # Check if we need to compute more block_partition_refs. + if not self._block_partition_refs[i]: # Exponentially increase the number computed per batch. for j in range(max(i + 1, i * 2)): - if j >= len(self._block_partitions): + if j >= len(self._block_partition_refs): break - if not self._block_partitions[j]: - self._block_partitions[j] = self._calls[j]() - assert self._block_partitions[i], self._block_partitions - return self._block_partitions[i] + if not self._block_partition_refs[j]: + ( + self._block_partition_refs[j], + self._block_partition_meta_refs[j], + ) = self._submit_task(j) + assert self._block_partition_refs[i], self._block_partition_refs + assert self._block_partition_meta_refs[i], self._block_partition_meta_refs + return self._block_partition_refs[i], self._block_partition_meta_refs[i] + + def _submit_task( + self, task_idx: int + ) -> Tuple[ObjectRef[MaybeBlockPartition], ObjectRef[BlockPartitionMetadata]]: + """Submit the task with index task_idx.""" + stats_actor = _get_or_create_stats_actor() + if not self._execution_started: + stats_actor.record_start.remote(self._stats_uuid) + self._execution_started = True + task = self._tasks[task_idx] + return ( + cached_remote_fn(_execute_read_task) + .options(num_returns=2, **self._remote_args) + .remote( + i=task_idx, + task=task, + context=DatasetContext.get_current(), + stats_uuid=self._stats_uuid, + stats_actor=stats_actor, + ) + ) def _num_computed(self) -> int: i = 0 - for b in self._block_partitions: + for b in self._block_partition_refs: if b is not None: i += 1 return i + + +def _execute_read_task( + i: int, + task: ReadTask, + context: DatasetContext, + stats_uuid: str, + stats_actor: ray.actor.ActorHandle, +) -> Tuple[MaybeBlockPartition, BlockPartitionMetadata]: + DatasetContext._set_current(context) + stats = BlockExecStats.builder() + + # Execute the task. + block = task() + + metadata = task.get_metadata() + if context.block_splitting_enabled: + metadata.exec_stats = stats.build() + else: + metadata = BlockAccessor.for_block(block).get_metadata( + input_files=metadata.input_files, exec_stats=stats.build() + ) + stats_actor.record_task.remote(stats_uuid, i, metadata) + return block, metadata diff --git a/python/ray/data/impl/plan.py b/python/ray/data/impl/plan.py index ccd141a79c7a..d26d134bba92 100644 --- a/python/ray/data/impl/plan.py +++ b/python/ray/data/impl/plan.py @@ -1,4 +1,4 @@ -from typing import Callable, Tuple, Optional, Union, Iterable, TYPE_CHECKING +from typing import Callable, Tuple, Optional, Union, Iterable, Iterator, TYPE_CHECKING import uuid if TYPE_CHECKING: @@ -79,7 +79,14 @@ def schema( blocks = self._out_blocks else: blocks = self._in_blocks - metadata = blocks.get_metadata() if blocks else [] + if blocks: + # Don't force fetching in case it's a lazy block list, in which case we + # don't want to trigger full execution for a schema read. If we want to + # trigger execution to get schema, we'll trigger read tasks progressively + # until a viable schema is available, below. + metadata = blocks.get_metadata(fetch_if_missing=False) + else: + metadata = [] # Some blocks could be empty, in which case we cannot get their schema. # TODO(ekl) validate schema is the same across different blocks. for m in metadata: @@ -87,8 +94,13 @@ def schema( return m.schema if not fetch_if_missing: return None - # Need to synchronously fetch schema. - return blocks.ensure_schema_for_first_block() + # Synchronously fetch the schema. + # For lazy block lists, this launches read tasks and fetches block metadata + # until we find valid block schema. + for _, m in blocks.iter_blocks_with_metadata(): + if m.schema is not None and (m.num_rows is None or m.num_rows > 0): + return m.schema + return None def meta_count(self) -> Optional[int]: """Get the number of rows after applying all plan stages if possible. @@ -103,7 +115,7 @@ def meta_count(self) -> Optional[int]: else: blocks = self._in_blocks metadata = blocks.get_metadata() if blocks else None - if metadata and metadata[0].num_rows is not None: + if metadata and all(m.num_rows is not None for m in metadata): return sum(m.num_rows for m in metadata) else: return None @@ -166,9 +178,7 @@ def _rewrite_read_stages(self) -> None: def _has_read_stage(self) -> bool: """Whether this plan has a read stage for its input.""" - return isinstance(self._in_blocks, LazyBlockList) and hasattr( - self._in_blocks, "_read_tasks" - ) + return isinstance(self._in_blocks, LazyBlockList) def _is_read_stage(self) -> bool: """Whether this plan is a bare read stage.""" @@ -185,18 +195,17 @@ def _rewrite_read_stage(self) -> Tuple[BlockList, "Stage"]: [GetReadTasks -> MapBatches(DoRead -> Fn)]. """ # Generate the "GetReadTasks" stage blocks. - remote_args = self._in_blocks._read_remote_args + remote_args = self._in_blocks._remote_args blocks = [] metadata = [] - for i, read_task in enumerate(self._in_blocks._read_tasks): - blocks.append(ray.put([read_task])) - metadata.append(self._in_blocks._metadata[i]) + for read_task in self._in_blocks._tasks: + blocks.append(ray.put(read_task._read_fn)) + metadata.append(read_task.get_metadata()) block_list = BlockList(blocks, metadata) - def block_fn(block: Block) -> Iterable[Block]: - [read_task] = block - for tmp1 in read_task._read_fn(): - yield tmp1 + def block_fn(read_fn: Callable[[], Iterator[Block]]) -> Iterator[Block]: + for block in read_fn(): + yield block return block_list, OneToOneStage("read", block_fn, "tasks", remote_args) diff --git a/python/ray/data/impl/stats.py b/python/ray/data/impl/stats.py index ca022a1fa635..ae4bf48a9a0f 100644 --- a/python/ray/data/impl/stats.py +++ b/python/ray/data/impl/stats.py @@ -112,22 +112,23 @@ def get(self, stats_uuid): _stats_actor = [None, None] -def get_or_create_stats_actor(): +def _get_or_create_stats_actor(): # Need to re-create it if Ray restarts (mostly for unit tests). if ( not _stats_actor[0] or not ray.is_initialized() or _stats_actor[1] != ray.get_runtime_context().job_id.hex() ): - _stats_actor[0] = _StatsActor.remote() + _stats_actor[0] = _StatsActor.options( + name="datasets_stats_actor", get_if_exists=True + ).remote() _stats_actor[1] = ray.get_runtime_context().job_id.hex() - # Clear the actor handle after Ray reinits since it's no longer - # valid. - def clear_actor(): - _stats_actor[0] = None + # Clear the actor handle after Ray reinits since it's no longer valid. + def clear_actor(): + _stats_actor[0] = None - ray.worker._post_init_hooks.append(clear_actor) + ray.worker._post_init_hooks.append(clear_actor) return _stats_actor[0] @@ -143,7 +144,7 @@ def __init__( *, stages: Dict[str, List[BlockMetadata]], parent: Union[Optional["DatasetStats"], List["DatasetStats"]], - stats_actor=None, + needs_stats_actor=False, stats_uuid=None ): """Create dataset stats. @@ -153,8 +154,9 @@ def __init__( previous one. Typically one entry, e.g., {"map": [...]}. parent: Reference to parent Dataset's stats, or a list of parents if there are multiple. - stats_actor: Reference to actor where stats should be pulled - from. This is only used for Datasets using LazyBlockList. + needs_stats_actor: Whether this Dataset's stats needs a stats actor for + stats collection. This is currently only used for Datasets using a lazy + datasource (i.e. a LazyBlockList). stats_uuid: The uuid for the stats, used to fetch the right stats from the stats actor. """ @@ -171,7 +173,7 @@ def __init__( ) self.dataset_uuid: str = None self.time_total_s: float = 0 - self.stats_actor = stats_actor + self.needs_stats_actor = needs_stats_actor self.stats_uuid = stats_uuid # Iteration stats, filled out if the user iterates over the dataset. @@ -181,6 +183,10 @@ def __init__( self.iter_user_s: Timer = Timer() self.iter_total_s: Timer = Timer() + @property + def stats_actor(self): + return _get_or_create_stats_actor() + def child_builder(self, name: str) -> _DatasetStatsBuilder: """Start recording stats for an op of the given name (e.g., map).""" return _DatasetStatsBuilder(name, self) @@ -199,7 +205,7 @@ def summary_string(self, already_printed: Set[str] = None) -> str: if already_printed is None: already_printed = set() - if self.stats_actor: + if self.needs_stats_actor: # XXX this is a super hack, clean it up. stats_map, self.time_total_s = ray.get( self.stats_actor.get.remote(self.stats_uuid) @@ -249,8 +255,12 @@ def _summarize_iter(self) -> str: def _summarize_blocks(self, blocks: List[BlockMetadata]) -> str: exec_stats = [m.exec_stats for m in blocks if m.exec_stats is not None] + rounded_total = round(self.time_total_s, 2) + if rounded_total <= 0: + # Handle -0.0 case. + rounded_total = 0 out = "{}/{} blocks executed in {}s\n".format( - len(exec_stats), len(blocks), round(self.time_total_s, 2) + len(exec_stats), len(blocks), rounded_total ) if exec_stats: diff --git a/python/ray/data/read_api.py b/python/ray/data/read_api.py index 3c26d91b741b..d96ed8e996bb 100644 --- a/python/ray/data/read_api.py +++ b/python/ray/data/read_api.py @@ -7,11 +7,9 @@ Union, Optional, Tuple, - Callable, TypeVar, TYPE_CHECKING, ) -import uuid import numpy as np @@ -30,9 +28,7 @@ Block, BlockAccessor, BlockMetadata, - MaybeBlockPartition, BlockExecStats, - BlockPartitionMetadata, ) from ray.data.context import DatasetContext from ray.data.dataset import Dataset @@ -60,7 +56,7 @@ from ray.data.impl.lazy_block_list import LazyBlockList from ray.data.impl.plan import ExecutionPlan from ray.data.impl.remote_fn import cached_remote_fn -from ray.data.impl.stats import DatasetStats, get_or_create_stats_actor +from ray.data.impl.stats import DatasetStats from ray.data.impl.util import _lazy_import_pyarrow_dataset T = TypeVar("T") @@ -252,62 +248,17 @@ def read_datasource( "dataset blocks.".format(len(read_tasks), len(read_tasks)) ) - context = DatasetContext.get_current() - stats_actor = get_or_create_stats_actor() - stats_uuid = uuid.uuid4() - stats_actor.record_start.remote(stats_uuid) - - def remote_read(i: int, task: ReadTask, stats_actor) -> MaybeBlockPartition: - DatasetContext._set_current(context) - stats = BlockExecStats.builder() - - # Execute the read task. - block = task() - - if context.block_splitting_enabled: - metadata = task.get_metadata() - metadata.exec_stats = stats.build() - else: - metadata = BlockAccessor.for_block(block).get_metadata( - input_files=task.get_metadata().input_files, exec_stats=stats.build() - ) - stats_actor.record_task.remote(stats_uuid, i, metadata) - return block - if ray_remote_args is None: ray_remote_args = {} if "scheduling_strategy" not in ray_remote_args: ray_remote_args["scheduling_strategy"] = "SPREAD" - remote_read = cached_remote_fn(remote_read) - calls: List[Callable[[], ObjectRef[MaybeBlockPartition]]] = [] - metadata: List[BlockPartitionMetadata] = [] + block_list = LazyBlockList(read_tasks, ray_remote_args=ray_remote_args) + block_list.compute_first_block() + block_list.ensure_metadata_for_first_block() - for i, task in enumerate(read_tasks): - calls.append( - lambda i=i, task=task: remote_read.options(**ray_remote_args).remote( - i, task, stats_actor - ) - ) - metadata.append(task.get_metadata()) - - block_list = LazyBlockList(calls, metadata) - # TODO(ekl) consider refactoring LazyBlockList to take read_tasks explicitly. - block_list._read_tasks = read_tasks - block_list._read_remote_args = ray_remote_args - - # Get the schema from the first block synchronously. - if metadata and metadata[0].schema is None: - block_list.ensure_schema_for_first_block() - - stats = DatasetStats( - stages={"read": metadata}, - parent=None, - stats_actor=stats_actor, - stats_uuid=stats_uuid, - ) return Dataset( - ExecutionPlan(block_list, stats), + ExecutionPlan(block_list, block_list.stats()), 0, False, ) diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index ef22bd284a36..f8edb3a5782c 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -961,6 +961,16 @@ def test_schema(ray_start_regular_shared): ) +def test_schema_lazy(ray_start_regular_shared): + ds = ray.data.range(100, parallelism=10) + # We kick off the read task for the first block by default. + assert ds._plan._in_blocks._num_computed() == 1 + schema = ds.schema() + assert schema == int + # Fetching the schema should not trigger execution of extra read tasks. + assert ds._plan.execute()._num_computed() == 1 + + def test_lazy_loading_exponential_rampup(ray_start_regular_shared): ds = ray.data.range(100, parallelism=20) assert ds._plan.execute()._num_computed() == 1 diff --git a/python/ray/data/tests/test_dataset_formats.py b/python/ray/data/tests/test_dataset_formats.py index a4e9270c6545..1d2e67736b1a 100644 --- a/python/ray/data/tests/test_dataset_formats.py +++ b/python/ray/data/tests/test_dataset_formats.py @@ -763,7 +763,7 @@ def test_numpy_read(ray_start_regular_shared, tmp_path): np.save(os.path.join(path, "test.npy"), np.expand_dims(np.arange(0, 10), 1)) ds = ray.data.read_numpy(path) assert str(ds) == ( - "Dataset(num_blocks=1, num_rows=None, " + "Dataset(num_blocks=1, num_rows=10, " "schema={value: })" ) assert str(ds.take(2)) == "[{'value': array([0])}, {'value': array([1])}]" diff --git a/python/ray/data/tests/test_optimize.py b/python/ray/data/tests/test_optimize.py index bb2f353335b2..d0dcfc8bd57b 100644 --- a/python/ray/data/tests/test_optimize.py +++ b/python/ray/data/tests/test_optimize.py @@ -253,6 +253,34 @@ def test_optimize_reread_base_data(ray_start_regular_shared, local_path): assert num_reads == 1, num_reads +@pytest.mark.skip(reason="reusing base data not enabled") +@pytest.mark.parametrize("with_shuffle", [True, False]) +@pytest.mark.parametrize("enable_dynamic_splitting", [True, False]) +def test_optimize_lazy_reuse_base_data( + ray_start_regular_shared, local_path, enable_dynamic_splitting, with_shuffle +): + context = DatasetContext.get_current() + context.block_splitting_enabled = enable_dynamic_splitting + + num_blocks = 4 + dfs = [pd.DataFrame({"one": list(range(i, i + 4))}) for i in range(num_blocks)] + paths = [os.path.join(local_path, f"test{i}.csv") for i in range(num_blocks)] + for df, path in zip(dfs, paths): + df.to_csv(path, index=False) + counter = Counter.remote() + source = MySource(counter) + ds = ray.data.read_datasource(source, parallelism=4, paths=paths) + num_reads = ray.get(counter.get.remote()) + assert num_reads == 1, num_reads + ds = ds._experimental_lazy() + ds = ds.map(lambda x: x) + if with_shuffle: + ds = ds.random_shuffle() + ds.take() + num_reads = ray.get(counter.get.remote()) + assert num_reads == num_blocks, num_reads + + if __name__ == "__main__": import sys