From 91d85e844112aa6bdbbd800001bf29857d226f30 Mon Sep 17 00:00:00 2001 From: Yijun Chen Date: Wed, 10 Jun 2020 06:38:43 +0000 Subject: [PATCH] unify impl --- .gitignore | 4 + src/operator/mshadow_op.h | 10 -- src/operator/mxnet_op.h | 2 - .../numpy/np_elemwise_broadcast_op.cc | 53 -------- .../numpy/np_elemwise_broadcast_op.cu | 30 ----- src/operator/numpy/np_elemwise_broadcast_op.h | 76 +----------- src/operator/numpy/np_true_divide-inl.h | 113 ------------------ src/operator/numpy/np_true_divide.cc | 12 -- .../tensor/elemwise_binary_broadcast_op.h | 4 +- 9 files changed, 7 insertions(+), 297 deletions(-) diff --git a/.gitignore b/.gitignore index c50d1ec99b9f..9fafdb13cb7e 100644 --- a/.gitignore +++ b/.gitignore @@ -121,6 +121,10 @@ cmake_install.cmake # Mac OS X .DS_Store +# Windows +windows_package.7z +windows_package + #Notebook Automated Test !tests/nightly/test_tutorial_config.txt !tests/nightly/TestNotebook diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 9069af9d2bf7..55f26b08fcc1 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -150,7 +150,6 @@ struct true_divide : public mxnet_op::tunable { return static_cast(a) / static_cast(b); } -#ifndef _WIN32 template::value, int>::type = 0> MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { @@ -168,7 +167,6 @@ struct true_divide : public mxnet_op::tunable { MSHADOW_XINLINE static double Map(DType a, double b) { return static_cast(a) / b; } -#endif }; struct rtrue_divide : public mxnet_op::tunable { @@ -184,7 +182,6 @@ struct rtrue_divide : public mxnet_op::tunable { return static_cast(b) / static_cast(a); } -#ifndef _WIN32 template::value, int>::type = 0> MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { @@ -202,14 +199,12 @@ struct rtrue_divide : public mxnet_op::tunable { MSHADOW_XINLINE static double Map(DType a, double b) { return b / static_cast(a); } -#endif }; MXNET_BINARY_MATH_OP_NC(left, a); MXNET_BINARY_MATH_OP_NC(right, b); -#ifndef _WIN32 struct mixed_plus { template::value, int>::type = 0> @@ -347,8 +342,6 @@ struct mixed_rpower { return static_cast(math::pow(b, a)); } }; -#endif - #pragma GCC diagnostic push #if __GNUC__ >= 7 @@ -584,7 +577,6 @@ MXNET_BINARY_MATH_OP(rpower, math::pow(b, a)); MXNET_BINARY_MATH_OP(rpower_grad, math::id(a) * math::log(b)); MXNET_BINARY_MATH_OP(arctan2, math::atan2(a, b)); - MXNET_BINARY_MATH_OP(arctan2_grad, math::id(b) / (math::id(a * a + b * b))); MXNET_BINARY_MATH_OP(arctan2_rgrad, -math::id(a) / (math::id(a * a + b * b))); @@ -819,7 +811,6 @@ struct mod : public mxnet_op::tunable { } }; -#ifndef _WIN32 struct mixed_mod { template::value, int>::type = 0> @@ -865,7 +856,6 @@ struct mixed_rmod { return mod::Map(b, static_cast(a)); } }; -#endif struct fmod : public mxnet_op::tunable { template diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index bc8c0afcf1a2..8b7a38be3986 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -867,7 +867,6 @@ struct op_with_req { KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i])); } -#ifndef _WIN32 /*! \brief inputs are two tensors with a half_t output tensor */ template::value, int>::type = 0> @@ -921,7 +920,6 @@ struct op_with_req { MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, const double value) { KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value)); } -#endif /*! \brief inputs are two tensors with a float output tensor */ template("FListInputNames", \ - [](const NodeAttrs& attrs) { \ - return std::vector{"lhs", "rhs"}; \ - }) \ - .set_attr("FInferShape", BinaryBroadcastShape) \ - .set_attr("FInferType", NumpyBinaryMixedPrecisionType) \ - .set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs){ \ - return std::vector >{{0, 0}, {1, 0}}; \ - }) \ - .set_attr("FResourceRequest", \ - [](const NodeAttrs& attrs) { \ - return std::vector{ResourceRequest::kTempSpace}; \ - }) \ - .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ - .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") -#endif MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_add) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastComputeWithBool) -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastComputeWithBool) -#endif .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_add"}); NNVM_REGISTER_OP(_backward_npi_broadcast_add) @@ -133,16 +104,10 @@ NNVM_REGISTER_OP(_backward_npi_broadcast_add) mshadow_op::posone>); MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastCompute) -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastCompute) -#endif .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_sub"}); NNVM_REGISTER_OP(_backward_npi_broadcast_sub) @@ -161,16 +126,10 @@ NNVM_REGISTER_OP(_backward_npi_broadcast_sub) mshadow_op::negone>); MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastComputeWithBool) -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastComputeWithBool) -#endif .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_mul"}); NNVM_REGISTER_OP(_backward_npi_broadcast_mul) @@ -189,16 +148,10 @@ NNVM_REGISTER_OP(_backward_npi_broadcast_mul) mshadow_op::left>); MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_mod) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastCompute) -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastCompute) -#endif .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_mod"}); NNVM_REGISTER_OP(_backward_npi_broadcast_mod) @@ -217,16 +170,10 @@ NNVM_REGISTER_OP(_backward_npi_broadcast_mod) mshadow_op::mod_rgrad>); MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_power) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastComputeWithBool) -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastComputeWithBool) -#endif .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_power"}); NNVM_REGISTER_OP(_backward_npi_broadcast_power) diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index 8a13b42e4846..a2927cda61ff 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -29,80 +29,50 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_npi_add) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastComputeWithBool); -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastComputeWithBool); -#endif NNVM_REGISTER_OP(_backward_npi_broadcast_add) .set_attr("FCompute", NumpyBinaryBackwardUseIn); NNVM_REGISTER_OP(_npi_subtract) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastCompute); -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastCompute); -#endif NNVM_REGISTER_OP(_backward_npi_broadcast_sub) .set_attr("FCompute", NumpyBinaryBackwardUseIn); NNVM_REGISTER_OP(_npi_multiply) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastComputeWithBool); -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastComputeWithBool); -#endif NNVM_REGISTER_OP(_backward_npi_broadcast_mul) .set_attr("FCompute", NumpyBinaryBackwardUseIn); NNVM_REGISTER_OP(_npi_mod) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastCompute); -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastCompute); -#endif NNVM_REGISTER_OP(_backward_npi_broadcast_mod) .set_attr("FCompute", NumpyBinaryBackwardUseIn); NNVM_REGISTER_OP(_npi_power) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastComputeWithBool); -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastComputeWithBool); -#endif NNVM_REGISTER_OP(_backward_npi_broadcast_power) .set_attr("FCompute", NumpyBinaryBackwardUseIn void MixedAllRealBinaryElemwiseCompute(const std::string& op_name, const OpContext& ctx, @@ -216,13 +215,9 @@ void MixedAllRealBinaryBroadcastCompute(const std::string& op_name, } }); } -#endif -#ifndef _WIN32 + template -#else -template -#endif void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -237,7 +232,6 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, const TBlob& rhs = inputs[1]; const TBlob& out = outputs[0]; -#ifndef _WIN32 mxnet::TShape new_lshape, new_rshape, new_oshape; int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_, &new_lshape, &new_rshape, &new_oshape); @@ -303,64 +297,9 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); } } -#else - mshadow::Stream *s = ctx.get_stream(); - if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { - TBlob temp_tblob; - // one is float, the other is bool - CHECK((out.type_flag_ == lhs.type_flag_) || (out.type_flag_ == rhs.type_flag_)) - << "This case out type should be same as the float type"; - if (lhs.type_flag_ == out.type_flag_) { - MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); - temp_tblob = TBlob(temp_tensor); - }); - CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); - BinaryBroadcastCompute( - attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); - } else { - MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); - temp_tblob = TBlob(temp_tensor); - }); - CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); - BinaryBroadcastCompute( - attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); - } - } else if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) { - TBlob temp_tblob; - if (lhs.type_flag_ == out.type_flag_) { - MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); - temp_tblob = TBlob(temp_tensor); - }); - CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); - BinaryBroadcastCompute( - attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); - } else { - MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); - temp_tblob = TBlob(temp_tensor); - }); - CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); - BinaryBroadcastCompute( - attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); - } - } else { - PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); - } -#endif } -#ifndef _WIN32 template -#else -template -#endif void NumpyBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -382,18 +321,10 @@ void NumpyBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, return; } -#ifndef _WIN32 MixedBinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); -#else - MixedBinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); -#endif } -#ifndef _WIN32 template -#else -template -#endif void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -438,12 +369,7 @@ void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs, } return; } - -#ifndef _WIN32 MixedBinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); -#else - MixedBinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); -#endif } template diff --git a/src/operator/numpy/np_true_divide-inl.h b/src/operator/numpy/np_true_divide-inl.h index 9edd795ffb77..6e975111c8c2 100644 --- a/src/operator/numpy/np_true_divide-inl.h +++ b/src/operator/numpy/np_true_divide-inl.h @@ -59,7 +59,6 @@ void TrueDivideScalarCompute(const nnvm::NodeAttrs &attrs, }); }); } else { -#ifndef _WIN32 CHECK_EQ(outputs[0].type_flag_, mxnet::common::GetDefaultDtype()) << "true_divide only supports float32 and float64" " output when input's dtype is " @@ -71,13 +70,6 @@ void TrueDivideScalarCompute(const nnvm::NodeAttrs &attrs, static_cast(alpha)); }); }); -#else - Tensor temp_tensor = - ctx.requested[0].get_space_typed(mshadow::Shape1(data.Size()), s); - TBlob temp_tblob(temp_tensor); - CastCompute(attrs, ctx, {data}, {kWriteTo}, {temp_tblob}); - TrueDivideScalarCompute(attrs, ctx, {temp_tblob}, req, outputs); -#endif } } @@ -120,8 +112,6 @@ void TrueDivideElemwiseCompute(const nnvm::NodeAttrs &attrs, }); } } else { -#ifndef _WIN32 - // Non-windows case: no usage of temporary space // Case when types of the 2 input tensors are different if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { // both lhs and rhs are float types, output type is the more precise one @@ -156,44 +146,6 @@ void TrueDivideElemwiseCompute(const nnvm::NodeAttrs &attrs, // lhs is integer type, rhs is integer type, output type should be float LOG(FATAL) << "not implemented yet..."; } -#else - // Windows case: using temp space for casting the type - // Case when types of the 2 input tensors are different - if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { - // both lhs and rhs are float types, output type is the more precise one - LOG(FATAL) << "not implemented yet..."; - } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { - // lhs is float type, rhs is integer type, the output type should be the same as lhs - CHECK_EQ(out.type_flag_, - common::is_float(lhs.type_flag_) ? lhs.type_flag_ : rhs.type_flag_) - << "This case out type should be same as the float type"; - TBlob temp_tblob; - if (common::is_float(lhs.type_flag_)) { - // lhs is the float one - MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(mshadow::Shape1(rhs.Size()), s); - temp_tblob = TBlob(temp_tensor); - }); - CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); - TrueDivideElemwiseCompute( - attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); - } else { - // rhs is the float one - MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(mshadow::Shape1(lhs.Size()), s); - temp_tblob = TBlob(temp_tensor); - }); - CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); - TrueDivideElemwiseCompute( - attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); - } - } else { - // lhs is integer type, rhs is integer type, output type should be float - LOG(FATAL) << "not implemented yet..."; - } -#endif } } @@ -217,7 +169,6 @@ void TrueDivideBroadcastCompute(const nnvm::NodeAttrs& attrs, const TBlob& lhs = inputs[0]; const TBlob& rhs = inputs[1]; const TBlob& out = outputs[0]; -#ifndef _WIN32 BROADCAST_NDIM_SWITCH(ndim, NDim, { mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = calc_stride(new_lshape.get()); @@ -277,70 +228,6 @@ void TrueDivideBroadcastCompute(const nnvm::NodeAttrs& attrs, } } }); -#else - if (lhs.type_flag_ == rhs.type_flag_) { - BROADCAST_NDIM_SWITCH(ndim, NDim, { - mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape lstride = calc_stride(new_lshape.get()); - mshadow::Shape rstride = calc_stride(new_rshape.get()); - // When the both inputs have the same data types - if (common::is_float(lhs.type_flag_)) { - // If both inputs are the same float types, output is the same float type - MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, DType, { - Kernel, xpu>:: - template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, - lhs.dptr(), rhs.dptr(), out.dptr()); - }); - } else { - CHECK_EQ(out.type_flag_, mxnet::common::GetDefaultDtype()) - << "true_divide only supports float32 and float64 output when input's dtype is " - << type_string(lhs.type_flag_); - MXNET_INT_TYPE_SWITCH(lhs.type_flag_, DType, { - // If both inputs are the same integer types, output is float type - Kernel, xpu>:: - template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, - lhs.dptr(), rhs.dptr(), out.dptr()); - }); - } - }); - } else { - if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { - // lhs and rhs have different float types, the output is the more precise one - LOG(FATAL) << "not implemented yet..."; - } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { - // one of lhs and rhs is float, the output is the same type as the float one - TBlob temp_tblob; - if (common::is_float(lhs.type_flag_)) { - // lhs is float type, output will be the same float type - CHECK_EQ(lhs.type_flag_, out.type_flag_) - << "lhs should have the same type as out, infer type broken?"; - MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(mshadow::Shape1(rhs.Size()), s); - temp_tblob = TBlob(temp_tensor); - }); - CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); - TrueDivideBroadcastCompute( - attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); - } else { - // rhs is float type, output will be the same float type - CHECK_EQ(rhs.type_flag_, out.type_flag_) - << "rhs should have the same type as out, infer type broken?"; - MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(mshadow::Shape1(lhs.Size()), s); - temp_tblob = TBlob(temp_tensor); - }); - CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); - TrueDivideBroadcastCompute( - attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); - } - } else { - // lhs and rhs have different integer types, the output is float type - LOG(FATAL) << "not implemented yet..."; - } - } -#endif } } diff --git a/src/operator/numpy/np_true_divide.cc b/src/operator/numpy/np_true_divide.cc index 7311d896d71c..3c2d744e4a21 100644 --- a/src/operator/numpy/np_true_divide.cc +++ b/src/operator/numpy/np_true_divide.cc @@ -74,12 +74,6 @@ NNVM_REGISTER_OP(_npi_true_divide) [](const NodeAttrs& attrs){ return std::vector >{{0, 0}, {1, 0}}; }) -#ifdef _WIN32 -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -#endif .set_attr("FCompute", TrueDivideBroadcastCompute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_div"}) .add_argument("lhs", "NDArray-or-Symbol", "Dividend array") @@ -111,12 +105,6 @@ NNVM_REGISTER_OP(_npi_true_divide_scalar) [](const NodeAttrs& attrs) { return std::vector >{{0, 0}}; }) -#ifdef _WIN32 -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -#endif .set_attr("FCompute", TrueDivideScalarCompute) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_div_scalar"}) .add_argument("data", "NDArray-or-Symbol", "source input") diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index 6f6711e9f881..ca83bdb01e37 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -246,8 +246,8 @@ struct binary_broadcast_kernel { } } -#ifndef _WIN32 /*! \brief Map function for binary_broadcast_kernel */ + /* used for mixed type binary ops */ template::value, int>::type = 0> MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, @@ -268,6 +268,7 @@ struct binary_broadcast_kernel { } /*! \brief Map function for binary_broadcast_kernel */ + /* used for mixed type binary ops */ template::value && !std::is_pointer::value, int>::type = 0> @@ -287,7 +288,6 @@ struct binary_broadcast_kernel { KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs, rhs[ridx])); } } -#endif }; template