-
Notifications
You must be signed in to change notification settings - Fork 664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding test for T.SlidingWindowCmn #1482
Conversation
@carolineechen Can you please review this? |
@parameterized.expand([ | ||
({'cmn_window': 600, 'min_cmn_window': 100, 'center': False, 'norm_vars': False}, ), | ||
({'cmn_window': 600, 'min_cmn_window': 100, 'center': True, 'norm_vars': False}, ), | ||
({'cmn_window': 600, 'min_cmn_window': 100, 'center': False, 'norm_vars': False}, ), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this set of params is a duplicate of the first one -- could you change or remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching that, will change remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the duplicate test case. Thanks!
@@ -157,7 +157,6 @@ def test_vol(self, gain, gain_type): | |||
@parameterized.expand([ | |||
({'cmn_window': 600, 'min_cmn_window': 100, 'center': False, 'norm_vars': False}, ), | |||
({'cmn_window': 600, 'min_cmn_window': 100, 'center': True, 'norm_vars': False}, ), | |||
({'cmn_window': 600, 'min_cmn_window': 100, 'center': False, 'norm_vars': True}, ), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mthrok this set of params causes the test to fail with RuntimeError: Jacobian mismatch for output 0 with respect to input 0
. Any idea why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means that when norm_vars=True
, some operation is not differential to the 2nd degree.
It's somewhere here but it is not immediately clear to me.
audio/torchaudio/functional/functional.py
Lines 1073 to 1100 in 0c263a9
if norm_vars: | |
cur_sumsq += torch.cumsum(input_part ** 2, 1)[:, -1, :] | |
else: | |
if window_start > last_window_start: | |
frame_to_remove = specgram[:, last_window_start, :] | |
cur_sum -= frame_to_remove | |
if norm_vars: | |
cur_sumsq -= (frame_to_remove ** 2) | |
if window_end > last_window_end: | |
frame_to_add = specgram[:, last_window_end, :] | |
cur_sum += frame_to_add | |
if norm_vars: | |
cur_sumsq += (frame_to_add ** 2) | |
window_frames = window_end - window_start | |
last_window_start = window_start | |
last_window_end = window_end | |
cmn_specgram[:, t, :] = specgram[:, t, :] - cur_sum / window_frames | |
if norm_vars: | |
if window_frames == 1: | |
cmn_specgram[:, t, :] = torch.zeros( | |
num_channels, num_feats, dtype=dtype, device=device) | |
else: | |
variance = cur_sumsq | |
variance = variance / window_frames | |
variance -= ((cur_sum ** 2) / (window_frames ** 2)) | |
variance = torch.pow(variance, -0.5) | |
cmn_specgram[:, t, :] *= variance | |
What we want to do is
- Identify which part is causing this.
- Change the code if that is possible without performance degradation.
However, they are beyond the scope of this PR, so here, we can set nondet_tol
and add docstring saying it's not 2nd-order differentiable when norm_vars=True
, like in Spectrogram.
audio/test/torchaudio_unittest/transforms/autograd_test_impl.py
Lines 71 to 77 in 0c263a9
# replication_pad1d_backward_cuda is not deteministic and | |
# gives very small (~2.7756e-17) difference. | |
# | |
# See https://github.com/pytorch/pytorch/issues/54093 | |
transform = T.Spectrogram(**kwargs) | |
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) | |
self.assert_grad(transform, [waveform], nondet_tol=1e-10) |
def test_sliding_window_cmn(self, kwargs): | ||
sample_rate = 8000 | ||
transform = T.SlidingWindowCmn(**kwargs) | ||
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The input to SlidingWindowCmn
is supposed to be spectrogram.
This has been fixed in the master documentation https://pytorch.org/audio/master/functional.html#torchaudio.functional.sliding_window_cmn
Can you use get_spectrogram
, then flip the last axis so that Tensor dimension is [... time, freq]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the doc of torch.stft
I could find that it returns a tensor in the shape (* × N × T) so do you suggest using torch.transpose(-2, -1)
on the output?
({'cmn_window': 600, 'min_cmn_window': 100, 'center': True, 'norm_vars': False}, ), | ||
({'cmn_window': 600, 'min_cmn_window': 100, 'center': False, 'norm_vars': True}, ), | ||
({'cmn_window': 600, 'min_cmn_window': 100, 'center': True, 'norm_vars': True}, ), | ||
({'cmn_window': 500, 'min_cmn_window': 50, 'center': False, 'norm_vars': False}, ), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cmn_window =600
and min_cmn_window=100
look too big for the input with 8000 * 0.05 == 400
(then FFT applied) can you make them somewhat smaller than the number of frames in time axis of the input tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @mthrok will try to implement your suggestions
unittest seems to pass with @mthrok 's suggestions
|
Autograd tests for Transforms #1414