-
Notifications
You must be signed in to change notification settings - Fork 622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fn.experimental.audio_resample
GPU
#3911
Changes from 33 commits
0778f61
740bd14
5b410ed
6065354
7c0162d
7a39c95
3ef0f58
a66a763
44354b2
e4eb226
13ff3db
3467768
b10d860
19e58e8
bd8212a
bf32abc
a983fdd
9a403ea
c6374a4
d88e78b
074bc13
5248048
3f74625
177fced
744f570
12e177c
ef495e1
73c6cee
44cd216
ab0470d
a10bd6b
6be84a4
3ad6ea1
c5cec3b
7a0cc7a
934c17e
d76cd88
23525bf
b2c2c57
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -31,15 +31,16 @@ class ResampleBase : public Operator<Backend> { | |||||
public: | ||||||
explicit ResampleBase(const OpSpec &spec) : Operator<Backend>(spec) { | ||||||
DALI_ENFORCE(in_rate_.HasValue() == out_rate_.HasValue(), | ||||||
"The parameters ``in_rate`` and ``out_rate`` must be specified together."); | ||||||
"The parameters ``in_rate`` and ``out_rate`` must be specified together."); | ||||||
if (in_rate_.HasValue() + scale_.HasValue() + out_length_.HasValue() > 1) | ||||||
DALI_FAIL("The sampling rates, ``scale`` and ``out_length`` cannot be used together."); | ||||||
if (!in_rate_.HasValue() && !scale_.HasValue() && !out_length_.HasValue()) | ||||||
DALI_FAIL("No resampling factor specified! Please supply either the scale, " | ||||||
"the output length or the input and output sampling rates."); | ||||||
DALI_FAIL( | ||||||
"No resampling factor specified! Please supply either the scale, " | ||||||
"the output length or the input and output sampling rates."); | ||||||
quality_ = spec_.template GetArgument<float>("quality"); | ||||||
DALI_ENFORCE(quality_ >= 0 && quality_ <= 100, make_string("``quality`` out of range: ", | ||||||
quality_, "\nValid range is [0..100].")); | ||||||
DALI_ENFORCE(quality_ >= 0 && quality_ <= 100, | ||||||
make_string("``quality`` out of range: ", quality_, "\nValid range is [0..100].")); | ||||||
if (spec_.TryGetArgument(dtype_, "dtype")) { | ||||||
// silence useless warning -----------------------------vvvvvvvvvvvvvvvv | ||||||
TYPE_SWITCH(dtype_, type2id, T, (AUDIO_RESAMPLE_TYPES), (T x; (void)x;), | ||||||
|
@@ -58,20 +59,20 @@ class ResampleBase : public Operator<Backend> { | |||||
dtype_ = ws.template Input<Backend>(0).type(); | ||||||
|
||||||
outputs[0].type = dtype_; | ||||||
CalculateScaleAndShape(outputs[0].shape, ws); | ||||||
CalculateShapeAndArgs(outputs[0].shape, ws); | ||||||
|
||||||
return true; | ||||||
} | ||||||
|
||||||
void CalculateScaleAndShape(TensorListShape<> &out_shape, const workspace_t<Backend> &ws) { | ||||||
void CalculateShapeAndArgs(TensorListShape<> &out_shape, const workspace_t<Backend> &ws) { | ||||||
const auto &input = ws.template Input<Backend>(0); | ||||||
const TensorListShape<> &shape = input.shape(); | ||||||
DALI_ENFORCE(shape.sample_dim() == 1 || shape.sample_dim() == 2, | ||||||
"Audio resampling supports only time series data, with an optional innermost " | ||||||
"channel dimension."); | ||||||
out_shape = shape; | ||||||
int N = shape.num_samples(); | ||||||
scales_.resize(N); | ||||||
args_.resize(N); | ||||||
if (in_rate_.HasValue()) { | ||||||
assert(out_rate_.HasValue()); | ||||||
in_rate_.Acquire(spec_, ws, N); | ||||||
|
@@ -93,7 +94,8 @@ class ResampleBase : public Operator<Backend> { | |||||
error << " for sample " << s; | ||||||
DALI_FAIL(error.str()); | ||||||
} | ||||||
scales_[s] = out_rate / in_rate; | ||||||
args_[s].in_rate = in_rate; | ||||||
args_[s].out_rate = out_rate; | ||||||
int64_t in_length = shape.tensor_shape_span(s)[0]; | ||||||
int64_t out_length = | ||||||
kernels::signal::resampling::resampled_length(in_length, in_rate, out_rate); | ||||||
|
@@ -110,7 +112,8 @@ class ResampleBase : public Operator<Backend> { | |||||
error << " for sample " << s; | ||||||
DALI_FAIL(error.str()); | ||||||
} | ||||||
scales_[s] = scale; | ||||||
args_[s].in_rate = 1.0; | ||||||
args_[s].out_rate = scale; | ||||||
int64_t in_length = shape.tensor_shape_span(s)[0]; | ||||||
int64_t out_length = | ||||||
kernels::signal::resampling::resampled_length(in_length, 1, scale); | ||||||
|
@@ -125,7 +128,8 @@ class ResampleBase : public Operator<Backend> { | |||||
DALI_FAIL(make_string("Cannot produce a non-empty signal from an empty input.\n" | ||||||
"Error at sample ", s)); | ||||||
} | ||||||
scales_[s] = in_length ? 1.0 * out_length / in_length : 0.0; | ||||||
args_[s].in_rate = 1.0; | ||||||
args_[s].out_rate = in_length ? 1.0 * out_length / in_length : 0.0; | ||||||
out_shape.tensor_shape_span(s)[0] = out_length; | ||||||
} | ||||||
} else { | ||||||
|
@@ -144,7 +148,8 @@ class ResampleBase : public Operator<Backend> { | |||||
ArgValue<float> scale_{"scale", spec_}; | ||||||
ArgValue<int64_t> out_length_{"out_length", spec_}; | ||||||
|
||||||
std::vector<double> scales_; | ||||||
using Args = kernels::signal::resampling::Args; | ||||||
SmallVector<Args, 128> args_; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Plain vector will do just fine. Large SmallVectors (however weird that sounds) are better suited for temporary local buffers, where the difference between (frequent) stack allocation (free) and heap allocation (thousands of cycles) is of essence. Here, the vector will be reallocated a few times per operator lifetime at worst (typically it will be allocated just once) - and chances are, we'll still run over 128.
Suggested change
|
||||||
}; | ||||||
|
||||||
} // namespace audio | ||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,83 @@ | ||||||||||||||||||
// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||||||||||||||
// | ||||||||||||||||||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||
// you may not use this file except in compliance with the License. | ||||||||||||||||||
// You may obtain a copy of the License at | ||||||||||||||||||
// | ||||||||||||||||||
// http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||
// | ||||||||||||||||||
// Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||
// distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||
// See the License for the specific language governing permissions and | ||||||||||||||||||
// limitations under the License. | ||||||||||||||||||
|
||||||||||||||||||
#include <map> | ||||||||||||||||||
#include <vector> | ||||||||||||||||||
#include "dali/operators/audio/resample.h" | ||||||||||||||||||
#include "dali/operators/audio/resampling_params.h" | ||||||||||||||||||
#include "dali/kernels/signal/resampling_gpu.h" | ||||||||||||||||||
#include "dali/kernels/kernel_params.h" | ||||||||||||||||||
#include "dali/kernels/kernel_manager.h" | ||||||||||||||||||
|
||||||||||||||||||
namespace dali { | ||||||||||||||||||
namespace audio { | ||||||||||||||||||
|
||||||||||||||||||
using kernels::InListGPU; | ||||||||||||||||||
using kernels::OutListGPU; | ||||||||||||||||||
|
||||||||||||||||||
class ResampleGPU : public ResampleBase<GPUBackend> { | ||||||||||||||||||
public: | ||||||||||||||||||
using Base = ResampleBase<GPUBackend>; | ||||||||||||||||||
explicit ResampleGPU(const OpSpec &spec) : Base(spec) {} | ||||||||||||||||||
|
||||||||||||||||||
void RunImpl(DeviceWorkspace &ws) override { | ||||||||||||||||||
auto &out = ws.template Output<GPUBackend>(0); | ||||||||||||||||||
const auto &in = ws.template Input<GPUBackend>(0); | ||||||||||||||||||
out.SetLayout(in.GetLayout()); | ||||||||||||||||||
|
||||||||||||||||||
int N = in.num_samples(); | ||||||||||||||||||
assert(N == static_cast<int>(args_.size())); | ||||||||||||||||||
assert(out.type() == dtype_); | ||||||||||||||||||
|
||||||||||||||||||
TYPE_SWITCH(dtype_, type2id, Out, (AUDIO_RESAMPLE_TYPES), ( | ||||||||||||||||||
TYPE_SWITCH(in.type(), type2id, In, (AUDIO_RESAMPLE_TYPES), ( | ||||||||||||||||||
ResampleTyped<Out, In>(view<Out>(out), view<const In>(in), ws.stream()); | ||||||||||||||||||
), ( // NOLINT | ||||||||||||||||||
DALI_FAIL( | ||||||||||||||||||
make_string("Unsupported input type: ", in.type(), "\nSupported types are : ", | ||||||||||||||||||
ListTypeNames<AUDIO_RESAMPLE_TYPES>())); | ||||||||||||||||||
)); // NOLINT | ||||||||||||||||||
), ( // NOLINT | ||||||||||||||||||
DALI_FAIL( | ||||||||||||||||||
make_string("Unsupported output type: ", dtype_, "\nSupported types are : ", | ||||||||||||||||||
ListTypeNames<AUDIO_RESAMPLE_TYPES>())); | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The constructor already checks
Suggested change
or
Suggested change
if you want to be defensive. |
||||||||||||||||||
)); // NOLINT | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
template <typename Out, typename In> | ||||||||||||||||||
void ResampleTyped(const OutListGPU<Out> &out, const InListGPU<const In> &in, | ||||||||||||||||||
cudaStream_t stream) { | ||||||||||||||||||
using Kernel = kernels::signal::resampling::ResamplerGPU<Out, In>; | ||||||||||||||||||
if (kmgr_.NumInstances() == 0) { | ||||||||||||||||||
kmgr_.Resize<Kernel>(1); | ||||||||||||||||||
auto params = ResamplingParams::FromQuality(quality_); | ||||||||||||||||||
kmgr_.Get<Kernel>(0).Initialize(params.lobes, params.lookup_size); | ||||||||||||||||||
} | ||||||||||||||||||
auto args = make_cspan(args_); | ||||||||||||||||||
kernels::KernelContext ctx; | ||||||||||||||||||
ctx.gpu.stream = stream; | ||||||||||||||||||
kmgr_.Setup<Kernel>(0, ctx, in, args); | ||||||||||||||||||
kmgr_.Run<Kernel>(0, ctx, out, in, args); | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
private: | ||||||||||||||||||
kernels::KernelManager kmgr_; | ||||||||||||||||||
}; | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
} // namespace audio | ||||||||||||||||||
|
||||||||||||||||||
DALI_REGISTER_OPERATOR(experimental__AudioResample, audio::ResampleGPU, GPU); | ||||||||||||||||||
|
||||||||||||||||||
} // namespace dali |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -17,8 +17,8 @@ | |||||||||
|
||||||||||
import numpy as np | ||||||||||
import scipy.io.wavfile | ||||||||||
from test_audio_decoder_utils import generate_waveforms, rosa_resample | ||||||||||
from test_utils import compare_pipelines, get_files, check_batch, dali_type_to_np | ||||||||||
from test_audio_decoder_utils import generate_waveforms | ||||||||||
from test_utils import check_batch, dali_type_to_np, as_array | ||||||||||
|
||||||||||
names = [ | ||||||||||
"/tmp/dali_test_1C.wav", | ||||||||||
|
@@ -34,64 +34,73 @@ | |||||||||
rates = [ 16000, 22050, 12347 ] | ||||||||||
lengths = [ 10000, 54321, 12345 ] | ||||||||||
|
||||||||||
def create_test_files(): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: nose was picking this as a test There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Making it "private" (with leading underscore) would help, too. |
||||||||||
def create_files(): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this would be another way, more verbose probably:
Suggested change
|
||||||||||
for i in range(len(names)): | ||||||||||
wave = generate_waveforms(lengths[i], freqs[i]) | ||||||||||
wave = (wave * 32767).round().astype(np.int16) | ||||||||||
scipy.io.wavfile.write(names[i], rates[i], wave) | ||||||||||
|
||||||||||
|
||||||||||
create_test_files() | ||||||||||
create_files() | ||||||||||
|
||||||||||
|
||||||||||
@pipeline_def | ||||||||||
def audio_decoder_pipe(): | ||||||||||
def audio_decoder_pipe(device): | ||||||||||
encoded, _ = fn.readers.file(files=names) | ||||||||||
audio0, sr0 = fn.decoders.audio(encoded, dtype=types.FLOAT) | ||||||||||
out_sr = 15000 | ||||||||||
audio1, sr1 = fn.decoders.audio(encoded, dtype=types.FLOAT, sample_rate=out_sr) | ||||||||||
if device == 'gpu': | ||||||||||
audio0 = audio0.gpu() | ||||||||||
audio2 = fn.experimental.audio_resample(audio0, in_rate=sr0, out_rate=out_sr) | ||||||||||
audio3 = fn.experimental.audio_resample(audio0, scale=out_sr/sr0) | ||||||||||
audio4 = fn.experimental.audio_resample(audio0, out_length=fn.shapes(audio1)[0]) | ||||||||||
return audio1, audio2, audio3, audio4 | ||||||||||
|
||||||||||
def test_standalone_vs_fused(): | ||||||||||
pipe = audio_decoder_pipe(batch_size=2, num_threads=1, device_id=0) | ||||||||||
def _test_standalone_vs_fused(device): | ||||||||||
pipe = audio_decoder_pipe(device=device, batch_size=2, num_threads=1, device_id=0) | ||||||||||
pipe.build() | ||||||||||
is_gpu = device == 'gpu' | ||||||||||
for _ in range(2): | ||||||||||
outs = pipe.run() | ||||||||||
# two sampling rates - should be bit-exact | ||||||||||
check_batch(outs[0], outs[1], eps=0, max_allowed_error=0) | ||||||||||
check_batch(outs[0], outs[1], eps=1e-6 if is_gpu else 0, | ||||||||||
max_allowed_error=1e-4 if is_gpu else 0) | ||||||||||
# numerical round-off error in rate | ||||||||||
check_batch(outs[0], outs[2], eps=1e-6, max_allowed_error=1e-4) | ||||||||||
# here, the sampling rate is slightly different, so we can tolerate larger errors | ||||||||||
check_batch(outs[0], outs[3], eps=1e-4, max_allowed_error=1) | ||||||||||
|
||||||||||
def _test_type_conversion(src_type, in_values, dst_type, out_values, eps): | ||||||||||
def test_standalone_vs_fused(): | ||||||||||
for device in ('gpu', 'cpu'): | ||||||||||
yield _test_standalone_vs_fused, device | ||||||||||
|
||||||||||
def _test_type_conversion(device, src_type, in_values, dst_type, out_values, eps): | ||||||||||
src_nptype = dali_type_to_np(src_type) | ||||||||||
dst_nptype = dali_type_to_np(dst_type) | ||||||||||
assert len(out_values) == len(in_values) | ||||||||||
in_data = [np.full((100 + 10 * i,), x, src_nptype) for i, x in enumerate(in_values)] | ||||||||||
@pipeline_def(batch_size=len(in_values)) | ||||||||||
def test_pipe(): | ||||||||||
input = fn.external_source(in_data, batch = False, cycle = 'quiet') | ||||||||||
def test_pipe(device): | ||||||||||
input = fn.external_source(in_data, batch=False, cycle='quiet', device=device) | ||||||||||
return fn.experimental.audio_resample(input, dtype=dst_type, scale=1, quality=0) | ||||||||||
pipe = test_pipe(device_id=0, num_threads=4) | ||||||||||
pipe = test_pipe(device, device_id=0, num_threads=4) | ||||||||||
pipe.build() | ||||||||||
is_gpu = device == 'gpu' | ||||||||||
for _ in range(2): | ||||||||||
out, = pipe.run() | ||||||||||
assert len(out) == len(out_values) | ||||||||||
assert out.dtype == dst_type | ||||||||||
for i in range(len(out_values)): | ||||||||||
ref = np.full_like(in_data[i], out_values[i], dst_nptype) | ||||||||||
if not np.allclose(out.at(i), ref, 1e-6, eps): | ||||||||||
print("Actual: ", out.at(i)) | ||||||||||
print(out.at(i).dtype, out.at(i).shape) | ||||||||||
out_arr = as_array(out[i]) | ||||||||||
if not np.allclose(out_arr, ref, 1e-6, eps): | ||||||||||
print("Actual: ", out_arr) | ||||||||||
print(out_arr.dtype, out_arr.shape) | ||||||||||
print("Reference: ", ref) | ||||||||||
print(ref.dtype, ref.shape) | ||||||||||
print("Diff: ", out.at(i).astype(np.float) - ref) | ||||||||||
assert np.allclose(out.at(i), ref, 1e-6, eps) | ||||||||||
|
||||||||||
print("Diff: ", out_arr.astype(np.float) - ref) | ||||||||||
assert False | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick: the original code deliberately repeated the check here, so that the error in nosetests would appear as more than |
||||||||||
|
||||||||||
def test_dynamic_ranges(): | ||||||||||
for type, values, eps in [(types.FLOAT, [-1.e30, -1-1.e-6, -1, -0.5, -1.e-30, 0, 1.e-30, 0.5, 1, 1+1.e-6, 1e30], 0), | ||||||||||
|
@@ -101,7 +110,8 @@ def test_dynamic_ranges(): | |||||||||
(types.INT16, [-32768, -32767, -100, -1, 0, 1, 100, 32767], 0), | ||||||||||
(types.UINT32, [0, 1, 0x7fffffff, 0x80000000, 0xfffffffe, 0xffffffff], 128), | ||||||||||
(types.INT32, [-0x80000000, -0x7fffffff, -100, -1, 0, 1, 0x7fffffff], 128)]: | ||||||||||
yield _test_type_conversion, type, values, type, values, eps | ||||||||||
for device in ('cpu', 'gpu'): | ||||||||||
yield _test_type_conversion, device, type, values, type, values, eps | ||||||||||
|
||||||||||
def test_type_conversion(): | ||||||||||
type_ranges = [(types.FLOAT, [-1, 1]), | ||||||||||
|
@@ -138,4 +148,5 @@ def test_type_conversion(): | |||||||||
if eps < 1 and (o_lo != -o_hi or (i_hi != i_lo and dst_type != types.FLOAT)): | ||||||||||
eps = 1 | ||||||||||
|
||||||||||
yield _test_type_conversion, src_type, in_values, dst_type, out_values, eps | ||||||||||
for device in ('cpu', 'gpu'): | ||||||||||
yield _test_type_conversion, device, src_type, in_values, dst_type, out_values, eps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick:
Now that we have
args
, this would increase precision - we use the inverse scale in the kernel (in_rate/out_rate), so performing the other division here and reciprocal there will decrease (albeit very slightly) the precision.