Skip to content

Commit

Permalink
Lightning Data: Refactor files (#19424)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Feb 8, 2024
1 parent bc56630 commit ac9d63f
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 90 deletions.
2 changes: 1 addition & 1 deletion src/lightning/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from lightning.data.processing.functions import map, optimize, walk
from lightning.data.streaming.combined import CombinedStreamingDataset
from lightning.data.streaming.dataloader import StreamingDataLoader
from lightning.data.streaming.dataset import StreamingDataset
from lightning.data.streaming.functions import map, optimize, walk

__all__ = [
"LightningDataset",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@

import torch

from lightning.data.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
from lightning.data.processing.readers import BaseReader
from lightning.data.streaming.constants import _IS_IN_STUDIO, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
from lightning.data.streaming.resolver import (
Dir,
_assert_dir_has_index_file,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/processing/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from lightning_utilities.core.imports import RequirementCache

from lightning.data.streaming.shuffle import _associate_chunks_and_internals_to_ranks
from lightning.data.utilities.env import _DistributedEnv
from lightning.data.utilities.shuffle import _associate_chunks_and_internals_to_ranks

_POLARS_AVAILABLE = RequirementCache("polars")
_PYARROW_AVAILABLE = RequirementCache("pyarrow")
Expand Down
4 changes: 0 additions & 4 deletions src/lightning/data/streaming/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,14 @@

from lightning.data.streaming.cache import Cache
from lightning.data.streaming.combined import CombinedStreamingDataset
from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
from lightning.data.streaming.dataloader import StreamingDataLoader
from lightning.data.streaming.dataset import StreamingDataset
from lightning.data.streaming.item_loader import TokensLoader

__all__ = [
"Cache",
"DataProcessor",
"StreamingDataset",
"CombinedStreamingDataset",
"StreamingDataLoader",
"DataTransformRecipe",
"DataChunkRecipe",
"TokensLoader",
]
74 changes: 1 addition & 73 deletions src/lightning/data/streaming/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from lightning.data.streaming import Cache
from lightning.data.utilities.env import _DistributedEnv
from lightning.data.utilities.shuffle import _associate_chunks_and_internals_to_ranks, _intra_node_chunk_shuffle


class Shuffle(ABC):
Expand Down Expand Up @@ -129,76 +130,3 @@ def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, c

def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk_index: int) -> List[int]:
return np.random.RandomState([self.seed, num_chunks * current_epoch, chunk_index]).permutation(array).tolist()


def _intra_node_chunk_shuffle(
distributed_env: _DistributedEnv,
chunks_per_ranks: List[List[int]],
seed: int,
current_epoch: int,
) -> List[int]:
chunk_indexes_per_nodes: Any = [[] for _ in range(distributed_env.num_nodes)]
for rank, chunks_per_rank in enumerate(chunks_per_ranks):
chunk_indexes_per_nodes[0 if distributed_env.num_nodes == 1 else rank // distributed_env.num_nodes].extend(
chunks_per_rank
)

# shuffle the chunks associated to the node
for i in range(len(chunk_indexes_per_nodes)):
# permute the indexes within the node
chunk_indexes_per_nodes[i] = np.random.RandomState(seed=seed + current_epoch).permutation(
chunk_indexes_per_nodes[i]
)

return [index for chunks in chunk_indexes_per_nodes for index in chunks]


def _associate_chunks_and_internals_to_ranks(
distributed_env: _DistributedEnv,
indexes: Any,
chunk_intervals: Any,
drop_last: bool,
) -> Tuple[List[List[int]], List[Any]]:
num_items = sum([(interval[-1] - interval[0]) for interval in chunk_intervals])
num_items_per_ranks: List[int] = [
num_items // distributed_env.world_size + num_items % distributed_env.world_size
if rank == distributed_env.world_size - 1 and not drop_last
else num_items // distributed_env.world_size
for rank in range(distributed_env.world_size)
]
chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
intervals_per_ranks: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]

# 4. Assign the chunk & intervals to each rank
for chunk_index, chunk_interval in zip(indexes, chunk_intervals):
rank = 0

while True:
if rank == len(num_items_per_ranks):
break

items_left_to_assign = num_items_per_ranks[rank]

if items_left_to_assign == 0:
rank += 1
continue

items_in_chunk = chunk_interval[-1] - chunk_interval[0]

if items_in_chunk == 0:
break

if items_in_chunk > items_left_to_assign:
chunks_per_ranks[rank].append(chunk_index)
begin, end = chunk_interval
intervals_per_ranks[rank].append([begin, begin + items_left_to_assign])
chunk_interval = (begin + items_left_to_assign, end)
num_items_per_ranks[rank] = 0
rank += 1
else:
chunks_per_ranks[rank].append(chunk_index)
intervals_per_ranks[rank].append(chunk_interval)
num_items_per_ranks[rank] -= items_in_chunk
break

return chunks_per_ranks, intervals_per_ranks
78 changes: 78 additions & 0 deletions src/lightning/data/utilities/shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Any, List, Tuple

import numpy as np

from lightning.data.utilities.env import _DistributedEnv


def _intra_node_chunk_shuffle(
distributed_env: _DistributedEnv,
chunks_per_ranks: List[List[int]],
seed: int,
current_epoch: int,
) -> List[int]:
chunk_indexes_per_nodes: Any = [[] for _ in range(distributed_env.num_nodes)]
for rank, chunks_per_rank in enumerate(chunks_per_ranks):
chunk_indexes_per_nodes[0 if distributed_env.num_nodes == 1 else rank // distributed_env.num_nodes].extend(
chunks_per_rank
)

# shuffle the chunks associated to the node
for i in range(len(chunk_indexes_per_nodes)):
# permute the indexes within the node
chunk_indexes_per_nodes[i] = np.random.RandomState(seed=seed + current_epoch).permutation(
chunk_indexes_per_nodes[i]
)

return [index for chunks in chunk_indexes_per_nodes for index in chunks]


def _associate_chunks_and_internals_to_ranks(
distributed_env: _DistributedEnv,
indexes: Any,
chunk_intervals: Any,
drop_last: bool,
) -> Tuple[List[List[int]], List[Any]]:
num_items = sum([(interval[-1] - interval[0]) for interval in chunk_intervals])
num_items_per_ranks: List[int] = [
num_items // distributed_env.world_size + num_items % distributed_env.world_size
if rank == distributed_env.world_size - 1 and not drop_last
else num_items // distributed_env.world_size
for rank in range(distributed_env.world_size)
]
chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
intervals_per_ranks: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]

