Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix different types operations
Browse files Browse the repository at this point in the history
  • Loading branch information
agrabow committed Nov 16, 2021
1 parent a3dd206 commit 34ff9eb
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 18 deletions.
1 change: 1 addition & 0 deletions src/operator/nn/dnnl/dnnl_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ bool SupportDNNLTranspose(const NDArray& data);
bool SupportDNNLBatchDot(const std::vector<NDArray>& inputs, const NDArray& output);
bool SupportDNNLLayerNorm(const LayerNormParam& param, const std::vector<NDArray>& inputs);
bool SupportDNNLReshape(const NDArray& input, const NDArray& output);
bool SupportDNNLBinary(const std::vector<NDArray>& inputs);
} // namespace op

static int GetTypeSize(int dtype) {
Expand Down
5 changes: 5 additions & 0 deletions src/operator/nn/dnnl/dnnl_binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ void DNNLBinaryOpFwd::Execute(const std::vector<NDArray>& inputs,
DNNLStream::Get()->Submit();
}

bool SupportDNNLBinary(const std::vector<NDArray>& inputs) {
return inputs[0].shape().ndim() != 0 && IsDNNLType(inputs[0].dtype()) &&
inputs[1].shape().ndim() != 0 && IsDNNLType(inputs[1].dtype());
}

} // namespace op
} // namespace mxnet
#endif // MXNET_USE_ONEDNN == 1
16 changes: 0 additions & 16 deletions src/operator/numpy/np_elemwise_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -548,22 +548,6 @@ void NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
.add_argument("data", "NDArray-or-Symbol", "source input") \
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())

inline bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const int ltype = in_attrs->at(0);
const int rtype = in_attrs->at(1);
if (ltype != -1 && rtype != -1 && (ltype != rtype)) {
// Only when both input types are known and not the same, we enter the mixed-precision mode
TYPE_ASSIGN_CHECK(*out_attrs, 0, common::np_binary_out_infer_type(ltype, rtype));
} else {
return ElemwiseType<2, 1>(attrs, in_attrs, out_attrs);
}
return true;
}

#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(2) \
Expand Down
18 changes: 17 additions & 1 deletion src/operator/tensor/elemwise_binary_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,22 @@ void BinaryBroadcastBackwardUseIn(const nnvm::NodeAttrs& attrs,
}
}

inline bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const int ltype = in_attrs->at(0);
const int rtype = in_attrs->at(1);
if (ltype != -1 && rtype != -1 && (ltype != rtype)) {
// Only when both input types are known and not the same, we enter the mixed-precision mode
TYPE_ASSIGN_CHECK(*out_attrs, 0, common::np_binary_out_infer_type(ltype, rtype));
} else {
return ElemwiseType<2, 1>(attrs, in_attrs, out_attrs);
}
return true;
}

#define MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(2) \
Expand All @@ -772,7 +788,7 @@ void BinaryBroadcastBackwardUseIn(const nnvm::NodeAttrs& attrs,
return std::vector<std::string>{"lhs", "rhs"}; \
}) \
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape) \
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) \
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryMixedPrecisionType) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs) { \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ static void BinaryOperatorComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& outputs) {
#if MXNET_USE_ONEDNN == 1
if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) {
if (IsDNNLType(inputs[0].dtype())) {
if (SupportDNNLBinary(inputs)) {
const dnnl::algorithm alg = GetDNNLAlgorithm<OP>::dnnl_alg;
DNNLRun(DNNLBinaryOpForward<alg>, attrs, ctx, inputs, req, outputs);
} else {
Expand Down

0 comments on commit 34ff9eb

Please sign in to comment.