Skip to content

Commit

Permalink
complex conv1d-->real conv1d
Browse files Browse the repository at this point in the history
  • Loading branch information
xir4n committed Jun 23, 2024
1 parent d7b5744 commit d824ad9
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 40 deletions.
39 changes: 4 additions & 35 deletions murenn/dtcwt/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,25 +84,6 @@ def to_conv1d(self):
-------
Return:
conv1d (torch.nn.Conv1d): A Pytorch Conv1d instance with weights initialized to y_jq.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> J = 8
>>> Q = 5
>>> N = 2**10
>>> tfm = murenn.MuReNNDirect(J=8, Q=5, T=32, in_channels=1)
>>> conv1d = tfm.to_conv1d
>>> x = torch.zeros(1,1,N)
>>> x[0,0,N//2]=1
>>> x = x*(1-1j)
>>> w = conv1d(x).reshape(J,Q,-1).detach()
>>> colors = [
>>> 'tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple',
>>> 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']
>>> for j in range(J):
>>> for q in range(Q):
>>> plt.semilogx(torch.abs(torch.fft.fft(w[j,q,:])), color=colors[j])
>>> plt.xlim(0, N//2)
"""
# T the filter length
T = self.conv1d[0].kernel_size[0]
Expand All @@ -112,34 +93,22 @@ def to_conv1d(self):
N = 2**J * T
x = torch.zeros(1, self.C, N)
x[:, :, N//2] = 1
# Get the padding mode
padding_mode = self.dtcwt.padding_mode
if padding_mode == "constant":
padding_mode = "zeros"

inv = murenn.IDTCWT(
J=J,
padding_mode=padding_mode,
normalize=False,
J = J,
normalize=True,
)
# Get DTCWT impulse reponses
phi, psis = self.dtcwt(x)
# Set phi to a zero valued tensor
zeros_phi = phi.new_zeros(size=(1, self.C*self.Q, phi.shape[-1]))
# Create an empty list for {w_jq}
ws = []
for j in range(J):
# Wpsi_jr = Re[psi_j] * w_jq
Wpsi_jr = self.conv1d[j](psis[j].real)
# W_ji = Im[psi_j] * w_jq
Wpsi_ji = self.conv1d[j](psis[j].imag)
# Set the coefficients besides this scale to zero
Wpsis_jr = [Wpsi_jr * (1 + 0j) if k == j else psis[k].new_zeros(size=psis[k].shape).repeat(1, self.Q, 1) for k in range(J)]
Wpsis_ji = [Wpsi_ji * (0 + 1j) if k == j else psis[k].new_zeros(size=psis[k].shape).repeat(1, self.Q, 1) for k in range(J)]
Wpsis_j = [torch.complex(Wpsi_jr, Wpsi_ji) if k == j else psis[k].new_zeros(size=psis[k].shape).repeat(1, self.Q, 1) for k in range(J)]
# Get the impulse response
w_jr = inv(zeros_phi, Wpsis_jr)
w_ji = inv(zeros_phi, Wpsis_ji)
w_j = torch.complex(w_jr, w_ji)
w_j = inv(zeros_phi, Wpsis_j)
# We only need data form one channel
w_j = w_j.reshape(self.C, self.Q, 1, N)[0,...]
ws.append(w_j)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_dtcwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def test_inv():
lp, bp = dtcwt(Xt)
lp = lp.new_zeros(lp.shape)
X_rec = idtcwt(lp, bp)
bp_r = [(bp[j].real)*(1+0j) for j in range(8)]
bp_i = [(bp[j].imag)*(0+1j) for j in range(8)]
bp_r = [(bp[j].real)*(1+0j) for j in range(dtcwt.J)]
bp_i = [(bp[j].imag)*(0+1j) for j in range(dtcwt.J)]
X_rec_r = idtcwt(lp, bp_r)
X_rec_i = idtcwt(lp, bp_i)
assert torch.allclose((X_rec_r+X_rec_i), X_rec, atol=1e-3)
5 changes: 2 additions & 3 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,13 @@ def test_modulus():

@pytest.mark.parametrize("Q", [1, 2])
@pytest.mark.parametrize("T", [1, 2])
@pytest.mark.parametrize("C", [1, 2])
def test_toconv1d_shape(Q, T, C):
def test_toconv1d_shape(Q, T):
J = 4
tfm = murenn.MuReNNDirect(
J=J,
Q=Q,
T=T,
in_channels=C,
in_channels=2,
)
conv1d = tfm.to_conv1d
assert isinstance(conv1d, torch.nn.Conv1d)
Expand Down

0 comments on commit d824ad9

Please sign in to comment.