Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
apply review
Browse files Browse the repository at this point in the history
  • Loading branch information
Bartlomiej Gawrych committed Feb 9, 2022
1 parent 761e956 commit fa5fb86
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 11 deletions.
4 changes: 2 additions & 2 deletions src/operator/nn/dnnl/dnnl_where-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
#include <memory>
#include <unordered_map>
#include <vector>
#include "./dnnl_base-inl.h"
#include "./dnnl_ops-inl.h"
#include "dnnl_base-inl.h"
#include "dnnl_ops-inl.h"

namespace mxnet {
namespace op {
Expand Down
8 changes: 2 additions & 6 deletions src/operator/nn/dnnl/dnnl_where.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,8 @@ static mxnet::TShape GetBroadcastableShape(const mxnet::TShape& in_shape,

mxnet::TShape broadcastable_in_shape(out_shape.ndim(), -1);
const int lack_dims = out_shape.ndim() - in_shape.ndim();
for (int i = 0; i < out_shape.ndim(); ++i) {
int y = 1;
if (i >= lack_dims) {
y = in_shape[i - lack_dims];
}
broadcastable_in_shape[i] = y;
for (int i = lack_dims; i < out_shape.ndim(); ++i) {
broadcastable_in_shape[i] = in_shape[i - lack_dims];
}
return broadcastable_in_shape;
}
Expand Down
5 changes: 2 additions & 3 deletions src/operator/numpy/np_where_forward_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,8 @@ NNVM_REGISTER_OP(_npi_where)
#endif
.set_attr<nnvm::FGradient>(
"FGradient",
// Use the following lambda function instead of ElemwiseGradUseIn
// for best efficiency. grad[condition] = 0; to calculate grad[x] and
// grad[y] we need only condition from input.
// Use the following lambda function instead of ElemwiseGradUseIn for best efficiency.
// grad[condition] = 0; to calculate grad[x] and grad[y] we need only condition from input.
[](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
std::vector<nnvm::NodeEntry> ret;
// make zero grad node for grad[condition]
Expand Down

0 comments on commit fa5fb86

Please sign in to comment.