Skip to content

Commit

Permalink
Add fn.experimental.audio_resample GPU (#3911)
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton <[email protected]>
  • Loading branch information
jantonguirao authored May 23, 2022
1 parent aa0196f commit 9918cb5
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 41 deletions.
3 changes: 1 addition & 2 deletions dali/kernels/signal/resampling_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,12 @@ void ResamplerGPU<Out, In>::Run(KernelContext &context, const OutListGPU<Out> &o
desc.out = out_sample.data;
desc.window = window_gpu_;
const auto &in_sh = in_sample.shape;
const auto &out_sh = out_sample.shape;
desc.in_len = in_sh[0];
auto &arg = args[i];
desc.out_begin = arg.out_begin > 0 ? arg.out_begin : 0;
desc.out_end =
arg.out_end > 0 ? arg.out_end : resampled_length(in_sh[0], arg.in_rate, arg.out_rate);
assert((desc.out_end - desc.out_begin) == out_sh[0]);
assert((desc.out_end - desc.out_begin) == out_sample.shape[0]);
desc.nchannels = in_sh.sample_dim() > 1 ? in_sh[1] : 1;
desc.scale = arg.in_rate / arg.out_rate;
any_multichannel |= desc.nchannels > 1;
Expand Down
14 changes: 7 additions & 7 deletions dali/operators/audio/resample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "dali/operators/audio/resample.h"
#include <map>
#include <vector>
#include "dali/core/convert.h"
#include "dali/kernels/kernel_params.h"
#include "dali/kernels/signal/resampling_cpu.h"
#include "dali/operators/audio/resample.h"
#include "dali/operators/audio/resampling_params.h"
#include "dali/kernels/kernel_params.h"
#include "dali/core/convert.h"

namespace dali {

Expand Down Expand Up @@ -117,7 +117,7 @@ class ResampleCPU : public ResampleBase<CPUBackend> {
const auto &in_shape = in.shape();
out.SetLayout(in.GetLayout());
int N = in.num_samples();
assert(N == static_cast<int>(scales_.size()));
assert(N == static_cast<int>(args_.size()));
assert(out.type() == dtype_);

auto &tp = ws.GetThreadPool();
Expand All @@ -131,17 +131,17 @@ class ResampleCPU : public ResampleBase<CPUBackend> {
make_string("Unsupported output type: ", dtype_,
"\nSupported types are : ", ListTypeNames<AUDIO_RESAMPLE_TYPES>()));));
TYPE_SWITCH(dtype_, type2id, T, (AUDIO_RESAMPLE_TYPES),
(ResampleTyped<T>(view<T>(out[s]), in_view, scales_[s]);),
(ResampleTyped<T>(view<T>(out[s]), in_view, args_[s]);),
(assert(!"Unreachable code.")));
});
}
tp.RunAll();
}

