Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Audio data sources + Numpy file support #651

Merged
merged 23 commits into from
Aug 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,5 @@ logs/cache/*
flash_examples/data
flash_examples/cli/*/data
timit/
urban8k_images/
__MACOSX
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,20 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added Flash Zero, a zero code command line ML platform built with flash ([#611](https://github.com/PyTorchLightning/lightning-flash/pull/611))

- Added support for `.npy` and `.npz` files to `ImageClassificationData` and `AudioClassificationData` ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651))

- Added support for `from_csv` to the `AudioClassificationData` ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651))

- Added option to pass a `resolver` to the `from_csv` and `from_pandas` methods of `ImageClassificationData`, which is used to resolve filenames given IDs ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))

- Removed bolts pretrained weights for SSL from ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))

- Changed the behaviour of the `sampler` argument of the `DataModule` to take a `Sampler` type rather than instantiated object ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651))

### Fixed

- Fixed a bug where serve sanity checking would not be triggered using the latest PyTorchLightning version ([#493](https://github.com/PyTorchLightning/lightning-flash/pull/493))
Expand All @@ -50,6 +58,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where some tasks were not compatible with PyTorch 1.7 due to use of `torch.jit.isinstance` ([#611](https://github.com/PyTorchLightning/lightning-flash/pull/611))

- Fixed a bug where custom samplers would not be properly forwarded to the data loader ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651))

## [0.4.0] - 2021-06-22

### Added
Expand Down
62 changes: 44 additions & 18 deletions flash/audio/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,46 @@
# limitations under the License.
from typing import Any, Callable, Dict, Optional, Tuple

import numpy as np

from flash.audio.classification.transforms import default_transforms, train_default_transforms
from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DefaultDataSources
from flash.core.data.data_source import (
DefaultDataSources,
has_file_allowed_extension,
LoaderDataFrameDataSource,
PathsDataSource,
)
from flash.core.data.process import Deserializer, Preprocess
from flash.core.utilities.imports import requires_extras
from flash.image.classification.data import MatplotlibVisualization
from flash.image.data import ImageDeserializer, ImagePathsDataSource
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, requires_extras
from flash.image.classification.data import ImageClassificationData
from flash.image.data import ImageDeserializer

if _TORCHVISION_AVAILABLE:
from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS


NP_EXTENSIONS = (".npy", ".npz")


def spectrogram_loader(filepath: str):
if has_file_allowed_extension(filepath, IMG_EXTENSIONS):
img = default_loader(filepath)
data = np.array(img)
else:
data = np.load(filepath)
return data


class AudioClassificationPathsDataSource(PathsDataSource):
@requires_extras("image")
def __init__(self):
super().__init__(loader=spectrogram_loader, extensions=IMG_EXTENSIONS + NP_EXTENSIONS)


class AudioClassificationDataFrameDataSource(LoaderDataFrameDataSource):
@requires_extras("image")
def __init__(self):
super().__init__(spectrogram_loader)


class AudioClassificationPreprocess(Preprocess):
Expand All @@ -31,7 +63,7 @@ def __init__(
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
spectrogram_size: Tuple[int, int] = (196, 196),
spectrogram_size: Tuple[int, int] = (128, 128),
time_mask_param: int = 80,
freq_mask_param: int = 80,
deserializer: Optional["Deserializer"] = None,
Expand All @@ -46,8 +78,10 @@ def __init__(
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.FILES: ImagePathsDataSource(),
DefaultDataSources.FOLDERS: ImagePathsDataSource(),
DefaultDataSources.FILES: AudioClassificationPathsDataSource(),
DefaultDataSources.FOLDERS: AudioClassificationPathsDataSource(),
"data_frame": AudioClassificationDataFrameDataSource(),
DefaultDataSources.CSV: AudioClassificationDataFrameDataSource(),
},
deserializer=deserializer or ImageDeserializer(),
default_data_source=DefaultDataSources.FILES,
Expand All @@ -72,15 +106,7 @@ def train_default_transforms(self) -> Optional[Dict[str, Callable]]:
return train_default_transforms(self.spectrogram_size, self.time_mask_param, self.freq_mask_param)


class AudioClassificationData(DataModule):
class AudioClassificationData(ImageClassificationData):
"""Data module for audio classification."""

preprocess_cls = AudioClassificationPreprocess

def set_block_viz_window(self, value: bool) -> None:
"""Setter method to switch on/off matplotlib to pop up windows."""
self.data_fetcher.block_viz_window = value

@staticmethod
def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
return MatplotlibVisualization(*args, **kwargs)
7 changes: 4 additions & 3 deletions flash/audio/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@

import torch
from torch import nn
from torch.utils.data._utils.collate import default_collate

from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.transforms import ApplyToKeys, kornia_collate, merge_transforms
from flash.core.data.transforms import ApplyToKeys, merge_transforms
from flash.core.utilities.imports import _TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
Expand All @@ -32,12 +33,12 @@ def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable]
"""The default transforms for audio classification for spectrograms: resize the spectrogram, convert the
spectrogram and target to a tensor, and collate the batch."""
return {
"pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(spectrogram_size)),
"to_tensor_transform": nn.Sequential(
ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()),
ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
),
"collate": kornia_collate,
"post_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(spectrogram_size)),
"collate": default_collate,
}


Expand Down
56 changes: 36 additions & 20 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,20 @@
# limitations under the License.
import os
import platform
from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
from typing import (
Any,
Callable,
Collection,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TYPE_CHECKING,
Union,
)

import numpy as np
import pytorch_lightning as pl
Expand Down Expand Up @@ -86,7 +99,7 @@ def __init__(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
sampler: Optional[Type[Sampler]] = None,
) -> None:

super().__init__()
Expand Down Expand Up @@ -281,7 +294,10 @@ def _train_dataloader(self) -> DataLoader:
pin_memory = True

if self.sampler is None:
sampler = None
shuffle = not isinstance(train_ds, (IterableDataset, IterableAutoDataset))
else:
sampler = self.sampler(train_ds)

if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_train_dataset(
Expand All @@ -292,14 +308,14 @@ def _train_dataloader(self) -> DataLoader:
shuffle=shuffle,
drop_last=drop_last,
collate_fn=collate_fn,
sampler=self.sampler,
sampler=sampler,
)

return DataLoader(
train_ds,
batch_size=self.batch_size,
shuffle=shuffle,
sampler=self.sampler,
sampler=sampler,
num_workers=self.num_workers,
pin_memory=pin_memory,
drop_last=drop_last,
Expand Down Expand Up @@ -453,7 +469,7 @@ def from_data_source(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given inputs to
Expand Down Expand Up @@ -489,7 +505,7 @@ def from_data_source(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.

Expand Down Expand Up @@ -553,7 +569,7 @@ def from_folders(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given folders using the
Expand Down Expand Up @@ -582,7 +598,7 @@ def from_folders(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.

Expand Down Expand Up @@ -636,7 +652,7 @@ def from_files(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given sequences of files
Expand Down Expand Up @@ -668,7 +684,7 @@ def from_files(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.

Expand Down Expand Up @@ -723,7 +739,7 @@ def from_tensors(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given tensors using the
Expand Down Expand Up @@ -755,7 +771,7 @@ def from_tensors(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.

Expand Down Expand Up @@ -810,7 +826,7 @@ def from_numpy(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given numpy array using the
Expand Down Expand Up @@ -842,7 +858,7 @@ def from_numpy(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.

Expand Down Expand Up @@ -896,7 +912,7 @@ def from_json(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
sampler: Optional[Type[Sampler]] = None,
field: Optional[str] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
Expand Down Expand Up @@ -928,7 +944,7 @@ def from_json(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
field: To specify the field that holds the data in the JSON file.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down Expand Up @@ -1006,7 +1022,7 @@ def from_csv(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given CSV files using the
Expand Down Expand Up @@ -1037,7 +1053,7 @@ def from_csv(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.

Expand Down Expand Up @@ -1090,7 +1106,7 @@ def from_datasets(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given datasets using the
Expand Down Expand Up @@ -1119,7 +1135,7 @@ def from_datasets(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.

Expand Down
Loading