Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH MuReNNDirect.to_conv1d() #56

Merged
merged 19 commits into from
Jul 25, 2024
102 changes: 86 additions & 16 deletions murenn/dtcwt/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,50 @@ class MuReNNDirect(torch.nn.Module):
"""
Args:
J (int): Number of levels (octaves) in the DTCWT decomposition.
Q (int): Number of Conv1D filters per octave.
Q (int or list): Number of Conv1D filters per octave.
T (int): Conv1D Kernel size multiplier. The Conv1d kernel size at scale j is equal to
T * Q[j] where Q[j] is the number of filters.
J_phi (int): Number of levels of downsampling. Stride is 2**J_phi. Default is J.
in_channels (int): Number of channels in the input signal.
padding_mode (str): One of 'symmetric' (default), 'zeros', 'replicate',
and 'circular'. Padding scheme for the DTCWT decomposition.
"""
def __init__(self, *, J, Q, T, in_channels, padding_mode="symmetric"):
def __init__(self, *, J, Q, T, in_channels, J_phi=None, padding_mode="symmetric"):
super().__init__()
self.Q = Q
self.C = in_channels
if isinstance(Q, int):
self.Q = [Q for j in range(J)]
elif isinstance(Q, list):
assert len(Q) == J
self.Q = Q
else:
raise TypeError(f"Q must to be int or list, got {type(Q)}")
if J_phi is None:
J_phi = J
if J_phi < J:
raise ValueError("J_phi must be greater or equal to J")
self.T = [T*self.Q[j] for j in range(J)]
self.in_channels = in_channels
down = []
conv1d = []
self.dtcwt = murenn.DTCWT(
J=J,
padding_mode=padding_mode,
alternate_gh=False,
)

for j in range(J):
down_j = murenn.DTCWT(
J=J-j,
J=J_phi-j,
padding_mode=padding_mode,
skip_hps=True,
alternate_gh=False,
)
down.append(down_j)

