From b53cf5eb542882156e35abe93a2cac1c0a69d6f0 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Tue, 17 Sep 2024 20:06:36 +0200 Subject: [PATCH] more fixing --- .pre-commit-config.yaml | 7 -- examples/multi_modal/create_labelencoder.py | 5 +- examples/multi_modal/dataloader.py | 31 ++++--- examples/multi_modal/loop.py | 99 ++++++--------------- examples/multi_modal/model_arc.py | 15 ++-- pyproject.toml | 39 +++++--- src/litdata/processing/data_processor.py | 17 ++-- src/litdata/processing/functions.py | 15 ++-- src/litdata/processing/readers.py | 2 +- src/litdata/streaming/cache.py | 4 +- src/litdata/streaming/combined.py | 7 +- src/litdata/streaming/config.py | 5 +- src/litdata/streaming/dataloader.py | 2 +- src/litdata/streaming/item_loader.py | 6 +- src/litdata/streaming/shuffle.py | 3 - src/litdata/streaming/writer.py | 5 +- src/litdata/utilities/broadcast.py | 4 +- src/litdata/utilities/dataset_utilities.py | 2 +- src/litdata/utilities/train_test_split.py | 8 +- tests/conftest.py | 4 +- tests/processing/test_data_processor.py | 2 +- tests/streaming/test_dataset.py | 4 +- 22 files changed, 117 insertions(+), 169 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 93cb15a1..104ebe5c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -46,13 +46,6 @@ repos: additional_dependencies: [tomli] #args: ["--write-changes"] # uncomment if you want to get automatic fixing - - repo: https://github.com/PyCQA/docformatter - rev: v1.7.5 - hooks: - - id: docformatter - additional_dependencies: [tomli] - args: ["--in-place"] - - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.3.2 hooks: diff --git a/examples/multi_modal/create_labelencoder.py b/examples/multi_modal/create_labelencoder.py index 570a83d4..83549924 100644 --- a/examples/multi_modal/create_labelencoder.py +++ b/examples/multi_modal/create_labelencoder.py @@ -3,10 +3,7 @@ def create_labelencoder(): - """Create a label encoder - Returns: - - """ + """Create a label encoder.""" data = ["Cancelation", "IBAN Change", "Damage Report"] # Create an instance of LabelEncoder label_encoder = LabelEncoder() diff --git a/examples/multi_modal/dataloader.py b/examples/multi_modal/dataloader.py index ffe6ab38..ae5e35e5 100644 --- a/examples/multi_modal/dataloader.py +++ b/examples/multi_modal/dataloader.py @@ -29,14 +29,12 @@ def __init__(self): self.hyperparameters = HYPERPARAMETERS def load_labelencoder(self): - """Function to load the label encoder from s3 - Returns: - """ + """Function to load the label encoder from s3.""" return joblib.load(self.hyperparameters["label_encoder_name"]) def load_tokenizer(self): - """Load the tokenizer files and the pre training model path from s3 spezified in the hyperparameters - Returns: tokenizer + """Load the tokenizer files and the pre-training model path from s3 spezified in the hyperparameters + Returns: tokenizer. """ # Load Bert tokenizer return BertTokenizerFast.from_pretrained("bert-base-cased") @@ -60,12 +58,10 @@ def __init__(self, input_dir: Union[str, Any], hyperparameters: Union[dict, Any] self.labelencoder = EC.load_labelencoder() def tokenize_data(self, tokenizer, texts, max_length: int): - """Tokenize the text - Args: - tokenizer: - texts: - max_length: - Returns: input_ids, attention_masks + """Tokenize the text. + + Returns: input_ids, attention_masks. + """ encoded_text = tokenizer( texts, @@ -98,7 +94,7 @@ def __init__(self, hyperparameters: dict): """Init if the Data Module Args: data_path: dataframe with the data - hyperparameters: Hyperparameters + hyperparameters: Hyperparameters. """ super().__init__() self.hyperparameters = hyperparameters @@ -126,9 +122,12 @@ def __init__(self, hyperparameters: dict): ) def train_dataloader(self) -> DataLoader: - """Define the training dataloader + """Define the training dataloader. + Returns: - training dataloader + ------- + training dataloader. + """ dataset_train = DocumentClassificationDataset( hyperparameters=self.hyperparameters, @@ -147,7 +146,7 @@ def train_dataloader(self) -> DataLoader: def val_dataloader(self) -> DataLoader: """Define the validation dataloader Returns: - validation dataloader + validation dataloader. """ dataset_val = DocumentClassificationDataset( hyperparameters=self.hyperparameters, @@ -165,7 +164,7 @@ def val_dataloader(self) -> DataLoader: def test_dataloader(self) -> DataLoader: """Define the test dataloader Returns: - test dataloader + test dataloader. """ dataset_test = DocumentClassificationDataset( hyperparameters=self.hyperparameters, diff --git a/examples/multi_modal/loop.py b/examples/multi_modal/loop.py index 4acf4997..3a25c52d 100644 --- a/examples/multi_modal/loop.py +++ b/examples/multi_modal/loop.py @@ -78,7 +78,6 @@ def save_reports(self, model_dir, mode, report_confusion_matrix, report): mode: train, test or val report_confusion_matrix: sklearn confusion matrix report: sklear classification report - Returns: """ df_cm = pd.DataFrame(report_confusion_matrix) @@ -88,16 +87,7 @@ def save_reports(self, model_dir, mode, report_confusion_matrix, report): logger.info("Confusion Matrix and Classication report are saved.") def save_test_evaluations(self, model_dir, mode, y_pred, y_true, confis, numerical_id_): - """Save a pandas dataframe with prediction and ground truth and identifier (numerical id) of the test dataset - Args: - model_dir: - mode: - y_pred: - y_true: - confis: - numerical_id_: - Returns: - """ + """Save pandas dataframe with prediction and ground truth and identifier (numerical id) of the test dataset.""" df_test = pd.DataFrame() df_test["pred"] = y_pred df_test["confidence"] = confis.max(axis=1) @@ -152,9 +142,6 @@ def forward( Used for train, test and val. - Args: - ---- - y: tensor with text data as tokens Returns: computional graph @@ -162,31 +149,29 @@ def forward( return self.module(x, y, z) def training_step(self, batch: Dict[str, torch.Tensor]) -> Dict: - """Call the eval share for training - Args: - batch: tensor + """Call the eval share for training. + Returns: - dict with loss, outputs and ground_truth + dict with loss, outputs and ground_truth. + """ return self._shared_eval_step(batch, "train") def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict: - """Call the eval share for validation - Args: - batch: - batch_idx: + """Call the eval share for validation. + Returns: - dict with loss, outputs and ground_truth + dict with loss, outputs and ground_truth. + """ return self._shared_eval_step(batch, "val") def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict: - """Call the eval share for test - Args: - batch: - batch_idx: + """Call the eval share for test. + Returns: - dict with loss, outputs and ground_truth + dict with loss, outputs and ground_truth. + """ ret = self._shared_eval_step(batch, "test") self.pred_list.append(ret) @@ -199,7 +184,9 @@ def _shared_eval_step(self, batch: Dict[str, torch.Tensor], mode: str) -> Dict: ---- batch: tensor mode: train, test or val + Returns: + ------- dict with loss, outputs and ground_truth """ @@ -227,13 +214,8 @@ def _shared_eval_step(self, batch: Dict[str, torch.Tensor], mode: str) -> Dict: return {"outputs": out, "loss": loss, "ground_truth": ground_truth, "numerical_id": numerical_id} - def _epoch_end(self, mode: str): - """Calculate loss and metricies at end of epoch - Args: - mode: - Returns: - None - """ + def _epoch_end(self, mode: str) -> None: + """Calculate loss and metrics at end of epoch.""" if mode == "val": output = self.val_metrics.compute() self.log_dict(output) @@ -248,15 +230,7 @@ def _epoch_end(self, mode: str): self.test_metrics.reset() def predict(self, batch: Dict[str, torch.Tensor], batch_idx: int = 0, dataloader_idx: int = 0) -> torch.Tensor: - """Model prediction without softmax and argmax to predict class label. - - Args: - ---- - outputs: - Returns: - None - - """ + """Model prediction without softmax and argmax to predict class label.""" self.eval() with torch.no_grad(): ids = batch["ID"] @@ -265,48 +239,31 @@ def predict(self, batch: Dict[str, torch.Tensor], batch_idx: int = 0, dataloader return self.forward(ids, atts, img) def on_test_epoch_end(self) -> None: - """Calculate the metrics at the end of epoch for test step - Args: - outputs: - Returns: - None - """ + """Calculate the metrics at the end of epoch for test step.""" self._epoch_end("test") - def on_validation_epoch_end(self): - """Calculate the metrics at the end of epoch for val step - Args: - outputs: - Returns: - None - """ + def on_validation_epoch_end(self) -> None: + """Calculate the metrics at the end of epoch for val step.""" self._epoch_end("val") - def on_train_epoch_end(self): - """Calculate the metrics at the end of epoch for train step - Args: - outputs: - Returns: - None - """ + def on_train_epoch_end(self) -> None: + """Calculate the metrics at the end of epoch for train step.""" self._epoch_end("train") def configure_optimizers(self) -> Any: - """Configure the optimizer + """Configure the optimizer. + Returns: + ------- optimizer + """ optimizer = AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.hyperparameters["weight_decay"]) scheduler = StepLR(optimizer, step_size=1, gamma=0.1) return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}] def configure_callbacks(self) -> Union[Sequence[pl.pytorch.Callback], pl.pytorch.Callback]: - """Configure Early stopping or Model Checkpointing. - - Returns - ------- - - """ + """Configure Early stopping or Model Checkpointing.""" early_stop = EarlyStopping( monitor="val_MulticlassAccuracy", patience=self.hyperparameters["patience"], mode="max" ) diff --git a/examples/multi_modal/model_arc.py b/examples/multi_modal/model_arc.py index 8ea36276..7ecb0652 100644 --- a/examples/multi_modal/model_arc.py +++ b/examples/multi_modal/model_arc.py @@ -42,10 +42,6 @@ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): Used for train, test and val. - Args: - ---- - input_ids - attention_mask Returns: computional graph @@ -74,8 +70,9 @@ def __init__(self, endpoint_mode: bool, hyperparameters: dict): self.dropout = nn.Dropout(self.hyperparameters["dropout"]) def get_bert_model(self): - """Load the pre trained bert model weigths - Returns: model + """Load the pre-trained bert model weigths. + + Returns: model. """ model = BertModel.from_pretrained("bert-base-cased") return BertClassifier(model) @@ -91,9 +88,9 @@ def forward( Args: ---- - x (torch.Tensor): Tensor with id tokesn - y (torch.Tensor): Tensor with attention tokens. - z (torch.Tensor): Tensor with iamge. + x: Tensor with id tokesn + y: Tensor with attention tokens. + z: Tensor with iamge. Returns: ------- diff --git a/pyproject.toml b/pyproject.toml index e22dd9b0..59a8b220 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,12 @@ ignore-words-list = "te, compiletime" [tool.ruff] line-length = 120 target-version = "py38" +# Exclude a variety of commonly ignored directories. +exclude = [ + ".git", + "docs", + "src/litdata/utilities/_pytree.py", +] # Enable Pyflakes `E` and `F` codes by default. lint.select = [ "E", "W", # see: https://pypi.org/project/pycodestyle @@ -65,40 +71,47 @@ lint.extend-select = [ "RET", # see: https://pypi.org/project/flake8-return "PT", # see: https://pypi.org/project/flake8-pytest-style "NPY201", # see: https://docs.astral.sh/ruff/rules/numpy2-deprecation - "RUF100" # yesqa + "RUF100", # yesqa ] lint.ignore = [ "E731", # Do not assign a lambda expression, use a def "S101", # todo: Use of `assert` detected ] -# Exclude a variety of commonly ignored directories. -exclude = [ - ".git", - "docs", - "src/litdata/utilities/_pytree.py", -] lint.ignore-init-module-imports = true +# Unlike Flake8, default to a complexity level of 10. +lint.mccabe.max-complexity = 10 +# Use Google-style docstrings. +lint.pydocstyle.convention = "google" [tool.ruff.lint.per-file-ignores] -".actions/*" = ["S101", "S310"] -"setup.py" = ["S101", "SIM115"] +"setup.py" = ["D100", "SIM115"] "examples/**" = [ + "D100", "D101", "D102", "D103", "D104", "D105", "D107", # Missing docstring in public module, class, method, function, package + "D205", # todo: 1 blank line required between summary line and description + "D401", "D404", # First line should be in imperative mood; try rephrasing "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes ] "src/**" = [ + "D100", # Missing docstring in public module + "D101", # todo: Missing docstring in public class + "D102", # todo: Missing docstring in public method + "D103", # todo: Missing docstring in public function + "D104", # Missing docstring in public package + "D105", # todo: Missing docstring in magic method + "D107", # todo: Missing docstring in __init__ + "D205", # todo: 1 blank line required between summary line and description + "D401", "D404", # todo: First line should be in imperative mood; try rephrasing "S602", # todo: `subprocess` call with `shell=True` identified, security issue "S605", # todo: Starting a process with a shell: seems safe, but may be changed in the future; consider rewriting without `shell` "S607", # todo: Starting a process with a partial executable path "S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected. ] "tests/**" = [ + "D100", "D101", "D102", "D103", "D104", "D105", "D107", # Missing docstring in public module, class, method, function, package + "D401", "D404", # First line should be in imperative mood; try rephrasing "S105", "S106", # todo: Possible hardcoded password: ... - "D100", "D101", "D102", "D103", "D104", "D105", ] -[tool.ruff.lint.mccabe] -# Unlike Flake8, default to a complexity level of 10. -max-complexity = 10 [tool.mypy] diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 71d110a0..b0ab60d5 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -119,7 +119,7 @@ def _wait_for_disk_usage_higher_than_threshold(input_dir: str, threshold_in_gb: def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue_out: Queue) -> None: - """This function is used to download data from a remote directory to a cache directory to optimise reading.""" + """Download data from a remote directory to a cache directory to optimise reading.""" s3 = S3Client() while True: @@ -176,7 +176,7 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue def _remove_target(input_dir: Dir, cache_dir: str, queue_in: Queue) -> None: - """This function is used to delete files from the cache directory to minimise disk space.""" + """Delete files from the cache directory to minimise disk space.""" while True: # 1. Collect paths paths = queue_in.get() @@ -199,7 +199,7 @@ def _remove_target(input_dir: Dir, cache_dir: str, queue_in: Queue) -> None: def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_dir: Dir) -> None: - """This function is used to upload optimised chunks from a local to remote dataset directory.""" + """Upload optimised chunks from a local to remote dataset directory.""" obj = parse.urlparse(output_dir.url if output_dir.url else output_dir.path) if obj.scheme == "s3": @@ -787,7 +787,7 @@ def prepare_structure(self, input_dir: Optional[str]) -> List[T]: @abstractmethod def prepare_item(self, item_metadata: T) -> Any: - """The return of this `prepare_item` method is persisted in chunked binary files.""" + """Returns `prepare_item` method is persisted in chunked binary files.""" def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Result: num_nodes = _get_num_nodes() @@ -832,7 +832,7 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul ) def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_rank: Optional[int]) -> None: - """This method upload the index file to the remote cloud directory.""" + """Upload the index file to the remote cloud directory.""" if output_dir.path is None and output_dir.url is None: return @@ -909,8 +909,7 @@ def __init__( item_loader: Optional[BaseItemLoader] = None, start_method: Optional[str] = None, ): - """The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make - training faster. + """Provides an efficient way to process data across multiple machine into chunks to make training faster. Arguments: --------- @@ -985,7 +984,7 @@ def __init__( self.random_seed = random_seed def run(self, data_recipe: DataRecipe) -> None: - """The `DataProcessor.run(...)` method triggers the data recipe processing over your dataset.""" + """Triggers the data recipe processing over your dataset.""" if not isinstance(data_recipe, DataRecipe): raise ValueError("The provided value should be a data recipe.") if not self.use_checkpoint and isinstance(data_recipe, DataChunkRecipe): @@ -1394,7 +1393,7 @@ def _load_checkpoint_config(self, workers_user_items: List[List[Any]]) -> None: def in_notebook() -> bool: - """Returns ``True`` if the module is running in IPython kernel, ``False`` if in IPython shell or other Python + """Returns ``True`` if the module is running in IPython kernel, ``False`` if in IPython or other Python shell. """ return "ipykernel" in sys.modules diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index a7516041..130d50d5 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -180,7 +180,7 @@ def prepare_structure(self, input_dir: Optional[str]) -> Any: return self._inputs def prepare_item(self, item_metadata: Any) -> Any: - """This method is overriden dynamically.""" + """Being overriden dynamically.""" def map( @@ -200,10 +200,9 @@ def map( reader: Optional[BaseReader] = None, batch_size: Optional[int] = None, ) -> None: - """This function maps a callable over a collection of inputs, possibly in a distributed way. + """Maps a callable over a collection of inputs, possibly in a distributed way. - Arguments: - --------- + Args: fn: A function to be executed over each input element inputs: A sequence of input to be processed by the `fn` function, or a streaming dataloader. output_dir: The folder where the processed data should be stored. @@ -219,6 +218,7 @@ def map( reorder_files: By default, reorders the files by file size to distribute work equally among all workers. Set this to ``False`` if the order in which samples are processed should be preserved. error_when_not_empty: Whether we should error if the output folder isn't empty. + reader: The reader to use when reading the data. By default, it uses the `BaseReader`. batch_size: Group the inputs into batches of batch_size length. """ @@ -318,8 +318,7 @@ def optimize( ) -> None: """This function converts a dataset into chunks, possibly in a distributed way. - Arguments: - --------- + Args: fn: A function to be executed over each input element. The function should return the data sample that corresponds to the input. Every invocation of the function should return a similar hierarchy of objects, where the object types and list sizes don't change. @@ -338,6 +337,7 @@ def optimize( machine: When doing remote execution, the machine to use. Only supported on https://lightning.ai/. num_downloaders: The number of downloaders per worker. num_uploaders: The numbers of uploaders per worker. + reader: The reader to use when reading the data. By default, it uses the `BaseReader`. reorder_files: By default, reorders the files by file size to distribute work equally among all workers. Set this to ``False`` if the order in which samples are processed should be preserved. batch_size: Group the inputs into batches of batch_size length. @@ -522,8 +522,7 @@ class CopyInfo: def merge_datasets(input_dirs: List[str], output_dir: str) -> None: - """The merge_datasets utility enables to merge multiple existing optimized datasets into a single optimized - dataset. + """Enables to merge multiple existing optimized datasets into a single optimized dataset. Arguments: --------- diff --git a/src/litdata/processing/readers.py b/src/litdata/processing/readers.py index b5d99cc6..1b2bdc1f 100644 --- a/src/litdata/processing/readers.py +++ b/src/litdata/processing/readers.py @@ -33,7 +33,7 @@ def get_node_rank(self) -> int: @abstractmethod def remap_items(self, items: Any, num_workers: int) -> List[Any]: - """This method is meant to remap the items provided by the users into items more adapted to be distributed.""" + """Remap the items provided by the users into items more adapted to be distributed.""" pass @abstractmethod diff --git a/src/litdata/streaming/cache.py b/src/litdata/streaming/cache.py index e9ffc1f5..d045d5da 100644 --- a/src/litdata/streaming/cache.py +++ b/src/litdata/streaming/cache.py @@ -51,12 +51,12 @@ def __init__( """The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements together in order to accelerate fetching. - Arguments: - --------- + Args: input_dir: The path to where the chunks will be or are stored. subsampled_files: List of subsampled chunk files loaded from `input_dir/index.json` file. region_of_interest: List of tuples of (start,end) of region of interest for each chunk. compression: The name of the algorithm to reduce the size of the chunks. + encryption: The encryption algorithm to use. chunk_bytes: The maximum number of bytes within a chunk. chunk_size: The maximum number of items within a chunk. item_loader: The object responsible to generate the chunk intervals and load an item froma chunk. diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index dad1bb7d..41589226 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -25,7 +25,7 @@ class CombinedStreamingDataset(IterableDataset): - """The `CombinedStreamingDataset` enables to stream data from multiple StreamingDataset with the sampling ratio of + """Enables to stream data from multiple StreamingDataset with the sampling ratio of your choice. Addtionally, the `CombinedStreamingDataset` keeps track of the number of samples fetched to enable resumability @@ -43,13 +43,16 @@ def __init__( weights: Optional[Sequence[float]] = None, iterate_over_all: bool = True, ) -> None: - """ " + """Enable to stream data from multiple StreamingDataset with the sampling ratio of your choice. + Arguments: + --------- datasets: The list of the StreamingDataset to use. seed: The random seed to initialize the sampler weights: The sampling ratio for the datasets iterate_over_all: When iterate_over_all is True, the combined dataset iterates over all the datasets. Otherwise, it stops as soon as one raises a StopIteration. + """ self._check_datasets(datasets) diff --git a/src/litdata/streaming/config.py b/src/litdata/streaming/config.py index 67ad2dcb..df0ea012 100644 --- a/src/litdata/streaming/config.py +++ b/src/litdata/streaming/config.py @@ -35,15 +35,14 @@ def __init__( region_of_interest: Optional[List[Tuple[int, int]]] = None, storage_options: Optional[Dict] = {}, ) -> None: - """The ChunksConfig reads the index files associated a chunked dataset and enables to map an index to its - chunk. + """Reads the index files associated a chunked dataset and enables to map an index to its chunk. Arguments: - --------- cache_dir: The path to cache folder. serializers: The serializers used to serialize and deserialize the chunks. remote_dir: The path to a remote folder where the data are located. The scheme needs to be added to the path. + item_loader: The item loader used to load the data from the chunks. subsampled_files: List of subsampled chunk files loaded from `input_dir/index.json` file. region_of_interest: List of tuples of {start,end} of region of interest for each chunk. storage_options: Additional connection options for accessing storage services. diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index b01597ac..30d1467a 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -85,7 +85,7 @@ def __init__( dataset: The dataset of the user cache_dir: The folder where the chunks are written to. chunk_bytes: The maximal number of bytes to write within a chunk. - chunk_sie: The maximal number of items to write to a chunk. + chunk_size: The maximal number of items to write to a chunk. compression: The compression algorithm to use to reduce the size of the chunk. """ diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index ab663e96..6d70537f 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -70,11 +70,11 @@ def state_dict(self) -> Dict: @abstractmethod def generate_intervals(self) -> List[Interval]: - """Returns a list of intervals: [chunk_start, - region_of_interest_start, region_of_interest_end, chunk_end] + """Returns a list of intervals. - region_of_interest: indicates the indexes a chunk our StreamingDataset is allowed to read. + The structure is: [chunk_start, region_of_interest_start, region_of_interest_end, chunk_end] + region_of_interest: indicates the indexes a chunk our StreamingDataset is allowed to read. """ pass diff --git a/src/litdata/streaming/shuffle.py b/src/litdata/streaming/shuffle.py index 0edb3415..90e1085b 100644 --- a/src/litdata/streaming/shuffle.py +++ b/src/litdata/streaming/shuffle.py @@ -84,13 +84,10 @@ class FullShuffle(Shuffle): """FullShuffle shuffles the chunks and associates them to the ranks. As the number of items in a chunk varies, it is possible for a rank to end up with more or less items. - To ensure the same fixed dataset length for all ranks while dropping as few items as possible, - we adopt the following strategy. We compute the maximum number of items per rank (M) and iterate through the chunks and ranks - until we have associated at least M items per rank. As a result, we lose at most (number of ranks) items. However, as some chunks are shared across ranks. This leads to diff --git a/src/litdata/streaming/writer.py b/src/litdata/streaming/writer.py index f4c93f33..bae647b8 100644 --- a/src/litdata/streaming/writer.py +++ b/src/litdata/streaming/writer.py @@ -58,15 +58,16 @@ def __init__( ): """The BinaryWriter enables to chunk dataset into an efficient streaming format for cloud training. - Arguments: - --------- + Args: cache_dir: The path to where the chunks will be saved. chunk_bytes: The maximum number of bytes within a chunk. chunk_size: The maximum number of items within a chunk. compression: The compression algorithm to use. encryption: The encryption algorithm to use. + follow_tensor_dimension: Whether to follow the tensor dimension when serializing the data. serializers: Provide your own serializers. chunk_index: The index of the chunk to start from. + item_loader: The object responsible to generate the chunk intervals and load an item from a chunk. """ self._cache_dir = cache_dir diff --git a/src/litdata/utilities/broadcast.py b/src/litdata/utilities/broadcast.py index 8968a905..8d94c5c0 100644 --- a/src/litdata/utilities/broadcast.py +++ b/src/litdata/utilities/broadcast.py @@ -48,9 +48,7 @@ def _response(r: Any, *args: Any, **kwargs: Any) -> Any: class _HTTPClient: - """A wrapper class around the requests library which handles chores like logging, retries, and timeouts - automatically. - """ + """A wrapper around the requests library which handles chores like logging, retries, and timeouts automatically.""" def __init__( self, diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index 45e21bd7..bf31b780 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -210,7 +210,7 @@ def load_index_file(input_dir: str) -> Dict[str, Any]: def adapt_mds_shards_to_chunks(data: Dict[str, Any]) -> Dict[str, Any]: """Adapt mds shard-based index data to chunk-based format for compatibility. - For more details about MDS, refer to the MosaicML Streaming documentation: https://github.com/mosaicml/streaming + For more details about MDS, refer to the MosaicML Streaming documentation: https://github.com/mosaicml/streaming. Args: ---- diff --git a/src/litdata/utilities/train_test_split.py b/src/litdata/utilities/train_test_split.py index 730d2c64..a44b8b5c 100644 --- a/src/litdata/utilities/train_test_split.py +++ b/src/litdata/utilities/train_test_split.py @@ -20,18 +20,16 @@ def train_test_split( These subsets can be used for training, testing, and validation purposes. Args: - ---- - streaming_dataset (StreamingDataset): An instance of StreamingDataset that needs to be split. - splits (List[float]): A list of floats representing the proportion of data to be allocated to each split + streaming_dataset: An instance of StreamingDataset that needs to be split. + splits: A list of floats representing the proportion of data to be allocated to each split (e.g., [0.8, 0.1, 0.1] for 80% training, 10% testing, and 10% validation). + seed: An integer used to seed the random number generator for reproducibility. Returns: - ------- List[StreamingDataset]: A list of StreamingDataset instances, where each element represents a split of the original dataset according to the proportions specified in the 'splits' argument. Raises: - ------ ValueError: If any element in the 'splits' list is not a float between 0 (inclusive) and 1 (exclusive). ValueError: If the sum of the values in the 'splits' list is greater than 1. Exception: If the provided StreamingDataset is already a subsample (not currently supported). diff --git a/tests/conftest.py b/tests/conftest.py index 32a86645..a4b1f707 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,7 @@ @pytest.fixture(autouse=True) def teardown_process_group(): # noqa: PT004 - """Ensures that the distributed process group gets closed before the next test runs.""" + """Ensures distributed process group gets closed before the next test runs.""" yield if torch.distributed.is_available() and torch.distributed.is_initialized(): torch.distributed.destroy_process_group() @@ -84,7 +84,7 @@ def lightning_sdk_mock(monkeypatch): @pytest.fixture(autouse=True) def _thread_police(): - """Attempts to stop left-over threads to avoid test interactions. + """Attempts stopping left-over threads to avoid test interactions. Adapted from PyTorch Lightning. diff --git a/tests/processing/test_data_processor.py b/tests/processing/test_data_processor.py index 75153405..cc3ef9b2 100644 --- a/tests/processing/test_data_processor.py +++ b/tests/processing/test_data_processor.py @@ -460,7 +460,7 @@ def _broadcast_object(self, obj: Any) -> Any: condition=(not _PIL_AVAILABLE or sys.platform == "win32" or sys.platform == "linux"), reason="Requires: ['pil']" ) def test_data_processsor_distributed(fast_dev_run, delete_cached_files, tmpdir, monkeypatch): - """This test ensures the data optimizer works in a fully distributed settings.""" + """Ensures the data optimizer works in a fully distributed settings.""" seed_everything(42) monkeypatch.setattr(data_processor_module.os, "_exit", mock.MagicMock()) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 00d30464..ef93021c 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -892,9 +892,7 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir, shuffle=False): @pytest.mark.timeout(60) @pytest.mark.parametrize("shuffle", [True, False]) def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): - """This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have - the same size. - """ + """Tests resuming from a chunk past the first chunk, when subsequent chunks don't have the same size.""" s3_cache_dir = str(tmpdir / "s3cache") optimize_data_cache_dir = str(tmpdir / "optimize_data_cache") optimize_cache_dir = str(tmpdir / "optimize_cache")