Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor][Dataset] YesNo implementation #1127

Merged
merged 9 commits into from
Jan 19, 2021
Merged
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 29 additions & 36 deletions torchaudio/datasets/yesno.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,14 @@
extract_archive,
)

URL = "http://www.openslr.org/resources/1/waves_yesno.tar.gz"
FOLDER_IN_ARCHIVE = "waves_yesno"
_CHECKSUMS = {
"http://www.openslr.org/resources/1/waves_yesno.tar.gz":
"962ff6e904d2df1126132ecec6978786"
}


def load_yesno_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, int, List[int]]:
# Read label
labels = [int(c) for c in fileid.split("_")]

# Read wav
file_audio = os.path.join(path, fileid + ext_audio)
waveform, sample_rate = torchaudio.load(file_audio)

return waveform, sample_rate, labels
_RELEASE_CONFIGS = {
"release1": {
"folder_in_archive": "waves_yesno",
"url": "http://www.openslr.org/resources/1/waves_yesno.tar.gz",
"checksum": "30301975fd8c5cac4040c261c0852f57cfa8adbbad2ce78e77e4986957445f27",
}
}


class YESNO(Dataset):
Expand All @@ -43,25 +34,26 @@ class YESNO(Dataset):
Whether to download the dataset if it is not found at root path. (default: ``False``).
"""

_ext_audio = ".wav"

def __init__(self,
root: Union[str, Path],
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False) -> None:
def __init__(
self,
root: Union[str, Path],
url: str = _RELEASE_CONFIGS["release1"]["url"],
folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"],
download: bool = False
) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mthrok I need your help again here. I am not sure how to resolve the style issue here. I am not exactly sure how to resolve this. This is what I was using.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flake8 is no happy when indentation for continuous line is 4 spaces, because that is visually same as logical indentation. This often happens when if statement spans multiple lines or function indentation like this.

I think indenting 4 more spaces will resolve the issue.

def __init__(
        self,
        ...
) -> None:
    ...


# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
self._parse_filesystem(root, url, folder_in_archive, download)

def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, download: bool) -> None:
root = Path(root)
archive = os.path.basename(url)
archive = os.path.join(root, archive)
self._path = os.path.join(root, folder_in_archive)
archive = root / archive

self._path = root / folder_in_archive
if download:
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
checksum = _CHECKSUMS.get(url, None)
checksum = _RELEASE_CONFIGS["release1"]["checksum"]
download_url(url, root, hash_value=checksum, hash_type="md5")
extract_archive(archive)

Expand All @@ -70,7 +62,13 @@ def __init__(self,
"Dataset not found. Please use `download=True` to download it."
)

self._walker = sorted(str(p.stem) for p in Path(self._path).glob('*' + self._ext_audio))
self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*.wav"))

def _load_item(self, fileid: str, path: str):
labels = [int(c) for c in fileid.split("_")]
file_audio = os.path.join(path, fileid + ".wav")
waveform, sample_rate = torchaudio.load(file_audio)
return waveform, sample_rate, labels

def __getitem__(self, n: int) -> Tuple[Tensor, int, List[int]]:
"""Load the n-th sample from the dataset.
Expand All @@ -82,13 +80,8 @@ def __getitem__(self, n: int) -> Tuple[Tensor, int, List[int]]:
tuple: ``(waveform, sample_rate, labels)``
"""
fileid = self._walker[n]
item = load_yesno_item(fileid, self._path, self._ext_audio)

# TODO Upon deprecation, uncomment line below and remove following code
# return item

waveform, sample_rate, labels = item
return waveform, sample_rate, labels
item = self._load_item(fileid, self._path)
return item

def __len__(self) -> int:
return len(self._walker)