Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fn.experimental.audio_resample GPU #3911

Merged
merged 39 commits into from
May 23, 2022
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
0778f61
Initial effort.
mzient May 10, 2022
740bd14
Add signal resampling GPU kernel
jantonguirao May 10, 2022
5b410ed
Remove downmixing
jantonguirao May 11, 2022
6065354
Code review fixes
jantonguirao May 11, 2022
7c0162d
Add benchmark
jantonguirao May 11, 2022
7a39c95
Avoid precision issue & add shared memory usage
jantonguirao May 11, 2022
3ef0f58
Move double in_block_f calculation inside the loop
jantonguirao May 12, 2022
a66a763
Fix benchmark
jantonguirao May 12, 2022
44354b2
Update benchmark
jantonguirao May 12, 2022
e4eb226
ROI & input conversion to float & limit tmp shared mem
jantonguirao May 12, 2022
13ff3db
Add comments
jantonguirao May 16, 2022
3467768
Use floorf and ceilf in CUDA code
jantonguirao May 16, 2022
b10d860
Improve tests & fix bugs
jantonguirao May 16, 2022
19e58e8
Move resampling GPU to cu file & add sync to Initialize
jantonguirao May 17, 2022
bd8212a
Add audio_resample GPU operator
jantonguirao May 17, 2022
bf32abc
Initial effort.
mzient May 10, 2022
a983fdd
Add signal resampling GPU kernel
jantonguirao May 10, 2022
9a403ea
Remove downmixing
jantonguirao May 11, 2022
c6374a4
Code review fixes
jantonguirao May 11, 2022
d88e78b
Add benchmark
jantonguirao May 11, 2022
074bc13
Avoid precision issue & add shared memory usage
jantonguirao May 11, 2022
5248048
Move double in_block_f calculation inside the loop
jantonguirao May 12, 2022
3f74625
Fix benchmark
jantonguirao May 12, 2022
177fced
Update benchmark
jantonguirao May 12, 2022
744f570
ROI & input conversion to float & limit tmp shared mem
jantonguirao May 12, 2022
12e177c
Add comments
jantonguirao May 16, 2022
ef495e1
Use floorf and ceilf in CUDA code
jantonguirao May 16, 2022
73c6cee
Improve tests & fix bugs
jantonguirao May 16, 2022
44cd216
Move resampling GPU to cu file & add sync to Initialize
jantonguirao May 17, 2022
ab0470d
Fix reference to temp member
jantonguirao May 17, 2022
a10bd6b
Merge branch 'audio_resampling_gpu_kernel' into audio_resampling_gpu_op
jantonguirao May 17, 2022
6be84a4
Call Initialize
jantonguirao May 17, 2022
3ad6ea1
Merge remote-tracking branch 'upstream/main' into audio_resampling_gp…
jantonguirao May 18, 2022
c5cec3b
Rebase
jantonguirao May 18, 2022
7a0cc7a
Code review fixes
jantonguirao May 18, 2022
934c17e
Test full GPU pipe
jantonguirao May 18, 2022
d76cd88
Merge remote-tracking branch 'upstream/main' into audio_resampling_gp…
jantonguirao May 18, 2022
23525bf
Code review fixes
jantonguirao May 20, 2022
b2c2c57
Code review fixes
jantonguirao May 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions dali/operators/audio/resample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class ResampleCPU : public ResampleBase<CPUBackend> {
const auto &in_shape = in.shape();
out.SetLayout(in.GetLayout());
int N = in.num_samples();
assert(N == static_cast<int>(scales_.size()));
assert(N == static_cast<int>(args_.size()));
assert(out.type() == dtype_);

auto &tp = ws.GetThreadPool();
Expand All @@ -129,17 +129,17 @@ class ResampleCPU : public ResampleBase<CPUBackend> {
make_string("Unsupported output type: ", dtype_,
"\nSupported types are : ", ListTypeNames<AUDIO_RESAMPLE_TYPES>()));));
TYPE_SWITCH(dtype_, type2id, T, (AUDIO_RESAMPLE_TYPES),
(ResampleTyped<T>(view<T>(out[s]), in_view, scales_[s]);),
(ResampleTyped<T>(view<T>(out[s]), in_view, args_[s]);),
(assert(!"Unreachable code.")));
});
}
tp.RunAll();
}

