Skip to content

Commit

Permalink
feat: add fastscan dir, add corrupt check, add silence check
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Oct 24, 2022
1 parent 1d24721 commit 348c6d5
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 12 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ WAVDataset(
recursive: bool = False # Recursively load files from provided paths
with_sample_rate: bool = False, # Returns sample rate as second argument
transforms: Optional[Callable] = None, # Transforms to apply to audio files
check_silence: bool = True # Discards silent samples if true
)
```

Expand Down
3 changes: 1 addition & 2 deletions audio_data_pytorch/datasets/audio_web_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def process_wav(self):
waveform = self.transforms(waveform)

wav_dest_path = f"{self.path_prefix}.wav"
print(wav_dest_path)
torchaudio.save(wav_dest_path, waveform, rate)

self.wav_dest_path = wav_dest_path
Expand Down Expand Up @@ -106,7 +105,7 @@ async def preprocess(self):
waveform_id = 0

async with Downloader(urls, path=path) as files:
async with Decompressor(files, path=path) as folders:
async with Decompressor(files, path=path, remove_on_exit=True) as folders:
with tarfile.open(tarfile_name, "w") as archive:
for folder in tqdm(folders):
for wav in tqdm(glob.glob(folder + "/**/*.wav")):
Expand Down
4 changes: 3 additions & 1 deletion audio_data_pytorch/datasets/clotho_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ async def preprocess(self):
async with Downloader(urls, path=path) as files:
to_decompress = [f for f in files if f.endswith(".7z")]
caption_csv_file = [f for f in files if f.endswith(".csv")][0]
async with Decompressor(to_decompress, path=path) as folders:
async with Decompressor(
to_decompress, path=path, remove_on_exit=True
) as folders:
captions = pd.read_csv(caption_csv_file)
length = len(captions.index)

Expand Down
38 changes: 30 additions & 8 deletions audio_data_pytorch/datasets/wav_dataset.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import glob
import os
import random
from typing import Callable, List, Optional, Sequence, Tuple, Union

import torch
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset

from ..utils import fast_scandir, is_silence


def get_all_wav_filenames(paths: Sequence[str], recursive: bool) -> List[str]:
extensions = ["wav", "flac"]
extensions = [".wav", ".flac"]
filenames = []
for ext_name in extensions:
ext = f"**/*.{ext_name}" if recursive else f"*.{ext_name}"
for path in paths:
filenames.extend(glob.glob(os.path.join(path, ext), recursive=recursive))
for path in paths:
_, files = fast_scandir(path, extensions, recursive=recursive)
filenames.extend(files)
return filenames


Expand All @@ -25,26 +25,48 @@ def __init__(
recursive: bool = False,
transforms: Optional[Callable] = None,
sample_rate: Optional[int] = None,
check_silence: bool = True,
):
self.paths = path if isinstance(path, (list, tuple)) else [path]
self.wavs = get_all_wav_filenames(self.paths, recursive=recursive)
self.transforms = transforms
self.sample_rate = sample_rate
self.check_silence = check_silence

def __getitem__(
self, idx: Union[Tensor, int]
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
idx = idx.tolist() if torch.is_tensor(idx) else idx # type: ignore
waveform, sample_rate = torchaudio.load(self.wavs[idx])
invalid_audio = False

# Check that we can load audio properly
try:
waveform, sample_rate = torchaudio.load(self.wavs[idx])
except Exception:
invalid_audio = True

# Check that the sample is not silent
if not invalid_audio and self.check_silence and is_silence(waveform):
invalid_audio = True

# Get new sample if audio is invalid
if invalid_audio:
return self[random.randrange(len(self))]

# Apply sample rate transform if necessary
if self.sample_rate and sample_rate != self.sample_rate:
waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=self.sample_rate
)(waveform)

# Apply other transforms
if self.transforms:
waveform = self.transforms(waveform)

# Check silence after transforms (useful for random crops)
if self.check_silence and is_silence(waveform):
return self[random.randrange(len(self))]

return waveform

def __len__(self) -> int:
Expand Down
43 changes: 43 additions & 0 deletions audio_data_pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import aiohttp
import torch
from torch import Tensor
from torch.utils.data.dataset import Dataset, Subset
from tqdm import tqdm
from typing_extensions import TypeGuard
Expand Down Expand Up @@ -46,6 +47,48 @@ def fractional_random_split(
return splits


"""
Audio utils
"""


def is_silence(audio: Tensor, thresh: int = -60):
dBmax = 20 * torch.log10(torch.flatten(audio.abs()).max())
return dBmax < thresh


"""
Data/async utils
"""


def fast_scandir(path: str, exts: List[str], recursive: bool = False):
# Scan files recursively faster than glob
# From github.com/drscotthawley/aeiou/blob/main/aeiou/core.py
subfolders, files = [], []

try: # hope to avoid 'permission denied' by this try
for f in os.scandir(path):
try: # 'hope to avoid too many levels of symbolic links' error
if f.is_dir():
subfolders.append(f.path)
elif f.is_file():
if os.path.splitext(f.name)[1].lower() in exts:
files.append(f.path)
except Exception:
pass
except Exception:
pass

if recursive:
for path in list(subfolders):
sf, f = fast_scandir(path, exts, recursive=recursive)
subfolders.extend(sf)
files.extend(f) # type: ignore

return subfolders, files


class RunThread(threading.Thread):
def __init__(self, func):
self.func = func
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-data-pytorch",
packages=find_packages(exclude=[]),
version="0.0.16",
version="0.0.17",
license="MIT",
description="Audio Data - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 348c6d5

Please sign in to comment.