From 0778f6128277c94601baed14732d2546b6b93bac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Zientkiewicz?= Date: Tue, 10 May 2022 07:57:48 +0200 Subject: [PATCH 01/36] Initial effort. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: MichaƂ Zientkiewicz --- dali/kernels/signal/resampling.h | 23 ++++++---- dali/kernels/signal/resampling_gpu.cu | 0 dali/kernels/signal/resampling_gpu.h | 63 +++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 8 deletions(-) create mode 100644 dali/kernels/signal/resampling_gpu.cu create mode 100644 dali/kernels/signal/resampling_gpu.h diff --git a/dali/kernels/signal/resampling.h b/dali/kernels/signal/resampling.h index 42cd452cdc4..28c1c161d14 100644 --- a/dali/kernels/signal/resampling.h +++ b/dali/kernels/signal/resampling.h @@ -61,12 +61,12 @@ struct ResamplingWindow { return {i0, i1}; } - inline float operator()(float x) const { + inline DALI_HOST_DEV float operator()(float x) const { float fi = x * scale + center; float floori = std::floor(fi); float di = fi - floori; int i = floori; - assert(i >= 0 && i < static_cast(lookup.size())); + assert(i >= 0 && i < lookup_size); return lookup[i] + di * (lookup[i + 1] - lookup[i]); } @@ -112,25 +112,32 @@ struct ResamplingWindow { float scale = 1, center = 1; int lobes = 0, coeffs = 0; - std::vector lookup; + int lookup_size = 0; + const float *lookup = nullptr; }; -inline void windowed_sinc(ResamplingWindow &window, +struct ResamplingWindowCPU : ResamplingWindow { + std::vector storage; +}; + +inline void windowed_sinc(ResamplingWindowCPU &window, int coeffs, int lobes, std::function envelope = Hann) { assert(coeffs > 1 && lobes > 0 && "Degenerate parameters specified."); float scale = 2.0f * lobes / (coeffs - 1); float scale_envelope = 2.0f / coeffs; window.coeffs = coeffs; window.lobes = lobes; - window.lookup.clear(); - window.lookup.resize(coeffs + 5); // add zeros and a full 4-lane vector + window.storage.clear(); + window.storage.resize(coeffs + 5); // add zeros and a full 4-lane vector int center = (coeffs - 1) * 0.5f; for (int i = 0; i < coeffs; i++) { float x = (i - center) * scale; float y = (i - center) * scale_envelope; float w = sinc(x) * envelope(y); - window.lookup[i + 1] = w; + window.storage[i + 1] = w; } + window.lookup = window.storage.data(); + window.lookup_size = window.storage.size(); window.center = center + 1; // allow for leading zero window.scale = 1 / scale; } @@ -141,7 +148,7 @@ inline int64_t resampled_length(int64_t in_length, double in_rate, double out_ra } struct Resampler { - ResamplingWindow window; + ResamplingWindowCPU window; void Initialize(int lobes = 16, int lookup_size = 2048) { windowed_sinc(window, lookup_size, lobes); diff --git a/dali/kernels/signal/resampling_gpu.cu b/dali/kernels/signal/resampling_gpu.cu new file mode 100644 index 00000000000..e69de29bb2d diff --git a/dali/kernels/signal/resampling_gpu.h b/dali/kernels/signal/resampling_gpu.h new file mode 100644 index 00000000000..4390b34f665 --- /dev/null +++ b/dali/kernels/signal/resampling_gpu.h @@ -0,0 +1,63 @@ +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_KERNELS_SIGNAL_RESAMPLING_GPU_H_ +#define DALI_KERNELS_SIGNAL_RESAMPLING_GPU_H_ + +#include "dali/kernels/signal/resampling.h" + +namespace dali { +namespace kernels { +namespace signal { + +namespace resampling { + +ResamplingWindow ToGPU(Scratchpad &scratch, const ResamplingWindow &cpu_window) { + ResamplingWindow wnd = cpu_window; + wnd.lookup = scratch.ToGPU(make_span(cpu_window.lookup, cpu_windwo.lookup_size)); + return wnd; +} + +struct ResamplerGPU { + ResamplingWindowCPU window; + + void Initialize(int lobes = 16, int lookup_size = 2048) { + windowed_sinc(window, lookup_size, lobes); + } + + + /** + * @brief Resample multi-channel signal and convert to Out + * + * Calculates a range of resampled signal. + * The function can seamlessly resample the input and produce the result in chunks. + * To reuse memory and still simulate chunk processing, adjust the in/out pointers. + */ + template + void Resample( + Out *__restrict__ out, int64_t out_begin, int64_t out_end, double out_rate, + const float *__restrict__ in, int64_t n_in, double in_rate, + int num_channels, + cudaStream_t stream); +}; + +} // namespace resampling +} // namespace signal +} // namespace kernels +} // namespace dali + +} // namespace dali + + +#endif // DALI_KERNELS_SIGNAL_RESAMPLING_GPU_H_ From 740bd14aa16f8c27c361018eb4a8707f082506ec Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Tue, 10 May 2022 17:38:33 +0200 Subject: [PATCH 02/36] Add signal resampling GPU kernel Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling.h | 42 +++-- dali/kernels/signal/resampling_gpu.cu | 0 dali/kernels/signal/resampling_gpu.cuh | 117 +++++++++++++ dali/kernels/signal/resampling_gpu.h | 104 ++++++++--- dali/kernels/signal/resampling_gpu_test.cu | 78 +++++++++ dali/kernels/signal/resampling_test.cc | 190 +++++++++++++-------- dali/kernels/signal/resampling_test.h | 67 ++++++++ 7 files changed, 489 insertions(+), 109 deletions(-) delete mode 100644 dali/kernels/signal/resampling_gpu.cu create mode 100644 dali/kernels/signal/resampling_gpu.cuh create mode 100644 dali/kernels/signal/resampling_gpu_test.cu create mode 100644 dali/kernels/signal/resampling_test.h diff --git a/dali/kernels/signal/resampling.h b/dali/kernels/signal/resampling.h index 28c1c161d14..6fd8917209a 100644 --- a/dali/kernels/signal/resampling.h +++ b/dali/kernels/signal/resampling.h @@ -29,6 +29,7 @@ #include "dali/core/small_vector.h" #include "dali/core/convert.h" #include "dali/core/static_switch.h" +#include "dali/core/geom/vec.h" namespace dali { namespace kernels { @@ -54,7 +55,7 @@ inline float32x4_t vsetq_f32(float x0, float x1, float x2, float x3) { #endif struct ResamplingWindow { - inline std::pair input_range(float x) const { + inline DALI_HOST_DEV ivec<2> input_range(float x) const { int xc = std::ceil(x); int i0 = xc - lobes; int i1 = xc + lobes; @@ -220,8 +221,9 @@ struct Resampler { float in_pos = in_block_f - in_block_i; const float *__restrict__ in_block_ptr = in + in_block_i; for (int64_t out_pos = out_block; out_pos < block_end; out_pos++, in_pos += fscale) { - int i0, i1; - std::tie(i0, i1) = window.input_range(in_pos); + auto irange = window.input_range(in_pos); + int i0 = irange[0]; + int i1 = irange[1]; if (i0 + in_block_i < 0) i0 = -in_block_i; if (i1 + in_block_i > n_in) @@ -251,8 +253,9 @@ struct Resampler { * To reuse memory and still simulate chunk processing, adjust the in/out pointers. * * @tparam static_channels number of channels, if known at compile time, or -1 + * @tparam downmix whether to downmix all channels in the output */ - template + template void Resample( Out *__restrict__ out, int64_t out_begin, int64_t out_end, double out_rate, const float *__restrict__ in, int64_t n_in, double in_rate, @@ -282,8 +285,9 @@ struct Resampler { float in_pos = in_block_f - in_block_i; const float *__restrict__ in_block_ptr = in + in_block_i * num_channels; for (int64_t out_pos = out_block; out_pos < block_end; out_pos++, in_pos += fscale) { - int i0, i1; - std::tie(i0, i1) = window.input_range(in_pos); + auto irange = window.input_range(in_pos); + int i0 = irange[0]; + int i1 = irange[1]; if (i0 + in_block_i < 0) i0 = -in_block_i; if (i1 + in_block_i > n_in) @@ -304,8 +308,16 @@ struct Resampler { } } assert(out_pos >= out_begin && out_pos < out_end); - for (int c = 0; c < num_channels; c++) - out[out_pos * num_channels + c] = ConvertSatNorm(tmp[c]); + if (downmix) { + float out_val = 0; + for (int c = 0; c < num_channels; c++) + out_val += tmp[c]; + out_val /= num_channels; + out[out_pos] = ConvertSatNorm(out_val); + } else { + for (int c = 0; c < num_channels; c++) + out[out_pos * num_channels + c] = ConvertSatNorm(tmp[c]); + } } } } @@ -321,12 +333,14 @@ struct Resampler { void Resample( Out *__restrict__ out, int64_t out_begin, int64_t out_end, double out_rate, const float *__restrict__ in, int64_t n_in, double in_rate, - int num_channels) { - VALUE_SWITCH(num_channels, static_channels, (1, 2, 3, 4, 5, 6, 7, 8), - (Resample(out, out_begin, out_end, out_rate, - in, n_in, in_rate, static_channels);), - (Resample<-1, Out>(out, out_begin, out_end, out_rate, - in, n_in, in_rate, num_channels))); + int num_channels, bool downmix = false) { + BOOL_SWITCH(downmix, Downmix, ( + VALUE_SWITCH(num_channels, static_channels, (1, 2, 3, 4, 5, 6, 7, 8), + (Resample(out, out_begin, out_end, out_rate, + in, n_in, in_rate, static_channels);), + (Resample<-1, Downmix, Out>(out, out_begin, out_end, out_rate, + in, n_in, in_rate, num_channels))); + )); // NOLINT } }; diff --git a/dali/kernels/signal/resampling_gpu.cu b/dali/kernels/signal/resampling_gpu.cu deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/dali/kernels/signal/resampling_gpu.cuh b/dali/kernels/signal/resampling_gpu.cuh new file mode 100644 index 00000000000..6d32f6d227c --- /dev/null +++ b/dali/kernels/signal/resampling_gpu.cuh @@ -0,0 +1,117 @@ +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_KERNELS_SIGNAL_RESAMPLING_GPU_CUH_ +#define DALI_KERNELS_SIGNAL_RESAMPLING_GPU_CUH_ + +#include +#include "dali/kernels/signal/resampling.h" + +namespace dali { +namespace kernels { +namespace signal { + +namespace resampling { + +struct SampleDesc { + void *out; + const void *in; + ResamplingWindow window; + int64_t in_len; // num samples in input + int64_t out_len; // num samples in output + int64_t nchannels; // number of channels + double scale; // in_sampling_rate / out_sampling_rate +}; + +/** + * @brief Resamples 1D signal (single or multi-channel), optionally downmixing and converting to a different data type. + * + * @param samples sample descriptors + */ +template +__global__ void ResampleGPUKernel(const SampleDesc *samples) { + auto sample = samples[blockIdx.y]; + double scale = sample.scale; + float fscale = scale; + int nchannels = SingleChannel ? 1 : sample.nchannels; + auto &window = sample.window; + + Out* out = reinterpret_cast(sample.out); + const In* in = reinterpret_cast(sample.in); + + int64_t grid_stride = static_cast(gridDim.x) * blockDim.x; + int64_t out_block = static_cast(blockIdx.x) * blockDim.x; + int64_t start_out_pos = out_block + threadIdx.x; + + double in_block_f = out_block * scale; + int64_t in_block_i = std::floor(in_block_f); + float in_pos_start = in_block_f - in_block_i; + const In* in_blk_ptr = in + in_block_i * nchannels; + + for (int64_t out_pos = start_out_pos; out_pos < sample.out_len; + out_pos += grid_stride, in_pos_start += fscale * grid_stride) { + float in_pos = in_pos_start + fscale * threadIdx.x; + auto i_range = window.input_range(in_pos); + int i0 = i_range[0]; + int i1 = i_range[1]; + if (i0 + in_block_i < 0) + i0 = -in_block_i; + if (i1 + in_block_i > sample.in_len) + i1 = sample.in_len - in_block_i; + + float out_val = 0; + if (SingleChannel) { + for (int i = i0; i < i1; i++) { + In in_val = in_blk_ptr[i]; + float x = i - in_pos; + float w = window(x); + out_val = fma(in_val, w, out_val); + } + out[out_pos] = ConvertSatNorm(out_val); + } else { // multiple channels + float tmp[32]; // more than enough + for (int c = 0; c < nchannels; c++) { + tmp[c] = 0; + } + + for (int i = i0; i < i1; i++) { + float x = i - in_pos; + float w = window(x); + for (int c = 0; c < nchannels; c++) { + In in_val = in_blk_ptr[i * nchannels + c]; + tmp[c] = fma(in_val, w, tmp[c]); + } + } + + if (Downmix) { + for (int c = 0; c < nchannels; c++) { + out_val += tmp[c]; + } + out_val /= nchannels; + out[out_pos] = ConvertSatNorm(out_val); + } else { + for (int c = 0; c < nchannels; c++) { + out[out_pos * nchannels + c] = ConvertSatNorm(tmp[c]); + } + } + } + } +} + +} // namespace resampling +} // namespace signal +} // namespace kernels +} // namespace dali + +#endif // DALI_KERNELS_SIGNAL_RESAMPLING_GPU_CUH_ diff --git a/dali/kernels/signal/resampling_gpu.h b/dali/kernels/signal/resampling_gpu.h index 4390b34f665..b9bc6ba220b 100644 --- a/dali/kernels/signal/resampling_gpu.h +++ b/dali/kernels/signal/resampling_gpu.h @@ -15,7 +15,14 @@ #ifndef DALI_KERNELS_SIGNAL_RESAMPLING_GPU_H_ #define DALI_KERNELS_SIGNAL_RESAMPLING_GPU_H_ +#include #include "dali/kernels/signal/resampling.h" +#include "dali/kernels/signal/resampling_gpu.cuh" +#include "dali/kernels/kernel.h" +#include "dali/kernels/dynamic_scratchpad.h" +#include "dali/core/mm/memory.h" +#include "dali/core/dev_buffer.h" +#include "dali/core/static_switch.h" namespace dali { namespace kernels { @@ -23,33 +30,81 @@ namespace signal { namespace resampling { -ResamplingWindow ToGPU(Scratchpad &scratch, const ResamplingWindow &cpu_window) { - ResamplingWindow wnd = cpu_window; - wnd.lookup = scratch.ToGPU(make_span(cpu_window.lookup, cpu_windwo.lookup_size)); - return wnd; -} - -struct ResamplerGPU { - ResamplingWindowCPU window; - +template +class ResamplerGPU { + public: void Initialize(int lobes = 16, int lookup_size = 2048) { - windowed_sinc(window, lookup_size, lobes); + windowed_sinc(window_cpu_, lookup_size, lobes); + window_gpu_storage_.from_host(window_cpu_.storage); + window_gpu_ = window_cpu_; + window_gpu_.lookup = window_gpu_storage_.data(); + } + + KernelRequirements Setup(KernelContext &context, const InListGPU &in, + span in_rate, span out_rate, bool downmix) { + KernelRequirements req; + auto out_shape = in.shape; + for (int i = 0; i < in.num_samples(); i++) { + auto in_sh = in.shape.tensor_shape_span(i); + auto out_sh = out_shape.tensor_shape_span(i); + out_sh[0] = resampled_length(in_sh[0], in_rate[i], out_rate[i]); + if (downmix) + out_sh[1] = 1; + } + req.output_shapes = {out_shape}; + return req; } + void Run(KernelContext &context, const OutListGPU &out, + const InListGPU &in, span in_rates, span out_rates, + bool downmix) { + if (window_gpu_storage_.empty()) + Initialize(); + + DynamicScratchpad dyn_scratchpad({}, AccessOrder(context.gpu.stream)); + if (!context.scratchpad) + context.scratchpad = &dyn_scratchpad; + auto &scratch = *context.scratchpad; + + int nsamples = in.num_samples(); + auto samples_cpu = + make_span(scratch.Allocate(nsamples), nsamples); + + bool any_multichannel = false; + for (int i = 0; i < nsamples; i++) { + auto &desc = samples_cpu[i]; + desc.in = in[i].data; + desc.out = out[i].data; + desc.window = window_gpu_; + const auto &in_sh = in[i].shape; + const auto &out_sh = out[i].shape; + desc.in_len = in_sh[0]; + desc.out_len = resampled_length(in_sh[0], in_rates[i], out_rates[i]); + assert(desc.out_len == out_sh[0]); + desc.nchannels = in[i].shape.sample_dim() > 1 ? in_sh[1] : 1; + desc.scale = static_cast(in_rates[i]) / out_rates[i]; + any_multichannel |= desc.nchannels > 1; + } - /** - * @brief Resample multi-channel signal and convert to Out - * - * Calculates a range of resampled signal. - * The function can seamlessly resample the input and produce the result in chunks. - * To reuse memory and still simulate chunk processing, adjust the in/out pointers. - */ - template - void Resample( - Out *__restrict__ out, int64_t out_begin, int64_t out_end, double out_rate, - const float *__restrict__ in, int64_t n_in, double in_rate, - int num_channels, - cudaStream_t stream); + auto samples_gpu = scratch.ToGPU(context.gpu.stream, samples_cpu); + + dim3 block(256, 1); + int blocks_per_sample = std::max(32, 1024 / nsamples); + dim3 grid(blocks_per_sample, nsamples); + + BOOL_SWITCH(downmix && any_multichannel, Downmix, ( + BOOL_SWITCH(!any_multichannel, SingleChannel, ( + ResampleGPUKernel + <<>>(samples_gpu); + )); // NOLINT + )); // NOLINT + CUDA_CALL(cudaGetLastError()); + } + + private: + ResamplingWindowCPU window_cpu_; + ResamplingWindow window_gpu_; + DeviceBuffer window_gpu_storage_; }; } // namespace resampling @@ -57,7 +112,4 @@ struct ResamplerGPU { } // namespace kernels } // namespace dali -} // namespace dali - - #endif // DALI_KERNELS_SIGNAL_RESAMPLING_GPU_H_ diff --git a/dali/kernels/signal/resampling_gpu_test.cu b/dali/kernels/signal/resampling_gpu_test.cu new file mode 100644 index 00000000000..aae6033ff8b --- /dev/null +++ b/dali/kernels/signal/resampling_gpu_test.cu @@ -0,0 +1,78 @@ +// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "dali/kernels/signal/resampling_gpu.h" +#include "dali/kernels/signal/resampling_test.h" + + +namespace dali { +namespace kernels { +namespace signal { +namespace resampling { + +class ResamplingGPUTest : public ResamplingTest { + public: + void RunResampling(span in_rates, span out_rates, + bool downmix) override { + ResamplerGPU R; + R.Initialize(16); + + KernelContext ctx; + ctx.gpu.stream = 0; + + auto req = R.Setup(ctx, ttl_in_.gpu(), in_rates, out_rates, downmix); + auto outref_sh = ttl_outref_.cpu().shape; + auto in_batch_sh = ttl_in_.cpu().shape; + for (int s = 0; s < outref_sh.size(); s++) { + auto sh = req.output_shapes[0].tensor_shape_span(s); + auto expected_sh = outref_sh.tensor_shape_span(s); + auto in_sh = in_batch_sh.tensor_shape_span(s); + ASSERT_EQ(sh.size(), in_sh.size()); + if (downmix) { + ASSERT_EQ(sh[1], 1); + } else { + if (sh.size() > 1) + ASSERT_EQ(sh[1], in_sh[1]); + } + } + + R.Run(ctx, ttl_out_.gpu(), ttl_in_.gpu(), in_rates, out_rates, downmix); + + CUDA_CALL(cudaStreamSynchronize(ctx.gpu.stream)); + } +}; + +TEST_F(ResamplingGPUTest, SingleChannel) { + this->RunTest(8, 1, false); +} + +TEST_F(ResamplingGPUTest, TwoChannel) { + this->RunTest(3, 2, false); +} + +TEST_F(ResamplingGPUTest, EightChannel) { + this->RunTest(3, 8, false); +} + +TEST_F(ResamplingGPUTest, ThreeChannelDownmix) { + this->RunTest(3, 3, true); +} + +} // namespace resampling +} // namespace signal +} // namespace kernels +} // namespace dali diff --git a/dali/kernels/signal/resampling_test.cc b/dali/kernels/signal/resampling_test.cc index 9dd2233867d..642177563c9 100644 --- a/dali/kernels/signal/resampling_test.cc +++ b/dali/kernels/signal/resampling_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,91 +16,143 @@ #include #include #include "dali/kernels/signal/resampling.h" +#include "dali/kernels/signal/resampling_test.h" namespace dali { namespace kernels { namespace signal { namespace resampling { -namespace { +void ResamplingTest::PrepareData(int nsamples, int nchannels, span in_rates, + span out_rates) { + TensorListShape<> in_sh(nsamples, nchannels > 1 ? 2 : 1); + TensorListShape<> out_sh(nsamples, nchannels > 1 ? 2 : 1); + for (int s = 0; s < nsamples; s++) { + double in_rate = in_rates[s]; + double out_rate = out_rates[s]; + double scale = static_cast(in_rate) / out_rate; + int n_in = in_rate + 12345 * s; // different lengths + int n_out = std::ceil(n_in / scale); + in_sh.tensor_shape_span(s)[0] = n_in; + out_sh.tensor_shape_span(s)[0] = n_out; + if (nchannels > 1) { + in_sh.tensor_shape_span(s)[1] = nchannels; + out_sh.tensor_shape_span(s)[1] = nchannels; + } + } + ttl_in_.reshape(in_sh); + ttl_out_.reshape(out_sh); + ttl_outref_.reshape(out_sh); + for (int s = 0; s < nsamples; s++) { + double in_rate = in_rates[s]; + double out_rate = out_rates[s]; + double scale = static_cast(in_rate) / out_rate; + for (int c = 0; c < nchannels; c++) { + float f_in = 0.1f + 0.01 * s + 0.001 * c; + float f_out = f_in * scale; + int n_in = in_sh.tensor_shape_span(s)[0]; + int n_out = out_sh.tensor_shape_span(s)[0]; + TestWave(ttl_in_.cpu()[s].data + c, n_in, nchannels, f_in); + TestWave(ttl_outref_.cpu()[s].data + c, n_out, nchannels, f_out); + } + } +} + +void ResamplingTest::Verify(bool downmix) { + auto in_sh = ttl_in_.cpu().shape; + auto out_sh = ttl_outref_.cpu().shape; + int nsamples = in_sh.num_samples(); + double err = 0, max_diff = 0; + + for (int s = 0; s < nsamples; s++) { + float *out_data = ttl_out_.cpu()[s].data; + float *out_ref = ttl_outref_.cpu()[s].data; + int n_out = out_sh.tensor_shape_span(s)[0]; + int nchannels = out_sh.sample_dim() == 1 ? 1 : out_sh.tensor_shape_span(s)[1]; + for (int i = 0; i < n_out; i++) { + float ref_val = 0; + if (downmix) { + for (int c = 0; c < nchannels; c++) { + ref_val += out_ref[i * nchannels + c]; + } + ref_val /= nchannels; + } else { + ref_val = out_ref[i]; + } -double HannWindow(int i, int n) { - assert(n > 0); - return Hann(2.0*i / n - 1); + ASSERT_NEAR(out_data[i], ref_val, eps()) + << "Sample error too big @ sample=" << s << " pos=" << i << std::endl; + float diff = std::abs(out_data[i] - ref_val); + if (diff > max_diff) + max_diff = diff; + err += diff * diff; + } + + err = std::sqrt(err / n_out); + EXPECT_LE(err, max_avg_err()) << "Average error too big"; + std::cerr << "Resampling with Hann-windowed sinc filter and 16 zero crossings" + "\n max difference vs fresh signal: " + << max_diff << "\n RMS error: " << err << std::endl; + } } -template -void TestWave(T *out, int n, int stride, float freq) { - for (int i = 0; i < n; i++) { - float x = i * freq; - float f = std::sin(i* freq) * HannWindow(i, n); - out[i*stride] = ConvertSatNorm(f); +void ResamplingTest::RunTest(int nsamples, int nchannels, bool downmix) { + std::vector in_rates_v; + for (int i = 0; i < nsamples; i++) { + if (i % 2 == 0) + in_rates_v.push_back(22050.0f); + else + in_rates_v.push_back(44100.0f); } + auto in_rates = make_cspan(in_rates_v); + + std::vector out_rates_v(nsamples, 16000.0f); + auto out_rates = make_cspan(out_rates_v); + + PrepareData(nsamples, nchannels, in_rates, out_rates); + + RunResampling(in_rates, out_rates, downmix); + + Verify(downmix); } -} // namespace - -TEST(ResampleSinc, SingleChannel) { - int n_in = 22050, n_out = 16000; // typical downsampling - std::vector in(n_in); - std::vector out(n_out); - std::vector ref(out.size()); - float f_in = 0.1f; - float f_out = f_in * n_in / n_out; - double in_rate = n_in; - double out_rate = n_out; - TestWave(in.data(), n_in, 1, f_in); - TestWave(ref.data(), n_out, 1, f_out); - Resampler R; - R.Initialize(16); - R.Resample(out.data(), 0, n_out, out_rate, in.data(), n_in, in_rate); +class ResamplingCPUTest : public ResamplingTest { + public: + void RunResampling(span in_rates, span out_rates, bool downmix) override { + Resampler R; + R.Initialize(16); - double err = 0, max_diff = 0; - for (int i = 0; i < n_out; i++) { - ASSERT_NEAR(out[i], ref[i], 1e-3) << "Sample error too big @" << i << std::endl; - float diff = std::abs(out[i] - ref[i]); - if (diff > max_diff) - max_diff = diff; - err += diff*diff; + int nsamples = in_rates.size(); + assert(nsamples == out_rates.size()); + + auto in_view = ttl_in_.cpu(); + auto out_view = ttl_out_.cpu(); + for (int s = 0; s < nsamples; s++) { + auto out_sh = out_view.shape[s]; + auto in_sh = in_view.shape[s]; + int n_out = out_sh[0]; + int n_in = in_sh[0]; + int nchannels = in_sh.sample_dim() > 1 ? in_sh[1] : 1; + R.Resample(out_view[s].data, 0, n_out, out_rates[s], in_view[s].data, n_in, in_rates[s], + nchannels, downmix); + } } - err = std::sqrt(err/n_out); - EXPECT_LE(err, 1e-3) << "Average error too big"; - std::cerr << "Resampling with Hann-windowed sinc filter and 16 zero crossings" - "\n max difference vs fresh signal: " << max_diff << - "\n RMS error: " << err << std::endl; +}; + +TEST_F(ResamplingCPUTest, SingleChannel) { + this->RunTest(1, 1, false); } -TEST(ResampleSinc, MultiChannel) { - int n_in = 22050, n_out = 22053; // some weird upsampling - int ch = 5; - std::vector in(n_in * ch); - std::vector out(n_out * ch); - std::vector ref(out.size()); - double in_rate = n_in; - double out_rate = n_out; - for (int c = 0; c < ch; c++) { - float f_in = 0.1f * (1 + c * 0.012345); // different signal in each channel - float f_out = f_in * n_in / n_out; - TestWave(in.data() + c, n_in, ch, f_in); - TestWave(ref.data() + c, n_out, ch, f_out); - } - Resampler R; - R.Initialize(16); - R.Resample(out.data(), 0, n_out, out_rate, in.data(), n_in, in_rate, ch); +TEST_F(ResamplingCPUTest, TwoChannel) { + this->RunTest(1, 2, false); +} - double err = 0, max_diff = 0; - for (int i = 0; i < n_out * ch; i++) { - ASSERT_NEAR(out[i], ref[i], 2e-3) << "Sample error too big @" << i << std::endl; - float diff = std::abs(out[i] - ref[i]); - if (diff > max_diff) - max_diff = diff; - err += diff*diff; - } - err = std::sqrt(err/(n_out * ch)); - EXPECT_LE(err, 1e-3) << "Average error too big"; - std::cerr << "Resampling with Hann-windowed sinc filter and 16 zero crossings" - "\n max difference vs fresh signal: " << max_diff << - "\n RMS error: " << err << std::endl; +TEST_F(ResamplingCPUTest, EightChannel) { + this->RunTest(1, 8, false); +} + +TEST_F(ResamplingCPUTest, ThreeChannelDownmix) { + this->RunTest(1, 3, true); } } // namespace resampling diff --git a/dali/kernels/signal/resampling_test.h b/dali/kernels/signal/resampling_test.h new file mode 100644 index 00000000000..0163410daf1 --- /dev/null +++ b/dali/kernels/signal/resampling_test.h @@ -0,0 +1,67 @@ +// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "dali/kernels/signal/resampling.h" +#include "dali/test/tensor_test_utils.h" +#include "dali/test/test_tensors.h" + +namespace dali { +namespace kernels { +namespace signal { +namespace resampling { + +namespace { + +double HannWindow(int i, int n) { + assert(n > 0); + return Hann(2.0*i / n - 1); +} + +template +void TestWave(T *out, int n, int stride, float freq) { + for (int i = 0; i < n; i++) { + float f = std::sin(i* freq) * HannWindow(i, n); + out[i*stride] = ConvertSatNorm(f); + } +} + +} // namespace + +class ResamplingTest : public ::testing::Test { + public: + void PrepareData(int nsamples, int nchannels, + span in_rates, span out_rates); + + virtual float eps() const { return 2e-3; } + virtual float max_avg_err() const { return 1e-3; } + void Verify(bool downmix); + + virtual void RunResampling(span in_rates, span out_rates, bool downmix) = 0; + + void RunTest(int nsamples, int nchannels, bool downmix); + + + TestTensorList ttl_in_; + TestTensorList ttl_out_; + TestTensorList ttl_outref_; +}; + + +} // namespace resampling +} // namespace signal +} // namespace kernels +} // namespace dali From 5b410ed6c33bdeefd2166e686edbfd9d1b9c45db Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Wed, 11 May 2022 12:55:50 +0200 Subject: [PATCH 03/36] Remove downmixing Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling.h | 29 +++++---------- dali/kernels/signal/resampling_gpu.cuh | 17 +++------ dali/kernels/signal/resampling_gpu.h | 16 +++----- dali/kernels/signal/resampling_gpu_test.cu | 29 +++++---------- dali/kernels/signal/resampling_test.cc | 43 +++++++++------------- dali/kernels/signal/resampling_test.h | 22 +++++------ 6 files changed, 58 insertions(+), 98 deletions(-) diff --git a/dali/kernels/signal/resampling.h b/dali/kernels/signal/resampling.h index 6fd8917209a..802c8a4c225 100644 --- a/dali/kernels/signal/resampling.h +++ b/dali/kernels/signal/resampling.h @@ -253,9 +253,8 @@ struct Resampler { * To reuse memory and still simulate chunk processing, adjust the in/out pointers. * * @tparam static_channels number of channels, if known at compile time, or -1 - * @tparam downmix whether to downmix all channels in the output */ - template + template void Resample( Out *__restrict__ out, int64_t out_begin, int64_t out_end, double out_rate, const float *__restrict__ in, int64_t n_in, double in_rate, @@ -308,16 +307,8 @@ struct Resampler { } } assert(out_pos >= out_begin && out_pos < out_end); - if (downmix) { - float out_val = 0; - for (int c = 0; c < num_channels; c++) - out_val += tmp[c]; - out_val /= num_channels; - out[out_pos] = ConvertSatNorm(out_val); - } else { - for (int c = 0; c < num_channels; c++) - out[out_pos * num_channels + c] = ConvertSatNorm(tmp[c]); - } + for (int c = 0; c < num_channels; c++) + out[out_pos * num_channels + c] = ConvertSatNorm(tmp[c]); } } } @@ -333,14 +324,12 @@ struct Resampler { void Resample( Out *__restrict__ out, int64_t out_begin, int64_t out_end, double out_rate, const float *__restrict__ in, int64_t n_in, double in_rate, - int num_channels, bool downmix = false) { - BOOL_SWITCH(downmix, Downmix, ( - VALUE_SWITCH(num_channels, static_channels, (1, 2, 3, 4, 5, 6, 7, 8), - (Resample(out, out_begin, out_end, out_rate, - in, n_in, in_rate, static_channels);), - (Resample<-1, Downmix, Out>(out, out_begin, out_end, out_rate, - in, n_in, in_rate, num_channels))); - )); // NOLINT + int num_channels) { + VALUE_SWITCH(num_channels, static_channels, (1, 2, 3, 4, 5, 6, 7, 8), + (Resample(out, out_begin, out_end, out_rate, + in, n_in, in_rate, static_channels);), + (Resample<-1, Out>(out, out_begin, out_end, out_rate, + in, n_in, in_rate, num_channels))); } }; diff --git a/dali/kernels/signal/resampling_gpu.cuh b/dali/kernels/signal/resampling_gpu.cuh index 6d32f6d227c..39ad0f7aec2 100644 --- a/dali/kernels/signal/resampling_gpu.cuh +++ b/dali/kernels/signal/resampling_gpu.cuh @@ -35,11 +35,11 @@ struct SampleDesc { }; /** - * @brief Resamples 1D signal (single or multi-channel), optionally downmixing and converting to a different data type. + * @brief Resamples 1D signal (single or multi-channel), optionally converting to a different data type. * * @param samples sample descriptors */ -template +template __global__ void ResampleGPUKernel(const SampleDesc *samples) { auto sample = samples[blockIdx.y]; double scale = sample.scale; @@ -80,6 +80,7 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { } out[out_pos] = ConvertSatNorm(out_val); } else { // multiple channels + assert(nchannels <= 32); float tmp[32]; // more than enough for (int c = 0; c < nchannels; c++) { tmp[c] = 0; @@ -94,16 +95,8 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { } } - if (Downmix) { - for (int c = 0; c < nchannels; c++) { - out_val += tmp[c]; - } - out_val /= nchannels; - out[out_pos] = ConvertSatNorm(out_val); - } else { - for (int c = 0; c < nchannels; c++) { - out[out_pos * nchannels + c] = ConvertSatNorm(tmp[c]); - } + for (int c = 0; c < nchannels; c++) { + out[out_pos * nchannels + c] = ConvertSatNorm(tmp[c]); } } } diff --git a/dali/kernels/signal/resampling_gpu.h b/dali/kernels/signal/resampling_gpu.h index b9bc6ba220b..ed7db6ff2e9 100644 --- a/dali/kernels/signal/resampling_gpu.h +++ b/dali/kernels/signal/resampling_gpu.h @@ -41,23 +41,21 @@ class ResamplerGPU { } KernelRequirements Setup(KernelContext &context, const InListGPU &in, - span in_rate, span out_rate, bool downmix) { + span in_rate, span out_rate) { KernelRequirements req; auto out_shape = in.shape; for (int i = 0; i < in.num_samples(); i++) { auto in_sh = in.shape.tensor_shape_span(i); auto out_sh = out_shape.tensor_shape_span(i); out_sh[0] = resampled_length(in_sh[0], in_rate[i], out_rate[i]); - if (downmix) - out_sh[1] = 1; } req.output_shapes = {out_shape}; return req; } void Run(KernelContext &context, const OutListGPU &out, - const InListGPU &in, span in_rates, span out_rates, - bool downmix) { + const InListGPU &in, span in_rates, + span out_rates) { if (window_gpu_storage_.empty()) Initialize(); @@ -92,11 +90,9 @@ class ResamplerGPU { int blocks_per_sample = std::max(32, 1024 / nsamples); dim3 grid(blocks_per_sample, nsamples); - BOOL_SWITCH(downmix && any_multichannel, Downmix, ( - BOOL_SWITCH(!any_multichannel, SingleChannel, ( - ResampleGPUKernel - <<>>(samples_gpu); - )); // NOLINT + BOOL_SWITCH(!any_multichannel, SingleChannel, ( + ResampleGPUKernel + <<>>(samples_gpu); )); // NOLINT CUDA_CALL(cudaGetLastError()); } diff --git a/dali/kernels/signal/resampling_gpu_test.cu b/dali/kernels/signal/resampling_gpu_test.cu index aae6033ff8b..23f1fef28d9 100644 --- a/dali/kernels/signal/resampling_gpu_test.cu +++ b/dali/kernels/signal/resampling_gpu_test.cu @@ -18,60 +18,49 @@ #include "dali/kernels/signal/resampling_gpu.h" #include "dali/kernels/signal/resampling_test.h" - namespace dali { namespace kernels { namespace signal { namespace resampling { +namespace test { class ResamplingGPUTest : public ResamplingTest { public: - void RunResampling(span in_rates, span out_rates, - bool downmix) override { + void RunResampling(span in_rates, span out_rates) override { ResamplerGPU R; R.Initialize(16); KernelContext ctx; ctx.gpu.stream = 0; - auto req = R.Setup(ctx, ttl_in_.gpu(), in_rates, out_rates, downmix); + auto req = R.Setup(ctx, ttl_in_.gpu(), in_rates, out_rates); auto outref_sh = ttl_outref_.cpu().shape; auto in_batch_sh = ttl_in_.cpu().shape; for (int s = 0; s < outref_sh.size(); s++) { auto sh = req.output_shapes[0].tensor_shape_span(s); auto expected_sh = outref_sh.tensor_shape_span(s); - auto in_sh = in_batch_sh.tensor_shape_span(s); - ASSERT_EQ(sh.size(), in_sh.size()); - if (downmix) { - ASSERT_EQ(sh[1], 1); - } else { - if (sh.size() > 1) - ASSERT_EQ(sh[1], in_sh[1]); - } + ASSERT_EQ(sh, expected_sh); } - R.Run(ctx, ttl_out_.gpu(), ttl_in_.gpu(), in_rates, out_rates, downmix); + R.Run(ctx, ttl_out_.gpu(), ttl_in_.gpu(), in_rates, out_rates); CUDA_CALL(cudaStreamSynchronize(ctx.gpu.stream)); } }; TEST_F(ResamplingGPUTest, SingleChannel) { - this->RunTest(8, 1, false); + this->RunTest(8, 1); } TEST_F(ResamplingGPUTest, TwoChannel) { - this->RunTest(3, 2, false); + this->RunTest(3, 2); } TEST_F(ResamplingGPUTest, EightChannel) { - this->RunTest(3, 8, false); -} - -TEST_F(ResamplingGPUTest, ThreeChannelDownmix) { - this->RunTest(3, 3, true); + this->RunTest(3, 8); } +} // namespace test } // namespace resampling } // namespace signal } // namespace kernels diff --git a/dali/kernels/signal/resampling_test.cc b/dali/kernels/signal/resampling_test.cc index 642177563c9..538506a283e 100644 --- a/dali/kernels/signal/resampling_test.cc +++ b/dali/kernels/signal/resampling_test.cc @@ -22,6 +22,12 @@ namespace dali { namespace kernels { namespace signal { namespace resampling { +namespace test { + +double HannWindow(int i, int n) { + assert(n > 0); + return Hann(2.0*i / n - 1); +} void ResamplingTest::PrepareData(int nsamples, int nchannels, span in_rates, span out_rates) { @@ -58,7 +64,7 @@ void ResamplingTest::PrepareData(int nsamples, int nchannels, span } } -void ResamplingTest::Verify(bool downmix) { +void ResamplingTest::Verify() { auto in_sh = ttl_in_.cpu().shape; auto out_sh = ttl_outref_.cpu().shape; int nsamples = in_sh.num_samples(); @@ -70,19 +76,9 @@ void ResamplingTest::Verify(bool downmix) { int n_out = out_sh.tensor_shape_span(s)[0]; int nchannels = out_sh.sample_dim() == 1 ? 1 : out_sh.tensor_shape_span(s)[1]; for (int i = 0; i < n_out; i++) { - float ref_val = 0; - if (downmix) { - for (int c = 0; c < nchannels; c++) { - ref_val += out_ref[i * nchannels + c]; - } - ref_val /= nchannels; - } else { - ref_val = out_ref[i]; - } - - ASSERT_NEAR(out_data[i], ref_val, eps()) + ASSERT_NEAR(out_data[i], out_ref[i], eps()) << "Sample error too big @ sample=" << s << " pos=" << i << std::endl; - float diff = std::abs(out_data[i] - ref_val); + float diff = std::abs(out_data[i] - out_ref[i]); if (diff > max_diff) max_diff = diff; err += diff * diff; @@ -96,7 +92,7 @@ void ResamplingTest::Verify(bool downmix) { } } -void ResamplingTest::RunTest(int nsamples, int nchannels, bool downmix) { +void ResamplingTest::RunTest(int nsamples, int nchannels) { std::vector in_rates_v; for (int i = 0; i < nsamples; i++) { if (i % 2 == 0) @@ -111,14 +107,14 @@ void ResamplingTest::RunTest(int nsamples, int nchannels, bool downmix) { PrepareData(nsamples, nchannels, in_rates, out_rates); - RunResampling(in_rates, out_rates, downmix); + RunResampling(in_rates, out_rates); - Verify(downmix); + Verify(); } class ResamplingCPUTest : public ResamplingTest { public: - void RunResampling(span in_rates, span out_rates, bool downmix) override { + void RunResampling(span in_rates, span out_rates) override { Resampler R; R.Initialize(16); @@ -134,27 +130,24 @@ class ResamplingCPUTest : public ResamplingTest { int n_in = in_sh[0]; int nchannels = in_sh.sample_dim() > 1 ? in_sh[1] : 1; R.Resample(out_view[s].data, 0, n_out, out_rates[s], in_view[s].data, n_in, in_rates[s], - nchannels, downmix); + nchannels); } } }; TEST_F(ResamplingCPUTest, SingleChannel) { - this->RunTest(1, 1, false); + this->RunTest(1, 1); } TEST_F(ResamplingCPUTest, TwoChannel) { - this->RunTest(1, 2, false); + this->RunTest(1, 2); } TEST_F(ResamplingCPUTest, EightChannel) { - this->RunTest(1, 8, false); -} - -TEST_F(ResamplingCPUTest, ThreeChannelDownmix) { - this->RunTest(1, 3, true); + this->RunTest(1, 8); } +} // namespace test } // namespace resampling } // namespace signal } // namespace kernels diff --git a/dali/kernels/signal/resampling_test.h b/dali/kernels/signal/resampling_test.h index 0163410daf1..2eea4965528 100644 --- a/dali/kernels/signal/resampling_test.h +++ b/dali/kernels/signal/resampling_test.h @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef DALI_KERNELS_SIGNAL_RESAMPLING_TEST_H_ +#define DALI_KERNELS_SIGNAL_RESAMPLING_TEST_H_ + #include #include #include @@ -23,13 +26,9 @@ namespace dali { namespace kernels { namespace signal { namespace resampling { +namespace test { -namespace { - -double HannWindow(int i, int n) { - assert(n > 0); - return Hann(2.0*i / n - 1); -} +double HannWindow(int i, int n); template void TestWave(T *out, int n, int stride, float freq) { @@ -39,8 +38,6 @@ void TestWave(T *out, int n, int stride, float freq) { } } -} // namespace - class ResamplingTest : public ::testing::Test { public: void PrepareData(int nsamples, int nchannels, @@ -48,11 +45,11 @@ class ResamplingTest : public ::testing::Test { virtual float eps() const { return 2e-3; } virtual float max_avg_err() const { return 1e-3; } - void Verify(bool downmix); + void Verify(); - virtual void RunResampling(span in_rates, span out_rates, bool downmix) = 0; + virtual void RunResampling(span in_rates, span out_rates) = 0; - void RunTest(int nsamples, int nchannels, bool downmix); + void RunTest(int nsamples, int nchannels); TestTensorList ttl_in_; @@ -61,7 +58,10 @@ class ResamplingTest : public ::testing::Test { }; +} // namespace test } // namespace resampling } // namespace signal } // namespace kernels } // namespace dali + +#endif // DALI_KERNELS_SIGNAL_RESAMPLING_TEST_H_ From 6065354cc055c558fafda9e53696ba60ebce03e8 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Wed, 11 May 2022 13:11:21 +0200 Subject: [PATCH 04/36] Code review fixes Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling.h | 15 +++++++++------ dali/kernels/signal/resampling_gpu.cuh | 4 ++-- dali/kernels/signal/resampling_gpu_test.cu | 2 +- dali/kernels/signal/resampling_test.h | 2 +- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/dali/kernels/signal/resampling.h b/dali/kernels/signal/resampling.h index 802c8a4c225..410ca30a3d7 100644 --- a/dali/kernels/signal/resampling.h +++ b/dali/kernels/signal/resampling.h @@ -29,7 +29,6 @@ #include "dali/core/small_vector.h" #include "dali/core/convert.h" #include "dali/core/static_switch.h" -#include "dali/core/geom/vec.h" namespace dali { namespace kernels { @@ -55,7 +54,11 @@ inline float32x4_t vsetq_f32(float x0, float x1, float x2, float x3) { #endif struct ResamplingWindow { - inline DALI_HOST_DEV ivec<2> input_range(float x) const { + struct InputRange { + int i0, i1; + }; + + inline DALI_HOST_DEV InputRange input_range(float x) const { int xc = std::ceil(x); int i0 = xc - lobes; int i1 = xc + lobes; @@ -222,8 +225,8 @@ struct Resampler { const float *__restrict__ in_block_ptr = in + in_block_i; for (int64_t out_pos = out_block; out_pos < block_end; out_pos++, in_pos += fscale) { auto irange = window.input_range(in_pos); - int i0 = irange[0]; - int i1 = irange[1]; + int i0 = irange.i0; + int i1 = irange.i1; if (i0 + in_block_i < 0) i0 = -in_block_i; if (i1 + in_block_i > n_in) @@ -285,8 +288,8 @@ struct Resampler { const float *__restrict__ in_block_ptr = in + in_block_i * num_channels; for (int64_t out_pos = out_block; out_pos < block_end; out_pos++, in_pos += fscale) { auto irange = window.input_range(in_pos); - int i0 = irange[0]; - int i1 = irange[1]; + int i0 = irange.i0; + int i1 = irange.i1; if (i0 + in_block_i < 0) i0 = -in_block_i; if (i1 + in_block_i > n_in) diff --git a/dali/kernels/signal/resampling_gpu.cuh b/dali/kernels/signal/resampling_gpu.cuh index 39ad0f7aec2..3f5878fe3f6 100644 --- a/dali/kernels/signal/resampling_gpu.cuh +++ b/dali/kernels/signal/resampling_gpu.cuh @@ -63,8 +63,8 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { out_pos += grid_stride, in_pos_start += fscale * grid_stride) { float in_pos = in_pos_start + fscale * threadIdx.x; auto i_range = window.input_range(in_pos); - int i0 = i_range[0]; - int i1 = i_range[1]; + int i0 = i_range.i0; + int i1 = i_range.i1; if (i0 + in_block_i < 0) i0 = -in_block_i; if (i1 + in_block_i > sample.in_len) diff --git a/dali/kernels/signal/resampling_gpu_test.cu b/dali/kernels/signal/resampling_gpu_test.cu index 23f1fef28d9..bf6bcf349bc 100644 --- a/dali/kernels/signal/resampling_gpu_test.cu +++ b/dali/kernels/signal/resampling_gpu_test.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dali/kernels/signal/resampling_test.h b/dali/kernels/signal/resampling_test.h index 2eea4965528..2afadd281df 100644 --- a/dali/kernels/signal/resampling_test.h +++ b/dali/kernels/signal/resampling_test.h @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. From 7c0162d12ddd064103c36a7481fae4f6e9cc1eb5 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Wed, 11 May 2022 19:39:24 +0200 Subject: [PATCH 05/36] Add benchmark Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling_gpu_test.cu | 43 ++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/dali/kernels/signal/resampling_gpu_test.cu b/dali/kernels/signal/resampling_gpu_test.cu index bf6bcf349bc..c6bee6a9d6b 100644 --- a/dali/kernels/signal/resampling_gpu_test.cu +++ b/dali/kernels/signal/resampling_gpu_test.cu @@ -17,6 +17,7 @@ #include #include "dali/kernels/signal/resampling_gpu.h" #include "dali/kernels/signal/resampling_test.h" +#include "dali/core/cuda_event.h" namespace dali { namespace kernels { @@ -46,6 +47,44 @@ class ResamplingGPUTest : public ResamplingTest { CUDA_CALL(cudaStreamSynchronize(ctx.gpu.stream)); } + + void RunPerfTest(int batch_size, int nchannels, int n_iters = 1000) { + std::vector in_rates_v(batch_size, 22050.0f); + auto in_rates = make_cspan(in_rates_v); + std::vector out_rates_v(batch_size, 16000.0f); + auto out_rates = make_cspan(out_rates_v); + + this->PrepareData(batch_size, nchannels, in_rates, out_rates); + + ResamplerGPU R; + R.Initialize(16); + + KernelContext ctx; + ctx.gpu.stream = 0; + + CUDAEvent start = CUDAEvent::CreateWithFlags(0); + CUDAEvent end = CUDAEvent::CreateWithFlags(0); + double avg_time = 0; + int64_t in_elems = ttl_in_.cpu().shape.num_elements(); + int64_t in_bytes = in_elems * sizeof(float); + std::cout << "Resampling GPU Perf test.\n" + << "\nInput contains " << in_elems << " floats.\n"; + + auto req = R.Setup(ctx, ttl_in_.gpu(), in_rates, out_rates); + ASSERT_EQ(ttl_out_.cpu().shape, req.output_shapes[0]); + + for (int i = 0; i < n_iters; ++i) { + CUDA_CALL(cudaEventRecord(start)); + R.Run(ctx, ttl_out_.gpu(), ttl_in_.gpu(), in_rates, out_rates); + CUDA_CALL(cudaEventRecord(end)); + CUDA_CALL(cudaDeviceSynchronize()); + float time; + CUDA_CALL(cudaEventElapsedTime(&time, start, end)); + + avg_time += time; + } + std::cout << "Processed " << in_bytes / avg_time << " bytes/sec" << std::endl; + } }; TEST_F(ResamplingGPUTest, SingleChannel) { @@ -60,6 +99,10 @@ TEST_F(ResamplingGPUTest, EightChannel) { this->RunTest(3, 8); } +TEST_F(ResamplingGPUTest, PerfTest) { + this->RunPerfTest(64, 1, 1000); +} + } // namespace test } // namespace resampling } // namespace signal From 7a39c95d906db4a440384e3d2a59f77fb33f4801 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Wed, 11 May 2022 19:39:46 +0200 Subject: [PATCH 06/36] Avoid precision issue & add shared memory usage Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling_gpu.cuh | 26 +++++++++++++++++++------- dali/kernels/signal/resampling_gpu.h | 3 ++- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/dali/kernels/signal/resampling_gpu.cuh b/dali/kernels/signal/resampling_gpu.cuh index 3f5878fe3f6..118828382f1 100644 --- a/dali/kernels/signal/resampling_gpu.cuh +++ b/dali/kernels/signal/resampling_gpu.cuh @@ -17,6 +17,7 @@ #include #include "dali/kernels/signal/resampling.h" +#include "dali/core/util.h" namespace dali { namespace kernels { @@ -45,22 +46,33 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { double scale = sample.scale; float fscale = scale; int nchannels = SingleChannel ? 1 : sample.nchannels; - auto &window = sample.window; + auto& window = sample.window; + + extern __shared__ float window_coeffs_sh[]; + for (int k = threadIdx.x; k < window.lookup_size; k += blockDim.x) { + window_coeffs_sh[k] = window.lookup[k]; + } + __syncthreads(); + window.lookup = window_coeffs_sh; Out* out = reinterpret_cast(sample.out); const In* in = reinterpret_cast(sample.in); - int64_t grid_stride = static_cast(gridDim.x) * blockDim.x; - int64_t out_block = static_cast(blockIdx.x) * blockDim.x; - int64_t start_out_pos = out_block + threadIdx.x; + int blocks = div_ceil(sample.out_len, static_cast(blockDim.x)); + int64_t start_block_idx = blockIdx.x * blocks / gridDim.x; + int64_t end_block_idx = (blockIdx.x + 1) * blocks / gridDim.x; + int64_t out_stride = blockDim.x; + float in_stride = fscale * blockDim.x; + int64_t out_block_start = start_block_idx * blockDim.x; + int64_t out_block_end = cuda_min(end_block_idx * blockDim.x, sample.out_len); - double in_block_f = out_block * scale; + double in_block_f = out_block_start * scale; int64_t in_block_i = std::floor(in_block_f); float in_pos_start = in_block_f - in_block_i; const In* in_blk_ptr = in + in_block_i * nchannels; - for (int64_t out_pos = start_out_pos; out_pos < sample.out_len; - out_pos += grid_stride, in_pos_start += fscale * grid_stride) { + for (int64_t out_pos = out_block_start + threadIdx.x; out_pos < out_block_end; + out_pos += out_stride, in_pos_start += in_stride) { float in_pos = in_pos_start + fscale * threadIdx.x; auto i_range = window.input_range(in_pos); int i0 = i_range.i0; diff --git a/dali/kernels/signal/resampling_gpu.h b/dali/kernels/signal/resampling_gpu.h index ed7db6ff2e9..b4a7715d170 100644 --- a/dali/kernels/signal/resampling_gpu.h +++ b/dali/kernels/signal/resampling_gpu.h @@ -89,10 +89,11 @@ class ResamplerGPU { dim3 block(256, 1); int blocks_per_sample = std::max(32, 1024 / nsamples); dim3 grid(blocks_per_sample, nsamples); + size_t shm_size = window_gpu_storage_.size() * sizeof(float); BOOL_SWITCH(!any_multichannel, SingleChannel, ( ResampleGPUKernel - <<>>(samples_gpu); + <<>>(samples_gpu); )); // NOLINT CUDA_CALL(cudaGetLastError()); } From 3ef0f581baca681eabfed46e15f265c923622cd0 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Thu, 12 May 2022 13:06:18 +0200 Subject: [PATCH 07/36] Move double in_block_f calculation inside the loop Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling_gpu.cuh | 36 ++++++++++++-------------- dali/kernels/signal/resampling_gpu.h | 8 ++++-- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/dali/kernels/signal/resampling_gpu.cuh b/dali/kernels/signal/resampling_gpu.cuh index 118828382f1..fa982ba6552 100644 --- a/dali/kernels/signal/resampling_gpu.cuh +++ b/dali/kernels/signal/resampling_gpu.cuh @@ -31,7 +31,7 @@ struct SampleDesc { ResamplingWindow window; int64_t in_len; // num samples in input int64_t out_len; // num samples in output - int64_t nchannels; // number of channels + int nchannels; // number of channels double scale; // in_sampling_rate / out_sampling_rate }; @@ -41,14 +41,17 @@ struct SampleDesc { * @param samples sample descriptors */ template -__global__ void ResampleGPUKernel(const SampleDesc *samples) { +__global__ void ResampleGPUKernel(const SampleDesc *samples, int max_nchannels) { auto sample = samples[blockIdx.y]; double scale = sample.scale; float fscale = scale; int nchannels = SingleChannel ? 1 : sample.nchannels; auto& window = sample.window; - extern __shared__ float window_coeffs_sh[]; + extern __shared__ float sh_mem[]; + float *window_coeffs_sh = sh_mem; + float *tmp = sh_mem + window.lookup_size + + threadIdx.x * max_nchannels; // used to accummulate per-channel out values for (int k = threadIdx.x; k < window.lookup_size; k += blockDim.x) { window_coeffs_sh[k] = window.lookup[k]; } @@ -58,22 +61,17 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { Out* out = reinterpret_cast(sample.out); const In* in = reinterpret_cast(sample.in); - int blocks = div_ceil(sample.out_len, static_cast(blockDim.x)); - int64_t start_block_idx = blockIdx.x * blocks / gridDim.x; - int64_t end_block_idx = (blockIdx.x + 1) * blocks / gridDim.x; - int64_t out_stride = blockDim.x; - float in_stride = fscale * blockDim.x; - int64_t out_block_start = start_block_idx * blockDim.x; - int64_t out_block_end = cuda_min(end_block_idx * blockDim.x, sample.out_len); - - double in_block_f = out_block_start * scale; - int64_t in_block_i = std::floor(in_block_f); - float in_pos_start = in_block_f - in_block_i; - const In* in_blk_ptr = in + in_block_i * nchannels; - - for (int64_t out_pos = out_block_start + threadIdx.x; out_pos < out_block_end; - out_pos += out_stride, in_pos_start += in_stride) { + int64_t grid_stride = static_cast(gridDim.x) * blockDim.x; + int64_t out_block_start = static_cast(blockIdx.x) * blockDim.x; + + for (int64_t out_pos = out_block_start + threadIdx.x; out_pos < sample.out_len; + out_block_start += grid_stride, out_pos += grid_stride) { + double in_block_f = out_block_start * scale; + int64_t in_block_i = std::floor(in_block_f); + float in_pos_start = in_block_f - in_block_i; + const In* in_blk_ptr = in + in_block_i * nchannels; float in_pos = in_pos_start + fscale * threadIdx.x; + auto i_range = window.input_range(in_pos); int i0 = i_range.i0; int i1 = i_range.i1; @@ -92,8 +90,6 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { } out[out_pos] = ConvertSatNorm(out_val); } else { // multiple channels - assert(nchannels <= 32); - float tmp[32]; // more than enough for (int c = 0; c < nchannels; c++) { tmp[c] = 0; } diff --git a/dali/kernels/signal/resampling_gpu.h b/dali/kernels/signal/resampling_gpu.h index b4a7715d170..a648d0b4356 100644 --- a/dali/kernels/signal/resampling_gpu.h +++ b/dali/kernels/signal/resampling_gpu.h @@ -69,6 +69,7 @@ class ResamplerGPU { make_span(scratch.Allocate(nsamples), nsamples); bool any_multichannel = false; + int max_nchannels = 0; for (int i = 0; i < nsamples; i++) { auto &desc = samples_cpu[i]; desc.in = in[i].data; @@ -80,6 +81,7 @@ class ResamplerGPU { desc.out_len = resampled_length(in_sh[0], in_rates[i], out_rates[i]); assert(desc.out_len == out_sh[0]); desc.nchannels = in[i].shape.sample_dim() > 1 ? in_sh[1] : 1; + max_nchannels = std::max(desc.nchannels, max_nchannels); desc.scale = static_cast(in_rates[i]) / out_rates[i]; any_multichannel |= desc.nchannels > 1; } @@ -89,11 +91,13 @@ class ResamplerGPU { dim3 block(256, 1); int blocks_per_sample = std::max(32, 1024 / nsamples); dim3 grid(blocks_per_sample, nsamples); - size_t shm_size = window_gpu_storage_.size() * sizeof(float); + + // window coefficients and temporary per channel out values + size_t shm_size = (window_gpu_storage_.size() + max_nchannels * block.x) * sizeof(float); BOOL_SWITCH(!any_multichannel, SingleChannel, ( ResampleGPUKernel - <<>>(samples_gpu); + <<>>(samples_gpu, max_nchannels); )); // NOLINT CUDA_CALL(cudaGetLastError()); } From a66a7631ec9e20b928c505a2b50a843101c0dd7e Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Thu, 12 May 2022 13:10:08 +0200 Subject: [PATCH 08/36] Fix benchmark Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling_gpu_test.cu | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dali/kernels/signal/resampling_gpu_test.cu b/dali/kernels/signal/resampling_gpu_test.cu index c6bee6a9d6b..b5c386eb006 100644 --- a/dali/kernels/signal/resampling_gpu_test.cu +++ b/dali/kernels/signal/resampling_gpu_test.cu @@ -64,7 +64,7 @@ class ResamplingGPUTest : public ResamplingTest { CUDAEvent start = CUDAEvent::CreateWithFlags(0); CUDAEvent end = CUDAEvent::CreateWithFlags(0); - double avg_time = 0; + double total_time = 0; int64_t in_elems = ttl_in_.cpu().shape.num_elements(); int64_t in_bytes = in_elems * sizeof(float); std::cout << "Resampling GPU Perf test.\n" @@ -80,10 +80,9 @@ class ResamplingGPUTest : public ResamplingTest { CUDA_CALL(cudaDeviceSynchronize()); float time; CUDA_CALL(cudaEventElapsedTime(&time, start, end)); - - avg_time += time; + total_time += time; } - std::cout << "Processed " << in_bytes / avg_time << " bytes/sec" << std::endl; + std::cout << "Processed " << n_iters * in_bytes / (total_time * 1e6) << " MBs/sec" << std::endl; } }; From 44354b25c17c73fc60352b195f711cf40c9b060f Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Thu, 12 May 2022 15:10:04 +0200 Subject: [PATCH 09/36] Update benchmark Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling_gpu.h | 4 +-- dali/kernels/signal/resampling_gpu_test.cu | 32 ++++++++++++++-------- dali/kernels/signal/resampling_test.cc | 4 +-- dali/kernels/signal/resampling_test.h | 2 +- 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/dali/kernels/signal/resampling_gpu.h b/dali/kernels/signal/resampling_gpu.h index a648d0b4356..8548b4c1570 100644 --- a/dali/kernels/signal/resampling_gpu.h +++ b/dali/kernels/signal/resampling_gpu.h @@ -59,9 +59,7 @@ class ResamplerGPU { if (window_gpu_storage_.empty()) Initialize(); - DynamicScratchpad dyn_scratchpad({}, AccessOrder(context.gpu.stream)); - if (!context.scratchpad) - context.scratchpad = &dyn_scratchpad; + assert(context.scratchpad); auto &scratch = *context.scratchpad; int nsamples = in.num_samples(); diff --git a/dali/kernels/signal/resampling_gpu_test.cu b/dali/kernels/signal/resampling_gpu_test.cu index b5c386eb006..278dced32d2 100644 --- a/dali/kernels/signal/resampling_gpu_test.cu +++ b/dali/kernels/signal/resampling_gpu_test.cu @@ -33,6 +33,8 @@ class ResamplingGPUTest : public ResamplingTest { KernelContext ctx; ctx.gpu.stream = 0; + DynamicScratchpad dyn_scratchpad({}, AccessOrder(ctx.gpu.stream)); + ctx.scratchpad = &dyn_scratchpad; auto req = R.Setup(ctx, ttl_in_.gpu(), in_rates, out_rates); auto outref_sh = ttl_outref_.cpu().shape; @@ -53,36 +55,44 @@ class ResamplingGPUTest : public ResamplingTest { auto in_rates = make_cspan(in_rates_v); std::vector out_rates_v(batch_size, 16000.0f); auto out_rates = make_cspan(out_rates_v); + int nsec = 30; - this->PrepareData(batch_size, nchannels, in_rates, out_rates); + this->PrepareData(batch_size, nchannels, in_rates, out_rates, nsec); ResamplerGPU R; R.Initialize(16); KernelContext ctx; ctx.gpu.stream = 0; + auto req = R.Setup(ctx, ttl_in_.gpu(), in_rates, out_rates); + ASSERT_EQ(ttl_out_.cpu().shape, req.output_shapes[0]); CUDAEvent start = CUDAEvent::CreateWithFlags(0); CUDAEvent end = CUDAEvent::CreateWithFlags(0); - double total_time = 0; + double total_time_ms = 0; int64_t in_elems = ttl_in_.cpu().shape.num_elements(); - int64_t in_bytes = in_elems * sizeof(float); + int64_t out_elems = ttl_out_.cpu().shape.num_elements(); + int64_t out_bytes = out_elems * sizeof(float); std::cout << "Resampling GPU Perf test.\n" - << "\nInput contains " << in_elems << " floats.\n"; - - auto req = R.Setup(ctx, ttl_in_.gpu(), in_rates, out_rates); - ASSERT_EQ(ttl_out_.cpu().shape, req.output_shapes[0]); + << "Input contains " << in_elems << " floats.\n" + << "Output contains " << out_elems << " floats.\n"; for (int i = 0; i < n_iters; ++i) { + CUDA_CALL(cudaDeviceSynchronize()); + + DynamicScratchpad dyn_scratchpad({}, AccessOrder(ctx.gpu.stream)); + ctx.scratchpad = &dyn_scratchpad; + CUDA_CALL(cudaEventRecord(start)); R.Run(ctx, ttl_out_.gpu(), ttl_in_.gpu(), in_rates, out_rates); CUDA_CALL(cudaEventRecord(end)); CUDA_CALL(cudaDeviceSynchronize()); - float time; - CUDA_CALL(cudaEventElapsedTime(&time, start, end)); - total_time += time; + float time_ms; + CUDA_CALL(cudaEventElapsedTime(&time_ms, start, end)); + total_time_ms += time_ms; } - std::cout << "Processed " << n_iters * in_bytes / (total_time * 1e6) << " MBs/sec" << std::endl; + std::cout << "Processed " << n_iters * out_bytes / (total_time_ms * 1e6) << " GBs/sec" + << std::endl; } }; diff --git a/dali/kernels/signal/resampling_test.cc b/dali/kernels/signal/resampling_test.cc index 538506a283e..886e7bfe8bf 100644 --- a/dali/kernels/signal/resampling_test.cc +++ b/dali/kernels/signal/resampling_test.cc @@ -30,14 +30,14 @@ double HannWindow(int i, int n) { } void ResamplingTest::PrepareData(int nsamples, int nchannels, span in_rates, - span out_rates) { + span out_rates, int nsec) { TensorListShape<> in_sh(nsamples, nchannels > 1 ? 2 : 1); TensorListShape<> out_sh(nsamples, nchannels > 1 ? 2 : 1); for (int s = 0; s < nsamples; s++) { double in_rate = in_rates[s]; double out_rate = out_rates[s]; double scale = static_cast(in_rate) / out_rate; - int n_in = in_rate + 12345 * s; // different lengths + int n_in = nsec * in_rate + 12345 * s; // different lengths int n_out = std::ceil(n_in / scale); in_sh.tensor_shape_span(s)[0] = n_in; out_sh.tensor_shape_span(s)[0] = n_out; diff --git a/dali/kernels/signal/resampling_test.h b/dali/kernels/signal/resampling_test.h index 2afadd281df..eefa90fbba8 100644 --- a/dali/kernels/signal/resampling_test.h +++ b/dali/kernels/signal/resampling_test.h @@ -41,7 +41,7 @@ void TestWave(T *out, int n, int stride, float freq) { class ResamplingTest : public ::testing::Test { public: void PrepareData(int nsamples, int nchannels, - span in_rates, span out_rates); + span in_rates, span out_rates, int nsec = 1); virtual float eps() const { return 2e-3; } virtual float max_avg_err() const { return 1e-3; } From e4eb226d5335214f53b8c15cba39694b77a38fe5 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Thu, 12 May 2022 18:49:42 +0200 Subject: [PATCH 10/36] ROI & input conversion to float & limit tmp shared mem Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling.h | 11 +++- dali/kernels/signal/resampling_gpu.cuh | 60 +++++++++++------ dali/kernels/signal/resampling_gpu.h | 27 ++++---- dali/kernels/signal/resampling_gpu_test.cu | 32 +++++---- dali/kernels/signal/resampling_test.cc | 77 ++++++++++++---------- dali/kernels/signal/resampling_test.h | 20 +++--- 6 files changed, 141 insertions(+), 86 deletions(-) diff --git a/dali/kernels/signal/resampling.h b/dali/kernels/signal/resampling.h index 410ca30a3d7..9112b602784 100644 --- a/dali/kernels/signal/resampling.h +++ b/dali/kernels/signal/resampling.h @@ -36,6 +36,11 @@ namespace signal { namespace resampling { +struct Args { + double in_rate = 1, out_rate = 1; + int64_t out_begin = 0, out_end = -1; // default values result in the whole range +}; + inline double Hann(double x) { return 0.5 * (1 + std::cos(x * M_PI)); } @@ -241,7 +246,8 @@ struct Resampler { f += in_block_ptr[i] * w; } assert(out_pos >= out_begin && out_pos < out_end); - out[out_pos] = ConvertSatNorm(f); + auto rel_pos = out_pos - out_begin; + out[rel_pos] = ConvertSatNorm(f); } } } @@ -310,8 +316,9 @@ struct Resampler { } } assert(out_pos >= out_begin && out_pos < out_end); + auto rel_pos = out_pos - out_begin; for (int c = 0; c < num_channels; c++) - out[out_pos * num_channels + c] = ConvertSatNorm(tmp[c]); + out[rel_pos * num_channels + c] = ConvertSatNorm(tmp[c]); } } } diff --git a/dali/kernels/signal/resampling_gpu.cuh b/dali/kernels/signal/resampling_gpu.cuh index fa982ba6552..b501eaaab8f 100644 --- a/dali/kernels/signal/resampling_gpu.cuh +++ b/dali/kernels/signal/resampling_gpu.cuh @@ -19,6 +19,8 @@ #include "dali/kernels/signal/resampling.h" #include "dali/core/util.h" +#define SHM_NCHANNELS 16 + namespace dali { namespace kernels { namespace signal { @@ -30,18 +32,33 @@ struct SampleDesc { const void *in; ResamplingWindow window; int64_t in_len; // num samples in input - int64_t out_len; // num samples in output + int64_t out_begin; // output region-of-interest start + int64_t out_end; // output region-of-interest end int nchannels; // number of channels double scale; // in_sampling_rate / out_sampling_rate }; +/** + * @brief Gets intermediate floating point representation depending on the input/output types + */ +template +__device__ float ConvertInput(In in_val) { + if (std::is_unsigned::value && std::is_signed::value) { + return (ConvertSatNorm(in_val) + 1.0f) * 0.5f; + } else if (std::is_signed::value && std::is_unsigned::value) { + return ConvertSatNorm(in_val) * 2.0f - 1.0f; // treat half-range as 0 + } else { + return ConvertSatNorm(in_val); // just normalize + } +} + /** * @brief Resamples 1D signal (single or multi-channel), optionally converting to a different data type. * * @param samples sample descriptors */ template -__global__ void ResampleGPUKernel(const SampleDesc *samples, int max_nchannels) { +__global__ void ResampleGPUKernel(const SampleDesc *samples) { auto sample = samples[blockIdx.y]; double scale = sample.scale; float fscale = scale; @@ -51,7 +68,7 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples, int max_nchannels) extern __shared__ float sh_mem[]; float *window_coeffs_sh = sh_mem; float *tmp = sh_mem + window.lookup_size + - threadIdx.x * max_nchannels; // used to accummulate per-channel out values + threadIdx.x * (SHM_NCHANNELS+1); // used to accummulate per-channel out values for (int k = threadIdx.x; k < window.lookup_size; k += blockDim.x) { window_coeffs_sh[k] = window.lookup[k]; } @@ -62,9 +79,9 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples, int max_nchannels) const In* in = reinterpret_cast(sample.in); int64_t grid_stride = static_cast(gridDim.x) * blockDim.x; - int64_t out_block_start = static_cast(blockIdx.x) * blockDim.x; + int64_t out_block_start = sample.out_begin + static_cast(blockIdx.x) * blockDim.x; - for (int64_t out_pos = out_block_start + threadIdx.x; out_pos < sample.out_len; + for (int64_t out_pos = out_block_start + threadIdx.x; out_pos < sample.out_end; out_block_start += grid_stride, out_pos += grid_stride) { double in_block_f = out_block_start * scale; int64_t in_block_i = std::floor(in_block_f); @@ -83,28 +100,31 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples, int max_nchannels) float out_val = 0; if (SingleChannel) { for (int i = i0; i < i1; i++) { - In in_val = in_blk_ptr[i]; + float in_val = ConvertInput(in_blk_ptr[i]); float x = i - in_pos; float w = window(x); out_val = fma(in_val, w, out_val); } - out[out_pos] = ConvertSatNorm(out_val); + out[out_pos - sample.out_begin] = ConvertSatNorm(out_val); } else { // multiple channels - for (int c = 0; c < nchannels; c++) { - tmp[c] = 0; - } - - for (int i = i0; i < i1; i++) { - float x = i - in_pos; - float w = window(x); - for (int c = 0; c < nchannels; c++) { - In in_val = in_blk_ptr[i * nchannels + c]; - tmp[c] = fma(in_val, w, tmp[c]); + for (int c0 = 0; c0 < nchannels; c0 += SHM_NCHANNELS) { + int nc = cuda_min(SHM_NCHANNELS, nchannels - c0); + for (int c = 0; c < nc; c++) { + tmp[c] = 0; } - } - for (int c = 0; c < nchannels; c++) { - out[out_pos * nchannels + c] = ConvertSatNorm(tmp[c]); + for (int i = i0; i < i1; i++) { + float x = i - in_pos; + float w = window(x); + for (int c = 0; c < nc; c++) { + float in_val = ConvertInput(in_blk_ptr[i * nchannels + c]); + tmp[c] = fma(in_val, w, tmp[c]); + } + } + Out *out_ptr = out + (out_pos - sample.out_begin) * nchannels; + for (int c = 0; c < nc; c++) { + out_ptr[c + c0] = ConvertSatNorm(tmp[c]); + } } } } diff --git a/dali/kernels/signal/resampling_gpu.h b/dali/kernels/signal/resampling_gpu.h index 8548b4c1570..a02fe6a855f 100644 --- a/dali/kernels/signal/resampling_gpu.h +++ b/dali/kernels/signal/resampling_gpu.h @@ -41,21 +41,25 @@ class ResamplerGPU { } KernelRequirements Setup(KernelContext &context, const InListGPU &in, - span in_rate, span out_rate) { + span args) { KernelRequirements req; auto out_shape = in.shape; for (int i = 0; i < in.num_samples(); i++) { auto in_sh = in.shape.tensor_shape_span(i); auto out_sh = out_shape.tensor_shape_span(i); - out_sh[0] = resampled_length(in_sh[0], in_rate[i], out_rate[i]); + auto &arg = args[i]; + if (arg.out_begin > 0 || arg.out_end > 0) { + out_sh[0] = arg.out_end - arg.out_begin; + } else { + out_sh[0] = resampled_length(in_sh[0], arg.in_rate, arg.out_rate); + } } req.output_shapes = {out_shape}; return req; } void Run(KernelContext &context, const OutListGPU &out, - const InListGPU &in, span in_rates, - span out_rates) { + const InListGPU &in, span args) { if (window_gpu_storage_.empty()) Initialize(); @@ -67,7 +71,6 @@ class ResamplerGPU { make_span(scratch.Allocate(nsamples), nsamples); bool any_multichannel = false; - int max_nchannels = 0; for (int i = 0; i < nsamples; i++) { auto &desc = samples_cpu[i]; desc.in = in[i].data; @@ -76,11 +79,13 @@ class ResamplerGPU { const auto &in_sh = in[i].shape; const auto &out_sh = out[i].shape; desc.in_len = in_sh[0]; - desc.out_len = resampled_length(in_sh[0], in_rates[i], out_rates[i]); - assert(desc.out_len == out_sh[0]); + auto &arg = args[i]; + desc.out_begin = arg.out_begin > 0 ? arg.out_begin : 0; + desc.out_end = arg.out_end > 0 ? arg.out_end : + resampled_length(in_sh[0], arg.in_rate, arg.out_rate); + assert((desc.out_end - desc.out_begin) == out_sh[0]); desc.nchannels = in[i].shape.sample_dim() > 1 ? in_sh[1] : 1; - max_nchannels = std::max(desc.nchannels, max_nchannels); - desc.scale = static_cast(in_rates[i]) / out_rates[i]; + desc.scale = arg.in_rate / arg.out_rate; any_multichannel |= desc.nchannels > 1; } @@ -91,11 +96,11 @@ class ResamplerGPU { dim3 grid(blocks_per_sample, nsamples); // window coefficients and temporary per channel out values - size_t shm_size = (window_gpu_storage_.size() + max_nchannels * block.x) * sizeof(float); + size_t shm_size = (window_gpu_storage_.size() + (SHM_NCHANNELS + 1) * block.x) * sizeof(float); BOOL_SWITCH(!any_multichannel, SingleChannel, ( ResampleGPUKernel - <<>>(samples_gpu, max_nchannels); + <<>>(samples_gpu); )); // NOLINT CUDA_CALL(cudaGetLastError()); } diff --git a/dali/kernels/signal/resampling_gpu_test.cu b/dali/kernels/signal/resampling_gpu_test.cu index 278dced32d2..a4dc4f962c7 100644 --- a/dali/kernels/signal/resampling_gpu_test.cu +++ b/dali/kernels/signal/resampling_gpu_test.cu @@ -27,7 +27,7 @@ namespace test { class ResamplingGPUTest : public ResamplingTest { public: - void RunResampling(span in_rates, span out_rates) override { + void RunResampling(span args) override { ResamplerGPU R; R.Initialize(16); @@ -36,7 +36,7 @@ class ResamplingGPUTest : public ResamplingTest { DynamicScratchpad dyn_scratchpad({}, AccessOrder(ctx.gpu.stream)); ctx.scratchpad = &dyn_scratchpad; - auto req = R.Setup(ctx, ttl_in_.gpu(), in_rates, out_rates); + auto req = R.Setup(ctx, ttl_in_.gpu(), args); auto outref_sh = ttl_outref_.cpu().shape; auto in_batch_sh = ttl_in_.cpu().shape; for (int s = 0; s < outref_sh.size(); s++) { @@ -45,26 +45,24 @@ class ResamplingGPUTest : public ResamplingTest { ASSERT_EQ(sh, expected_sh); } - R.Run(ctx, ttl_out_.gpu(), ttl_in_.gpu(), in_rates, out_rates); + R.Run(ctx, ttl_out_.gpu(), ttl_in_.gpu(), args); CUDA_CALL(cudaStreamSynchronize(ctx.gpu.stream)); } void RunPerfTest(int batch_size, int nchannels, int n_iters = 1000) { - std::vector in_rates_v(batch_size, 22050.0f); - auto in_rates = make_cspan(in_rates_v); - std::vector out_rates_v(batch_size, 16000.0f); - auto out_rates = make_cspan(out_rates_v); + std::vector args_v(batch_size, {22050.0f, 16000.0f}); + auto args = make_cspan(args_v); int nsec = 30; - this->PrepareData(batch_size, nchannels, in_rates, out_rates, nsec); + this->PrepareData(batch_size, nchannels, args, nsec); ResamplerGPU R; R.Initialize(16); KernelContext ctx; ctx.gpu.stream = 0; - auto req = R.Setup(ctx, ttl_in_.gpu(), in_rates, out_rates); + auto req = R.Setup(ctx, ttl_in_.gpu(), args); ASSERT_EQ(ttl_out_.cpu().shape, req.output_shapes[0]); CUDAEvent start = CUDAEvent::CreateWithFlags(0); @@ -84,7 +82,7 @@ class ResamplingGPUTest : public ResamplingTest { ctx.scratchpad = &dyn_scratchpad; CUDA_CALL(cudaEventRecord(start)); - R.Run(ctx, ttl_out_.gpu(), ttl_in_.gpu(), in_rates, out_rates); + R.Run(ctx, ttl_out_.gpu(), ttl_in_.gpu(), args); CUDA_CALL(cudaEventRecord(end)); CUDA_CALL(cudaDeviceSynchronize()); float time_ms; @@ -108,7 +106,19 @@ TEST_F(ResamplingGPUTest, EightChannel) { this->RunTest(3, 8); } -TEST_F(ResamplingGPUTest, PerfTest) { +TEST_F(ResamplingGPUTest, HundredChannel) { + this->RunTest(3, 100); +} + +TEST_F(ResamplingGPUTest, OutBeginEnd) { + this->RunTest(3, 1, true); +} + +TEST_F(ResamplingGPUTest, EightChannelOutBeginEnd) { + this->RunTest(3, 8, true); +} + +TEST_F(ResamplingGPUTest, DISABLED_PerfTest) { this->RunPerfTest(64, 1, 1000); } diff --git a/dali/kernels/signal/resampling_test.cc b/dali/kernels/signal/resampling_test.cc index 886e7bfe8bf..a4070db3b4b 100644 --- a/dali/kernels/signal/resampling_test.cc +++ b/dali/kernels/signal/resampling_test.cc @@ -26,21 +26,23 @@ namespace test { double HannWindow(int i, int n) { assert(n > 0); - return Hann(2.0*i / n - 1); + return Hann(2.0 * i / n - 1); } -void ResamplingTest::PrepareData(int nsamples, int nchannels, span in_rates, - span out_rates, int nsec) { +void ResamplingTest::PrepareData(int nsamples, int nchannels, span args, int nsec) { TensorListShape<> in_sh(nsamples, nchannels > 1 ? 2 : 1); TensorListShape<> out_sh(nsamples, nchannels > 1 ? 2 : 1); for (int s = 0; s < nsamples; s++) { - double in_rate = in_rates[s]; - double out_rate = out_rates[s]; - double scale = static_cast(in_rate) / out_rate; + double in_rate = args[s].in_rate; + double out_rate = args[s].out_rate; + double scale = in_rate / out_rate; int n_in = nsec * in_rate + 12345 * s; // different lengths int n_out = std::ceil(n_in / scale); + int64_t out_begin = args[s].out_begin > 0 ? args[s].out_begin : 0; + int64_t out_end = args[s].out_end > 0 ? args[s].out_end : n_out; + ASSERT_GT(out_end, out_begin); in_sh.tensor_shape_span(s)[0] = n_in; - out_sh.tensor_shape_span(s)[0] = n_out; + out_sh.tensor_shape_span(s)[0] = out_end - out_begin; if (nchannels > 1) { in_sh.tensor_shape_span(s)[1] = nchannels; out_sh.tensor_shape_span(s)[1] = nchannels; @@ -50,21 +52,23 @@ void ResamplingTest::PrepareData(int nsamples, int nchannels, span ttl_out_.reshape(out_sh); ttl_outref_.reshape(out_sh); for (int s = 0; s < nsamples; s++) { - double in_rate = in_rates[s]; - double out_rate = out_rates[s]; + double in_rate = args[s].in_rate; + double out_rate = args[s].out_rate; double scale = static_cast(in_rate) / out_rate; + int64_t n_in = in_sh.tensor_shape_span(s)[0]; + int n_out = std::ceil(n_in / scale); + int64_t out_begin = args[s].out_begin > 0 ? args[s].out_begin : 0; + int64_t out_end = args[s].out_end > 0 ? args[s].out_end : n_out; for (int c = 0; c < nchannels; c++) { float f_in = 0.1f + 0.01 * s + 0.001 * c; float f_out = f_in * scale; - int n_in = in_sh.tensor_shape_span(s)[0]; - int n_out = out_sh.tensor_shape_span(s)[0]; TestWave(ttl_in_.cpu()[s].data + c, n_in, nchannels, f_in); - TestWave(ttl_outref_.cpu()[s].data + c, n_out, nchannels, f_out); + TestWave(ttl_outref_.cpu()[s].data + c, n_out, nchannels, f_out, out_begin, out_end); } } } -void ResamplingTest::Verify() { +void ResamplingTest::Verify(span args) { auto in_sh = ttl_in_.cpu().shape; auto out_sh = ttl_outref_.cpu().shape; int nsamples = in_sh.num_samples(); @@ -73,9 +77,9 @@ void ResamplingTest::Verify() { for (int s = 0; s < nsamples; s++) { float *out_data = ttl_out_.cpu()[s].data; float *out_ref = ttl_outref_.cpu()[s].data; - int n_out = out_sh.tensor_shape_span(s)[0]; + int64_t out_len = out_sh.tensor_shape_span(s)[0]; int nchannels = out_sh.sample_dim() == 1 ? 1 : out_sh.tensor_shape_span(s)[1]; - for (int i = 0; i < n_out; i++) { + for (int64_t i = 0; i < out_len; i++) { ASSERT_NEAR(out_data[i], out_ref[i], eps()) << "Sample error too big @ sample=" << s << " pos=" << i << std::endl; float diff = std::abs(out_data[i] - out_ref[i]); @@ -84,7 +88,7 @@ void ResamplingTest::Verify() { err += diff * diff; } - err = std::sqrt(err / n_out); + err = std::sqrt(err / out_len); EXPECT_LE(err, max_avg_err()) << "Average error too big"; std::cerr << "Resampling with Hann-windowed sinc filter and 16 zero crossings" "\n max difference vs fresh signal: " @@ -92,45 +96,44 @@ void ResamplingTest::Verify() { } } -void ResamplingTest::RunTest(int nsamples, int nchannels) { - std::vector in_rates_v; +void ResamplingTest::RunTest(int nsamples, int nchannels, bool use_roi) { + std::vector args_v; for (int i = 0; i < nsamples; i++) { + int roi_start = use_roi ? 100 : 0; + int roi_end = use_roi ? 8000 : -1; if (i % 2 == 0) - in_rates_v.push_back(22050.0f); + args_v.push_back({22050.0f, 16000.0f, roi_start, roi_end}); else - in_rates_v.push_back(44100.0f); + args_v.push_back({44100.0f, 16000.0f, roi_start, roi_end}); } - auto in_rates = make_cspan(in_rates_v); + auto args = make_cspan(args_v); - std::vector out_rates_v(nsamples, 16000.0f); - auto out_rates = make_cspan(out_rates_v); + PrepareData(nsamples, nchannels, args); - PrepareData(nsamples, nchannels, in_rates, out_rates); + RunResampling(args); - RunResampling(in_rates, out_rates); - - Verify(); + Verify(args); } class ResamplingCPUTest : public ResamplingTest { public: - void RunResampling(span in_rates, span out_rates) override { + void RunResampling(span args) override { Resampler R; R.Initialize(16); - int nsamples = in_rates.size(); - assert(nsamples == out_rates.size()); + int nsamples = args.size(); auto in_view = ttl_in_.cpu(); auto out_view = ttl_out_.cpu(); for (int s = 0; s < nsamples; s++) { auto out_sh = out_view.shape[s]; auto in_sh = in_view.shape[s]; - int n_out = out_sh[0]; int n_in = in_sh[0]; int nchannels = in_sh.sample_dim() > 1 ? in_sh[1] : 1; - R.Resample(out_view[s].data, 0, n_out, out_rates[s], in_view[s].data, n_in, in_rates[s], - nchannels); + int64_t out_begin = args[s].out_begin > 0 ? args[s].out_begin : 0; + int64_t out_end = args[s].out_end > 0 ? args[s].out_end : out_sh[0]; + R.Resample(out_view[s].data, out_begin, out_end, args[s].out_rate, + in_view[s].data, n_in, args[s].in_rate, nchannels); } } }; @@ -147,6 +150,14 @@ TEST_F(ResamplingCPUTest, EightChannel) { this->RunTest(1, 8); } +TEST_F(ResamplingCPUTest, OutBeginEnd) { + this->RunTest(1, 1, true); +} + +TEST_F(ResamplingCPUTest, EightChannelOutBeginEnd) { + this->RunTest(3, 8, true); +} + } // namespace test } // namespace resampling } // namespace signal diff --git a/dali/kernels/signal/resampling_test.h b/dali/kernels/signal/resampling_test.h index eefa90fbba8..51f7c4be7b2 100644 --- a/dali/kernels/signal/resampling_test.h +++ b/dali/kernels/signal/resampling_test.h @@ -31,25 +31,27 @@ namespace test { double HannWindow(int i, int n); template -void TestWave(T *out, int n, int stride, float freq) { - for (int i = 0; i < n; i++) { - float f = std::sin(i* freq) * HannWindow(i, n); - out[i*stride] = ConvertSatNorm(f); +void TestWave(T *out, int n, int stride, float freq, int i_start = 0, int i_end = -1) { + if (i_end <= 0) i_end = n; + assert(i_start >= 0 && i_start <= n); + assert(i_end >= 0 && i_end <= n); + for (int i = i_start; i < i_end; i++) { + float f = std::sin(i * freq) * HannWindow(i, n); + out[(i - i_start) * stride] = ConvertSatNorm(f); } } class ResamplingTest : public ::testing::Test { public: - void PrepareData(int nsamples, int nchannels, - span in_rates, span out_rates, int nsec = 1); + void PrepareData(int nsamples, int nchannels, span args, int nsec = 1); virtual float eps() const { return 2e-3; } virtual float max_avg_err() const { return 1e-3; } - void Verify(); + void Verify(span args); - virtual void RunResampling(span in_rates, span out_rates) = 0; + virtual void RunResampling(span args) = 0; - void RunTest(int nsamples, int nchannels); + void RunTest(int nsamples, int nchannels, bool use_roi = false); TestTensorList ttl_in_; From 13ff3db7442ec9784149371f6981acae6c7765a2 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Mon, 16 May 2022 13:24:48 +0200 Subject: [PATCH 11/36] Add comments Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling.h | 4 ++++ dali/kernels/signal/resampling_gpu.cuh | 3 +++ 2 files changed, 7 insertions(+) diff --git a/dali/kernels/signal/resampling.h b/dali/kernels/signal/resampling.h index 9112b602784..d03ebe896dc 100644 --- a/dali/kernels/signal/resampling.h +++ b/dali/kernels/signal/resampling.h @@ -70,6 +70,10 @@ struct ResamplingWindow { return {i0, i1}; } + /** + * @brief Calculates the window coefficient at an arbitrary floating point position + * by interpolating between two samples. + */ inline DALI_HOST_DEV float operator()(float x) const { float fi = x * scale + center; float floori = std::floor(fi); diff --git a/dali/kernels/signal/resampling_gpu.cuh b/dali/kernels/signal/resampling_gpu.cuh index b501eaaab8f..f08473aa295 100644 --- a/dali/kernels/signal/resampling_gpu.cuh +++ b/dali/kernels/signal/resampling_gpu.cuh @@ -83,6 +83,9 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { for (int64_t out_pos = out_block_start + threadIdx.x; out_pos < sample.out_end; out_block_start += grid_stride, out_pos += grid_stride) { + // A floating point distance `in_pos_start` is calculated from an arbitrary relative + // position, keeping the floats small in order to keep precision. `in_block_f`, used to + // calculate the reference for distance (in_block_i) needs to be calculated in double precision. double in_block_f = out_block_start * scale; int64_t in_block_i = std::floor(in_block_f); float in_pos_start = in_block_f - in_block_i; From 3467768bfda574dc57c6232c3f47c17dbba3f639 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Mon, 16 May 2022 13:47:55 +0200 Subject: [PATCH 12/36] Use floorf and ceilf in CUDA code Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling.h | 4 ++-- dali/kernels/signal/resampling_gpu.cuh | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dali/kernels/signal/resampling.h b/dali/kernels/signal/resampling.h index d03ebe896dc..0c0a968be50 100644 --- a/dali/kernels/signal/resampling.h +++ b/dali/kernels/signal/resampling.h @@ -64,7 +64,7 @@ struct ResamplingWindow { }; inline DALI_HOST_DEV InputRange input_range(float x) const { - int xc = std::ceil(x); + int xc = ceilf(x); int i0 = xc - lobes; int i1 = xc + lobes; return {i0, i1}; @@ -76,7 +76,7 @@ struct ResamplingWindow { */ inline DALI_HOST_DEV float operator()(float x) const { float fi = x * scale + center; - float floori = std::floor(fi); + float floori = floorf(fi); float di = fi - floori; int i = floori; assert(i >= 0 && i < lookup_size); diff --git a/dali/kernels/signal/resampling_gpu.cuh b/dali/kernels/signal/resampling_gpu.cuh index f08473aa295..6c612fbad7b 100644 --- a/dali/kernels/signal/resampling_gpu.cuh +++ b/dali/kernels/signal/resampling_gpu.cuh @@ -87,7 +87,7 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { // position, keeping the floats small in order to keep precision. `in_block_f`, used to // calculate the reference for distance (in_block_i) needs to be calculated in double precision. double in_block_f = out_block_start * scale; - int64_t in_block_i = std::floor(in_block_f); + int64_t in_block_i = floorf(in_block_f); float in_pos_start = in_block_f - in_block_i; const In* in_blk_ptr = in + in_block_i * nchannels; float in_pos = in_pos_start + fscale * threadIdx.x; From b10d860f4462409e5a366f846ffa8c5a8a8b7477 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Mon, 16 May 2022 18:32:59 +0200 Subject: [PATCH 13/36] Improve tests & fix bugs Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling.h | 4 +- dali/kernels/signal/resampling_gpu.cuh | 10 ++- dali/kernels/signal/resampling_gpu.h | 16 ++-- dali/kernels/signal/resampling_gpu_test.cu | 57 ++++++++++--- dali/kernels/signal/resampling_test.cc | 99 +++++++++++++++------- dali/kernels/signal/resampling_test.h | 27 ++++-- 6 files changed, 150 insertions(+), 63 deletions(-) diff --git a/dali/kernels/signal/resampling.h b/dali/kernels/signal/resampling.h index 0c0a968be50..dc47d601ea4 100644 --- a/dali/kernels/signal/resampling.h +++ b/dali/kernels/signal/resampling.h @@ -222,7 +222,7 @@ struct Resampler { Out *__restrict__ out, int64_t out_begin, int64_t out_end, double out_rate, const float *__restrict__ in, int64_t n_in, double in_rate) const { assert(out_rate > 0 && in_rate > 0 && "Sampling rate must be positive"); - int64_t block = 1 << 10; // still leaves 13 significant bits for fractional part + int64_t block = 1 << 8; // still leaves 15 significant bits for fractional part double scale = in_rate / out_rate; float fscale = scale; @@ -285,7 +285,7 @@ struct Resampler { const int num_channels = static_channels < 0 ? dynamic_num_channels : static_channels; assert(num_channels > 0); - int64_t block = 1 << 10; // still leaves 13 significant bits for fractional part + int64_t block = 1 << 8; // still leaves 15 significant bits for fractional part double scale = in_rate / out_rate; float fscale = scale; SmallVector tmp; diff --git a/dali/kernels/signal/resampling_gpu.cuh b/dali/kernels/signal/resampling_gpu.cuh index 6c612fbad7b..452f7f1c72d 100644 --- a/dali/kernels/signal/resampling_gpu.cuh +++ b/dali/kernels/signal/resampling_gpu.cuh @@ -87,7 +87,7 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { // position, keeping the floats small in order to keep precision. `in_block_f`, used to // calculate the reference for distance (in_block_i) needs to be calculated in double precision. double in_block_f = out_block_start * scale; - int64_t in_block_i = floorf(in_block_f); + int64_t in_block_i = floor(in_block_f); float in_pos_start = in_block_f - in_block_i; const In* in_blk_ptr = in + in_block_i * nchannels; float in_pos = in_pos_start + fscale * threadIdx.x; @@ -110,6 +110,7 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { } out[out_pos - sample.out_begin] = ConvertSatNorm(out_val); } else { // multiple channels + Out *out_ptr = out + (out_pos - sample.out_begin) * nchannels; for (int c0 = 0; c0 < nchannels; c0 += SHM_NCHANNELS) { int nc = cuda_min(SHM_NCHANNELS, nchannels - c0); for (int c = 0; c < nc; c++) { @@ -119,14 +120,15 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { for (int i = i0; i < i1; i++) { float x = i - in_pos; float w = window(x); + const In *in_ptr = in_blk_ptr + i * nchannels + c0; for (int c = 0; c < nc; c++) { - float in_val = ConvertInput(in_blk_ptr[i * nchannels + c]); + float in_val = ConvertInput(in_ptr[c]); tmp[c] = fma(in_val, w, tmp[c]); } } - Out *out_ptr = out + (out_pos - sample.out_begin) * nchannels; + for (int c = 0; c < nc; c++) { - out_ptr[c + c0] = ConvertSatNorm(tmp[c]); + out_ptr[c0 + c] = ConvertSatNorm(tmp[c]); } } } diff --git a/dali/kernels/signal/resampling_gpu.h b/dali/kernels/signal/resampling_gpu.h index a02fe6a855f..e8fe697ab62 100644 --- a/dali/kernels/signal/resampling_gpu.h +++ b/dali/kernels/signal/resampling_gpu.h @@ -48,11 +48,17 @@ class ResamplerGPU { auto in_sh = in.shape.tensor_shape_span(i); auto out_sh = out_shape.tensor_shape_span(i); auto &arg = args[i]; - if (arg.out_begin > 0 || arg.out_end > 0) { - out_sh[0] = arg.out_end - arg.out_begin; - } else { - out_sh[0] = resampled_length(in_sh[0], arg.in_rate, arg.out_rate); - } + auto out_len = resampled_length(in_sh[0], arg.in_rate, arg.out_rate); + auto out_begin = arg.out_begin > 0 ? arg.out_begin : 0; + auto out_end = arg.out_end > 0 ? arg.out_end : out_len; + if (out_end < out_begin) + throw std::invalid_argument( + make_string("out_begin can't be larger than out_end. Got out_begin=", out_begin, + ", out_end=", out_end)); + if (out_end > out_len) + throw std::invalid_argument(make_string( + "out_end can't be outside of the range of the output signal: [0, ", out_len, ")")); + out_sh[0] = out_end - out_begin; } req.output_shapes = {out_shape}; return req; diff --git a/dali/kernels/signal/resampling_gpu_test.cu b/dali/kernels/signal/resampling_gpu_test.cu index a4dc4f962c7..46a2302d818 100644 --- a/dali/kernels/signal/resampling_gpu_test.cu +++ b/dali/kernels/signal/resampling_gpu_test.cu @@ -27,6 +27,10 @@ namespace test { class ResamplingGPUTest : public ResamplingTest { public: + ResamplingGPUTest() { + this->nsamples_ = 8; + } + void RunResampling(span args) override { ResamplerGPU R; R.Initialize(16); @@ -50,12 +54,11 @@ class ResamplingGPUTest : public ResamplingTest { CUDA_CALL(cudaStreamSynchronize(ctx.gpu.stream)); } - void RunPerfTest(int batch_size, int nchannels, int n_iters = 1000) { - std::vector args_v(batch_size, {22050.0f, 16000.0f}); + void RunPerfTest(int n_iters = 1000) { + std::vector args_v(nsamples_, {22050.0f, 16000.0f}); auto args = make_cspan(args_v); - int nsec = 30; - - this->PrepareData(batch_size, nchannels, args, nsec); + this->nsec_ = 30; + this->PrepareData(args); ResamplerGPU R; R.Initialize(16); @@ -95,31 +98,57 @@ class ResamplingGPUTest : public ResamplingTest { }; TEST_F(ResamplingGPUTest, SingleChannel) { - this->RunTest(8, 1); + this->nchannels_ = 1; + this->RunTest(); } TEST_F(ResamplingGPUTest, TwoChannel) { - this->RunTest(3, 2); + this->nchannels_ = 2; + this->RunTest(); } TEST_F(ResamplingGPUTest, EightChannel) { - this->RunTest(3, 8); + this->nchannels_ = 8; + this->RunTest(); } -TEST_F(ResamplingGPUTest, HundredChannel) { - this->RunTest(3, 100); +TEST_F(ResamplingGPUTest, ThirtyChannel) { + this->nchannels_ = 30; + this->RunTest(); } TEST_F(ResamplingGPUTest, OutBeginEnd) { - this->RunTest(3, 1, true); + this->roi_start_ = 100; + this->roi_end_ = 8000; + this->RunTest(); } TEST_F(ResamplingGPUTest, EightChannelOutBeginEnd) { - this->RunTest(3, 8, true); + this->roi_start_ = 100; + this->roi_end_ = 8000; + this->nchannels_ = 8; + this->RunTest(); +} + +TEST_F(ResamplingGPUTest, PerfTest) { + this->RunPerfTest(1000); +} + +TEST_F(ResamplingGPUTest, SingleChannelNeedHighPrecision) { + this->default_freq_in_ = 0.49; + this->nsec_ = 400; + this->roi_start_ = 4000000; // enough to look long into the signal + this->roi_end_ = 4010000; + this->RunTest(); } -TEST_F(ResamplingGPUTest, DISABLED_PerfTest) { - this->RunPerfTest(64, 1, 1000); +TEST_F(ResamplingGPUTest, ThreeChannelNeedHighPrecision) { + this->default_freq_in_ = 0.49; + this->nsec_ = 400; + this->nchannels_ = 3; + this->roi_start_ = 4000000; // enough to look long into the signal + this->roi_end_ = 4010000; + this->RunTest(); } } // namespace test diff --git a/dali/kernels/signal/resampling_test.cc b/dali/kernels/signal/resampling_test.cc index a4070db3b4b..723676f1cca 100644 --- a/dali/kernels/signal/resampling_test.cc +++ b/dali/kernels/signal/resampling_test.cc @@ -29,41 +29,46 @@ double HannWindow(int i, int n) { return Hann(2.0 * i / n - 1); } -void ResamplingTest::PrepareData(int nsamples, int nchannels, span args, int nsec) { - TensorListShape<> in_sh(nsamples, nchannels > 1 ? 2 : 1); - TensorListShape<> out_sh(nsamples, nchannels > 1 ? 2 : 1); - for (int s = 0; s < nsamples; s++) { +void ResamplingTest::PrepareData(span args) { + TensorListShape<> in_sh(nsamples_, nchannels_ > 1 ? 2 : 1); + TensorListShape<> out_sh(nsamples_, nchannels_ > 1 ? 2 : 1); + for (int s = 0; s < nsamples_; s++) { double in_rate = args[s].in_rate; double out_rate = args[s].out_rate; double scale = in_rate / out_rate; - int n_in = nsec * in_rate + 12345 * s; // different lengths + int n_in = nsec_ * in_rate + 12345 * s; // different lengths int n_out = std::ceil(n_in / scale); int64_t out_begin = args[s].out_begin > 0 ? args[s].out_begin : 0; int64_t out_end = args[s].out_end > 0 ? args[s].out_end : n_out; ASSERT_GT(out_end, out_begin); in_sh.tensor_shape_span(s)[0] = n_in; out_sh.tensor_shape_span(s)[0] = out_end - out_begin; - if (nchannels > 1) { - in_sh.tensor_shape_span(s)[1] = nchannels; - out_sh.tensor_shape_span(s)[1] = nchannels; + if (nchannels_ > 1) { + in_sh.tensor_shape_span(s)[1] = nchannels_; + out_sh.tensor_shape_span(s)[1] = nchannels_; } } ttl_in_.reshape(in_sh); ttl_out_.reshape(out_sh); ttl_outref_.reshape(out_sh); - for (int s = 0; s < nsamples; s++) { + for (int s = 0; s < nsamples_; s++) { double in_rate = args[s].in_rate; double out_rate = args[s].out_rate; - double scale = static_cast(in_rate) / out_rate; + double scale = in_rate / out_rate; int64_t n_in = in_sh.tensor_shape_span(s)[0]; int n_out = std::ceil(n_in / scale); int64_t out_begin = args[s].out_begin > 0 ? args[s].out_begin : 0; int64_t out_end = args[s].out_end > 0 ? args[s].out_end : n_out; - for (int c = 0; c < nchannels; c++) { - float f_in = 0.1f + 0.01 * s + 0.001 * c; - float f_out = f_in * scale; - TestWave(ttl_in_.cpu()[s].data + c, n_in, nchannels, f_in); - TestWave(ttl_outref_.cpu()[s].data + c, n_out, nchannels, f_out, out_begin, out_end); + for (int c = 0; c < nchannels_; c++) { + double f_in = default_freq_in_ + 0.01 * s + 0.001 * c; + double f_out = f_in * scale; + // enough input samples for a given output region + int64_t in_begin = std::max(out_begin * scale - 200, 0); + int64_t in_end = std::min(out_end * scale + 200, n_in); + TestWave(ttl_in_.cpu()[s].data + in_begin * nchannels_ + c, n_in, nchannels_, f_in, + use_envelope_, in_begin, in_end); + TestWave(ttl_outref_.cpu()[s].data + c, n_out, nchannels_, f_out, use_envelope_, out_begin, + out_end); } } } @@ -80,7 +85,7 @@ void ResamplingTest::Verify(span args) { int64_t out_len = out_sh.tensor_shape_span(s)[0]; int nchannels = out_sh.sample_dim() == 1 ? 1 : out_sh.tensor_shape_span(s)[1]; for (int64_t i = 0; i < out_len; i++) { - ASSERT_NEAR(out_data[i], out_ref[i], eps()) + ASSERT_NEAR(out_data[i], out_ref[i], eps_) << "Sample error too big @ sample=" << s << " pos=" << i << std::endl; float diff = std::abs(out_data[i] - out_ref[i]); if (diff > max_diff) @@ -89,26 +94,24 @@ void ResamplingTest::Verify(span args) { } err = std::sqrt(err / out_len); - EXPECT_LE(err, max_avg_err()) << "Average error too big"; + EXPECT_LE(err, max_avg_err_) << "Average error too big"; std::cerr << "Resampling with Hann-windowed sinc filter and 16 zero crossings" "\n max difference vs fresh signal: " << max_diff << "\n RMS error: " << err << std::endl; } } -void ResamplingTest::RunTest(int nsamples, int nchannels, bool use_roi) { +void ResamplingTest::RunTest() { std::vector args_v; - for (int i = 0; i < nsamples; i++) { - int roi_start = use_roi ? 100 : 0; - int roi_end = use_roi ? 8000 : -1; + for (int i = 0; i < nsamples_; i++) { if (i % 2 == 0) - args_v.push_back({22050.0f, 16000.0f, roi_start, roi_end}); + args_v.push_back({22050.0f, 16000.0f, roi_start_, roi_end_}); else - args_v.push_back({44100.0f, 16000.0f, roi_start, roi_end}); + args_v.push_back({44100.0f, 16000.0f, roi_start_, roi_end_}); } auto args = make_cspan(args_v); - PrepareData(nsamples, nchannels, args); + PrepareData(args); RunResampling(args); @@ -121,17 +124,19 @@ class ResamplingCPUTest : public ResamplingTest { Resampler R; R.Initialize(16); - int nsamples = args.size(); + ASSERT_EQ(args.size(), nsamples_); auto in_view = ttl_in_.cpu(); auto out_view = ttl_out_.cpu(); - for (int s = 0; s < nsamples; s++) { + for (int s = 0; s < nsamples_; s++) { auto out_sh = out_view.shape[s]; auto in_sh = in_view.shape[s]; int n_in = in_sh[0]; + int n_out = resampled_length(n_in, args[s].in_rate, args[s].out_rate); int nchannels = in_sh.sample_dim() > 1 ? in_sh[1] : 1; int64_t out_begin = args[s].out_begin > 0 ? args[s].out_begin : 0; - int64_t out_end = args[s].out_end > 0 ? args[s].out_end : out_sh[0]; + int64_t out_end = args[s].out_end > 0 ? args[s].out_end : n_out; + ASSERT_EQ(out_sh[0], out_end - out_begin); R.Resample(out_view[s].data, out_begin, out_end, args[s].out_rate, in_view[s].data, n_in, args[s].in_rate, nchannels); } @@ -139,23 +144,53 @@ class ResamplingCPUTest : public ResamplingTest { }; TEST_F(ResamplingCPUTest, SingleChannel) { - this->RunTest(1, 1); + this->nchannels_ = 1; + this->RunTest(); } TEST_F(ResamplingCPUTest, TwoChannel) { - this->RunTest(1, 2); + this->nchannels_ = 2; + this->RunTest(); } TEST_F(ResamplingCPUTest, EightChannel) { - this->RunTest(1, 8); + this->nchannels_ = 8; + this->RunTest(); +} + +TEST_F(ResamplingCPUTest, ThirtyChannel) { + this->nchannels_ = 30; + this->RunTest(); } TEST_F(ResamplingCPUTest, OutBeginEnd) { - this->RunTest(1, 1, true); + this->roi_start_ = 100; + this->roi_end_ = 8000; + this->RunTest(); } TEST_F(ResamplingCPUTest, EightChannelOutBeginEnd) { - this->RunTest(3, 8, true); + this->roi_start_ = 100; + this->roi_end_ = 8000; + this->nchannels_ = 8; + this->RunTest(); +} + +TEST_F(ResamplingCPUTest, SingleChannelNeedHighPrecision) { + this->default_freq_in_ = 0.49; + this->nsec_ = 400; + this->roi_start_ = 4000000; // enough to look at the tail + this->roi_end_ = -1; + this->RunTest(); +} + +TEST_F(ResamplingCPUTest, ThreeChannelNeedHighPrecision) { + this->default_freq_in_ = 0.49; + this->nsec_ = 400; + this->nchannels_ = 3; + this->roi_start_ = 4000000; // enough to look at the tail + this->roi_end_ = -1; + this->RunTest(); } } // namespace test diff --git a/dali/kernels/signal/resampling_test.h b/dali/kernels/signal/resampling_test.h index 51f7c4be7b2..66abc8fd800 100644 --- a/dali/kernels/signal/resampling_test.h +++ b/dali/kernels/signal/resampling_test.h @@ -31,28 +31,43 @@ namespace test { double HannWindow(int i, int n); template -void TestWave(T *out, int n, int stride, float freq, int i_start = 0, int i_end = -1) { - if (i_end <= 0) i_end = n; +void TestWave(T *out, int n, int stride, double freq, bool envelope = true, int i_start = 0, + int i_end = -1) { + if (i_end <= 0) + i_end = n; assert(i_start >= 0 && i_start <= n); assert(i_end >= 0 && i_end <= n); for (int i = i_start; i < i_end; i++) { - float f = std::sin(i * freq) * HannWindow(i, n); + float f; + if (envelope) + f = std::sin(i * freq) * HannWindow(i, n); + else + f = std::sin(i * freq); out[(i - i_start) * stride] = ConvertSatNorm(f); } } class ResamplingTest : public ::testing::Test { public: - void PrepareData(int nsamples, int nchannels, span args, int nsec = 1); + void PrepareData(span args); - virtual float eps() const { return 2e-3; } virtual float max_avg_err() const { return 1e-3; } void Verify(span args); virtual void RunResampling(span args) = 0; - void RunTest(int nsamples, int nchannels, bool use_roi = false); + void RunTest(); + protected: + int nsamples_ = 1; + int nchannels_ = 1; + double default_freq_in_ = 0.1; + int nsec_ = 1; + float eps_ = 2e-3; + float max_avg_err_ = 1e-3; + bool use_envelope_ = true; + int64_t roi_start_ = 0; + int64_t roi_end_ = -1; // means end-of-signal TestTensorList ttl_in_; TestTensorList ttl_out_; From 19e58e86a0cc9885e8110f568a855ab23e2129db Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Tue, 17 May 2022 11:43:33 +0200 Subject: [PATCH 14/36] Move resampling GPU to cu file & add sync to Initialize Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling_gpu.cu | 110 +++++++++++++++++++++ dali/kernels/signal/resampling_gpu.h | 107 +++++--------------- dali/kernels/signal/resampling_gpu_test.cu | 1 + 3 files changed, 137 insertions(+), 81 deletions(-) create mode 100644 dali/kernels/signal/resampling_gpu.cu diff --git a/dali/kernels/signal/resampling_gpu.cu b/dali/kernels/signal/resampling_gpu.cu new file mode 100644 index 00000000000..d68e111aeed --- /dev/null +++ b/dali/kernels/signal/resampling_gpu.cu @@ -0,0 +1,110 @@ +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "dali/core/dev_buffer.h" +#include "dali/core/mm/memory.h" +#include "dali/core/static_switch.h" +#include "dali/kernels/dynamic_scratchpad.h" +#include "dali/kernels/kernel.h" +#include "dali/kernels/signal/resampling_gpu.cuh" +#include "dali/kernels/signal/resampling_gpu.h" + +namespace dali { +namespace kernels { +namespace signal { + +namespace resampling { + +template +void ResamplerGPU::Initialize(int lobes, int lookup_size) { + windowed_sinc(window_cpu_, lookup_size, lobes); + window_gpu_storage_.from_host(window_cpu_.storage); + window_gpu_ = window_cpu_; + window_gpu_.lookup = window_gpu_storage_.data(); + CUDA_CALL(cudaStreamSynchronize(0)); +} + +template +KernelRequirements ResamplerGPU::Setup(KernelContext &context, const InListGPU &in, + span args) { + KernelRequirements req; + auto out_shape = in.shape; + for (int i = 0; i < in.num_samples(); i++) { + auto in_sh = in.shape.tensor_shape_span(i); + auto out_sh = out_shape.tensor_shape_span(i); + auto &arg = args[i]; + if (arg.out_begin > 0 || arg.out_end > 0) { + out_sh[0] = arg.out_end - arg.out_begin; + } else { + out_sh[0] = resampled_length(in_sh[0], arg.in_rate, arg.out_rate); + } + } + req.output_shapes = {out_shape}; + return req; +} + +template +void ResamplerGPU::Run(KernelContext &context, const OutListGPU &out, + const InListGPU &in, span args) { + if (window_gpu_storage_.empty()) + Initialize(); + + assert(context.scratchpad); + auto &scratch = *context.scratchpad; + + int nsamples = in.num_samples(); + auto samples_cpu = + make_span(scratch.Allocate(nsamples), nsamples); + + bool any_multichannel = false; + for (int i = 0; i < nsamples; i++) { + auto &desc = samples_cpu[i]; + desc.in = in[i].data; + desc.out = out[i].data; + desc.window = window_gpu_; + const auto &in_sh = in[i].shape; + const auto &out_sh = out[i].shape; + desc.in_len = in_sh[0]; + auto &arg = args[i]; + desc.out_begin = arg.out_begin > 0 ? arg.out_begin : 0; + desc.out_end = + arg.out_end > 0 ? arg.out_end : resampled_length(in_sh[0], arg.in_rate, arg.out_rate); + assert((desc.out_end - desc.out_begin) == out_sh[0]); + desc.nchannels = in[i].shape.sample_dim() > 1 ? in_sh[1] : 1; + desc.scale = arg.in_rate / arg.out_rate; + any_multichannel |= desc.nchannels > 1; + } + + auto samples_gpu = scratch.ToGPU(context.gpu.stream, samples_cpu); + + dim3 block(256, 1); + int blocks_per_sample = std::max(32, 1024 / nsamples); + dim3 grid(blocks_per_sample, nsamples); + + // window coefficients and temporary per channel out values + size_t shm_size = (window_gpu_storage_.size() + (SHM_NCHANNELS + 1) * block.x) * sizeof(float); + + BOOL_SWITCH(!any_multichannel, SingleChannel, + (ResampleGPUKernel + <<>>(samples_gpu);)); // NOLINT + CUDA_CALL(cudaGetLastError()); +} + +DALI_INSTANTIATE_RESAMPLER_GPU(); + +} // namespace resampling +} // namespace signal +} // namespace kernels +} // namespace dali diff --git a/dali/kernels/signal/resampling_gpu.h b/dali/kernels/signal/resampling_gpu.h index e8fe697ab62..f82ca7cf905 100644 --- a/dali/kernels/signal/resampling_gpu.h +++ b/dali/kernels/signal/resampling_gpu.h @@ -17,12 +17,8 @@ #include #include "dali/kernels/signal/resampling.h" -#include "dali/kernels/signal/resampling_gpu.cuh" #include "dali/kernels/kernel.h" -#include "dali/kernels/dynamic_scratchpad.h" -#include "dali/core/mm/memory.h" #include "dali/core/dev_buffer.h" -#include "dali/core/static_switch.h" namespace dali { namespace kernels { @@ -30,86 +26,15 @@ namespace signal { namespace resampling { -template -class ResamplerGPU { +template +class DLL_PUBLIC ResamplerGPU { public: - void Initialize(int lobes = 16, int lookup_size = 2048) { - windowed_sinc(window_cpu_, lookup_size, lobes); - window_gpu_storage_.from_host(window_cpu_.storage); - window_gpu_ = window_cpu_; - window_gpu_.lookup = window_gpu_storage_.data(); - } + void Initialize(int lobes = 16, int lookup_size = 2048); - KernelRequirements Setup(KernelContext &context, const InListGPU &in, - span args) { - KernelRequirements req; - auto out_shape = in.shape; - for (int i = 0; i < in.num_samples(); i++) { - auto in_sh = in.shape.tensor_shape_span(i); - auto out_sh = out_shape.tensor_shape_span(i); - auto &arg = args[i]; - auto out_len = resampled_length(in_sh[0], arg.in_rate, arg.out_rate); - auto out_begin = arg.out_begin > 0 ? arg.out_begin : 0; - auto out_end = arg.out_end > 0 ? arg.out_end : out_len; - if (out_end < out_begin) - throw std::invalid_argument( - make_string("out_begin can't be larger than out_end. Got out_begin=", out_begin, - ", out_end=", out_end)); - if (out_end > out_len) - throw std::invalid_argument(make_string( - "out_end can't be outside of the range of the output signal: [0, ", out_len, ")")); - out_sh[0] = out_end - out_begin; - } - req.output_shapes = {out_shape}; - return req; - } + KernelRequirements Setup(KernelContext &context, const InListGPU &in, span args); - void Run(KernelContext &context, const OutListGPU &out, - const InListGPU &in, span args) { - if (window_gpu_storage_.empty()) - Initialize(); - - assert(context.scratchpad); - auto &scratch = *context.scratchpad; - - int nsamples = in.num_samples(); - auto samples_cpu = - make_span(scratch.Allocate(nsamples), nsamples); - - bool any_multichannel = false; - for (int i = 0; i < nsamples; i++) { - auto &desc = samples_cpu[i]; - desc.in = in[i].data; - desc.out = out[i].data; - desc.window = window_gpu_; - const auto &in_sh = in[i].shape; - const auto &out_sh = out[i].shape; - desc.in_len = in_sh[0]; - auto &arg = args[i]; - desc.out_begin = arg.out_begin > 0 ? arg.out_begin : 0; - desc.out_end = arg.out_end > 0 ? arg.out_end : - resampled_length(in_sh[0], arg.in_rate, arg.out_rate); - assert((desc.out_end - desc.out_begin) == out_sh[0]); - desc.nchannels = in[i].shape.sample_dim() > 1 ? in_sh[1] : 1; - desc.scale = arg.in_rate / arg.out_rate; - any_multichannel |= desc.nchannels > 1; - } - - auto samples_gpu = scratch.ToGPU(context.gpu.stream, samples_cpu); - - dim3 block(256, 1); - int blocks_per_sample = std::max(32, 1024 / nsamples); - dim3 grid(blocks_per_sample, nsamples); - - // window coefficients and temporary per channel out values - size_t shm_size = (window_gpu_storage_.size() + (SHM_NCHANNELS + 1) * block.x) * sizeof(float); - - BOOL_SWITCH(!any_multichannel, SingleChannel, ( - ResampleGPUKernel - <<>>(samples_gpu); - )); // NOLINT - CUDA_CALL(cudaGetLastError()); - } + void Run(KernelContext &context, const OutListGPU &out, + const InListGPU &in, span args); private: ResamplingWindowCPU window_cpu_; @@ -117,6 +42,26 @@ class ResamplerGPU { DeviceBuffer window_gpu_storage_; }; +#define DALI_INSTANTIATE_RESAMPLER_GPU_OUT(linkage, Out)\ + linkage template class ResamplerGPU; \ + linkage template class ResamplerGPU; \ + linkage template class ResamplerGPU; \ + linkage template class ResamplerGPU; \ + linkage template class ResamplerGPU; \ + linkage template class ResamplerGPU; \ + linkage template class ResamplerGPU; \ + +#define DALI_INSTANTIATE_RESAMPLER_GPU(linkage) \ + DALI_INSTANTIATE_RESAMPLER_GPU_OUT(linkage, float) \ + DALI_INSTANTIATE_RESAMPLER_GPU_OUT(linkage, int8_t) \ + DALI_INSTANTIATE_RESAMPLER_GPU_OUT(linkage, uint8_t) \ + DALI_INSTANTIATE_RESAMPLER_GPU_OUT(linkage, int16_t) \ + DALI_INSTANTIATE_RESAMPLER_GPU_OUT(linkage, uint16_t) \ + DALI_INSTANTIATE_RESAMPLER_GPU_OUT(linkage, int32_t) \ + DALI_INSTANTIATE_RESAMPLER_GPU_OUT(linkage, uint32_t) + +DALI_INSTANTIATE_RESAMPLER_GPU(extern) + } // namespace resampling } // namespace signal } // namespace kernels diff --git a/dali/kernels/signal/resampling_gpu_test.cu b/dali/kernels/signal/resampling_gpu_test.cu index 46a2302d818..632635c7c09 100644 --- a/dali/kernels/signal/resampling_gpu_test.cu +++ b/dali/kernels/signal/resampling_gpu_test.cu @@ -17,6 +17,7 @@ #include #include "dali/kernels/signal/resampling_gpu.h" #include "dali/kernels/signal/resampling_test.h" +#include "dali/kernels/dynamic_scratchpad.h" #include "dali/core/cuda_event.h" namespace dali { From bd8212a4203ba9935f700233d6f73c3686f49ede Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Tue, 17 May 2022 13:34:42 +0200 Subject: [PATCH 15/36] Add audio_resample GPU operator Signed-off-by: Joaquin Anton --- dali/operators/audio/resample.cc | 8 +- dali/operators/audio/resample.h | 29 ++++--- dali/operators/audio/resample_gpu.cc | 78 +++++++++++++++++++ .../python/test_operator_audio_resample.py | 54 ++++++++----- 4 files changed, 133 insertions(+), 36 deletions(-) create mode 100644 dali/operators/audio/resample_gpu.cc diff --git a/dali/operators/audio/resample.cc b/dali/operators/audio/resample.cc index 0e49fe61e50..5e40b8c0119 100644 --- a/dali/operators/audio/resample.cc +++ b/dali/operators/audio/resample.cc @@ -115,7 +115,7 @@ class ResampleCPU : public ResampleBase { const auto &in_shape = in.shape(); out.SetLayout(in.GetLayout()); int N = in.num_samples(); - assert(N == static_cast(scales_.size())); + assert(N == static_cast(args_.size())); assert(out.type() == dtype_); auto &tp = ws.GetThreadPool(); @@ -129,7 +129,7 @@ class ResampleCPU : public ResampleBase { make_string("Unsupported output type: ", dtype_, "\nSupported types are : ", ListTypeNames()));)); TYPE_SWITCH(dtype_, type2id, T, (AUDIO_RESAMPLE_TYPES), - (ResampleTyped(view(out[s]), in_view, scales_[s]);), + (ResampleTyped(view(out[s]), in_view, args_[s]);), (assert(!"Unreachable code."))); }); } @@ -137,9 +137,9 @@ class ResampleCPU : public ResampleBase { } template - void ResampleTyped(const OutTensorCPU &out, const InTensorCPU &in, double scale) { + void ResampleTyped(const OutTensorCPU &out, const InTensorCPU &in, const Args& args) { int ch = out.shape.sample_dim() > 1 ? out.shape[1] : 1; - R.Resample(out.data, 0, out.shape[0], scale, in.data, in.shape[0], 1, ch); + R.Resample(out.data, 0, out.shape[0], args.out_rate, in.data, in.shape[0], args.in_rate, ch); } template diff --git a/dali/operators/audio/resample.h b/dali/operators/audio/resample.h index be2d94bf1ad..1ca8ee3c979 100644 --- a/dali/operators/audio/resample.h +++ b/dali/operators/audio/resample.h @@ -31,15 +31,16 @@ class ResampleBase : public Operator { public: explicit ResampleBase(const OpSpec &spec) : Operator(spec) { DALI_ENFORCE(in_rate_.HasValue() == out_rate_.HasValue(), - "The parameters ``in_rate`` and ``out_rate`` must be specified together."); + "The parameters ``in_rate`` and ``out_rate`` must be specified together."); if (in_rate_.HasValue() + scale_.HasValue() + out_length_.HasValue() > 1) DALI_FAIL("The sampling rates, ``scale`` and ``out_length`` cannot be used together."); if (!in_rate_.HasValue() && !scale_.HasValue() && !out_length_.HasValue()) - DALI_FAIL("No resampling factor specified! Please supply either the scale, " - "the output length or the input and output sampling rates."); + DALI_FAIL( + "No resampling factor specified! Please supply either the scale, " + "the output length or the input and output sampling rates."); quality_ = spec_.template GetArgument("quality"); - DALI_ENFORCE(quality_ >= 0 && quality_ <= 100, make_string("``quality`` out of range: ", - quality_, "\nValid range is [0..100].")); + DALI_ENFORCE(quality_ >= 0 && quality_ <= 100, + make_string("``quality`` out of range: ", quality_, "\nValid range is [0..100].")); if (spec_.TryGetArgument(dtype_, "dtype")) { // silence useless warning -----------------------------vvvvvvvvvvvvvvvv TYPE_SWITCH(dtype_, type2id, T, (AUDIO_RESAMPLE_TYPES), (T x; (void)x;), @@ -58,12 +59,12 @@ class ResampleBase : public Operator { dtype_ = ws.template Input(0).type(); outputs[0].type = dtype_; - CalculateScaleAndShape(outputs[0].shape, ws); + CalculateShapeAndArgs(outputs[0].shape, ws); return true; } - void CalculateScaleAndShape(TensorListShape<> &out_shape, const workspace_t &ws) { + void CalculateShapeAndArgs(TensorListShape<> &out_shape, const workspace_t &ws) { const auto &input = ws.template Input(0); const TensorListShape<> &shape = input.shape(); DALI_ENFORCE(shape.sample_dim() == 1 || shape.sample_dim() == 2, @@ -71,7 +72,7 @@ class ResampleBase : public Operator { "channel dimension."); out_shape = shape; int N = shape.num_samples(); - scales_.resize(N); + args_.resize(N); if (in_rate_.HasValue()) { assert(out_rate_.HasValue()); in_rate_.Acquire(spec_, ws, N); @@ -93,7 +94,8 @@ class ResampleBase : public Operator { error << " for sample " << s; DALI_FAIL(error.str()); } - scales_[s] = out_rate / in_rate; + args_[s].in_rate = in_rate; + args_[s].out_rate = out_rate; int64_t in_length = shape.tensor_shape_span(s)[0]; int64_t out_length = kernels::signal::resampling::resampled_length(in_length, in_rate, out_rate); @@ -110,7 +112,8 @@ class ResampleBase : public Operator { error << " for sample " << s; DALI_FAIL(error.str()); } - scales_[s] = scale; + args_[s].in_rate = 1.0; + args_[s].out_rate = scale; int64_t in_length = shape.tensor_shape_span(s)[0]; int64_t out_length = kernels::signal::resampling::resampled_length(in_length, 1, scale); @@ -125,7 +128,8 @@ class ResampleBase : public Operator { DALI_FAIL(make_string("Cannot produce a non-empty signal from an empty input.\n" "Error at sample ", s)); } - scales_[s] = in_length ? 1.0 * out_length / in_length : 0.0; + args_[s].in_rate = 1.0; + args_[s].out_rate = in_length ? 1.0 * out_length / in_length : 0.0; out_shape.tensor_shape_span(s)[0] = out_length; } } else { @@ -144,7 +148,8 @@ class ResampleBase : public Operator { ArgValue scale_{"scale", spec_}; ArgValue out_length_{"out_length", spec_}; - std::vector scales_; + using Args = kernels::signal::resampling::Args; + SmallVector args_; }; } // namespace audio diff --git a/dali/operators/audio/resample_gpu.cc b/dali/operators/audio/resample_gpu.cc new file mode 100644 index 00000000000..f327e2da80d --- /dev/null +++ b/dali/operators/audio/resample_gpu.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "dali/operators/audio/resample.h" +#include "dali/operators/audio/resampling_params.h" +#include "dali/kernels/signal/resampling_gpu.h" +#include "dali/kernels/kernel_params.h" +#include "dali/kernels/kernel_manager.h" + +namespace dali { +namespace audio { + +using kernels::InListGPU; +using kernels::OutListGPU; + +class ResampleGPU : public ResampleBase { + public: + using Base = ResampleBase; + explicit ResampleGPU(const OpSpec &spec) : Base(spec) {} + + void RunImpl(DeviceWorkspace &ws) override { + auto &out = ws.template Output(0); + const auto &in = ws.template Input(0); + out.SetLayout(in.GetLayout()); + + int N = in.num_samples(); + assert(N == static_cast(args_.size())); + assert(out.type() == dtype_); + + TYPE_SWITCH(dtype_, type2id, Out, (AUDIO_RESAMPLE_TYPES), ( + TYPE_SWITCH(in.type(), type2id, In, (AUDIO_RESAMPLE_TYPES), ( + ResampleTyped(view(out), view(in), ws.stream()); + ), ( + DALI_FAIL( + make_string("Unsupported input type: ", in.type(), "\nSupported types are : ", + ListTypeNames())); + )); + ), ( + DALI_FAIL( + make_string("Unsupported output type: ", dtype_, "\nSupported types are : ", + ListTypeNames())); + )); + } + + template + void ResampleTyped(const OutListGPU &out, const InListGPU &in, cudaStream_t stream) { + using Kernel = kernels::signal::resampling::ResamplerGPU; + kmgr_.Resize(1); + auto args = make_cspan(args_); + kernels::KernelContext ctx; + ctx.gpu.stream = stream; + kmgr_.Setup(0, ctx, in, args); + kmgr_.Run(0, ctx, out, in, args); + } + + private: + kernels::KernelManager kmgr_; +}; + + +} // namespace audio + +DALI_REGISTER_OPERATOR(experimental__AudioResample, audio::ResampleGPU, GPU); + +} // namespace dali diff --git a/dali/test/python/test_operator_audio_resample.py b/dali/test/python/test_operator_audio_resample.py index 9ea0ffd3c73..d1b542b02fd 100644 --- a/dali/test/python/test_operator_audio_resample.py +++ b/dali/test/python/test_operator_audio_resample.py @@ -17,8 +17,8 @@ import numpy as np import scipy.io.wavfile -from test_audio_decoder_utils import generate_waveforms, rosa_resample -from test_utils import compare_pipelines, get_files, check_batch, dali_type_to_np +from test_audio_decoder_utils import generate_waveforms +from test_utils import check_batch, dali_type_to_np, as_array names = [ "/tmp/dali_test_1C.wav", @@ -34,64 +34,76 @@ rates = [ 16000, 22050, 12347 ] lengths = [ 10000, 54321, 12345 ] -def create_test_files(): +def create_files(): for i in range(len(names)): wave = generate_waveforms(lengths[i], freqs[i]) wave = (wave * 32767).round().astype(np.int16) scipy.io.wavfile.write(names[i], rates[i], wave) -create_test_files() +create_files() @pipeline_def -def audio_decoder_pipe(): +def audio_decoder_pipe(device): encoded, _ = fn.readers.file(files=names) audio0, sr0 = fn.decoders.audio(encoded, dtype=types.FLOAT) out_sr = 15000 audio1, sr1 = fn.decoders.audio(encoded, dtype=types.FLOAT, sample_rate=out_sr) + if device == 'gpu': + audio0 = audio0.gpu() audio2 = fn.experimental.audio_resample(audio0, in_rate=sr0, out_rate=out_sr) audio3 = fn.experimental.audio_resample(audio0, scale=out_sr/sr0) audio4 = fn.experimental.audio_resample(audio0, out_length=fn.shapes(audio1)[0]) return audio1, audio2, audio3, audio4 -def test_standalone_vs_fused(): - pipe = audio_decoder_pipe(batch_size=2, num_threads=1, device_id=0) +def _test_standalone_vs_fused(device): + pipe = audio_decoder_pipe(device=device, batch_size=2, num_threads=1, device_id=0) pipe.build() + is_gpu = device == 'gpu' for _ in range(2): outs = pipe.run() # two sampling rates - should be bit-exact - check_batch(outs[0], outs[1], eps=0, max_allowed_error=0) + check_batch(outs[0], outs[1], eps=0, max_allowed_error=1e-4 if is_gpu else 0) # numerical round-off error in rate check_batch(outs[0], outs[2], eps=1e-6, max_allowed_error=1e-4) # here, the sampling rate is slightly different, so we can tolerate larger errors check_batch(outs[0], outs[3], eps=1e-4, max_allowed_error=1) -def _test_type_conversion(src_type, in_values, dst_type, out_values, eps): +def test_standalone_vs_fused(): + for device in ('gpu', 'cpu'): + yield _test_standalone_vs_fused, device + +def _test_type_conversion(device, src_type, in_values, dst_type, out_values, rtol=1e-6, atol=None): src_nptype = dali_type_to_np(src_type) dst_nptype = dali_type_to_np(dst_type) assert len(out_values) == len(in_values) in_data = [np.full((100 + 10 * i,), x, src_nptype) for i, x in enumerate(in_values)] @pipeline_def(batch_size=len(in_values)) - def test_pipe(): - input = fn.external_source(in_data, batch = False, cycle = 'quiet') + def test_pipe(device): + input = fn.external_source(in_data, batch=False, cycle='quiet', device=device) return fn.experimental.audio_resample(input, dtype=dst_type, scale=1, quality=0) - pipe = test_pipe(device_id=0, num_threads=4) + pipe = test_pipe(device, device_id=0, num_threads=4) pipe.build() + is_gpu = device == 'gpu' for _ in range(2): out, = pipe.run() assert len(out) == len(out_values) assert out.dtype == dst_type for i in range(len(out_values)): ref = np.full_like(in_data[i], out_values[i], dst_nptype) - if not np.allclose(out.at(i), ref, 1e-6, eps): - print("Actual: ", out.at(i)) - print(out.at(i).dtype, out.at(i).shape) + out_arr = as_array(out[i]) + if atol is not None: + ok = np.allclose(out_arr, ref, rtol, atol) + else: + ok = np.allclose(out_arr, ref, rtol) + if not ok: + print("Actual: ", out_arr) + print(out_arr.dtype, out_arr.shape) print("Reference: ", ref) print(ref.dtype, ref.shape) - print("Diff: ", out.at(i).astype(np.float) - ref) - assert np.allclose(out.at(i), ref, 1e-6, eps) - + print("Diff: ", out_arr.astype(np.float) - ref) + assert False def test_dynamic_ranges(): for type, values, eps in [(types.FLOAT, [-1.e30, -1-1.e-6, -1, -0.5, -1.e-30, 0, 1.e-30, 0.5, 1, 1+1.e-6, 1e30], 0), @@ -101,7 +113,8 @@ def test_dynamic_ranges(): (types.INT16, [-32768, -32767, -100, -1, 0, 1, 100, 32767], 0), (types.UINT32, [0, 1, 0x7fffffff, 0x80000000, 0xfffffffe, 0xffffffff], 128), (types.INT32, [-0x80000000, -0x7fffffff, -100, -1, 0, 1, 0x7fffffff], 128)]: - yield _test_type_conversion, type, values, type, values, eps + yield _test_type_conversion, 'gpu', type, values, type, values, 2e-5, eps + yield _test_type_conversion, 'cpu', type, values, type, values, 1e-6, eps def test_type_conversion(): type_ranges = [(types.FLOAT, [-1, 1]), @@ -138,4 +151,5 @@ def test_type_conversion(): if eps < 1 and (o_lo != -o_hi or (i_hi != i_lo and dst_type != types.FLOAT)): eps = 1 - yield _test_type_conversion, src_type, in_values, dst_type, out_values, eps + yield _test_type_conversion, 'gpu', src_type, in_values, dst_type, out_values, 2e-5, eps + yield _test_type_conversion, 'cpu', src_type, in_values, dst_type, out_values, 1e-6, eps From bf32abc4185cf6dccc7ea09a0a93aa7191be82aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Zientkiewicz?= Date: Tue, 10 May 2022 07:57:48 +0200 Subject: [PATCH 16/36] Initial effort. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: MichaƂ Zientkiewicz --- dali/kernels/signal/resampling.h | 23 ++++++---- dali/kernels/signal/resampling_gpu.cu | 0 dali/kernels/signal/resampling_gpu.h | 63 +++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 8 deletions(-) create mode 100644 dali/kernels/signal/resampling_gpu.cu create mode 100644 dali/kernels/signal/resampling_gpu.h diff --git a/dali/kernels/signal/resampling.h b/dali/kernels/signal/resampling.h index 42cd452cdc4..28c1c161d14 100644 --- a/dali/kernels/signal/resampling.h +++ b/dali/kernels/signal/resampling.h @@ -61,12 +61,12 @@ struct ResamplingWindow { return {i0, i1}; } - inline float operator()(float x) const { + inline DALI_HOST_DEV float operator()(float x) const { float fi = x * scale + center; float floori = std::floor(fi); float di = fi - floori; int i = floori; - assert(i >= 0 && i < static_cast(lookup.size())); + assert(i >= 0 && i < lookup_size); return lookup[i] + di * (lookup[i + 1] - lookup[i]); } @@ -112,25 +112,32 @@ struct ResamplingWindow { float scale = 1, center = 1; int lobes = 0, coeffs = 0; - std::vector lookup; + int lookup_size = 0; + const float *lookup = nullptr; }; -inline void windowed_sinc(ResamplingWindow &window, +struct ResamplingWindowCPU : ResamplingWindow { + std::vector storage; +}; + +inline void windowed_sinc(ResamplingWindowCPU &window, int coeffs, int lobes, std::function envelope = Hann) { assert(coeffs > 1 && lobes > 0 && "Degenerate parameters specified."); float scale = 2.0f * lobes / (coeffs - 1); float scale_envelope = 2.0f / coeffs; window.coeffs = coeffs; window.lobes = lobes; - window.lookup.clear(); - window.lookup.resize(coeffs + 5); // add zeros and a full 4-lane vector + window.storage.clear(); + window.storage.resize(coeffs + 5); // add zeros and a full 4-lane vector int center = (coeffs - 1) * 0.5f; for (int i = 0; i < coeffs; i++) { float x = (i - center) * scale; float y = (i - center) * scale_envelope; float w = sinc(x) * envelope(y); - window.lookup[i + 1] = w; + window.storage[i + 1] = w; } + window.lookup = window.storage.data(); + window.lookup_size = window.storage.size(); window.center = center + 1; // allow for leading zero window.scale = 1 / scale; } @@ -141,7 +148,7 @@ inline int64_t resampled_length(int64_t in_length, double in_rate, double out_ra } struct Resampler { - ResamplingWindow window; + ResamplingWindowCPU window; void Initialize(int lobes = 16, int lookup_size = 2048) { windowed_sinc(window, lookup_size, lobes); diff --git a/dali/kernels/signal/resampling_gpu.cu b/dali/kernels/signal/resampling_gpu.cu new file mode 100644 index 00000000000..e69de29bb2d diff --git a/dali/kernels/signal/resampling_gpu.h b/dali/kernels/signal/resampling_gpu.h new file mode 100644 index 00000000000..4390b34f665 --- /dev/null +++ b/dali/kernels/signal/resampling_gpu.h @@ -0,0 +1,63 @@ +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_KERNELS_SIGNAL_RESAMPLING_GPU_H_ +#define DALI_KERNELS_SIGNAL_RESAMPLING_GPU_H_ + +#include "dali/kernels/signal/resampling.h" + +namespace dali { +namespace kernels { +namespace signal { + +namespace resampling { + +ResamplingWindow ToGPU(Scratchpad &scratch, const ResamplingWindow &cpu_window) { + ResamplingWindow wnd = cpu_window; + wnd.lookup = scratch.ToGPU(make_span(cpu_window.lookup, cpu_windwo.lookup_size)); + return wnd; +} + +struct ResamplerGPU { + ResamplingWindowCPU window; + + void Initialize(int lobes = 16, int lookup_size = 2048) { + windowed_sinc(window, lookup_size, lobes); + } + + + /** + * @brief Resample multi-channel signal and convert to Out + * + * Calculates a range of resampled signal. + * The function can seamlessly resample the input and produce the result in chunks. + * To reuse memory and still simulate chunk processing, adjust the in/out pointers. + */ + template + void Resample( + Out *__restrict__ out, int64_t out_begin, int64_t out_end, double out_rate, + const float *__restrict__ in, int64_t n_in, double in_rate, + int num_channels, + cudaStream_t stream); +}; + +} // namespace resampling +} // namespace signal +} // namespace kernels +} // namespace dali + +} // namespace dali + + +#endif // DALI_KERNELS_SIGNAL_RESAMPLING_GPU_H_ From a983fdddf99ebb26f844dc20ba96d91b48c74363 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Tue, 10 May 2022 17:38:33 +0200 Subject: [PATCH 17/36] Add signal resampling GPU kernel Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling.h | 42 +++-- dali/kernels/signal/resampling_gpu.cu | 0 dali/kernels/signal/resampling_gpu.cuh | 117 +++++++++++++ dali/kernels/signal/resampling_gpu.h | 104 ++++++++--- dali/kernels/signal/resampling_gpu_test.cu | 78 +++++++++ dali/kernels/signal/resampling_test.cc | 190 +++++++++++++-------- dali/kernels/signal/resampling_test.h | 67 ++++++++ 7 files changed, 489 insertions(+), 109 deletions(-) delete mode 100644 dali/kernels/signal/resampling_gpu.cu create mode 100644 dali/kernels/signal/resampling_gpu.cuh create mode 100644 dali/kernels/signal/resampling_gpu_test.cu create mode 100644 dali/kernels/signal/resampling_test.h diff --git a/dali/kernels/signal/resampling.h b/dali/kernels/signal/resampling.h index 28c1c161d14..6fd8917209a 100644 --- a/dali/kernels/signal/resampling.h +++ b/dali/kernels/signal/resampling.h @@ -29,6 +29,7 @@ #include "dali/core/small_vector.h" #include "dali/core/convert.h" #include "dali/core/static_switch.h" +#include "dali/core/geom/vec.h" namespace dali { namespace kernels { @@ -54,7 +55,7 @@ inline float32x4_t vsetq_f32(float x0, float x1, float x2, float x3) { #endif struct ResamplingWindow { - inline std::pair input_range(float x) const { + inline DALI_HOST_DEV ivec<2> input_range(float x) const { int xc = std::ceil(x); int i0 = xc - lobes; int i1 = xc + lobes; @@ -220,8 +221,9 @@ struct Resampler { float in_pos = in_block_f - in_block_i; const float *__restrict__ in_block_ptr = in + in_block_i; for (int64_t out_pos = out_block; out_pos < block_end; out_pos++, in_pos += fscale) { - int i0, i1; - std::tie(i0, i1) = window.input_range(in_pos); + auto irange = window.input_range(in_pos); + int i0 = irange[0]; + int i1 = irange[1]; if (i0 + in_block_i < 0) i0 = -in_block_i; if (i1 + in_block_i > n_in) @@ -251,8 +253,9 @@ struct Resampler { * To reuse memory and still simulate chunk processing, adjust the in/out pointers. * * @tparam static_channels number of channels, if known at compile time, or -1 + * @tparam downmix whether to downmix all channels in the output */ - template + template void Resample( Out *__restrict__ out, int64_t out_begin, int64_t out_end, double out_rate, const float *__restrict__ in, int64_t n_in, double in_rate, @@ -282,8 +285,9 @@ struct Resampler { float in_pos = in_block_f - in_block_i; const float *__restrict__ in_block_ptr = in + in_block_i * num_channels; for (int64_t out_pos = out_block; out_pos < block_end; out_pos++, in_pos += fscale) { - int i0, i1; - std::tie(i0, i1) = window.input_range(in_pos); + auto irange = window.input_range(in_pos); + int i0 = irange[0]; + int i1 = irange[1]; if (i0 + in_block_i < 0) i0 = -in_block_i; if (i1 + in_block_i > n_in) @@ -304,8 +308,16 @@ struct Resampler { } } assert(out_pos >= out_begin && out_pos < out_end); - for (int c = 0; c < num_channels; c++) - out[out_pos * num_channels + c] = ConvertSatNorm(tmp[c]); + if (downmix) { + float out_val = 0; + for (int c = 0; c < num_channels; c++) + out_val += tmp[c]; + out_val /= num_channels; + out[out_pos] = ConvertSatNorm(out_val); + } else { + for (int c = 0; c < num_channels; c++) + out[out_pos * num_channels + c] = ConvertSatNorm(tmp[c]); + } } } } @@ -321,12 +333,14 @@ struct Resampler { void Resample( Out *__restrict__ out, int64_t out_begin, int64_t out_end, double out_rate, const float *__restrict__ in, int64_t n_in, double in_rate, - int num_channels) { - VALUE_SWITCH(num_channels, static_channels, (1, 2, 3, 4, 5, 6, 7, 8), - (Resample(out, out_begin, out_end, out_rate, - in, n_in, in_rate, static_channels);), - (Resample<-1, Out>(out, out_begin, out_end, out_rate, - in, n_in, in_rate, num_channels))); + int num_channels, bool downmix = false) { + BOOL_SWITCH(downmix, Downmix, ( + VALUE_SWITCH(num_channels, static_channels, (1, 2, 3, 4, 5, 6, 7, 8), + (Resample(out, out_begin, out_end, out_rate, + in, n_in, in_rate, static_channels);), + (Resample<-1, Downmix, Out>(out, out_begin, out_end, out_rate, + in, n_in, in_rate, num_channels))); + )); // NOLINT } }; diff --git a/dali/kernels/signal/resampling_gpu.cu b/dali/kernels/signal/resampling_gpu.cu deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/dali/kernels/signal/resampling_gpu.cuh b/dali/kernels/signal/resampling_gpu.cuh new file mode 100644 index 00000000000..6d32f6d227c --- /dev/null +++ b/dali/kernels/signal/resampling_gpu.cuh @@ -0,0 +1,117 @@ +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_KERNELS_SIGNAL_RESAMPLING_GPU_CUH_ +#define DALI_KERNELS_SIGNAL_RESAMPLING_GPU_CUH_ + +#include +#include "dali/kernels/signal/resampling.h" + +namespace dali { +namespace kernels { +namespace signal { + +namespace resampling { + +struct SampleDesc { + void *out; + const void *in; + ResamplingWindow window; + int64_t in_len; // num samples in input + int64_t out_len; // num samples in output + int64_t nchannels; // number of channels + double scale; // in_sampling_rate / out_sampling_rate +}; + +/** + * @brief Resamples 1D signal (single or multi-channel), optionally downmixing and converting to a different data type. + * + * @param samples sample descriptors + */ +template +__global__ void ResampleGPUKernel(const SampleDesc *samples) { + auto sample = samples[blockIdx.y]; + double scale = sample.scale; + float fscale = scale; + int nchannels = SingleChannel ? 1 : sample.nchannels; + auto &window = sample.window; + + Out* out = reinterpret_cast(sample.out); + const In* in = reinterpret_cast(sample.in); + + int64_t grid_stride = static_cast(gridDim.x) * blockDim.x; + int64_t out_block = static_cast(blockIdx.x) * blockDim.x; + int64_t start_out_pos = out_block + threadIdx.x; + + double in_block_f = out_block * scale; + int64_t in_block_i = std::floor(in_block_f); + float in_pos_start = in_block_f - in_block_i; + const In* in_blk_ptr = in + in_block_i * nchannels; + + for (int64_t out_pos = start_out_pos; out_pos < sample.out_len; + out_pos += grid_stride, in_pos_start += fscale * grid_stride) { + float in_pos = in_pos_start + fscale * threadIdx.x; + auto i_range = window.input_range(in_pos); + int i0 = i_range[0]; + int i1 = i_range[1]; + if (i0 + in_block_i < 0) + i0 = -in_block_i; + if (i1 + in_block_i > sample.in_len) + i1 = sample.in_len - in_block_i; + + float out_val = 0; + if (SingleChannel) { + for (int i = i0; i < i1; i++) { + In in_val = in_blk_ptr[i]; + float x = i - in_pos; + float w = window(x); + out_val = fma(in_val, w, out_val); + } + out[out_pos] = ConvertSatNorm(out_val); + } else { // multiple channels + float tmp[32]; // more than enough + for (int c = 0; c < nchannels; c++) { + tmp[c] = 0; + } + + for (int i = i0; i < i1; i++) { + float x = i - in_pos; + float w = window(x); + for (int c = 0; c < nchannels; c++) { + In in_val = in_blk_ptr[i * nchannels + c]; + tmp[c] = fma(in_val, w, tmp[c]); + } + } + + if (Downmix) { + for (int c = 0; c < nchannels; c++) { + out_val += tmp[c]; + } + out_val /= nchannels; + out[out_pos] = ConvertSatNorm(out_val); + } else { + for (int c = 0; c < nchannels; c++) { + out[out_pos * nchannels + c] = ConvertSatNorm(tmp[c]); + } + } + } + } +} + +} // namespace resampling +} // namespace signal +} // namespace kernels +} // namespace dali + +#endif // DALI_KERNELS_SIGNAL_RESAMPLING_GPU_CUH_ diff --git a/dali/kernels/signal/resampling_gpu.h b/dali/kernels/signal/resampling_gpu.h index 4390b34f665..b9bc6ba220b 100644 --- a/dali/kernels/signal/resampling_gpu.h +++ b/dali/kernels/signal/resampling_gpu.h @@ -15,7 +15,14 @@ #ifndef DALI_KERNELS_SIGNAL_RESAMPLING_GPU_H_ #define DALI_KERNELS_SIGNAL_RESAMPLING_GPU_H_ +#include #include "dali/kernels/signal/resampling.h" +#include "dali/kernels/signal/resampling_gpu.cuh" +#include "dali/kernels/kernel.h" +#include "dali/kernels/dynamic_scratchpad.h" +#include "dali/core/mm/memory.h" +#include "dali/core/dev_buffer.h" +#include "dali/core/static_switch.h" namespace dali { namespace kernels { @@ -23,33 +30,81 @@ namespace signal { namespace resampling { -ResamplingWindow ToGPU(Scratchpad &scratch, const ResamplingWindow &cpu_window) { - ResamplingWindow wnd = cpu_window; - wnd.lookup = scratch.ToGPU(make_span(cpu_window.lookup, cpu_windwo.lookup_size)); - return wnd; -} - -struct ResamplerGPU { - ResamplingWindowCPU window; - +template +class ResamplerGPU { + public: void Initialize(int lobes = 16, int lookup_size = 2048) { - windowed_sinc(window, lookup_size, lobes); + windowed_sinc(window_cpu_, lookup_size, lobes); + window_gpu_storage_.from_host(window_cpu_.storage); + window_gpu_ = window_cpu_; + window_gpu_.lookup = window_gpu_storage_.data(); + } + + KernelRequirements Setup(KernelContext &context, const InListGPU &in, + span in_rate, span out_rate, bool downmix) { + KernelRequirements req; + auto out_shape = in.shape; + for (int i = 0; i < in.num_samples(); i++) { + auto in_sh = in.shape.tensor_shape_span(i); + auto out_sh = out_shape.tensor_shape_span(i); + out_sh[0] = resampled_length(in_sh[0], in_rate[i], out_rate[i]); + if (downmix) + out_sh[1] = 1; + } + req.output_shapes = {out_shape}; + return req; } + void Run(KernelContext &context, const OutListGPU &out, + const InListGPU &in, span in_rates, span out_rates, + bool downmix) { + if (window_gpu_storage_.empty()) + Initialize(); + + DynamicScratchpad dyn_scratchpad({}, AccessOrder(context.gpu.stream)); + if (!context.scratchpad) + context.scratchpad = &dyn_scratchpad; + auto &scratch = *context.scratchpad; + + int nsamples = in.num_samples(); + auto samples_cpu = + make_span(scratch.Allocate(nsamples), nsamples); + + bool any_multichannel = false; + for (int i = 0; i < nsamples; i++) { + auto &desc = samples_cpu[i]; + desc.in = in[i].data; + desc.out = out[i].data; + desc.window = window_gpu_; + const auto &in_sh = in[i].shape; + const auto &out_sh = out[i].shape; + desc.in_len = in_sh[0]; + desc.out_len = resampled_length(in_sh[0], in_rates[i], out_rates[i]); + assert(desc.out_len == out_sh[0]); + desc.nchannels = in[i].shape.sample_dim() > 1 ? in_sh[1] : 1; + desc.scale = static_cast(in_rates[i]) / out_rates[i]; + any_multichannel |= desc.nchannels > 1; + } - /** - * @brief Resample multi-channel signal and convert to Out - * - * Calculates a range of resampled signal. - * The function can seamlessly resample the input and produce the result in chunks. - * To reuse memory and still simulate chunk processing, adjust the in/out pointers. - */ - template - void Resample( - Out *__restrict__ out, int64_t out_begin, int64_t out_end, double out_rate, - const float *__restrict__ in, int64_t n_in, double in_rate, - int num_channels, - cudaStream_t stream); + auto samples_gpu = scratch.ToGPU(context.gpu.stream, samples_cpu); + + dim3 block(256, 1); + int blocks_per_sample = std::max(32, 1024 / nsamples); + dim3 grid(blocks_per_sample, nsamples); + + BOOL_SWITCH(downmix && any_multichannel, Downmix, ( + BOOL_SWITCH(!any_multichannel, SingleChannel, ( + ResampleGPUKernel + <<>>(samples_gpu); + )); // NOLINT + )); // NOLINT + CUDA_CALL(cudaGetLastError()); + } + + private: + ResamplingWindowCPU window_cpu_; + ResamplingWindow window_gpu_; + DeviceBuffer window_gpu_storage_; }; } // namespace resampling @@ -57,7 +112,4 @@ struct ResamplerGPU { } // namespace kernels } // namespace dali -} // namespace dali - - #endif // DALI_KERNELS_SIGNAL_RESAMPLING_GPU_H_ diff --git a/dali/kernels/signal/resampling_gpu_test.cu b/dali/kernels/signal/resampling_gpu_test.cu new file mode 100644 index 00000000000..aae6033ff8b --- /dev/null +++ b/dali/kernels/signal/resampling_gpu_test.cu @@ -0,0 +1,78 @@ +// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "dali/kernels/signal/resampling_gpu.h" +#include "dali/kernels/signal/resampling_test.h" + + +namespace dali { +namespace kernels { +namespace signal { +namespace resampling { + +class ResamplingGPUTest : public ResamplingTest { + public: + void RunResampling(span in_rates, span out_rates, + bool downmix) override { + ResamplerGPU R; + R.Initialize(16); + + KernelContext ctx; + ctx.gpu.stream = 0; + + auto req = R.Setup(ctx, ttl_in_.gpu(), in_rates, out_rates, downmix); + auto outref_sh = ttl_outref_.cpu().shape; + auto in_batch_sh = ttl_in_.cpu().shape; + for (int s = 0; s < outref_sh.size(); s++) { + auto sh = req.output_shapes[0].tensor_shape_span(s); + auto expected_sh = outref_sh.tensor_shape_span(s); + auto in_sh = in_batch_sh.tensor_shape_span(s); + ASSERT_EQ(sh.size(), in_sh.size()); + if (downmix) { + ASSERT_EQ(sh[1], 1); + } else { + if (sh.size() > 1) + ASSERT_EQ(sh[1], in_sh[1]); + } + } + + R.Run(ctx, ttl_out_.gpu(), ttl_in_.gpu(), in_rates, out_rates, downmix); + + CUDA_CALL(cudaStreamSynchronize(ctx.gpu.stream)); + } +}; + +TEST_F(ResamplingGPUTest, SingleChannel) { + this->RunTest(8, 1, false); +} + +TEST_F(ResamplingGPUTest, TwoChannel) { + this->RunTest(3, 2, false); +} + +TEST_F(ResamplingGPUTest, EightChannel) { + this->RunTest(3, 8, false); +} + +TEST_F(ResamplingGPUTest, ThreeChannelDownmix) { + this->RunTest(3, 3, true); +} + +} // namespace resampling +} // namespace signal +} // namespace kernels +} // namespace dali diff --git a/dali/kernels/signal/resampling_test.cc b/dali/kernels/signal/resampling_test.cc index 9dd2233867d..642177563c9 100644 --- a/dali/kernels/signal/resampling_test.cc +++ b/dali/kernels/signal/resampling_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,91 +16,143 @@ #include #include #include "dali/kernels/signal/resampling.h" +#include "dali/kernels/signal/resampling_test.h" namespace dali { namespace kernels { namespace signal { namespace resampling { -namespace { +void ResamplingTest::PrepareData(int nsamples, int nchannels, span in_rates, + span out_rates) { + TensorListShape<> in_sh(nsamples, nchannels > 1 ? 2 : 1); + TensorListShape<> out_sh(nsamples, nchannels > 1 ? 2 : 1); + for (int s = 0; s < nsamples; s++) { + double in_rate = in_rates[s]; + double out_rate = out_rates[s]; + double scale = static_cast(in_rate) / out_rate; + int n_in = in_rate + 12345 * s; // different lengths + int n_out = std::ceil(n_in / scale); + in_sh.tensor_shape_span(s)[0] = n_in; + out_sh.tensor_shape_span(s)[0] = n_out; + if (nchannels > 1) { + in_sh.tensor_shape_span(s)[1] = nchannels; + out_sh.tensor_shape_span(s)[1] = nchannels; + } + } + ttl_in_.reshape(in_sh); + ttl_out_.reshape(out_sh); + ttl_outref_.reshape(out_sh); + for (int s = 0; s < nsamples; s++) { + double in_rate = in_rates[s]; + double out_rate = out_rates[s]; + double scale = static_cast(in_rate) / out_rate; + for (int c = 0; c < nchannels; c++) { + float f_in = 0.1f + 0.01 * s + 0.001 * c; + float f_out = f_in * scale; + int n_in = in_sh.tensor_shape_span(s)[0]; + int n_out = out_sh.tensor_shape_span(s)[0]; + TestWave(ttl_in_.cpu()[s].data + c, n_in, nchannels, f_in); + TestWave(ttl_outref_.cpu()[s].data + c, n_out, nchannels, f_out); + } + } +} + +void ResamplingTest::Verify(bool downmix) { + auto in_sh = ttl_in_.cpu().shape; + auto out_sh = ttl_outref_.cpu().shape; + int nsamples = in_sh.num_samples(); + double err = 0, max_diff = 0; + + for (int s = 0; s < nsamples; s++) { + float *out_data = ttl_out_.cpu()[s].data; + float *out_ref = ttl_outref_.cpu()[s].data; + int n_out = out_sh.tensor_shape_span(s)[0]; + int nchannels = out_sh.sample_dim() == 1 ? 1 : out_sh.tensor_shape_span(s)[1]; + for (int i = 0; i < n_out; i++) { + float ref_val = 0; + if (downmix) { + for (int c = 0; c < nchannels; c++) { + ref_val += out_ref[i * nchannels + c]; + } + ref_val /= nchannels; + } else { + ref_val = out_ref[i]; + } -double HannWindow(int i, int n) { - assert(n > 0); - return Hann(2.0*i / n - 1); + ASSERT_NEAR(out_data[i], ref_val, eps()) + << "Sample error too big @ sample=" << s << " pos=" << i << std::endl; + float diff = std::abs(out_data[i] - ref_val); + if (diff > max_diff) + max_diff = diff; + err += diff * diff; + } + + err = std::sqrt(err / n_out); + EXPECT_LE(err, max_avg_err()) << "Average error too big"; + std::cerr << "Resampling with Hann-windowed sinc filter and 16 zero crossings" + "\n max difference vs fresh signal: " + << max_diff << "\n RMS error: " << err << std::endl; + } } -template -void TestWave(T *out, int n, int stride, float freq) { - for (int i = 0; i < n; i++) { - float x = i * freq; - float f = std::sin(i* freq) * HannWindow(i, n); - out[i*stride] = ConvertSatNorm(f); +void ResamplingTest::RunTest(int nsamples, int nchannels, bool downmix) { + std::vector in_rates_v; + for (int i = 0; i < nsamples; i++) { + if (i % 2 == 0) + in_rates_v.push_back(22050.0f); + else + in_rates_v.push_back(44100.0f); } + auto in_rates = make_cspan(in_rates_v); + + std::vector out_rates_v(nsamples, 16000.0f); + auto out_rates = make_cspan(out_rates_v); + + PrepareData(nsamples, nchannels, in_rates, out_rates); + + RunResampling(in_rates, out_rates, downmix); + + Verify(downmix); } -} // namespace - -TEST(ResampleSinc, SingleChannel) { - int n_in = 22050, n_out = 16000; // typical downsampling - std::vector in(n_in); - std::vector out(n_out); - std::vector ref(out.size()); - float f_in = 0.1f; - float f_out = f_in * n_in / n_out; - double in_rate = n_in; - double out_rate = n_out; - TestWave(in.data(), n_in, 1, f_in); - TestWave(ref.data(), n_out, 1, f_out); - Resampler R; - R.Initialize(16); - R.Resample(out.data(), 0, n_out, out_rate, in.data(), n_in, in_rate); +class ResamplingCPUTest : public ResamplingTest { + public: + void RunResampling(span in_rates, span out_rates, bool downmix) override { + Resampler R; + R.Initialize(16); - double err = 0, max_diff = 0; - for (int i = 0; i < n_out; i++) { - ASSERT_NEAR(out[i], ref[i], 1e-3) << "Sample error too big @" << i << std::endl; - float diff = std::abs(out[i] - ref[i]); - if (diff > max_diff) - max_diff = diff; - err += diff*diff; + int nsamples = in_rates.size(); + assert(nsamples == out_rates.size()); + + auto in_view = ttl_in_.cpu(); + auto out_view = ttl_out_.cpu(); + for (int s = 0; s < nsamples; s++) { + auto out_sh = out_view.shape[s]; + auto in_sh = in_view.shape[s]; + int n_out = out_sh[0]; + int n_in = in_sh[0]; + int nchannels = in_sh.sample_dim() > 1 ? in_sh[1] : 1; + R.Resample(out_view[s].data, 0, n_out, out_rates[s], in_view[s].data, n_in, in_rates[s], + nchannels, downmix); + } } - err = std::sqrt(err/n_out); - EXPECT_LE(err, 1e-3) << "Average error too big"; - std::cerr << "Resampling with Hann-windowed sinc filter and 16 zero crossings" - "\n max difference vs fresh signal: " << max_diff << - "\n RMS error: " << err << std::endl; +}; + +TEST_F(ResamplingCPUTest, SingleChannel) { + this->RunTest(1, 1, false); } -TEST(ResampleSinc, MultiChannel) { - int n_in = 22050, n_out = 22053; // some weird upsampling - int ch = 5; - std::vector in(n_in * ch); - std::vector out(n_out * ch); - std::vector ref(out.size()); - double in_rate = n_in; - double out_rate = n_out; - for (int c = 0; c < ch; c++) { - float f_in = 0.1f * (1 + c * 0.012345); // different signal in each channel - float f_out = f_in * n_in / n_out; - TestWave(in.data() + c, n_in, ch, f_in); - TestWave(ref.data() + c, n_out, ch, f_out); - } - Resampler R; - R.Initialize(16); - R.Resample(out.data(), 0, n_out, out_rate, in.data(), n_in, in_rate, ch); +TEST_F(ResamplingCPUTest, TwoChannel) { + this->RunTest(1, 2, false); +} - double err = 0, max_diff = 0; - for (int i = 0; i < n_out * ch; i++) { - ASSERT_NEAR(out[i], ref[i], 2e-3) << "Sample error too big @" << i << std::endl; - float diff = std::abs(out[i] - ref[i]); - if (diff > max_diff) - max_diff = diff; - err += diff*diff; - } - err = std::sqrt(err/(n_out * ch)); - EXPECT_LE(err, 1e-3) << "Average error too big"; - std::cerr << "Resampling with Hann-windowed sinc filter and 16 zero crossings" - "\n max difference vs fresh signal: " << max_diff << - "\n RMS error: " << err << std::endl; +TEST_F(ResamplingCPUTest, EightChannel) { + this->RunTest(1, 8, false); +} + +TEST_F(ResamplingCPUTest, ThreeChannelDownmix) { + this->RunTest(1, 3, true); } } // namespace resampling diff --git a/dali/kernels/signal/resampling_test.h b/dali/kernels/signal/resampling_test.h new file mode 100644 index 00000000000..0163410daf1 --- /dev/null +++ b/dali/kernels/signal/resampling_test.h @@ -0,0 +1,67 @@ +// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "dali/kernels/signal/resampling.h" +#include "dali/test/tensor_test_utils.h" +#include "dali/test/test_tensors.h" + +namespace dali { +namespace kernels { +namespace signal { +namespace resampling { + +namespace { + +double HannWindow(int i, int n) { + assert(n > 0); + return Hann(2.0*i / n - 1); +} + +template +void TestWave(T *out, int n, int stride, float freq) { + for (int i = 0; i < n; i++) { + float f = std::sin(i* freq) * HannWindow(i, n); + out[i*stride] = ConvertSatNorm(f); + } +} + +} // namespace + +class ResamplingTest : public ::testing::Test { + public: + void PrepareData(int nsamples, int nchannels, + span in_rates, span out_rates); + + virtual float eps() const { return 2e-3; } + virtual float max_avg_err() const { return 1e-3; } + void Verify(bool downmix); + + virtual void RunResampling(span in_rates, span out_rates, bool downmix) = 0; + + void RunTest(int nsamples, int nchannels, bool downmix); + + + TestTensorList ttl_in_; + TestTensorList ttl_out_; + TestTensorList ttl_outref_; +}; + + +} // namespace resampling +} // namespace signal +} // namespace kernels +} // namespace dali From 9a403ea54299d981cfac7a9947bdb7273b6cb751 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Wed, 11 May 2022 12:55:50 +0200 Subject: [PATCH 18/36] Remove downmixing Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling.h | 29 +++++---------- dali/kernels/signal/resampling_gpu.cuh | 17 +++------ dali/kernels/signal/resampling_gpu.h | 16 +++----- dali/kernels/signal/resampling_gpu_test.cu | 29 +++++---------- dali/kernels/signal/resampling_test.cc | 43 +++++++++------------- dali/kernels/signal/resampling_test.h | 22 +++++------ 6 files changed, 58 insertions(+), 98 deletions(-) diff --git a/dali/kernels/signal/resampling.h b/dali/kernels/signal/resampling.h index 6fd8917209a..802c8a4c225 100644 --- a/dali/kernels/signal/resampling.h +++ b/dali/kernels/signal/resampling.h @@ -253,9 +253,8 @@ struct Resampler { * To reuse memory and still simulate chunk processing, adjust the in/out pointers. * * @tparam static_channels number of channels, if known at compile time, or -1 - * @tparam downmix whether to downmix all channels in the output */ - template + template void Resample( Out *__restrict__ out, int64_t out_begin, int64_t out_end, double out_rate, const float *__restrict__ in, int64_t n_in, double in_rate, @@ -308,16 +307,8 @@ struct Resampler { } } assert(out_pos >= out_begin && out_pos < out_end); - if (downmix) { - float out_val = 0; - for (int c = 0; c < num_channels; c++) - out_val += tmp[c]; - out_val /= num_channels; - out[out_pos] = ConvertSatNorm(out_val); - } else { - for (int c = 0; c < num_channels; c++) - out[out_pos * num_channels + c] = ConvertSatNorm(tmp[c]); - } + for (int c = 0; c < num_channels; c++) + out[out_pos * num_channels + c] = ConvertSatNorm(tmp[c]); } } } @@ -333,14 +324,12 @@ struct Resampler { void Resample( Out *__restrict__ out, int64_t out_begin, int64_t out_end, double out_rate, const float *__restrict__ in, int64_t n_in, double in_rate, - int num_channels, bool downmix = false) { - BOOL_SWITCH(downmix, Downmix, ( - VALUE_SWITCH(num_channels, static_channels, (1, 2, 3, 4, 5, 6, 7, 8), - (Resample(out, out_begin, out_end, out_rate, - in, n_in, in_rate, static_channels);), - (Resample<-1, Downmix, Out>(out, out_begin, out_end, out_rate, - in, n_in, in_rate, num_channels))); - )); // NOLINT + int num_channels) { + VALUE_SWITCH(num_channels, static_channels, (1, 2, 3, 4, 5, 6, 7, 8), + (Resample(out, out_begin, out_end, out_rate, + in, n_in, in_rate, static_channels);), + (Resample<-1, Out>(out, out_begin, out_end, out_rate, + in, n_in, in_rate, num_channels))); } }; diff --git a/dali/kernels/signal/resampling_gpu.cuh b/dali/kernels/signal/resampling_gpu.cuh index 6d32f6d227c..39ad0f7aec2 100644 --- a/dali/kernels/signal/resampling_gpu.cuh +++ b/dali/kernels/signal/resampling_gpu.cuh @@ -35,11 +35,11 @@ struct SampleDesc { }; /** - * @brief Resamples 1D signal (single or multi-channel), optionally downmixing and converting to a different data type. + * @brief Resamples 1D signal (single or multi-channel), optionally converting to a different data type. * * @param samples sample descriptors */ -template +template __global__ void ResampleGPUKernel(const SampleDesc *samples) { auto sample = samples[blockIdx.y]; double scale = sample.scale; @@ -80,6 +80,7 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { } out[out_pos] = ConvertSatNorm(out_val); } else { // multiple channels + assert(nchannels <= 32); float tmp[32]; // more than enough for (int c = 0; c < nchannels; c++) { tmp[c] = 0; @@ -94,16 +95,8 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { } } - if (Downmix) { - for (int c = 0; c < nchannels; c++) { - out_val += tmp[c]; - } - out_val /= nchannels; - out[out_pos] = ConvertSatNorm(out_val); - } else { - for (int c = 0; c < nchannels; c++) { - out[out_pos * nchannels + c] = ConvertSatNorm(tmp[c]); - } + for (int c = 0; c < nchannels; c++) { + out[out_pos * nchannels + c] = ConvertSatNorm(tmp[c]); } } } diff --git a/dali/kernels/signal/resampling_gpu.h b/dali/kernels/signal/resampling_gpu.h index b9bc6ba220b..ed7db6ff2e9 100644 --- a/dali/kernels/signal/resampling_gpu.h +++ b/dali/kernels/signal/resampling_gpu.h @@ -41,23 +41,21 @@ class ResamplerGPU { } KernelRequirements Setup(KernelContext &context, const InListGPU &in, - span in_rate, span out_rate, bool downmix) { + span in_rate, span out_rate) { KernelRequirements req; auto out_shape = in.shape; for (int i = 0; i < in.num_samples(); i++) { auto in_sh = in.shape.tensor_shape_span(i); auto out_sh = out_shape.tensor_shape_span(i); out_sh[0] = resampled_length(in_sh[0], in_rate[i], out_rate[i]); - if (downmix) - out_sh[1] = 1; } req.output_shapes = {out_shape}; return req; } void Run(KernelContext &context, const OutListGPU &out, - const InListGPU &in, span in_rates, span out_rates, - bool downmix) { + const InListGPU &in, span in_rates, + span out_rates) { if (window_gpu_storage_.empty()) Initialize(); @@ -92,11 +90,9 @@ class ResamplerGPU { int blocks_per_sample = std::max(32, 1024 / nsamples); dim3 grid(blocks_per_sample, nsamples); - BOOL_SWITCH(downmix && any_multichannel, Downmix, ( - BOOL_SWITCH(!any_multichannel, SingleChannel, ( - ResampleGPUKernel - <<>>(samples_gpu); - )); // NOLINT + BOOL_SWITCH(!any_multichannel, SingleChannel, ( + ResampleGPUKernel + <<>>(samples_gpu); )); // NOLINT CUDA_CALL(cudaGetLastError()); } diff --git a/dali/kernels/signal/resampling_gpu_test.cu b/dali/kernels/signal/resampling_gpu_test.cu index aae6033ff8b..23f1fef28d9 100644 --- a/dali/kernels/signal/resampling_gpu_test.cu +++ b/dali/kernels/signal/resampling_gpu_test.cu @@ -18,60 +18,49 @@ #include "dali/kernels/signal/resampling_gpu.h" #include "dali/kernels/signal/resampling_test.h" - namespace dali { namespace kernels { namespace signal { namespace resampling { +namespace test { class ResamplingGPUTest : public ResamplingTest { public: - void RunResampling(span in_rates, span out_rates, - bool downmix) override { + void RunResampling(span in_rates, span out_rates) override { ResamplerGPU R; R.Initialize(16); KernelContext ctx; ctx.gpu.stream = 0; - auto req = R.Setup(ctx, ttl_in_.gpu(), in_rates, out_rates, downmix); + auto req = R.Setup(ctx, ttl_in_.gpu(), in_rates, out_rates); auto outref_sh = ttl_outref_.cpu().shape; auto in_batch_sh = ttl_in_.cpu().shape; for (int s = 0; s < outref_sh.size(); s++) { auto sh = req.output_shapes[0].tensor_shape_span(s); auto expected_sh = outref_sh.tensor_shape_span(s); - auto in_sh = in_batch_sh.tensor_shape_span(s); - ASSERT_EQ(sh.size(), in_sh.size()); - if (downmix) { - ASSERT_EQ(sh[1], 1); - } else { - if (sh.size() > 1) - ASSERT_EQ(sh[1], in_sh[1]); - } + ASSERT_EQ(sh, expected_sh); } - R.Run(ctx, ttl_out_.gpu(), ttl_in_.gpu(), in_rates, out_rates, downmix); + R.Run(ctx, ttl_out_.gpu(), ttl_in_.gpu(), in_rates, out_rates); CUDA_CALL(cudaStreamSynchronize(ctx.gpu.stream)); } }; TEST_F(ResamplingGPUTest, SingleChannel) { - this->RunTest(8, 1, false); + this->RunTest(8, 1); } TEST_F(ResamplingGPUTest, TwoChannel) { - this->RunTest(3, 2, false); + this->RunTest(3, 2); } TEST_F(ResamplingGPUTest, EightChannel) { - this->RunTest(3, 8, false); -} - -TEST_F(ResamplingGPUTest, ThreeChannelDownmix) { - this->RunTest(3, 3, true); + this->RunTest(3, 8); } +} // namespace test } // namespace resampling } // namespace signal } // namespace kernels diff --git a/dali/kernels/signal/resampling_test.cc b/dali/kernels/signal/resampling_test.cc index 642177563c9..538506a283e 100644 --- a/dali/kernels/signal/resampling_test.cc +++ b/dali/kernels/signal/resampling_test.cc @@ -22,6 +22,12 @@ namespace dali { namespace kernels { namespace signal { namespace resampling { +namespace test { + +double HannWindow(int i, int n) { + assert(n > 0); + return Hann(2.0*i / n - 1); +} void ResamplingTest::PrepareData(int nsamples, int nchannels, span in_rates, span out_rates) { @@ -58,7 +64,7 @@ void ResamplingTest::PrepareData(int nsamples, int nchannels, span } } -void ResamplingTest::Verify(bool downmix) { +void ResamplingTest::Verify() { auto in_sh = ttl_in_.cpu().shape; auto out_sh = ttl_outref_.cpu().shape; int nsamples = in_sh.num_samples(); @@ -70,19 +76,9 @@ void ResamplingTest::Verify(bool downmix) { int n_out = out_sh.tensor_shape_span(s)[0]; int nchannels = out_sh.sample_dim() == 1 ? 1 : out_sh.tensor_shape_span(s)[1]; for (int i = 0; i < n_out; i++) { - float ref_val = 0; - if (downmix) { - for (int c = 0; c < nchannels; c++) { - ref_val += out_ref[i * nchannels + c]; - } - ref_val /= nchannels; - } else { - ref_val = out_ref[i]; - } - - ASSERT_NEAR(out_data[i], ref_val, eps()) + ASSERT_NEAR(out_data[i], out_ref[i], eps()) << "Sample error too big @ sample=" << s << " pos=" << i << std::endl; - float diff = std::abs(out_data[i] - ref_val); + float diff = std::abs(out_data[i] - out_ref[i]); if (diff > max_diff) max_diff = diff; err += diff * diff; @@ -96,7 +92,7 @@ void ResamplingTest::Verify(bool downmix) { } } -void ResamplingTest::RunTest(int nsamples, int nchannels, bool downmix) { +void ResamplingTest::RunTest(int nsamples, int nchannels) { std::vector in_rates_v; for (int i = 0; i < nsamples; i++) { if (i % 2 == 0) @@ -111,14 +107,14 @@ void ResamplingTest::RunTest(int nsamples, int nchannels, bool downmix) { PrepareData(nsamples, nchannels, in_rates, out_rates); - RunResampling(in_rates, out_rates, downmix); + RunResampling(in_rates, out_rates); - Verify(downmix); + Verify(); } class ResamplingCPUTest : public ResamplingTest { public: - void RunResampling(span in_rates, span out_rates, bool downmix) override { + void RunResampling(span in_rates, span out_rates) override { Resampler R; R.Initialize(16); @@ -134,27 +130,24 @@ class ResamplingCPUTest : public ResamplingTest { int n_in = in_sh[0]; int nchannels = in_sh.sample_dim() > 1 ? in_sh[1] : 1; R.Resample(out_view[s].data, 0, n_out, out_rates[s], in_view[s].data, n_in, in_rates[s], - nchannels, downmix); + nchannels); } } }; TEST_F(ResamplingCPUTest, SingleChannel) { - this->RunTest(1, 1, false); + this->RunTest(1, 1); } TEST_F(ResamplingCPUTest, TwoChannel) { - this->RunTest(1, 2, false); + this->RunTest(1, 2); } TEST_F(ResamplingCPUTest, EightChannel) { - this->RunTest(1, 8, false); -} - -TEST_F(ResamplingCPUTest, ThreeChannelDownmix) { - this->RunTest(1, 3, true); + this->RunTest(1, 8); } +} // namespace test } // namespace resampling } // namespace signal } // namespace kernels diff --git a/dali/kernels/signal/resampling_test.h b/dali/kernels/signal/resampling_test.h index 0163410daf1..2eea4965528 100644 --- a/dali/kernels/signal/resampling_test.h +++ b/dali/kernels/signal/resampling_test.h @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef DALI_KERNELS_SIGNAL_RESAMPLING_TEST_H_ +#define DALI_KERNELS_SIGNAL_RESAMPLING_TEST_H_ + #include #include #include @@ -23,13 +26,9 @@ namespace dali { namespace kernels { namespace signal { namespace resampling { +namespace test { -namespace { - -double HannWindow(int i, int n) { - assert(n > 0); - return Hann(2.0*i / n - 1); -} +double HannWindow(int i, int n); template void TestWave(T *out, int n, int stride, float freq) { @@ -39,8 +38,6 @@ void TestWave(T *out, int n, int stride, float freq) { } } -} // namespace - class ResamplingTest : public ::testing::Test { public: void PrepareData(int nsamples, int nchannels, @@ -48,11 +45,11 @@ class ResamplingTest : public ::testing::Test { virtual float eps() const { return 2e-3; } virtual float max_avg_err() const { return 1e-3; } - void Verify(bool downmix); + void Verify(); - virtual void RunResampling(span in_rates, span out_rates, bool downmix) = 0; + virtual void RunResampling(span in_rates, span out_rates) = 0; - void RunTest(int nsamples, int nchannels, bool downmix); + void RunTest(int nsamples, int nchannels); TestTensorList ttl_in_; @@ -61,7 +58,10 @@ class ResamplingTest : public ::testing::Test { }; +} // namespace test } // namespace resampling } // namespace signal } // namespace kernels } // namespace dali + +#endif // DALI_KERNELS_SIGNAL_RESAMPLING_TEST_H_ From c6374a471432f77c709d06e87e2c0d19849347a3 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Wed, 11 May 2022 13:11:21 +0200 Subject: [PATCH 19/36] Code review fixes Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling.h | 15 +++++++++------ dali/kernels/signal/resampling_gpu.cuh | 4 ++-- dali/kernels/signal/resampling_gpu_test.cu | 2 +- dali/kernels/signal/resampling_test.h | 2 +- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/dali/kernels/signal/resampling.h b/dali/kernels/signal/resampling.h index 802c8a4c225..410ca30a3d7 100644 --- a/dali/kernels/signal/resampling.h +++ b/dali/kernels/signal/resampling.h @@ -29,7 +29,6 @@ #include "dali/core/small_vector.h" #include "dali/core/convert.h" #include "dali/core/static_switch.h" -#include "dali/core/geom/vec.h" namespace dali { namespace kernels { @@ -55,7 +54,11 @@ inline float32x4_t vsetq_f32(float x0, float x1, float x2, float x3) { #endif struct ResamplingWindow { - inline DALI_HOST_DEV ivec<2> input_range(float x) const { + struct InputRange { + int i0, i1; + }; + + inline DALI_HOST_DEV InputRange input_range(float x) const { int xc = std::ceil(x); int i0 = xc - lobes; int i1 = xc + lobes; @@ -222,8 +225,8 @@ struct Resampler { const float *__restrict__ in_block_ptr = in + in_block_i; for (int64_t out_pos = out_block; out_pos < block_end; out_pos++, in_pos += fscale) { auto irange = window.input_range(in_pos); - int i0 = irange[0]; - int i1 = irange[1]; + int i0 = irange.i0; + int i1 = irange.i1; if (i0 + in_block_i < 0) i0 = -in_block_i; if (i1 + in_block_i > n_in) @@ -285,8 +288,8 @@ struct Resampler { const float *__restrict__ in_block_ptr = in + in_block_i * num_channels; for (int64_t out_pos = out_block; out_pos < block_end; out_pos++, in_pos += fscale) { auto irange = window.input_range(in_pos); - int i0 = irange[0]; - int i1 = irange[1]; + int i0 = irange.i0; + int i1 = irange.i1; if (i0 + in_block_i < 0) i0 = -in_block_i; if (i1 + in_block_i > n_in) diff --git a/dali/kernels/signal/resampling_gpu.cuh b/dali/kernels/signal/resampling_gpu.cuh index 39ad0f7aec2..3f5878fe3f6 100644 --- a/dali/kernels/signal/resampling_gpu.cuh +++ b/dali/kernels/signal/resampling_gpu.cuh @@ -63,8 +63,8 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { out_pos += grid_stride, in_pos_start += fscale * grid_stride) { float in_pos = in_pos_start + fscale * threadIdx.x; auto i_range = window.input_range(in_pos); - int i0 = i_range[0]; - int i1 = i_range[1]; + int i0 = i_range.i0; + int i1 = i_range.i1; if (i0 + in_block_i < 0) i0 = -in_block_i; if (i1 + in_block_i > sample.in_len) diff --git a/dali/kernels/signal/resampling_gpu_test.cu b/dali/kernels/signal/resampling_gpu_test.cu index 23f1fef28d9..bf6bcf349bc 100644 --- a/dali/kernels/signal/resampling_gpu_test.cu +++ b/dali/kernels/signal/resampling_gpu_test.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dali/kernels/signal/resampling_test.h b/dali/kernels/signal/resampling_test.h index 2eea4965528..2afadd281df 100644 --- a/dali/kernels/signal/resampling_test.h +++ b/dali/kernels/signal/resampling_test.h @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. From d88e78b3332cebdeb966c21e93c25b56a70ab988 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Wed, 11 May 2022 19:39:24 +0200 Subject: [PATCH 20/36] Add benchmark Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling_gpu_test.cu | 43 ++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/dali/kernels/signal/resampling_gpu_test.cu b/dali/kernels/signal/resampling_gpu_test.cu index bf6bcf349bc..c6bee6a9d6b 100644 --- a/dali/kernels/signal/resampling_gpu_test.cu +++ b/dali/kernels/signal/resampling_gpu_test.cu @@ -17,6 +17,7 @@ #include #include "dali/kernels/signal/resampling_gpu.h" #include "dali/kernels/signal/resampling_test.h" +#include "dali/core/cuda_event.h" namespace dali { namespace kernels { @@ -46,6 +47,44 @@ class ResamplingGPUTest : public ResamplingTest { CUDA_CALL(cudaStreamSynchronize(ctx.gpu.stream)); } + + void RunPerfTest(int batch_size, int nchannels, int n_iters = 1000) { + std::vector in_rates_v(batch_size, 22050.0f); + auto in_rates = make_cspan(in_rates_v); + std::vector out_rates_v(batch_size, 16000.0f); + auto out_rates = make_cspan(out_rates_v); + + this->PrepareData(batch_size, nchannels, in_rates, out_rates); + + ResamplerGPU R; + R.Initialize(16); + + KernelContext ctx; + ctx.gpu.stream = 0; + + CUDAEvent start = CUDAEvent::CreateWithFlags(0); + CUDAEvent end = CUDAEvent::CreateWithFlags(0); + double avg_time = 0; + int64_t in_elems = ttl_in_.cpu().shape.num_elements(); + int64_t in_bytes = in_elems * sizeof(float); + std::cout << "Resampling GPU Perf test.\n" + << "\nInput contains " << in_elems << " floats.\n"; + + auto req = R.Setup(ctx, ttl_in_.gpu(), in_rates, out_rates); + ASSERT_EQ(ttl_out_.cpu().shape, req.output_shapes[0]); + + for (int i = 0; i < n_iters; ++i) { + CUDA_CALL(cudaEventRecord(start)); + R.Run(ctx, ttl_out_.gpu(), ttl_in_.gpu(), in_rates, out_rates); + CUDA_CALL(cudaEventRecord(end)); + CUDA_CALL(cudaDeviceSynchronize()); + float time; + CUDA_CALL(cudaEventElapsedTime(&time, start, end)); + + avg_time += time; + } + std::cout << "Processed " << in_bytes / avg_time << " bytes/sec" << std::endl; + } }; TEST_F(ResamplingGPUTest, SingleChannel) { @@ -60,6 +99,10 @@ TEST_F(ResamplingGPUTest, EightChannel) { this->RunTest(3, 8); } +TEST_F(ResamplingGPUTest, PerfTest) { + this->RunPerfTest(64, 1, 1000); +} + } // namespace test } // namespace resampling } // namespace signal From 074bc13cb59ddb6e03744b2aa39a10c4af664819 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Wed, 11 May 2022 19:39:46 +0200 Subject: [PATCH 21/36] Avoid precision issue & add shared memory usage Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling_gpu.cuh | 26 +++++++++++++++++++------- dali/kernels/signal/resampling_gpu.h | 3 ++- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/dali/kernels/signal/resampling_gpu.cuh b/dali/kernels/signal/resampling_gpu.cuh index 3f5878fe3f6..118828382f1 100644 --- a/dali/kernels/signal/resampling_gpu.cuh +++ b/dali/kernels/signal/resampling_gpu.cuh @@ -17,6 +17,7 @@ #include #include "dali/kernels/signal/resampling.h" +#include "dali/core/util.h" namespace dali { namespace kernels { @@ -45,22 +46,33 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { double scale = sample.scale; float fscale = scale; int nchannels = SingleChannel ? 1 : sample.nchannels; - auto &window = sample.window; + auto& window = sample.window; + + extern __shared__ float window_coeffs_sh[]; + for (int k = threadIdx.x; k < window.lookup_size; k += blockDim.x) { + window_coeffs_sh[k] = window.lookup[k]; + } + __syncthreads(); + window.lookup = window_coeffs_sh; Out* out = reinterpret_cast(sample.out); const In* in = reinterpret_cast(sample.in); - int64_t grid_stride = static_cast(gridDim.x) * blockDim.x; - int64_t out_block = static_cast(blockIdx.x) * blockDim.x; - int64_t start_out_pos = out_block + threadIdx.x; + int blocks = div_ceil(sample.out_len, static_cast(blockDim.x)); + int64_t start_block_idx = blockIdx.x * blocks / gridDim.x; + int64_t end_block_idx = (blockIdx.x + 1) * blocks / gridDim.x; + int64_t out_stride = blockDim.x; + float in_stride = fscale * blockDim.x; + int64_t out_block_start = start_block_idx * blockDim.x; + int64_t out_block_end = cuda_min(end_block_idx * blockDim.x, sample.out_len); - double in_block_f = out_block * scale; + double in_block_f = out_block_start * scale; int64_t in_block_i = std::floor(in_block_f); float in_pos_start = in_block_f - in_block_i; const In* in_blk_ptr = in + in_block_i * nchannels; - for (int64_t out_pos = start_out_pos; out_pos < sample.out_len; - out_pos += grid_stride, in_pos_start += fscale * grid_stride) { + for (int64_t out_pos = out_block_start + threadIdx.x; out_pos < out_block_end; + out_pos += out_stride, in_pos_start += in_stride) { float in_pos = in_pos_start + fscale * threadIdx.x; auto i_range = window.input_range(in_pos); int i0 = i_range.i0; diff --git a/dali/kernels/signal/resampling_gpu.h b/dali/kernels/signal/resampling_gpu.h index ed7db6ff2e9..b4a7715d170 100644 --- a/dali/kernels/signal/resampling_gpu.h +++ b/dali/kernels/signal/resampling_gpu.h @@ -89,10 +89,11 @@ class ResamplerGPU { dim3 block(256, 1); int blocks_per_sample = std::max(32, 1024 / nsamples); dim3 grid(blocks_per_sample, nsamples); + size_t shm_size = window_gpu_storage_.size() * sizeof(float); BOOL_SWITCH(!any_multichannel, SingleChannel, ( ResampleGPUKernel - <<>>(samples_gpu); + <<>>(samples_gpu); )); // NOLINT CUDA_CALL(cudaGetLastError()); } From 524804813ebbb7a27880fcedbb0da0979ef42899 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Thu, 12 May 2022 13:06:18 +0200 Subject: [PATCH 22/36] Move double in_block_f calculation inside the loop Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling_gpu.cuh | 36 ++++++++++++-------------- dali/kernels/signal/resampling_gpu.h | 8 ++++-- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/dali/kernels/signal/resampling_gpu.cuh b/dali/kernels/signal/resampling_gpu.cuh index 118828382f1..fa982ba6552 100644 --- a/dali/kernels/signal/resampling_gpu.cuh +++ b/dali/kernels/signal/resampling_gpu.cuh @@ -31,7 +31,7 @@ struct SampleDesc { ResamplingWindow window; int64_t in_len; // num samples in input int64_t out_len; // num samples in output - int64_t nchannels; // number of channels + int nchannels; // number of channels double scale; // in_sampling_rate / out_sampling_rate }; @@ -41,14 +41,17 @@ struct SampleDesc { * @param samples sample descriptors */ template -__global__ void ResampleGPUKernel(const SampleDesc *samples) { +__global__ void ResampleGPUKernel(const SampleDesc *samples, int max_nchannels) { auto sample = samples[blockIdx.y]; double scale = sample.scale; float fscale = scale; int nchannels = SingleChannel ? 1 : sample.nchannels; auto& window = sample.window; - extern __shared__ float window_coeffs_sh[]; + extern __shared__ float sh_mem[]; + float *window_coeffs_sh = sh_mem; + float *tmp = sh_mem + window.lookup_size + + threadIdx.x * max_nchannels; // used to accummulate per-channel out values for (int k = threadIdx.x; k < window.lookup_size; k += blockDim.x) { window_coeffs_sh[k] = window.lookup[k]; } @@ -58,22 +61,17 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { Out* out = reinterpret_cast(sample.out); const In* in = reinterpret_cast(sample.in); - int blocks = div_ceil(sample.out_len, static_cast(blockDim.x)); - int64_t start_block_idx = blockIdx.x * blocks / gridDim.x; - int64_t end_block_idx = (blockIdx.x + 1) * blocks / gridDim.x; - int64_t out_stride = blockDim.x; - float in_stride = fscale * blockDim.x; - int64_t out_block_start = start_block_idx * blockDim.x; - int64_t out_block_end = cuda_min(end_block_idx * blockDim.x, sample.out_len); - - double in_block_f = out_block_start * scale; - int64_t in_block_i = std::floor(in_block_f); - float in_pos_start = in_block_f - in_block_i; - const In* in_blk_ptr = in + in_block_i * nchannels; - - for (int64_t out_pos = out_block_start + threadIdx.x; out_pos < out_block_end; - out_pos += out_stride, in_pos_start += in_stride) { + int64_t grid_stride = static_cast(gridDim.x) * blockDim.x; + int64_t out_block_start = static_cast(blockIdx.x) * blockDim.x; + + for (int64_t out_pos = out_block_start + threadIdx.x; out_pos < sample.out_len; + out_block_start += grid_stride, out_pos += grid_stride) { + double in_block_f = out_block_start * scale; + int64_t in_block_i = std::floor(in_block_f); + float in_pos_start = in_block_f - in_block_i; + const In* in_blk_ptr = in + in_block_i * nchannels; float in_pos = in_pos_start + fscale * threadIdx.x; + auto i_range = window.input_range(in_pos); int i0 = i_range.i0; int i1 = i_range.i1; @@ -92,8 +90,6 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { } out[out_pos] = ConvertSatNorm(out_val); } else { // multiple channels - assert(nchannels <= 32); - float tmp[32]; // more than enough for (int c = 0; c < nchannels; c++) { tmp[c] = 0; } diff --git a/dali/kernels/signal/resampling_gpu.h b/dali/kernels/signal/resampling_gpu.h index b4a7715d170..a648d0b4356 100644 --- a/dali/kernels/signal/resampling_gpu.h +++ b/dali/kernels/signal/resampling_gpu.h @@ -69,6 +69,7 @@ class ResamplerGPU { make_span(scratch.Allocate(nsamples), nsamples); bool any_multichannel = false; + int max_nchannels = 0; for (int i = 0; i < nsamples; i++) { auto &desc = samples_cpu[i]; desc.in = in[i].data; @@ -80,6 +81,7 @@ class ResamplerGPU { desc.out_len = resampled_length(in_sh[0], in_rates[i], out_rates[i]); assert(desc.out_len == out_sh[0]); desc.nchannels = in[i].shape.sample_dim() > 1 ? in_sh[1] : 1; + max_nchannels = std::max(desc.nchannels, max_nchannels); desc.scale = static_cast(in_rates[i]) / out_rates[i]; any_multichannel |= desc.nchannels > 1; } @@ -89,11 +91,13 @@ class ResamplerGPU { dim3 block(256, 1); int blocks_per_sample = std::max(32, 1024 / nsamples); dim3 grid(blocks_per_sample, nsamples); - size_t shm_size = window_gpu_storage_.size() * sizeof(float); + + // window coefficients and temporary per channel out values + size_t shm_size = (window_gpu_storage_.size() + max_nchannels * block.x) * sizeof(float); BOOL_SWITCH(!any_multichannel, SingleChannel, ( ResampleGPUKernel - <<>>(samples_gpu); + <<>>(samples_gpu, max_nchannels); )); // NOLINT CUDA_CALL(cudaGetLastError()); } From 3f74625a16fc96a38c0fc9cc1611ad6abb48301c Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Thu, 12 May 2022 13:10:08 +0200 Subject: [PATCH 23/36] Fix benchmark Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling_gpu_test.cu | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dali/kernels/signal/resampling_gpu_test.cu b/dali/kernels/signal/resampling_gpu_test.cu index c6bee6a9d6b..b5c386eb006 100644 --- a/dali/kernels/signal/resampling_gpu_test.cu +++ b/dali/kernels/signal/resampling_gpu_test.cu @@ -64,7 +64,7 @@ class ResamplingGPUTest : public ResamplingTest { CUDAEvent start = CUDAEvent::CreateWithFlags(0); CUDAEvent end = CUDAEvent::CreateWithFlags(0); - double avg_time = 0; + double total_time = 0; int64_t in_elems = ttl_in_.cpu().shape.num_elements(); int64_t in_bytes = in_elems * sizeof(float); std::cout << "Resampling GPU Perf test.\n" @@ -80,10 +80,9 @@ class ResamplingGPUTest : public ResamplingTest { CUDA_CALL(cudaDeviceSynchronize()); float time; CUDA_CALL(cudaEventElapsedTime(&time, start, end)); - - avg_time += time; + total_time += time; } - std::cout << "Processed " << in_bytes / avg_time << " bytes/sec" << std::endl; + std::cout << "Processed " << n_iters * in_bytes / (total_time * 1e6) << " MBs/sec" << std::endl; } }; From 177fced2242c94cbdc593ba1efa37cb177b435f4 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Thu, 12 May 2022 15:10:04 +0200 Subject: [PATCH 24/36] Update benchmark Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling_gpu.h | 4 +-- dali/kernels/signal/resampling_gpu_test.cu | 32 ++++++++++++++-------- dali/kernels/signal/resampling_test.cc | 4 +-- dali/kernels/signal/resampling_test.h | 2 +- 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/dali/kernels/signal/resampling_gpu.h b/dali/kernels/signal/resampling_gpu.h index a648d0b4356..8548b4c1570 100644 --- a/dali/kernels/signal/resampling_gpu.h +++ b/dali/kernels/signal/resampling_gpu.h @@ -59,9 +59,7 @@ class ResamplerGPU { if (window_gpu_storage_.empty()) Initialize(); - DynamicScratchpad dyn_scratchpad({}, AccessOrder(context.gpu.stream)); - if (!context.scratchpad) - context.scratchpad = &dyn_scratchpad; + assert(context.scratchpad); auto &scratch = *context.scratchpad; int nsamples = in.num_samples(); diff --git a/dali/kernels/signal/resampling_gpu_test.cu b/dali/kernels/signal/resampling_gpu_test.cu index b5c386eb006..278dced32d2 100644 --- a/dali/kernels/signal/resampling_gpu_test.cu +++ b/dali/kernels/signal/resampling_gpu_test.cu @@ -33,6 +33,8 @@ class ResamplingGPUTest : public ResamplingTest { KernelContext ctx; ctx.gpu.stream = 0; + DynamicScratchpad dyn_scratchpad({}, AccessOrder(ctx.gpu.stream)); + ctx.scratchpad = &dyn_scratchpad; auto req = R.Setup(ctx, ttl_in_.gpu(), in_rates, out_rates); auto outref_sh = ttl_outref_.cpu().shape; @@ -53,36 +55,44 @@ class ResamplingGPUTest : public ResamplingTest { auto in_rates = make_cspan(in_rates_v); std::vector out_rates_v(batch_size, 16000.0f); auto out_rates = make_cspan(out_rates_v); + int nsec = 30; - this->PrepareData(batch_size, nchannels, in_rates, out_rates); + this->PrepareData(batch_size, nchannels, in_rates, out_rates, nsec); ResamplerGPU R; R.Initialize(16); KernelContext ctx; ctx.gpu.stream = 0; + auto req = R.Setup(ctx, ttl_in_.gpu(), in_rates, out_rates); + ASSERT_EQ(ttl_out_.cpu().shape, req.output_shapes[0]); CUDAEvent start = CUDAEvent::CreateWithFlags(0); CUDAEvent end = CUDAEvent::CreateWithFlags(0); - double total_time = 0; + double total_time_ms = 0; int64_t in_elems = ttl_in_.cpu().shape.num_elements(); - int64_t in_bytes = in_elems * sizeof(float); + int64_t out_elems = ttl_out_.cpu().shape.num_elements(); + int64_t out_bytes = out_elems * sizeof(float); std::cout << "Resampling GPU Perf test.\n" - << "\nInput contains " << in_elems << " floats.\n"; - - auto req = R.Setup(ctx, ttl_in_.gpu(), in_rates, out_rates); - ASSERT_EQ(ttl_out_.cpu().shape, req.output_shapes[0]); + << "Input contains " << in_elems << " floats.\n" + << "Output contains " << out_elems << " floats.\n"; for (int i = 0; i < n_iters; ++i) { + CUDA_CALL(cudaDeviceSynchronize()); + + DynamicScratchpad dyn_scratchpad({}, AccessOrder(ctx.gpu.stream)); + ctx.scratchpad = &dyn_scratchpad; + CUDA_CALL(cudaEventRecord(start)); R.Run(ctx, ttl_out_.gpu(), ttl_in_.gpu(), in_rates, out_rates); CUDA_CALL(cudaEventRecord(end)); CUDA_CALL(cudaDeviceSynchronize()); - float time; - CUDA_CALL(cudaEventElapsedTime(&time, start, end)); - total_time += time; + float time_ms; + CUDA_CALL(cudaEventElapsedTime(&time_ms, start, end)); + total_time_ms += time_ms; } - std::cout << "Processed " << n_iters * in_bytes / (total_time * 1e6) << " MBs/sec" << std::endl; + std::cout << "Processed " << n_iters * out_bytes / (total_time_ms * 1e6) << " GBs/sec" + << std::endl; } }; diff --git a/dali/kernels/signal/resampling_test.cc b/dali/kernels/signal/resampling_test.cc index 538506a283e..886e7bfe8bf 100644 --- a/dali/kernels/signal/resampling_test.cc +++ b/dali/kernels/signal/resampling_test.cc @@ -30,14 +30,14 @@ double HannWindow(int i, int n) { } void ResamplingTest::PrepareData(int nsamples, int nchannels, span in_rates, - span out_rates) { + span out_rates, int nsec) { TensorListShape<> in_sh(nsamples, nchannels > 1 ? 2 : 1); TensorListShape<> out_sh(nsamples, nchannels > 1 ? 2 : 1); for (int s = 0; s < nsamples; s++) { double in_rate = in_rates[s]; double out_rate = out_rates[s]; double scale = static_cast(in_rate) / out_rate; - int n_in = in_rate + 12345 * s; // different lengths + int n_in = nsec * in_rate + 12345 * s; // different lengths int n_out = std::ceil(n_in / scale); in_sh.tensor_shape_span(s)[0] = n_in; out_sh.tensor_shape_span(s)[0] = n_out; diff --git a/dali/kernels/signal/resampling_test.h b/dali/kernels/signal/resampling_test.h index 2afadd281df..eefa90fbba8 100644 --- a/dali/kernels/signal/resampling_test.h +++ b/dali/kernels/signal/resampling_test.h @@ -41,7 +41,7 @@ void TestWave(T *out, int n, int stride, float freq) { class ResamplingTest : public ::testing::Test { public: void PrepareData(int nsamples, int nchannels, - span in_rates, span out_rates); + span in_rates, span out_rates, int nsec = 1); virtual float eps() const { return 2e-3; } virtual float max_avg_err() const { return 1e-3; } From 744f570080717f441f268066fdc23d378c010110 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Thu, 12 May 2022 18:49:42 +0200 Subject: [PATCH 25/36] ROI & input conversion to float & limit tmp shared mem Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling.h | 11 +++- dali/kernels/signal/resampling_gpu.cuh | 60 +++++++++++------ dali/kernels/signal/resampling_gpu.h | 27 ++++---- dali/kernels/signal/resampling_gpu_test.cu | 32 +++++---- dali/kernels/signal/resampling_test.cc | 77 ++++++++++++---------- dali/kernels/signal/resampling_test.h | 20 +++--- 6 files changed, 141 insertions(+), 86 deletions(-) diff --git a/dali/kernels/signal/resampling.h b/dali/kernels/signal/resampling.h index 410ca30a3d7..9112b602784 100644 --- a/dali/kernels/signal/resampling.h +++ b/dali/kernels/signal/resampling.h @@ -36,6 +36,11 @@ namespace signal { namespace resampling { +struct Args { + double in_rate = 1, out_rate = 1; + int64_t out_begin = 0, out_end = -1; // default values result in the whole range +}; + inline double Hann(double x) { return 0.5 * (1 + std::cos(x * M_PI)); } @@ -241,7 +246,8 @@ struct Resampler { f += in_block_ptr[i] * w; } assert(out_pos >= out_begin && out_pos < out_end); - out[out_pos] = ConvertSatNorm(f); + auto rel_pos = out_pos - out_begin; + out[rel_pos] = ConvertSatNorm(f); } } } @@ -310,8 +316,9 @@ struct Resampler { } } assert(out_pos >= out_begin && out_pos < out_end); + auto rel_pos = out_pos - out_begin; for (int c = 0; c < num_channels; c++) - out[out_pos * num_channels + c] = ConvertSatNorm(tmp[c]); + out[rel_pos * num_channels + c] = ConvertSatNorm(tmp[c]); } } } diff --git a/dali/kernels/signal/resampling_gpu.cuh b/dali/kernels/signal/resampling_gpu.cuh index fa982ba6552..b501eaaab8f 100644 --- a/dali/kernels/signal/resampling_gpu.cuh +++ b/dali/kernels/signal/resampling_gpu.cuh @@ -19,6 +19,8 @@ #include "dali/kernels/signal/resampling.h" #include "dali/core/util.h" +#define SHM_NCHANNELS 16 + namespace dali { namespace kernels { namespace signal { @@ -30,18 +32,33 @@ struct SampleDesc { const void *in; ResamplingWindow window; int64_t in_len; // num samples in input - int64_t out_len; // num samples in output + int64_t out_begin; // output region-of-interest start + int64_t out_end; // output region-of-interest end int nchannels; // number of channels double scale; // in_sampling_rate / out_sampling_rate }; +/** + * @brief Gets intermediate floating point representation depending on the input/output types + */ +template +__device__ float ConvertInput(In in_val) { + if (std::is_unsigned::value && std::is_signed::value) { + return (ConvertSatNorm(in_val) + 1.0f) * 0.5f; + } else if (std::is_signed::value && std::is_unsigned::value) { + return ConvertSatNorm(in_val) * 2.0f - 1.0f; // treat half-range as 0 + } else { + return ConvertSatNorm(in_val); // just normalize + } +} + /** * @brief Resamples 1D signal (single or multi-channel), optionally converting to a different data type. * * @param samples sample descriptors */ template -__global__ void ResampleGPUKernel(const SampleDesc *samples, int max_nchannels) { +__global__ void ResampleGPUKernel(const SampleDesc *samples) { auto sample = samples[blockIdx.y]; double scale = sample.scale; float fscale = scale; @@ -51,7 +68,7 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples, int max_nchannels) extern __shared__ float sh_mem[]; float *window_coeffs_sh = sh_mem; float *tmp = sh_mem + window.lookup_size + - threadIdx.x * max_nchannels; // used to accummulate per-channel out values + threadIdx.x * (SHM_NCHANNELS+1); // used to accummulate per-channel out values for (int k = threadIdx.x; k < window.lookup_size; k += blockDim.x) { window_coeffs_sh[k] = window.lookup[k]; } @@ -62,9 +79,9 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples, int max_nchannels) const In* in = reinterpret_cast(sample.in); int64_t grid_stride = static_cast(gridDim.x) * blockDim.x; - int64_t out_block_start = static_cast(blockIdx.x) * blockDim.x; + int64_t out_block_start = sample.out_begin + static_cast(blockIdx.x) * blockDim.x; - for (int64_t out_pos = out_block_start + threadIdx.x; out_pos < sample.out_len; + for (int64_t out_pos = out_block_start + threadIdx.x; out_pos < sample.out_end; out_block_start += grid_stride, out_pos += grid_stride) { double in_block_f = out_block_start * scale; int64_t in_block_i = std::floor(in_block_f); @@ -83,28 +100,31 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples, int max_nchannels) float out_val = 0; if (SingleChannel) { for (int i = i0; i < i1; i++) { - In in_val = in_blk_ptr[i]; + float in_val = ConvertInput(in_blk_ptr[i]); float x = i - in_pos; float w = window(x); out_val = fma(in_val, w, out_val); } - out[out_pos] = ConvertSatNorm(out_val); + out[out_pos - sample.out_begin] = ConvertSatNorm(out_val); } else { // multiple channels - for (int c = 0; c < nchannels; c++) { - tmp[c] = 0; - } - - for (int i = i0; i < i1; i++) { - float x = i - in_pos; - float w = window(x); - for (int c = 0; c < nchannels; c++) { - In in_val = in_blk_ptr[i * nchannels + c]; - tmp[c] = fma(in_val, w, tmp[c]); + for (int c0 = 0; c0 < nchannels; c0 += SHM_NCHANNELS) { + int nc = cuda_min(SHM_NCHANNELS, nchannels - c0); + for (int c = 0; c < nc; c++) { + tmp[c] = 0; } - } - for (int c = 0; c < nchannels; c++) { - out[out_pos * nchannels + c] = ConvertSatNorm(tmp[c]); + for (int i = i0; i < i1; i++) { + float x = i - in_pos; + float w = window(x); + for (int c = 0; c < nc; c++) { + float in_val = ConvertInput(in_blk_ptr[i * nchannels + c]); + tmp[c] = fma(in_val, w, tmp[c]); + } + } + Out *out_ptr = out + (out_pos - sample.out_begin) * nchannels; + for (int c = 0; c < nc; c++) { + out_ptr[c + c0] = ConvertSatNorm(tmp[c]); + } } } } diff --git a/dali/kernels/signal/resampling_gpu.h b/dali/kernels/signal/resampling_gpu.h index 8548b4c1570..a02fe6a855f 100644 --- a/dali/kernels/signal/resampling_gpu.h +++ b/dali/kernels/signal/resampling_gpu.h @@ -41,21 +41,25 @@ class ResamplerGPU { } KernelRequirements Setup(KernelContext &context, const InListGPU &in, - span in_rate, span out_rate) { + span args) { KernelRequirements req; auto out_shape = in.shape; for (int i = 0; i < in.num_samples(); i++) { auto in_sh = in.shape.tensor_shape_span(i); auto out_sh = out_shape.tensor_shape_span(i); - out_sh[0] = resampled_length(in_sh[0], in_rate[i], out_rate[i]); + auto &arg = args[i]; + if (arg.out_begin > 0 || arg.out_end > 0) { + out_sh[0] = arg.out_end - arg.out_begin; + } else { + out_sh[0] = resampled_length(in_sh[0], arg.in_rate, arg.out_rate); + } } req.output_shapes = {out_shape}; return req; } void Run(KernelContext &context, const OutListGPU &out, - const InListGPU &in, span in_rates, - span out_rates) { + const InListGPU &in, span args) { if (window_gpu_storage_.empty()) Initialize(); @@ -67,7 +71,6 @@ class ResamplerGPU { make_span(scratch.Allocate(nsamples), nsamples); bool any_multichannel = false; - int max_nchannels = 0; for (int i = 0; i < nsamples; i++) { auto &desc = samples_cpu[i]; desc.in = in[i].data; @@ -76,11 +79,13 @@ class ResamplerGPU { const auto &in_sh = in[i].shape; const auto &out_sh = out[i].shape; desc.in_len = in_sh[0]; - desc.out_len = resampled_length(in_sh[0], in_rates[i], out_rates[i]); - assert(desc.out_len == out_sh[0]); + auto &arg = args[i]; + desc.out_begin = arg.out_begin > 0 ? arg.out_begin : 0; + desc.out_end = arg.out_end > 0 ? arg.out_end : + resampled_length(in_sh[0], arg.in_rate, arg.out_rate); + assert((desc.out_end - desc.out_begin) == out_sh[0]); desc.nchannels = in[i].shape.sample_dim() > 1 ? in_sh[1] : 1; - max_nchannels = std::max(desc.nchannels, max_nchannels); - desc.scale = static_cast(in_rates[i]) / out_rates[i]; + desc.scale = arg.in_rate / arg.out_rate; any_multichannel |= desc.nchannels > 1; } @@ -91,11 +96,11 @@ class ResamplerGPU { dim3 grid(blocks_per_sample, nsamples); // window coefficients and temporary per channel out values - size_t shm_size = (window_gpu_storage_.size() + max_nchannels * block.x) * sizeof(float); + size_t shm_size = (window_gpu_storage_.size() + (SHM_NCHANNELS + 1) * block.x) * sizeof(float); BOOL_SWITCH(!any_multichannel, SingleChannel, ( ResampleGPUKernel - <<>>(samples_gpu, max_nchannels); + <<>>(samples_gpu); )); // NOLINT CUDA_CALL(cudaGetLastError()); } diff --git a/dali/kernels/signal/resampling_gpu_test.cu b/dali/kernels/signal/resampling_gpu_test.cu index 278dced32d2..a4dc4f962c7 100644 --- a/dali/kernels/signal/resampling_gpu_test.cu +++ b/dali/kernels/signal/resampling_gpu_test.cu @@ -27,7 +27,7 @@ namespace test { class ResamplingGPUTest : public ResamplingTest { public: - void RunResampling(span in_rates, span out_rates) override { + void RunResampling(span args) override { ResamplerGPU R; R.Initialize(16); @@ -36,7 +36,7 @@ class ResamplingGPUTest : public ResamplingTest { DynamicScratchpad dyn_scratchpad({}, AccessOrder(ctx.gpu.stream)); ctx.scratchpad = &dyn_scratchpad; - auto req = R.Setup(ctx, ttl_in_.gpu(), in_rates, out_rates); + auto req = R.Setup(ctx, ttl_in_.gpu(), args); auto outref_sh = ttl_outref_.cpu().shape; auto in_batch_sh = ttl_in_.cpu().shape; for (int s = 0; s < outref_sh.size(); s++) { @@ -45,26 +45,24 @@ class ResamplingGPUTest : public ResamplingTest { ASSERT_EQ(sh, expected_sh); } - R.Run(ctx, ttl_out_.gpu(), ttl_in_.gpu(), in_rates, out_rates); + R.Run(ctx, ttl_out_.gpu(), ttl_in_.gpu(), args); CUDA_CALL(cudaStreamSynchronize(ctx.gpu.stream)); } void RunPerfTest(int batch_size, int nchannels, int n_iters = 1000) { - std::vector in_rates_v(batch_size, 22050.0f); - auto in_rates = make_cspan(in_rates_v); - std::vector out_rates_v(batch_size, 16000.0f); - auto out_rates = make_cspan(out_rates_v); + std::vector args_v(batch_size, {22050.0f, 16000.0f}); + auto args = make_cspan(args_v); int nsec = 30; - this->PrepareData(batch_size, nchannels, in_rates, out_rates, nsec); + this->PrepareData(batch_size, nchannels, args, nsec); ResamplerGPU R; R.Initialize(16); KernelContext ctx; ctx.gpu.stream = 0; - auto req = R.Setup(ctx, ttl_in_.gpu(), in_rates, out_rates); + auto req = R.Setup(ctx, ttl_in_.gpu(), args); ASSERT_EQ(ttl_out_.cpu().shape, req.output_shapes[0]); CUDAEvent start = CUDAEvent::CreateWithFlags(0); @@ -84,7 +82,7 @@ class ResamplingGPUTest : public ResamplingTest { ctx.scratchpad = &dyn_scratchpad; CUDA_CALL(cudaEventRecord(start)); - R.Run(ctx, ttl_out_.gpu(), ttl_in_.gpu(), in_rates, out_rates); + R.Run(ctx, ttl_out_.gpu(), ttl_in_.gpu(), args); CUDA_CALL(cudaEventRecord(end)); CUDA_CALL(cudaDeviceSynchronize()); float time_ms; @@ -108,7 +106,19 @@ TEST_F(ResamplingGPUTest, EightChannel) { this->RunTest(3, 8); } -TEST_F(ResamplingGPUTest, PerfTest) { +TEST_F(ResamplingGPUTest, HundredChannel) { + this->RunTest(3, 100); +} + +TEST_F(ResamplingGPUTest, OutBeginEnd) { + this->RunTest(3, 1, true); +} + +TEST_F(ResamplingGPUTest, EightChannelOutBeginEnd) { + this->RunTest(3, 8, true); +} + +TEST_F(ResamplingGPUTest, DISABLED_PerfTest) { this->RunPerfTest(64, 1, 1000); } diff --git a/dali/kernels/signal/resampling_test.cc b/dali/kernels/signal/resampling_test.cc index 886e7bfe8bf..a4070db3b4b 100644 --- a/dali/kernels/signal/resampling_test.cc +++ b/dali/kernels/signal/resampling_test.cc @@ -26,21 +26,23 @@ namespace test { double HannWindow(int i, int n) { assert(n > 0); - return Hann(2.0*i / n - 1); + return Hann(2.0 * i / n - 1); } -void ResamplingTest::PrepareData(int nsamples, int nchannels, span in_rates, - span out_rates, int nsec) { +void ResamplingTest::PrepareData(int nsamples, int nchannels, span args, int nsec) { TensorListShape<> in_sh(nsamples, nchannels > 1 ? 2 : 1); TensorListShape<> out_sh(nsamples, nchannels > 1 ? 2 : 1); for (int s = 0; s < nsamples; s++) { - double in_rate = in_rates[s]; - double out_rate = out_rates[s]; - double scale = static_cast(in_rate) / out_rate; + double in_rate = args[s].in_rate; + double out_rate = args[s].out_rate; + double scale = in_rate / out_rate; int n_in = nsec * in_rate + 12345 * s; // different lengths int n_out = std::ceil(n_in / scale); + int64_t out_begin = args[s].out_begin > 0 ? args[s].out_begin : 0; + int64_t out_end = args[s].out_end > 0 ? args[s].out_end : n_out; + ASSERT_GT(out_end, out_begin); in_sh.tensor_shape_span(s)[0] = n_in; - out_sh.tensor_shape_span(s)[0] = n_out; + out_sh.tensor_shape_span(s)[0] = out_end - out_begin; if (nchannels > 1) { in_sh.tensor_shape_span(s)[1] = nchannels; out_sh.tensor_shape_span(s)[1] = nchannels; @@ -50,21 +52,23 @@ void ResamplingTest::PrepareData(int nsamples, int nchannels, span ttl_out_.reshape(out_sh); ttl_outref_.reshape(out_sh); for (int s = 0; s < nsamples; s++) { - double in_rate = in_rates[s]; - double out_rate = out_rates[s]; + double in_rate = args[s].in_rate; + double out_rate = args[s].out_rate; double scale = static_cast(in_rate) / out_rate; + int64_t n_in = in_sh.tensor_shape_span(s)[0]; + int n_out = std::ceil(n_in / scale); + int64_t out_begin = args[s].out_begin > 0 ? args[s].out_begin : 0; + int64_t out_end = args[s].out_end > 0 ? args[s].out_end : n_out; for (int c = 0; c < nchannels; c++) { float f_in = 0.1f + 0.01 * s + 0.001 * c; float f_out = f_in * scale; - int n_in = in_sh.tensor_shape_span(s)[0]; - int n_out = out_sh.tensor_shape_span(s)[0]; TestWave(ttl_in_.cpu()[s].data + c, n_in, nchannels, f_in); - TestWave(ttl_outref_.cpu()[s].data + c, n_out, nchannels, f_out); + TestWave(ttl_outref_.cpu()[s].data + c, n_out, nchannels, f_out, out_begin, out_end); } } } -void ResamplingTest::Verify() { +void ResamplingTest::Verify(span args) { auto in_sh = ttl_in_.cpu().shape; auto out_sh = ttl_outref_.cpu().shape; int nsamples = in_sh.num_samples(); @@ -73,9 +77,9 @@ void ResamplingTest::Verify() { for (int s = 0; s < nsamples; s++) { float *out_data = ttl_out_.cpu()[s].data; float *out_ref = ttl_outref_.cpu()[s].data; - int n_out = out_sh.tensor_shape_span(s)[0]; + int64_t out_len = out_sh.tensor_shape_span(s)[0]; int nchannels = out_sh.sample_dim() == 1 ? 1 : out_sh.tensor_shape_span(s)[1]; - for (int i = 0; i < n_out; i++) { + for (int64_t i = 0; i < out_len; i++) { ASSERT_NEAR(out_data[i], out_ref[i], eps()) << "Sample error too big @ sample=" << s << " pos=" << i << std::endl; float diff = std::abs(out_data[i] - out_ref[i]); @@ -84,7 +88,7 @@ void ResamplingTest::Verify() { err += diff * diff; } - err = std::sqrt(err / n_out); + err = std::sqrt(err / out_len); EXPECT_LE(err, max_avg_err()) << "Average error too big"; std::cerr << "Resampling with Hann-windowed sinc filter and 16 zero crossings" "\n max difference vs fresh signal: " @@ -92,45 +96,44 @@ void ResamplingTest::Verify() { } } -void ResamplingTest::RunTest(int nsamples, int nchannels) { - std::vector in_rates_v; +void ResamplingTest::RunTest(int nsamples, int nchannels, bool use_roi) { + std::vector args_v; for (int i = 0; i < nsamples; i++) { + int roi_start = use_roi ? 100 : 0; + int roi_end = use_roi ? 8000 : -1; if (i % 2 == 0) - in_rates_v.push_back(22050.0f); + args_v.push_back({22050.0f, 16000.0f, roi_start, roi_end}); else - in_rates_v.push_back(44100.0f); + args_v.push_back({44100.0f, 16000.0f, roi_start, roi_end}); } - auto in_rates = make_cspan(in_rates_v); + auto args = make_cspan(args_v); - std::vector out_rates_v(nsamples, 16000.0f); - auto out_rates = make_cspan(out_rates_v); + PrepareData(nsamples, nchannels, args); - PrepareData(nsamples, nchannels, in_rates, out_rates); + RunResampling(args); - RunResampling(in_rates, out_rates); - - Verify(); + Verify(args); } class ResamplingCPUTest : public ResamplingTest { public: - void RunResampling(span in_rates, span out_rates) override { + void RunResampling(span args) override { Resampler R; R.Initialize(16); - int nsamples = in_rates.size(); - assert(nsamples == out_rates.size()); + int nsamples = args.size(); auto in_view = ttl_in_.cpu(); auto out_view = ttl_out_.cpu(); for (int s = 0; s < nsamples; s++) { auto out_sh = out_view.shape[s]; auto in_sh = in_view.shape[s]; - int n_out = out_sh[0]; int n_in = in_sh[0]; int nchannels = in_sh.sample_dim() > 1 ? in_sh[1] : 1; - R.Resample(out_view[s].data, 0, n_out, out_rates[s], in_view[s].data, n_in, in_rates[s], - nchannels); + int64_t out_begin = args[s].out_begin > 0 ? args[s].out_begin : 0; + int64_t out_end = args[s].out_end > 0 ? args[s].out_end : out_sh[0]; + R.Resample(out_view[s].data, out_begin, out_end, args[s].out_rate, + in_view[s].data, n_in, args[s].in_rate, nchannels); } } }; @@ -147,6 +150,14 @@ TEST_F(ResamplingCPUTest, EightChannel) { this->RunTest(1, 8); } +TEST_F(ResamplingCPUTest, OutBeginEnd) { + this->RunTest(1, 1, true); +} + +TEST_F(ResamplingCPUTest, EightChannelOutBeginEnd) { + this->RunTest(3, 8, true); +} + } // namespace test } // namespace resampling } // namespace signal diff --git a/dali/kernels/signal/resampling_test.h b/dali/kernels/signal/resampling_test.h index eefa90fbba8..51f7c4be7b2 100644 --- a/dali/kernels/signal/resampling_test.h +++ b/dali/kernels/signal/resampling_test.h @@ -31,25 +31,27 @@ namespace test { double HannWindow(int i, int n); template -void TestWave(T *out, int n, int stride, float freq) { - for (int i = 0; i < n; i++) { - float f = std::sin(i* freq) * HannWindow(i, n); - out[i*stride] = ConvertSatNorm(f); +void TestWave(T *out, int n, int stride, float freq, int i_start = 0, int i_end = -1) { + if (i_end <= 0) i_end = n; + assert(i_start >= 0 && i_start <= n); + assert(i_end >= 0 && i_end <= n); + for (int i = i_start; i < i_end; i++) { + float f = std::sin(i * freq) * HannWindow(i, n); + out[(i - i_start) * stride] = ConvertSatNorm(f); } } class ResamplingTest : public ::testing::Test { public: - void PrepareData(int nsamples, int nchannels, - span in_rates, span out_rates, int nsec = 1); + void PrepareData(int nsamples, int nchannels, span args, int nsec = 1); virtual float eps() const { return 2e-3; } virtual float max_avg_err() const { return 1e-3; } - void Verify(); + void Verify(span args); - virtual void RunResampling(span in_rates, span out_rates) = 0; + virtual void RunResampling(span args) = 0; - void RunTest(int nsamples, int nchannels); + void RunTest(int nsamples, int nchannels, bool use_roi = false); TestTensorList ttl_in_; From 12e177cd3b173fa16577ff041dee2aefd89486de Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Mon, 16 May 2022 13:24:48 +0200 Subject: [PATCH 26/36] Add comments Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling.h | 4 ++++ dali/kernels/signal/resampling_gpu.cuh | 3 +++ 2 files changed, 7 insertions(+) diff --git a/dali/kernels/signal/resampling.h b/dali/kernels/signal/resampling.h index 9112b602784..d03ebe896dc 100644 --- a/dali/kernels/signal/resampling.h +++ b/dali/kernels/signal/resampling.h @@ -70,6 +70,10 @@ struct ResamplingWindow { return {i0, i1}; } + /** + * @brief Calculates the window coefficient at an arbitrary floating point position + * by interpolating between two samples. + */ inline DALI_HOST_DEV float operator()(float x) const { float fi = x * scale + center; float floori = std::floor(fi); diff --git a/dali/kernels/signal/resampling_gpu.cuh b/dali/kernels/signal/resampling_gpu.cuh index b501eaaab8f..f08473aa295 100644 --- a/dali/kernels/signal/resampling_gpu.cuh +++ b/dali/kernels/signal/resampling_gpu.cuh @@ -83,6 +83,9 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { for (int64_t out_pos = out_block_start + threadIdx.x; out_pos < sample.out_end; out_block_start += grid_stride, out_pos += grid_stride) { + // A floating point distance `in_pos_start` is calculated from an arbitrary relative + // position, keeping the floats small in order to keep precision. `in_block_f`, used to + // calculate the reference for distance (in_block_i) needs to be calculated in double precision. double in_block_f = out_block_start * scale; int64_t in_block_i = std::floor(in_block_f); float in_pos_start = in_block_f - in_block_i; From ef495e1dc8dabd7754ebb025ede9d8c5d36ba9ee Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Mon, 16 May 2022 13:47:55 +0200 Subject: [PATCH 27/36] Use floorf and ceilf in CUDA code Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling.h | 4 ++-- dali/kernels/signal/resampling_gpu.cuh | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dali/kernels/signal/resampling.h b/dali/kernels/signal/resampling.h index d03ebe896dc..0c0a968be50 100644 --- a/dali/kernels/signal/resampling.h +++ b/dali/kernels/signal/resampling.h @@ -64,7 +64,7 @@ struct ResamplingWindow { }; inline DALI_HOST_DEV InputRange input_range(float x) const { - int xc = std::ceil(x); + int xc = ceilf(x); int i0 = xc - lobes; int i1 = xc + lobes; return {i0, i1}; @@ -76,7 +76,7 @@ struct ResamplingWindow { */ inline DALI_HOST_DEV float operator()(float x) const { float fi = x * scale + center; - float floori = std::floor(fi); + float floori = floorf(fi); float di = fi - floori; int i = floori; assert(i >= 0 && i < lookup_size); diff --git a/dali/kernels/signal/resampling_gpu.cuh b/dali/kernels/signal/resampling_gpu.cuh index f08473aa295..6c612fbad7b 100644 --- a/dali/kernels/signal/resampling_gpu.cuh +++ b/dali/kernels/signal/resampling_gpu.cuh @@ -87,7 +87,7 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { // position, keeping the floats small in order to keep precision. `in_block_f`, used to // calculate the reference for distance (in_block_i) needs to be calculated in double precision. double in_block_f = out_block_start * scale; - int64_t in_block_i = std::floor(in_block_f); + int64_t in_block_i = floorf(in_block_f); float in_pos_start = in_block_f - in_block_i; const In* in_blk_ptr = in + in_block_i * nchannels; float in_pos = in_pos_start + fscale * threadIdx.x; From 73c6cee229cc91f171a42fea7265a87d11cbb252 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Mon, 16 May 2022 18:32:59 +0200 Subject: [PATCH 28/36] Improve tests & fix bugs Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling.h | 4 +- dali/kernels/signal/resampling_gpu.cuh | 10 ++- dali/kernels/signal/resampling_gpu.h | 16 ++-- dali/kernels/signal/resampling_gpu_test.cu | 57 ++++++++++--- dali/kernels/signal/resampling_test.cc | 99 +++++++++++++++------- dali/kernels/signal/resampling_test.h | 27 ++++-- 6 files changed, 150 insertions(+), 63 deletions(-) diff --git a/dali/kernels/signal/resampling.h b/dali/kernels/signal/resampling.h index 0c0a968be50..dc47d601ea4 100644 --- a/dali/kernels/signal/resampling.h +++ b/dali/kernels/signal/resampling.h @@ -222,7 +222,7 @@ struct Resampler { Out *__restrict__ out, int64_t out_begin, int64_t out_end, double out_rate, const float *__restrict__ in, int64_t n_in, double in_rate) const { assert(out_rate > 0 && in_rate > 0 && "Sampling rate must be positive"); - int64_t block = 1 << 10; // still leaves 13 significant bits for fractional part + int64_t block = 1 << 8; // still leaves 15 significant bits for fractional part double scale = in_rate / out_rate; float fscale = scale; @@ -285,7 +285,7 @@ struct Resampler { const int num_channels = static_channels < 0 ? dynamic_num_channels : static_channels; assert(num_channels > 0); - int64_t block = 1 << 10; // still leaves 13 significant bits for fractional part + int64_t block = 1 << 8; // still leaves 15 significant bits for fractional part double scale = in_rate / out_rate; float fscale = scale; SmallVector tmp; diff --git a/dali/kernels/signal/resampling_gpu.cuh b/dali/kernels/signal/resampling_gpu.cuh index 6c612fbad7b..452f7f1c72d 100644 --- a/dali/kernels/signal/resampling_gpu.cuh +++ b/dali/kernels/signal/resampling_gpu.cuh @@ -87,7 +87,7 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { // position, keeping the floats small in order to keep precision. `in_block_f`, used to // calculate the reference for distance (in_block_i) needs to be calculated in double precision. double in_block_f = out_block_start * scale; - int64_t in_block_i = floorf(in_block_f); + int64_t in_block_i = floor(in_block_f); float in_pos_start = in_block_f - in_block_i; const In* in_blk_ptr = in + in_block_i * nchannels; float in_pos = in_pos_start + fscale * threadIdx.x; @@ -110,6 +110,7 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { } out[out_pos - sample.out_begin] = ConvertSatNorm(out_val); } else { // multiple channels + Out *out_ptr = out + (out_pos - sample.out_begin) * nchannels; for (int c0 = 0; c0 < nchannels; c0 += SHM_NCHANNELS) { int nc = cuda_min(SHM_NCHANNELS, nchannels - c0); for (int c = 0; c < nc; c++) { @@ -119,14 +120,15 @@ __global__ void ResampleGPUKernel(const SampleDesc *samples) { for (int i = i0; i < i1; i++) { float x = i - in_pos; float w = window(x); + const In *in_ptr = in_blk_ptr + i * nchannels + c0; for (int c = 0; c < nc; c++) { - float in_val = ConvertInput(in_blk_ptr[i * nchannels + c]); + float in_val = ConvertInput(in_ptr[c]); tmp[c] = fma(in_val, w, tmp[c]); } } - Out *out_ptr = out + (out_pos - sample.out_begin) * nchannels; + for (int c = 0; c < nc; c++) { - out_ptr[c + c0] = ConvertSatNorm(tmp[c]); + out_ptr[c0 + c] = ConvertSatNorm(tmp[c]); } } } diff --git a/dali/kernels/signal/resampling_gpu.h b/dali/kernels/signal/resampling_gpu.h index a02fe6a855f..e8fe697ab62 100644 --- a/dali/kernels/signal/resampling_gpu.h +++ b/dali/kernels/signal/resampling_gpu.h @@ -48,11 +48,17 @@ class ResamplerGPU { auto in_sh = in.shape.tensor_shape_span(i); auto out_sh = out_shape.tensor_shape_span(i); auto &arg = args[i]; - if (arg.out_begin > 0 || arg.out_end > 0) { - out_sh[0] = arg.out_end - arg.out_begin; - } else { - out_sh[0] = resampled_length(in_sh[0], arg.in_rate, arg.out_rate); - } + auto out_len = resampled_length(in_sh[0], arg.in_rate, arg.out_rate); + auto out_begin = arg.out_begin > 0 ? arg.out_begin : 0; + auto out_end = arg.out_end > 0 ? arg.out_end : out_len; + if (out_end < out_begin) + throw std::invalid_argument( + make_string("out_begin can't be larger than out_end. Got out_begin=", out_begin, + ", out_end=", out_end)); + if (out_end > out_len) + throw std::invalid_argument(make_string( + "out_end can't be outside of the range of the output signal: [0, ", out_len, ")")); + out_sh[0] = out_end - out_begin; } req.output_shapes = {out_shape}; return req; diff --git a/dali/kernels/signal/resampling_gpu_test.cu b/dali/kernels/signal/resampling_gpu_test.cu index a4dc4f962c7..46a2302d818 100644 --- a/dali/kernels/signal/resampling_gpu_test.cu +++ b/dali/kernels/signal/resampling_gpu_test.cu @@ -27,6 +27,10 @@ namespace test { class ResamplingGPUTest : public ResamplingTest { public: + ResamplingGPUTest() { + this->nsamples_ = 8; + } + void RunResampling(span args) override { ResamplerGPU R; R.Initialize(16); @@ -50,12 +54,11 @@ class ResamplingGPUTest : public ResamplingTest { CUDA_CALL(cudaStreamSynchronize(ctx.gpu.stream)); } - void RunPerfTest(int batch_size, int nchannels, int n_iters = 1000) { - std::vector args_v(batch_size, {22050.0f, 16000.0f}); + void RunPerfTest(int n_iters = 1000) { + std::vector args_v(nsamples_, {22050.0f, 16000.0f}); auto args = make_cspan(args_v); - int nsec = 30; - - this->PrepareData(batch_size, nchannels, args, nsec); + this->nsec_ = 30; + this->PrepareData(args); ResamplerGPU R; R.Initialize(16); @@ -95,31 +98,57 @@ class ResamplingGPUTest : public ResamplingTest { }; TEST_F(ResamplingGPUTest, SingleChannel) { - this->RunTest(8, 1); + this->nchannels_ = 1; + this->RunTest(); } TEST_F(ResamplingGPUTest, TwoChannel) { - this->RunTest(3, 2); + this->nchannels_ = 2; + this->RunTest(); } TEST_F(ResamplingGPUTest, EightChannel) { - this->RunTest(3, 8); + this->nchannels_ = 8; + this->RunTest(); } -TEST_F(ResamplingGPUTest, HundredChannel) { - this->RunTest(3, 100); +TEST_F(ResamplingGPUTest, ThirtyChannel) { + this->nchannels_ = 30; + this->RunTest(); } TEST_F(ResamplingGPUTest, OutBeginEnd) { - this->RunTest(3, 1, true); + this->roi_start_ = 100; + this->roi_end_ = 8000; + this->RunTest(); } TEST_F(ResamplingGPUTest, EightChannelOutBeginEnd) { - this->RunTest(3, 8, true); + this->roi_start_ = 100; + this->roi_end_ = 8000; + this->nchannels_ = 8; + this->RunTest(); +} + +TEST_F(ResamplingGPUTest, PerfTest) { + this->RunPerfTest(1000); +} + +TEST_F(ResamplingGPUTest, SingleChannelNeedHighPrecision) { + this->default_freq_in_ = 0.49; + this->nsec_ = 400; + this->roi_start_ = 4000000; // enough to look long into the signal + this->roi_end_ = 4010000; + this->RunTest(); } -TEST_F(ResamplingGPUTest, DISABLED_PerfTest) { - this->RunPerfTest(64, 1, 1000); +TEST_F(ResamplingGPUTest, ThreeChannelNeedHighPrecision) { + this->default_freq_in_ = 0.49; + this->nsec_ = 400; + this->nchannels_ = 3; + this->roi_start_ = 4000000; // enough to look long into the signal + this->roi_end_ = 4010000; + this->RunTest(); } } // namespace test diff --git a/dali/kernels/signal/resampling_test.cc b/dali/kernels/signal/resampling_test.cc index a4070db3b4b..723676f1cca 100644 --- a/dali/kernels/signal/resampling_test.cc +++ b/dali/kernels/signal/resampling_test.cc @@ -29,41 +29,46 @@ double HannWindow(int i, int n) { return Hann(2.0 * i / n - 1); } -void ResamplingTest::PrepareData(int nsamples, int nchannels, span args, int nsec) { - TensorListShape<> in_sh(nsamples, nchannels > 1 ? 2 : 1); - TensorListShape<> out_sh(nsamples, nchannels > 1 ? 2 : 1); - for (int s = 0; s < nsamples; s++) { +void ResamplingTest::PrepareData(span args) { + TensorListShape<> in_sh(nsamples_, nchannels_ > 1 ? 2 : 1); + TensorListShape<> out_sh(nsamples_, nchannels_ > 1 ? 2 : 1); + for (int s = 0; s < nsamples_; s++) { double in_rate = args[s].in_rate; double out_rate = args[s].out_rate; double scale = in_rate / out_rate; - int n_in = nsec * in_rate + 12345 * s; // different lengths + int n_in = nsec_ * in_rate + 12345 * s; // different lengths int n_out = std::ceil(n_in / scale); int64_t out_begin = args[s].out_begin > 0 ? args[s].out_begin : 0; int64_t out_end = args[s].out_end > 0 ? args[s].out_end : n_out; ASSERT_GT(out_end, out_begin); in_sh.tensor_shape_span(s)[0] = n_in; out_sh.tensor_shape_span(s)[0] = out_end - out_begin; - if (nchannels > 1) { - in_sh.tensor_shape_span(s)[1] = nchannels; - out_sh.tensor_shape_span(s)[1] = nchannels; + if (nchannels_ > 1) { + in_sh.tensor_shape_span(s)[1] = nchannels_; + out_sh.tensor_shape_span(s)[1] = nchannels_; } } ttl_in_.reshape(in_sh); ttl_out_.reshape(out_sh); ttl_outref_.reshape(out_sh); - for (int s = 0; s < nsamples; s++) { + for (int s = 0; s < nsamples_; s++) { double in_rate = args[s].in_rate; double out_rate = args[s].out_rate; - double scale = static_cast(in_rate) / out_rate; + double scale = in_rate / out_rate; int64_t n_in = in_sh.tensor_shape_span(s)[0]; int n_out = std::ceil(n_in / scale); int64_t out_begin = args[s].out_begin > 0 ? args[s].out_begin : 0; int64_t out_end = args[s].out_end > 0 ? args[s].out_end : n_out; - for (int c = 0; c < nchannels; c++) { - float f_in = 0.1f + 0.01 * s + 0.001 * c; - float f_out = f_in * scale; - TestWave(ttl_in_.cpu()[s].data + c, n_in, nchannels, f_in); - TestWave(ttl_outref_.cpu()[s].data + c, n_out, nchannels, f_out, out_begin, out_end); + for (int c = 0; c < nchannels_; c++) { + double f_in = default_freq_in_ + 0.01 * s + 0.001 * c; + double f_out = f_in * scale; + // enough input samples for a given output region + int64_t in_begin = std::max(out_begin * scale - 200, 0); + int64_t in_end = std::min(out_end * scale + 200, n_in); + TestWave(ttl_in_.cpu()[s].data + in_begin * nchannels_ + c, n_in, nchannels_, f_in, + use_envelope_, in_begin, in_end); + TestWave(ttl_outref_.cpu()[s].data + c, n_out, nchannels_, f_out, use_envelope_, out_begin, + out_end); } } } @@ -80,7 +85,7 @@ void ResamplingTest::Verify(span args) { int64_t out_len = out_sh.tensor_shape_span(s)[0]; int nchannels = out_sh.sample_dim() == 1 ? 1 : out_sh.tensor_shape_span(s)[1]; for (int64_t i = 0; i < out_len; i++) { - ASSERT_NEAR(out_data[i], out_ref[i], eps()) + ASSERT_NEAR(out_data[i], out_ref[i], eps_) << "Sample error too big @ sample=" << s << " pos=" << i << std::endl; float diff = std::abs(out_data[i] - out_ref[i]); if (diff > max_diff) @@ -89,26 +94,24 @@ void ResamplingTest::Verify(span args) { } err = std::sqrt(err / out_len); - EXPECT_LE(err, max_avg_err()) << "Average error too big"; + EXPECT_LE(err, max_avg_err_) << "Average error too big"; std::cerr << "Resampling with Hann-windowed sinc filter and 16 zero crossings" "\n max difference vs fresh signal: " << max_diff << "\n RMS error: " << err << std::endl; } } -void ResamplingTest::RunTest(int nsamples, int nchannels, bool use_roi) { +void ResamplingTest::RunTest() { std::vector args_v; - for (int i = 0; i < nsamples; i++) { - int roi_start = use_roi ? 100 : 0; - int roi_end = use_roi ? 8000 : -1; + for (int i = 0; i < nsamples_; i++) { if (i % 2 == 0) - args_v.push_back({22050.0f, 16000.0f, roi_start, roi_end}); + args_v.push_back({22050.0f, 16000.0f, roi_start_, roi_end_}); else - args_v.push_back({44100.0f, 16000.0f, roi_start, roi_end}); + args_v.push_back({44100.0f, 16000.0f, roi_start_, roi_end_}); } auto args = make_cspan(args_v); - PrepareData(nsamples, nchannels, args); + PrepareData(args); RunResampling(args); @@ -121,17 +124,19 @@ class ResamplingCPUTest : public ResamplingTest { Resampler R; R.Initialize(16); - int nsamples = args.size(); + ASSERT_EQ(args.size(), nsamples_); auto in_view = ttl_in_.cpu(); auto out_view = ttl_out_.cpu(); - for (int s = 0; s < nsamples; s++) { + for (int s = 0; s < nsamples_; s++) { auto out_sh = out_view.shape[s]; auto in_sh = in_view.shape[s]; int n_in = in_sh[0]; + int n_out = resampled_length(n_in, args[s].in_rate, args[s].out_rate); int nchannels = in_sh.sample_dim() > 1 ? in_sh[1] : 1; int64_t out_begin = args[s].out_begin > 0 ? args[s].out_begin : 0; - int64_t out_end = args[s].out_end > 0 ? args[s].out_end : out_sh[0]; + int64_t out_end = args[s].out_end > 0 ? args[s].out_end : n_out; + ASSERT_EQ(out_sh[0], out_end - out_begin); R.Resample(out_view[s].data, out_begin, out_end, args[s].out_rate, in_view[s].data, n_in, args[s].in_rate, nchannels); } @@ -139,23 +144,53 @@ class ResamplingCPUTest : public ResamplingTest { }; TEST_F(ResamplingCPUTest, SingleChannel) { - this->RunTest(1, 1); + this->nchannels_ = 1; + this->RunTest(); } TEST_F(ResamplingCPUTest, TwoChannel) { - this->RunTest(1, 2); + this->nchannels_ = 2; + this->RunTest(); } TEST_F(ResamplingCPUTest, EightChannel) { - this->RunTest(1, 8); + this->nchannels_ = 8; + this->RunTest(); +} + +TEST_F(ResamplingCPUTest, ThirtyChannel) { + this->nchannels_ = 30; + this->RunTest(); } TEST_F(ResamplingCPUTest, OutBeginEnd) { - this->RunTest(1, 1, true); + this->roi_start_ = 100; + this->roi_end_ = 8000; + this->RunTest(); } TEST_F(ResamplingCPUTest, EightChannelOutBeginEnd) { - this->RunTest(3, 8, true); + this->roi_start_ = 100; + this->roi_end_ = 8000; + this->nchannels_ = 8; + this->RunTest(); +} + +TEST_F(ResamplingCPUTest, SingleChannelNeedHighPrecision) { + this->default_freq_in_ = 0.49; + this->nsec_ = 400; + this->roi_start_ = 4000000; // enough to look at the tail + this->roi_end_ = -1; + this->RunTest(); +} + +TEST_F(ResamplingCPUTest, ThreeChannelNeedHighPrecision) { + this->default_freq_in_ = 0.49; + this->nsec_ = 400; + this->nchannels_ = 3; + this->roi_start_ = 4000000; // enough to look at the tail + this->roi_end_ = -1; + this->RunTest(); } } // namespace test diff --git a/dali/kernels/signal/resampling_test.h b/dali/kernels/signal/resampling_test.h index 51f7c4be7b2..66abc8fd800 100644 --- a/dali/kernels/signal/resampling_test.h +++ b/dali/kernels/signal/resampling_test.h @@ -31,28 +31,43 @@ namespace test { double HannWindow(int i, int n); template -void TestWave(T *out, int n, int stride, float freq, int i_start = 0, int i_end = -1) { - if (i_end <= 0) i_end = n; +void TestWave(T *out, int n, int stride, double freq, bool envelope = true, int i_start = 0, + int i_end = -1) { + if (i_end <= 0) + i_end = n; assert(i_start >= 0 && i_start <= n); assert(i_end >= 0 && i_end <= n); for (int i = i_start; i < i_end; i++) { - float f = std::sin(i * freq) * HannWindow(i, n); + float f; + if (envelope) + f = std::sin(i * freq) * HannWindow(i, n); + else + f = std::sin(i * freq); out[(i - i_start) * stride] = ConvertSatNorm(f); } } class ResamplingTest : public ::testing::Test { public: - void PrepareData(int nsamples, int nchannels, span args, int nsec = 1); + void PrepareData(span args); - virtual float eps() const { return 2e-3; } virtual float max_avg_err() const { return 1e-3; } void Verify(span args); virtual void RunResampling(span args) = 0; - void RunTest(int nsamples, int nchannels, bool use_roi = false); + void RunTest(); + protected: + int nsamples_ = 1; + int nchannels_ = 1; + double default_freq_in_ = 0.1; + int nsec_ = 1; + float eps_ = 2e-3; + float max_avg_err_ = 1e-3; + bool use_envelope_ = true; + int64_t roi_start_ = 0; + int64_t roi_end_ = -1; // means end-of-signal TestTensorList ttl_in_; TestTensorList ttl_out_; From 44cd21606f5ec7693b00c3b25ca68d95471f6df5 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Tue, 17 May 2022 11:43:33 +0200 Subject: [PATCH 29/36] Move resampling GPU to cu file & add sync to Initialize Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling_gpu.cu | 110 +++++++++++++++++++++ dali/kernels/signal/resampling_gpu.h | 107 +++++--------------- dali/kernels/signal/resampling_gpu_test.cu | 1 + 3 files changed, 137 insertions(+), 81 deletions(-) create mode 100644 dali/kernels/signal/resampling_gpu.cu diff --git a/dali/kernels/signal/resampling_gpu.cu b/dali/kernels/signal/resampling_gpu.cu new file mode 100644 index 00000000000..d68e111aeed --- /dev/null +++ b/dali/kernels/signal/resampling_gpu.cu @@ -0,0 +1,110 @@ +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "dali/core/dev_buffer.h" +#include "dali/core/mm/memory.h" +#include "dali/core/static_switch.h" +#include "dali/kernels/dynamic_scratchpad.h" +#include "dali/kernels/kernel.h" +#include "dali/kernels/signal/resampling_gpu.cuh" +#include "dali/kernels/signal/resampling_gpu.h" + +namespace dali { +namespace kernels { +namespace signal { + +namespace resampling { + +template +void ResamplerGPU::Initialize(int lobes, int lookup_size) { + windowed_sinc(window_cpu_, lookup_size, lobes); + window_gpu_storage_.from_host(window_cpu_.storage); + window_gpu_ = window_cpu_; + window_gpu_.lookup = window_gpu_storage_.data(); + CUDA_CALL(cudaStreamSynchronize(0)); +} + +template +KernelRequirements ResamplerGPU::Setup(KernelContext &context, const InListGPU &in, + span args) { + KernelRequirements req; + auto out_shape = in.shape; + for (int i = 0; i < in.num_samples(); i++) { + auto in_sh = in.shape.tensor_shape_span(i); + auto out_sh = out_shape.tensor_shape_span(i); + auto &arg = args[i]; + if (arg.out_begin > 0 || arg.out_end > 0) { + out_sh[0] = arg.out_end - arg.out_begin; + } else { + out_sh[0] = resampled_length(in_sh[0], arg.in_rate, arg.out_rate); + } + } + req.output_shapes = {out_shape}; + return req; +} + +template +void ResamplerGPU::Run(KernelContext &context, const OutListGPU &out, + const InListGPU &in, span args) { + if (window_gpu_storage_.empty()) + Initialize(); + + assert(context.scratchpad); + auto &scratch = *context.scratchpad; + + int nsamples = in.num_samples(); + auto samples_cpu = + make_span(scratch.Allocate(nsamples), nsamples); + + bool any_multichannel = false; + for (int i = 0; i < nsamples; i++) { + auto &desc = samples_cpu[i]; + desc.in = in[i].data; + desc.out = out[i].data; + desc.window = window_gpu_; + const auto &in_sh = in[i].shape; + const auto &out_sh = out[i].shape; + desc.in_len = in_sh[0]; + auto &arg = args[i]; + desc.out_begin = arg.out_begin > 0 ? arg.out_begin : 0; + desc.out_end = + arg.out_end > 0 ? arg.out_end : resampled_length(in_sh[0], arg.in_rate, arg.out_rate); + assert((desc.out_end - desc.out_begin) == out_sh[0]); + desc.nchannels = in[i].shape.sample_dim() > 1 ? in_sh[1] : 1; + desc.scale = arg.in_rate / arg.out_rate; + any_multichannel |= desc.nchannels > 1; + } + + auto samples_gpu = scratch.ToGPU(context.gpu.stream, samples_cpu); + + dim3 block(256, 1); + int blocks_per_sample = std::max(32, 1024 / nsamples); + dim3 grid(blocks_per_sample, nsamples); + + // window coefficients and temporary per channel out values + size_t shm_size = (window_gpu_storage_.size() + (SHM_NCHANNELS + 1) * block.x) * sizeof(float); + + BOOL_SWITCH(!any_multichannel, SingleChannel, + (ResampleGPUKernel + <<>>(samples_gpu);)); // NOLINT + CUDA_CALL(cudaGetLastError()); +} + +DALI_INSTANTIATE_RESAMPLER_GPU(); + +} // namespace resampling +} // namespace signal +} // namespace kernels +} // namespace dali diff --git a/dali/kernels/signal/resampling_gpu.h b/dali/kernels/signal/resampling_gpu.h index e8fe697ab62..f82ca7cf905 100644 --- a/dali/kernels/signal/resampling_gpu.h +++ b/dali/kernels/signal/resampling_gpu.h @@ -17,12 +17,8 @@ #include #include "dali/kernels/signal/resampling.h" -#include "dali/kernels/signal/resampling_gpu.cuh" #include "dali/kernels/kernel.h" -#include "dali/kernels/dynamic_scratchpad.h" -#include "dali/core/mm/memory.h" #include "dali/core/dev_buffer.h" -#include "dali/core/static_switch.h" namespace dali { namespace kernels { @@ -30,86 +26,15 @@ namespace signal { namespace resampling { -template -class ResamplerGPU { +template +class DLL_PUBLIC ResamplerGPU { public: - void Initialize(int lobes = 16, int lookup_size = 2048) { - windowed_sinc(window_cpu_, lookup_size, lobes); - window_gpu_storage_.from_host(window_cpu_.storage); - window_gpu_ = window_cpu_; - window_gpu_.lookup = window_gpu_storage_.data(); - } + void Initialize(int lobes = 16, int lookup_size = 2048); - KernelRequirements Setup(KernelContext &context, const InListGPU &in, - span args) { - KernelRequirements req; - auto out_shape = in.shape; - for (int i = 0; i < in.num_samples(); i++) { - auto in_sh = in.shape.tensor_shape_span(i); - auto out_sh = out_shape.tensor_shape_span(i); - auto &arg = args[i]; - auto out_len = resampled_length(in_sh[0], arg.in_rate, arg.out_rate); - auto out_begin = arg.out_begin > 0 ? arg.out_begin : 0; - auto out_end = arg.out_end > 0 ? arg.out_end : out_len; - if (out_end < out_begin) - throw std::invalid_argument( - make_string("out_begin can't be larger than out_end. Got out_begin=", out_begin, - ", out_end=", out_end)); - if (out_end > out_len) - throw std::invalid_argument(make_string( - "out_end can't be outside of the range of the output signal: [0, ", out_len, ")")); - out_sh[0] = out_end - out_begin; - } - req.output_shapes = {out_shape}; - return req; - } + KernelRequirements Setup(KernelContext &context, const InListGPU &in, span args); - void Run(KernelContext &context, const OutListGPU &out, - const InListGPU &in, span args) { - if (window_gpu_storage_.empty()) - Initialize(); - - assert(context.scratchpad); - auto &scratch = *context.scratchpad; - - int nsamples = in.num_samples(); - auto samples_cpu = - make_span(scratch.Allocate(nsamples), nsamples); - - bool any_multichannel = false; - for (int i = 0; i < nsamples; i++) { - auto &desc = samples_cpu[i]; - desc.in = in[i].data; - desc.out = out[i].data; - desc.window = window_gpu_; - const auto &in_sh = in[i].shape; - const auto &out_sh = out[i].shape; - desc.in_len = in_sh[0]; - auto &arg = args[i]; - desc.out_begin = arg.out_begin > 0 ? arg.out_begin : 0; - desc.out_end = arg.out_end > 0 ? arg.out_end : - resampled_length(in_sh[0], arg.in_rate, arg.out_rate); - assert((desc.out_end - desc.out_begin) == out_sh[0]); - desc.nchannels = in[i].shape.sample_dim() > 1 ? in_sh[1] : 1; - desc.scale = arg.in_rate / arg.out_rate; - any_multichannel |= desc.nchannels > 1; - } - - auto samples_gpu = scratch.ToGPU(context.gpu.stream, samples_cpu); - - dim3 block(256, 1); - int blocks_per_sample = std::max(32, 1024 / nsamples); - dim3 grid(blocks_per_sample, nsamples); - - // window coefficients and temporary per channel out values - size_t shm_size = (window_gpu_storage_.size() + (SHM_NCHANNELS + 1) * block.x) * sizeof(float); - - BOOL_SWITCH(!any_multichannel, SingleChannel, ( - ResampleGPUKernel - <<>>(samples_gpu); - )); // NOLINT - CUDA_CALL(cudaGetLastError()); - } + void Run(KernelContext &context, const OutListGPU &out, + const InListGPU &in, span args); private: ResamplingWindowCPU window_cpu_; @@ -117,6 +42,26 @@ class ResamplerGPU { DeviceBuffer window_gpu_storage_; }; +#define DALI_INSTANTIATE_RESAMPLER_GPU_OUT(linkage, Out)\ + linkage template class ResamplerGPU; \ + linkage template class ResamplerGPU; \ + linkage template class ResamplerGPU; \ + linkage template class ResamplerGPU; \ + linkage template class ResamplerGPU; \ + linkage template class ResamplerGPU; \ + linkage template class ResamplerGPU; \ + +#define DALI_INSTANTIATE_RESAMPLER_GPU(linkage) \ + DALI_INSTANTIATE_RESAMPLER_GPU_OUT(linkage, float) \ + DALI_INSTANTIATE_RESAMPLER_GPU_OUT(linkage, int8_t) \ + DALI_INSTANTIATE_RESAMPLER_GPU_OUT(linkage, uint8_t) \ + DALI_INSTANTIATE_RESAMPLER_GPU_OUT(linkage, int16_t) \ + DALI_INSTANTIATE_RESAMPLER_GPU_OUT(linkage, uint16_t) \ + DALI_INSTANTIATE_RESAMPLER_GPU_OUT(linkage, int32_t) \ + DALI_INSTANTIATE_RESAMPLER_GPU_OUT(linkage, uint32_t) + +DALI_INSTANTIATE_RESAMPLER_GPU(extern) + } // namespace resampling } // namespace signal } // namespace kernels diff --git a/dali/kernels/signal/resampling_gpu_test.cu b/dali/kernels/signal/resampling_gpu_test.cu index 46a2302d818..632635c7c09 100644 --- a/dali/kernels/signal/resampling_gpu_test.cu +++ b/dali/kernels/signal/resampling_gpu_test.cu @@ -17,6 +17,7 @@ #include #include "dali/kernels/signal/resampling_gpu.h" #include "dali/kernels/signal/resampling_test.h" +#include "dali/kernels/dynamic_scratchpad.h" #include "dali/core/cuda_event.h" namespace dali { From ab0470d0357ecf9dbe537179e9c30ebb0955fc11 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Tue, 17 May 2022 16:29:11 +0200 Subject: [PATCH 30/36] Fix reference to temp member Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling_gpu.cu | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/dali/kernels/signal/resampling_gpu.cu b/dali/kernels/signal/resampling_gpu.cu index d68e111aeed..c9920842657 100644 --- a/dali/kernels/signal/resampling_gpu.cu +++ b/dali/kernels/signal/resampling_gpu.cu @@ -71,18 +71,20 @@ void ResamplerGPU::Run(KernelContext &context, const OutListGPU &o bool any_multichannel = false; for (int i = 0; i < nsamples; i++) { auto &desc = samples_cpu[i]; - desc.in = in[i].data; - desc.out = out[i].data; + auto in_sample = in[i]; + auto out_sample = out[i]; + desc.in = in_sample.data; + desc.out = out_sample.data; desc.window = window_gpu_; - const auto &in_sh = in[i].shape; - const auto &out_sh = out[i].shape; + const auto &in_sh = in_sample.shape; + const auto &out_sh = out_sample.shape; desc.in_len = in_sh[0]; auto &arg = args[i]; desc.out_begin = arg.out_begin > 0 ? arg.out_begin : 0; desc.out_end = arg.out_end > 0 ? arg.out_end : resampled_length(in_sh[0], arg.in_rate, arg.out_rate); assert((desc.out_end - desc.out_begin) == out_sh[0]); - desc.nchannels = in[i].shape.sample_dim() > 1 ? in_sh[1] : 1; + desc.nchannels = in_sh.sample_dim() > 1 ? in_sh[1] : 1; desc.scale = arg.in_rate / arg.out_rate; any_multichannel |= desc.nchannels > 1; } From 6be84a4b4e2858bf694f090d9abcd662da492f76 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Tue, 17 May 2022 17:19:27 +0200 Subject: [PATCH 31/36] Call Initialize Signed-off-by: Joaquin Anton --- dali/operators/audio/resample_gpu.cc | 17 +++++++++++------ .../python/test_operator_audio_resample.py | 19 ++++++++----------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/dali/operators/audio/resample_gpu.cc b/dali/operators/audio/resample_gpu.cc index f327e2da80d..6bc13a3707d 100644 --- a/dali/operators/audio/resample_gpu.cc +++ b/dali/operators/audio/resample_gpu.cc @@ -43,22 +43,27 @@ class ResampleGPU : public ResampleBase { TYPE_SWITCH(dtype_, type2id, Out, (AUDIO_RESAMPLE_TYPES), ( TYPE_SWITCH(in.type(), type2id, In, (AUDIO_RESAMPLE_TYPES), ( ResampleTyped(view(out), view(in), ws.stream()); - ), ( + ), ( // NOLINT DALI_FAIL( make_string("Unsupported input type: ", in.type(), "\nSupported types are : ", ListTypeNames())); - )); - ), ( + )); // NOLINT + ), ( // NOLINT DALI_FAIL( make_string("Unsupported output type: ", dtype_, "\nSupported types are : ", ListTypeNames())); - )); + )); // NOLINT } template - void ResampleTyped(const OutListGPU &out, const InListGPU &in, cudaStream_t stream) { + void ResampleTyped(const OutListGPU &out, const InListGPU &in, + cudaStream_t stream) { using Kernel = kernels::signal::resampling::ResamplerGPU; - kmgr_.Resize(1); + if (kmgr_.NumInstances() == 0) { + kmgr_.Resize(1); + auto params = ResamplingParams::FromQuality(quality_); + kmgr_.Get(0).Initialize(params.lobes, params.lookup_size); + } auto args = make_cspan(args_); kernels::KernelContext ctx; ctx.gpu.stream = stream; diff --git a/dali/test/python/test_operator_audio_resample.py b/dali/test/python/test_operator_audio_resample.py index d1b542b02fd..ff396b819e8 100644 --- a/dali/test/python/test_operator_audio_resample.py +++ b/dali/test/python/test_operator_audio_resample.py @@ -64,7 +64,8 @@ def _test_standalone_vs_fused(device): for _ in range(2): outs = pipe.run() # two sampling rates - should be bit-exact - check_batch(outs[0], outs[1], eps=0, max_allowed_error=1e-4 if is_gpu else 0) + check_batch(outs[0], outs[1], eps=1e-6 if is_gpu else 0, + max_allowed_error=1e-4 if is_gpu else 0) # numerical round-off error in rate check_batch(outs[0], outs[2], eps=1e-6, max_allowed_error=1e-4) # here, the sampling rate is slightly different, so we can tolerate larger errors @@ -74,7 +75,7 @@ def test_standalone_vs_fused(): for device in ('gpu', 'cpu'): yield _test_standalone_vs_fused, device -def _test_type_conversion(device, src_type, in_values, dst_type, out_values, rtol=1e-6, atol=None): +def _test_type_conversion(device, src_type, in_values, dst_type, out_values, eps): src_nptype = dali_type_to_np(src_type) dst_nptype = dali_type_to_np(dst_type) assert len(out_values) == len(in_values) @@ -93,11 +94,7 @@ def test_pipe(device): for i in range(len(out_values)): ref = np.full_like(in_data[i], out_values[i], dst_nptype) out_arr = as_array(out[i]) - if atol is not None: - ok = np.allclose(out_arr, ref, rtol, atol) - else: - ok = np.allclose(out_arr, ref, rtol) - if not ok: + if not np.allclose(out_arr, ref, 1e-6, eps): print("Actual: ", out_arr) print(out_arr.dtype, out_arr.shape) print("Reference: ", ref) @@ -113,8 +110,8 @@ def test_dynamic_ranges(): (types.INT16, [-32768, -32767, -100, -1, 0, 1, 100, 32767], 0), (types.UINT32, [0, 1, 0x7fffffff, 0x80000000, 0xfffffffe, 0xffffffff], 128), (types.INT32, [-0x80000000, -0x7fffffff, -100, -1, 0, 1, 0x7fffffff], 128)]: - yield _test_type_conversion, 'gpu', type, values, type, values, 2e-5, eps - yield _test_type_conversion, 'cpu', type, values, type, values, 1e-6, eps + for device in ('cpu', 'gpu'): + yield _test_type_conversion, device, type, values, type, values, eps def test_type_conversion(): type_ranges = [(types.FLOAT, [-1, 1]), @@ -151,5 +148,5 @@ def test_type_conversion(): if eps < 1 and (o_lo != -o_hi or (i_hi != i_lo and dst_type != types.FLOAT)): eps = 1 - yield _test_type_conversion, 'gpu', src_type, in_values, dst_type, out_values, 2e-5, eps - yield _test_type_conversion, 'cpu', src_type, in_values, dst_type, out_values, 1e-6, eps + for device in ('cpu', 'gpu'): + yield _test_type_conversion, device, src_type, in_values, dst_type, out_values, eps From c5cec3b2d5628f39101fdcc36ededc8b810da400 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Wed, 18 May 2022 13:44:22 +0200 Subject: [PATCH 32/36] Rebase Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling.h | 253 +--------------- dali/kernels/signal/resampling_cpu.cc | 283 ++++++++++++++++++ dali/kernels/signal/resampling_cpu.h | 58 ++++ dali/kernels/signal/resampling_gpu.cuh | 1 + dali/kernels/signal/resampling_test.cc | 4 +- .../decoder/audio/audio_decoder_impl.cc | 6 +- .../decoder/audio/audio_decoder_impl.h | 14 +- .../decoder/audio/audio_decoder_op.h | 4 +- .../operators/reader/loader/nemo_asr_loader.h | 4 +- .../reader/loader/nemo_asr_loader_test.cc | 2 +- 10 files changed, 361 insertions(+), 268 deletions(-) create mode 100644 dali/kernels/signal/resampling_cpu.cc create mode 100644 dali/kernels/signal/resampling_cpu.h diff --git a/dali/kernels/signal/resampling.h b/dali/kernels/signal/resampling.h index dc47d601ea4..8a551bb182c 100644 --- a/dali/kernels/signal/resampling.h +++ b/dali/kernels/signal/resampling.h @@ -15,20 +15,12 @@ #ifndef DALI_KERNELS_SIGNAL_RESAMPLING_H_ #define DALI_KERNELS_SIGNAL_RESAMPLING_H_ -#ifdef __SSE2__ -#include -#endif -#ifdef __ARM_NEON -#include -#endif #include +#include #include #include #include #include "dali/core/math_util.h" -#include "dali/core/small_vector.h" -#include "dali/core/convert.h" -#include "dali/core/static_switch.h" namespace dali { namespace kernels { @@ -45,19 +37,6 @@ inline double Hann(double x) { return 0.5 * (1 + std::cos(x * M_PI)); } -#ifdef __ARM_NEON - -inline float32x4_t vsetq_f32(float x0, float x1, float x2, float x3) { - float32x4_t x; - x = vdupq_n_f32(x0); - x = vsetq_lane_f32(x1, x, 1); - x = vsetq_lane_f32(x2, x, 2); - x = vsetq_lane_f32(x3, x, 3); - return x; -} - -#endif - struct ResamplingWindow { struct InputRange { int i0, i1; @@ -74,7 +53,7 @@ struct ResamplingWindow { * @brief Calculates the window coefficient at an arbitrary floating point position * by interpolating between two samples. */ - inline DALI_HOST_DEV float operator()(float x) const { + DALI_HOST_DEV float operator()(float x) const { float fi = x * scale + center; float floori = floorf(fi); float di = fi - floori; @@ -83,46 +62,6 @@ struct ResamplingWindow { return lookup[i] + di * (lookup[i + 1] - lookup[i]); } -#ifdef __ARM_NEON - inline float32x4_t operator()(float32x4_t x) const { - float32x4_t fi = vfmaq_n_f32(vdupq_n_f32(center), x, scale); - int32x4_t i = vcvtq_s32_f32(fi); - float32x4_t fifloor = vcvtq_f32_s32(i); - float32x4_t di = vsubq_f32(fi, fifloor); - int idx[4] = { - vgetq_lane_s32(i, 0), - vgetq_lane_s32(i, 1), - vgetq_lane_s32(i, 2), - vgetq_lane_s32(i, 3) - }; - float32x2_t c0 = vld1_f32(&lookup[idx[0]]); - float32x2_t c1 = vld1_f32(&lookup[idx[1]]); - float32x2_t c2 = vld1_f32(&lookup[idx[2]]); - float32x2_t c3 = vld1_f32(&lookup[idx[3]]); - float32x4x2_t w = vuzpq_f32(vcombine_f32(c0, c1), vcombine_f32(c2, c3)); - float32x4_t curr = w.val[0]; - float32x4_t next = w.val[1]; - return vfmaq_f32(curr, di, vsubq_f32(next, curr)); - } -#endif - -#ifdef __SSE2__ - inline __m128 operator()(__m128 x) const { - __m128 fi = _mm_add_ps(x * _mm_set1_ps(scale), _mm_set1_ps(center)); - __m128i i = _mm_cvttps_epi32(fi); - __m128 fifloor = _mm_cvtepi32_ps(i); - __m128 di = _mm_sub_ps(fi, fifloor); - int idx[4]; - _mm_storeu_si128(reinterpret_cast<__m128i*>(idx), i); - __m128 curr = _mm_setr_ps(lookup[idx[0]], lookup[idx[1]], - lookup[idx[2]], lookup[idx[3]]); - __m128 next = _mm_setr_ps(lookup[idx[0]+1], lookup[idx[1]+1], - lookup[idx[2]+1], lookup[idx[3]+1]); - return _mm_add_ps(curr, _mm_mul_ps(di, _mm_sub_ps(next, curr))); - } -#endif - - float scale = 1, center = 1; int lobes = 0, coeffs = 0; int lookup_size = 0; @@ -155,198 +94,10 @@ inline void windowed_sinc(ResamplingWindowCPU &window, window.scale = 1 / scale; } - inline int64_t resampled_length(int64_t in_length, double in_rate, double out_rate) { return std::ceil(in_length * out_rate / in_rate); } -struct Resampler { - ResamplingWindowCPU window; - - void Initialize(int lobes = 16, int lookup_size = 2048) { - windowed_sinc(window, lookup_size, lobes); - } - -#if defined(__ARM_NEON) - inline float filter_vec(int &i_ref, float in_pos, int i1, const float *in) const { - const float32x4_t _0123 = vsetq_f32(0, 1, 2, 3); - float32x4_t f4 = vdupq_n_f32(0); - - int i = i_ref; - float32x4_t x4 = vaddq_f32(vdupq_n_f32(i - in_pos), _0123); - - for (; i + 3 < i1; i += 4) { - float32x4_t w4 = window(x4); - f4 = vfmaq_f32(f4, vld1q_f32(in + i), w4); - x4 = vaddq_f32(x4, vdupq_n_f32(4)); - } - // Sum elements in f4 - float32x2_t f2 = vpadd_f32(vget_low_f32(f4), vget_high_f32(f4)); - f2 = vpadd_f32(f2, f2); - i_ref = i; - return vget_lane_f32(f2, 0); - } -#elif defined(__SSE2__) - inline float filter_vec(int &i_ref, float in_pos, int i1, const float *in) const { - __m128 f4 = _mm_setzero_ps(); - int i = i_ref; - __m128 x4 = _mm_setr_ps(i - in_pos, i+1 - in_pos, i+2 - in_pos, i+3 - in_pos); - for (; i + 3 < i1; i += 4) { - __m128 w4 = window(x4); - - f4 = _mm_add_ps(f4, _mm_mul_ps(_mm_loadu_ps(in + i), w4)); - x4 = _mm_add_ps(x4, _mm_set1_ps(4)); - } - i_ref = i; - - // Sum elements in f4 - f4 = _mm_add_ps(f4, _mm_shuffle_ps(f4, f4, _MM_SHUFFLE(1, 0, 3, 2))); - f4 = _mm_add_ps(f4, _mm_shuffle_ps(f4, f4, _MM_SHUFFLE(0, 1, 0, 1))); - return _mm_cvtss_f32(f4); - } -#else - static float filter_vec(int &, float, int, const float *) { - return 0; - } -#endif - - /** - * @brief Resample single-channel signal and convert to Out - * - * Calculates a range of resampled signal. - * The function can seamlessly resample the input and produce the result in chunks. - * To reuse memory and still simulate chunk processing, adjust the in/out pointers. - */ - template - void Resample( - Out *__restrict__ out, int64_t out_begin, int64_t out_end, double out_rate, - const float *__restrict__ in, int64_t n_in, double in_rate) const { - assert(out_rate > 0 && in_rate > 0 && "Sampling rate must be positive"); - int64_t block = 1 << 8; // still leaves 15 significant bits for fractional part - double scale = in_rate / out_rate; - float fscale = scale; - - for (int64_t out_block = out_begin; out_block < out_end; out_block += block) { - int64_t block_end = std::min(out_block + block, out_end); - double in_block_f = out_block * scale; - int64_t in_block_i = std::floor(in_block_f); - float in_pos = in_block_f - in_block_i; - const float *__restrict__ in_block_ptr = in + in_block_i; - for (int64_t out_pos = out_block; out_pos < block_end; out_pos++, in_pos += fscale) { - auto irange = window.input_range(in_pos); - int i0 = irange.i0; - int i1 = irange.i1; - if (i0 + in_block_i < 0) - i0 = -in_block_i; - if (i1 + in_block_i > n_in) - i1 = n_in - in_block_i; - int i = i0; - - float f = filter_vec(i, in_pos, i1, in_block_ptr); - - float x = i - in_pos; - for (; i < i1; i++, x++) { - float w = window(x); - f += in_block_ptr[i] * w; - } - assert(out_pos >= out_begin && out_pos < out_end); - auto rel_pos = out_pos - out_begin; - out[rel_pos] = ConvertSatNorm(f); - } - } - } - - - - /** - * @brief Resample multi-channel signal and convert to Out - * - * Calculates a range of resampled signal. - * The function can seamlessly resample the input and produce the result in chunks. - * To reuse memory and still simulate chunk processing, adjust the in/out pointers. - * - * @tparam static_channels number of channels, if known at compile time, or -1 - */ - template - void Resample( - Out *__restrict__ out, int64_t out_begin, int64_t out_end, double out_rate, - const float *__restrict__ in, int64_t n_in, double in_rate, - int dynamic_num_channels) { - static_assert(static_channels != 0, "Static number of channels must be positive (use static) " - "or negative (use dynamic)."); - assert(out_rate > 0 && in_rate > 0 && "Sampling rate must be positive"); - if (dynamic_num_channels == 1) { - // fast path - Resample(out, out_begin, out_end, out_rate, in, n_in, in_rate); - return; - } - // the check below is compile time, so num_channels will be a compile-time constant - // or a run-time constant, depending on the value of static_channels - const int num_channels = static_channels < 0 ? dynamic_num_channels : static_channels; - assert(num_channels > 0); - - int64_t block = 1 << 8; // still leaves 15 significant bits for fractional part - double scale = in_rate / out_rate; - float fscale = scale; - SmallVector tmp; - tmp.resize(num_channels); - for (int64_t out_block = out_begin; out_block < out_end; out_block += block) { - int64_t block_end = std::min(out_block + block, out_end); - double in_block_f = out_block * scale; - int64_t in_block_i = std::floor(in_block_f); - float in_pos = in_block_f - in_block_i; - const float *__restrict__ in_block_ptr = in + in_block_i * num_channels; - for (int64_t out_pos = out_block; out_pos < block_end; out_pos++, in_pos += fscale) { - auto irange = window.input_range(in_pos); - int i0 = irange.i0; - int i1 = irange.i1; - if (i0 + in_block_i < 0) - i0 = -in_block_i; - if (i1 + in_block_i > n_in) - i1 = n_in - in_block_i; - - for (int c = 0; c < num_channels; c++) - tmp[c] = 0; - - float x = i0 - in_pos; - int ofs0 = i0 * num_channels; - int ofs1 = i1 * num_channels; - for (int in_ofs = ofs0; in_ofs < ofs1; in_ofs += num_channels, x++) { - float w = window(x); - for (int c = 0; c < num_channels; c++) { - assert(in_block_ptr + in_ofs + c >= in && - in_block_ptr + in_ofs + c < in + n_in * num_channels); - tmp[c] += in_block_ptr[in_ofs + c] * w; - } - } - assert(out_pos >= out_begin && out_pos < out_end); - auto rel_pos = out_pos - out_begin; - for (int c = 0; c < num_channels; c++) - out[rel_pos * num_channels + c] = ConvertSatNorm(tmp[c]); - } - } - } - - /** - * @brief Resample multi-channel signal and convert to Out - * - * Calculates a range of resampled signal. - * The function can seamlessly resample the input and produce the result in chunks. - * To reuse memory and still simulate chunk processing, adjust the in/out pointers. - */ - template - void Resample( - Out *__restrict__ out, int64_t out_begin, int64_t out_end, double out_rate, - const float *__restrict__ in, int64_t n_in, double in_rate, - int num_channels) { - VALUE_SWITCH(num_channels, static_channels, (1, 2, 3, 4, 5, 6, 7, 8), - (Resample(out, out_begin, out_end, out_rate, - in, n_in, in_rate, static_channels);), - (Resample<-1, Out>(out, out_begin, out_end, out_rate, - in, n_in, in_rate, num_channels))); - } -}; - } // namespace resampling } // namespace signal } // namespace kernels diff --git a/dali/kernels/signal/resampling_cpu.cc b/dali/kernels/signal/resampling_cpu.cc new file mode 100644 index 00000000000..a96d20ef715 --- /dev/null +++ b/dali/kernels/signal/resampling_cpu.cc @@ -0,0 +1,283 @@ +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dali/kernels/signal/resampling_cpu.h" + +#ifdef __SSE2__ +#include +#endif +#ifdef __ARM_NEON +#include +#endif +#include +#include +#include +#include +#include "dali/core/convert.h" +#include "dali/core/math_util.h" +#include "dali/core/small_vector.h" +#include "dali/core/static_switch.h" + +namespace dali { +namespace kernels { +namespace signal { + +namespace resampling { + +#if defined(__ARM_NEON) + +inline float32x4_t evaluate(const ResamplingWindow &window, float32x4_t x) { + float32x4_t fi = vfmaq_n_f32(vdupq_n_f32(window.center), x, window.scale); + int32x4_t i = vcvtq_s32_f32(fi); + float32x4_t fifloor = vcvtq_f32_s32(i); + float32x4_t di = vsubq_f32(fi, fifloor); + int idx[4] = {vgetq_lane_s32(i, 0), vgetq_lane_s32(i, 1), vgetq_lane_s32(i, 2), + vgetq_lane_s32(i, 3)}; + float32x2_t c0 = vld1_f32(&window.lookup[idx[0]]); + float32x2_t c1 = vld1_f32(&window.lookup[idx[1]]); + float32x2_t c2 = vld1_f32(&window.lookup[idx[2]]); + float32x2_t c3 = vld1_f32(&window.lookup[idx[3]]); + float32x4x2_t w = vuzpq_f32(vcombine_f32(c0, c1), vcombine_f32(c2, c3)); + float32x4_t curr = w.val[0]; + float32x4_t next = w.val[1]; + return vfmaq_f32(curr, di, vsubq_f32(next, curr)); +} + +inline float32x4_t vsetq_f32(float x0, float x1, float x2, float x3) { + float32x4_t x; + x = vdupq_n_f32(x0); + x = vsetq_lane_f32(x1, x, 1); + x = vsetq_lane_f32(x2, x, 2); + x = vsetq_lane_f32(x3, x, 3); + return x; +} + +inline float filter_vec(const ResamplingWindow &window, int &i_ref, float in_pos, int i1, + const float *in) { + const float32x4_t _0123 = vsetq_f32(0, 1, 2, 3); + float32x4_t f4 = vdupq_n_f32(0); + + int i = i_ref; + float32x4_t x4 = vaddq_f32(vdupq_n_f32(i - in_pos), _0123); + + for (; i + 3 < i1; i += 4) { + float32x4_t w4 = evaluate(window, x4); + f4 = vfmaq_f32(f4, vld1q_f32(in + i), w4); + x4 = vaddq_f32(x4, vdupq_n_f32(4)); + } + // Sum elements in f4 + float32x2_t f2 = vpadd_f32(vget_low_f32(f4), vget_high_f32(f4)); + f2 = vpadd_f32(f2, f2); + i_ref = i; + return vget_lane_f32(f2, 0); +} + +#elif defined(__SSE2__) + +inline __m128 evaluate(const ResamplingWindow &window, __m128 x) { + __m128 fi = _mm_add_ps(x * _mm_set1_ps(window.scale), _mm_set1_ps(window.center)); + __m128i i = _mm_cvttps_epi32(fi); + __m128 fifloor = _mm_cvtepi32_ps(i); + __m128 di = _mm_sub_ps(fi, fifloor); + int idx[4]; + _mm_storeu_si128(reinterpret_cast<__m128i *>(idx), i); + __m128 curr = _mm_setr_ps(window.lookup[idx[0]], window.lookup[idx[1]], + window.lookup[idx[2]], window.lookup[idx[3]]); + __m128 next = _mm_setr_ps(window.lookup[idx[0] + 1], window.lookup[idx[1] + 1], + window.lookup[idx[2] + 1], window.lookup[idx[3] + 1]); + return _mm_add_ps(curr, _mm_mul_ps(di, _mm_sub_ps(next, curr))); +} + +inline float filter_vec(const ResamplingWindow &window, int &i_ref, float in_pos, int i1, + const float *in) { + __m128 f4 = _mm_setzero_ps(); + int i = i_ref; + __m128 x4 = _mm_setr_ps(i - in_pos, i + 1 - in_pos, i + 2 - in_pos, i + 3 - in_pos); + for (; i + 3 < i1; i += 4) { + __m128 w4 = evaluate(window, x4); + + f4 = _mm_add_ps(f4, _mm_mul_ps(_mm_loadu_ps(in + i), w4)); + x4 = _mm_add_ps(x4, _mm_set1_ps(4)); + } + i_ref = i; + + // Sum elements in f4 + f4 = _mm_add_ps(f4, _mm_shuffle_ps(f4, f4, _MM_SHUFFLE(1, 0, 3, 2))); + f4 = _mm_add_ps(f4, _mm_shuffle_ps(f4, f4, _MM_SHUFFLE(0, 1, 0, 1))); + return _mm_cvtss_f32(f4); +} + +#else + +inline float filter_vec(const ResamplingWindow &, int &, float, int, const float *) { + return 0; +} + +#endif + +/** + * @brief Resample single-channel signal and convert to Out + * + * Calculates a range of resampled signal. + * The function can seamlessly resample the input and produce the result in chunks. + * To reuse memory and still simulate chunk processing, adjust the in/out pointers. + */ +template +void ResampleCPUImpl(ResamplingWindow window, Out *__restrict__ out, int64_t out_begin, + int64_t out_end, double out_rate, const float *__restrict__ in, int64_t n_in, + double in_rate) { + assert(out_rate > 0 && in_rate > 0 && "Sampling rate must be positive"); + int64_t block = 1 << 8; // still leaves 15 significant bits for fractional part + double scale = in_rate / out_rate; + float fscale = scale; + + for (int64_t out_block = out_begin; out_block < out_end; out_block += block) { + int64_t block_end = std::min(out_block + block, out_end); + double in_block_f = out_block * scale; + int64_t in_block_i = std::floor(in_block_f); + float in_pos = in_block_f - in_block_i; + const float *__restrict__ in_block_ptr = in + in_block_i; + for (int64_t out_pos = out_block; out_pos < block_end; out_pos++, in_pos += fscale) { + auto irange = window.input_range(in_pos); + int i0 = irange.i0; + int i1 = irange.i1; + if (i0 + in_block_i < 0) + i0 = -in_block_i; + if (i1 + in_block_i > n_in) + i1 = n_in - in_block_i; + int i = i0; + + float f = filter_vec(window, i, in_pos, i1, in_block_ptr); + + float x = i - in_pos; + for (; i < i1; i++, x++) { + float w = window(x); + f += in_block_ptr[i] * w; + } + assert(out_pos >= out_begin && out_pos < out_end); + auto rel_pos = out_pos - out_begin; + out[rel_pos] = ConvertSatNorm(f); + } + } +} + +/** + * @brief Resample multi-channel signal and convert to Out + * + * Calculates a range of resampled signal. + * The function can seamlessly resample the input and produce the result in chunks. + * To reuse memory and still simulate chunk processing, adjust the in/out pointers. + * + * @tparam static_channels number of channels, if known at compile time, or -1 + */ +template +void ResampleCPUImpl(ResamplingWindow window, Out *__restrict__ out, int64_t out_begin, + int64_t out_end, double out_rate, const float *__restrict__ in, int64_t n_in, + double in_rate, int dynamic_num_channels) { + static_assert(static_channels != 0, + "Static number of channels must be positive (use static) " + "or negative (use dynamic)."); + assert(out_rate > 0 && in_rate > 0 && "Sampling rate must be positive"); + if (dynamic_num_channels == 1) { + // fast path + ResampleCPUImpl(window, out, out_begin, out_end, out_rate, in, n_in, in_rate); + return; + } + // the check below is compile time, so num_channels will be a compile-time constant + // or a run-time constant, depending on the value of static_channels + const int num_channels = static_channels < 0 ? dynamic_num_channels : static_channels; + assert(num_channels > 0); + + int64_t block = 1 << 8; // still leaves 15 significant bits for fractional part + double scale = in_rate / out_rate; + float fscale = scale; + SmallVector tmp; + tmp.resize(num_channels); + for (int64_t out_block = out_begin; out_block < out_end; out_block += block) { + int64_t block_end = std::min(out_block + block, out_end); + double in_block_f = out_block * scale; + int64_t in_block_i = std::floor(in_block_f); + float in_pos = in_block_f - in_block_i; + const float *__restrict__ in_block_ptr = in + in_block_i * num_channels; + for (int64_t out_pos = out_block; out_pos < block_end; out_pos++, in_pos += fscale) { + auto irange = window.input_range(in_pos); + int i0 = irange.i0; + int i1 = irange.i1; + if (i0 + in_block_i < 0) + i0 = -in_block_i; + if (i1 + in_block_i > n_in) + i1 = n_in - in_block_i; + + for (int c = 0; c < num_channels; c++) + tmp[c] = 0; + + float x = i0 - in_pos; + int ofs0 = i0 * num_channels; + int ofs1 = i1 * num_channels; + for (int in_ofs = ofs0; in_ofs < ofs1; in_ofs += num_channels, x++) { + float w = window(x); + for (int c = 0; c < num_channels; c++) { + assert(in_block_ptr + in_ofs + c >= in && + in_block_ptr + in_ofs + c < in + n_in * num_channels); + tmp[c] += in_block_ptr[in_ofs + c] * w; + } + } + assert(out_pos >= out_begin && out_pos < out_end); + auto rel_pos = out_pos - out_begin; + for (int c = 0; c < num_channels; c++) + out[rel_pos * num_channels + c] = ConvertSatNorm(tmp[c]); + } + } +} + +/** + * @brief Resample multi-channel (or single channel) signal and convert to Out + * + * Calculates a range of resampled signal. + * The function can resample a region-of-interest (ROI) of the output, specified by `out_begin` and + * `out_end`. In this case, the output pointer points to the beginning of the ROI. + */ +template +void ResampleCPUImpl(ResamplingWindow window, Out *__restrict__ out, int64_t out_begin, + int64_t out_end, double out_rate, const float *__restrict__ in, int64_t n_in, + double in_rate, int num_channels) { + VALUE_SWITCH(num_channels, static_channels, (1, 2, 3, 4, 5, 6, 7, 8), + (ResampleCPUImpl(window, out, out_begin, out_end, out_rate, + in, n_in, in_rate, static_channels);), + (ResampleCPUImpl<-1, Out>(window, out, out_begin, out_end, out_rate, + in, n_in, in_rate, num_channels))); +} + +#define DALI_INSTANTIATE_RESAMPLER_CPU_OUT(Out) \ + template void ResampleCPUImpl(ResamplingWindow window, Out *__restrict__ out, \ + int64_t out_begin, int64_t out_end, double out_rate, \ + const float *__restrict__ in, int64_t n_in, double in_rate, \ + int num_channels); + +#define DALI_INSTANTIATE_RESAMPLER_CPU() \ + DALI_INSTANTIATE_RESAMPLER_CPU_OUT(float); \ + DALI_INSTANTIATE_RESAMPLER_CPU_OUT(int8_t); \ + DALI_INSTANTIATE_RESAMPLER_CPU_OUT(uint8_t); \ + DALI_INSTANTIATE_RESAMPLER_CPU_OUT(int16_t); \ + DALI_INSTANTIATE_RESAMPLER_CPU_OUT(uint16_t); \ + DALI_INSTANTIATE_RESAMPLER_CPU_OUT(int32_t); \ + DALI_INSTANTIATE_RESAMPLER_CPU_OUT(uint32_t); + +DALI_INSTANTIATE_RESAMPLER_CPU(); + + +} // namespace resampling +} // namespace signal +} // namespace kernels +} // namespace dali diff --git a/dali/kernels/signal/resampling_cpu.h b/dali/kernels/signal/resampling_cpu.h new file mode 100644 index 00000000000..c5aabb1fde5 --- /dev/null +++ b/dali/kernels/signal/resampling_cpu.h @@ -0,0 +1,58 @@ +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_KERNELS_SIGNAL_RESAMPLING_CPU_H_ +#define DALI_KERNELS_SIGNAL_RESAMPLING_CPU_H_ + +#include "dali/kernels/signal/resampling.h" +#include "dali/core/api_helper.h" + +namespace dali { +namespace kernels { +namespace signal { + +namespace resampling { + +template +DLL_PUBLIC void ResampleCPUImpl(ResamplingWindow window, Out *__restrict__ out, int64_t out_begin, + int64_t out_end, double out_rate, const float *__restrict__ in, + int64_t n_in, double in_rate, int num_channels); + +struct DLL_PUBLIC ResamplerCPU { + ResamplingWindowCPU window; + + inline void Initialize(int lobes = 16, int lookup_size = 2048) { + windowed_sinc(window, lookup_size, lobes); + } + + /** + * @brief Resample multi-channel (or single channel) signal and convert to Out + * + * Calculates a range of resampled signal. + * The function can resample a region-of-interest (ROI) of the output, specified by `out_begin` and + * `out_end`. In this case, the output pointer points to the beginning of the ROI. + */ + template + void Resample(Out *__restrict__ out, int64_t out_begin, int64_t out_end, double out_rate, + const float *__restrict__ in, int64_t n_in, double in_rate, int num_channels) { + ResampleCPUImpl(window, out, out_begin, out_end, out_rate, in, n_in, in_rate, num_channels); + } +}; + +} // namespace resampling +} // namespace signal +} // namespace kernels +} // namespace dali + +#endif // DALI_KERNELS_SIGNAL_RESAMPLING_CPU_H_ diff --git a/dali/kernels/signal/resampling_gpu.cuh b/dali/kernels/signal/resampling_gpu.cuh index 452f7f1c72d..e93ce31b2dc 100644 --- a/dali/kernels/signal/resampling_gpu.cuh +++ b/dali/kernels/signal/resampling_gpu.cuh @@ -18,6 +18,7 @@ #include #include "dali/kernels/signal/resampling.h" #include "dali/core/util.h" +#include "dali/core/convert.h" #define SHM_NCHANNELS 16 diff --git a/dali/kernels/signal/resampling_test.cc b/dali/kernels/signal/resampling_test.cc index 723676f1cca..f76f98dd858 100644 --- a/dali/kernels/signal/resampling_test.cc +++ b/dali/kernels/signal/resampling_test.cc @@ -15,7 +15,7 @@ #include #include #include -#include "dali/kernels/signal/resampling.h" +#include "dali/kernels/signal/resampling_cpu.h" #include "dali/kernels/signal/resampling_test.h" namespace dali { @@ -121,7 +121,7 @@ void ResamplingTest::RunTest() { class ResamplingCPUTest : public ResamplingTest { public: void RunResampling(span args) override { - Resampler R; + ResamplerCPU R; R.Initialize(16); ASSERT_EQ(args.size(), nsamples_); diff --git a/dali/operators/decoder/audio/audio_decoder_impl.cc b/dali/operators/decoder/audio/audio_decoder_impl.cc index 3001bb3a19e..9fe7545de29 100644 --- a/dali/operators/decoder/audio/audio_decoder_impl.cc +++ b/dali/operators/decoder/audio/audio_decoder_impl.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -48,7 +48,7 @@ TensorShape<> DecodedAudioShape(const AudioMetadata &meta, float target_sample_r template void DecodeAudio(TensorView audio, AudioDecoderBase &decoder, - const AudioMetadata &meta, kernels::signal::resampling::Resampler &resampler, + const AudioMetadata &meta, kernels::signal::resampling::ResamplerCPU &resampler, span decode_scratch_mem, span resample_scratch_mem, float target_sample_rate, bool downmix, @@ -107,7 +107,7 @@ void DecodeAudio(TensorView audio, AudioDecode #define DECLARE_IMPL(OutType) \ template void DecodeAudio( \ TensorView audio, AudioDecoderBase & decoder, \ - const AudioMetadata &meta, kernels::signal::resampling::Resampler &resampler, \ + const AudioMetadata &meta, kernels::signal::resampling::ResamplerCPU &resampler, \ span decode_scratch_mem, span resample_scratch_mem, \ float target_sample_rate, bool downmix, const char *audio_filepath); diff --git a/dali/operators/decoder/audio/audio_decoder_impl.h b/dali/operators/decoder/audio/audio_decoder_impl.h index 1ae0007f742..957bba96f53 100644 --- a/dali/operators/decoder/audio/audio_decoder_impl.h +++ b/dali/operators/decoder/audio/audio_decoder_impl.h @@ -1,4 +1,4 @@ -// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ #include "dali/operators/decoder/audio/audio_decoder.h" #include "dali/operators/decoder/audio/generic_decoder.h" #include "dali/pipeline/data/backend.h" -#include "dali/kernels/signal/resampling.h" +#include "dali/kernels/signal/resampling_cpu.h" #include "dali/core/tensor_view.h" namespace dali { @@ -42,7 +42,7 @@ DLL_PUBLIC std::pair ProcessOffsetAndLength(const AudioMetadat * @param target_sample_rate If a positive number is provided, it represent the target sampling rate * (the audio data is expected to be resampled if its original sampling rate differs) * @param downmix If set to true, the audio channels are expected to be downmixed, resulting in a shape with 1 - * dimension ({nsamples,}), instead of 2 ({nsamples, nchannels}) + * dimension ({nsamples,}), instead of 2 ({nsamples, nchannels}) */ DLL_PUBLIC TensorShape<> DecodedAudioShape(const AudioMetadata &meta, float target_sample_rate = -1, bool downmix = true); @@ -52,23 +52,23 @@ DLL_PUBLIC TensorShape<> DecodedAudioShape(const AudioMetadata &meta, float targ * @param audio Destination buffer. The function will decode as many audio samples as the shape of this argument * @param decoder Decoder object. * @param meta Audio metadata. - * @param resampler Resampler instance used if resampling is required + * @param resampler ResamplerCPU instance used if resampling is required * @param decode_scratch_mem Scratch memory used for decoding, when decoding can't be done directly to the output buffer. * If downmixing or resampling is required, this buffer should have a positive length, representing * decoded audio length at the original sampling rate: ``length * nchannels`` * @param resample_scratch_mem Scratch memory used for the input of resampling. - * If resampling is required, the buffer should have a positive length, representing the + * If resampling is required, the buffer should have a positive length, representing the * decoded audio length, ``length`` if downmixing is enabled, or the decoded audio length including * channels, ``length * nchannels``, otherwise. * @param target_sample_rate If a positive value is provided, the signal will be resampled except when its original sampling rate * is equal to the target. * @param downmix If true, the audio channes will be downmixed to a single one - * @param audio_filepath Path to the audio file being decoded, only used for debugging purposes + * @param audio_filepath Path to the audio file being decoded, only used for debugging purposes */ template DLL_PUBLIC void DecodeAudio(TensorView audio, AudioDecoderBase &decoder, const AudioMetadata &meta, - kernels::signal::resampling::Resampler &resampler, + kernels::signal::resampling::ResamplerCPU &resampler, span decode_scratch_mem, span resample_scratch_mem, float target_sample_rate, bool downmix, const char *audio_filepath); diff --git a/dali/operators/decoder/audio/audio_decoder_op.h b/dali/operators/decoder/audio/audio_decoder_op.h index 17d29ca98c6..d46bbb4a5a1 100644 --- a/dali/operators/decoder/audio/audio_decoder_op.h +++ b/dali/operators/decoder/audio/audio_decoder_op.h @@ -26,7 +26,7 @@ #include "dali/pipeline/workspace/workspace.h" #include "dali/pipeline/operator/operator.h" #include "dali/pipeline/workspace/host_workspace.h" -#include "dali/kernels/signal/resampling.h" +#include "dali/kernels/signal/resampling_cpu.h" #include "dali/kernels/signal/downmixing.h" #include "dali/core/tensor_view.h" @@ -83,7 +83,7 @@ class AudioDecoderCpu : public Operator { } std::vector target_sample_rates_; - kernels::signal::resampling::Resampler resampler_; + kernels::signal::resampling::ResamplerCPU resampler_; DALIDataType output_type_ = DALI_NO_TYPE, decode_type_ = DALI_NO_TYPE; const bool downmix_ = false, use_resampling_ = false; const float quality_ = 50.0f; diff --git a/dali/operators/reader/loader/nemo_asr_loader.h b/dali/operators/reader/loader/nemo_asr_loader.h index e1cf137e82b..f91c453e5f8 100644 --- a/dali/operators/reader/loader/nemo_asr_loader.h +++ b/dali/operators/reader/loader/nemo_asr_loader.h @@ -25,7 +25,7 @@ #include "dali/core/common.h" #include "dali/core/error_handling.h" -#include "dali/kernels/signal/resampling.h" +#include "dali/kernels/signal/resampling_cpu.h" #include "dali/operators/decoder/audio/audio_decoder.h" #include "dali/operators/decoder/audio/audio_decoder_impl.h" #include "dali/operators/reader/loader/file_label_loader.h" @@ -181,7 +181,7 @@ class DLL_PUBLIC NemoAsrLoader : public Loader { double max_duration_; bool read_text_; int num_threads_; - kernels::signal::resampling::Resampler resampler_; + kernels::signal::resampling::ResamplerCPU resampler_; std::vector> decode_scratch_; std::vector> resample_scratch_; }; diff --git a/dali/operators/reader/loader/nemo_asr_loader_test.cc b/dali/operators/reader/loader/nemo_asr_loader_test.cc index 9adb476628b..c5726c5302c 100644 --- a/dali/operators/reader/loader/nemo_asr_loader_test.cc +++ b/dali/operators/reader/loader/nemo_asr_loader_test.cc @@ -283,7 +283,7 @@ TEST(NemoAsrLoaderTest, ReadSample) { std::vector downsampled(downsampled_len, 0.0f); constexpr double q = 50.0; int lobes = std::round(0.007 * q * q - 0.09 * q + 3); - kernels::signal::resampling::Resampler resampler; + kernels::signal::resampling::ResamplerCPU resampler; resampler.Initialize(lobes, lobes * 64 + 1); resampler.Resample(downsampled.data(), 0, downsampled_len, sr_out, downmixed.data(), downmixed.size(), sr_in, 1); From 7a0cc7a0696f68928edc257622629ae79eda2078 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Wed, 18 May 2022 13:45:42 +0200 Subject: [PATCH 33/36] Code review fixes Signed-off-by: Joaquin Anton --- dali/kernels/signal/resampling_gpu.cu | 3 +-- dali/operators/audio/resample.h | 2 +- dali/operators/audio/resample_gpu.cc | 6 +----- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/dali/kernels/signal/resampling_gpu.cu b/dali/kernels/signal/resampling_gpu.cu index c9920842657..d747eafdfd1 100644 --- a/dali/kernels/signal/resampling_gpu.cu +++ b/dali/kernels/signal/resampling_gpu.cu @@ -77,13 +77,12 @@ void ResamplerGPU::Run(KernelContext &context, const OutListGPU &o desc.out = out_sample.data; desc.window = window_gpu_; const auto &in_sh = in_sample.shape; - const auto &out_sh = out_sample.shape; desc.in_len = in_sh[0]; auto &arg = args[i]; desc.out_begin = arg.out_begin > 0 ? arg.out_begin : 0; desc.out_end = arg.out_end > 0 ? arg.out_end : resampled_length(in_sh[0], arg.in_rate, arg.out_rate); - assert((desc.out_end - desc.out_begin) == out_sh[0]); + assert((desc.out_end - desc.out_begin) == out_sample.shape[0]); desc.nchannels = in_sh.sample_dim() > 1 ? in_sh[1] : 1; desc.scale = arg.in_rate / arg.out_rate; any_multichannel |= desc.nchannels > 1; diff --git a/dali/operators/audio/resample.h b/dali/operators/audio/resample.h index 1ca8ee3c979..aa03b086601 100644 --- a/dali/operators/audio/resample.h +++ b/dali/operators/audio/resample.h @@ -149,7 +149,7 @@ class ResampleBase : public Operator { ArgValue out_length_{"out_length", spec_}; using Args = kernels::signal::resampling::Args; - SmallVector args_; + std::vector args_; }; } // namespace audio diff --git a/dali/operators/audio/resample_gpu.cc b/dali/operators/audio/resample_gpu.cc index 6bc13a3707d..bf1fb82d805 100644 --- a/dali/operators/audio/resample_gpu.cc +++ b/dali/operators/audio/resample_gpu.cc @@ -48,11 +48,7 @@ class ResampleGPU : public ResampleBase { make_string("Unsupported input type: ", in.type(), "\nSupported types are : ", ListTypeNames())); )); // NOLINT - ), ( // NOLINT - DALI_FAIL( - make_string("Unsupported output type: ", dtype_, "\nSupported types are : ", - ListTypeNames())); - )); // NOLINT + ), (assert(!"Unreachable code."))); // NOLINT } template From 934c17e61b677b4ec7257f32c7dc9922c7dcf03e Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Wed, 18 May 2022 14:29:31 +0200 Subject: [PATCH 34/36] Test full GPU pipe Signed-off-by: Joaquin Anton --- dali/operators/audio/resample.cc | 4 +- dali/test/python/test_torch_pipeline_rnnt.py | 79 ++++++++++++++++---- 2 files changed, 67 insertions(+), 16 deletions(-) diff --git a/dali/operators/audio/resample.cc b/dali/operators/audio/resample.cc index 5e40b8c0119..f2b37e0eada 100644 --- a/dali/operators/audio/resample.cc +++ b/dali/operators/audio/resample.cc @@ -17,6 +17,8 @@ #include "dali/operators/audio/resample.h" #include "dali/operators/audio/resampling_params.h" #include "dali/kernels/kernel_params.h" +#include "dali/kernels/signal/resampling_cpu.h" +#include "dali/core/convert.h" namespace dali { @@ -176,7 +178,7 @@ class ResampleCPU : public ResampleBase { } private: - kernels::signal::resampling::Resampler R; + kernels::signal::resampling::ResamplerCPU R; std::vector> in_fp32; }; diff --git a/dali/test/python/test_torch_pipeline_rnnt.py b/dali/test/python/test_torch_pipeline_rnnt.py index 0a85522eccb..eef9692eeba 100644 --- a/dali/test/python/test_torch_pipeline_rnnt.py +++ b/dali/test/python/test_torch_pipeline_rnnt.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import math import random import os +from nose.tools import nottest audio_files = get_files('db/audio/wav', 'wav') audio_files = [file for file in audio_files if '237-134500' in file] # Filtering librispeech samples @@ -247,7 +248,7 @@ def dali_frame_splicing_graph(x, nfeatures, x_len, stacking=1, subsampling=1): if stacking > 1: seq = [x] for n in range(1, stacking): - f = fn.slice(x, start=n, shape=x_len, axes=(1,), out_of_bounds_policy='pad', fill_values=0) + f = fn.slice(x, n, x_len, axes=(1,), out_of_bounds_policy='pad', fill_values=0) seq.append(f) x = fn.cat(*seq, axis=0) nfeatures = nfeatures * stacking @@ -276,10 +277,10 @@ def flip_1d(x): x = fn.flip(x, vertical=1) x = fn.reshape(x, shape=(-1,), layout="t") return x - pad_start = fn.slice(x, start=1, shape=pad_amount, axes=(0,)) + pad_start = fn.slice(x, 1, pad_amount, axes=(0,)) pad_start = flip_1d(pad_start) - pad_end = fn.slice(x, start=(x_len-pad_amount-1), shape=pad_amount, axes=(0,)) + pad_end = fn.slice(x, x_len-pad_amount-1, pad_amount, axes=(0,)) pad_end = flip_1d(pad_end) x = fn.cat(pad_start, x, pad_end, axis=0) return x @@ -288,7 +289,9 @@ def flip_1d(x): def rnnt_train_pipe(files, sample_rate, pad_amount=0, preemph_coeff=.97, window_size=.02, window_stride=.01, window="hann", nfeatures=64, nfft=512, frame_splicing_stack=1, frame_splicing_subsample=1, - lowfreq=0.0, highfreq=None, normalize_type='per_feature', device='cpu'): + lowfreq=0.0, highfreq=None, normalize_type='per_feature', + speed_perturb=False, silence_trim=False, + device='cpu'): assert normalize_type == 'per_feature' or normalize_type == 'all_features' norm_axes = [1] if normalize_type == 'per_feature' else [0, 1] win_len, win_hop = win_args(sample_rate, window_size, window_stride) @@ -298,30 +301,48 @@ def rnnt_train_pipe(files, sample_rate, pad_amount=0, preemph_coeff=.97, data, _ = fn.readers.file(files=files, device="cpu", random_shuffle=False, shard_id=0, num_shards=1) audio, _ = fn.decoders.audio(data, dtype=types.FLOAT, downmix=True) - audio_shape = fn.shapes(audio, dtype=types.INT32) - orig_audio_len = fn.slice(audio_shape, 0, 1, axes=(0,)) + # splicing with subsampling doesn't work if audio_len is a GPU data node + if device == 'gpu' and frame_splicing_subsample == 1: + audio = audio.gpu() - if pad_amount > 0: - audio_len = orig_audio_len + 2 * pad_amount - else: - audio_len = orig_audio_len + # Speed perturbation 0.85x - 1.15x + if speed_perturb: + target_sr_factor = fn.random.uniform(device="cpu", range=(1/1.15, 1/0.85)) + audio = fn.experimental.audio_resample(audio, scale=target_sr_factor) - spec_len = audio_len // win_hop + 1 + # Silence trimming + if silence_trim: + begin, length = fn.nonsilent_region(audio, cutoff_db=-80) + audio = fn.slice(audio, begin, length, axes=[0]) - if device == 'gpu': + audio_shape = fn.shapes(audio, dtype=types.INT32) + orig_audio_len = fn.slice(audio_shape, 0, 1, axes=(0,)) + + # If we couldn't move to GPU earlier, do it now + if device == 'gpu' and frame_splicing_subsample > 1: audio = audio.gpu() if pad_amount > 0: + audio_len = orig_audio_len + 2 * pad_amount padded_audio = dali_reflect_pad_graph(audio, orig_audio_len, pad_amount) else: + audio_len = orig_audio_len padded_audio = audio + # Preemphasis filter preemph_audio = fn.preemphasis_filter(padded_audio, preemph_coeff=preemph_coeff, border='zero') + + # Spectrogram + spec_len = audio_len // win_hop + 1 spec = fn.spectrogram(preemph_audio, nfft=nfft, window_fn=window_fn_arg, window_length=win_len, window_step=win_hop, center_windows=True, reflect_padding=True) + # Mel spectrogram mel_spec = fn.mel_filter_bank(spec, sample_rate=sample_rate, nfilter=nfeatures, freq_low=lowfreq, freq_high=highfreq) + + # Log log_features = fn.to_decibels(mel_spec + 1e-20, multiplier=np.log(10), reference=1.0, cutoff_db=-80) + # Frame splicing if frame_splicing_stack > 1 or frame_splicing_subsample > 1: log_features_spliced = dali_frame_splicing_graph(log_features, nfeatures, spec_len, stacking=frame_splicing_stack, @@ -329,6 +350,7 @@ def rnnt_train_pipe(files, sample_rate, pad_amount=0, preemph_coeff=.97, else: log_features_spliced = log_features + # Normalization if normalize_type: norm_log_features = fn.normalize(log_features_spliced, axes=norm_axes, device=device, epsilon=4e-5, ddof=1) else: @@ -352,6 +374,8 @@ def _testimpl_rnnt_data_pipeline(device, pad_amount=0, preemph_coeff=.97, window window="hann", nfeatures=64, n_fft=512, frame_splicing_stack=1, frame_splicing_subsample=1, lowfreq=0.0, highfreq=None, normalize_type='per_feature', batch_size=32): sample_rate = npy_files_sr + speed_perturb = False + silence_trim = False ref_pipeline = FilterbankFeatures( sample_rate=sample_rate, window_size=window_size, window_stride=window_stride, window=window, normalize=normalize_type, @@ -369,8 +393,8 @@ def _testimpl_rnnt_data_pipeline(device, pad_amount=0, preemph_coeff=.97, window pipe = rnnt_train_pipe( audio_files, sample_rate, pad_amount, preemph_coeff, window_size, window_stride, window, nfeatures, - n_fft, frame_splicing_stack, frame_splicing_subsample, lowfreq, highfreq, normalize_type, device, - seed=42, batch_size=batch_size + n_fft, frame_splicing_stack, frame_splicing_subsample, lowfreq, highfreq, normalize_type, + speed_perturb, silence_trim, device, seed=42, batch_size=batch_size ) pipe.build() nbatches = (nrecordings + batch_size - 1) // batch_size @@ -445,3 +469,28 @@ def test_rnnt_data_pipeline(): yield _testimpl_rnnt_data_pipeline, device, \ pad_amount, preemph_coeff, window_size, window_stride, window, nfeatures, n_fft, \ frame_splicing_stack, frame_splicing_subsample, lowfreq, highfreq, normalize_type + +@nottest # To be run manually to check perf +def test_rnnt_data_pipeline_throughput(pad_amount=0, preemph_coeff=.97, window_size=.02, window_stride=.01, + window="hann", nfeatures=64, n_fft=512, frame_splicing_stack=1, frame_splicing_subsample=1, + speed_perturb=True, silence_trim=True, lowfreq=0.0, highfreq=None, normalize_type='per_feature', batch_size=32): + sample_rate = npy_files_sr + device = 'gpu' + pipe = rnnt_train_pipe( + audio_files, sample_rate, pad_amount, preemph_coeff, window_size, window_stride, window, nfeatures, + n_fft, frame_splicing_stack, frame_splicing_subsample, lowfreq, highfreq, + normalize_type, speed_perturb, silence_trim, device, seed=42, batch_size=batch_size + ) + pipe.build() + + import time + from test_utils import AverageMeter + end = time.time() + data_time = AverageMeter() + iters = 1000 + for j in range(iters): + pipe.run() + data_time.update(time.time() - end) + if j % 100 == 0: + print(f"run {j+1}/ {iters}, avg time: {data_time.avg} [s], worst time: {data_time.max_val} [s], speed: {batch_size / data_time.avg} [recordings/s]") + end = time.time() From 23525bf704787bdcdca4113b49b46f1990df64f8 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Fri, 20 May 2022 10:16:28 +0200 Subject: [PATCH 35/36] Code review fixes Signed-off-by: Joaquin Anton --- dali/operators/audio/resample.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dali/operators/audio/resample.h b/dali/operators/audio/resample.h index aa03b086601..f03fb36c93f 100644 --- a/dali/operators/audio/resample.h +++ b/dali/operators/audio/resample.h @@ -128,8 +128,8 @@ class ResampleBase : public Operator { DALI_FAIL(make_string("Cannot produce a non-empty signal from an empty input.\n" "Error at sample ", s)); } - args_[s].in_rate = 1.0; - args_[s].out_rate = in_length ? 1.0 * out_length / in_length : 0.0; + args_[s].in_rate = in_length ? in_length : 1; // avoid division by 0 + args_[s].out_rate = out_length ? out_length : 1; // avoid division by 0 out_shape.tensor_shape_span(s)[0] = out_length; } } else { From b2c2c57028415203a84619329ca5c463c044856c Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Fri, 20 May 2022 13:11:55 +0200 Subject: [PATCH 36/36] Code review fixes Signed-off-by: Joaquin Anton --- dali/operators/audio/resample.h | 2 +- .../python/test_operator_audio_resample.py | 2 +- dali/test/python/test_torch_pipeline_rnnt.py | 79 ++++--------------- 3 files changed, 17 insertions(+), 66 deletions(-) diff --git a/dali/operators/audio/resample.h b/dali/operators/audio/resample.h index f03fb36c93f..4255c6555b5 100644 --- a/dali/operators/audio/resample.h +++ b/dali/operators/audio/resample.h @@ -129,7 +129,7 @@ class ResampleBase : public Operator { "Error at sample ", s)); } args_[s].in_rate = in_length ? in_length : 1; // avoid division by 0 - args_[s].out_rate = out_length ? out_length : 1; // avoid division by 0 + args_[s].out_rate = out_length ? out_length : 1; // avoid division by 0 out_shape.tensor_shape_span(s)[0] = out_length; } } else { diff --git a/dali/test/python/test_operator_audio_resample.py b/dali/test/python/test_operator_audio_resample.py index ff396b819e8..28b861fba3c 100644 --- a/dali/test/python/test_operator_audio_resample.py +++ b/dali/test/python/test_operator_audio_resample.py @@ -100,7 +100,7 @@ def test_pipe(device): print("Reference: ", ref) print(ref.dtype, ref.shape) print("Diff: ", out_arr.astype(np.float) - ref) - assert False + assert np.allclose(out_arr, ref, 1e-6, eps) def test_dynamic_ranges(): for type, values, eps in [(types.FLOAT, [-1.e30, -1-1.e-6, -1, -0.5, -1.e-30, 0, 1.e-30, 0.5, 1, 1+1.e-6, 1e30], 0), diff --git a/dali/test/python/test_torch_pipeline_rnnt.py b/dali/test/python/test_torch_pipeline_rnnt.py index eef9692eeba..0a85522eccb 100644 --- a/dali/test/python/test_torch_pipeline_rnnt.py +++ b/dali/test/python/test_torch_pipeline_rnnt.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +24,6 @@ import math import random import os -from nose.tools import nottest audio_files = get_files('db/audio/wav', 'wav') audio_files = [file for file in audio_files if '237-134500' in file] # Filtering librispeech samples @@ -248,7 +247,7 @@ def dali_frame_splicing_graph(x, nfeatures, x_len, stacking=1, subsampling=1): if stacking > 1: seq = [x] for n in range(1, stacking): - f = fn.slice(x, n, x_len, axes=(1,), out_of_bounds_policy='pad', fill_values=0) + f = fn.slice(x, start=n, shape=x_len, axes=(1,), out_of_bounds_policy='pad', fill_values=0) seq.append(f) x = fn.cat(*seq, axis=0) nfeatures = nfeatures * stacking @@ -277,10 +276,10 @@ def flip_1d(x): x = fn.flip(x, vertical=1) x = fn.reshape(x, shape=(-1,), layout="t") return x - pad_start = fn.slice(x, 1, pad_amount, axes=(0,)) + pad_start = fn.slice(x, start=1, shape=pad_amount, axes=(0,)) pad_start = flip_1d(pad_start) - pad_end = fn.slice(x, x_len-pad_amount-1, pad_amount, axes=(0,)) + pad_end = fn.slice(x, start=(x_len-pad_amount-1), shape=pad_amount, axes=(0,)) pad_end = flip_1d(pad_end) x = fn.cat(pad_start, x, pad_end, axis=0) return x @@ -289,9 +288,7 @@ def flip_1d(x): def rnnt_train_pipe(files, sample_rate, pad_amount=0, preemph_coeff=.97, window_size=.02, window_stride=.01, window="hann", nfeatures=64, nfft=512, frame_splicing_stack=1, frame_splicing_subsample=1, - lowfreq=0.0, highfreq=None, normalize_type='per_feature', - speed_perturb=False, silence_trim=False, - device='cpu'): + lowfreq=0.0, highfreq=None, normalize_type='per_feature', device='cpu'): assert normalize_type == 'per_feature' or normalize_type == 'all_features' norm_axes = [1] if normalize_type == 'per_feature' else [0, 1] win_len, win_hop = win_args(sample_rate, window_size, window_stride) @@ -301,48 +298,30 @@ def rnnt_train_pipe(files, sample_rate, pad_amount=0, preemph_coeff=.97, data, _ = fn.readers.file(files=files, device="cpu", random_shuffle=False, shard_id=0, num_shards=1) audio, _ = fn.decoders.audio(data, dtype=types.FLOAT, downmix=True) - # splicing with subsampling doesn't work if audio_len is a GPU data node - if device == 'gpu' and frame_splicing_subsample == 1: - audio = audio.gpu() - - # Speed perturbation 0.85x - 1.15x - if speed_perturb: - target_sr_factor = fn.random.uniform(device="cpu", range=(1/1.15, 1/0.85)) - audio = fn.experimental.audio_resample(audio, scale=target_sr_factor) - - # Silence trimming - if silence_trim: - begin, length = fn.nonsilent_region(audio, cutoff_db=-80) - audio = fn.slice(audio, begin, length, axes=[0]) - audio_shape = fn.shapes(audio, dtype=types.INT32) orig_audio_len = fn.slice(audio_shape, 0, 1, axes=(0,)) - # If we couldn't move to GPU earlier, do it now - if device == 'gpu' and frame_splicing_subsample > 1: + if pad_amount > 0: + audio_len = orig_audio_len + 2 * pad_amount + else: + audio_len = orig_audio_len + + spec_len = audio_len // win_hop + 1 + + if device == 'gpu': audio = audio.gpu() if pad_amount > 0: - audio_len = orig_audio_len + 2 * pad_amount padded_audio = dali_reflect_pad_graph(audio, orig_audio_len, pad_amount) else: - audio_len = orig_audio_len padded_audio = audio - # Preemphasis filter preemph_audio = fn.preemphasis_filter(padded_audio, preemph_coeff=preemph_coeff, border='zero') - - # Spectrogram - spec_len = audio_len // win_hop + 1 spec = fn.spectrogram(preemph_audio, nfft=nfft, window_fn=window_fn_arg, window_length=win_len, window_step=win_hop, center_windows=True, reflect_padding=True) - # Mel spectrogram mel_spec = fn.mel_filter_bank(spec, sample_rate=sample_rate, nfilter=nfeatures, freq_low=lowfreq, freq_high=highfreq) - - # Log log_features = fn.to_decibels(mel_spec + 1e-20, multiplier=np.log(10), reference=1.0, cutoff_db=-80) - # Frame splicing if frame_splicing_stack > 1 or frame_splicing_subsample > 1: log_features_spliced = dali_frame_splicing_graph(log_features, nfeatures, spec_len, stacking=frame_splicing_stack, @@ -350,7 +329,6 @@ def rnnt_train_pipe(files, sample_rate, pad_amount=0, preemph_coeff=.97, else: log_features_spliced = log_features - # Normalization if normalize_type: norm_log_features = fn.normalize(log_features_spliced, axes=norm_axes, device=device, epsilon=4e-5, ddof=1) else: @@ -374,8 +352,6 @@ def _testimpl_rnnt_data_pipeline(device, pad_amount=0, preemph_coeff=.97, window window="hann", nfeatures=64, n_fft=512, frame_splicing_stack=1, frame_splicing_subsample=1, lowfreq=0.0, highfreq=None, normalize_type='per_feature', batch_size=32): sample_rate = npy_files_sr - speed_perturb = False - silence_trim = False ref_pipeline = FilterbankFeatures( sample_rate=sample_rate, window_size=window_size, window_stride=window_stride, window=window, normalize=normalize_type, @@ -393,8 +369,8 @@ def _testimpl_rnnt_data_pipeline(device, pad_amount=0, preemph_coeff=.97, window pipe = rnnt_train_pipe( audio_files, sample_rate, pad_amount, preemph_coeff, window_size, window_stride, window, nfeatures, - n_fft, frame_splicing_stack, frame_splicing_subsample, lowfreq, highfreq, normalize_type, - speed_perturb, silence_trim, device, seed=42, batch_size=batch_size + n_fft, frame_splicing_stack, frame_splicing_subsample, lowfreq, highfreq, normalize_type, device, + seed=42, batch_size=batch_size ) pipe.build() nbatches = (nrecordings + batch_size - 1) // batch_size @@ -469,28 +445,3 @@ def test_rnnt_data_pipeline(): yield _testimpl_rnnt_data_pipeline, device, \ pad_amount, preemph_coeff, window_size, window_stride, window, nfeatures, n_fft, \ frame_splicing_stack, frame_splicing_subsample, lowfreq, highfreq, normalize_type - -@nottest # To be run manually to check perf -def test_rnnt_data_pipeline_throughput(pad_amount=0, preemph_coeff=.97, window_size=.02, window_stride=.01, - window="hann", nfeatures=64, n_fft=512, frame_splicing_stack=1, frame_splicing_subsample=1, - speed_perturb=True, silence_trim=True, lowfreq=0.0, highfreq=None, normalize_type='per_feature', batch_size=32): - sample_rate = npy_files_sr - device = 'gpu' - pipe = rnnt_train_pipe( - audio_files, sample_rate, pad_amount, preemph_coeff, window_size, window_stride, window, nfeatures, - n_fft, frame_splicing_stack, frame_splicing_subsample, lowfreq, highfreq, - normalize_type, speed_perturb, silence_trim, device, seed=42, batch_size=batch_size - ) - pipe.build() - - import time - from test_utils import AverageMeter - end = time.time() - data_time = AverageMeter() - iters = 1000 - for j in range(iters): - pipe.run() - data_time.update(time.time() - end) - if j % 100 == 0: - print(f"run {j+1}/ {iters}, avg time: {data_time.avg} [s], worst time: {data_time.max_val} [s], speed: {batch_size / data_time.avg} [recordings/s]") - end = time.time()