From 71ff22eb66236834acdd2b541c77e7646bcc91e5 Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Fri, 27 Mar 2020 14:49:36 +0000 Subject: [PATCH] Improving performance by leveraging vectorization. --- src/operator/numpy/np_matmul_op-inl.h | 19 +++- src/operator/tensor/broadcast_reduce_op.h | 106 ++++++++++++++++------ 2 files changed, 93 insertions(+), 32 deletions(-) diff --git a/src/operator/numpy/np_matmul_op-inl.h b/src/operator/numpy/np_matmul_op-inl.h index 89560f64d8c0..610402a6253a 100644 --- a/src/operator/numpy/np_matmul_op-inl.h +++ b/src/operator/numpy/np_matmul_op-inl.h @@ -157,12 +157,27 @@ inline void MatmulImpl(const OpContext& ctx, DType* bc_b_ptr = bc_a_ptr + bc_size_a; MSHADOW_TYPE_SWITCH_WITH_BOOL(input_a.type_flag_, IType, { MSHADOW_TYPE_SWITCH_WITH_BOOL(input_b.type_flag_, OType, { + uint64_t axes[ndim-2], out_stride[ndim-2]; + int iter = ndim - 3, i = 0; + out_stride[iter] = 1; + if (k_a_shape[iter] != k_a_shape_bc[iter]) { + axes[i] = iter; + i++; + } + --iter; + for (; iter >= 0; --iter) { + out_stride[iter] = out_stride[iter-1] * k_a_shape_bc[iter+1]; + if (k_a_shape[iter] != k_a_shape_bc[iter]) { + axes[i] = iter; + i++; + } + } Kernel, xpu>::Launch( s, bc_size_a, input_a.dptr(), bc_a_ptr, - k_a_shape, k_a_shape_bc, OpReqType::kWriteTo, ndim); + k_a_shape, k_a_shape_bc, OpReqType::kWriteTo, ndim, axes, out_stride, i); Kernel, xpu>::Launch( s, bc_size_b, input_b.dptr(), bc_b_ptr, - k_b_shape, k_b_shape_bc, OpReqType::kWriteTo, ndim); + k_b_shape, k_b_shape_bc, OpReqType::kWriteTo, ndim, axes, out_stride, i); }); }); ans = mshadow::Tensor(output.dptr(), diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index dd57f0826562..0a84ee8e8566 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -1051,29 +1051,57 @@ void ReduceAxesBackwardUseInOut(const nnvm::NodeAttrs& attrs, template struct broadcast_kernel { - template - MSHADOW_XINLINE static void Map(IDXType i, + template + MSHADOW_XINLINE static void Map(index_t i, IType *input, OType *output, mshadow::Shape in_shape, mshadow::Shape out_shape, const OpReqType req, - const uint32_t ndim) { - size_t in_stride = 1; - size_t out_stride = 1; - IDXType idx = i; - IDXType in_idx = i; - for (int iter = ndim - 1; iter >= 0; --iter) { - size_t dim_idx = idx % out_shape[iter]; - in_idx -= dim_idx * out_stride; - if (in_shape[iter] != 1) { - in_idx += dim_idx * in_stride; - } - idx /= out_shape[iter]; - in_stride *= in_shape[iter]; - out_stride *= out_shape[iter]; - } - KERNEL_ASSIGN(output[i], req, OP::Map(input[in_idx])); + const uint32_t ndim, + const uint64_t *axes, + const uint64_t *out_stride, + const size_t no_axes) { + index_t idx = i; + index_t init_off = 0; + for (int iter = ndim - 1; idx > 0 && iter >= 0; --iter) { + size_t dim_idx = idx % in_shape[iter]; + init_off += dim_idx * out_stride[iter]; + idx /= in_shape[iter]; + } + index_t stride_0, stride_1, stride_2; + switch (no_axes) { + case 1 : + stride_0 = out_stride[axes[0]]; + for (int l=0; l < out_shape[axes[0]]; l++) { + KERNEL_ASSIGN(output[init_off + l*stride_0], + req, OP::Map(input[i])); + } + break; + + case 2: + stride_1 = out_stride[axes[1]], stride_0 = out_stride[axes[0]]; + for (int k=0; k < out_shape[axes[1]]; k++) { + for (int l=0; l < out_shape[axes[0]]; l++) { + KERNEL_ASSIGN(output[init_off + k*stride_1 + l*stride_0], + req, OP::Map(input[i])); + } + } + break; + + case 3: + stride_2 = out_stride[axes[2]], stride_1 = out_stride[axes[1]]; + stride_0 = out_stride[axes[0]]; + for (int j=0; j < out_shape[axes[2]]; j++) { + for (int k=0; k < out_shape[axes[1]]; k++) { + for (int l=0; l < out_shape[axes[0]]; l++) { + KERNEL_ASSIGN(output[init_off + j*stride_2 + k*stride_1 + l*stride_0], + req, OP::Map(input[i])); + } + } + } + break; + } } }; @@ -1090,6 +1118,7 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs, mxnet::TShape src_shape, dst_shape; BroadcastReduceShapeCompact(outputs[0].shape_, small, &dst_shape, &src_shape); Stream *s = ctx.get_stream(); + bool enable_lt = dst_shape.Size() > INT_MAX; MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, IType, { MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, { mshadow::Shape in_shape; @@ -1103,19 +1132,35 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs, out_shape[i] = 1; } } + uint64_t axes[dst_shape.ndim()], out_stride[dst_shape.ndim()]; + int iter = dst_shape.ndim() - 1, i = 0; + out_stride[iter] = 1; + if (in_shape[iter] != dst_shape[iter]) { + axes[i] = iter; + i++; + } + --iter; + for (; iter >= 0; --iter) { + if (in_shape[iter] != dst_shape[iter]) { + axes[i] = iter; + i++; + } + out_stride[iter] = out_stride[iter+1] * dst_shape[iter+1]; + } if (dst_shape.ndim() == 2) { Tensor out = outputs[0].get_with_shape(dst_shape.get<2>(), s); Tensor data = inputs[0].get_with_shape(src_shape.get<2>(), s); - if (out_shape.Size() > ((int64_t{1} << 31) - 1)) { - typedef int64_t IDXType; - Kernel, xpu>::Launch( - s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, req[0], 2); + if (!enable_lt) { + typedef int32_t index_t; + Kernel, xpu>::Launch( + s, data.shape_.Size(), data.dptr_, out.dptr_, in_shape, + out_shape, req[0], 2, axes, out_stride, 1); } else { - typedef int32_t IDXType; - Kernel, xpu>::Launch( - s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, req[0], 2); + Kernel, xpu>::Launch( + s, data.shape_.Size(), data.dptr_, out.dptr_, in_shape, + out_shape, req[0], 2, axes, out_stride, 1); } } else { const int ndim = MXNET_SPECIAL_MAX_NDIM; @@ -1123,14 +1168,15 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs, outputs[0].get_with_shape(dst_shape.get(), s); Tensor data = inputs[0].get_with_shape(src_shape.get(), s); - if (out_shape.Size() > ((int64_t{1} << 31) - 1)) { - typedef int64_t IDXType; + if (!enable_lt) { + typedef int32_t index_t; Kernel, xpu>::Launch( - s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, req[0], ndim); + s, data.shape_.Size(), data.dptr_, out.dptr_, in_shape, + out_shape, req[0], ndim, axes, out_stride, i); } else { - typedef int32_t IDXType; Kernel, xpu>::Launch( - s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, req[0], ndim); + s, data.shape_.Size(), data.dptr_, out.dptr_, in_shape, + out_shape, req[0], ndim, axes, out_stride, i); } } });