Skip to content

Commit

Permalink
modifications
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen committed Aug 24, 2022
1 parent bb94647 commit b3a1a4a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 39 deletions.
4 changes: 2 additions & 2 deletions torchaudio/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .gtzan import GTZAN
from .librilight_limited import LibriLightLimited
from .librimix import LibriMix
from .librispeech import LIBRISPEECH, LIBRISPEECHBase
from .librispeech import LIBRISPEECH, LibriSpeechBase
from .libritts import LIBRITTS
from .ljspeech import LJSPEECH
from .musdb_hq import MUSDB_HQ
Expand All @@ -21,7 +21,7 @@
__all__ = [
"COMMONVOICE",
"LIBRISPEECH",
"LIBRISPEECHBase",
"LibriSpeechBase",
"LibriLightLimited",
"SPEECHCOMMANDS",
"VCTK_092",
Expand Down
69 changes: 32 additions & 37 deletions torchaudio/datasets/librispeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,19 @@ def download_librispeech(root, url):


def get_librispeech_metadata(
fileid: str, path: str, ext_audio: str, ext_txt: str
fileid: str, root: str, subset: str, ext_audio: str, ext_txt: str
) -> Tuple[str, int, str, int, int, int]:
speaker_id, chapter_id, utterance_id = fileid.split("-")

# Get audio path and sample rate
fileid_audio = f"{speaker_id}-{chapter_id}-{utterance_id}"
filepath = os.path.join(speaker_id, chapter_id, f"{fileid_audio}{ext_audio}")
full_path = os.path.join(path, filepath)
filepath = os.path.join(subset, speaker_id, chapter_id, f"{fileid_audio}{ext_audio}")
full_path = os.path.join(root, filepath)
sample_rate = torchaudio.info(full_path).sample_rate

# Load text
file_text = f"{speaker_id}-{chapter_id}{ext_txt}"
file_text = os.path.join(path, speaker_id, chapter_id, file_text)
file_text = os.path.join(root, subset, speaker_id, chapter_id, file_text)
with open(file_text) as ft:
for line in ft:
fileid_text, transcript = line.strip().split(" ", 1)
Expand All @@ -79,11 +79,14 @@ def get_librispeech_metadata(
def load_librispeech_item(
fileid: str, path: str, ext_audio: str, ext_txt: str
) -> Tuple[Tensor, int, str, int, int, int]:
path = os.path.normpath(path)
root = os.path.dirname(path)
subset = os.path.basename(path)

filepath, sample_rate, transcript, speaker_id, chapter_id, utterance_id = get_librispeech_metadata(
fileid, path, ext_audio, ext_txt
fileid, root, subset, ext_audio, ext_txt
)
full_path = os.path.join(path, filepath)
waveform, _ = torchaudio.load(full_path)
waveform, _ = torchaudio.load(os.path.join(root, filepath))

return (
waveform,
Expand All @@ -95,47 +98,33 @@ def load_librispeech_item(
)


class LIBRISPEECHBase(Dataset):
class LibriSpeechBase(Dataset):
"""Create a Dataset for *LibriSpeech* [:footcite:`7178964`].
Args:
root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from,
or the type of the dataset to dowload.
Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``,
root (str or Path): Path to the directory where the dataset is found.
subset (str, optional): Subset of LibriSpeech to use.
Valid options: ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``,
``"test-other"``, ``"train-clean-100"``, ``"train-clean-360"`` and
``"train-other-500"``. (default: ``"train-clean-100"``)
folder_in_archive (str, optional):
The top-level directory of the dataset. (default: ``"LibriSpeech"``)
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
"""

_ext_txt = ".trans.txt"
_ext_audio = ".flac"

def __init__(
self,
root: Union[str, Path],
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False,
) -> None:
if url not in _DATA_SUBSETS:
raise ValueError(f"Invalid url '{url}' given; please provide one of {_DATA_SUBSETS}.")
def __init__(self, root: Union[str, Path], subset: str = "train-clean-100") -> None:
if subset not in _DATA_SUBSETS:
raise ValueError(f"Invalid subset '{subset}' given; please provide one of {_DATA_SUBSETS}.")

root = os.fspath(root)
self._path = os.path.join(root, folder_in_archive, url)
self._root = root
self._subset = subset
subset_path = os.path.join(root, subset)

if not os.path.isdir(self._path):
if download:
download_librispeech(root, url)
else:
raise RuntimeError(
f"Dataset not found at {self._path}. Please set `download=True` to download the dataset."
)
if not os.path.isdir(subset_path):
raise RuntimeError(f"Dataset not found at {subset_path}.")

self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio))
self._walker = sorted(str(p.stem) for p in Path(subset_path).glob("*/*/*" + self._ext_audio))

def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
"""Load the n-th sample from the dataset.
Expand All @@ -148,13 +137,13 @@ def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
``(filepath, sample_rate, transcript, speaker_id, chapter_id, utterance_id)``
"""
fileid = self._walker[n]
return get_librispeech_metadata(fileid, self._path, self._ext_audio, self._ext_txt)
return get_librispeech_metadata(fileid, self._root, self._subset, self._ext_audio, self._ext_txt)

def __len__(self) -> int:
return len(self._walker)


class LIBRISPEECH(LIBRISPEECHBase):
class LIBRISPEECH(LibriSpeechBase):
"""Create a Dataset for *LibriSpeech* [:footcite:`7178964`].
Args:
Expand All @@ -177,7 +166,13 @@ def __init__(
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False,
) -> None:
super().__init__(root)
archive = os.path.join(root, folder_in_archive)
self._path = os.path.join(archive, url)

if not os.path.isdir(self._path) and download:
download_librispeech(archive, url)

super().__init__(root=root, subset=url)

def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
"""Load the n-th sample from the dataset.
Expand Down

0 comments on commit b3a1a4a

Please sign in to comment.