Skip to content

Commit

Permalink
Add SignalInfo typedef, and extension module
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jun 18, 2020
1 parent b17da7a commit 3c6cdec
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 0 deletions.
18 changes: 18 additions & 0 deletions torchaudio/csrc/register.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef TORCHAUDIO_REGISTER_H
#define TORCHAUDIO_REGISTER_H

#include <torchaudio/csrc/typedefs.h>

namespace torchaudio {
namespace {

static auto registerSignalInfo =
torch::class_<SignalInfo>("torchaudio", "SignalInfo")
.def(torch::init<int64_t, int64_t, int64_t>())
.def("get_sample_rate", &SignalInfo::getSampleRate)
.def("get_num_channels", &SignalInfo::getNumChannels)
.def("get_num_samples", &SignalInfo::getNumSamples);

} // namespace
} // namespace torchaudio
#endif
23 changes: 23 additions & 0 deletions torchaudio/csrc/typedefs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include <torchaudio/csrc/typedefs.h>

namespace torchaudio {
SignalInfo::SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_samples_)
: sample_rate(sample_rate_),
num_channels(num_channels_),
num_samples(num_samples_){};

int64_t SignalInfo::getSampleRate() const {
return sample_rate;
}

int64_t SignalInfo::getNumChannels() const {
return num_channels;
}

int64_t SignalInfo::getNumSamples() const {
return num_samples;
}
} // namespace torchaudio
23 changes: 23 additions & 0 deletions torchaudio/csrc/typedefs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef TORCHAUDIO_TYPDEFS_H
#define TORCHAUDIO_TYPDEFS_H

#include <torch/script.h>

namespace torchaudio {
struct SignalInfo : torch::CustomClassHolder {
int64_t sample_rate;
int64_t num_channels;
int64_t num_samples;

SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_samples_);
int64_t getSampleRate() const;
int64_t getNumChannels() const;
int64_t getNumSamples() const;
};

} // namespace torchaudio

#endif
7 changes: 7 additions & 0 deletions torchaudio/extension/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .extension import (
_init_extension,
)

_init_extension()

del _init_extension
49 changes: 49 additions & 0 deletions torchaudio/extension/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import importlib
from collections import namedtuple

import torch
from torchaudio._internal import module_utils as _mod_utils


def _init_extension():
ext = 'torchaudio._torchaudio'
if _mod_utils.is_module_available(ext):
_init_script_module(ext)
else:
_init_dummy_module()


def _init_script_module(module):
path = importlib.util.find_spec(module).origin
torch.classes.load_library(path)
torch.ops.load_library(path)


def _init_dummy_module():
class SignalInfo:
"""Container class for audio format information
data class used when torchaudio C++ extension is not available.
This class will be used for backends other than SoX, such as soundfile
and required to annotate related functions.
This class has to implement the same interface as C++ equivalent.
"""
def __init__(self, sample_rate: int, num_channels: int, num_samples: int):
self.sample_rate = sample_rate
self.num_channels = num_channels
self.num_samples = num_samples

def get_sample_rate(self):
return self.sample_rate

def get_num_channels(self):
return self.num_channels

def get_num_samples(self):
return self.num_samples

DummyModule = namedtuple('torchaudio', ['SignalInfo'])
module = DummyModule(SignalInfo)
setattr(torch.classes, 'torchaudio', module)

0 comments on commit 3c6cdec

Please sign in to comment.