From 08a71271b2c778fb79cc47b8213203432bf1dc1e Mon Sep 17 00:00:00 2001 From: gmagogsfm Date: Mon, 3 Aug 2020 12:22:39 -0700 Subject: [PATCH] Switch string formatting to str.format to be TorchScript friendly. (#850) --- test/compliance/generate_fbank_data.py | 2 +- test/test_compliance_kaldi.py | 4 +++- torchaudio/compliance/kaldi.py | 12 +++++++----- torchaudio/functional.py | 4 ++-- torchaudio/transforms.py | 10 +++++----- 5 files changed, 18 insertions(+), 14 deletions(-) diff --git a/test/compliance/generate_fbank_data.py b/test/compliance/generate_fbank_data.py index 446ebea284..96a935d116 100644 --- a/test/compliance/generate_fbank_data.py +++ b/test/compliance/generate_fbank_data.py @@ -92,7 +92,7 @@ def decode(fn, sound_path, exe_path, scp_path, out_dir): 'round_to_power_of_two', 'snip_edges', 'subtract_mean', 'use_energy', 'use_log_fbank', 'use_power', 'vtln_high', 'vtln_low', 'vtln_warp', 'window_type'] fn_split = fn.split('-') - assert len(fn_split) == len(arr), ('Len mismatch: %d and %d' % (len(fn_split), len(arr))) + assert len(fn_split) == len(arr), ('Len mismatch: {} and {}'.format(len(fn_split), len(arr))) inputs = {arr[i]: utils.parse(fn_split[i]) for i in range(len(arr))} # print flags for C++ diff --git a/test/test_compliance_kaldi.py b/test/test_compliance_kaldi.py index 450183d6b3..c1b9250612 100644 --- a/test/test_compliance_kaldi.py +++ b/test/test_compliance_kaldi.py @@ -148,7 +148,9 @@ def _compliance_test_helper(self, sound_filepath, filepath_key, expected_num_fil sound, sr = torchaudio.load_wav(sound_filepath) files = self.test_filepaths[filepath_key] - assert len(files) == expected_num_files, ('number of kaldi %s file changed to %d' % (filepath_key, len(files))) + assert len(files) == expected_num_files, \ + ('number of kaldi {} file changed to {}'.format( + filepath_key, len(files))) for f in files: print(f) diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py index 6211b09f2d..792f439ccd 100644 --- a/torchaudio/compliance/kaldi.py +++ b/torchaudio/compliance/kaldi.py @@ -135,13 +135,15 @@ def _get_waveform_and_window_properties(waveform: Tensor, r"""Gets the waveform and window properties """ channel = max(channel, 0) - assert channel < waveform.size(0), ('Invalid channel %d for size %d' % (channel, waveform.size(0))) + assert channel < waveform.size(0), ('Invalid channel {} for size {}'.format(channel, waveform.size(0))) waveform = waveform[channel, :] # size (n) window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS) window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS) padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size - assert 2 <= window_size <= len(waveform), ('choose a window size %d that is [2, %d]' % (window_size, len(waveform))) + assert 2 <= window_size <= len( + waveform), ('choose a window size {} that is [2, {}]' + .format(window_size, len(waveform))) assert 0 < window_shift, '`window_shift` must be greater than 0' assert padded_window_size % 2 == 0, 'the padded `window_size` must be divisible by two.' \ ' use `round_to_power_of_two` or change `frame_length`' @@ -430,7 +432,7 @@ def get_mel_banks(num_bins: int, high_freq += nyquist assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), \ - ('Bad values in options: low-freq %f and high-freq %f vs. nyquist %f' % (low_freq, high_freq, nyquist)) + ('Bad values in options: low-freq {} and high-freq {} vs. nyquist {}'.format(low_freq, high_freq, nyquist)) # fft-bin width [think of it as Nyquist-freq / half-window-length] fft_bin_width = sample_freq / window_length_padded @@ -446,8 +448,8 @@ def get_mel_banks(num_bins: int, assert vtln_warp_factor == 1.0 or ((low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)), \ - ('Bad values in options: vtln-low %f and vtln-high %f, versus low-freq %f and high-freq %f' % - (vtln_low, vtln_high, low_freq, high_freq)) + ('Bad values in options: vtln-low {} and vtln-high {}, versus ' + 'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq)) bin = torch.arange(num_bins).unsqueeze(1) left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 3b70c8efe2..4b79c01960 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -149,8 +149,8 @@ def griffinlim( Returns: torch.Tensor: waveform of (..., time), where time equals the ``length`` parameter if given. """ - assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum - assert momentum >= 0, 'momentum=%s < 0' % momentum + assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum) + assert momentum >= 0, 'momentum={} < 0'.format(momentum) # pack batch shape = specgram.size() diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 72e55c94c8..ef0a58e637 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -141,8 +141,8 @@ def __init__(self, rand_init: bool = True) -> None: super(GriffinLim, self).__init__() - assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum - assert momentum > 0, 'momentum=%s < 0' % momentum + assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum) + assert momentum > 0, 'momentum={} < 0'.format(momentum) self.n_fft = n_fft self.n_iter = n_iter @@ -237,7 +237,7 @@ def __init__(self, self.f_max = f_max if f_max is not None else float(sample_rate // 2) self.f_min = f_min - assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max) + assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate) @@ -313,7 +313,7 @@ def __init__(self, self.tolerance_change = tolerance_change self.sgdargs = sgdargs or {'lr': 0.1, 'momentum': 0.9} - assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max) + assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate) self.register_buffer('fb', fb) @@ -607,7 +607,7 @@ def forward(self, waveform: Tensor) -> Tensor: return waveform - raise ValueError('Invalid resampling method: %s' % (self.resampling_method)) + raise ValueError('Invalid resampling method: {}'.format(self.resampling_method)) class ComplexNorm(torch.nn.Module):