Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix sanity
Browse files Browse the repository at this point in the history
  • Loading branch information
Bartlomiej Gawrych committed Feb 7, 2022
1 parent aa7482e commit 761e956
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
7 changes: 4 additions & 3 deletions src/operator/nn/dnnl/dnnl_where-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
#define MXNET_OPERATOR_NN_DNNL_DNNL_WHERE_INL_H_

#if MXNET_USE_ONEDNN == 1
#include <memory>
#include <unordered_map>
#include <vector>

#include "./dnnl_base-inl.h"
#include "./dnnl_ops-inl.h"

Expand All @@ -45,7 +46,7 @@ class DNNLWhereFwd {

static DNNLWhereFwd GetCached(const Tensors& tensors);

DNNLWhereFwd(const Tensors& tensors);
explicit DNNLWhereFwd(const Tensors& tensors);

void Execute(const Tensors& tensors,
const std::vector<OpReqType>& req,
Expand All @@ -69,4 +70,4 @@ bool SupportDNNLWhere(const std::vector<NDArray>& inputs);
} // namespace op
} // namespace mxnet
#endif
#endif // MXNET_OPERATOR_NN_DNNL_DNNL_WHERE_INL_H_
#endif // MXNET_OPERATOR_NN_DNNL_DNNL_WHERE_INL_H_
9 changes: 6 additions & 3 deletions src/operator/nn/dnnl/dnnl_where.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@

#if MXNET_USE_ONEDNN == 1

#include <algorithm>
#include <set>
#include <unordered_set>
#include "dnnl_where-inl.h"
#include "src/operator/operator_common.h"
#include "operator/operator_common.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -168,7 +171,7 @@ void DNNLWhereFwd::Execute(const Tensors& tensors,

// allocate temporary memory for 4 additional tensors
mshadow::Tensor<cpu, 1> tmp_workspace = ctx.requested[0].get_space<cpu>(
mshadow::Shape1(tensors.output.shape().Size() * 4 * dtype_size, cpu_stream);
mshadow::Shape1(tensors.output.shape().Size() * 4 * dtype_size), cpu_stream);
char* workspace_ptr = reinterpret_cast<char*>(tmp_workspace.dptr_);
const int offset_size = tensors.output.shape().Size() * dtype_size;

Expand Down Expand Up @@ -207,4 +210,4 @@ void DNNLWhereFwd::Execute(const Tensors& tensors,

} // namespace op
} // namespace mxnet
#endif
#endif

0 comments on commit 761e956

Please sign in to comment.