-
Notifications
You must be signed in to change notification settings - Fork 664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Migration] Torchaudio Complex Tensor Support and Migration #1337
Comments
…rogram Part of pytorch#1337 . - This code changes the return type of spectrogram to be native complex dtype, when (and only when) returning raw (complex-valued) spectrogram. - Change `return_complex=False` to `return_complex=True` in spectrogram ops. - `return_complex` is only effective when `power` is `None`. It is ignored for cases where `power` is not `None`. Because the returned Tensor is power spectrogram, which is real-valued Tensors.
…rogram Part of pytorch#1337 . - This code changes the return type of spectrogram to be native complex dtype, when (and only when) returning raw (complex-valued) spectrogram. - Change `return_complex=False` to `return_complex=True` in spectrogram ops. - `return_complex` is only effective when `power` is `None`. It is ignored for cases where `power` is not `None`. Because the returned Tensor is power spectrogram, which is real-valued Tensors.
…rogram Part of pytorch#1337 . - This code changes the return type of spectrogram to be native complex dtype, when (and only when) returning raw (complex-valued) spectrogram. - Change `return_complex=False` to `return_complex=True` in spectrogram ops. - `return_complex` is only effective when `power` is `None`. It is ignored for cases where `power` is not `None`. Because the returned Tensor is power spectrogram, which is real-valued Tensors.
#1549) * [BC-Breaking] Default to native complex type when returning raw spectrogram Part of #1337 . - This code changes the return type of spectrogram to be native complex dtype, when (and only when) returning raw (complex-valued) spectrogram. - Change `return_complex=False` to `return_complex=True` in spectrogram ops. - `return_complex` is only effective when `power` is `None`. It is ignored for cases where `power` is not `None`. Because the returned Tensor is power spectrogram, which is real-valued Tensors.
Following the plan #1337, this commit drops the support for pseudo complex type from `F.spectrogram` and `T.Spectrogram`. It also deprecates the use of `return_complex` argument.
Please, before going forward with the deprecation, note that complex32 format is poorly supported on cuda. Had to rewrite whole code back to pseudo complex to be able to work with half precision... |
Thanks for letting us know. TorchAudio is not tested on fp16 nor complex32, and they are not part of officially supported types. So we did not realize that pseudo complex can be used for a workaround of complex32. Unfortunately, we wrapped up the release v0.11 (scheduled to be out in about one week) and PyTorch removed complex32 type and torchaudio removed the support for pseudo complex type. So I assume it will be unusable to you. I can try reverting the pseudo complex support if that's the best course of action. (however it is known that some operations with pseudo complex have issues with accuracy, so it's not the best workaround, which is one we wanted to migrate to native complex type.) The treatment of complex32 indeed needs improvement and there is an issue created for this in PyTorch. pytorch/pytorch#71680 The most ideal outcome is that PyTorch core adds complex32 support quickly but I am not sure if that can happen quickly. However, your voice matters a lot here, so would you be willing to provide features that will be most relevant for you in pytorch/pytorch#71680? That way, the PyTorch core team can prioritize it if they decide to work on them. |
There were some typos here, so the migration made the program crash.
The correct version should be power = spectrogram.abs().pow(2)
norm = spectrogram.abs().pow(norm)
magnitude, phase = spectrogram.abs().pow(n), spectrogram.angle() |
@jeffeuxMartin thanks for the report. Fixed it. |
Summary: Mention new profiler API. Test Plan: make html-noplot
Torchaudio Complex Tensor Support and Migration
Overview
torchaudio
has been expressing complex numbers by having an extra dimension for real-part and imaginary-part. (We will refer this format as"pseudo complex type"
)PyTorch 1.6 introduced complex Tensor type, such as
torch.complex64
(torch.cfloat
) andtorch.complex128
(torch.cdouble
). (Will be refered as"native complex type"
)The natitve complex type comes with handy methods for complex operation such as
abs
,angle
andmagphase
. (Please refer to the official documentation for the detail.)Over the few coming releases, we plan to migrate torchaudio's functions and transforms to the native complex type. This issues describes the planned approaches/works/changes/timeline. If you have a question, a concern or a suggestion. Feel free to leave a comment.
Migration Stages
We will perform the migration in multiple stages. At this moment, the completion of the later migration stages re not tied with specific releases yet.
✅ Stage 0 (
~ 0.8
)Up to release 0.8, torchaudio exclusively used pseudo complex type. In PyTorch 1.7, PyTorch started the adaptation of native complex type and the migration of
torch.fft
namespace. Because of this, torchaudio already uses native complex type in some implementations (F.vad
,T.Vad
,kaldi.spectrogram
andkaldi.fbank
) but all the user facing APIs use use pseudo complex type.✅ Stage 1 Add support for native complex type and deprecate pseudo complex type)
Completed: PyTorch 1.9 / torchaudio 0.9
Library code change
In this stage, torchaudio will support both pseudo complex type and native complex type. This means that
return_complex
will be added so that users can switch the behavior.Test code update
In addition to the above library code changes, we are going to add a set of tests to make sure that native complex types work in common use cases. This includes;
nn.Module
compatibility✅Stage 2 (Switch to native complex type by default)
Completed:
main
branch. To be released as part of PyTorch 1.10 / torchaudio 0.10The default value for
return_complex
is changed toTrue
.👉 Stage 3 (Remove the support for pseudo complex type)
In this stage, we will remove the support for pseudo complex type.
return_complex
argument added in Stage. 1 is deprecated and eventually removed.Affected Functions
The following figure illustrates the functions that handle complex values and their dependencies.
Utility functions
F.angle
,F.complex_norm
,T.ComplexNorm
,F.magphase
These functions are deprecated in Stage.1 and will be removed in Stage.3.
For
F.angle
, native complex tensors provide theangle()
function.For
F.complex_norm
/T.ComplexNorm
, the equivalent computation can be performed withabs().pow(n)
.F.magphase
is a convenient function to callF.angle
andF.complex_norm
, therefore, this function is deprecated as well.Real to real functions
F.griffinlim
,T.GriffinLim
Changes to these functions are kept internal, therefore we can simply change the internals without disturbing the downstream users.
Complex to complex functions
F.phase_vocoder
,T.TimeStretch
When adding support for native complex type, we can simplify the interface change as follow
Real to complex functions
F.spectrogram
,T.Spectrogram
These functions return either real valued Tensor (power, energy) or complex valued Tensor (frequency representation), which depends on what
power
argument was provided. Whenpower
is not provided, these functions return a complex-valued Tensor. In this case, users have the option to receive the result in pseudo complex type or native complex type.return_complex
argument will be added for this choice. Ifreturn_complex
isTrue
, then native complex type is returned. See #1009 for the discussion.Timeline
F.angle
,F.complex_norm
,F.magphase
,T.ComplexNorm
F.griffinlim
,T.GriffinLim
F.phase_vocoder
,T.TimeStretch
(native for native, pseudo for pseudo)
Only handles native complex type.
F.spectrogram
,T.Spectrogram
power=None
)return_complex
is added. (default value isFalse
)When the return value is complex-valued (
power=None
),the type of the returned Tensor can be switched with
return_complex
.return_complex
is changed toTrue
.return_complex
argument is deprecated.Migration steps
F.angle
,F.complex_norm
,F.magphase
andT.ComplexNorm
F.phase_vocoder
,T.TimeStretch
F.spectrogram
,T.Spectrogram
PRs - TODO (@mthrok)
Migration
Phase 1
Code Change
F.griffinlim
Adopt native complex dtype in griffnlim #1368
F.phase_vocoder
,T.TimeStretch
Make
F.phase_vocoder
andT.TimeStretch
handle complex dtype #1410Use complex tensors in phase_vocoder #758F.spectrogram
,T.Spectrogram
,T.MelSpectrogram
Add
return_complex
to F.spectrogram and T.Spectrogram #1366[D] Update spectrogram to use complex #1009Add deprecation Warnings
F.angle
,F.complex_norm
,T.ComplexNorm
,F.phase_vocoder
,T.TimeStretch
,F.spectrogram
,T.Spectrogram
Add deprecation warnings to functions for complex #1445
Add warning about complex tensor to spectrogram #1431F.magphase
Add deprecation warnings to magphase and ComplexNorm #1492Phase 2
Change the default value of
return_complex
toTrue
.F.spectrogram
,T.Spectrogram
[BC-Breaking] Default to native complex type when returning raw spect… #1549Update the deprecation warnings to indicate the version of removal.
F.angle
,F.complex_norm
,F.magphase
,F.phase_vocoder
,T.TimeStretch
,T.ComplexNorm
Set removal version of pseudo complex support #1553Phase 3
Remove the support for pseudo complex type.
F.magphase
[BC-Breaking] Remove deprecatedF.magphase
#1934F.angle
[BC-Breaking] Remove deprecated F.angle #1935F.complex_norm
andT.ComplexNorm
[BC-Breaking] Remove F.complex_norm and T.ComplexNorm #1942F.phase_vocoder
andT.TimeStretch
[BC-Breaking] Drop pseudo complex support from phase_vocoder / TimeStretch #1957F.spectrogram
andT.Spectrogram
[BC-Breaking] Drop pseudo complex support from spectrogram #1958return_complex
argument fromF.spectrogram
andT.Spectrogram
Surrounding works
Conjugate input tests
Autograd tests
T.GriffinLim
F.griffinlim
T.TimeStretch
,F.phase_vocoder
NOTE They are not differentiable around zero.
T.Spectrogram
,T.MelSpectrogram
T.Spectrogram
is not deterministic on its backward pass. [error -> issue pytorch/#54093]F.spectrogram
Ensuring TorchScript support
Check if all the functionals/transforms are covered by TorchScript consistency test and add if missing
F.spectrogram
F.griffinlim
F.phase_vocoder
T.Spectrogram
T.GriffinLim
T.MelSpectrogram
T.TimeStretch
Benchmark
F.phase_vocoder
MakeF.phase_vocoder
andT.TimeStretch
handle complex dtype #1410 x6 speed up on CPU. Not much different on GPUcc @anjali411 @vincentqb
The text was updated successfully, but these errors were encountered: