Skip to content

Commit

Permalink
[Relay] Keep fixed dim when unifying dynamic shape (apache#5795)
Browse files Browse the repository at this point in the history
  • Loading branch information
lixiaoquan authored and trevor-m committed Sep 3, 2020
1 parent c854567 commit 94064eb
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
11 changes: 11 additions & 0 deletions src/relay/analysis/type_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,17 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
if (ulhs.same_as(urhs)) {
return ulhs;
}

if (ulhs.as<AnyNode>() && urhs.as<tvm::IntImmNode>()) {
solver_->shape_uf_.Set(urhs, ulhs);
return urhs;
}

if (ulhs.as<tvm::IntImmNode>() && urhs.as<AnyNode>()) {
solver_->shape_uf_.Set(ulhs, urhs);
return ulhs;
}

if (ulhs.as<AnyNode>() || urhs.as<AnyNode>()) {
return Any();
}
Expand Down
4 changes: 1 addition & 3 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,9 +721,7 @@ def _body(i, st):
mod["main"] = func
data = np.array(0.0, dtype='int32')
ref = np.array([0] + list(range(10))).reshape((11, 1)).astype("int32")
# TODO(@jroesch): After LambdaLift pass, TypeInfer pass will fail
# so currently we cannot run this test case on VM
for kind in ["debug"]:
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data)
np.testing.assert_allclose(result.asnumpy(), ref)
Expand Down
11 changes: 11 additions & 0 deletions tests/python/relay/test_type_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tvm import te
from tvm import relay
from tvm.relay import op, transform, analysis
from tvm.relay import Any


def run_infer_type(expr, mod=None):
Expand Down Expand Up @@ -362,6 +363,15 @@ def test_let_polymorphism():
tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])]))


def test_if():
choice_t = relay.FuncType([], relay.scalar_type('bool'))
f = relay.Var('f', choice_t)
true_branch = relay.Var('True', relay.TensorType([Any(), 1], dtype='float32'))
false_branch = relay.Var('False', relay.TensorType([Any(), Any()], dtype='float32'))
top = relay.Function([true_branch, false_branch], relay.If(f(), true_branch, false_branch))
ft = run_infer_type(top)
tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype='float32'))

if __name__ == "__main__":
test_free_expr()
test_dual_op()
Expand All @@ -380,3 +390,4 @@ def test_let_polymorphism():
test_constructor_call()
test_adt_match()
test_let_polymorphism()
test_if()

0 comments on commit 94064eb

Please sign in to comment.