Skip to content

Commit

Permalink
[Datasets] [Out-of-Band Serialization: 1/3] Refactor LazyBlockList. (
Browse files Browse the repository at this point in the history
…ray-project#23821)

This PR refactors `LazyBlockList` in service of out-of-band serialization (see [mono-PR](ray-project#22616)) and is a precursor to an execution plan refactor (PR #2) and adding the actual out-of-band serialization APIs (PR #3). The following is included in this refactor:
1. `ReadTask`s are now a first-class concept, replacing calls;
2. read stage progress tracking is consolidated into `LazyBlockList._get_blocks_with_metadta()` and more of the read task complexity, e.g. the read remote function, was pushed into `LazyBlockList` to make `ray.data.read_datasource()` simpler;
3. we are a bit smarter with how we progressively launch tasks and fetch and cache metadata, including fetching the metadata for read tasks in `.iter_blocks_with_metadata()` instead of relying on the pre-read task metadata (which will be less accurate), and we also fix some small bugs in the lazy ramp-up around progressive metadata fetching.

(1) is the most important item for supporting out-of-band serialization and fundamentally changes the `LazyBlockList` data model. This is required since we need to be able to reference the underlying read tasks when rewriting read stages during optimization and when serializing the lineage of the Dataset. See the [mono-PR](ray-project#22616) for more context.

Other changes:
1. Changed stats actor to a global named actor singleton in order to obviate the need for serializing the actor handle with the Dataset stats; without this, we were encountering serialization failures.
  • Loading branch information
clarkzinzow authored Apr 14, 2022
1 parent d96ac25 commit efc5ac5
Show file tree
Hide file tree
Showing 9 changed files with 454 additions and 195 deletions.
37 changes: 22 additions & 15 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
ParquetDatasource,
BlockWritePathProvider,
DefaultBlockWritePathProvider,
ReadTask,
WriteResult,
)
from ray.data.datasource.file_based_datasource import (
Expand Down Expand Up @@ -988,26 +989,28 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]":

start_time = time.perf_counter()
context = DatasetContext.get_current()
calls: List[Callable[[], ObjectRef[BlockPartition]]] = []
metadata: List[BlockPartitionMetadata] = []
block_partitions: List[ObjectRef[BlockPartition]] = []
tasks: List[ReadTask] = []
block_partition_refs: List[ObjectRef[BlockPartition]] = []
block_partition_meta_refs: List[ObjectRef[BlockPartitionMetadata]] = []

datasets = [self] + list(other)
for ds in datasets:
bl = ds._plan.execute()
if isinstance(bl, LazyBlockList):
calls.extend(bl._calls)
metadata.extend(bl._metadata)
block_partitions.extend(bl._block_partitions)
tasks.extend(bl._tasks)
block_partition_refs.extend(bl._block_partition_refs)
block_partition_meta_refs.extend(bl._block_partition_meta_refs)
else:
calls.extend([None] * bl.initial_num_blocks())
metadata.extend(bl._metadata)
tasks.extend([ReadTask(lambda: None, meta) for meta in bl._metadata])
if context.block_splitting_enabled:
block_partitions.extend(
block_partition_refs.extend(
[ray.put([(b, m)]) for b, m in bl.get_blocks_with_metadata()]
)
else:
block_partitions.extend(bl.get_blocks())
block_partition_refs.extend(bl.get_blocks())
block_partition_meta_refs.extend(
[ray.put(meta) for meta in bl._metadata]
)

epochs = [ds._get_epoch() for ds in datasets]
max_epoch = max(*epochs)
Expand All @@ -1028,7 +1031,8 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]":
dataset_stats.time_total_s = time.perf_counter() - start_time
return Dataset(
ExecutionPlan(
LazyBlockList(calls, metadata, block_partitions), dataset_stats
LazyBlockList(tasks, block_partition_refs, block_partition_meta_refs),
dataset_stats,
),
max_epoch,
self._lazy,
Expand Down Expand Up @@ -2548,6 +2552,7 @@ def repeat(self, times: Optional[int] = None) -> "DatasetPipeline[T]":
# to enable fusion with downstream map stages.
ctx = DatasetContext.get_current()
if self._plan._is_read_stage() and ctx.optimize_fuse_read_stages:
self._plan._in_blocks.clear()
blocks, read_stage = self._plan._rewrite_read_stage()
outer_stats = DatasetStats(stages={}, parent=None)
else:
Expand Down Expand Up @@ -2666,6 +2671,7 @@ def window(
# to enable fusion with downstream map stages.
ctx = DatasetContext.get_current()
if self._plan._is_read_stage() and ctx.optimize_fuse_read_stages:
self._plan._in_blocks.clear()
blocks, read_stage = self._plan._rewrite_read_stage()
outer_stats = DatasetStats(stages={}, parent=None)
else:
Expand Down Expand Up @@ -2749,12 +2755,13 @@ def fully_executed(self) -> "Dataset[T]":
Returns:
A Dataset with all blocks fully materialized in memory.
"""
blocks = self.get_internal_block_refs()
bar = ProgressBar("Force reads", len(blocks))
bar.block_until_complete(blocks)
blocks, metadata = [], []
for b, m in self._plan.execute().get_blocks_with_metadata():
blocks.append(b)
metadata.append(m)
ds = Dataset(
ExecutionPlan(
BlockList(blocks, self._plan.execute().get_metadata()),
BlockList(blocks, metadata),
self._plan.stats(),
dataset_uuid=self._get_uuid(),
),
Expand Down
34 changes: 3 additions & 31 deletions python/ray/data/impl/block_list.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import math
from typing import List, Iterator, Tuple, Any, Union, Optional, TYPE_CHECKING

if TYPE_CHECKING:
import pyarrow
from typing import List, Iterator, Tuple

import numpy as np

import ray
from ray.types import ObjectRef
from ray.data.block import Block, BlockMetadata, BlockAccessor
from ray.data.impl.remote_fn import cached_remote_fn
from ray.data.block import Block, BlockMetadata


class BlockList:
Expand All @@ -26,11 +22,7 @@ def __init__(self, blocks: List[ObjectRef[Block]], metadata: List[BlockMetadata]
self._num_blocks = len(self._blocks)
self._metadata: List[BlockMetadata] = metadata

def set_metadata(self, i: int, metadata: BlockMetadata) -> None:
"""Set the metadata for a given block."""
self._metadata[i] = metadata

def get_metadata(self) -> List[BlockMetadata]:
def get_metadata(self, fetch_if_missing: bool = False) -> List[BlockMetadata]:
"""Get the metadata for all blocks."""
return self._metadata.copy()

Expand Down Expand Up @@ -182,23 +174,3 @@ def executed_num_blocks(self) -> int:
doesn't know how many blocks will be produced until tasks finish.
"""
return len(self.get_blocks())

def ensure_schema_for_first_block(self) -> Optional[Union["pyarrow.Schema", type]]:
"""Ensure that the schema is set for the first block.
Returns None if the block list is empty.
"""
get_schema = cached_remote_fn(_get_schema)
try:
block = next(self.iter_blocks())
except (StopIteration, ValueError):
# Dataset is empty (no blocks) or was manually cleared.
return None
schema = ray.get(get_schema.remote(block))
# Set the schema.
self._metadata[0].schema = schema
return schema


def _get_schema(block: Block) -> Any:
return BlockAccessor.for_block(block).schema()
Loading

0 comments on commit efc5ac5

Please sign in to comment.