Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve some bugs #121

Merged
merged 8 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ files = [
]
# This section is for folders with "-" as they are not valid python modules
exclude = [
"src/litdata/utilities/_pytree.py",
]
install_types = "True"
non_interactive = "True"
Expand Down
3 changes: 1 addition & 2 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import signal
import tempfile
import traceback
import types
from abc import abstractmethod
from dataclasses import dataclass
from multiprocessing import Process, Queue
Expand Down Expand Up @@ -625,7 +624,7 @@ def _handle_data_chunk_recipe(self, index: int) -> None:
try:
current_item = self.items[index] if self.reader is None else self.reader.read(self.items[index])
item_data_or_generator = self.data_recipe.prepare_item(current_item)
if isinstance(item_data_or_generator, types.GeneratorType):
if self.data_recipe.is_generator:
for item_data in item_data_or_generator:
if item_data is not None:
chunk_filepath = self.cache._add_item(self._index_counter, item_data)
Expand Down
35 changes: 20 additions & 15 deletions src/litdata/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,26 +130,31 @@ def __init__(
super().__init__(chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression)
self._fn = fn
self._inputs = inputs
self.is_generator = False

self.check_fn()

self.prepare_item = self._prepare_item_generator if self.is_generator else self._prepare_item # type: ignore

def check_fn(self) -> None:
if (
isinstance(self._fn, (partial, FunctionType))
and inspect.isgeneratorfunction(self._fn)
or (callable(self._fn) and inspect.isgeneratorfunction(self._fn.__call__)) # type: ignore
):
self.is_generator = True

def _prepare_item(self, item_metadata: Any) -> Any:
return self._fn(item_metadata)

def _prepare_item_generator(self, item_metadata: Any) -> Any:
yield from self._fn(item_metadata) # type: ignore

def prepare_structure(self, input_dir: Optional[str]) -> Any:
return self._inputs

def prepare_item(self, item_metadata: Any) -> Any:
if isinstance(self._fn, partial):
yield from self._fn(item_metadata)

elif isinstance(self._fn, FunctionType):
if inspect.isgeneratorfunction(self._fn):
yield from self._fn(item_metadata)
else:
yield self._fn(item_metadata)
elif callable(self._fn):
if inspect.isgeneratorfunction(self._fn.__call__): # type: ignore
yield from self._fn.__call__(item_metadata) # type: ignore
else:
yield self._fn.__call__(item_metadata) # type: ignore
else:
raise ValueError(f"The provided {self._fn} isn't supported.")
"""This method is overriden dynamically."""


def map(
Expand Down
4 changes: 3 additions & 1 deletion src/litdata/streaming/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ def serialize(self, items: Any) -> Tuple[bytes, Optional[int]]:
# Flatten the items provided by the users
flattened, data_spec = tree_flatten(items)

is_single_tensor = len(flattened) == 1 and isinstance(flattened[0], torch.Tensor)
is_single_tensor = (
len(flattened) == 1 and isinstance(flattened[0], torch.Tensor) and len(flattened[0].shape) == 1
)

# Collect the sizes and associated bytes for each item
sizes: List[int] = []
Expand Down
1 change: 0 additions & 1 deletion status.json

This file was deleted.

4 changes: 4 additions & 0 deletions tests/processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,8 @@ def test_map_items_to_workers_sequentially(monkeypatch):


class CustomDataChunkRecipe(DataChunkRecipe):
is_generator = False

def prepare_structure(self, input_dir: str) -> List[Any]:
filepaths = [os.path.join(input_dir, f) for f in os.listdir(input_dir)]
assert len(filepaths) == 30
Expand Down Expand Up @@ -553,6 +555,8 @@ def test_data_processsor_distributed(fast_dev_run, delete_cached_files, tmpdir,


class TextTokenizeRecipe(DataChunkRecipe):
is_generator = True

def prepare_structure(self, input_dir: str) -> List[Any]:
return [os.path.join(input_dir, "dummy.txt")]

Expand Down
Loading