Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
backport mixed type binary op
Browse files Browse the repository at this point in the history
  • Loading branch information
yijunc committed Jul 2, 2020
1 parent 9c06894 commit 306fcad
Show file tree
Hide file tree
Showing 25 changed files with 594 additions and 485 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ cmake_install.cmake
# Mac OS X
.DS_Store

# Windows
windows_package.7z
windows_package

#Notebook Automated Test
!tests/nightly/test_tutorial_config.txt
!tests/nightly/TestNotebook
Expand Down
11 changes: 11 additions & 0 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,16 @@ class Imperative {
}
return is_np_shape_thread_local_ ? 1 : 0;
}

/*! \brief return current numpy default dtype compatibility status.
* */
bool is_np_default_dtype() const {
if (is_np_default_dtype_global_) {
return true;
}
return false;
}

/*! \brief specify numpy compatibility off, thread local on or global on. */
bool set_is_np_shape(int is_np_shape) {
NumpyShape flag = static_cast<NumpyShape>(is_np_shape);
Expand Down Expand Up @@ -215,6 +225,7 @@ class Imperative {
static MX_THREAD_LOCAL bool is_np_shape_thread_local_;
#endif
bool is_np_shape_global_{false};
bool is_np_default_dtype_global_{false};
/*! \brief node count used for naming */
std::atomic<uint64_t> node_count_{0};
/*! \brief variable count used for naming */
Expand Down
8 changes: 5 additions & 3 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,13 +1529,15 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou
if isinstance(rhs, numeric_types):
return fn_scalar(lhs, rhs, out=out)
else:
is_int = isinstance(rhs, integer_types)
if rfn_scalar is None:
# commutative function
return lfn_scalar(rhs, float(lhs), out=out)
return lfn_scalar(rhs, scalar=float(lhs), is_int=is_int, out=out)
else:
return rfn_scalar(rhs, float(lhs), out=out)
return rfn_scalar(rhs, scalar=float(lhs), is_int=is_int, out=out)
elif isinstance(rhs, numeric_types):
return lfn_scalar(lhs, float(rhs), out=out)
is_int = isinstance(rhs, integer_types)
return lfn_scalar(lhs, scalar=float(rhs), is_int=is_int, out=out)
elif isinstance(rhs, Symbol):
return fn_array(lhs, rhs, out=out)
else:
Expand Down
19 changes: 19 additions & 0 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <nnvm/node.h>
#include <mxnet/engine.h>
#include <mxnet/ndarray.h>
#include <mxnet/imperative.h>
#include <mxnet/op_attr_types.h>
#include <mxnet/graph_attr_types.h>
#include <nnvm/graph_attr_types.h>
Expand Down Expand Up @@ -874,6 +875,11 @@ inline bool is_float(const int dtype) {
return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16;
}

inline bool is_int(const int dtype) {
return dtype == mshadow::kUint8 || dtype == mshadow::kInt8 ||
dtype == mshadow::kInt32 || dtype == mshadow::kInt64;
}

inline int get_more_precise_type(const int type1, const int type2) {
if (type1 == type2) return type1;
if (is_float(type1) && is_float(type2)) {
Expand Down Expand Up @@ -910,6 +916,19 @@ inline int np_binary_out_infer_type(const int type1, const int type2) {
return get_more_precise_type(type1, type2);
}

inline int GetDefaultDtype() {
return Imperative::Get()->is_np_default_dtype() ?
mshadow::kFloat64 :
mshadow::kFloat32;
}

inline int GetDefaultDtype(int dtype) {
if (dtype != -1) return dtype;
return Imperative::Get()->is_np_default_dtype() ?
mshadow::kFloat64 :
mshadow::kFloat32;
}

} // namespace common
} // namespace mxnet
#endif // MXNET_COMMON_UTILS_H_
6 changes: 2 additions & 4 deletions src/operator/contrib/gradient_multiplier_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ In forward pass it acts as an identity transform. During backpropagation it
multiplies the gradient from the subsequent level by a scalar factor lambda and passes it to
the preceding layer.
)code" ADD_FILELINE)
.set_attr_parser([](NodeAttrs* attrs) {
attrs->parsed = dmlc::stod(attrs->dict["scalar"]);
})
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<1, 1, false, true, true>)
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", UnaryOp::IdentityComputeEx<cpu>)
Expand All @@ -88,7 +86,7 @@ the preceding layer.
[](const NodeAttrs& attrs){
return std::vector<bool>{true};
})
.add_argument("scalar", "float", "lambda multiplier");
.add_arguments(NumpyBinaryScalarParam::__FIELDS__());

MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_contrib_backward_gradientmultiplier)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
Expand Down
83 changes: 76 additions & 7 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ struct true_divide : public mxnet_op::tunable {
return static_cast<float>(a) / static_cast<float>(b);
}

#ifndef _WIN32
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
Expand All @@ -166,7 +165,6 @@ struct true_divide : public mxnet_op::tunable {
MSHADOW_XINLINE static double Map(DType a, double b) {
return static_cast<double>(a) / b;
}
#endif
};