template <typename T>
void ResampleTyped(const OutTensorCPU<T> &out, const InTensorCPU<float> &in, double scale) {
void ResampleTyped(const OutTensorCPU<T> &out, const InTensorCPU<float> &in, const Args& args) {
int ch = out.shape.sample_dim() > 1 ? out.shape[1] : 1;
R.Resample(out.data, 0, out.shape[0], scale, in.data, in.shape[0], 1, ch);
R.Resample(out.data, 0, out.shape[0], args.out_rate, in.data, in.shape[0], args.in_rate, ch);
}

template <typename T>
Expand Down
29 changes: 17 additions & 12 deletions dali/operators/audio/resample.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,16 @@ class ResampleBase : public Operator<Backend> {
public:
explicit ResampleBase(const OpSpec &spec) : Operator<Backend>(spec) {
DALI_ENFORCE(in_rate_.HasValue() == out_rate_.HasValue(),
"The parameters ``in_rate`` and ``out_rate`` must be specified together.");
"The parameters ``in_rate`` and ``out_rate`` must be specified together.");
if (in_rate_.HasValue() + scale_.HasValue() + out_length_.HasValue() > 1)
DALI_FAIL("The sampling rates, ``scale`` and ``out_length`` cannot be used together.");
if (!in_rate_.HasValue() && !scale_.HasValue() && !out_length_.HasValue())
DALI_FAIL("No resampling factor specified! Please supply either the scale, "
"the output length or the input and output sampling rates.");
DALI_FAIL(
"No resampling factor specified! Please supply either the scale, "
"the output length or the input and output sampling rates.");
quality_ = spec_.template GetArgument<float>("quality");
DALI_ENFORCE(quality_ >= 0 && quality_ <= 100, make_string("``quality`` out of range: ",
quality_, "\nValid range is [0..100]."));
DALI_ENFORCE(quality_ >= 0 && quality_ <= 100,
make_string("``quality`` out of range: ", quality_, "\nValid range is [0..100]."));
if (spec_.TryGetArgument(dtype_, "dtype")) {
// silence useless warning -----------------------------vvvvvvvvvvvvvvvv
TYPE_SWITCH(dtype_, type2id, T, (AUDIO_RESAMPLE_TYPES), (T x; (void)x;),
Expand All @@ -58,20 +59,20 @@ class ResampleBase : public Operator<Backend> {
dtype_ = ws.template Input<Backend>(0).type();

outputs[0].type = dtype_;
CalculateScaleAndShape(outputs[0].shape, ws);
CalculateShapeAndArgs(outputs[0].shape, ws);

return true;
}

void CalculateScaleAndShape(TensorListShape<> &out_shape, const workspace_t<Backend> &ws) {
void CalculateShapeAndArgs(TensorListShape<> &out_shape, const workspace_t<Backend> &ws) {
const auto &input = ws.template Input<Backend>(0);
const TensorListShape<> &shape = input.shape();
DALI_ENFORCE(shape.sample_dim() == 1 || shape.sample_dim() == 2,
"Audio resampling supports only time series data, with an optional innermost "
"channel dimension.");
out_shape = shape;
int N = shape.num_samples();
scales_.resize(N);
args_.resize(N);
if (in_rate_.HasValue()) {
assert(out_rate_.HasValue());
in_rate_.Acquire(spec_, ws, N);
Expand All @@ -93,7 +94,8 @@ class ResampleBase : public Operator<Backend> {
error << " for sample " << s;
DALI_FAIL(error.str());
}
scales_[s] = out_rate / in_rate;
args_[s].in_rate = in_rate;
args_[s].out_rate = out_rate;
int64_t in_length = shape.tensor_shape_span(s)[0];
int64_t out_length =
kernels::signal::resampling::resampled_length(in_length, in_rate, out_rate);
Expand All @@ -110,7 +112,8 @@ class ResampleBase : public Operator<Backend> {
error << " for sample " << s;
DALI_FAIL(error.str());
}
scales_[s] = scale;
args_[s].in_rate = 1.0;
args_[s].out_rate = scale;
int64_t in_length = shape.tensor_shape_span(s)[0];
int64_t out_length =
kernels::signal::resampling::resampled_length(in_length, 1, scale);
Expand All @@ -125,7 +128,8 @@ class ResampleBase : public Operator<Backend> {
DALI_FAIL(make_string("Cannot produce a non-empty signal from an empty input.\n"
"Error at sample ", s));
}
scales_[s] = in_length ? 1.0 * out_length / in_length : 0.0;
args_[s].in_rate = in_length ? in_length : 1; // avoid division by 0
args_[s].out_rate = out_length ? out_length : 1; // avoid division by 0
out_shape.tensor_shape_span(s)[0] = out_length;
}
} else {
Expand All @@ -144,7 +148,8 @@ class ResampleBase : public Operator<Backend> {
ArgValue<float> scale_{"scale", spec_};
ArgValue<int64_t> out_length_{"out_length", spec_};

std::vector<double> scales_;
using Args = kernels::signal::resampling::Args;
std::vector<Args> args_;
};

} // namespace audio
Expand Down
79 changes: 79 additions & 0 deletions dali/operators/audio/resample_gpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <map>
#include <vector>
#include "dali/operators/audio/resample.h"
#include "dali/operators/audio/resampling_params.h"
#include "dali/kernels/signal/resampling_gpu.h"
#include "dali/kernels/kernel_params.h"
#include "dali/kernels/kernel_manager.h"

namespace dali {
namespace audio {

using kernels::InListGPU;
using kernels::OutListGPU;

class ResampleGPU : public ResampleBase<GPUBackend> {
public:
using Base = ResampleBase<GPUBackend>;
explicit ResampleGPU(const OpSpec &spec) : Base(spec) {}

void RunImpl(DeviceWorkspace &ws) override {
auto &out = ws.template Output<GPUBackend>(0);
const auto &in = ws.template Input<GPUBackend>(0);
out.SetLayout(in.GetLayout());

int N = in.num_samples();
assert(N == static_cast<int>(args_.size()));
assert(out.type() == dtype_);

TYPE_SWITCH(dtype_, type2id, Out, (AUDIO_RESAMPLE_TYPES), (
TYPE_SWITCH(in.type(), type2id, In, (AUDIO_RESAMPLE_TYPES), (
ResampleTyped<Out, In>(view<Out>(out), view<const In>(in), ws.stream());
), ( // NOLINT
DALI_FAIL(
make_string("Unsupported input type: ", in.type(), "\nSupported types are : ",
ListTypeNames<AUDIO_RESAMPLE_TYPES>()));
)); // NOLINT
), (assert(!"Unreachable code."))); // NOLINT
}

template <typename Out, typename In>
void ResampleTyped(const OutListGPU<Out> &out, const InListGPU<const In> &in,
cudaStream_t stream) {
using Kernel = kernels::signal::resampling::ResamplerGPU<Out, In>;
if (kmgr_.NumInstances() == 0) {
kmgr_.Resize<Kernel>(1);
auto params = ResamplingParams::FromQuality(quality_);
kmgr_.Get<Kernel>(0).Initialize(params.lobes, params.lookup_size);
}
auto args = make_cspan(args_);
kernels::KernelContext ctx;
ctx.gpu.stream = stream;
kmgr_.Setup<Kernel>(0, ctx, in, args);
kmgr_.Run<Kernel>(0, ctx, out, in, args);
}

private:
kernels::KernelManager kmgr_;
};


} // namespace audio

DALI_REGISTER_OPERATOR(experimental__AudioResample, audio::ResampleGPU, GPU);

} // namespace dali
51 changes: 31 additions & 20 deletions dali/test/python/test_operator_audio_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import numpy as np
import scipy.io.wavfile
from test_audio_decoder_utils import generate_waveforms, rosa_resample
from test_utils import compare_pipelines, get_files, check_batch, dali_type_to_np
from test_audio_decoder_utils import generate_waveforms
from test_utils import check_batch, dali_type_to_np, as_array

names = [
"/tmp/dali_test_1C.wav",
Expand All @@ -34,64 +34,73 @@
rates = [ 16000, 22050, 12347 ]
lengths = [ 10000, 54321, 12345 ]

def create_test_files():
def create_files():
for i in range(len(names)):
wave = generate_waveforms(lengths[i], freqs[i])
wave = (wave * 32767).round().astype(np.int16)
scipy.io.wavfile.write(names[i], rates[i], wave)


create_test_files()
create_files()


@pipeline_def
def audio_decoder_pipe():
def audio_decoder_pipe(device):
encoded, _ = fn.readers.file(files=names)
audio0, sr0 = fn.decoders.audio(encoded, dtype=types.FLOAT)
out_sr = 15000
audio1, sr1 = fn.decoders.audio(encoded, dtype=types.FLOAT, sample_rate=out_sr)
if device == 'gpu':
audio0 = audio0.gpu()
audio2 = fn.experimental.audio_resample(audio0, in_rate=sr0, out_rate=out_sr)
audio3 = fn.experimental.audio_resample(audio0, scale=out_sr/sr0)
audio4 = fn.experimental.audio_resample(audio0, out_length=fn.shapes(audio1)[0])
return audio1, audio2, audio3, audio4

def test_standalone_vs_fused():
pipe = audio_decoder_pipe(batch_size=2, num_threads=1, device_id=0)
def _test_standalone_vs_fused(device):
pipe = audio_decoder_pipe(device=device, batch_size=2, num_threads=1, device_id=0)
pipe.build()
is_gpu = device == 'gpu'
for _ in range(2):
outs = pipe.run()
# two sampling rates - should be bit-exact
check_batch(outs[0], outs[1], eps=0, max_allowed_error=0)
check_batch(outs[0], outs[1], eps=1e-6 if is_gpu else 0,
max_allowed_error=1e-4 if is_gpu else 0)
# numerical round-off error in rate
check_batch(outs[0], outs[2], eps=1e-6, max_allowed_error=1e-4)
# here, the sampling rate is slightly different, so we can tolerate larger errors
check_batch(outs[0], outs[3], eps=1e-4, max_allowed_error=1)

def _test_type_conversion(src_type, in_values, dst_type, out_values, eps):
def test_standalone_vs_fused():
for device in ('gpu', 'cpu'):
yield _test_standalone_vs_fused, device

def _test_type_conversion(device, src_type, in_values, dst_type, out_values, eps):
src_nptype = dali_type_to_np(src_type)
dst_nptype = dali_type_to_np(dst_type)
assert len(out_values) == len(in_values)
in_data = [np.full((100 + 10 * i,), x, src_nptype) for i, x in enumerate(in_values)]
@pipeline_def(batch_size=len(in_values))
def test_pipe():
input = fn.external_source(in_data, batch = False, cycle = 'quiet')
def test_pipe(device):
input = fn.external_source(in_data, batch=False, cycle='quiet', device=device)
return fn.experimental.audio_resample(input, dtype=dst_type, scale=1, quality=0)
pipe = test_pipe(device_id=0, num_threads=4)
pipe = test_pipe(device, device_id=0, num_threads=4)
pipe.build()
is_gpu = device == 'gpu'
for _ in range(2):
out, = pipe.run()
assert len(out) == len(out_values)
assert out.dtype == dst_type
for i in range(len(out_values)):
ref = np.full_like(in_data[i], out_values[i], dst_nptype)
if not np.allclose(out.at(i), ref, 1e-6, eps):
print("Actual: ", out.at(i))
print(out.at(i).dtype, out.at(i).shape)
out_arr = as_array(out[i])
if not np.allclose(out_arr, ref, 1e-6, eps):
print("Actual: ", out_arr)
print(out_arr.dtype, out_arr.shape)
print("Reference: ", ref)
print(ref.dtype, ref.shape)
print("Diff: ", out.at(i).astype(np.float) - ref)
assert np.allclose(out.at(i), ref, 1e-6, eps)

print("Diff: ", out_arr.astype(np.float) - ref)
assert np.allclose(out_arr, ref, 1e-6, eps)

def test_dynamic_ranges():
for type, values, eps in [(types.FLOAT, [-1.e30, -1-1.e-6, -1, -0.5, -1.e-30, 0, 1.e-30, 0.5, 1, 1+1.e-6, 1e30], 0),
Expand All @@ -101,7 +110,8 @@ def test_dynamic_ranges():
(types.INT16, [-32768, -32767, -100, -1, 0, 1, 100, 32767], 0),
(types.UINT32, [0, 1, 0x7fffffff, 0x80000000, 0xfffffffe, 0xffffffff], 128),
(types.INT32, [-0x80000000, -0x7fffffff, -100, -1, 0, 1, 0x7fffffff], 128)]:
yield _test_type_conversion, type, values, type, values, eps
for device in ('cpu', 'gpu'):
yield _test_type_conversion, device, type, values, type, values, eps

def test_type_conversion():
type_ranges = [(types.FLOAT, [-1, 1]),
Expand Down Expand Up @@ -138,4 +148,5 @@ def test_type_conversion():
if eps < 1 and (o_lo != -o_hi or (i_hi != i_lo and dst_type != types.FLOAT)):
eps = 1

yield _test_type_conversion, src_type, in_values, dst_type, out_values, eps
for device in ('cpu', 'gpu'):
yield _test_type_conversion, device, src_type, in_values, dst_type, out_values, eps

0 comments on commit 9918cb5

Please sign in to comment.