Skip to content

Commit

Permalink
FilterbankFeaturesTA to match FilterbankFeatures (NVIDIA#5913)
Browse files Browse the repository at this point in the history
Signed-off-by: Mohamed Saad Ibn Seddik <[email protected]>
Signed-off-by: Jason <[email protected]>
  • Loading branch information
msis authored and blisc committed Feb 10, 2023
1 parent fe6d14f commit c99a261
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions nemo/collections/asr/parts/preprocessing/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,22 +529,22 @@ def __init__(
if window not in self.torch_windows:
raise ValueError(f"Got window value '{window}' but expected a member of {self.torch_windows.keys()}")

self._win_length = n_window_size
self._hop_length = n_window_stride
self.win_length = n_window_size
self.hop_length = n_window_stride
self._sample_rate = sample_rate
self._normalize_strategy = normalize
self._use_log = log
self._preemphasis_value = preemph
self._log_zero_guard_type = log_zero_guard_type
self._log_zero_guard_value: Union[str, float] = log_zero_guard_value
self._dither_value = dither
self._pad_to = pad_to
self._pad_value = pad_value
self._num_fft = n_fft
self.log_zero_guard_type = log_zero_guard_type
self.log_zero_guard_value: Union[str, float] = log_zero_guard_value
self.dither = dither
self.pad_to = pad_to
self.pad_value = pad_value
self.n_fft = n_fft
self._mel_spec_extractor: torchaudio.transforms.MelSpectrogram = torchaudio.transforms.MelSpectrogram(
sample_rate=self._sample_rate,
win_length=self._win_length,
hop_length=self._hop_length,
win_length=self.win_length,
hop_length=self.hop_length,
n_mels=nfilt,
window_fn=self.torch_windows[window],
mel_scale="slaney",
Expand All @@ -561,13 +561,13 @@ def filter_banks(self):
return self._mel_spec_extractor.mel_scale.fb

def _resolve_log_zero_guard_value(self, dtype: torch.dtype) -> float:
if isinstance(self._log_zero_guard_value, float):
return self._log_zero_guard_value
return getattr(torch.finfo(dtype), self._log_zero_guard_value)
if isinstance(self.log_zero_guard_value, float):
return self.log_zero_guard_value
return getattr(torch.finfo(dtype), self.log_zero_guard_value)

def _apply_dithering(self, signals: torch.Tensor) -> torch.Tensor:
if self.training and self._dither_value > 0.0:
noise = torch.randn_like(signals) * self._dither_value
if self.training and self.dither > 0.0:
noise = torch.randn_like(signals) * self.dither
signals = signals + noise
return signals

Expand All @@ -578,25 +578,25 @@ def _apply_preemphasis(self, signals: torch.Tensor) -> torch.Tensor:
return signals

def _compute_output_lengths(self, input_lengths: torch.Tensor) -> torch.Tensor:
out_lengths = input_lengths.div(self._hop_length, rounding_mode="floor").add(1).long()
out_lengths = input_lengths.div(self.hop_length, rounding_mode="floor").add(1).long()
return out_lengths

def _apply_pad_to(self, features: torch.Tensor) -> torch.Tensor:
# Only apply during training; else need to capture dynamic shape for exported models
if not self.training or self._pad_to == 0 or features.shape[-1] % self._pad_to == 0:
if not self.training or self.pad_to == 0 or features.shape[-1] % self.pad_to == 0:
return features
pad_length = self._pad_to - (features.shape[-1] % self._pad_to)
return torch.nn.functional.pad(features, pad=(0, pad_length), value=self._pad_value)
pad_length = self.pad_to - (features.shape[-1] % self.pad_to)
return torch.nn.functional.pad(features, pad=(0, pad_length), value=self.pad_value)

def _apply_log(self, features: torch.Tensor) -> torch.Tensor:
if self._use_log:
zero_guard = self._resolve_log_zero_guard_value(features.dtype)
if self._log_zero_guard_type == "add":
if self.log_zero_guard_type == "add":
features = features + zero_guard
elif self._log_zero_guard_type == "clamp":
elif self.log_zero_guard_type == "clamp":
features = features.clamp(min=zero_guard)
else:
raise ValueError(f"Unsupported log zero guard type: '{self._log_zero_guard_type}'")
raise ValueError(f"Unsupported log zero guard type: '{self.log_zero_guard_type}'")
features = features.log()
return features

Expand Down

0 comments on commit c99a261

Please sign in to comment.