Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-323] Improve performance of broadcast ops backward pass #11252

Merged
merged 11 commits into from
Jul 13, 2018
4 changes: 4 additions & 0 deletions src/operator/tensor/broadcast_reduce-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,10 @@ void Reduce(Stream<gpu> *s, const TBlob& small, const OpReqType req,
ReduceImpl<Reducer, ndim, DType, OP>(stream, small, req, big, workspace, config);
}

template <typename Reducer, int ndim, typename DType, typename OP>
void ReduceWithExtraMem(Stream<gpu>* s, const TBlob& small, const OpReqType req,
const Tensor<gpu, 1, char>& workspace, const TBlob& big) {};

template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2>
void Reduce(Stream<gpu> *s, const TBlob& small, const OpReqType req,
const Tensor<gpu, 1, char>& workspace, const TBlob& big,
Expand Down
48 changes: 45 additions & 3 deletions src/operator/tensor/broadcast_reduce-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <string>
#include <utility>
#include "../mshadow_op.h"
#include "../operator_common.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -204,16 +205,57 @@ void seq_reduce_compute(const int N, const int M, const bool addto,
}
}

template<typename Reducer, int ndim, typename DType, typename OP>
void Reduce(Stream<cpu> *s, const TBlob& small, const OpReqType req,
template <typename Reducer, int ndim, typename DType, typename OP>
void seq_reduce_compute_extra_mem(const int N, const int M, const bool addto,
const DType* big, DType* small,
const Shape<ndim> bshape,
const Shape<ndim> sshape,
const Shape<ndim> rshape,
const Shape<ndim> rstride,
const index_t* ws_dptr) {
#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
for (int idx = 0; idx < N; ++idx) {
Shape<ndim> coord = unravel(idx, sshape);
int j = ravel(coord, bshape);
DType val, residual;
Reducer::SetInitValue(val, residual);
for (int k = 0; k < M; ++k) {
Reducer::Reduce(val, OP::Map(big[j + ws_dptr[k]]), residual);
}
assign(&small[idx], addto, val);
}
}

template <typename Reducer, int ndim, typename DType, typename OP>
void Reduce(Stream<cpu>* s, const TBlob& small, const OpReqType req,
const Tensor<cpu, 1, char>& workspace, const TBlob& big) {
if (req == kNullOp) return;
Shape<ndim> rshape, rstride;
diff(small.shape_.get<ndim>(), big.shape_.get<ndim>(), &rshape, &rstride);
int N = small.shape_.Size(), M = rshape.Size();
seq_reduce_compute<Reducer, ndim, DType, OP>(
N, M, req == kAddTo, big.dptr<DType>(), small.dptr<DType>(),
big.shape_.get<ndim>(), small.shape_.get<ndim>(), rshape, rstride);
}

template <typename Reducer, int ndim, typename DType, typename OP>
void ReduceWithExtraMem(Stream<cpu>* s, const TBlob& small, const OpReqType req,
const Tensor<cpu, 1, char>& workspace, const TBlob& big) {
using namespace mxnet_op;
if (req == kNullOp) return;
Shape<ndim> rshape, rstride;
diff(small.shape_.get<ndim>(), big.shape_.get<ndim>(), &rshape, &rstride);
index_t* ws_dptr = reinterpret_cast<index_t*>(workspace.dptr_);
int N = small.shape_.Size(), M = rshape.Size();
#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
for (int k = 0; k < M; k++) {
Shape<ndim> coord = unravel(k, rshape);
ws_dptr[k] = dot(coord, rstride);
}

seq_reduce_compute_extra_mem<Reducer, ndim, DType, OP>(
N, M, req == kAddTo, big.dptr<DType>(), small.dptr<DType>(), big.shape_.get<ndim>(),
small.shape_.get<ndim>(), rshape, rstride);
small.shape_.get<ndim>(), rshape, rstride, ws_dptr);
}

