Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BC-Breaking] Remove F.complex_norm and T.ComplexNorm #1942

Merged
merged 4 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,6 @@ resample

.. autofunction:: resample

:hidden:`Complex Utility`
~~~~~~~~~~~~~~~~~~~~~~~~~

Utilities for pseudo complex tensor. This is not for the native complex dtype, such as `cfloat64`, but for tensors with real-value type and have extra dimension at the end for real and imaginary parts.


complex_norm
------------

.. autofunction:: complex_norm


:hidden:`Filtering`
~~~~~~~~~~~~~~~~~~~
Expand Down
10 changes: 0 additions & 10 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,6 @@ Transforms are common audio transforms. They can be chained together using :clas

.. automethod:: forward

:hidden:`Complex Utility`
~~~~~~~~~~~~~~~~~~~~~~~~~

:hidden:`ComplexNorm`
---------------------

.. autoclass:: ComplexNorm

.. automethod:: forward

:hidden:`Feature Extractions`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
10 changes: 0 additions & 10 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,16 +319,6 @@ def test_amplitude_to_DB_top_db_clamp(self, shape):
f"No values were close to the limit. Did it over-clamp?\n{decibels}"
)

@parameterized.expand(
list(itertools.product([(1, 2, 1025, 400, 2), (1025, 400, 2)], [1, 2, 0.7]))
)
def test_complex_norm(self, shape, power):
torch.random.manual_seed(42)
complex_tensor = torch.randn(*shape, dtype=self.dtype, device=self.device)
expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2)
norm_tensor = F.complex_norm(complex_tensor, power)
self.assertEqual(norm_tensor, expected_norm_tensor, atol=1e-5, rtol=1e-5)

@parameterized.expand(
list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0., 30.], [1, 2]))
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,6 @@ def func(tensor):
tensor = torch.rand((1, 10))
self._assert_consistency(func, tensor)

def test_complex_norm(self):
def func(tensor):
power = 2.
return F.complex_norm(tensor, power)

tensor = torch.randn(1, 2, 1025, 400, 2)
self._assert_consistency(func, tensor)

def test_mask_along_axis(self):
def func(tensor):
mask_param = 100
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,6 @@ def test_Resample(self):
tensor = common_utils.get_whitenoise(sample_rate=sr1)
self._assert_consistency(T.Resample(sr1, sr2), tensor)

def test_ComplexNorm(self):
tensor = torch.rand((1, 2, 201, 2))
self._assert_consistency(T.ComplexNorm(), tensor)

def test_MuLawEncoding(self):
tensor = common_utils.get_whitenoise()
self._assert_consistency(T.MuLawEncoding(), tensor)
Expand Down
2 changes: 0 additions & 2 deletions torchaudio/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .functional import (
amplitude_to_DB,
complex_norm,
compute_deltas,
compute_kaldi_pitch,
create_dct,
Expand Down Expand Up @@ -52,7 +51,6 @@

__all__ = [
'amplitude_to_DB',
'complex_norm',
'compute_deltas',
'compute_kaldi_pitch',
'create_dct',
Expand Down
27 changes: 0 additions & 27 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
"DB_to_amplitude",
"mu_law_encoding",
"mu_law_decoding",
"complex_norm",
"phase_vocoder",
'mask_along_axis',
'mask_along_axis_iid',
Expand Down Expand Up @@ -722,32 +721,6 @@ def mu_law_decoding(
return x


@_mod_utils.deprecated(
"Please convert the input Tensor to complex type with `torch.view_as_complex` then "
"use `torch.abs`. "
"Please refer to https://github.com/pytorch/audio/issues/1337 "
"for more details about torchaudio's plan to migrate to native complex type.",
version="0.11",
)
def complex_norm(
complex_tensor: Tensor,
power: float = 1.0
) -> Tensor:
r"""Compute the norm of complex tensor input.

Args:
complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
power (float, optional): Power of the norm. (Default: `1.0`).

Returns:
Tensor: Power of the normed input tensor. Shape of `(..., )`
"""

# Replace by torch.norm once issue is fixed
# https://github.com/pytorch/pytorch/issues/34279
return complex_tensor.pow(2.).sum(-1).pow(0.5 * power)


def phase_vocoder(
complex_specgrams: Tensor,
rate: float,
Expand Down
37 changes: 0 additions & 37 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
'MuLawEncoding',
'MuLawDecoding',
'Resample',
'ComplexNorm',
'TimeStretch',
'Fade',
'FrequencyMasking',
Expand Down Expand Up @@ -900,42 +899,6 @@ def forward(self, waveform: Tensor) -> Tensor:
self.kernel, self.width)


class ComplexNorm(torch.nn.Module):
r"""Compute the norm of complex tensor input.

Args:
power (float, optional): Power of the norm. (Default: to ``1.0``)

Example
>>> complex_tensor = ... # Tensor shape of (…, complex=2)
>>> transform = transforms.ComplexNorm(power=2)
>>> complex_norm = transform(complex_tensor)
"""
__constants__ = ['power']

def __init__(self, power: float = 1.0) -> None:
warnings.warn(
'torchaudio.transforms.ComplexNorm has been deprecated '
'and will be removed from future release.'
'Please convert the input Tensor to complex type with `torch.view_as_complex` then '
'use `torch.abs` and `torch.angle`. '
'Please refer to https://github.com/pytorch/audio/issues/1337 '
"for more details about torchaudio's plan to migrate to native complex type."
)
super(ComplexNorm, self).__init__()
self.power = power

def forward(self, complex_tensor: Tensor) -> Tensor:
r"""
Args:
complex_tensor (Tensor): Tensor shape of `(..., complex=2)`.

Returns:
Tensor: norm of the input tensor, shape of `(..., )`.
"""
return F.complex_norm(complex_tensor, self.power)


class ComputeDeltas(torch.nn.Module):
r"""Compute delta coefficients of a tensor, usually a spectrogram.

Expand Down