Skip to content

Commit

Permalink
Fix: Chunks deletion issue (#375)
Browse files Browse the repository at this point in the history
* not sure if it works. hehe

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added test to check behaviour in CI

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Bhimraj Yadav <[email protected]>
  • Loading branch information
3 people authored Sep 26, 2024
1 parent b039b64 commit 5cae73c
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 6 deletions.
26 changes: 20 additions & 6 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ def run(self) -> None:
try:
self._setup()
self._loop()
self._terminate()
except Exception:
traceback_format = traceback.format_exc()
self.error_queue.put(traceback_format)
Expand All @@ -469,6 +470,19 @@ def _setup(self) -> None:
self._start_uploaders()
self._start_remover()

def _terminate(self) -> None:
"""Make sure all the uploaders, downloaders and removers are terminated."""
for uploader in self.uploaders:
if uploader.is_alive():
uploader.join()

for downloader in self.downloaders:
if downloader.is_alive():
downloader.join()

if self.remover and self.remover.is_alive():
self.remover.join()

def _loop(self) -> None:
num_downloader_finished = 0

Expand Down Expand Up @@ -795,7 +809,7 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul

chunks = [file for file in os.listdir(cache_dir) if file.endswith(".bin")]
if chunks and delete_cached_files and output_dir.path is not None:
raise RuntimeError(f"All the chunks should have been deleted. Found {chunks}")
raise RuntimeError(f"All the chunks should have been deleted. Found {chunks} in cache: {cache_dir}")

merge_cache = Cache(cache_dir, chunk_bytes=1)
node_rank = _get_node_rank()
Expand Down Expand Up @@ -1110,6 +1124,10 @@ def run(self, data_recipe: DataRecipe) -> None:

current_total = new_total
if current_total == num_items:
# make sure all processes are terminated
for w in self.workers:
if w.is_alive():
w.join()
break

if _IS_IN_STUDIO and node_rank == 0 and _ENABLE_STATUS:
Expand All @@ -1118,17 +1136,13 @@ def run(self, data_recipe: DataRecipe) -> None:

# Exit early if all the workers are done.
# This means there were some kinda of errors.
# TODO: Check whether this is still required.
if all(not w.is_alive() for w in self.workers):
raise RuntimeError("One of the worker has failed")

if _TQDM_AVAILABLE:
pbar.close()

# TODO: Check whether this is still required.
if num_nodes == 1:
for w in self.workers:
w.join()

print("Workers are finished.")
result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir)

Expand Down
55 changes: 55 additions & 0 deletions tests/processing/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import glob
import os
import random
import shutil
import sys
from pathlib import Path
from unittest import mock

import cryptography
import numpy as np
import pytest
import requests
from litdata import StreamingDataset, merge_datasets, optimize, walk
from litdata.processing.functions import _get_input_dir, _resolve_dir
from litdata.streaming.cache import Cache
Expand Down Expand Up @@ -475,3 +480,53 @@ def test_optimize_with_rsa_encryption(tmpdir):
# encryption=rsa,
# mode="overwrite",
# )


def tokenize(filename: str):
with open(filename, encoding="utf-8") as file:
text = file.read()
text = text.strip().split(" ")
word_to_int = {word: random.randint(1, 1000) for word in set(text)} # noqa: S311
tokenized = [word_to_int[word] for word in text]

yield tokenized


@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows")
def test_optimize_race_condition(tmpdir):
# issue: https://github.com/Lightning-AI/litdata/issues/367
# run_commands = [
# "mkdir -p tempdir/custom_texts",
# "curl https://www.gutenberg.org/cache/epub/24440/pg24440.txt --output tempdir/custom_texts/book1.txt",
# "curl https://www.gutenberg.org/cache/epub/26393/pg26393.txt --output tempdir/custom_texts/book2.txt",
# ]
shutil.rmtree(f"{tmpdir}/custom_texts", ignore_errors=True)
os.makedirs(f"{tmpdir}/custom_texts", exist_ok=True)

urls = [
"https://www.gutenberg.org/cache/epub/24440/pg24440.txt",
"https://www.gutenberg.org/cache/epub/26393/pg26393.txt",
]

for i, url in enumerate(urls):
print(f"downloading {i+1} file")
with requests.get(url, stream=True, timeout=10) as r:
r.raise_for_status() # Raise an exception for bad status codes

with open(f"{tmpdir}/custom_texts/book{i+1}.txt", "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)

print("=" * 100)

train_files = sorted(glob.glob(str(Path(f"{tmpdir}/custom_texts") / "*.txt")))
print("=" * 100)
print(train_files)
print("=" * 100)
optimize(
fn=tokenize,
inputs=train_files,
output_dir=f"{tmpdir}/temp",
num_workers=1,
chunk_bytes="50MB",
)

0 comments on commit 5cae73c

Please sign in to comment.