Skip to content

Commit

Permalink
Merge branch 'main' into ruff/D
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Sep 17, 2024
2 parents b53cf5e + 8eb516a commit ecc378b
Show file tree
Hide file tree
Showing 21 changed files with 102 additions and 125 deletions.
9 changes: 6 additions & 3 deletions examples/multi_modal/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ def load_labelencoder(self):
return joblib.load(self.hyperparameters["label_encoder_name"])

def load_tokenizer(self):
"""Load the tokenizer files and the pre-training model path from s3 spezified in the hyperparameters
Returns: tokenizer.
"""Load the tokenizer files and the pre-training model path from s3 specified in the hyperparameters
Returns: tokenizer
"""
# Load Bert tokenizer
return BertTokenizerFast.from_pretrained("bert-base-cased")
Expand Down Expand Up @@ -92,6 +93,7 @@ class MixedDataModule(pl.LightningDataModule):

def __init__(self, hyperparameters: dict):
"""Init if the Data Module
Args:
data_path: dataframe with the data
hyperparameters: Hyperparameters.
Expand Down Expand Up @@ -125,7 +127,6 @@ def train_dataloader(self) -> DataLoader:
"""Define the training dataloader.
Returns:
-------
training dataloader.
"""
Expand All @@ -145,6 +146,7 @@ def train_dataloader(self) -> DataLoader:

def val_dataloader(self) -> DataLoader:
"""Define the validation dataloader
Returns:
validation dataloader.
"""
Expand All @@ -163,6 +165,7 @@ def val_dataloader(self) -> DataLoader:

def test_dataloader(self) -> DataLoader:
"""Define the test dataloader
Returns:
test dataloader.
"""
Expand Down
2 changes: 1 addition & 1 deletion examples/multi_modal/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def save_reports(self, model_dir, mode, report_confusion_matrix, report):
df_cr = pd.DataFrame(report).transpose()
df_cm.to_csv(f"{model_dir}/{mode}_confusion_matrix.csv", sep=";")
df_cr.to_csv(f"{model_dir}/{mode}_classification_report.csv", sep=";")
logger.info("Confusion Matrix and Classication report are saved.")
logger.info("Confusion Matrix and Classification report are saved.")