# 4. Assign the chunk & intervals to each rank
for chunk_index, chunk_interval in zip(indexes, chunk_intervals):
rank = 0

while True:
if rank == len(num_items_per_ranks):
break

items_left_to_assign = num_items_per_ranks[rank]

if items_left_to_assign == 0:
rank += 1
continue

items_in_chunk = chunk_interval[-1] - chunk_interval[0]

if items_in_chunk == 0:
break

if items_in_chunk > items_left_to_assign:
chunks_per_ranks[rank].append(chunk_index)
begin, end = chunk_interval
intervals_per_ranks[rank].append([begin, begin + items_left_to_assign])
chunk_interval = (begin + items_left_to_assign, end)
num_items_per_ranks[rank] = 0
rank += 1
else:
chunks_per_ranks[rank].append(chunk_index)
intervals_per_ranks[rank].append(chunk_interval)
num_items_per_ranks[rank] -= items_in_chunk
break

return chunks_per_ranks, intervals_per_ranks
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
import pytest
import torch
from lightning import seed_everything
from lightning.data.streaming import data_processor as data_processor_module
from lightning.data.streaming import functions, resolver
from lightning.data.streaming.cache import Cache, Dir
from lightning.data.streaming.data_processor import (
from lightning.data.processing import data_processor as data_processor_module
from lightning.data.processing import functions
from lightning.data.processing.data_processor import (
DataChunkRecipe,
DataProcessor,
DataTransformRecipe,
Expand All @@ -26,7 +25,9 @@
_wait_for_disk_usage_higher_than_threshold,
_wait_for_file_to_exist,
)
from lightning.data.streaming.functions import LambdaDataTransformRecipe, map, optimize
from lightning.data.processing.functions import LambdaDataTransformRecipe, map, optimize
from lightning.data.streaming import resolver
from lightning.data.streaming.cache import Cache, Dir
from lightning_utilities.core.imports import RequirementCache

_PIL_AVAILABLE = RequirementCache("PIL")
Expand Down Expand Up @@ -162,7 +163,7 @@ def fn(*_, **__):


@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
@mock.patch("lightning.data.streaming.data_processor._wait_for_disk_usage_higher_than_threshold")
@mock.patch("lightning.data.processing.data_processor._wait_for_disk_usage_higher_than_threshold")
def test_download_data_target(wait_for_disk_usage_higher_than_threshold_mock, tmpdir):
input_dir = os.path.join(tmpdir, "input_dir")
os.makedirs(input_dir, exist_ok=True)
Expand Down Expand Up @@ -201,7 +202,7 @@ def fn(*_, **__):

def test_wait_for_disk_usage_higher_than_threshold():
disk_usage_mock = mock.Mock(side_effect=[mock.Mock(free=10e9), mock.Mock(free=10e9), mock.Mock(free=10e11)])
with mock.patch("lightning.data.streaming.data_processor.shutil.disk_usage", disk_usage_mock):
with mock.patch("lightning.data.processing.data_processor.shutil.disk_usage", disk_usage_mock):
_wait_for_disk_usage_higher_than_threshold("/", 10, sleep_time=0)
assert disk_usage_mock.call_count == 3

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest
from lightning.data import walk
from lightning.data.streaming.functions import _get_input_dir
from lightning.data.processing.functions import _get_input_dir


@pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.")
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_data/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import pytest
import torch
from lightning import seed_everything
from lightning.data.streaming import Cache, functions
from lightning.data.processing import functions
from lightning.data.streaming import Cache
from lightning.data.streaming.dataloader import StreamingDataLoader
from lightning.data.streaming.dataset import (
_INDEX_FILENAME,
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_data/streaming/test_shuffle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from lightning.data.streaming.shuffle import _associate_chunks_and_internals_to_ranks, _intra_node_chunk_shuffle
from lightning.data.utilities.env import _DistributedEnv
from lightning.data.utilities.shuffle import _associate_chunks_and_internals_to_ranks, _intra_node_chunk_shuffle


def test_intra_node_chunk_shuffle():
Expand Down

0 comments on commit ac9d63f

Please sign in to comment.