Skip to content

Commit

Permalink
Improve DatasetOptimizer API (#18827)
Browse files Browse the repository at this point in the history
Co-authored-by: Justus Schock <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: thomas <[email protected]>
  • Loading branch information
4 people authored Oct 23, 2023
1 parent 1a5718a commit e59dc41
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 93 deletions.
2 changes: 1 addition & 1 deletion requirements/app/app.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
lightning-cloud ==0.5.42 # Must be pinned to ensure compatibility
lightning-cloud ==0.5.43 # Must be pinned to ensure compatibility
packaging
typing-extensions >=4.0.0, <4.8.0
deepdiff >=5.7.0, <6.6.0
Expand Down
8 changes: 5 additions & 3 deletions src/lightning/app/cli/commands/cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _upload_files(live, client: LightningClient, local_src: str, remote_dst: str
else:
upload_paths = [local_src]

upload_urls = []
_upload_urls = []

clusters = client.projects_service_list_project_cluster_bindings(project_id)

Expand All @@ -129,9 +129,11 @@ def _upload_files(live, client: LightningClient, local_src: str, remote_dst: str
body=ProjectIdStorageBody(cluster_id=cluster.cluster_id, filename=filename),
async_req=True,
)
upload_urls.append(response)
_upload_urls.append(response)

upload_urls = [upload_url.get().upload_url for upload_url in upload_urls]
upload_urls = []
for upload_url in _upload_urls:
upload_urls.extend(upload_url.get().urls)

live.stop()

Expand Down
9 changes: 8 additions & 1 deletion src/lightning/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from lightning.data.datasets import LightningDataset, LightningIterableDataset
from lightning.data.streaming.dataloader import StreamingDataLoader
from lightning.data.streaming.dataset import StreamingDataset
from lightning.data.streaming.dataset_optimizer import DatasetOptimizer

__all__ = ["LightningDataset", "StreamingDataset", "StreamingDataLoader", "LightningIterableDataset"]
__all__ = [
"LightningDataset",
"StreamingDataset",
"StreamingDataLoader",
"LightningIterableDataset",
"DatasetOptimizer",
]
121 changes: 110 additions & 11 deletions src/lightning/data/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,135 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Literal, Optional, Union
from typing import Any, List, Literal, Optional, Union

from torch.utils.data import Dataset
import numpy as np
from torch.utils.data import IterableDataset

from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv
from lightning.data.streaming import Cache
from lightning.data.streaming.item_loader import BaseItemLoader
from lightning.data.streaming.sampler import ChunkedIndex


class StreamingDataset(Dataset):
class StreamingDataset(IterableDataset):
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class."""

def __init__(
self, name: str, version: Optional[Union[int, Literal["latest"]]] = "latest", cache_dir: Optional[str] = None
self,
name: str,
version: Optional[Union[int, Literal["latest"]]] = "latest",
cache_dir: Optional[str] = None,
item_loader: Optional[BaseItemLoader] = None,
shuffle: bool = True,
seed: int = 42,
) -> None:
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
Arguments:
name: The name of the optimised dataset.
version: The version of the dataset to use.
cache_dir: The cache dir where the data would be stored.
item_loader: The logic to load an item from a chunk.
shuffle: Whether to shuffle the data.
seed: Random seed for shuffling.
"""
super().__init__()
self.cache = Cache(name=name, version=version, cache_dir=cache_dir)
self.cache = Cache(name=name, version=version, cache_dir=cache_dir, item_loader=item_loader, chunk_bytes=1)

self.cache._reader._try_load_config()

if not self.cache.filled:
raise ValueError(f"The provided dataset `{name}` isn't filled up.")

self.shuffle = shuffle
self.distributed_env = _DistributedEnv.detect()
self.worker_env: Optional[_WorkerEnv] = None

chunk_intervals = self.cache.get_chunk_interval()
self.L = sum([(interval[-1] - interval[0]) for interval in chunk_intervals])

self.worker_chunks: List[int] = []
self.worker_intervals: List[List[int]] = []
self.current_indexes: List[int] = []
self.chunk_index = 0
self.index = 0
self.has_triggered_download = False
self.min_items_per_replica: Optional[int] = None
self.seed = seed
self.num_iter = 0
self.random_state = None

def __len__(self) -> int:
return len(self.cache)
return self.L

def __iter__(self) -> "StreamingDataset":
self.random_state = np.random.RandomState(seed=self.seed + self.num_iter) # type: ignore
chunk_intervals = self.cache.get_chunk_interval()
indexes = range(len(chunk_intervals))
shuffled_indexes = self.random_state.permutation(indexes) if self.shuffle else list(indexes)
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]

chunks_per_replica: List[List[int]] = [[] for _ in range(self.distributed_env.world_size)]
intervals_per_replica: List[List[List[int]]] = [[] for _ in range(self.distributed_env.world_size)]
for index, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)):
replica_index = index % self.distributed_env.world_size
chunks_per_replica[replica_index].append(chunk_index)
intervals_per_replica[replica_index].append(chunk_interval)