template <typename T>
void ResampleTyped(const OutTensorCPU<T> &out, const InTensorCPU<float> &in, double scale) {
void ResampleTyped(const OutTensorCPU<T> &out, const InTensorCPU<float> &in, const Args& args) {
int ch = out.shape.sample_dim() > 1 ? out.shape[1] : 1;
R.Resample(out.data, 0, out.shape[0], scale, in.data, in.shape[0], 1, ch);
R.Resample(out.data, 0, out.shape[0], args.out_rate, in.data, in.shape[0], args.in_rate, ch);
}

template <typename T>
Expand Down
29 changes: 17 additions & 12 deletions dali/operators/audio/resample.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,16 @@ class ResampleBase : public Operator<Backend> {
public:
explicit ResampleBase(const OpSpec &spec) : Operator<Backend>(spec) {
DALI_ENFORCE(in_rate_.HasValue() == out_rate_.HasValue(),
"The parameters ``in_rate`` and ``out_rate`` must be specified together.");
"The parameters ``in_rate`` and ``out_rate`` must be specified together.");
if (in_rate_.HasValue() + scale_.HasValue() + out_length_.HasValue() > 1)
DALI_FAIL("The sampling rates, ``scale`` and ``out_length`` cannot be used together.");
if (!in_rate_.HasValue() && !scale_.HasValue() && !out_length_.HasValue())
DALI_FAIL("No resampling factor specified! Please supply either the scale, "
"the output length or the input and output sampling rates.");
DALI_FAIL(
"No resampling factor specified! Please supply either the scale, "
"the output length or the input and output sampling rates.");
quality_ = spec_.template GetArgument<float>("quality");
DALI_ENFORCE(quality_ >= 0 && quality_ <= 100, make_string("``quality`` out of range: ",
quality_, "\nValid range is [0..100]."));
DALI_ENFORCE(quality_ >= 0 && quality_ <= 100,
make_string("``quality`` out of range: ", quality_, "\nValid range is [0..100]."));
if (spec_.TryGetArgument(dtype_, "dtype")) {
// silence useless warning -----------------------------vvvvvvvvvvvvvvvv
TYPE_SWITCH(dtype_, type2id, T, (AUDIO_RESAMPLE_TYPES), (T x; (void)x;),
Expand All @@ -58,20 +59,20 @@ class ResampleBase : public Operator<Backend> {
dtype_ = ws.template Input<Backend>(0).type();

outputs[0].type = dtype_;
CalculateScaleAndShape(outputs[0].shape, ws);
CalculateShapeAndArgs(outputs[0].shape, ws);

return true;
}

void CalculateScaleAndShape(TensorListShape<> &out_shape, const workspace_t<Backend> &ws) {
void CalculateShapeAndArgs(TensorListShape<> &out_shape, const workspace_t<Backend> &ws) {
const auto &input = ws.template Input<Backend>(0);
const TensorListShape<> &shape = input.shape();
DALI_ENFORCE(shape.sample_dim() == 1 || shape.sample_dim() == 2,
"Audio resampling supports only time series data, with an optional innermost "
"channel dimension.");
out_shape = shape;
int N = shape.num_samples();
scales_.resize(N);
args_.resize(N);
if (in_rate_.HasValue()) {
assert(out_rate_.HasValue());
in_rate_.Acquire(spec_, ws, N);
Expand All @@ -93,7 +94,8 @@ class ResampleBase : public Operator<Backend> {
error << " for sample " << s;
DALI_FAIL(error.str());
}
scales_[s] = out_rate / in_rate;
args_[s].in_rate = in_rate;
args_[s].out_rate = out_rate;
int64_t in_length = shape.tensor_shape_span(s)[0];
int64_t out_length =
kernels::signal::resampling::resampled_length(in_length, in_rate, out_rate);
Expand All @@ -110,7 +112,8 @@ class ResampleBase : public Operator<Backend> {
error << " for sample " << s;
DALI_FAIL(error.str());
}
scales_[s] = scale;
args_[s].in_rate = 1.0;
args_[s].out_rate = scale;
int64_t in_length = shape.tensor_shape_span(s)[0];
int64_t out_length =
kernels::signal::resampling::resampled_length(in_length, 1, scale);
Expand All @@ -125,7 +128,8 @@ class ResampleBase : public Operator<Backend> {
DALI_FAIL(make_string("Cannot produce a non-empty signal from an empty input.\n"
"Error at sample ", s));
}
scales_[s] = in_length ? 1.0 * out_length / in_length : 0.0;
args_[s].in_rate = 1.0;
args_[s].out_rate = in_length ? 1.0 * out_length / in_length : 0.0;
Copy link
Contributor

@mzient mzient May 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick:
Now that we have args, this would increase precision - we use the inverse scale in the kernel (in_rate/out_rate), so performing the other division here and reciprocal there will decrease (albeit very slightly) the precision.

Suggested change
args_[s].in_rate = 1.0;
args_[s].out_rate = in_length ? 1.0 * out_length / in_length : 0.0;
args_[s].in_rate = in_length ? in_length : 1; // avoid division by 0
args_[s].out_rate = out_length ? out_length : 1; // avoid division by 0

out_shape.tensor_shape_span(s)[0] = out_length;
}
} else {
Expand All @@ -144,7 +148,8 @@ class ResampleBase : public Operator<Backend> {
ArgValue<float> scale_{"scale", spec_};
ArgValue<int64_t> out_length_{"out_length", spec_};

std::vector<double> scales_;
using Args = kernels::signal::resampling::Args;
SmallVector<Args, 128> args_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plain vector will do just fine. Large SmallVectors (however weird that sounds) are better suited for temporary local buffers, where the difference between (frequent) stack allocation (free) and heap allocation (thousands of cycles) is of essence. Here, the vector will be reallocated a few times per operator lifetime at worst (typically it will be allocated just once) - and chances are, we'll still run over 128.

Suggested change
SmallVector<Args, 128> args_;
std::vector<Args> args_;

};

} // namespace audio
Expand Down
83 changes: 83 additions & 0 deletions dali/operators/audio/resample_gpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <map>
#include <vector>
#include "dali/operators/audio/resample.h"
#include "dali/operators/audio/resampling_params.h"
#include "dali/kernels/signal/resampling_gpu.h"
#include "dali/kernels/kernel_params.h"
#include "dali/kernels/kernel_manager.h"

namespace dali {
namespace audio {

using kernels::InListGPU;
using kernels::OutListGPU;

class ResampleGPU : public ResampleBase<GPUBackend> {
public:
using Base = ResampleBase<GPUBackend>;
explicit ResampleGPU(const OpSpec &spec) : Base(spec) {}

void RunImpl(DeviceWorkspace &ws) override {
auto &out = ws.template Output<GPUBackend>(0);
const auto &in = ws.template Input<GPUBackend>(0);
out.SetLayout(in.GetLayout());

int N = in.num_samples();
assert(N == static_cast<int>(args_.size()));
assert(out.type() == dtype_);

TYPE_SWITCH(dtype_, type2id, Out, (AUDIO_RESAMPLE_TYPES), (
TYPE_SWITCH(in.type(), type2id, In, (AUDIO_RESAMPLE_TYPES), (
ResampleTyped<Out, In>(view<Out>(out), view<const In>(in), ws.stream());
), ( // NOLINT
DALI_FAIL(
make_string("Unsupported input type: ", in.type(), "\nSupported types are : ",
ListTypeNames<AUDIO_RESAMPLE_TYPES>()));
)); // NOLINT
), ( // NOLINT
DALI_FAIL(
make_string("Unsupported output type: ", dtype_, "\nSupported types are : ",
ListTypeNames<AUDIO_RESAMPLE_TYPES>()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The constructor already checks dtype and produces a message - no need to duplicate it. Also, since this is already verified, reaching this case is an indication of an internal error, not a misuse.

Suggested change
DALI_FAIL(
make_string("Unsupported output type: ", dtype_, "\nSupported types are : ",
ListTypeNames<AUDIO_RESAMPLE_TYPES>()));
assert(!"Internal error: inconsistent output type.");

or

Suggested change
DALI_FAIL(
make_string("Unsupported output type: ", dtype_, "\nSupported types are : ",
ListTypeNames<AUDIO_RESAMPLE_TYPES>()));
throw std::logic_error("Internal error: inconsistent output type.");

if you want to be defensive.

)); // NOLINT
}

template <typename Out, typename In>
void ResampleTyped(const OutListGPU<Out> &out, const InListGPU<const In> &in,
cudaStream_t stream) {
using Kernel = kernels::signal::resampling::ResamplerGPU<Out, In>;
if (kmgr_.NumInstances() == 0) {
kmgr_.Resize<Kernel>(1);
auto params = ResamplingParams::FromQuality(quality_);
kmgr_.Get<Kernel>(0).Initialize(params.lobes, params.lookup_size);
}
auto args = make_cspan(args_);
kernels::KernelContext ctx;
ctx.gpu.stream = stream;
kmgr_.Setup<Kernel>(0, ctx, in, args);
kmgr_.Run<Kernel>(0, ctx, out, in, args);
}

private:
kernels::KernelManager kmgr_;
};


} // namespace audio

DALI_REGISTER_OPERATOR(experimental__AudioResample, audio::ResampleGPU, GPU);

} // namespace dali
51 changes: 31 additions & 20 deletions dali/test/python/test_operator_audio_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import numpy as np
import scipy.io.wavfile
from test_audio_decoder_utils import generate_waveforms, rosa_resample
from test_utils import compare_pipelines, get_files, check_batch, dali_type_to_np
from test_audio_decoder_utils import generate_waveforms
from test_utils import check_batch, dali_type_to_np, as_array

names = [
"/tmp/dali_test_1C.wav",
Expand All @@ -34,64 +34,73 @@
rates = [ 16000, 22050, 12347 ]
lengths = [ 10000, 54321, 12345 ]

def create_test_files():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: nose was picking this as a test

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making it "private" (with leading underscore) would help, too.

def create_files():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would be another way, more verbose probably:

Suggested change
def create_files():
from nose.tools import nottest
@nottest
def create_files():

for i in range(len(names)):
wave = generate_waveforms(lengths[i], freqs[i])
wave = (wave * 32767).round().astype(np.int16)
scipy.io.wavfile.write(names[i], rates[i], wave)


create_test_files()
create_files()


@pipeline_def
def audio_decoder_pipe():
def audio_decoder_pipe(device):
encoded, _ = fn.readers.file(files=names)
audio0, sr0 = fn.decoders.audio(encoded, dtype=types.FLOAT)
out_sr = 15000
audio1, sr1 = fn.decoders.audio(encoded, dtype=types.FLOAT, sample_rate=out_sr)
if device == 'gpu':
audio0 = audio0.gpu()
audio2 = fn.experimental.audio_resample(audio0, in_rate=sr0, out_rate=out_sr)
audio3 = fn.experimental.audio_resample(audio0, scale=out_sr/sr0)
audio4 = fn.experimental.audio_resample(audio0, out_length=fn.shapes(audio1)[0])
return audio1, audio2, audio3, audio4

def test_standalone_vs_fused():
pipe = audio_decoder_pipe(batch_size=2, num_threads=1, device_id=0)
def _test_standalone_vs_fused(device):
pipe = audio_decoder_pipe(device=device, batch_size=2, num_threads=1, device_id=0)
pipe.build()
is_gpu = device == 'gpu'
for _ in range(2):
outs = pipe.run()
# two sampling rates - should be bit-exact
check_batch(outs[0], outs[1], eps=0, max_allowed_error=0)
check_batch(outs[0], outs[1], eps=1e-6 if is_gpu else 0,
max_allowed_error=1e-4 if is_gpu else 0)
# numerical round-off error in rate
check_batch(outs[0], outs[2], eps=1e-6, max_allowed_error=1e-4)
# here, the sampling rate is slightly different, so we can tolerate larger errors
check_batch(outs[0], outs[3], eps=1e-4, max_allowed_error=1)

def _test_type_conversion(src_type, in_values, dst_type, out_values, eps):
def test_standalone_vs_fused():
for device in ('gpu', 'cpu'):
yield _test_standalone_vs_fused, device

def _test_type_conversion(device, src_type, in_values, dst_type, out_values, eps):
src_nptype = dali_type_to_np(src_type)
dst_nptype = dali_type_to_np(dst_type)
assert len(out_values) == len(in_values)
in_data = [np.full((100 + 10 * i,), x, src_nptype) for i, x in enumerate(in_values)]
@pipeline_def(batch_size=len(in_values))
def test_pipe():
input = fn.external_source(in_data, batch = False, cycle = 'quiet')
def test_pipe(device):
input = fn.external_source(in_data, batch=False, cycle='quiet', device=device)
return fn.experimental.audio_resample(input, dtype=dst_type, scale=1, quality=0)
pipe = test_pipe(device_id=0, num_threads=4)
pipe = test_pipe(device, device_id=0, num_threads=4)
pipe.build()
is_gpu = device == 'gpu'
for _ in range(2):
out, = pipe.run()
assert len(out) == len(out_values)
assert out.dtype == dst_type
for i in range(len(out_values)):
ref = np.full_like(in_data[i], out_values[i], dst_nptype)
if not np.allclose(out.at(i), ref, 1e-6, eps):
print("Actual: ", out.at(i))
print(out.at(i).dtype, out.at(i).shape)
out_arr = as_array(out[i])
if not np.allclose(out_arr, ref, 1e-6, eps):
print("Actual: ", out_arr)
print(out_arr.dtype, out_arr.shape)
print("Reference: ", ref)
print(ref.dtype, ref.shape)
print("Diff: ", out.at(i).astype(np.float) - ref)
assert np.allclose(out.at(i), ref, 1e-6, eps)

print("Diff: ", out_arr.astype(np.float) - ref)
assert False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: the original code deliberately repeated the check here, so that the error in nosetests would appear as more than False in non-verbose runs.


def test_dynamic_ranges():
for type, values, eps in [(types.FLOAT, [-1.e30, -1-1.e-6, -1, -0.5, -1.e-30, 0, 1.e-30, 0.5, 1, 1+1.e-6, 1e30], 0),
Expand All @@ -101,7 +110,8 @@ def test_dynamic_ranges():
(types.INT16, [-32768, -32767, -100, -1, 0, 1, 100, 32767], 0),
(types.UINT32, [0, 1, 0x7fffffff, 0x80000000, 0xfffffffe, 0xffffffff], 128),
(types.INT32, [-0x80000000, -0x7fffffff, -100, -1, 0, 1, 0x7fffffff], 128)]:
yield _test_type_conversion, type, values, type, values, eps
for device in ('cpu', 'gpu'):
yield _test_type_conversion, device, type, values, type, values, eps

def test_type_conversion():
type_ranges = [(types.FLOAT, [-1, 1]),
Expand Down Expand Up @@ -138,4 +148,5 @@ def test_type_conversion():
if eps < 1 and (o_lo != -o_hi or (i_hi != i_lo and dst_type != types.FLOAT)):
eps = 1

yield _test_type_conversion, src_type, in_values, dst_type, out_values, eps
for device in ('cpu', 'gpu'):
yield _test_type_conversion, device, src_type, in_values, dst_type, out_values, eps