Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix mixed type binary logic operators
Browse files Browse the repository at this point in the history
  • Loading branch information
yijunc committed May 28, 2020
1 parent 382279e commit 8ef646d
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 10 deletions.
6 changes: 4 additions & 2 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ using std::is_integral;

#define MXNET_BINARY_LOGIC_OP_NC(name, expr) \
struct name : public mxnet_op::tunable { \
template<typename DType> \
MSHADOW_XINLINE static bool Map(DType a, DType b) { \
template<typename DType, typename EType> \
MSHADOW_XINLINE static bool Map(DType lhs, EType rhs) { \
long double a = static_cast<long double>(lhs); \
long double b = static_cast<long double>(rhs); \
return (expr); \
} \
}
Expand Down
7 changes: 7 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,13 @@ struct op_with_req {
KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value));
}

/*! \brief input is two tensors with different type and with a boolean output tensor */
template<typename LType, typename RType,
typename std::enable_if<!std::is_same<LType, RType>::value, int>::type = 0>
MSHADOW_XINLINE static void Map(index_t i, bool *out, const LType *lhs, const RType *rhs) {
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i]));
}

#ifndef _WIN32
/*! \brief inputs are two tensors with a half_t output tensor */
template<typename DType,
Expand Down
2 changes: 0 additions & 2 deletions src/operator/numpy/np_elemwise_broadcast_logic_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ bool NumpyBinaryLogicOpType(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
if (in_attrs->at(0) == -1 && in_attrs->at(1) == -1) return false;
TYPE_ASSIGN_CHECK(*in_attrs, 0, in_attrs->at(1));
TYPE_ASSIGN_CHECK(*in_attrs, 1, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool);
return true;
}
Expand Down
34 changes: 29 additions & 5 deletions src/operator/tensor/elemwise_binary_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,25 @@ struct binary_broadcast_kernel {
}
}

/*! \brief Map function for binary_broadcast_kernel */
template<typename LType, typename RType, typename OType>
MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req,
const Shape <ndim> &lstride, const Shape <ndim> &rstride,
const Shape <ndim> &oshape, LType *lhs, RType *rhs,
OType *out) {
Shape <ndim> coord = unravel(base, oshape);
auto lidx = static_cast<index_t>(dot(coord, lstride));
auto ridx = static_cast<index_t>(dot(coord, rstride));
KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx]));
// starts from 1 to avoid extra inc at end of loop
for (index_t i = 1; i < length; ++i) {
inc(&coord, oshape, &lidx, lstride, &ridx, rstride);
// When tuning, don't actually run the op, since it's not going to be tuned against
// the actual op we'll eventually be using
KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs[lidx], rhs[ridx]));
}
}

/*! \brief Map function for binary_broadcast_kernel */
template<typename IType, typename DType>
MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req,
Expand Down Expand Up @@ -430,23 +449,28 @@ void BinaryBroadcastComputeLogic(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
if (outputs[0].shape_.Size() == 0U) return;
mxnet::TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_,
const TBlob& lhs = inputs[0];
const TBlob& rhs = inputs[1];
const TBlob& out = outputs[0];
int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_,
&new_lshape, &new_rshape, &new_oshape);
if (!ndim) {
ElemwiseBinaryOp::ComputeLogic<xpu, OP>(attrs, ctx, inputs, req, outputs);
} else {
if (req[0] == kNullOp) return;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(lhs.type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(rhs.type_flag_, EType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::
template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape,
inputs[0].dptr<DType>(), inputs[1].dptr<DType>(),
outputs[0].dptr<bool>());
lhs.dptr<DType>(), rhs.dptr<EType>(),
out.dptr<bool>());
});
});
});
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/operator/tensor/elemwise_binary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -620,14 +620,16 @@ template<typename xpu, typename OP>
CHECK_EQ(outputs.size(), 1U);
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[1].type_flag_, EType, {
const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size())
+ DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
if (size != 0) {
Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s, size,
outputs[0].dptr<bool>(),
inputs[0].dptr<DType>(),
inputs[1].dptr<DType>());
inputs[1].dptr<EType>());
}
});
});
});
}
Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2627,6 +2627,15 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
'mod': (1.0, 5.0, None, None),
'power': (1.0, 3.0, lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2,
lambda y, x1, x2: _np.power(x1, x2) * _np.log(x1)),
'equal': (0.0, 2.0, None, None),
'not_equal': (0.0, 2.0, None, None),
'greater': (0.0, 2.0, None, None),
'less': (0.0, 2.0, None, None),
'greater_equal': (0.0, 2.0, None, None),
'less_equal': (0.0, 2.0, None, None),
'logical_and': (0.0, 2.0, None, None),
'logical_or': (0.0, 2.0, None, None),
'logical_xor': (0.0, 2.0, None, None),
}

shape_pairs = [((3, 2), (3, 2)),
Expand Down

0 comments on commit 8ef646d

Please sign in to comment.