Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH MuReNNDirect.to_conv1d() #56

Merged
merged 19 commits into from
Jul 25, 2024
102 changes: 86 additions & 16 deletions murenn/dtcwt/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,49 @@ class MuReNNDirect(torch.nn.Module):
"""
Args:
J (int): Number of levels (octaves) in the DTCWT decomposition.
Q (int): Number of Conv1D filters per octave.
Q (int, dict or list): Number of Conv1D filters per octave.
T (int): Conv1D Kernel size multiplier
lostanlen marked this conversation as resolved.
Show resolved Hide resolved
J_phi (int): Number of levels of downsampling. Stride is 2**J_phi. Default is J.
in_channels (int): Number of channels in the input signal.
padding_mode (str): One of 'symmetric' (default), 'zeros', 'replicate',
and 'circular'. Padding scheme for the DTCWT decomposition.
"""
def __init__(self, *, J, Q, T, in_channels, padding_mode="symmetric"):
def __init__(self, *, J, Q, T, in_channels, J_phi=None, padding_mode="symmetric"):
super().__init__()
self.Q = Q
self.C = in_channels
if isinstance(Q, int):
self.Q = [Q for j in range(J)]
elif isinstance(Q, (dict, list)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think supporting dict is not necessary. list is sufficient

[3, 4, 5, 2] is more explicit than {0: 3, 1: 3, 3: 5, 4: 2}

assert len(Q) == J
self.Q = Q
else:
raise TypeError(f"Q must to be int, dict or list, got {type(Q)}")
if J_phi is None:
J_phi = J
if J_phi < J:
raise ValueError("J_phi must be greater or equal to J")
self.T = [T*self.Q[j] for j in range(J)]
self.in_channels = in_channels
down = []
conv1d = []
self.dtcwt = murenn.DTCWT(
J=J,
padding_mode=padding_mode,
alternate_gh=False,
)

for j in range(J):
down_j = murenn.DTCWT(
J=J-j,
J=J_phi-j,
padding_mode=padding_mode,
skip_hps=True,
alternate_gh=False,
)
down.append(down_j)

conv1d_j = torch.nn.Conv1d(
in_channels=in_channels,
out_channels=Q*in_channels,
kernel_size=T,
out_channels=self.Q[j]*in_channels,
kernel_size=self.T[j],
bias=False,
groups=in_channels,
padding="same",
Expand All @@ -51,24 +66,79 @@ def __init__(self, *, J, Q, T, in_channels, padding_mode="symmetric"):
def forward(self, x):
"""
Args:
x (PyTorch tensor): A tensor of shape `(B, C, T)`. B is a batch size,
C denotes a number of channels, T is a length of signal sequence.
x (PyTorch tensor): A tensor of shape `(B, in_channels, T)`. B is a batch size.
in_channels is the number of channels in the input tensor, this should match
the in_channels attribute of the class instance. T is a length of signal sequence.
Returns:
y (PyTorch tensor): A tensor of shape `(B, C, Q, J, T_out)`
y (PyTorch tensor): A tensor of shape `(B, in_channels, Q, J, T_out)`
"""
assert self.C == x.shape[1]
assert self.in_channels == x.shape[1]
lp, bps = self.dtcwt(x)
output = []
UWx = []
for j in range(self.dtcwt.J):
Wx_j_r = self.conv1d[j](bps[j].real)
Wx_j_i = self.conv1d[j](bps[j].imag)
UWx_j = ModulusStable.apply(Wx_j_r, Wx_j_i)
# Avarange over time
UWx_j, _ = self.down[j](UWx_j)
B, _, N = UWx_j.shape
# reshape from (B, C*Q, N) to (B, C, Q, N)
UWx_j = UWx_j.view(B, self.C, self.Q, N)
output.append(UWx_j)
return torch.stack(output, dim=3)
UWx_j = UWx_j.view(B, self.in_channels, self.Q[j], N)
UWx.append(UWx_j)
UWx = torch.cat(UWx, dim=2)
return UWx

@property
lostanlen marked this conversation as resolved.
Show resolved Hide resolved
def to_conv1d(self):
"""
Compute the single-resolution equivalent impulse response of the MuReNN layer.
This would be helpful for visualization in Fourier domain, for receptive fields,
and for comparing computational costs.
DTCWT conv1d IDTCWT
δ -------> ψ_j --------> w_jq -------> y_jq
-------
Return:
conv1d (torch.nn.Conv1d): A Pytorch Conv1d instance with weights initialized to y_jq.
"""

device = self.conv1d[0].weight.data.device
# T the filter length
T = max(self.T)
# J the number of levels of decompostion
J = self.dtcwt.J
# Generate the impulse signal
N = 2 ** J * T
x = torch.zeros(1, self.in_channels, N).to(device)
x[:, :, N//2] = 1
inv = murenn.IDTCWT(
J = J,
alternate_gh=False
).to(device)
# Get DTCWT impulse reponses
phi, psis = self.dtcwt(x)
# Set phi to a zero valued tensor
zeros_phi = phi.new_zeros(1, 1, phi.shape[-1])
ws = []
for j in range(J):
Wpsi_jr = self.conv1d[j](psis[j].real).reshape(self.in_channels, self.Q[j], -1)
Wpsi_ji = self.conv1d[j](psis[j].imag).reshape(self.in_channels, self.Q[j], -1)
for q in range(self.Q[j]):
Wpsi_jqr = Wpsi_jr[0, q, :].reshape(1,1,-1)
Wpsi_jqi = Wpsi_ji[0, q, :].reshape(1,1,-1)
Wpsis_r = [Wpsi_jqr * (1+0j) if k == j else psis[k].new_zeros(1,1,psis[k].shape[-1]) for k in range(J)]
Wpsis_i = [Wpsi_jqi * (0+1j) if k == j else psis[k].new_zeros(1,1,psis[k].shape[-1]) for k in range(J)]
w_r = inv(zeros_phi, Wpsis_r)
w_i = inv(zeros_phi, Wpsis_i)
ws.append(torch.complex(w_r, w_i))
ws = torch.cat(ws, dim=0)
conv1d = torch.nn.Conv1d(
in_channels=1,
out_channels=ws.shape[0],
kernel_size=N,
bias=False,
padding="same",
)
conv1d.weight.data = torch.nn.parameter.Parameter(ws)
return conv1d


class ModulusStable(torch.autograd.Function):
Expand Down
14 changes: 8 additions & 6 deletions murenn/dtcwt/transform1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def forward(self, x):
# Ensure the lowpass is divisible by 4
if x_phi.shape[-1] % 4 != 0:
x_phi = torch.cat((x_phi[:,:,0:1], x_phi, x_phi[:,:,-1:]), dim=-1)

if self.normalize:
x_phi = 1/np.sqrt(2) * x_phi
x_phi, x_psi_r, x_psi_i = FWD_J2PLUS.apply(
x_phi,
h0a,
Expand All @@ -165,15 +166,12 @@ def forward(self, x):
h1b,
self.skip_hps[j],
self.padding_mode,
self.normalize,
)

if (j % 2 == 1) and self.alternate_gh:
# The result is anti-analytic in the Hilbert sense.
# We conjugate the result to bring the spectrum back to (0, pi).
# This is purely by convention and for consistency through j.
x_psi_i = -1 * x_psi_i

x_psis.append(x_psi_r + 1j * x_psi_i)

if self.include_scale[j]:
Expand Down Expand Up @@ -295,10 +293,14 @@ def forward(self, yl, yh):
g0b,
g1b,
self.padding_mode,
self.normalize,
)
if self.normalize:
x_phi = np.sqrt(2) * x_phi

# LEVEL 1 ##
if x_phi.shape[-1] != x_psis[0].shape[-1] * 2:
x_phi = x_phi[:,:,1:-1]

## LEVEL 1 ##
x_psi_r, x_psi_i = x_psis[0].real, x_psis[0].imag

x_phi = INV_J1.apply(
Expand Down
25 changes: 5 additions & 20 deletions murenn/dtcwt/transform_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class FWD_J2PLUS(torch.autograd.Function):
high-pass output of tree b."""

@staticmethod
def forward(ctx, x_phi, h0a, h1a, h0b, h1b, skip_hps, padding_mode, normalize):
def forward(ctx, x_phi, h0a, h1a, h0b, h1b, skip_hps, padding_mode):
"""
Forward dual-tree complex wavelet transform at levels 2 and coarser.

Expand Down Expand Up @@ -104,7 +104,6 @@ def forward(ctx, x_phi, h0a, h1a, h0b, h1b, skip_hps, padding_mode, normalize):
ctx.save_for_backward(h0a_rep, h1a_rep, h0b_rep, h1b_rep)
ctx.skip_hps = skip_hps
ctx.mode = mode_to_int(padding_mode)
ctx.normalize = normalize

# Apply low-pass filtering on trees a (real) and b (imaginary).
lo = coldfilt(x_phi, h0a_rep, h0b_rep, padding_mode)
Expand All @@ -122,17 +121,13 @@ def forward(ctx, x_phi, h0a, h1a, h0b, h1b, skip_hps, padding_mode, normalize):

# Return low-pass output, and band-pass output in conjunction:
# real part for tree a and imaginary part for tree b.
if normalize:
return 1/np.sqrt(2) * lo, 1/np.sqrt(2) * bp_r, 1/np.sqrt(2) * bp_i
else:
return lo, bp_r, bp_i
return lo, bp_r, bp_i

@staticmethod
def backward(ctx, dx_phi, dx_psi_r, dx_psi_i):
g0b, g1b, g0a, g1a = ctx.saved_tensors
skip_hps = ctx.skip_hps
padding_mode = int_to_mode(ctx.mode)
normalize = ctx.normalize
b, ch, T = dx_phi.shape
if not ctx.needs_input_grad[0]:
dx = None
Expand All @@ -141,8 +136,6 @@ def backward(ctx, dx_phi, dx_psi_r, dx_psi_i):
if not skip_hps:
dx_psi = torch.stack((dx_psi_i, dx_psi_r), dim=-1).view(b, ch, T)
dx += colifilt(dx_psi, g1a, g1b, padding_mode)
if normalize:
dx *= 1/np.sqrt(2)
return dx, None, None, None, None, None, None, None


Expand Down Expand Up @@ -212,7 +205,7 @@ class INV_J2PLUS(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, lo, bp_r, bp_i, g0a, g1a, g0b, g1b, padding_mode, normalize):
def forward(ctx, lo, bp_r, bp_i, g0a, g1a, g0b, g1b, padding_mode):
"""
Inverse dual-tree complex wavelet transform at levels 2 and coarser.

Expand All @@ -237,33 +230,25 @@ def forward(ctx, lo, bp_r, bp_i, g0a, g1a, g0b, g1b, padding_mode, normalize):
g0b_rep = g0b.repeat(ch, 1, 1)
g1b_rep = g1b.repeat(ch, 1, 1)
ctx.save_for_backward(g0a_rep, g1a_rep, g0b_rep, g1b_rep)
ctx.normalize = normalize
ctx.mode = mode_to_int(padding_mode)

bp = torch.stack((bp_i, bp_r), dim=-1).view(b, ch, T)
lo = colifilt(lo, g0a_rep, g0b_rep, padding_mode) + colifilt(bp, g1a_rep, g1b_rep, padding_mode)

if normalize:
return np.sqrt(2) * lo
else:
return lo
return lo


@staticmethod
def backward(ctx, dx):
g0b, g1b, g0a, g1a = ctx.saved_tensors
padding_mode = int_to_mode(ctx.mode)
normalize = ctx.normalize
b, ch, T = dx.shape
dlo, dbp = None, None
if ctx.needs_input_grad[0]:
dlo = coldfilt(dx, g0a, g0b, padding_mode)
dlo = torch.stack([dlo[:,:ch], dlo[:,ch:2*ch]], dim=-1).view(b, ch, T//2)
if normalize:
dlo *= np.sqrt(2)
if ctx.needs_input_grad[1] or ctx.needs_input_grad[2]:
dbp = coldfilt(dx, g1a, g1b, padding_mode)
if normalize:
dbp *= np.sqrt(2)
if ctx.needs_input_grad[1]:
dbp_r = dbp[:,ch:2*ch]
if ctx.needs_input_grad[2]:
Expand Down
16 changes: 15 additions & 1 deletion tests/test_dtcwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_fwd_same(J):
@pytest.mark.parametrize("normalize", [True, False])
@pytest.mark.parametrize("J", list(range(1, 5)))
@pytest.mark.parametrize("T", [44099, 44100])
def test_inv(level1, qshift, J, T, alternate_gh, normalize):
def test_pr(level1, qshift, J, T, alternate_gh, normalize):
Xt = torch.randn(2, 2, T)
xfm_murenn = murenn.DTCWTDirect(
J=J,
Expand Down Expand Up @@ -75,3 +75,17 @@ def test_skip_hps(skip_hps, include_scale):
inv = murenn.DTCWTInverse(J=J, skip_hps=skip_hps, include_scale=include_scale)
X_rec = inv(lp, bp)
assert X_rec.shape == Xt.shape

def test_inv():
T = 2**10
Xt = torch.randn(2, 2, T)
dtcwt = murenn.DTCWTDirect()
idtcwt = murenn.DTCWTInverse()
lp, bp = dtcwt(Xt)
lp = lp.new_zeros(lp.shape)
X_rec = idtcwt(lp, bp)
bp_r = [(bp[j].real)*(1+0j) for j in range(dtcwt.J)]
bp_i = [(bp[j].imag)*(0+1j) for j in range(dtcwt.J)]
X_rec_r = idtcwt(lp, bp_r)
X_rec_i = idtcwt(lp, bp_i)
assert torch.allclose((X_rec_r+X_rec_i), X_rec, atol=1e-3)
14 changes: 4 additions & 10 deletions tests/test_grad.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import pytest
import torch
from torch.autograd import gradcheck
import numpy as np
import dtcwt
import murenn
import murenn.dtcwt.transform_funcs as tf
from contextlib import contextmanager
Expand Down Expand Up @@ -36,15 +34,14 @@ def test_fwd_j1(skip_hps):
gradcheck(tf.FWD_J1.apply, input, eps=eps, atol=atol)


@pytest.mark.parametrize("normalize", [True, False])
@pytest.mark.parametrize("skip_hps", [[0, 1], [1, 0]])
def test_fwd_j2(skip_hps, normalize):
def test_fwd_j2(skip_hps):
J = 2
eps = 1e-3
atol = 1e-4
with set_double_precision():
x = torch.randn(2, 2, 4, device=dev, requires_grad=True)
fwd = murenn.DTCWTDirect(J=J, skip_hps=skip_hps, normalize=normalize).to(dev)
fwd = murenn.DTCWTDirect(J=J, skip_hps=skip_hps).to(dev)
input = (
x,
fwd.h0a,
Expand All @@ -53,7 +50,6 @@ def test_fwd_j2(skip_hps, normalize):
fwd.h1b,
fwd.skip_hps[1],
fwd.padding_mode,
fwd.normalize,
)
gradcheck(tf.FWD_J2PLUS.apply, input, eps=eps, atol=atol)

Expand All @@ -71,16 +67,15 @@ def test_inv_j1():
gradcheck(tf.INV_J1.apply, input, eps=eps, atol=atol)


@pytest.mark.parametrize("normalize", [True, False])
def test_inv_j2(normalize):
def test_inv_j2():
J = 2
eps = 1e-3
atol = 1e-4
with set_double_precision():
lo = torch.randn(2, 2, 8, device=dev, requires_grad=True)
bp_r = torch.randn(2, 2, 4, device=dev, requires_grad=True)
bp_i = torch.randn(2, 2, 4, device=dev, requires_grad=True)
inv = murenn.DTCWTInverse(J=J, normalize=normalize).to(dev)
inv = murenn.DTCWTInverse(J=J).to(dev)

input = (
lo,
Expand All @@ -91,7 +86,6 @@ def test_inv_j2(normalize):
inv.g0b,
inv.g1b,
inv.padding_mode,
inv.normalize,
)
gradcheck(tf.INV_J2PLUS.apply, input, eps=eps, atol=atol)

Expand Down
17 changes: 15 additions & 2 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_direct_shape(J, Q, T, N, padding_mode):
padding_mode=padding_mode,
)
y = graph(x)
assert y.shape[:4] == (B, C, Q, J)
assert y.shape[:3] == (B, C, Q*J)


def test_direct_diff():
Expand Down Expand Up @@ -86,4 +86,17 @@ def test_modulus():
loss0 = torch.sum(Ux0)
loss0.backward()
assert torch.max(torch.abs(x0r.grad)) <= 1e-7
assert torch.max(torch.abs(x0i.grad)) <= 1e-7
assert torch.max(torch.abs(x0i.grad)) <= 1e-7

@pytest.mark.parametrize("Q", [1, 2])
@pytest.mark.parametrize("T", [1, 2])
def test_toconv1d_shape(Q, T):
J = 4
tfm = murenn.MuReNNDirect(
J=J,
Q=Q,
T=T,
in_channels=2,
)
conv1d = tfm.to_conv1d
assert isinstance(conv1d, torch.nn.Conv1d)