Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

AudioClassificationData from_numpy and from_tensors #745

Merged
merged 5 commits into from
Sep 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for MP3 files to the `SpeechRecognition` task with librosa ([#726](https://github.com/PyTorchLightning/lightning-flash/pull/726))

- Added support for `from_numpy` and `from_tensors` to `AudioClassificationData` ([#745](https://github.com/PyTorchLightning/lightning-flash/pull/745))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
20 changes: 18 additions & 2 deletions flash/audio/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

from flash.audio.classification.transforms import default_transforms, train_default_transforms
from flash.core.data.data_source import (
DefaultDataKeys,
DefaultDataSources,
has_file_allowed_extension,
LoaderDataFrameDataSource,
NumpyDataSource,
PathsDataSource,
)
from flash.core.data.process import Deserializer, Preprocess
Expand All @@ -40,6 +42,18 @@ def spectrogram_loader(filepath: str):
return data


class AudioClassificationNumpyDataSource(NumpyDataSource):
def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]:
sample[DefaultDataKeys.INPUT] = np.transpose(sample[DefaultDataKeys.INPUT], (1, 2, 0))
return sample


class AudioClassificationTensorDataSource(AudioClassificationNumpyDataSource):
def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]:
sample[DefaultDataKeys.INPUT] = sample[DefaultDataKeys.INPUT].numpy()
return super().load_sample(sample, dataset=dataset)


class AudioClassificationPathsDataSource(PathsDataSource):
def __init__(self):
super().__init__(loader=spectrogram_loader, extensions=IMG_EXTENSIONS + NP_EXTENSIONS)
Expand All @@ -58,8 +72,8 @@ def __init__(
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
spectrogram_size: Tuple[int, int] = (128, 128),
time_mask_param: int = 80,
freq_mask_param: int = 80,
time_mask_param: Optional[int] = None,
freq_mask_param: Optional[int] = None,
deserializer: Optional["Deserializer"] = None,
):
self.spectrogram_size = spectrogram_size
Expand All @@ -76,6 +90,8 @@ def __init__(
DefaultDataSources.FOLDERS: AudioClassificationPathsDataSource(),
"data_frame": AudioClassificationDataFrameDataSource(),
DefaultDataSources.CSV: AudioClassificationDataFrameDataSource(),
DefaultDataSources.NUMPY: AudioClassificationNumpyDataSource(),
DefaultDataSources.TENSORS: AudioClassificationTensorDataSource(),
},
deserializer=deserializer or ImageDeserializer(),
default_data_source=DefaultDataSources.FILES,
Expand Down
23 changes: 13 additions & 10 deletions flash/audio/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, Tuple
from typing import Callable, Dict, Optional, Tuple

import torch
from torch import nn
Expand Down Expand Up @@ -43,14 +43,17 @@ def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable]


def train_default_transforms(
spectrogram_size: Tuple[int, int], time_mask_param: int, freq_mask_param: int
spectrogram_size: Tuple[int, int], time_mask_param: Optional[int], freq_mask_param: Optional[int]
) -> Dict[str, Callable]:
"""During training we apply the default transforms with additional ``TimeMasking`` and ``Frequency Masking``"""
transforms = {
"post_tensor_transform": nn.Sequential(
ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param)),
ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param)),
)
}
"""During training we apply the default transforms with optional ``TimeMasking`` and ``Frequency Masking``."""
augs = []

if time_mask_param is not None:
augs.append(ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param)))

if freq_mask_param is not None:
augs.append(ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param)))

return merge_transforms(default_transforms(spectrogram_size), transforms)
if len(augs) > 0:
return merge_transforms(default_transforms(spectrogram_size), {"post_tensor_transform": nn.Sequential(*augs)})
return default_transforms(spectrogram_size)
2 changes: 1 addition & 1 deletion flash/image/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")


NP_EXTENSIONS = (".npy", ".npz")
NP_EXTENSIONS = (".npy",)


def image_loader(filepath: str):
Expand Down
76 changes: 72 additions & 4 deletions tests/audio/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ def _rand_image(size: Tuple[int, int] = None):
def test_from_filepaths_smoke(tmpdir):
tmpdir = Path(tmpdir)

(tmpdir / "a").mkdir()
(tmpdir / "b").mkdir()
_rand_image().save(tmpdir / "a_1.png")
_rand_image().save(tmpdir / "b_1.png")

Expand All @@ -70,11 +68,82 @@ def test_from_filepaths_smoke(tmpdir):
assert sorted(list(labels.numpy())) == [1, 2]


@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
@pytest.mark.parametrize(
"data,from_function",
[
(torch.rand(3, 3, 64, 64), AudioClassificationData.from_tensors),
(np.random.rand(3, 3, 64, 64), AudioClassificationData.from_numpy),
],
)
def test_from_data(data, from_function):
img_data = from_function(
train_data=data,
train_targets=[0, 3, 6],
val_data=data,
val_targets=[1, 4, 7],
test_data=data,
test_targets=[2, 5, 8],
batch_size=2,
num_workers=0,
)

# check training data
data = next(iter(img_data.train_dataloader()))
imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2,)
assert labels.numpy()[0] in [0, 3, 6] # data comes shuffled here
assert labels.numpy()[1] in [0, 3, 6] # data comes shuffled here

# check validation data
data = next(iter(img_data.val_dataloader()))
imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2,)
assert list(labels.numpy()) == [1, 4]

# check test data
data = next(iter(img_data.test_dataloader()))
imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2,)
assert list(labels.numpy()) == [2, 5]


@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
def test_from_filepaths_numpy(tmpdir):
tmpdir = Path(tmpdir)

np.save(str(tmpdir / "a_1.npy"), np.random.rand(64, 64, 3))
np.save(str(tmpdir / "b_1.npy"), np.random.rand(64, 64, 3))

train_images = [
str(tmpdir / "a_1.npy"),
str(tmpdir / "b_1.npy"),
]

spectrograms_data = AudioClassificationData.from_files(
train_files=train_images,
train_targets=[1, 2],
batch_size=2,
num_workers=0,
)
assert spectrograms_data.train_dataloader() is not None
assert spectrograms_data.val_dataloader() is None
assert spectrograms_data.test_dataloader() is None

data = next(iter(spectrograms_data.train_dataloader()))
imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2,)
assert sorted(list(labels.numpy())) == [1, 2]


@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
def test_from_filepaths_list_image_paths(tmpdir):
tmpdir = Path(tmpdir)

(tmpdir / "e").mkdir()
_rand_image().save(tmpdir / "e_1.png")

train_images = [
Expand Down Expand Up @@ -122,7 +191,6 @@ def test_from_filepaths_list_image_paths(tmpdir):
def test_from_filepaths_visualise(tmpdir):
tmpdir = Path(tmpdir)

(tmpdir / "e").mkdir()
_rand_image().save(tmpdir / "e_1.png")

train_images = [
Expand Down