Skip to content

Commit

Permalink
Add docstring and remove deprecated/debug arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Feb 8, 2021
1 parent d0d5e49 commit 317835e
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 12 deletions.
7 changes: 1 addition & 6 deletions torchaudio/csrc/kaldi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ torch::Tensor denormalize(const torch::Tensor& t) {

torch::Tensor compute_kaldi_pitch(
const torch::Tensor& wave,
::kaldi::PitchExtractionOptions& opts) {
const ::kaldi::PitchExtractionOptions& opts) {
::kaldi::VectorBase<::kaldi::BaseFloat> input(wave);
::kaldi::Matrix<::kaldi::BaseFloat> output;
::kaldi::ComputeKaldiPitch(opts, input, &output);
Expand All @@ -30,7 +30,6 @@ torch::Tensor ComputeKaldiPitch(
double sample_frequency,
double frame_length,
double frame_shift,
double preemphasis_coefficient,
double min_f0,
double max_f0,
double soft_min_f0,
Expand All @@ -45,7 +44,6 @@ torch::Tensor ComputeKaldiPitch(
int64_t frames_per_chunk,
bool simulate_first_pass_online,
int64_t recompute_frame,
bool nccf_ballast_online,
bool snip_edges) {

TORCH_CHECK(wave.ndimension() == 2, "Input tensor must be 2 dimentional.");
Expand All @@ -56,23 +54,20 @@ torch::Tensor ComputeKaldiPitch(
opts.samp_freq = static_cast<::kaldi::BaseFloat>(sample_frequency);
opts.frame_shift_ms = static_cast<::kaldi::BaseFloat>(frame_shift);
opts.frame_length_ms = static_cast<::kaldi::BaseFloat>(frame_length);
opts.preemph_coeff = static_cast<::kaldi::BaseFloat>(preemphasis_coefficient);
opts.min_f0 = static_cast<::kaldi::BaseFloat>(min_f0);
opts.max_f0 = static_cast<::kaldi::BaseFloat>(max_f0);
opts.soft_min_f0 = static_cast<::kaldi::BaseFloat>(soft_min_f0);
opts.penalty_factor = static_cast<::kaldi::BaseFloat>(penalty_factor);
opts.lowpass_cutoff = static_cast<::kaldi::BaseFloat>(lowpass_cutoff);
opts.resample_freq = static_cast<::kaldi::BaseFloat>(resample_frequency);
opts.delta_pitch = static_cast<::kaldi::BaseFloat>(delta_pitch);
opts.nccf_ballast = static_cast<::kaldi::BaseFloat>(nccf_ballast);
opts.lowpass_filter_width = static_cast<::kaldi::int32>(lowpass_filter_width);
opts.upsample_filter_width =
static_cast<::kaldi::int32>(upsample_filter_width);
opts.max_frames_latency = static_cast<::kaldi::int32>(max_frames_latency);
opts.frames_per_chunk = static_cast<::kaldi::int32>(frames_per_chunk);
opts.simulate_first_pass_online = simulate_first_pass_online;
opts.recompute_frame = static_cast<::kaldi::int32>(recompute_frame);
opts.nccf_ballast_online = nccf_ballast_online;
opts.snip_edges = snip_edges;

// Kaldi's float type expects value range of int16 expressed as float
Expand Down
76 changes: 70 additions & 6 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,6 @@ def compute_kaldi_pitch(
sample_rate: float,
frame_length: float = 25.0,
frame_shift: float = 10.0,
preemph_coeff: float = 0.0,
min_f0: float = 50,
max_f0: float = 400,
soft_min_f0: float = 10.0,
Expand All @@ -1014,19 +1013,84 @@ def compute_kaldi_pitch(
frames_per_chunk: int = 0,
simulate_first_pass_online: bool = False,
recompute_frame: int = 500,
nccf_ballast_online: bool = False,
snip_edges: bool = True,
):
"""Equivalent of `compute-kaldi-pitch-feats`"""
) -> torch.Tensor:
"""Extract pitch based on method described in [1].
This function computes the equivalent of `compute-kaldi-pitch-feats` from Kaldi.
Args:
waveform (Tensor):
The input waveform of shape `(..., time)`.
sample_rate (float):
Sample rate of `waveform`.
frame_length (float, optional):
Frame length in milliseconds.
frame_shift (float, optional):
Frame shift in milliseconds.
min_f0 (float, optional):
Minimum F0 to search for (Hz)
max_f0 (float, optional):
Maximum F0 to search for (Hz)
soft_min_f0 (float, optional):
Minimum f0, applied in soft way, must not exceed min-f0
penalty_factor (float, optional):
Cost factor for FO change.
lowpass_cutoff (float, optional):
Cutoff frequency for LowPass filter (Hz)
resample_frequency (float, optional):
Frequency that we down-sample the signal to. Must be more than twice lowpass-cutoff.
delta_pitch( float, optional):
Smallest relative change in pitch that our algorithm measures.
nccf_ballast (float, optional):
Increasing this factor reduces NCCF for quiet frames
lowpass_filter_width (int, optional):
Integer that determines filter width of lowpass filter, more gives sharper filter.
upsample_filter_width (int, optional):
Integer that determines filter width when upsampling NCCF.
max_frames_latency (int, optional):
Maximum number of frames of latency that we allow pitch tracking to introduce into
the feature processing (affects output only if ``frames_per_chunk > 0`` and
``simulate_first_pass_online=True``)
frames_per_chunk (int, optional):
The number of frames used for energy normalization.
simulate_first_pass_online (bool, optional):
If true, the function will output features that correspond to what an online decoder
would see in the first pass of decoding -- not the final version of the features,
which is the default.
Relevant if ``frames_per_chunk > 0``.
recompute_frame (int, optional):
Only relevant for compatibility with online pitch extraction.
A non-critical parameter; the frame at which we recompute some of the forward pointers,
after revising our estimate of the signal energy.
Relevant if ``frames_per_chunk > 0``.
snip_edges (bool, optional):
If this is set to false, the incomplete frames near the ending edge won't be snipped,
so that the number of frames is the file size divided by the frame-shift.
This makes different types of features give the same number of frames.
Returns:
Tensor: Pitch feature. Shape: `(batch, frames 2)` where the last dimension
corresponds to pitch and NCCF.
Reference:
- A pitch extraction algorithm tuned for automatic speech recognition
P. Ghahremani, B. BabaAli, D. Povey, K. Riedhammer, J. Trmal and S. Khudanpur
2014 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP),
Florence, 2014, pp. 2494-2498, doi: 10.1109/ICASSP.2014.6854049.
"""
shape = waveform.shape
waveform = waveform.reshape(-1, shape[-1])
result = torch.ops.torchaudio.kaldi_ComputeKaldiPitch(
waveform, sample_rate, frame_length, frame_shift, preemph_coeff,
waveform, sample_rate, frame_length, frame_shift,
min_f0, max_f0, soft_min_f0, penalty_factor, lowpass_cutoff,
resample_frequency, delta_pitch, nccf_ballast,
lowpass_filter_width, upsample_filter_width, max_frames_latency,
frames_per_chunk, simulate_first_pass_online, recompute_frame,
nccf_ballast_online, snip_edges,
snip_edges,
)
result = result.reshape(shape[:-1] + result.shape[-2:])
return result

0 comments on commit 317835e

Please sign in to comment.