Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Feature/task a thon audio classification spectrograms #594

Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ jobs:
python-version: 3.8
requires: 'latest'
topic: ['graph']
- os: ubuntu-20.04
python-version: 3.8
requires: 'latest'
topic: ['audio']

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 35
Expand Down Expand Up @@ -128,6 +132,11 @@ jobs:
run: |
pip install '.[all]' --pre --upgrade

- name: Install audio test dependencies
if: matrix.topic[0] == 'audio'
run: |
pip install '.[image]' --pre --upgrade

- name: Cache datasets
uses: actions/cache@v2
with:
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for `field` parameter for loadng JSON based datasets in text tasks. ([#585](https://github.com/PyTorchLightning/lightning-flash/pull/585))

- Added `AudioClassificationData` and an example for classifying audio spectrograms ([#594](https://github.com/PyTorchLightning/lightning-flash/pull/594))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
6 changes: 6 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ Lightning Flash
reference/style_transfer
reference/video_classification

.. toctree::
:maxdepth: 1
:caption: Audio

reference/audio_classification

.. toctree::
:maxdepth: 1
:caption: Tabular
Expand Down
73 changes: 73 additions & 0 deletions docs/source/reference/audio_classification.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@

.. _audio_classification:

####################
Audio Classification
####################

********
The Task
********

The task of identifying what is in an audio file is called audio classification.
Typically, Audio Classification is used to identify audio files containing sounds or words.
The task predicts which ‘class’ the sound or words most likely belongs to with a degree of certainty.
A class is a label that describes the sounds in an audio file, such as ‘children_playing’, ‘jackhammer’, ‘siren’ etc.

------

*******
Example
*******

Let's look at the task of predicting whether audio file contains sounds of an airconditioner, carhorn, childrenplaying, dogbark, drilling, engingeidling, gunshot, jackhammer, siren, or street_music using the UrbanSound8k spectrogram images dataset.
The dataset contains ``train``, ``val`` and ``test`` folders, and then each folder contains a **airconditioner** folder, with spectrograms generated from air-conditioner sounds, **siren** folder with spectrograms generated from siren sounds and the same goes for the other classes.

.. code-block::

urban8k_images
├── train
│ ├── air_conditioner
│ ├── car_horn
│ ├── children_playing
│ ├── dog_bark
│ ├── drilling
│ ├── engine_idling
│ ├── gun_shot
│ ├── jackhammer
│ ├── siren
│ └── street_music
├── test
│ ├── air_conditioner
│ ├── car_horn
│ ├── children_playing
│ ├── dog_bark
│ ├── drilling
│ ├── engine_idling
│ ├── gun_shot
│ ├── jackhammer
│ ├── siren
│ └── street_music
└── val
├── air_conditioner
├── car_horn
├── children_playing
├── dog_bark
├── drilling
├── engine_idling
├── gun_shot
├── jackhammer
├── siren
└── street_music

...

Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.audio.classification.data.AudioClassificationData`.
We select a pre-trained backbone to use for our :class:`~flash.image.classification.model.ImageClassifier` and fine-tune on the UrbanSound8k spectrogram images data.
We then use the trained :class:`~flash.image.classification.model.ImageClassifier` for inference.
Finally, we save the model.
Here's the full example:

.. literalinclude:: ../../../flash_examples/audio_classification.py
:language: python
:lines: 14-
1 change: 1 addition & 0 deletions flash/audio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess # noqa: F401
1 change: 1 addition & 0 deletions flash/audio/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from flash.audio.classification.data import AudioClassificationData, AudioClassificationPreprocess # noqa: F401
87 changes: 87 additions & 0 deletions flash/audio/classification/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Optional, Tuple

from flash.audio.classification.transforms import default_transforms, train_default_transforms
from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DefaultDataSources
from flash.core.data.process import Deserializer, Preprocess
from flash.core.utilities.imports import requires_extras
from flash.image.classification.data import MatplotlibVisualization
from flash.image.data import ImageDeserializer, ImagePathsDataSource


class AudioClassificationPreprocess(Preprocess):

@requires_extras(["audio", "image"])
def __init__(
self,
train_transform: Optional[Dict[str, Callable]],
val_transform: Optional[Dict[str, Callable]],
test_transform: Optional[Dict[str, Callable]],
predict_transform: Optional[Dict[str, Callable]],
spectrogram_size: Tuple[int, int] = (196, 196),
time_mask_param: int = 80,
freq_mask_param: int = 80,
deserializer: Optional['Deserializer'] = None,
):
self.spectrogram_size = spectrogram_size
self.time_mask_param = time_mask_param
self.freq_mask_param = freq_mask_param

super().__init__(
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.FILES: ImagePathsDataSource(),
DefaultDataSources.FOLDERS: ImagePathsDataSource()
},
deserializer=deserializer or ImageDeserializer(),
default_data_source=DefaultDataSources.FILES,
)

def get_state_dict(self) -> Dict[str, Any]:
return {
**self.transforms,
"spectrogram_size": self.spectrogram_size,
"time_mask_param": self.time_mask_param,
"freq_mask_param": self.freq_mask_param,
}

@classmethod
def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
return cls(**state_dict)

def default_transforms(self) -> Optional[Dict[str, Callable]]:
return default_transforms(self.spectrogram_size)

def train_default_transforms(self) -> Optional[Dict[str, Callable]]:
return train_default_transforms(self.spectrogram_size, self.time_mask_param, self.freq_mask_param)


class AudioClassificationData(DataModule):
"""Data module for audio classification."""

preprocess_cls = AudioClassificationPreprocess

def set_block_viz_window(self, value: bool) -> None:
"""Setter method to switch on/off matplotlib to pop up windows."""
self.data_fetcher.block_viz_window = value

@staticmethod
def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
return MatplotlibVisualization(*args, **kwargs)
72 changes: 72 additions & 0 deletions flash/audio/classification/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Callable, Dict, Tuple

import torch
from torch import nn

from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.transforms import ApplyToKeys, kornia_collate, merge_transforms
from flash.core.utilities.imports import _KORNIA_AVAILABLE, _TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE

if _KORNIA_AVAILABLE:
import kornia as K

if _TORCHVISION_AVAILABLE:
import torchvision
from torchvision import transforms as T

if _TORCHAUDIO_AVAILABLE:
from torchaudio import transforms as TAudio


def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable]:
"""The default transforms for audio classification for spectrograms: resize the spectrogram,
convert the spectrogram and target to a tensor, and collate the batch."""
if _KORNIA_AVAILABLE and os.getenv("FLASH_TESTING", "0") != "1":
# Better approach as all transforms are applied on tensor directly
return {
"to_tensor_transform": nn.Sequential(
ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()),
ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
),
"post_tensor_transform": ApplyToKeys(
DefaultDataKeys.INPUT,
K.geometry.Resize(spectrogram_size),
),
"collate": kornia_collate,
}
return {
"pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(spectrogram_size)),
"to_tensor_transform": nn.Sequential(
ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()),
ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
),
"collate": kornia_collate,
}


def train_default_transforms(spectrogram_size: Tuple[int, int], time_mask_param: int,
freq_mask_param: int) -> Dict[str, Callable]:
"""During training we apply the default transforms with aditional ``TimeMasking`` and ``Frequency Masking``"""
if os.getenv("FLASH_TESTING", "0") != 1:
transforms = {
"post_tensor_transform": nn.Sequential(
ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param)),
ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param))
)
}

return merge_transforms(default_transforms(spectrogram_size), transforms)
26 changes: 19 additions & 7 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import operator
import types
from importlib.util import find_spec
from typing import Callable, List, Union

from pkg_resources import DistributionNotFound

Expand Down Expand Up @@ -89,6 +90,7 @@ def _compare_version(package: str, op, version) -> bool:
_TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter")
_TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse")
_TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric")
_TORCHAUDIO_AVAILABLE = _module_available("torchaudio")

if Version:
_TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0")
Expand All @@ -108,6 +110,7 @@ def _compare_version(package: str, op, version) -> bool:
_POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE
_AUDIO_AVAILABLE = all([
_ASTEROID_AVAILABLE,
_TORCHAUDIO_AVAILABLE,
])
_GRAPH_AVAILABLE = _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE

Expand All @@ -123,15 +126,22 @@ def _compare_version(package: str, op, version) -> bool:
}


def _requires(module_path: str, module_available: bool):
def _requires(
module_paths: Union[str, List],
module_available: Callable[[str], bool],
formatter: Callable[[List[str]], str],
):

if not isinstance(module_paths, list):
module_paths = [module_paths]

def decorator(func):
if not module_available:
if not all(module_available(module_path) for module_path in module_paths):

@functools.wraps(func)
def wrapper(*args, **kwargs):
raise ModuleNotFoundError(
f"Required dependencies not available. Please run: pip install '{module_path}'"
f"Required dependencies not available. Please run: pip install {formatter(module_paths)}"
)

return wrapper
Expand All @@ -141,12 +151,14 @@ def wrapper(*args, **kwargs):
return decorator


def requires(module_path: str):
return _requires(module_path, _module_available(module_path))
def requires(module_paths: Union[str, List]):
return _requires(module_paths, _module_available, lambda module_paths: " ".join(module_paths))


def requires_extras(extras: str):
return _requires(f"lightning-flash[{extras}]", _EXTRAS_AVAILABLE[extras])
def requires_extras(extras: Union[str, List]):
return _requires(
extras, lambda extras: _EXTRAS_AVAILABLE[extras], lambda extras: f"'lightning-flash[{','.join(extras)}]'"
)


def lazy_import(module_name, callback=None):
Expand Down
45 changes: 45 additions & 0 deletions flash_examples/audio_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import flash
from flash.audio import AudioClassificationData
from flash.core.data.utils import download_data
from flash.core.finetuning import FreezeUnfreeze
from flash.image import ImageClassifier

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/urban8k_images.zip", "./data")

datamodule = AudioClassificationData.from_folders(
train_folder="data/urban8k_images/train",
val_folder="data/urban8k_images/val",
spectrogram_size=(64, 64),
)

# 2. Build the model.
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3)
trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))

# 4. Predict what's on few images! air_conditioner, children_playing, siren e.t.c
predictions = model.predict([
"data/urban8k_images/test/air_conditioner/13230-0-0-5.wav.jpg",
"data/urban8k_images/test/children_playing/9223-2-0-15.wav.jpg",
"data/urban8k_images/test/jackhammer/22883-7-10-0.wav.jpg",
])
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("audio_classification_model.pt")
Loading