Skip to content

Commit

Permalink
Switch string formatting to str.format to be TorchScript friendly. (#850
Browse files Browse the repository at this point in the history
)
  • Loading branch information
gmagogsfm authored Aug 3, 2020
1 parent 3bab2b2 commit 08a7127
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 14 deletions.
2 changes: 1 addition & 1 deletion test/compliance/generate_fbank_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def decode(fn, sound_path, exe_path, scp_path, out_dir):
'round_to_power_of_two', 'snip_edges', 'subtract_mean', 'use_energy', 'use_log_fbank',
'use_power', 'vtln_high', 'vtln_low', 'vtln_warp', 'window_type']
fn_split = fn.split('-')
assert len(fn_split) == len(arr), ('Len mismatch: %d and %d' % (len(fn_split), len(arr)))
assert len(fn_split) == len(arr), ('Len mismatch: {} and {}'.format(len(fn_split), len(arr)))
inputs = {arr[i]: utils.parse(fn_split[i]) for i in range(len(arr))}

# print flags for C++
Expand Down
4 changes: 3 additions & 1 deletion test/test_compliance_kaldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ def _compliance_test_helper(self, sound_filepath, filepath_key, expected_num_fil
sound, sr = torchaudio.load_wav(sound_filepath)
files = self.test_filepaths[filepath_key]

assert len(files) == expected_num_files, ('number of kaldi %s file changed to %d' % (filepath_key, len(files)))
assert len(files) == expected_num_files, \
('number of kaldi {} file changed to {}'.format(
filepath_key, len(files)))

for f in files:
print(f)
Expand Down
12 changes: 7 additions & 5 deletions torchaudio/compliance/kaldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,15 @@ def _get_waveform_and_window_properties(waveform: Tensor,
r"""Gets the waveform and window properties
"""
channel = max(channel, 0)
assert channel < waveform.size(0), ('Invalid channel %d for size %d' % (channel, waveform.size(0)))
assert channel < waveform.size(0), ('Invalid channel {} for size {}'.format(channel, waveform.size(0)))
waveform = waveform[channel, :] # size (n)
window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size

assert 2 <= window_size <= len(waveform), ('choose a window size %d that is [2, %d]' % (window_size, len(waveform)))
assert 2 <= window_size <= len(
waveform), ('choose a window size {} that is [2, {}]'
.format(window_size, len(waveform)))
assert 0 < window_shift, '`window_shift` must be greater than 0'
assert padded_window_size % 2 == 0, 'the padded `window_size` must be divisible by two.' \
' use `round_to_power_of_two` or change `frame_length`'
Expand Down Expand Up @@ -430,7 +432,7 @@ def get_mel_banks(num_bins: int,
high_freq += nyquist

assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), \
('Bad values in options: low-freq %f and high-freq %f vs. nyquist %f' % (low_freq, high_freq, nyquist))
('Bad values in options: low-freq {} and high-freq {} vs. nyquist {}'.format(low_freq, high_freq, nyquist))

# fft-bin width [think of it as Nyquist-freq / half-window-length]
fft_bin_width = sample_freq / window_length_padded
Expand All @@ -446,8 +448,8 @@ def get_mel_banks(num_bins: int,

assert vtln_warp_factor == 1.0 or ((low_freq < vtln_low < high_freq) and
(0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)), \
('Bad values in options: vtln-low %f and vtln-high %f, versus low-freq %f and high-freq %f' %
(vtln_low, vtln_high, low_freq, high_freq))
('Bad values in options: vtln-low {} and vtln-high {}, versus '
'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq))

bin = torch.arange(num_bins).unsqueeze(1)
left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def griffinlim(
Returns:
torch.Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
"""
assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum
assert momentum >= 0, 'momentum=%s < 0' % momentum
assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum)
assert momentum >= 0, 'momentum={} < 0'.format(momentum)

# pack batch
shape = specgram.size()
Expand Down
10 changes: 5 additions & 5 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def __init__(self,
rand_init: bool = True) -> None:
super(GriffinLim, self).__init__()

assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum
assert momentum > 0, 'momentum=%s < 0' % momentum
assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum)
assert momentum > 0, 'momentum={} < 0'.format(momentum)

self.n_fft = n_fft
self.n_iter = n_iter
Expand Down Expand Up @@ -237,7 +237,7 @@ def __init__(self,
self.f_max = f_max if f_max is not None else float(sample_rate // 2)
self.f_min = f_min

assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)

fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
Expand Down Expand Up @@ -313,7 +313,7 @@ def __init__(self,
self.tolerance_change = tolerance_change
self.sgdargs = sgdargs or {'lr': 0.1, 'momentum': 0.9}

assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)

fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
self.register_buffer('fb', fb)
Expand Down Expand Up @@ -607,7 +607,7 @@ def forward(self, waveform: Tensor) -> Tensor:

return waveform

raise ValueError('Invalid resampling method: %s' % (self.resampling_method))
raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))


class ComplexNorm(torch.nn.Module):
Expand Down

0 comments on commit 08a7127

Please sign in to comment.