Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make F.phase_vocoder and T.TimeStretch handle complex dtype #1410

Merged
merged 8 commits into from
Apr 2, 2021

Conversation

mthrok
Copy link
Collaborator

@mthrok mthrok commented Mar 29, 2021

This PR supersedes #758 and includes further updates.
See #1337 for the overview.

  1. F.phase_vocoder accepts Tensor with complex dtype.
    • The implementation path has been updated from Use complex tensors in phase_vocoder #758 so that they share the same code path by internally converting the input Tensor to complex dtype and performing all the operation on top of it.
    • Adopted torch.polar for simpler Tensor generation from magnitude and angle.
  2. Updated tests
    • librosa compatibility test for complex dtype and pseudo complex dtype
      • Extracted the output shape check test and moved it to functional so that it will be tested on all the combination of {CPU | CUDA} x {complex64 | complex128}
    • TorchScript compatibility test for F.phase_vocoder and T.TimeStretch.
    • batch consistency test for T.TimeStretch.

Benchmark of phase_vocoder

env
Collecting environment information...
PyTorch version: 1.9.0.dev20210329
Is debug build: False
CUDA used to build PyTorch: 10.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.18.4

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.1.243
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 450.80.02
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] pytorch-sphinx-theme==0.0.24
[pip3] torch==1.9.0.dev20210329
[pip3] torchaudio==0.9.0a0+550dc90
[pip3] torchtext==0.9.0a0+c072ba6
[conda] blas                      1.0                         mkl
[conda] cudatoolkit               10.1.243             h6bb024c_0
[conda] magma-cuda101             2.5.2                         1    pytorch
[conda] mkl                       2020.2                      256
[conda] mkl-include               2020.4             h726a3e6_304    conda-forge
[conda] mkl-service               2.3.0            py38he904b0f_0
[conda] mkl_fft                   1.3.0            py38h54f3939_0
[conda] mkl_random                1.1.1            py38h0573a6f_0
[conda] numpy                     1.19.2           py38h54aff64_0
[conda] numpy-base                1.19.2           py38hfa32c7d_0
[conda] pytorch                   1.9.0.dev20210329 py3.8_cuda10.1_cudnn7.6.3_0    pytorch-nightly
[conda] pytorch-sphinx-theme      0.0.24                    dev_0    <develop>
[conda] torch                     1.7.1                    pypi_0    pypi
[conda] torchaudio                0.9.0a0+550dc90           dev_0    <develop>
[conda] torchtext                 0.9.0a0+c072ba6           dev_0    <develop>

CPU

code
#!/usr/bin/env bash

OMP_NUM_THREADS=1 numactl --membind 0 --cpubind 0 python -m timeit -s """
import torch
import torchaudio;

num_channels = 2
num_freq = 1025
num_frames = 400

rate = 0.5
hop_length = 256

torch.manual_seed(0);
spec = torch.randn(num_channels, num_freq, num_frames, 2, dtype=torch.float32);
phase_advance = torch.linspace(0, 3.14 * hop_length, num_freq, dtype=torch.float32)[..., None];
""" """
torchaudio.functional.phase_vocoder(spec, rate, phase_advance)
"""

OMP_NUM_THREADS=1 numactl --membind 0 --cpubind 0 python -m timeit -s """
import torch
import torchaudio;

num_channels = 2
num_freq = 1025
num_frames = 400

rate = 0.5
hop_length = 256

torch.manual_seed(0);
spec = torch.view_as_complex(torch.randn(num_channels, num_freq, num_frames, 2, dtype=torch.float32));
phase_advance = torch.linspace(0, 3.14 * hop_length, num_freq, dtype=torch.float32)[..., None];
""" """
torchaudio.functional.phase_vocoder(spec, rate, phase_advance)
"""
Code Pseudo Complex Native Complex
512c2fa (before this PR) 758 N/A
With this PR applied 97 95

unit: msec

CUDA

code
#!/usr/bin/env bash

OMP_NUM_THREADS=1 numactl --membind 0 --cpubind 0 python -m timeit -s """
import torch
import torchaudio;

num_channels = 2
num_freq = 1025
num_frames = 400

rate = 0.5
hop_length = 256

torch.manual_seed(0);
spec = torch.randn(num_channels, num_freq, num_frames, 2, dtype=torch.float32, device='cuda');
phase_advance = torch.linspace(0, 3.14 * hop_length, num_freq, dtype=torch.float32, device='cuda')[..., None];
""" """
torchaudio.functional.phase_vocoder(spec, rate, phase_advance)
"""