template<int ndim, typename DType>
Expand Down
64 changes: 64 additions & 0 deletions src/operator/tensor/elemwise_binary_broadcast_op-inl.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_BROADCAST_OP_CUH_
#define MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_BROADCAST_OP_CUH_
#include <mxnet/operator_util.h>
#include <mxnet/op_attr_types.h>
#include <algorithm>
#include <vector>
#include <string>
#include <utility>
#include "broadcast_reduce-inl.h"
namespace mxnet {
namespace op {
template<typename xpu, typename LOP, typename ROP>
inline typename std::enable_if<std::is_same<xpu, gpu>::value, void>::type
BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace broadcast;
TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(outputs[0].shape_, outputs[1].shape_, inputs[0].shape_,
&new_lshape, &new_rshape, &new_oshape);
if (!ndim) {
ElemwiseBinaryOp::BackwardUseNone<gpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
} else {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Stream<gpu> *s = ctx.get_stream<gpu>();
const TBlob lhs = outputs[0].reshape(new_lshape);
const TBlob rhs = outputs[1].reshape(new_rshape);
const TBlob out = inputs[0].reshape(new_oshape);
BROADCAST_NDIM_SWITCH(ndim, NDim, {
// Request temporary storage
size_t workspace_size = new_oshape.Size();
Tensor<gpu, 1, char> workspace =
ctx.requested[0].get_space_typed<gpu, 1, char>(
Shape1(workspace_size * sizeof(index_t)), s);
Reduce<red::sum, NDim, DType, LOP>(s, lhs, req[0], workspace, out);
Reduce<red::sum, NDim, DType, ROP>(s, rhs, req[1], workspace, out);
});
});
}
}
} // namespace op
} // namespace mxnet
#endif
35 changes: 22 additions & 13 deletions src/operator/tensor/elemwise_binary_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,8 @@ void BinaryBroadcastComputeDenseEx(const nnvm::NodeAttrs& attrs,
}

template<typename xpu, typename LOP, typename ROP>
void BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs,
inline typename std::enable_if<std::is_same<xpu, cpu>::value, void>::type
BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
Expand All @@ -535,29 +536,34 @@ void BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs,
int ndim = BinaryBroadcastShapeCompact(outputs[0].shape_, outputs[1].shape_, inputs[0].shape_,
&new_lshape, &new_rshape, &new_oshape);
if (!ndim) {
ElemwiseBinaryOp::BackwardUseNone<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
ElemwiseBinaryOp::BackwardUseNone<cpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
} else {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Stream<xpu> *s = ctx.get_stream<xpu>();
Stream<cpu> *s = ctx.get_stream<cpu>();
const TBlob lhs = outputs[0].reshape(new_lshape);
const TBlob rhs = outputs[1].reshape(new_rshape);
const TBlob out = inputs[0].reshape(new_oshape);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this implementation is only for cpu, is it better to replace xpu with cpu inside?

BROADCAST_NDIM_SWITCH(ndim, NDim, {
// Request temporary storage
size_t workspace_size_l = ReduceWorkspaceSize<NDim, DType>(
s, lhs.shape_, req[0], out.shape_);
size_t workspace_size_r = ReduceWorkspaceSize<NDim, DType>(
s, rhs.shape_, req[1], out.shape_);
size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
Reduce<red::sum, NDim, DType, LOP>(s, lhs, req[0], workspace, out);
Reduce<red::sum, NDim, DType, ROP>(s, rhs, req[1], workspace, out);
size_t workspace_size = new_oshape.Size();
Tensor<cpu, 1, char> workspace =
ctx.requested[0].get_space_typed<cpu, 1, char>(
Shape1(workspace_size * sizeof(index_t)), s);
ReduceWithExtraMem<red::sum, NDim, DType, LOP>(s, lhs, req[0], workspace, out);
ReduceWithExtraMem<red::sum, NDim, DType, ROP>(s, rhs, req[1], workspace, out);
});
});
}
}

template<typename xpu, typename LOP, typename ROP>
inline typename std::enable_if<std::is_same<xpu, gpu>::value, void>::type
BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs);

template<typename xpu, int ndim, typename DType, typename LOP, typename ROP>
inline void BinaryBroadcastBackwardUseInImpl(const OpContext& ctx,
const std::vector<TBlob>& inputs,
Expand All @@ -581,7 +587,7 @@ inline void BinaryBroadcastBackwardUseInImpl(const OpContext& ctx,
s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_);
size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, LOP>(s, lgrad, req[0], workspace,
ograd, lhs, rhs);
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, ROP>(s, rgrad, req[1], workspace,
Expand Down Expand Up @@ -629,4 +635,7 @@ void BinaryBroadcastBackwardUseIn(const nnvm::NodeAttrs& attrs,

} // namespace op
} // namespace mxnet
#ifdef __CUDACC__
#include "./elemwise_binary_broadcast_op-inl.cuh"
#endif
#endif // MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_BROADCAST_OP_H_