conv1d_j = torch.nn.Conv1d(
in_channels=in_channels,
out_channels=Q*in_channels,
kernel_size=T,
out_channels=self.Q[j]*in_channels,
kernel_size=self.T[j],
bias=False,
groups=in_channels,
padding="same",
Expand All @@ -51,24 +67,78 @@ def __init__(self, *, J, Q, T, in_channels, padding_mode="symmetric"):
def forward(self, x):
"""
Args:
x (PyTorch tensor): A tensor of shape `(B, C, T)`. B is a batch size,
C denotes a number of channels, T is a length of signal sequence.
x (PyTorch tensor): A tensor of shape `(B, in_channels, T)`. B is a batch size.
in_channels is the number of channels in the input tensor, this should match
the in_channels attribute of the class instance. T is a length of signal sequence.
Returns:
y (PyTorch tensor): A tensor of shape `(B, C, Q, J, T_out)`
y (PyTorch tensor): A tensor of shape `(B, in_channels, Q, J, T_out)`
"""
assert self.C == x.shape[1]
assert self.in_channels == x.shape[1]
lp, bps = self.dtcwt(x)
output = []
UWx = []
for j in range(self.dtcwt.J):
Wx_j_r = self.conv1d[j](bps[j].real)
Wx_j_i = self.conv1d[j](bps[j].imag)
UWx_j = ModulusStable.apply(Wx_j_r, Wx_j_i)
# Avarange over time
UWx_j, _ = self.down[j](UWx_j)
B, _, N = UWx_j.shape
# reshape from (B, C*Q, N) to (B, C, Q, N)
UWx_j = UWx_j.view(B, self.C, self.Q, N)
output.append(UWx_j)
return torch.stack(output, dim=3)
UWx_j = UWx_j.view(B, self.in_channels, self.Q[j], N)
UWx.append(UWx_j)
UWx = torch.cat(UWx, dim=2)
return UWx

def to_conv1d(self):
"""
Compute the single-resolution equivalent impulse response of the MuReNN layer.
This would be helpful for visualization in Fourier domain, for receptive fields,
and for comparing computational costs.
DTCWT conv1d IDTCWT
δ -------> ψ_j --------> w_jq -------> y_jq
-------
Return:
conv1d (torch.nn.Conv1d): A Pytorch Conv1d instance with weights initialized to y_jq.
"""

device = self.conv1d[0].weight.data.device
# T the filter length
T = max(self.T)
# J the number of levels of decompostion
J = self.dtcwt.J
# Generate the impulse signal
N = 2 ** J * T
x = torch.zeros(1, self.in_channels, N).to(device)
x[:, :, N//2] = 1
inv = murenn.IDTCWT(
J = J,
alternate_gh=False
).to(device)
# Get DTCWT impulse reponses
phi, psis = self.dtcwt(x)
# Set phi to a zero valued tensor
zeros_phi = phi.new_zeros(1, 1, phi.shape[-1])
ws = []
for j in range(J):
Wpsi_jr = self.conv1d[j](psis[j].real).reshape(self.in_channels, self.Q[j], -1)
Wpsi_ji = self.conv1d[j](psis[j].imag).reshape(self.in_channels, self.Q[j], -1)
for q in range(self.Q[j]):
Wpsi_jqr = Wpsi_jr[0, q, :].reshape(1,1,-1)
Wpsi_jqi = Wpsi_ji[0, q, :].reshape(1,1,-1)
Wpsis_r = [Wpsi_jqr * (1+0j) if k == j else psis[k].new_zeros(1,1,psis[k].shape[-1]) for k in range(J)]
Wpsis_i = [Wpsi_jqi * (0+1j) if k == j else psis[k].new_zeros(1,1,psis[k].shape[-1]) for k in range(J)]
w_r = inv(zeros_phi, Wpsis_r)
w_i = inv(zeros_phi, Wpsis_i)
ws.append(torch.complex(w_r, w_i))
ws = torch.cat(ws, dim=0)
conv1d = torch.nn.Conv1d(
in_channels=1,
out_channels=ws.shape[0],
kernel_size=N,
bias=False,
padding="same",
)
conv1d.weight.data = torch.nn.parameter.Parameter(ws)
return conv1d


class ModulusStable(torch.autograd.Function):
Expand Down
5 changes: 4 additions & 1 deletion murenn/dtcwt/transform1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,10 @@ def forward(self, yl, yh):
if self.normalize:
x_phi = np.sqrt(2) * x_phi

## LEVEL 1 ##
# LEVEL 1 ##
if x_phi.shape[-1] != x_psis[0].shape[-1] * 2:
x_phi = x_phi[:,:,1:-1]

x_psi_r, x_psi_i = x_psis[0].real, x_psis[0].imag

x_phi = INV_J1.apply(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dtcwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,4 @@ def test_avrg_energy(alternate_gh):
Ppsi_j = torch.linalg.norm(torch.abs(psi)) ** 2 / psi.shape[-1]
P_Ux = P_Ux + Ppsi_j
ratio = P_Ux / P_x
assert torch.abs(ratio - 1) <= 0.01
assert torch.abs(ratio - 1) <= 0.01
17 changes: 15 additions & 2 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_direct_shape(J, Q, T, N, padding_mode):
padding_mode=padding_mode,
)
y = graph(x)
assert y.shape[:4] == (B, C, Q, J)
assert y.shape[:3] == (B, C, Q*J)


def test_direct_diff():
Expand Down Expand Up @@ -86,4 +86,17 @@ def test_modulus():
loss0 = torch.sum(Ux0)
loss0.backward()
assert torch.max(torch.abs(x0r.grad)) <= 1e-7
assert torch.max(torch.abs(x0i.grad)) <= 1e-7
assert torch.max(torch.abs(x0i.grad)) <= 1e-7

@pytest.mark.parametrize("Q", [1, 2])
@pytest.mark.parametrize("T", [1, 2])
def test_toconv1d_shape(Q, T):
J = 4
tfm = murenn.MuReNNDirect(
J=J,
Q=Q,
T=T,
in_channels=2,
)
conv1d = tfm.to_conv1d()
assert isinstance(conv1d, torch.nn.Conv1d)