From 34878e1763e5cdfca06a78b8d0c4fd47b94a997e Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 19 Jun 2020 23:23:41 +0000 Subject: [PATCH] Fix SignalInfo member name to frame --- test/sox_io_backend/test_info.py | 10 +++++----- test/sox_io_backend/test_torchscript.py | 2 +- torchaudio/csrc/register.cpp | 2 +- torchaudio/csrc/typedefs.cpp | 8 ++++---- torchaudio/csrc/typedefs.h | 6 +++--- torchaudio/extension/extension.py | 8 ++++---- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/test/sox_io_backend/test_info.py b/test/sox_io_backend/test_info.py index 917feb0618..7954af782f 100644 --- a/test/sox_io_backend/test_info.py +++ b/test/sox_io_backend/test_info.py @@ -35,7 +35,7 @@ def test_wav(self, dtype, sample_rate, num_channels): ) info = sox_io_backend.info(path) assert info.get_sample_rate() == sample_rate - assert info.get_num_samples() == sample_rate * duration + assert info.get_num_frames() == sample_rate * duration assert info.get_num_channels() == num_channels @parameterized.expand(list(itertools.product( @@ -55,7 +55,7 @@ def test_wav_multiple_channels(self, dtype, sample_rate, num_channels): ) info = sox_io_backend.info(path) assert info.get_sample_rate() == sample_rate - assert info.get_num_samples() == sample_rate * duration + assert info.get_num_frames() == sample_rate * duration assert info.get_num_channels() == num_channels @parameterized.expand(list(itertools.product( @@ -74,7 +74,7 @@ def test_mp3(self, sample_rate, num_channels, bit_rate): info = sox_io_backend.info(path) assert info.get_sample_rate() == sample_rate # mp3 does not preserve the number of samples - # assert info.get_num_samples() == sample_rate * duration + # assert info.get_num_frames() == sample_rate * duration assert info.get_num_channels() == num_channels @parameterized.expand(list(itertools.product( @@ -92,7 +92,7 @@ def test_flac(self, sample_rate, num_channels, compression_level): ) info = sox_io_backend.info(path) assert info.get_sample_rate() == sample_rate - assert info.get_num_samples() == sample_rate * duration + assert info.get_num_frames() == sample_rate * duration assert info.get_num_channels() == num_channels @parameterized.expand(list(itertools.product( @@ -110,5 +110,5 @@ def test_vorbis(self, sample_rate, num_channels, quality_level): ) info = sox_io_backend.info(path) assert info.get_sample_rate() == sample_rate - assert info.get_num_samples() == sample_rate * duration + assert info.get_num_frames() == sample_rate * duration assert info.get_num_channels() == num_channels diff --git a/test/sox_io_backend/test_torchscript.py b/test/sox_io_backend/test_torchscript.py index 1446c245f1..aff488126c 100644 --- a/test/sox_io_backend/test_torchscript.py +++ b/test/sox_io_backend/test_torchscript.py @@ -44,5 +44,5 @@ def test_info_wav(self, dtype, sample_rate, num_channels): ts_info = ts_info_func(audio_path) assert py_info.get_sample_rate() == ts_info.get_sample_rate() - assert py_info.get_num_samples() == ts_info.get_num_samples() + assert py_info.get_num_frames() == ts_info.get_num_frames() assert py_info.get_num_channels() == ts_info.get_num_channels() diff --git a/torchaudio/csrc/register.cpp b/torchaudio/csrc/register.cpp index 9de97505a1..4ab3fde639 100644 --- a/torchaudio/csrc/register.cpp +++ b/torchaudio/csrc/register.cpp @@ -12,7 +12,7 @@ static auto registerSignalInfo = .def(torch::init()) .def("get_sample_rate", &SignalInfo::getSampleRate) .def("get_num_channels", &SignalInfo::getNumChannels) - .def("get_num_samples", &SignalInfo::getNumSamples); + .def("get_num_frames", &SignalInfo::getNumFrames); static auto registerGetInfo = torch::RegisterOperators().op( torch::RegisterOperators::options() diff --git a/torchaudio/csrc/typedefs.cpp b/torchaudio/csrc/typedefs.cpp index 7b81d665dc..f4136cc918 100644 --- a/torchaudio/csrc/typedefs.cpp +++ b/torchaudio/csrc/typedefs.cpp @@ -4,10 +4,10 @@ namespace torchaudio { SignalInfo::SignalInfo( const int64_t sample_rate_, const int64_t num_channels_, - const int64_t num_samples_) + const int64_t num_frames_) : sample_rate(sample_rate_), num_channels(num_channels_), - num_samples(num_samples_){}; + num_frames(num_frames_){}; int64_t SignalInfo::getSampleRate() const { return sample_rate; @@ -17,7 +17,7 @@ int64_t SignalInfo::getNumChannels() const { return num_channels; } -int64_t SignalInfo::getNumSamples() const { - return num_samples; +int64_t SignalInfo::getNumFrames() const { + return num_frames; } } // namespace torchaudio diff --git a/torchaudio/csrc/typedefs.h b/torchaudio/csrc/typedefs.h index 646ed09f3d..ddd210e647 100644 --- a/torchaudio/csrc/typedefs.h +++ b/torchaudio/csrc/typedefs.h @@ -7,15 +7,15 @@ namespace torchaudio { struct SignalInfo : torch::CustomClassHolder { int64_t sample_rate; int64_t num_channels; - int64_t num_samples; + int64_t num_frames; SignalInfo( const int64_t sample_rate_, const int64_t num_channels_, - const int64_t num_samples_); + const int64_t num_frames_); int64_t getSampleRate() const; int64_t getNumChannels() const; - int64_t getNumSamples() const; + int64_t getNumFrames() const; }; } // namespace torchaudio diff --git a/torchaudio/extension/extension.py b/torchaudio/extension/extension.py index 4a2ab82124..4d8ac4dcba 100644 --- a/torchaudio/extension/extension.py +++ b/torchaudio/extension/extension.py @@ -30,10 +30,10 @@ class SignalInfo: without extension. This class has to implement the same interface as C++ equivalent. """ - def __init__(self, sample_rate: int, num_channels: int, num_samples: int): + def __init__(self, sample_rate: int, num_channels: int, num_frames: int): self.sample_rate = sample_rate self.num_channels = num_channels - self.num_samples = num_samples + self.num_frames = num_frames def get_sample_rate(self): return self.sample_rate @@ -41,8 +41,8 @@ def get_sample_rate(self): def get_num_channels(self): return self.num_channels - def get_num_samples(self): - return self.num_samples + def get_num_frames(self): + return self.num_frames DummyModule = namedtuple('torchaudio', ['SignalInfo']) module = DummyModule(SignalInfo)