def save_test_evaluations(self, model_dir, mode, y_pred, y_true, confis, numerical_id_):
"""Save pandas dataframe with prediction and ground truth and identifier (numerical id) of the test dataset."""
Expand Down
10 changes: 4 additions & 6 deletions examples/multi_modal/model_arc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Architecure Bert & Resnet lightning."""
"""Architecture Bert & Resnet lightning."""

import logging

Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(self, endpoint_mode: bool, hyperparameters: dict):
self.dropout = nn.Dropout(self.hyperparameters["dropout"])

def get_bert_model(self):
"""Load the pre-trained bert model weigths.
"""Load the pre-trained bert model weights.
Returns: model.
"""
Expand All @@ -87,13 +87,11 @@ def forward(
validation.
Args:
----
x: Tensor with id tokesn
x: Tensor with id token
y: Tensor with attention tokens.
z: Tensor with iamge.
z: Tensor with image.
Returns:
-------
torch.Tensor: The output tensor representing the computational graph.
"""
Expand Down
9 changes: 2 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,6 @@ requires = [
"wheel",
]


[tool.black]
line-length = 120
exclude = '(_notebooks/.*)'

[tool.docformatter]
recursive = true
# this need to be shorter as some docstings are r"""...
Expand All @@ -37,13 +32,13 @@ blank = true

[tool.codespell]
# Todo: enable also python files in a next step
skip = '*.py'
#skip = '*.py'
quiet-level = 3
# comma separated list of words; waiting for:
# https://github.com/codespell-project/codespell/issues/2839#issuecomment-1731601603
# also adding links until they ignored by its: nature
# https://github.com/codespell-project/codespell/issues/2243#issuecomment-1732019960
ignore-words-list = "te, compiletime"
ignore-words-list = "cancelation"


[tool.ruff]
Expand Down
11 changes: 5 additions & 6 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def _collect_paths(self) -> None:

# For speed reasons, we assume starting with `self.input_dir` is enough to be a real file.
# Other alternative would be too slow.
# TODO: Try using dictionary for higher accurary.
# TODO: Try using dictionary for higher accuracy.
indexed_paths = {
index: _to_path(element)
for index, element in enumerate(flattened_item)
Expand Down Expand Up @@ -855,7 +855,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra

# Merge the index files generated by each node.
# Note: When using the Data Optimizer, they should be a single process on each node executing this section
# So no risk to get race conditon.
# So no risk to get race condition.
if num_nodes == node_rank + 1:
# Get the index file locally
for node_rank in range(num_nodes - 1):
Expand Down Expand Up @@ -1136,12 +1136,11 @@ def run(self, data_recipe: DataRecipe) -> None:
if num_nodes == node_rank + 1 and self.output_dir.url and self.output_dir.path is not None and _IS_IN_STUDIO:
from lightning_sdk.lightning_cloud.openapi import V1DatasetType

data_type = V1DatasetType.CHUNKED if isinstance(data_recipe, DataChunkRecipe) else V1DatasetType.TRANSFORMED
_create_dataset(
input_dir=self.input_dir.path,
storage_dir=self.output_dir.path,
dataset_type=V1DatasetType.CHUNKED
if isinstance(data_recipe, DataChunkRecipe)
else V1DatasetType.TRANSFORMED,
dataset_type=data_type,
empty=False,
size=result.size,
num_bytes=result.num_bytes,
Expand Down Expand Up @@ -1197,7 +1196,7 @@ def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: L
self.stop_queues = stop_queues

def _signal_handler(self, signal: Any, frame: Any) -> None:
"""On temrination, we stop all the processes to avoid leaking RAM."""
"""On termination, we stop all the processes to avoid leaking RAM."""
for stop_queue in self.stop_queues:
stop_queue.put(None)
for w in self.workers:
Expand Down
5 changes: 2 additions & 3 deletions src/litdata/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def prepare_structure(self, input_dir: Optional[str]) -> Any:
return self._inputs

def prepare_item(self, item_metadata: Any) -> Any:
"""Being overriden dynamically."""
"""Being overridden dynamically."""


def map(
Expand Down Expand Up @@ -524,8 +524,7 @@ class CopyInfo:
def merge_datasets(input_dirs: List[str], output_dir: str) -> None:
"""Enables to merge multiple existing optimized datasets into a single optimized dataset.
Arguments:
---------
Args:
input_dirs: A list of directories pointing to the existing optimized datasets.
output_dir: The directory where the merged dataset would be stored.
Expand Down
2 changes: 0 additions & 2 deletions src/litdata/processing/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,10 @@ def get_node_rank(self) -> int:
@abstractmethod
def remap_items(self, items: Any, num_workers: int) -> List[Any]:
"""Remap the items provided by the users into items more adapted to be distributed."""
pass

@abstractmethod
def read(self, item: Any) -> Any:
"""Read the data associated to an item."""
pass


class ParquetReader(BaseReader):
Expand Down
2 changes: 1 addition & 1 deletion src/litdata/streaming/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class CombinedStreamingDataset(IterableDataset):
"""Enables to stream data from multiple StreamingDataset with the sampling ratio of
your choice.
Addtionally, the `CombinedStreamingDataset` keeps track of the number of samples fetched to enable resumability
Additionally, the `CombinedStreamingDataset` keeps track of the number of samples fetched to enable reusability
of the datasets.
Note that due to the random sampling, the number of samples returned from the iterator is variable and a function
Expand Down
12 changes: 6 additions & 6 deletions src/litdata/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(
chunk_size: Optional[int],
compression: Optional[str],
):
"""The `CacheDataset` is a dataset wraper to provide a beginner experience with the Cache.
"""The `CacheDataset` is a dataset wrapper to provide a beginner experience with the Cache.
Arguments:
---------
Expand Down Expand Up @@ -127,7 +127,7 @@ def __call__(self, items: List[Any]) -> Any:
if all(item is None for item in items):
return None

# If the __getitem__ method is asynchornous, collect all the items.
# If the __getitem__ method is asynchronous, collect all the items.
if all(inspect.iscoroutine(item) for item in items):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
Expand All @@ -137,7 +137,7 @@ def __call__(self, items: List[Any]) -> Any:


class _SingleProcessDataLoaderIterPatch(_SingleProcessDataLoaderIter):
"""This is overriden to inform the cache is done chunking."""
"""This is ç to inform the cache is done chunking."""

def _next_data(self) -> Any:
try:
Expand Down Expand Up @@ -339,7 +339,7 @@ def __init__(
)

def _get_iterator(self) -> "_BaseDataLoaderIter":
"""Overriden to ensure the `Cache.done()` method is triggered on iteration done."""
"""Overridden to ensure the `Cache.done()` method is triggered on iteration done."""
if self.num_workers == 0:
return _SingleProcessDataLoaderIterPatch(self)
self.check_worker_number_rationality()
Expand Down Expand Up @@ -712,7 +712,7 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None:

# Inform that the dataloader is resuming.
# TODO: Check if the number of samples yielded is less than the length of the dataset.
# Also, len is not available for CombinedStreamingDataset incase of provided weights.
# Also, len is not available for CombinedStreamingDataset in case of provided weights.
self.restore = True

elif isinstance(self.dataset, StreamingDataset):
Expand All @@ -725,7 +725,7 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None:
raise RuntimeError("The provided dataset should be a `StreamingDataset` or a `CombinedStreamingDataset`.")

def _get_iterator(self) -> "_BaseDataLoaderIter":
"""Overriden to ensure the `Cache.done()` method is triggered on iteration done."""
"""Overridden to ensure the `Cache.done()` method is triggered on iteration done."""
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
self.check_worker_number_rationality()
Expand Down
Loading

0 comments on commit ecc378b

Please sign in to comment.