current_chunks = chunks_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]
current_intervals = intervals_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]

if self.worker_env is None:
self.worker_env = _WorkerEnv.detect()

self.worker_chunks = []
self.worker_intervals = []

for i, (chunk_index, chunk_interval) in enumerate(zip(current_chunks, current_intervals)):
if i % self.worker_env.world_size != self.worker_env.rank:
continue
self.worker_chunks.append(chunk_index)
self.worker_intervals.append(chunk_interval)

self.current_indexes = []
self.chunk_index = 0
self.num_iter += 1

return self

def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any:
if isinstance(index, int):
index = ChunkedIndex(index, self.cache._get_chunk_index_from_index(index))
return self.cache[index]

def __next__(self) -> Any:
# Lazily re-populate the interval to reduce memory usage.
if len(self.current_indexes) == 0:
if self.chunk_index == len(self.worker_intervals):
raise StopIteration

interval = self.worker_intervals[self.chunk_index]
current_indexes = np.arange(0, interval[1] - interval[0])
if self.shuffle:
current_indexes = self.random_state.permutation(current_indexes)
self.current_indexes = current_indexes.tolist()
self.chunk_index += 1

# Get the first index
index = self.current_indexes.pop(0)

# Call the `__getitem__` method.
data = self.__getitem__(
ChunkedIndex(
index=index,
chunk_index=self.worker_chunks[self.chunk_index - 1],
chunk_indexes=None if self.has_triggered_download else self.worker_chunks,
)
)

def __getitem__(self, idx: int) -> Any:
return self.cache[idx]
self.has_triggered_download = True
self.index += 1

def getitem(self, obj: Any) -> Any:
"""Override the getitem with your own logic to transform the cache object."""
return obj
return data
102 changes: 45 additions & 57 deletions src/lightning/data/streaming/dataset_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import signal
import traceback
import types
from abc import ABC, abstractmethod
from enum import Enum
from multiprocessing import Process, Queue
from pathlib import Path
from queue import Empty
from shutil import copyfile
from textwrap import dedent
from threading import Thread
from time import sleep, time
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
from typing import Any, Callable, Dict, List, Literal, Optional, Protocol, Tuple, TypeVar, runtime_checkable
from urllib import parse

