-
Notifications
You must be signed in to change notification settings - Fork 661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use complex tensors in phase_vocoder #758
Changes from all commits
8918a2e
1d86ee8
05abd96
47e2b28
7f98428
4509fe8
6ab1a2f
cd2f8eb
924ab47
da9edeb
d75f6fc
c3d9b8e
efe1644
f3a6ac9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -458,14 +458,28 @@ def phase_vocoder( | |
factor of ``rate``. | ||
|
||
Args: | ||
complex_specgrams (Tensor): Dimension of `(..., freq, time, complex=2)` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about something like this?
We were using "complex tensor" to mean There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah tensor with a complex dtype sounds good. However this "or" way of documenting could be problematic in case where the function takes more than one complex tensors. Perhaps in those cases, we can add a note stating that either all inputs should be real tensors or all inputs should be of complex dtype. I think it might be nicer to add a separate example with complex dtype tensors so that it's also clear that the returned output would also be complex (if applicable) etc., especially since we are planning to switch to using complex dtype tensors in the release after the upcoming release There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do find this "or" way of discussing this a little cumbersome, and I agree this will get long if many tensors are involved. We could add a note in each, and just define the args/returns with complex dtype. We can still keep the example for clarity. """
We are migrating to complex-dtype tensors. For backward compatibility reason,
this function still supports the legacy convention of ending with a dimension of 2
to represent a complex tensor.
Args:
complex_specgrams (Tensor): A tensor of dimension `(..., freq, time)` with complex dtype.
rate (float): Speed-up factor
phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1)
Returns:
Tensor: Complex Specgrams Stretch with dimension of `(..., freq, ceil(time/rate))`
with a complex dtype.
Example
Example - Legacy
""" Thoughts? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. P.S. Good suggestion below for example naming |
||
complex_specgrams (Tensor): Either a real tensor of dimension of `(..., freq, time, complex=2)` | ||
or a tensor of dimension `(..., freq, time)` with complex dtype. | ||
rate (float): Speed-up factor | ||
phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1) | ||
|
||
Returns: | ||
Tensor: Complex Specgrams Stretch with dimension of `(..., freq, ceil(time/rate), complex=2)` | ||
Tensor: Complex Specgrams Stretch with either a real dtype and dimension of | ||
`(..., freq, ceil(time/rate), complex=2)` or | ||
a complex dtype and dimension of `(..., freq, ceil(time/rate))`. | ||
|
||
Example | ||
Example - New API (using tensors with complex dtype) | ||
>>> freq, hop_length = 1025, 512 | ||
>>> # (channel, freq, time) | ||
>>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is neat! |
||
>>> rate = 1.3 # Speed up by 30% | ||
>>> phase_advance = torch.linspace( | ||
>>> 0, math.pi * hop_length, freq)[..., None] | ||
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance) | ||
>>> x.shape # with 231 == ceil(300 / 1.3) | ||
torch.Size([2, 1025, 231]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we might not need to change the example. we could add a second example, or a comment next to each There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should have an example with tensors of complex dtype so that users know how to deal with complex tensors. This is an option:
what do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good suggestion :) How about standardizing on this?
|
||
|
||
Example - Old API (using real tensors with shape (..., complex=2)) | ||
>>> freq, hop_length = 1025, 512 | ||
>>> # (channel, freq, time, complex=2) | ||
>>> complex_specgrams = torch.randn(2, freq, 300, 2) | ||
|
@@ -476,50 +490,71 @@ def phase_vocoder( | |
>>> x.shape # with 231 == ceil(300 / 1.3) | ||
torch.Size([2, 1025, 231, 2]) | ||
""" | ||
|
||
# pack batch | ||
use_complex = complex_specgrams.is_complex() | ||
shape = complex_specgrams.size() | ||
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:])) | ||
|
||
time_steps = torch.arange(0, | ||
complex_specgrams.size(-2), | ||
rate, | ||
device=complex_specgrams.device, | ||
dtype=complex_specgrams.dtype) | ||
|
||
alphas = time_steps % 1.0 | ||
phase_0 = angle(complex_specgrams[..., :1, :]) | ||
|
||
# Time Padding | ||
complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2]) | ||
|
||
# (new_bins, freq, 2) | ||
complex_specgrams_0 = complex_specgrams.index_select(-2, time_steps.long()) | ||
complex_specgrams_1 = complex_specgrams.index_select(-2, (time_steps + 1).long()) | ||
|
||
angle_0 = angle(complex_specgrams_0) | ||
angle_1 = angle(complex_specgrams_1) | ||
|
||
norm_0 = torch.norm(complex_specgrams_0, p=2, dim=-1) | ||
norm_1 = torch.norm(complex_specgrams_1, p=2, dim=-1) | ||
if use_complex: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As an alternative to duplicating all the logic here, you could have instead taken the real tensor, viewed it as complex, and then used the complex codepath (viewing it back as real in the end). Something to consider? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that's a possibility however the goal is to be able to remove the code in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. well, in my suggestion, you'd delete the real code immediately :) Anyway, this is NBD There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh sorry I misread your comment and thought you meant the other way round! yeah that sounds reasonable to me. cc. @mthrok thoughts? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. on more thought, current lack of autograd support and testing for complex, might introduce bc breaking changes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I am in favor for not duplicating the logic, however if that introduces BC breaking on real value tensor input, then I think we can wait until the autograd support arrives. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it's R2R with complex insides, the choice of JAX/TF convention doesn't matter, you'll always get the same gradients in the end. |
||
# pack batch | ||
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-2:])) | ||
time_steps = torch.arange(0, | ||
complex_specgrams.size(-1), | ||
rate, | ||
device=complex_specgrams.device, | ||
dtype=torch.real(complex_specgrams).dtype) | ||
phase_0 = complex_specgrams[..., :1].angle() | ||
alphas = time_steps % 1.0 | ||
# Time Padding | ||
complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 2]) | ||
# (new_bins, freq, 2) | ||
complex_specgrams_0 = complex_specgrams.index_select(-1, time_steps.long()) | ||
complex_specgrams_1 = complex_specgrams.index_select(-1, (time_steps + 1).long()) | ||
|
||
angle_0 = complex_specgrams_0.angle() | ||
angle_1 = complex_specgrams_1.angle() | ||
norm_0 = complex_specgrams_0.abs() | ||
norm_1 = complex_specgrams_1.abs() | ||
else: | ||
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:])) | ||
time_steps = torch.arange(0, | ||
complex_specgrams.size(-2), | ||
rate, | ||
device=complex_specgrams.device, | ||
dtype=complex_specgrams.dtype) | ||
alphas = time_steps % 1.0 | ||
phase_0 = angle(complex_specgrams[..., :1, :]) | ||
# Time Padding | ||
complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2]) | ||
# (new_bins, freq, 2) | ||
complex_specgrams_0 = complex_specgrams.index_select(-2, time_steps.long()) | ||
complex_specgrams_1 = complex_specgrams.index_select(-2, (time_steps + 1).long()) | ||
|
||
angle_0 = angle(complex_specgrams_0) | ||
angle_1 = angle(complex_specgrams_1) | ||
norm_0 = torch.norm(complex_specgrams_0, p=2, dim=-1) | ||
norm_1 = torch.norm(complex_specgrams_1, p=2, dim=-1) | ||
|
||
phase = angle_1 - angle_0 - phase_advance | ||
phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi)) | ||
|
||
# Compute Phase Accum | ||
phase = phase + phase_advance | ||
phase = torch.cat([phase_0, phase[..., :-1]], dim=-1) | ||
|
||
phase_acc = torch.cumsum(phase, -1) | ||
|
||
mag = alphas * norm_1 + (1 - alphas) * norm_0 | ||
|
||
real_stretch = mag * torch.cos(phase_acc) | ||
imag_stretch = mag * torch.sin(phase_acc) | ||
|
||
complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1) | ||
if use_complex: | ||
complex_specgrams_stretch = torch.view_as_complex(torch.stack([real_stretch, imag_stretch], dim=-1)) | ||
|
||
# unpack batch | ||
complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams_stretch.shape[1:]) | ||
# unpack batch | ||
complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-2] + complex_specgrams_stretch.shape[1:]) | ||
else: | ||
complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1) | ||
# unpack batch | ||
complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams_stretch.shape[1:]) | ||
|
||
return complex_specgrams_stretch | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change related or necessary?
My understanding is that
_assert_consistency
method will change the dtype/device to appropriate ones so this change has no effect.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh okay removed it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still see
dtype=torch.double
...