-
Notifications
You must be signed in to change notification settings - Fork 664
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add SignalInfo typedef, and extension module
- Loading branch information
Showing
5 changed files
with
120 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#ifndef TORCHAUDIO_REGISTER_H | ||
#define TORCHAUDIO_REGISTER_H | ||
|
||
#include <torchaudio/csrc/typedefs.h> | ||
|
||
namespace torchaudio { | ||
namespace { | ||
|
||
static auto registerSignalInfo = | ||
torch::class_<SignalInfo>("torchaudio", "SignalInfo") | ||
.def(torch::init<int64_t, int64_t, int64_t>()) | ||
.def("get_sample_rate", &SignalInfo::getSampleRate) | ||
.def("get_num_channels", &SignalInfo::getNumChannels) | ||
.def("get_num_samples", &SignalInfo::getNumSamples); | ||
|
||
} // namespace | ||
} // namespace torchaudio | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#include <torchaudio/csrc/typedefs.h> | ||
|
||
namespace torchaudio { | ||
SignalInfo::SignalInfo( | ||
const int64_t sample_rate_, | ||
const int64_t num_channels_, | ||
const int64_t num_samples_) | ||
: sample_rate(sample_rate_), | ||
num_channels(num_channels_), | ||
num_samples(num_samples_){}; | ||
|
||
int64_t SignalInfo::getSampleRate() const { | ||
return sample_rate; | ||
} | ||
|
||
int64_t SignalInfo::getNumChannels() const { | ||
return num_channels; | ||
} | ||
|
||
int64_t SignalInfo::getNumSamples() const { | ||
return num_samples; | ||
} | ||
} // namespace torchaudio |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#ifndef TORCHAUDIO_TYPDEFS_H | ||
#define TORCHAUDIO_TYPDEFS_H | ||
|
||
#include <torch/script.h> | ||
|
||
namespace torchaudio { | ||
struct SignalInfo : torch::CustomClassHolder { | ||
int64_t sample_rate; | ||
int64_t num_channels; | ||
int64_t num_samples; | ||
|
||
SignalInfo( | ||
const int64_t sample_rate_, | ||
const int64_t num_channels_, | ||
const int64_t num_samples_); | ||
int64_t getSampleRate() const; | ||
int64_t getNumChannels() const; | ||
int64_t getNumSamples() const; | ||
}; | ||
|
||
} // namespace torchaudio | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .extension import ( | ||
_init_extension, | ||
) | ||
|
||
_init_extension() | ||
|
||
del _init_extension |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import importlib | ||
from collections import namedtuple | ||
|
||
import torch | ||
from torchaudio._internal import module_utils as _mod_utils | ||
|
||
|
||
def _init_extension(): | ||
ext = 'torchaudio._torchaudio' | ||
if _mod_utils.is_module_available(ext): | ||
_init_script_module(ext) | ||
else: | ||
_init_dummy_module() | ||
|
||
|
||
def _init_script_module(module): | ||
path = importlib.util.find_spec(module).origin | ||
torch.classes.load_library(path) | ||
torch.ops.load_library(path) | ||
|
||
|
||
def _init_dummy_module(): | ||
class SignalInfo: | ||
"""Container class for audio format information | ||
data class used when torchaudio C++ extension is not available. | ||
This class will be used for backends other than SoX, such as soundfile | ||
and required to annotate related functions. | ||
This class has to implement the same interface as C++ equivalent. | ||
""" | ||
def __init__(self, sample_rate: int, num_channels: int, num_samples: int): | ||
self.sample_rate = sample_rate | ||
self.num_channels = num_channels | ||
self.num_samples = num_samples | ||
|
||
def get_sample_rate(self): | ||
return self.sample_rate | ||
|
||
def get_num_channels(self): | ||
return self.num_channels | ||
|
||
def get_num_samples(self): | ||
return self.num_samples | ||
|
||
DummyModule = namedtuple('torchaudio', ['SignalInfo']) | ||
module = DummyModule(SignalInfo) | ||
setattr(torch.classes, 'torchaudio', module) |