-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
144 lines (128 loc) · 4.29 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import torchaudio.transforms as T
import torch
from torchaudio.datasets import SPEECHCOMMANDS
from pytorch_lightning import LightningDataModule
import torch.utils.data
SC_CLASSES = [
"background_noise_",
"backward",
"bed",
"bird",
"cat",
"dog",
"down",
"eight",
"five",
"follow",
"forward",
"four",
"go",
"happy",
"house",
"learn",
"left",
"marvin",
"nine",
"no",
"off",
"on",
"one",
"right",
"seven",
"sheila",
"six",
"stop",
"three",
"tree",
"two",
"up",
"visual",
"wow",
"yes",
"zero",
]
class SubsetSC(SPEECHCOMMANDS):
def __init__(self, transform=None, subset: str = "", new_sample_rate=8000):
super().__init__("./", download=True)
self.transform = transform
def load_list(filename):
filepath = os.path.join(self._path, filename)
with open(filepath) as fileobj:
return [
os.path.normpath(os.path.join(self._path, line.strip()))
for line in fileobj
]
self.new_sample_rate=new_sample_rate
self.resample = T.Resample(orig_freq=16000, new_freq=self.new_sample_rate)
self.mean = torch.tensor(-2.7432e-06)
self.std = torch.tensor(0.7073)
# Create a dictionary that maps each unique label to a unique integer
self.label_to_int = {label: i for i, label in enumerate(sorted(SC_CLASSES))}
self.int_to_label = {i: label for i, label in enumerate(sorted(SC_CLASSES))}
if subset == "validation":
self._walker = load_list("validation_list.txt")
elif subset == "testing":
self._walker = load_list("testing_list.txt")
elif subset == "training":
excludes = load_list("validation_list.txt") + load_list("testing_list.txt")
excludes = set(excludes)
self._walker = [w for w in self._walker if w not in excludes]
def __getitem__(self, index):
item = super().__getitem__(index)
waveform = item[0]
label = item[2]
waveform = self.resample(waveform)
# Pad or trim the waveform to a fixed length (e.g., corresponding to 8000 samples)
if waveform.size(1) > self.new_sample_rate:
waveform = waveform[:, :self.new_sample_rate] # Trim
elif waveform.size(1) < self.new_sample_rate:
# Pad
padding_size = self.new_sample_rate - waveform.size(1)
padding = torch.zeros((waveform.size(0), padding_size))
waveform = torch.cat((waveform, padding), dim=1)
waveform = (waveform - self.mean) / self.std
label = torch.tensor(self.label_to_int[label])
return waveform, label
@staticmethod
def num_labels() -> int:
return len(SC_CLASSES)
class AudioDataModule(LightningDataModule):
def __init__(self, batch_size, num_workers, pin_memory, sample_rate_hz:int=8000):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
self.new_sample_rate = sample_rate_hz
self.train_set = SubsetSC(subset="training", new_sample_rate=self.new_sample_rate)
self.val_set = SubsetSC(subset="validation", new_sample_rate=self.new_sample_rate)
self.test_set = SubsetSC(subset="testing", new_sample_rate=self.new_sample_rate)
def setup(self, stage=None):
pass
@staticmethod
def num_classes() -> int:
return SubsetSC.num_labels()
def train_dataloader(self):
return torch.utils.data.DataLoader(
self.train_set,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
)
def val_dataloader(self):
return torch.utils.data.DataLoader(
self.val_set,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
)
def test_dataloader(self):
return torch.utils.data.DataLoader(
self.test_set,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
)