diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h index bffe63d780e4..b91bfe03ae87 100644 --- a/topi/include/topi/nn.h +++ b/topi/include/topi/nn.h @@ -268,7 +268,7 @@ inline tvm::Tensor conv2d_nchw(const tvm::Tensor& I, : pad(I, {tvm::Expr(0), tvm::Expr(0), pad_h, pad_w}); auto l = [&](tvm::Var b, tvm::Var o, tvm::Var h, tvm::Var w) { return tvm::sum( - T(b, i, stride_h * h + kh, stride_w * w + kw) * W(i, o, kh, kw), + T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw), {i, kh, kw}); }; return tvm::compute(output_shape, l, name, tag);