Skip to content

Commit

Permalink
Fix the shift column for scale_shift_nchw and scale_shift_nhwc in C t…
Browse files Browse the repository at this point in the history
…opi (apache#5679)
  • Loading branch information
tobegit3hub authored and Trevor Morris committed Jun 9, 2020
1 parent 6d0d74e commit 115ee90
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions topi/include/topi/nn/mapping.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ using namespace tvm::te;
inline Tensor scale_shift_nchw(const Tensor& x, const Tensor& scale, const Tensor& shift,
std::string name = "ScaleShift", std::string tag = kBroadcast) {
return tvm::te::compute(
x->shape, [&](Var b, Var c, Var h, Var w) { return x(b, c, h, w) * scale(c) + shift(w); },
x->shape, [&](Var b, Var c, Var h, Var w) { return x(b, c, h, w) * scale(c) + shift(c); },
name, tag);
}

Expand All @@ -66,7 +66,7 @@ inline Tensor scale_shift_nchw(const Tensor& x, const Tensor& scale, const Tenso
inline Tensor scale_shift_nhwc(const Tensor& x, const Tensor& scale, const Tensor& shift,
std::string name = "ScaleShift", std::string tag = kBroadcast) {
return tvm::te::compute(
x->shape, [&](Var b, Var h, Var w, Var c) { return x(b, h, w, c) * scale(c) + shift(w); },
x->shape, [&](Var b, Var h, Var w, Var c) { return x(b, h, w, c) * scale(c) + shift(c); },
name, tag);
}

Expand Down

0 comments on commit 115ee90

Please sign in to comment.