OMP_NUM_THREADS=1 numactl --membind 0 --cpubind 0 python -m timeit -s """
import torch
import torchaudio;

num_channels = 2
num_freq = 1025
num_frames = 400

rate = 0.5
hop_length = 256

torch.manual_seed(0);
spec = torch.view_as_complex(torch.randn(num_channels, num_freq, num_frames, 2, dtype=torch.float32, device='cuda'));
phase_advance = torch.linspace(0, 3.14 * hop_length, num_freq, dtype=torch.float32, device='cuda')[..., None];
""" """
torchaudio.functional.phase_vocoder(spec, rate, phase_advance)
"""
Code Pseudo Complex Native Complex
512c2fa (before this PR) 1.85 N/A
With this PR applied 1.22 1.21

unit: msec

@mthrok mthrok force-pushed the migrate-phase-vocoder branch from 92f7755 to 88fba5c Compare March 30, 2021 00:01
@mthrok mthrok force-pushed the migrate-phase-vocoder branch from 88fba5c to d27a820 Compare March 30, 2021 15:56
@mthrok mthrok changed the title [WIP] Migrate phase_vocoder to complex dtypes Make F.phase_vocoder and T.TimeStretch handle complex dtype Mar 30, 2021
@mthrok mthrok marked this pull request as ready for review March 30, 2021 16:50
assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"
if not complex_specgrams.is_complex() and complex_specgrams.size(-1) != 2:
raise ValueError(
"complex_specgrams must be either complex dtype or "

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - must be either native complex tensors or real valued tensors with shape (..., 2)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: in particular "e.g." should be removed here

Copy link

@anjali411 anjali411 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a minor comment regarding docs but looks good overall. cc. @vincentqb if you had any comments/concerns/ wanted to take a look.

@anjali411
Copy link

Actually just realized, we don't have serialization tests for TimeStretch module, and should prob add one

Copy link
Contributor

@vincentqb vincentqb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx @anjali411 for the ping -- LGTM too! (with minor phrasing as you pointed out)

About the serialization test:

  • @anjali411: are there tests on the pytorch-side around serialization? can/should we make the assumption the serialization is covered on that front? say by supporting the operations in the doc?
  • Since adding a test for serialization could eventually catch errors especially when changing complex dtype parameters (say across versions), I can see this could help catch errors. However, since the parameters in TimeStretch are not changing for this pull request, I do not expect the nn.Module state to break BC compatibility with this pull request (say a user had saved the state prior to this pull request and now re-load it after this pull request).

assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"
if not complex_specgrams.is_complex() and complex_specgrams.size(-1) != 2:
raise ValueError(
"complex_specgrams must be either complex dtype or "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: in particular "e.g." should be removed here

else:
rate = overriding_rate

if rate == 1.0:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you clarify this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TimeStretch is a thin wrapper around phase_vocoder. and not altering the input Tensor when rate == 1.0 is a valid branching for phase_vocoder itself, not just TimeStretch. so it makes more sense if it's in phase_vocoder.

@anjali411
Copy link

thx @anjali411 for the ping -- LGTM too! (with minor phrasing as you pointed out)

About the serialization test:

  • @anjali411: are there tests on the pytorch-side around serialization? can/should we make the assumption the serialization is covered on that front? say by supporting the operations in the doc?

No, because the JIT support for complex numbers is very new and some may say, it's still "under construction".

  • Since adding a test for serialization could eventually catch errors especially when changing complex dtype parameters (say across versions), I can see this could help catch errors. However, since the parameters in TimeStretch are not changing for this pull request, I do not expect the nn.Module state to break BC compatibility with this pull request (say a user had saved the state prior to this pull request and now re-load it after this pull request).

That makes total sense and I agree. But I think it would be nice to have serialization tests, just in general, and especially for a sanity check now that we are migrating to complex types.

@mthrok
Copy link
Collaborator Author

mthrok commented Apr 2, 2021

@anjali411 I am merging this so that I can resolve the conflict in other PRs. Let me know if you have further update, we can follow up on it in a separate PR.

@mthrok mthrok merged commit 0433b7a into pytorch:master Apr 2, 2021
@mthrok mthrok deleted the migrate-phase-vocoder branch April 2, 2021 13:48
@mthrok mthrok modified the milestone: v0.9 Apr 2, 2021
@mthrok mthrok modified the milestones: Complex Tensor Migration, v0.9 Apr 5, 2021
mthrok added a commit to mthrok/audio that referenced this pull request Apr 5, 2021
spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)

assert spec.dim() == spec_stretch.dim()
expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW torch also supports ceil

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants