From 740d6266ae48ecbd84709d2f937835e26ea7cee7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20Karpi=C5=84ski?= <34919255+hugo213@users.noreply.github.com> Date: Tue, 14 Dec 2021 20:00:13 +0100 Subject: [PATCH] Coalesce stores in Slice for smaller output types (#3568) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Coalesce stores in Slice for smaller output types This change coalesces stores to global memory in SliceGPU when OutputType is smaller than 4 bytes in order to improve performance. Signed-off-by: Szymon KarpiƄski --- dali/kernels/slice/slice_gpu.cuh | 89 ++++++++++++++++++-------------- 1 file changed, 51 insertions(+), 38 deletions(-) diff --git a/dali/kernels/slice/slice_gpu.cuh b/dali/kernels/slice/slice_gpu.cuh index de9eba34384..cc840aebdff 100644 --- a/dali/kernels/slice/slice_gpu.cuh +++ b/dali/kernels/slice/slice_gpu.cuh @@ -72,6 +72,9 @@ struct SliceBlockDesc { uint64_t size; }; +template +constexpr int coalesced_values = sizeof(OutputType) >= 4 ? 1 : 4 / sizeof(OutputType); + /** * @brief Simplified algorithm when no padding is necessary * @remarks `in` already refers to the slice anchor start @@ -87,18 +90,22 @@ __device__ void SliceFuncNoPad(OutputType *__restrict__ out, const InputType *__ return; } - for (; offset < block_end; offset += blockDim.x) { - uint64_t idx = offset; - uint64_t out_idx = idx; - uint64_t in_idx = 0; - + for (; offset < block_end; offset += blockDim.x * coalesced_values) { #pragma unroll - for (int d = 0; d < Dims; d++) { - int i_d = div_mod(idx, idx, out_strides[d]); - in_idx += i_d * in_strides[d]; + for (uint64_t i = 0; i < coalesced_values; i++) { + uint64_t idx = offset + i; + if (idx >= block_end) break; + uint64_t out_idx = idx; + uint64_t in_idx = 0; + + #pragma unroll + for (int d = 0; d < Dims; d++) { + int i_d = div_mod(idx, idx, out_strides[d]); + in_idx += i_d * in_strides[d]; + } + in_idx += idx; // remaining dims have equal strides + out[out_idx] = clamp(in[in_idx]); } - in_idx += idx; // remaining dims have equal strides - out[out_idx] = clamp(in[in_idx]); } } @@ -131,44 +138,50 @@ __device__ void SliceFunc(OutputType *__restrict__ out, const InputType *__restr inner_in_extent = Dims > 1 ? in_strides[LastDim - 1] : in_shape[LastDim] * in_strides[LastDim]; } - for (; offset < block_end; offset += blockDim.x) { - uint64_t idx = offset; - uint64_t out_idx = idx; - - // If no dimensions were skipped (AllDims=true) we can avoid division in the last dimension, - // because know the strides are 1 (or we treat them as 1 if we fused dimensions) - int i_c = 0; - int i_d; - bool out_of_bounds = false; - uint64_t in_idx = 0; - + for (; offset < block_end; offset += blockDim.x * coalesced_values) { + #ifndef __clang__ #pragma unroll - for (int d = 0; d < Dims - 1; d++) { - i_d = div_mod(idx, idx, out_strides[d]); - if (d == channel_dim) + #endif + for (uint64_t i = 0; i < coalesced_values; i++) { + uint64_t idx = offset + i; + if (idx >= block_end) break; + uint64_t out_idx = idx; + + // If no dimensions were skipped (AllDims=true) we can avoid division in the last dimension, + // because know the strides are 1 (or we treat them as 1 if we fused dimensions) + int i_c = 0; + int i_d; + bool out_of_bounds = false; + uint64_t in_idx = 0; + + #pragma unroll + for (int d = 0; d < Dims - 1; d++) { + i_d = div_mod(idx, idx, out_strides[d]); + if (d == channel_dim) + i_c = i_d; + out_of_bounds |= is_out_of_bounds(anchor[d] + i_d, in_shape[d]); + if (!out_of_bounds) + in_idx += i_d * in_strides[d]; + } + + constexpr int d = LastDim; + i_d = idx; // out_strides[d] is 1 + if (AllDims && d == channel_dim) i_c = i_d; - out_of_bounds |= is_out_of_bounds(anchor[d] + i_d, in_shape[d]); + out_of_bounds |= is_out_of_bounds(inner_in_anchor + i_d, inner_in_extent); if (!out_of_bounds) - in_idx += i_d * in_strides[d]; - } - - constexpr int d = LastDim; - i_d = idx; // out_strides[d] is 1 - if (AllDims && d == channel_dim) - i_c = i_d; - out_of_bounds |= is_out_of_bounds(inner_in_anchor + i_d, inner_in_extent); - if (!out_of_bounds) - in_idx += i_d; // in_strides[d] is 1 + in_idx += i_d; // in_strides[d] is 1 - // Fill values are reused a lot, so let's make sure they are cached (by using __ldg()) - out[out_idx] = out_of_bounds ? __ldg(&fill_values[i_c]) : clamp(in[in_idx]); + // Fill values are reused a lot, so let's make sure they are cached (by using __ldg()) + out[out_idx] = out_of_bounds ? __ldg(&fill_values[i_c]) : clamp(in[in_idx]); + } } } template __global__ void SliceKernel(const SliceSampleDesc *samples, const SliceBlockDesc *blocks) { int sampleIdx = blocks[blockIdx.x].sampleIdx; - uint64_t offset = blocks[blockIdx.x].offset + threadIdx.x; + uint64_t offset = blocks[blockIdx.x].offset + threadIdx.x * coalesced_values; uint64_t block_end = blocks[blockIdx.x].offset + blocks[blockIdx.x].size; auto sample = samples[sampleIdx]; auto *out = static_cast(sample.out);