Skip to content

Commit

Permalink
[Relay] Minor fix for some TF OD models (apache#6729)
Browse files Browse the repository at this point in the history
* Minor fix for some tf od models

* More fix

* Minor fix

* Fix lint

* Minor fix
  • Loading branch information
kevinthesun authored and Trevor Morris committed Oct 23, 2020
1 parent 36aa2cc commit f67e0ae
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 10 deletions.
20 changes: 15 additions & 5 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -897,8 +897,14 @@ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y,
return compute(
condition->shape,
[&](const Array<Var>& indices) {
Array<PrimExpr> condition_idx{indices[0]};
return tvm::tir::Select(condition(condition_idx) != 0, x(), y());
PrimExpr cond;
if (condition->shape.size() == 0) {
cond = condition();
} else {
Array<PrimExpr> condition_idx{indices[0]};
cond = condition(condition_idx);
}
return tvm::tir::Select(cond != 0, x(), y());
},
name, tag);
} else if (condition->shape.size() != 1) {
Expand All @@ -913,9 +919,13 @@ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y,
},
name, tag);
} else {
CHECK_EQ(topi::GetConstInt(condition->shape[0]), topi::GetConstInt(x->shape[0]))
<< "If condition is 1-D, the first dimension must be the same as x: " << condition->shape[0]
<< " vs " << x->shape[0];
int64_t cond_first_dim = topi::GetConstInt(condition->shape[0]);
int64_t x_first_dim = topi::GetConstInt(x->shape[0]);
if (cond_first_dim > 0 && x_first_dim > 0) {
CHECK_EQ(cond_first_dim, x_first_dim)
<< "If condition is 1-D, the first dimension must be the same as x: " << cond_first_dim
<< " vs " << x_first_dim;
}
return compute(
x->shape,
[&](const Array<Var>& indices) {
Expand Down
22 changes: 17 additions & 5 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,7 +1549,7 @@ def _impl(inputs, attr, params, mod):
idx += st

# Only return when in_shape is fully static in the range from begin to end.
if idx >= st:
if idx >= ed:
ret = _expr.const(out_data, dtype)
if shrink_axis_mask:
ret = _op.squeeze(ret)
Expand Down Expand Up @@ -1659,14 +1659,26 @@ def _transform_mask(stride_dim, ellipsis_mask):

def _pad(name):
def _impl(inputs, attr, params, mod):
padlist = _get_param(params, inputs[1])
paddings = tuple(tuple(l) for l in padlist)
try:
padlist = _get_param(params, inputs[1])
except (IndexError, KeyError, AttributeError):
try:
padlist = _infer_value(inputs[1], params, mod).asnumpy().tolist()
except Exception:
padlist = inputs[1]

if isinstance(padlist, _expr.Expr):
paddings = padlist
else:
paddings = tuple(tuple(l) for l in padlist)
attr["pad_width"] = paddings
attr["pad_value"] = 0
new_inputs = [inputs[0]]
if name == "PadV2":
constant_values = _get_num_param(params, inputs[2])
attr["pad_value"] = constant_values
try:
attr["pad_value"] = _get_num_param(params, inputs[2])
except (IndexError, KeyError, AttributeError):
attr["pad_value"] = inputs[2]
return AttrCvt(
op_name="pad",
ignores=["Tpaddings"],
Expand Down
18 changes: 18 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,9 @@ def repeat_shape_func(attrs, inputs, _):

@_reg.register_shape_func("broadcast_to_like", False)
def broadcast_to_like_shape_func(attrs, inputs, _):
"""
Shape func for broadcast_to_like.
"""
return [topi.math.identity(inputs[1])]


Expand All @@ -796,7 +799,22 @@ def _stack_shape_func(data_shape, axis, num_inputs):

@_reg.register_shape_func("stack", False)
def stack_shape_func(attrs, inputs, _):
"""
Shape func for stack.
"""
axis = get_const_int(attrs.axis)
if axis < 0:
axis += inputs[0].shape[0] + 1
return [_stack_shape_func(inputs[0], convert(axis), convert(len(inputs)))]


@_reg.register_shape_func("where", False)
def where_shape_func(attrs, inputs, _):
"""
Shape func for where.
"""
cond_shape = inputs[0]
x_shape = inputs[1]
out_shape = x_shape if x_shape.shape else cond_shape

return [topi.math.identity(out_shape)]
1 change: 1 addition & 0 deletions tests/python/topi/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ def test_reshape():

@tvm.testing.uses_gpu
def test_where():
verify_where(())
verify_where((1, 2, 3, 4))


Expand Down

0 comments on commit f67e0ae

Please sign in to comment.