-
-
Notifications
You must be signed in to change notification settings - Fork 38
/
model.py
159 lines (129 loc) · 5.1 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import pytorch_lightning as pl
import pickle
import os
class CausalConv1d(torch.nn.Conv1d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
self.__padding = (kernel_size - 1) * dilation
super(CausalConv1d, self).__init__(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=self.__padding,
dilation=dilation,
groups=groups,
bias=bias,
)
def forward(self, input):
result = super(CausalConv1d, self).forward(input)
if self.__padding != 0:
return result[:, :, : -self.__padding]
return result
def _conv_stack(dilations, in_channels, out_channels, kernel_size):
"""
Create stack of dilated convolutional layers, outlined in WaveNet paper:
https://arxiv.org/pdf/1609.03499.pdf
"""
return nn.ModuleList(
[
CausalConv1d(
in_channels=in_channels,
out_channels=out_channels,
dilation=d,
kernel_size=kernel_size,
)
for i, d in enumerate(dilations)
]
)
class WaveNet(nn.Module):
def __init__(self, num_channels, dilation_depth, num_repeat, kernel_size=2):
super(WaveNet, self).__init__()
dilations = [2 ** d for d in range(dilation_depth)] * num_repeat
internal_channels = int(num_channels * 2)
self.hidden = _conv_stack(dilations, num_channels, internal_channels, kernel_size)
self.residuals = _conv_stack(dilations, num_channels, num_channels, 1)
self.input_layer = CausalConv1d(
in_channels=1,
out_channels=num_channels,
kernel_size=1,
)
self.linear_mix = nn.Conv1d(
in_channels=num_channels * dilation_depth * num_repeat,
out_channels=1,
kernel_size=1,
)
self.num_channels = num_channels
def forward(self, x):
out = x
skips = []
out = self.input_layer(out)
for hidden, residual in zip(self.hidden, self.residuals):
x = out
out_hidden = hidden(x)
# gated activation
# split (32,16,3) into two (16,16,3) for tanh and sigm calculations
out_hidden_split = torch.split(out_hidden, self.num_channels, dim=1)
out = torch.tanh(out_hidden_split[0]) * torch.sigmoid(out_hidden_split[1])
skips.append(out)
out = residual(out)
out = out + x[:, :, -out.size(2) :]
# modified "postprocess" step:
out = torch.cat([s[:, :, -out.size(2) :] for s in skips], dim=1)
out = self.linear_mix(out)
return out
def error_to_signal(y, y_pred):
"""
Error to signal ratio with pre-emphasis filter:
https://www.mdpi.com/2076-3417/10/3/766/htm
"""
y, y_pred = pre_emphasis_filter(y), pre_emphasis_filter(y_pred)
return (y - y_pred).pow(2).sum(dim=2) / (y.pow(2).sum(dim=2) + 1e-10)
def pre_emphasis_filter(x, coeff=0.95):
return torch.cat((x[:, :, 0:1], x[:, :, 1:] - coeff * x[:, :, :-1]), dim=2)
class PedalNet(pl.LightningModule):
def __init__(self, hparams):
super(PedalNet, self).__init__()
self.wavenet = WaveNet(
num_channels=hparams["num_channels"],
dilation_depth=hparams["dilation_depth"],
num_repeat=hparams["num_repeat"],
kernel_size=hparams["kernel_size"],
)
self.hparams = hparams
def prepare_data(self):
ds = lambda x, y: TensorDataset(torch.from_numpy(x), torch.from_numpy(y))
data = pickle.load(open(os.path.dirname(self.hparams.model) + "/data.pickle", "rb"))
self.train_ds = ds(data["x_train"], data["y_train"])
self.valid_ds = ds(data["x_valid"], data["y_valid"])
def configure_optimizers(self):
return torch.optim.Adam(self.wavenet.parameters(), lr=self.hparams.learning_rate)
def train_dataloader(self):
return DataLoader(
self.train_ds,
shuffle=True,
batch_size=self.hparams.batch_size,
num_workers=4,
)
def val_dataloader(self):
return DataLoader(self.valid_ds, batch_size=self.hparams.batch_size, num_workers=4)
def forward(self, x):
return self.wavenet(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_pred = self.forward(x)
loss = error_to_signal(y[:, :, -y_pred.size(2) :], y_pred).mean()
logs = {"loss": loss}
return {"loss": loss, "log": logs}
def validation_step(self, batch, batch_idx):
x, y = batch
y_pred = self.forward(x)
loss = error_to_signal(y[:, :, -y_pred.size(2) :], y_pred).mean()
return {"val_loss": loss}
def validation_epoch_end(self, outs):
avg_loss = torch.stack([x["val_loss"] for x in outs]).mean()
logs = {"val_loss": avg_loss}
return {"avg_val_loss": avg_loss, "log": logs}