struct rtrue_divide : public mxnet_op::tunable {
Expand All @@ -182,7 +180,6 @@ struct rtrue_divide : public mxnet_op::tunable {
return static_cast<float>(b) / static_cast<float>(a);
}

#ifndef _WIN32
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
Expand All @@ -200,14 +197,12 @@ struct rtrue_divide : public mxnet_op::tunable {
MSHADOW_XINLINE static double Map(DType a, double b) {
return b / static_cast<double>(a);
}
#endif
};

MXNET_BINARY_MATH_OP_NC(left, a);

MXNET_BINARY_MATH_OP_NC(right, b);

#ifndef _WIN32
struct mixed_plus {
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
Expand Down Expand Up @@ -345,8 +340,12 @@ struct mixed_rpower {
return static_cast<double>(math::pow(b, a));
}
};
#endif

#pragma GCC diagnostic push
#if __GNUC__ >= 7
#pragma GCC diagnostic ignored "-Wint-in-bool-context"
#pragma GCC diagnostic ignored "-Wbool-compare"
#endif
MXNET_BINARY_MATH_OP_NC_WITH_BOOL(mul, a * b);

MXNET_BINARY_MATH_OP_NC_WITH_BOOL(div, a / b);
Expand Down Expand Up @@ -575,7 +574,6 @@ MXNET_BINARY_MATH_OP(rpower, math::pow(b, a));
MXNET_BINARY_MATH_OP(rpower_grad, math::id(a) * math::log(b));

MXNET_BINARY_MATH_OP(arctan2, math::atan2(a, b));

MXNET_BINARY_MATH_OP(arctan2_grad, math::id(b) / (math::id(a * a + b * b)));

MXNET_BINARY_MATH_OP(arctan2_rgrad, -math::id(a) / (math::id(a * a + b * b)));
Expand Down Expand Up @@ -728,6 +726,10 @@ MXNET_BINARY_MATH_OP_NC(minus_sign, a - b > DType(0) ? DType(1) : -DType(1));

MXNET_BINARY_MATH_OP(rminus, b - a);

MXNET_BINARY_MATH_OP_NC(posone, 1);

MXNET_BINARY_MATH_OP_NC(negone, -1);

MXNET_BINARY_MATH_OP(div_grad, 1.0f / math::id(b));

template<>
Expand Down Expand Up @@ -795,6 +797,73 @@ struct mod : public mxnet_op::tunable {
}
};

struct mixed_mod {
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return mod::Map(static_cast<mshadow::half::half_t>(a), b);
}

template<typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return mod::Map(static_cast<float>(a), b);
}

template<typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_same<DType, float>::value ||
std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return mod::Map(static_cast<double>(a), b);
}
};

struct mixed_rmod {
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return mod::Map(b, static_cast<mshadow::half::half_t>(a));
}

template<typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return mod::Map(b, static_cast<float>(a));
}

template<typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_same<DType, float>::value ||
std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return mod::Map(b, static_cast<double>(a));
}
};

struct fmod : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
if (b == DType(0)) {
return DType(0);
} else {
return DType(::fmod(static_cast<double>(a), static_cast<double>(b)));
}
}
};

struct rfmod : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
if (a == DType(0)) {
return DType(0);
} else {
return DType(::fmod(static_cast<double>(b), static_cast<double>(a)));
}
}
};

template<>
MSHADOW_XINLINE mshadow::half::half2_t mod::Map<mshadow::half::half2_t>
Expand Down
9 changes: 7 additions & 2 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,13 @@ struct op_with_req {
KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value));
}

#ifndef _WIN32
/*! \brief input is two tensors with different type and with a boolean output tensor */
template<typename LType, typename RType,
typename std::enable_if<!std::is_same<LType, RType>::value, int>::type = 0>
MSHADOW_XINLINE static void Map(index_t i, bool *out, const LType *lhs, const RType *rhs) {
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i]));
}

/*! \brief inputs are two tensors with a half_t output tensor */
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
Expand Down Expand Up @@ -903,7 +909,6 @@ struct op_with_req {
MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, const double value) {
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value));
}
#endif

/*! \brief inputs are two tensors with a float output tensor */
template<typename DType,
Expand Down
18 changes: 10 additions & 8 deletions src/operator/numpy/np_elemwise_broadcast_logic_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ struct TVMBinaryBroadcastScalarCompute {

// scalar param
type_codes[1] = kDLFloat;
values[1].v_float64 = nnvm::get<double>(attrs.parsed);
const NumpyBinaryScalarParam& param = nnvm::get<NumpyBinaryScalarParam>(attrs.parsed);
values[1].v_float64 = param.scalar;

// output tensor
type_codes[2] = kTVMDLTensorHandle;
Expand All @@ -225,9 +226,7 @@ struct TVMBinaryBroadcastScalarCompute {
NNVM_REGISTER_OP(_npi_##name##_scalar) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser([](NodeAttrs* attrs) { \
attrs->parsed = dmlc::stod(attrs->dict["scalar"]); \
}) \
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>) \
.set_attr<nnvm::FListInputNames>("FListInputNames", \
[](const NodeAttrs& attrs) { \
return std::vector<std::string>{"data"}; \
Expand All @@ -240,7 +239,7 @@ struct TVMBinaryBroadcastScalarCompute {
}) \
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes) \
.add_argument("data", "NDArray-or-Symbol", "First input to the function") \
.add_argument("scalar", "float", "scalar input")
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(equal);
MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(not_equal);
Expand Down Expand Up @@ -285,9 +284,12 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(less_equal);

#else

#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(name) \
NNVM_REGISTER_OP(_npi_##name##_scalar) \
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeLogic<cpu, mshadow_op::np_##name>)
#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(name) \
NNVM_REGISTER_OP(_npi_##name##_scalar) \
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeLogic<cpu, mshadow_op::np_##name>) \
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) { \
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \
})

#endif // MXNET_USE_TVM_OP

Expand Down
Loading

0 comments on commit 306fcad

Please sign in to comment.