Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SignalInfo member name to frame #734

Merged
merged 1 commit into from
Jun 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions test/sox_io_backend/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_wav(self, dtype, sample_rate, num_channels):
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels

@parameterized.expand(list(itertools.product(
Expand All @@ -55,7 +55,7 @@ def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels

@parameterized.expand(list(itertools.product(
Expand All @@ -74,7 +74,7 @@ def test_mp3(self, sample_rate, num_channels, bit_rate):
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
# mp3 does not preserve the number of samples
# assert info.get_num_samples() == sample_rate * duration
# assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels

@parameterized.expand(list(itertools.product(
Expand All @@ -92,7 +92,7 @@ def test_flac(self, sample_rate, num_channels, compression_level):
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels

@parameterized.expand(list(itertools.product(
Expand All @@ -110,5 +110,5 @@ def test_vorbis(self, sample_rate, num_channels, quality_level):
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels
2 changes: 1 addition & 1 deletion test/sox_io_backend/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ def test_info_wav(self, dtype, sample_rate, num_channels):
ts_info = ts_info_func(audio_path)

assert py_info.get_sample_rate() == ts_info.get_sample_rate()
assert py_info.get_num_samples() == ts_info.get_num_samples()
assert py_info.get_num_frames() == ts_info.get_num_frames()
assert py_info.get_num_channels() == ts_info.get_num_channels()
2 changes: 1 addition & 1 deletion torchaudio/csrc/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ static auto registerSignalInfo =
.def(torch::init<int64_t, int64_t, int64_t>())
.def("get_sample_rate", &SignalInfo::getSampleRate)
.def("get_num_channels", &SignalInfo::getNumChannels)
.def("get_num_samples", &SignalInfo::getNumSamples);
.def("get_num_frames", &SignalInfo::getNumFrames);

static auto registerGetInfo = torch::RegisterOperators().op(
torch::RegisterOperators::options()
Expand Down
8 changes: 4 additions & 4 deletions torchaudio/csrc/typedefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ namespace torchaudio {
SignalInfo::SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_samples_)
const int64_t num_frames_)
: sample_rate(sample_rate_),
num_channels(num_channels_),
num_samples(num_samples_){};
num_frames(num_frames_){};

int64_t SignalInfo::getSampleRate() const {
return sample_rate;
Expand All @@ -17,7 +17,7 @@ int64_t SignalInfo::getNumChannels() const {
return num_channels;
}

int64_t SignalInfo::getNumSamples() const {
return num_samples;
int64_t SignalInfo::getNumFrames() const {
return num_frames;
}
} // namespace torchaudio
6 changes: 3 additions & 3 deletions torchaudio/csrc/typedefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ namespace torchaudio {
struct SignalInfo : torch::CustomClassHolder {
int64_t sample_rate;
int64_t num_channels;
int64_t num_samples;
int64_t num_frames;

SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_samples_);
const int64_t num_frames_);
int64_t getSampleRate() const;
int64_t getNumChannels() const;
int64_t getNumSamples() const;
int64_t getNumFrames() const;
};

} // namespace torchaudio
Expand Down
8 changes: 4 additions & 4 deletions torchaudio/extension/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@ class SignalInfo:
without extension.
This class has to implement the same interface as C++ equivalent.
"""
def __init__(self, sample_rate: int, num_channels: int, num_samples: int):
def __init__(self, sample_rate: int, num_channels: int, num_frames: int):
self.sample_rate = sample_rate
self.num_channels = num_channels
self.num_samples = num_samples
self.num_frames = num_frames

def get_sample_rate(self):
return self.sample_rate

def get_num_channels(self):
return self.num_channels

def get_num_samples(self):
return self.num_samples
def get_num_frames(self):
return self.num_frames

DummyModule = namedtuple('torchaudio', ['SignalInfo'])
module = DummyModule(SignalInfo)
Expand Down