Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Improving performance by leveraging vectorization.
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohit Kumar Srivastava committed Apr 13, 2020
1 parent 6cbc565 commit 71ff22e
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 32 deletions.
19 changes: 17 additions & 2 deletions src/operator/numpy/np_matmul_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,27 @@ inline void MatmulImpl(const OpContext& ctx,
DType* bc_b_ptr = bc_a_ptr + bc_size_a;
MSHADOW_TYPE_SWITCH_WITH_BOOL(input_a.type_flag_, IType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(input_b.type_flag_, OType, {
uint64_t axes[ndim-2], out_stride[ndim-2];
int iter = ndim - 3, i = 0;
out_stride[iter] = 1;
if (k_a_shape[iter] != k_a_shape_bc[iter]) {
axes[i] = iter;
i++;
}
--iter;
for (; iter >= 0; --iter) {
out_stride[iter] = out_stride[iter-1] * k_a_shape_bc[iter+1];
if (k_a_shape[iter] != k_a_shape_bc[iter]) {
axes[i] = iter;
i++;
}
}
Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
s, bc_size_a, input_a.dptr<IType>(), bc_a_ptr,
k_a_shape, k_a_shape_bc, OpReqType::kWriteTo, ndim);
k_a_shape, k_a_shape_bc, OpReqType::kWriteTo, ndim, axes, out_stride, i);
Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
s, bc_size_b, input_b.dptr<IType>(), bc_b_ptr,
k_b_shape, k_b_shape_bc, OpReqType::kWriteTo, ndim);
k_b_shape, k_b_shape_bc, OpReqType::kWriteTo, ndim, axes, out_stride, i);
});
});
ans = mshadow::Tensor<xpu, 3, DType>(output.dptr<DType>(),
Expand Down
106 changes: 76 additions & 30 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1051,29 +1051,57 @@ void ReduceAxesBackwardUseInOut(const nnvm::NodeAttrs& attrs,

template<typename OP>
struct broadcast_kernel {
template<typename IType, typename OType, typename IDXType>
MSHADOW_XINLINE static void Map(IDXType i,
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(index_t i,
IType *input,
OType *output,
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape,
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> out_shape,
const OpReqType req,
const uint32_t ndim) {
size_t in_stride = 1;
size_t out_stride = 1;
IDXType idx = i;
IDXType in_idx = i;
for (int iter = ndim - 1; iter >= 0; --iter) {
size_t dim_idx = idx % out_shape[iter];
in_idx -= dim_idx * out_stride;
if (in_shape[iter] != 1) {
in_idx += dim_idx * in_stride;
}
idx /= out_shape[iter];
in_stride *= in_shape[iter];
out_stride *= out_shape[iter];
}
KERNEL_ASSIGN(output[i], req, OP::Map(input[in_idx]));
const uint32_t ndim,
const uint64_t *axes,
const uint64_t *out_stride,
const size_t no_axes) {
index_t idx = i;
index_t init_off = 0;
for (int iter = ndim - 1; idx > 0 && iter >= 0; --iter) {
size_t dim_idx = idx % in_shape[iter];
init_off += dim_idx * out_stride[iter];
idx /= in_shape[iter];
}
index_t stride_0, stride_1, stride_2;
switch (no_axes) {
case 1 :
stride_0 = out_stride[axes[0]];
for (int l=0; l < out_shape[axes[0]]; l++) {
KERNEL_ASSIGN(output[init_off + l*stride_0],
req, OP::Map(input[i]));
}
break;

case 2:
stride_1 = out_stride[axes[1]], stride_0 = out_stride[axes[0]];
for (int k=0; k < out_shape[axes[1]]; k++) {
for (int l=0; l < out_shape[axes[0]]; l++) {
KERNEL_ASSIGN(output[init_off + k*stride_1 + l*stride_0],
req, OP::Map(input[i]));
}
}
break;

case 3:
stride_2 = out_stride[axes[2]], stride_1 = out_stride[axes[1]];
stride_0 = out_stride[axes[0]];
for (int j=0; j < out_shape[axes[2]]; j++) {
for (int k=0; k < out_shape[axes[1]]; k++) {
for (int l=0; l < out_shape[axes[0]]; l++) {
KERNEL_ASSIGN(output[init_off + j*stride_2 + k*stride_1 + l*stride_0],
req, OP::Map(input[i]));
}
}
}
break;
}
}
};

Expand All @@ -1090,6 +1118,7 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs,
mxnet::TShape src_shape, dst_shape;
BroadcastReduceShapeCompact(outputs[0].shape_, small, &dst_shape, &src_shape);
Stream<xpu> *s = ctx.get_stream<xpu>();
bool enable_lt = dst_shape.Size() > INT_MAX;
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, IType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, {
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape;
Expand All @@ -1103,34 +1132,51 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs,
out_shape[i] = 1;
}
}
uint64_t axes[dst_shape.ndim()], out_stride[dst_shape.ndim()];
int iter = dst_shape.ndim() - 1, i = 0;
out_stride[iter] = 1;
if (in_shape[iter] != dst_shape[iter]) {
axes[i] = iter;
i++;
}
--iter;
for (; iter >= 0; --iter) {
if (in_shape[iter] != dst_shape[iter]) {
axes[i] = iter;
i++;
}
out_stride[iter] = out_stride[iter+1] * dst_shape[iter+1];
}
if (dst_shape.ndim() == 2) {
Tensor<xpu, 2, OType> out =
outputs[0].get_with_shape<xpu, 2, OType>(dst_shape.get<2>(), s);
Tensor<xpu, 2, IType> data =
inputs[0].get_with_shape<xpu, 2, IType>(src_shape.get<2>(), s);
if (out_shape.Size() > ((int64_t{1} << 31) - 1)) {
typedef int64_t IDXType;
Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, req[0], 2);
if (!enable_lt) {
typedef int32_t index_t;
Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
s, data.shape_.Size(), data.dptr_, out.dptr_, in_shape,
out_shape, req[0], 2, axes, out_stride, 1);
} else {
typedef int32_t IDXType;
Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, req[0], 2);
Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
s, data.shape_.Size(), data.dptr_, out.dptr_, in_shape,
out_shape, req[0], 2, axes, out_stride, 1);
}
} else {
const int ndim = MXNET_SPECIAL_MAX_NDIM;
Tensor<xpu, ndim, OType> out =
outputs[0].get_with_shape<xpu, ndim, OType>(dst_shape.get<ndim>(), s);
Tensor<xpu, ndim, IType> data =
inputs[0].get_with_shape<xpu, ndim, IType>(src_shape.get<ndim>(), s);
if (out_shape.Size() > ((int64_t{1} << 31) - 1)) {
typedef int64_t IDXType;
if (!enable_lt) {
typedef int32_t index_t;
Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, req[0], ndim);
s, data.shape_.Size(), data.dptr_, out.dptr_, in_shape,
out_shape, req[0], ndim, axes, out_stride, i);
} else {
typedef int32_t IDXType;
Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, req[0], ndim);
s, data.shape_.Size(), data.dptr_, out.dptr_, in_shape,
out_shape, req[0], ndim, axes, out_stride, i);
}
}
});
Expand Down

0 comments on commit 71ff22e

Please sign in to comment.