Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
AudioClassificationData from_numpy and from_tensors (#745)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Sep 7, 2021
1 parent 75afa79 commit 282d43d
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 17 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for MP3 files to the `SpeechRecognition` task with librosa ([#726](https://github.com/PyTorchLightning/lightning-flash/pull/726))

- Added support for `from_numpy` and `from_tensors` to `AudioClassificationData` ([#745](https://github.com/PyTorchLightning/lightning-flash/pull/745))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
20 changes: 18 additions & 2 deletions flash/audio/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

from flash.audio.classification.transforms import default_transforms, train_default_transforms
from flash.core.data.data_source import (
DefaultDataKeys,
DefaultDataSources,
has_file_allowed_extension,
LoaderDataFrameDataSource,
NumpyDataSource,
PathsDataSource,
)
from flash.core.data.process import Deserializer, Preprocess
Expand All @@ -40,6 +42,18 @@ def spectrogram_loader(filepath: str):
return data


class AudioClassificationNumpyDataSource(NumpyDataSource):
def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]:
sample[DefaultDataKeys.INPUT] = np.transpose(sample[DefaultDataKeys.INPUT], (1, 2, 0))
return sample


class AudioClassificationTensorDataSource(AudioClassificationNumpyDataSource):
def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]:
sample[DefaultDataKeys.INPUT] = sample[DefaultDataKeys.INPUT].numpy()
return super().load_sample(sample, dataset=dataset)


class AudioClassificationPathsDataSource(PathsDataSource):
def __init__(self):
super().__init__(loader=spectrogram_loader, extensions=IMG_EXTENSIONS + NP_EXTENSIONS)
Expand All @@ -58,8 +72,8 @@ def __init__(
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
spectrogram_size: Tuple[int, int] = (128, 128),
time_mask_param: int = 80,
freq_mask_param: int = 80,
time_mask_param: Optional[int] = None,
freq_mask_param: Optional[int] = None,
deserializer: Optional["Deserializer"] = None,
):
self.spectrogram_size = spectrogram_size
Expand All @@ -76,6 +90,8 @@ def __init__(
DefaultDataSources.FOLDERS: AudioClassificationPathsDataSource(),
"data_frame": AudioClassificationDataFrameDataSource(),
DefaultDataSources.CSV: AudioClassificationDataFrameDataSource(),
DefaultDataSources.NUMPY: AudioClassificationNumpyDataSource(),
DefaultDataSources.TENSORS: AudioClassificationTensorDataSource(),
},
deserializer=deserializer or ImageDeserializer(),
default_data_source=DefaultDataSources.FILES,
Expand Down
23 changes: 13 additions & 10 deletions flash/audio/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, Tuple
from typing import Callable, Dict, Optional, Tuple

import torch
from torch import nn
Expand Down Expand Up @@ -43,14 +43,17 @@ def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable]


def train_default_transforms(
spectrogram_size: Tuple[int, int], time_mask_param: int, freq_mask_param: int
spectrogram_size: Tuple[int, int], time_mask_param: Optional[int], freq_mask_param: Optional[int]
) -> Dict[str, Callable]:
"""During training we apply the default transforms with additional ``TimeMasking`` and ``Frequency Masking``"""
transforms = {
"post_tensor_transform": nn.Sequential(
ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param)),
ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param)),
)
}
"""During training we apply the default transforms with optional ``TimeMasking`` and ``Frequency Masking``."""
augs = []

if time_mask_param is not None:
augs.append(ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param)))

if freq_mask_param is not None:
augs.append(ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param)))

return merge_transforms(default_transforms(spectrogram_size), transforms)
if len(augs) > 0:
return merge_transforms(default_transforms(spectrogram_size), {"post_tensor_transform": nn.Sequential(*augs)})
return default_transforms(spectrogram_size)
2 changes: 1 addition & 1 deletion flash/image/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")


NP_EXTENSIONS = (".npy", ".npz")
NP_EXTENSIONS = (".npy",)


def image_loader(filepath: str):
Expand Down
76 changes: 72 additions & 4 deletions tests/audio/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ def _rand_image(size: Tuple[int, int] = None):
def test_from_filepaths_smoke(tmpdir):
tmpdir = Path(tmpdir)

(tmpdir / "a").mkdir()
(tmpdir / "b").mkdir()
_rand_image().save(tmpdir / "a_1.png")
_rand_image().save(tmpdir / "b_1.png")

Expand All @@ -70,11 +68,82 @@ def test_from_filepaths_smoke(tmpdir):
assert sorted(list(labels.numpy())) == [1, 2]


@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
@pytest.mark.parametrize(
"data,from_function",
[
(torch.rand(3, 3, 64, 64), AudioClassificationData.from_tensors),
(np.random.rand(3, 3, 64, 64), AudioClassificationData.from_numpy),
],
)
def test_from_data(data, from_function):
img_data = from_function(
train_data=data,
train_targets=[0, 3, 6],
val_data=data,
val_targets=[1, 4, 7],
test_data=data,
test_targets=[2, 5, 8],
batch_size=2,
num_workers=0,
)

# check training data
data = next(iter(img_data.train_dataloader()))
imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2,)
assert labels.numpy()[0] in [0, 3, 6] # data comes shuffled here
assert labels.numpy()[1] in [0, 3, 6] # data comes shuffled here

# check validation data
data = next(iter(img_data.val_dataloader()))
imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2,)
assert list(labels.numpy()) == [1, 4]

# check test data
data = next(iter(img_data.test_dataloader()))
imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2,)
assert list(labels.numpy()) == [2, 5]


@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
def test_from_filepaths_numpy(tmpdir):
tmpdir = Path(tmpdir)

np.save(str(tmpdir / "a_1.npy"), np.random.rand(64, 64, 3))
np.save(str(tmpdir / "b_1.npy"), np.random.rand(64, 64, 3))

train_images = [
str(tmpdir / "a_1.npy"),
str(tmpdir / "b_1.npy"),
]

spectrograms_data = AudioClassificationData.from_files(
train_files=train_images,
train_targets=[1, 2],
batch_size=2,
num_workers=0,
)
assert spectrograms_data.train_dataloader() is not None
assert spectrograms_data.val_dataloader() is None
assert spectrograms_data.test_dataloader() is None

data = next(iter(spectrograms_data.train_dataloader()))
imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2,)
assert sorted(list(labels.numpy())) == [1, 2]


@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
def test_from_filepaths_list_image_paths(tmpdir):
tmpdir = Path(tmpdir)

(tmpdir / "e").mkdir()
_rand_image().save(tmpdir / "e_1.png")

train_images = [
Expand Down Expand Up @@ -122,7 +191,6 @@ def test_from_filepaths_list_image_paths(tmpdir):
def test_from_filepaths_visualise(tmpdir):
tmpdir = Path(tmpdir)

(tmpdir / "e").mkdir()
_rand_image().save(tmpdir / "e_1.png")

train_images = [
Expand Down

0 comments on commit 282d43d

Please sign in to comment.