from tqdm.auto import tqdm
Expand Down Expand Up @@ -167,7 +167,7 @@ def __init__(
start_index: int,
dataset_name: str,
node_rank: int,
dataset_optimizer: "DatasetOptimizer",
prepare_item: Callable,
src_dir: str,
remote_src_dir: str,
remote_dst_dir: Optional[str],
Expand All @@ -187,7 +187,7 @@ def __init__(
self.start_index = start_index
self.dataset_name = dataset_name
self.node_rank = node_rank
self.prepare_item = dataset_optimizer.prepare_item
self.prepare_item = prepare_item
self.src_dir = src_dir
self.remote_src_dir = remote_src_dir
self.remote_dst_dir = remote_dst_dir
Expand Down Expand Up @@ -432,57 +432,21 @@ class WorkerType(Enum):
PROCESS = "process"


class DatasetOptimizer(ABC):
@abstractmethod
def prepare_dataset_structure(self, src_dir: str, filepaths: List[str]) -> List[Any]:
"""This function is meant to return a list of item metadata. Each item metadata should be enough to prepare a
single item when called with the prepare_item.
T = TypeVar("T")

Example::

# For a classification use case
def prepare_dataset_structure(self, src_dir, filepaths)
import numpy as np
filepaths = ['class_a/file_1.ext', ..., 'class_b/file_1.ext', ...]
classes = np.unique([filepath.split("/")[0] for filepath in filepaths])
classes_to_idx_map = {c: idx for idx, c in enumerate(classes)}
# Return pair with the filepath to the obj and its class
# [('class_a/file_1.ext', 0), ... ('class_b/file_1.ext', 1)]
return [(filepath, classes_to_idx_map[filepath.split("/")[0]]) for filepath in filepaths]
Example::
# For a image segmentation use case
def prepare_dataset_structure(self, src_dir, filepaths)
import numpy as np
filepaths = ['file_1.JPEG', 'file_1.mask', .... 'file_N.JPEG', 'file_N.mask', ...]
# [('file_1.JPEG', 'file_1.mask'), ... ('file_N.JPEG', 'file_N.mask')]
return [(x[i], x[i+1]) for i in range(len(filepaths) -1)]
def prepare_item(self, obj):
image_filepath, mask_filepath = obj
image = load_and_resize(image_filepath)
mask = load_and_resize(mask_filepath)
return (image, mask)
"""
@runtime_checkable
class _OptimizableDataset(Protocol):
@staticmethod
def prepare_dataset_structure(root: str, filepaths: List[str]) -> List[T]:
pass

def prepare_item(self, metadata_item: Any) -> Any:
"""Using some metadata, prepare the associated item.
@staticmethod
def prepare_item(item_metadata: T) -> Any:
return item_metadata

The output of this function will be binarised
"""
return metadata_item

class DatasetOptimizer:
def __init__(
self,
name: str,
Expand Down Expand Up @@ -547,9 +511,29 @@ def __init__(
)
self.random_seed = random_seed

def run(self) -> None:
def run(self, optimizable_dataset: _OptimizableDataset) -> None:
"""The `DatasetChunker.run(...)` method is used to trigger the data processing from your dataset into
chunks."""
if not isinstance(optimizable_dataset, _OptimizableDataset):
raise ValueError(
dedent(
"""The provided argument to the DatasetOptimizer.run(...) needs to have the following format:
Example:
class YourDataset:
@staticmethod
def prepare_dataset_structure(root: str, filepaths: List[str]) -> List[T]:
return [...]
@staticmethod
def prepare_item(item_metadata: T) -> Any:
return ...
"""
)
)

t0 = time()
print(f"Setup started for `{self.name}` with fast_dev_run={self.fast_dev_run}.")

Expand All @@ -564,7 +548,7 @@ def run(self) -> None:
seed_everything(self.random_seed)

# Call the setup method of the user
user_items = self.prepare_dataset_structure(self.src_dir, filepaths)
user_items: List[Any] = optimizable_dataset.prepare_dataset_structure(self.src_dir, filepaths)

if not isinstance(user_items, list):
raise ValueError("The setup_fn should return a list of item metadata.")
Expand All @@ -588,9 +572,9 @@ def run(self) -> None:
signal.signal(signal.SIGINT, self._signal_handler)

if self.worker_type == WorkerType.THREAD.value:
self._create_thread_workers(begins, workers_user_items)
self._create_thread_workers(optimizable_dataset, begins, workers_user_items)
else:
self._create_process_workers(begins, workers_user_items)
self._create_process_workers(optimizable_dataset, begins, workers_user_items)

print("Workers are ready ! Starting data processing...")

Expand Down Expand Up @@ -634,7 +618,9 @@ def _exit_on_error(self, error: str) -> None:
w.join(0)
raise RuntimeError(f"We found the following error {error}.")

def _create_thread_workers(self, begins: List[int], workers_user_items: List[List[Any]]) -> None:
def _create_thread_workers(
self, optimizable_dataset: _OptimizableDataset, begins: List[int], workers_user_items: List[List[Any]]
) -> None:
current_total = 0
total = sum([len(w) for w in workers_user_items])
with tqdm(total=total, smoothing=0) as pbar:
Expand All @@ -649,7 +635,7 @@ def _create_thread_workers(self, begins: List[int], workers_user_items: List[Lis
begins[worker_idx],
self.name,
_get_node_rank(),
self,
optimizable_dataset.prepare_item,
self.src_dir,
self.remote_src_dir,
self.remote_dst_dir,
Expand All @@ -676,7 +662,9 @@ def _create_thread_workers(self, begins: List[int], workers_user_items: List[Lis
if current_total == total:
break

def _create_process_workers(self, begins: List[int], workers_user_items: List[List[Any]]) -> None:
def _create_process_workers(
self, optimizable_dataset: _OptimizableDataset, begins: List[int], workers_user_items: List[List[Any]]
) -> None:
self.progress_queue = Queue()
workers: List[DataWorkerProcess] = []
stop_queues: List[Queue] = []
Expand All @@ -688,7 +676,7 @@ def _create_process_workers(self, begins: List[int], workers_user_items: List[Li
begins[worker_idx],
self.name,
_get_node_rank(),
self,
optimizable_dataset.prepare_item,
self.src_dir,
self.remote_src_dir,
self.remote_dst_dir,
Expand Down
Loading

0 comments on commit e59dc41

Please sign in to comment.