From 6e25c884bab56727ef0e9b81b444875d492d0587 Mon Sep 17 00:00:00 2001 From: bgawrych Date: Fri, 4 Mar 2022 09:26:20 +0100 Subject: [PATCH] Add oneDNN support for "where" operator (#20862) * Where operator enabled in oneDNN * Fix bug & refactor * fix sanity * apply review * Fix get_broadcastable_shape function * Apply review * Remove unused variable * Apply suggestions from code review Co-authored-by: bartekkuncer --- src/operator/nn/dnnl/dnnl_ops-inl.h | 7 + src/operator/nn/dnnl/dnnl_where-inl.h | 73 +++++++ src/operator/nn/dnnl/dnnl_where.cc | 224 ++++++++++++++++++++++ src/operator/numpy/np_where_forward_op.cc | 48 ++++- 4 files changed, 349 insertions(+), 3 deletions(-) create mode 100644 src/operator/nn/dnnl/dnnl_where-inl.h create mode 100644 src/operator/nn/dnnl/dnnl_where.cc diff --git a/src/operator/nn/dnnl/dnnl_ops-inl.h b/src/operator/nn/dnnl/dnnl_ops-inl.h index 40e944939bea..06ed1e0f2625 100644 --- a/src/operator/nn/dnnl/dnnl_ops-inl.h +++ b/src/operator/nn/dnnl/dnnl_ops-inl.h @@ -210,6 +210,13 @@ void DNNLReshapeForward(const nnvm::NodeAttrs& attrs, const NDArray& input, const OpReqType& req, const NDArray& output); + +void DNNLWhereForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + } // namespace op } // namespace mxnet diff --git a/src/operator/nn/dnnl/dnnl_where-inl.h b/src/operator/nn/dnnl/dnnl_where-inl.h new file mode 100644 index 000000000000..bfda68466892 --- /dev/null +++ b/src/operator/nn/dnnl/dnnl_where-inl.h @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file dnnl_where-inl.h + */ + +#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_WHERE_INL_H_ +#define MXNET_OPERATOR_NN_DNNL_DNNL_WHERE_INL_H_ + +#if MXNET_USE_ONEDNN == 1 +#include +#include +#include +#include "dnnl_base-inl.h" +#include "dnnl_ops-inl.h" + +namespace mxnet { +namespace op { + +class DNNLWhereFwd { + public: + struct Tensors { + Tensors(const std::vector& inputs, const std::vector& outputs); + const NDArray& condition; + const NDArray& left; + const NDArray& right; + const NDArray& output; + }; + + static DNNLWhereFwd GetCached(const Tensors& tensors); + + explicit DNNLWhereFwd(const Tensors& tensors); + + void Execute(const Tensors& tensors, + const std::vector& req, + const OpContext& ctx) const; + + private: + dnnl::binary::primitive_desc binary_eq_zero_pd; + dnnl::binary::primitive_desc binary_ne_zero_pd; + dnnl::binary::primitive_desc binary_mul_l_pd; + dnnl::binary::primitive_desc binary_mul_r_pd; + dnnl::binary::primitive_desc binary_sum_pd; + dnnl::binary binary_eq_zero; + dnnl::binary binary_ne_zero; + dnnl::binary binary_mul_l; + dnnl::binary binary_mul_r; + dnnl::binary binary_sum; +}; + +bool SupportDNNLWhere(const std::vector& inputs); + +} // namespace op +} // namespace mxnet +#endif +#endif // MXNET_OPERATOR_NN_DNNL_DNNL_WHERE_INL_H_ diff --git a/src/operator/nn/dnnl/dnnl_where.cc b/src/operator/nn/dnnl/dnnl_where.cc new file mode 100644 index 000000000000..c2335b9c8d63 --- /dev/null +++ b/src/operator/nn/dnnl/dnnl_where.cc @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file dnnl_where.cc + */ + +#if MXNET_USE_ONEDNN == 1 + +#include +#include +#include +#include "dnnl_where-inl.h" +#include "operator/operator_common.h" + +namespace mxnet { +namespace op { + +bool SupportDNNLWhere(const std::vector& inputs) { + static const std::set supported_dtypes = { + mshadow::kFloat32, mshadow::kBfloat16, mshadow::kInt8, mshadow::kUint8}; + for (int i = 0; i < inputs.size(); ++i) { + if (!supported_dtypes.count(inputs[i].dtype()) || inputs[i].shape().Size() <= 0 || + inputs[i].shape().ndim() <= 0) { + return false; + } + } + return true; +} + +void DNNLWhereForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + TmpMemMgr::Get()->Init(ctx.requested[0]); + const auto tensors = DNNLWhereFwd::Tensors(inputs, outputs); + const auto fwd = DNNLWhereFwd::GetCached(tensors); + fwd.Execute(tensors, req, ctx); +} + +DNNLWhereFwd::Tensors::Tensors(const std::vector& inputs, + const std::vector& outputs) + : condition(inputs[0]), left(inputs[1]), right(inputs[2]), output(outputs[0]) {} + +DNNLWhereFwd DNNLWhereFwd::GetCached(const Tensors& tensors) { + using where_op_fwd_map = std::unordered_map; +#if DMLC_CXX11_THREAD_LOCAL + static thread_local where_op_fwd_map fwds; +#else + static MX_THREAD_LOCAL where_op_fwd_map fwds; +#endif + + OpSignature key; + key.AddSign(tensors.condition); + key.AddSign(tensors.left); + key.AddSign(tensors.right); + key.AddSign(tensors.output); + + auto it = fwds.find(key); + if (it == fwds.end()) { + DNNLWhereFwd fwd(tensors); + it = AddToCache(&fwds, key, fwd); + } + return it->second; +} + +/*! + * \brief Align number of input dimensions to output. It is done by prepending shape with ones. + * oneDNN requires shapes to have same number of dimensions even if they are broadcastable. + * \param in_shape input shape which should be broadcastable with output + * \param out_shape output shape to which number of dimensions of input should be aligned + * \return input shape extended with ones to match number of dimensions of output + */ +static mxnet::TShape GetBroadcastableShape(const mxnet::TShape& in_shape, + const mxnet::TShape& out_shape) { + if (in_shape == out_shape) { + return in_shape; + } + + mxnet::TShape broadcastable_in_shape(out_shape.ndim(), 1); + const int lack_dims = out_shape.ndim() - in_shape.ndim(); + for (int i = lack_dims; i < out_shape.ndim(); ++i) { + broadcastable_in_shape[i] = in_shape[i - lack_dims]; + } + return broadcastable_in_shape; +} + +DNNLWhereFwd::DNNLWhereFwd(const Tensors& tensors) { + const auto cpu_engine = CpuEngine::Get()->get_engine(); + + const auto cnd = tensors.condition; + const auto lhs = tensors.left; + const auto rhs = tensors.right; + const auto out = tensors.output; + + const auto cnd_shape = GetBroadcastableShape(cnd.shape(), out.shape()); + const auto lhs_shape = GetBroadcastableShape(lhs.shape(), out.shape()); + const auto rhs_shape = GetBroadcastableShape(rhs.shape(), out.shape()); + + const auto& cnd_dtype = get_dnnl_type(cnd.dtype()); + const auto& inp_dtype = get_dnnl_type(lhs.dtype()); + const auto& def_ft = static_cast(GetDefaultFormat(lhs_shape.ndim())); + + const auto& cnd_dims = dnnl::memory::dims(cnd_shape.begin(), cnd_shape.end()); + const auto& lhs_dims = dnnl::memory::dims(lhs_shape.begin(), lhs_shape.end()); + const auto& rhs_dims = dnnl::memory::dims(rhs_shape.begin(), rhs_shape.end()); + const auto& out_dims = dnnl::memory::dims(out.shape().begin(), out.shape().end()); + const auto& scalar_dims = dnnl::memory::dims(cnd_shape.ndim(), 1); // broadcastable scalar + + auto cnd_md = dnnl::memory::desc(cnd_dims, cnd_dtype, def_ft); + auto lhs_md = dnnl::memory::desc(lhs_dims, inp_dtype, def_ft); + auto rhs_md = dnnl::memory::desc(rhs_dims, inp_dtype, def_ft); + auto out_md = dnnl::memory::desc(out_dims, inp_dtype, def_ft); + auto scalar_md = dnnl::memory::desc(scalar_dims, cnd_dtype, def_ft); + + binary_ne_zero_pd = dnnl::binary::primitive_desc( + dnnl::binary::desc(dnnl::algorithm::binary_ne, cnd_md, scalar_md, cnd_md), cpu_engine); + binary_eq_zero_pd = dnnl::binary::primitive_desc( + dnnl::binary::desc(dnnl::algorithm::binary_eq, cnd_md, scalar_md, cnd_md), cpu_engine); + + // if broadcast is needed output must be larger in size + auto lmask_dim = lhs_shape.Size() > cnd_shape.Size() ? lhs_dims : cnd_dims; + auto lmask_md = dnnl::memory::desc(lmask_dim, inp_dtype, def_ft); + binary_mul_l_pd = dnnl::binary::primitive_desc( + dnnl::binary::desc(dnnl::algorithm::binary_mul, lhs_md, cnd_md, lmask_md), cpu_engine); + + auto rmask_dim = rhs_shape.Size() > cnd_shape.Size() ? rhs_dims : cnd_dims; + auto rmask_md = dnnl::memory::desc(rmask_dim, inp_dtype, def_ft); + binary_mul_r_pd = dnnl::binary::primitive_desc( + dnnl::binary::desc(dnnl::algorithm::binary_mul, rhs_md, cnd_md, rmask_md), cpu_engine); + + binary_sum_pd = dnnl::binary::primitive_desc( + dnnl::binary::desc(dnnl::algorithm::binary_add, lmask_md, rmask_md, out_md), cpu_engine); + + binary_ne_zero = dnnl::binary(binary_ne_zero_pd); + binary_eq_zero = dnnl::binary(binary_eq_zero_pd); + binary_mul_l = dnnl::binary(binary_mul_l_pd); + binary_mul_r = dnnl::binary(binary_mul_r_pd); + binary_sum = dnnl::binary(binary_sum_pd); +} + +/*! + * \brief + * Execute where operator by oneDNN primitives. + * 1. Create tensor cnd_lhs = condition == 0 ==> convert 0 to 1 and all other values to 0 + * 2. Create tensor cnd_rhs = condition != 0 ==> convert all non-zero values to 1 + * 3. Mask lhs tensor by cnd_lhs => mask_lhs = lhs * cnd_lhs + * 4. Mask rhs tensor by cnd_hs => mask_rhs = rhs * cnd_rhs + * 5. output = mask_lhs + mask_rhs + */ +void DNNLWhereFwd::Execute(const Tensors& tensors, + const std::vector& req, + const OpContext& ctx) const { + const auto& cpu_engine = CpuEngine::Get()->get_engine(); + const auto& cpu_stream = ctx.get_stream(); + + const auto& cnd_tensor = tensors.condition.GetDNNLDataReorder(binary_eq_zero_pd.src0_desc()); + const auto& lhs_tensor = tensors.left.GetDNNLDataReorder(binary_mul_l_pd.src0_desc()); + const auto& rhs_tensor = tensors.right.GetDNNLDataReorder(binary_mul_r_pd.src0_desc()); + + mxnet::dnnl_output_t out_mem = CreateDNNLMem(tensors.output, binary_sum_pd.dst_desc(), req[0]); + + const int dtype_size = + std::max(GetTypeSize(tensors.condition.dtype()), GetTypeSize(tensors.left.dtype())); + + // allocate temporary memory for 4 additional tensors + mshadow::Tensor tmp_workspace = ctx.requested[0].get_space( + mshadow::Shape1(tensors.output.shape().Size() * 4 * dtype_size), cpu_stream); + char* workspace_ptr = reinterpret_cast(tmp_workspace.dptr_); + const int offset_size = tensors.output.shape().Size() * dtype_size; + + dnnl::memory cnd_lhs(binary_ne_zero_pd.dst_desc(), cpu_engine, workspace_ptr); + dnnl::memory cnd_rhs(binary_eq_zero_pd.dst_desc(), cpu_engine, workspace_ptr + offset_size); + dnnl::memory masked_lhs(binary_mul_l_pd.dst_desc(), cpu_engine, workspace_ptr + 2 * offset_size); + dnnl::memory masked_rhs(binary_mul_r_pd.dst_desc(), cpu_engine, workspace_ptr + 3 * offset_size); + + double zero{0}; + dnnl::memory zero_scalar(binary_eq_zero_pd.src1_desc(), cpu_engine, &zero); + + DNNLStream::Get()->RegisterPrimArgs( + binary_ne_zero, + {{DNNL_ARG_SRC_0, *cnd_tensor}, {DNNL_ARG_SRC_1, zero_scalar}, {DNNL_ARG_DST, cnd_lhs}}); + + DNNLStream::Get()->RegisterPrimArgs( + binary_eq_zero, + {{DNNL_ARG_SRC_0, *cnd_tensor}, {DNNL_ARG_SRC_1, zero_scalar}, {DNNL_ARG_DST, cnd_rhs}}); + + DNNLStream::Get()->RegisterPrimArgs( + binary_mul_l, + {{DNNL_ARG_SRC_0, *lhs_tensor}, {DNNL_ARG_SRC_1, cnd_lhs}, {DNNL_ARG_DST, masked_lhs}}); + + DNNLStream::Get()->RegisterPrimArgs( + binary_mul_r, + {{DNNL_ARG_SRC_0, *rhs_tensor}, {DNNL_ARG_SRC_1, cnd_rhs}, {DNNL_ARG_DST, masked_rhs}}); + + DNNLStream::Get()->RegisterPrimArgs(binary_sum, + {{DNNL_ARG_SRC_0, masked_lhs}, + {DNNL_ARG_SRC_1, masked_rhs}, + {DNNL_ARG_DST, *out_mem.second}}); + + CommitOutput(tensors.output, out_mem); + DNNLStream::Get()->Submit(); +} + +} // namespace op +} // namespace mxnet +#endif diff --git a/src/operator/numpy/np_where_forward_op.cc b/src/operator/numpy/np_where_forward_op.cc index bef9b19b0c94..6caa58d197ac 100644 --- a/src/operator/numpy/np_where_forward_op.cc +++ b/src/operator/numpy/np_where_forward_op.cc @@ -23,6 +23,7 @@ */ #include "np_where_op-inl.h" +#include "../nn/dnnl/dnnl_where-inl.h" namespace mxnet { namespace op { @@ -89,6 +90,39 @@ inline bool NumpyWhereScalarOpType(const nnvm::NodeAttrs& attrs, DMLC_REGISTER_PARAMETER(NumpyWhereScalarParam); DMLC_REGISTER_PARAMETER(NumpyWhereScalar2Param); +#if MXNET_USE_ONEDNN == 1 +static void WhereForwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& op_ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK(!inputs.empty()); + if (req[0] == kNullOp) { + return; + } + if (SupportDNNLWhere(inputs)) { + DNNL_OPCHECK_INIT(/*is backward*/ false, outputs.size(), inputs, outputs); + DNNLRun(DNNLWhereForward, attrs, op_ctx, inputs, req, outputs); + DNNL_OPCHECK_RUN(NumpyWhereOpForward, attrs, op_ctx, inputs, req, outputs); + } else { + FallBackCompute(NumpyWhereOpForward, attrs, op_ctx, inputs, req, outputs); + } +} + +inline static bool WhereInferStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + return DNNLStorageType(attrs, + dev_mask, + /*support onednn*/ true, + dispatch_mode, + in_attrs, + out_attrs); +} +#endif // MXNET_USE_ONEDNN == 1 + NNVM_REGISTER_OP(_npi_where) .set_num_inputs(3) .set_num_outputs(1) @@ -103,11 +137,19 @@ NNVM_REGISTER_OP(_npi_where) return std::vector >{{1, 0}, {2, 0}}; }) .set_attr("FCompute", NumpyWhereOpForward) +#if MXNET_USE_ONEDNN == 1 + .set_attr("FResourceRequest", + [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; + }) + .set_attr("FComputeEx", WhereForwardEx) + .set_attr("TIsDNNL", true) + .set_attr("FInferStorageType", WhereInferStorageType) +#endif .set_attr( "FGradient", - // Use the following lambda function instead of ElemwiseGradUseIn - // for best efficiency. grad[condition] = 0; to calculate grad[x] and grad[y] - // we need only condition from input. + // Use the following lambda function instead of ElemwiseGradUseIn for best efficiency. + // grad[condition] = 0; to calculate grad[x] and grad[y] we need only condition from input. [](const nnvm::ObjectPtr& n, const std::vector& ograds) { std::vector ret; // make zero grad node for grad[condition]