From b3fe5da8d1e34c86efe72c07a93e20b0fde3b855 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 29 May 2019 01:59:12 -0700 Subject: [PATCH] First version that works with hints, needs clean up --- python/tvm/relay/__init__.py | 46 +++--- python/tvm/relay/scope_builder.py | 72 ++++++++++ src/relay/pass/type_infer.cc | 168 +++++++++------------- src/relay/pass/type_solver.cc | 223 ++++++++++++++++++++++-------- src/relay/pass/type_solver.h | 11 +- tests/python/relay/test_any.py | 75 ++++------ 6 files changed, 362 insertions(+), 233 deletions(-) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 8a314d9a38b0e..1aa56dbfac5d8 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -127,35 +127,26 @@ def int32(val): return relay.const(val, 'int32') -# TODO(@jroesch): possibly move to their own file -def while_loop(cond, loop_vars, loop_bodies): - sb = ScopeBuilder() - wl = Var("while_loop") - with sb.if_scope(cond(*loop_vars)): - sb.ret(wl(*loop_bodies(*loop_vars))) - with sb.else_scope(): - sb.ret(Tuple(loop_vars)) - - # return wl, Function(loop_vars, sb.get()) - # Return this decision - def _while_loop(*args): - return Let( - wl, Function(loop_vars, sb.get()), - wl(*args)) - - return _while_loop +def type_of(sb, expr): + if isinstance(expr, Var): + return expr.type_annotation + else: + it = IncompleteType() + v = var("unify", it) + sb.let(v, expr) + return it # TODO(@jroesch): possibly move to their own file -def while_loop2(cond, loop_vars, loop_bodies): +def while_loop(cond, loop_vars, loop_bodies): sb = ScopeBuilder() wl = Var("while_loop") type_hints = [] fresh_vars = [] - for lv in loop_vars: - assert isinstance(lv, Var) - v = var(lv.name_hint, type_annotation=IncompleteType()) + for i, lv in enumerate(loop_vars): + name = lv.name_hint if isinstance(lv, Var) else "arg{}".format(i) + v = var(name, type_annotation=IncompleteType()) fresh_vars.append(v) - type_hints.append(TypeHint(v, lv.type_annotation)) + type_hints.append(TypeHint(v, type_of(sb, lv))) for i, hint in enumerate(type_hints): var_hint = var("hint{}".format(i)) @@ -166,13 +157,10 @@ def while_loop2(cond, loop_vars, loop_bodies): with sb.else_scope(): sb.ret(Tuple(fresh_vars)) - # return wl, Function(fresh_vars, sb.get()) - def _while_loop(*args): - return Let( - wl, Function(fresh_vars, sb.get()), - wl(*args)) - - return _while_loop + func = Function(fresh_vars, sb.get()) + let = Let(wl, func, wl) + print(let) + return let def foreach(iter, init, body): i = var("i", shape=(), dtype='int32') diff --git a/python/tvm/relay/scope_builder.py b/python/tvm/relay/scope_builder.py index 337044098cd5a..84ef23beac6fe 100644 --- a/python/tvm/relay/scope_builder.py +++ b/python/tvm/relay/scope_builder.py @@ -43,6 +43,9 @@ def __exit__(self, ptype, value, trace): self._exit_cb() +def int32(val): + return relay.const(val, 'int32') + def _make_lets(bindings, ret_value): """Make a nested let expressions. @@ -176,6 +179,75 @@ def _on_exit(): false_branch) return WithScope(None, _on_exit) + + def type_of(self, expr): + if isinstance(expr, Var): + return expr.type_annotation + else: + it = _ty.IncompleteType() + v = _expr.var("unify", it) + self.let(v, expr) + return it + +# def while_loop(self, cond, loop_vars, loop_bodies): +# self._enter_scope() +# wl = _expr_Var("while_loop") + +# with self.if_scope(cond(*loop_vars)): +# self.ret(wl(*loop_bodies(*loop_vars))) +# with sb.else_scope(): +# self.ret(Tuple(loop_vars)) + +# def _on_exit(): +# bindings, ret_value = self._exit_scope() +# _make_lets(bindings,) +# self._ret_values[-1] = _expr.Let( +# wl, +# Function(loop_vars, sb.get()), +# wl(loop_vars) + +# return WithScope(None, _on_exit) + +# # TODO(@jroesch): possibly move to their own file +# def while_loop2(cond, loop_vars, loop_bodies): +# sb = ScopeBuilder() +# wl = Var("while_loop") +# type_hints = [] +# fresh_vars = [] +# for i, lv in enumerate(loop_vars): +# name = lv.name_hint if isinstance(lv, Var) else "arg{}".format(i) +# v = var(name, type_annotation=IncompleteType()) +# fresh_vars.append(v) +# type_hints.append(TypeHint(v, type_of(sb, lv))) + +# for i, hint in enumerate(type_hints): +# var_hint = var("hint{}".format(i)) +# sb.let(var_hint, hint) + +# with sb.if_scope(cond(*fresh_vars)): +# sb.ret(wl(*loop_bodies(*fresh_vars))) +# with sb.else_scope(): +# sb.ret(Tuple(fresh_vars)) + +# # return wl, Function(fresh_vars, sb.get()) +# def _while_loop(*args): +# return Let( +# wl, Function(fresh_vars, sb.get()), +# wl(*args)) + +# return _while_loop + +# def foreach(iter, init, body): +# i = var("i", shape=(), dtype='int32') +# st = var("st", type_annotation=relay.IncompleteType()) +# update = body(i, st) +# dim = take(op.shape_of(iter), indices=i, axis=0) +# def _cond(i, st): +# return op.min(op.less(i, dim)) +# loop = while_loop( +# _cond, [i, st], [i + int32(1), update]) +# return loop(int32(0), init) + def ret(self, value): """Set the return value of this scope. diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 6f7319bac36c9..b2bc12503c064 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -231,7 +231,7 @@ class TypeInferencer : private ExprFunctor, return op->tensor_type(); } - Type VisitExpr_(const TupleNode* op) final { + /* Type VisitExpr_(const TupleNode* op) final { if (!make_tuple_rel_.defined()) { make_tuple_rel_ = TypeRelationFn( EnvFunc::Get("tvm.relay.type_relation.MakeTuple").node_); @@ -245,6 +245,14 @@ class TypeInferencer : private ExprFunctor, solver_.AddConstraint(TypeRelationNode::make( make_tuple_rel_, types, op->fields.size(), {}), GetRef(op)); return rtype; + } */ + + Type VisitExpr_(const TupleNode* op) final { + Array types; + for (Expr field : op->fields) { + types.push_back(GetType(field)); + } + return TupleTypeNode::make(types); } Type VisitExpr_(const TupleGetItemNode* op) final { @@ -335,6 +343,7 @@ class TypeInferencer : private ExprFunctor, Type let_type = IncompleteTypeNode::make(Kind::kType); if (is_functional_literal) { + let_type = GetType(let->var); type_map_[let->var].checked_type = let_type; } @@ -589,6 +598,8 @@ class TypeInferencer : private ExprFunctor, type_hint->type_hint.as() != nullptr) << "type hints can only be applied to tensor types: " << PrettyPrint(type_hint->type_hint); + auto ty = GetType(type_hint->expr); + solver_.AddHint(ty, type_hint->type_hint); auto it = this->hint_map_.find(type_hint->expr); if (it == this->hint_map_.end()) { this->hint_map_[type_hint->expr] = { type_hint->type_hint }; @@ -598,95 +609,55 @@ class TypeInferencer : private ExprFunctor, return TupleTypeNode::make({}); } - Type GeneralizeShape(const std::vector& concrete_types) { - if (concrete_types.size() >= 1) { - CHECK(concrete_types[0].defined()); - std::cout << "shape" << concrete_types[0]; - Shape shape = concrete_types[0]->shape; - DataType dtype = concrete_types[0]->dtype; - for (auto ty : concrete_types) { - std::cout << ty; - CHECK_EQ(dtype, ty->dtype); - CHECK_EQ(shape.size(), ty->shape.size()); - for (size_t i = 0; i < shape.size(); i++) { - auto ldim = shape[i]; - auto rdim = ty->shape[i]; - - std::cout << "dim1: " << ldim << "dim2: " << rdim; - - if (ldim.same_as(rdim)) { - continue; - } - - if (ldim.same_as(Any()) || rdim.same_as(Any())) { - continue; - } - - auto left_is_int = ldim.as(); - auto right_is_int = rdim.as(); - if (left_is_int && right_is_int && left_is_int->value != right_is_int->value) { - shape.Set(i, Any()); - continue; - } - - auto left_is_var = ldim.as(); - auto right_is_var = rdim.as(); - if (left_is_var && right_is_var && - !GetRef(right_is_var).same_as(GetRef(left_is_var))) { - shape.Set(i, Any()); - continue; - } - } - LOG(INFO) << "shape after: " << shape; - } - return TensorTypeNode::make(shape, dtype); - } else { - return IncompleteTypeNode::make(Kind::kType); - } - } - - void ApplyTypeHints() { - for (auto hint : hint_map_) { - auto expr = hint.first; - auto hints = hint.second; - LOG(INFO) << "type hint processing: " << expr; - // Each hint represents a possible type - // for the expression. - // - // In order to unify types we must - // split all concretete type information - // to construct the most general type. - // - // We will then unify the result with - // any hint types which are non-concrete. - std::vector concrete_types; - std::vector to_unify; - for (auto ty : hints) { - auto rty = solver_.Resolve(ty); - if (rty.as() != nullptr) { - to_unify.push_back(rty); - } else { - concrete_types.push_back(Downcast(rty)); - } - } - - std::cout << "above generalize"; - // Add the generalized shape. - Type general_shape = GeneralizeShape(concrete_types); - std::cout << "Generalized Shape: " << general_shape; - for (auto ty : to_unify) { - CHECK(general_shape.defined()); - general_shape = Unify(general_shape, ty, expr); - } + // void ApplyTypeHints() { + // for (auto hint : hint_map_) { + // auto expr = hint.first; + // auto hints = hint.second; + // LOG(INFO) << "type hint processing: " << expr; + // // Each hint represents a possible type + // // for the expression. + // // + // // In order to unify types we must + // // split all concretete type information + // // to construct the most general type. + // // + // // We will then unify the result with + // // any hint types which are non-concrete. + // std::vector concrete_types; + // std::vector to_unify; + // for (auto ty : hints) { + // auto rty = solver_.Resolve(ty); + // if (rty.as() != nullptr) { + // to_unify.push_back(rty); + // } else { + // concrete_types.push_back(Downcast(rty)); + // } + // } + + // std::cout << "above generalize"; + // // Add the generalized shape. + // Type general_shape = GeneralizeShape(concrete_types); + // std::cout << "Generalized Shape: " << general_shape; + // for (auto ty : to_unify) { + // CHECK(general_shape.defined()); + // general_shape = Unify(general_shape, ty, expr); + // } + + // auto it = type_map_.find(expr); + // CHECK(it != type_map_.end()); + // auto rtype = solver_.Resolve(it->second.checked_type); + // if (rtype.as() == nullptr) { + // it->second.checked_type = general_shape; + // } else { + // it->second.checked_type = Unify(rtype, general_shape, expr); + // } + // } + // } + void Solve() { + solver_.Solve(); - auto it = type_map_.find(expr); - CHECK(it != type_map_.end()); - auto rtype = solver_.Resolve(it->second.checked_type); - if (rtype.as() == nullptr) { - it->second.checked_type = general_shape; - } else { - it->second.checked_type = Unify(rtype, general_shape, expr); - } + if (err_reporter.AnyErrors()) { + err_reporter.RenderErrors(mod_); } } }; @@ -798,18 +769,18 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { // check if we need update the new_e bool need_update_type = !checked_type.same_as(new_e->checked_type_); bool need_update_call = ( - std::is_base_of::value); /* && + std::is_base_of::value && it->second.type_args.defined() && - !it->second.type_args.same_as(new_call->type_args)) */ + !it->second.type_args.same_as(new_call->type_args)); bool need_update_var = ( std::is_base_of::value && update_missing_type_annotation_ && !new_var->type_annotation.defined()); - bool need_update_fn = ( - std::is_base_of::value && + bool need_update_fn =( + std::is_base_of::value); /* && update_missing_type_annotation_ && - !new_fn->ret_type.defined()); + !new_fn->ret_type.defined()); */ if (!need_update_type && !need_update_var && @@ -875,14 +846,7 @@ Expr TypeInferencer::Infer(Expr expr) { GetType(expr); // Step 2: Solve the constraints. - solver_.Solve(); - - // Step 1: Apply type hints. - ApplyTypeHints(); - - if (err_reporter.AnyErrors()) { - err_reporter.RenderErrors(mod_); - } + Solve(); // Step 3: Attach resolved types to checked_type field. auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 8042342c0c23e..50abcc24a4197 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -88,6 +88,53 @@ class TypeSolver::OccursChecker : public TypeVisitor { bool found_; }; +static Type GeneralizeShape(const std::vector& concrete_types) { + if (concrete_types.size() >= 1) { + CHECK(concrete_types[0].defined()); + std::cout << "shape" << concrete_types[0]; + Shape shape = concrete_types[0]->shape; + DataType dtype = concrete_types[0]->dtype; + for (auto ty : concrete_types) { + std::cout << ty; + CHECK_EQ(dtype, ty->dtype); + CHECK_EQ(shape.size(), ty->shape.size()); + for (size_t i = 0; i < shape.size(); i++) { + auto ldim = shape[i]; + auto rdim = ty->shape[i]; + + std::cout << "dim1: " << ldim << "dim2: " << rdim; + + if (ldim.same_as(rdim)) { + continue; + } + + if (ldim.same_as(Any()) || rdim.same_as(Any())) { + continue; + } + + auto left_is_int = ldim.as(); + auto right_is_int = rdim.as(); + if (left_is_int && right_is_int && left_is_int->value != right_is_int->value) { + shape.Set(i, Any()); + continue; + } + + auto left_is_var = ldim.as(); + auto right_is_var = rdim.as(); + if (left_is_var && right_is_var && + !GetRef(right_is_var).same_as(GetRef(left_is_var))) { + shape.Set(i, Any()); + continue; + } + } + LOG(INFO) << "shape after: " << shape; + } + return TensorTypeNode::make(shape, dtype); + } else { + return IncompleteTypeNode::make(Kind::kType); + } +} + class TypeSolver::Unifier : public TypeFunctor { public: explicit Unifier(TypeSolver* solver, const NodeRef& loc) : solver_(solver), loc(loc) {} @@ -102,10 +149,15 @@ class TypeSolver::Unifier : public TypeFunctor { if (lhs->FindRoot() == rhs->FindRoot()) { return lhs->resolved_type; } + + std::cout << "LHS: " << lhs->resolved_type << " hints " << lhs->FindRoot()->hint << std::endl; + std::cout << "RHS: " << rhs->resolved_type << " hints " << rhs->FindRoot()->hint << std::endl; + if (lhs->resolved_type.as()) { CHECK(!OccursCheck(lhs, rhs->resolved_type)) << "Incomplete type " << lhs->resolved_type << " occurs in " << rhs->resolved_type << ", cannot unify"; + solver_->MergeFromTo(lhs, rhs); return rhs->resolved_type; } else if (rhs->resolved_type.as()) { @@ -115,20 +167,36 @@ class TypeSolver::Unifier : public TypeFunctor { solver_->MergeFromTo(rhs, lhs); return lhs->resolved_type; } else { - Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type); - if (!resolved.defined()) { - solver_->ReportError( - RELAY_ERROR( - "unable to unify: " - << "`" << PrettyPrint(lhs->resolved_type) - << "` and `" - << PrettyPrint(rhs->resolved_type) << "`"), this->loc); - return lhs->resolved_type; + if (lhs->hint.defined() || rhs->hint.defined()) { + CHECK(!(lhs->hint.defined() && rhs->hint.defined())); + std::vector tts; + tts.push_back(Downcast(lhs->resolved_type)); + if (lhs->hint.defined()) tts.push_back(Downcast(solver_->Resolve(lhs->hint))); + tts.push_back(Downcast(rhs->resolved_type)); + if (rhs->hint.defined()) tts.push_back(Downcast(solver_->Resolve(rhs->hint))); + auto gtt = GeneralizeShape(tts); + TypeNode* repr = solver_->GetTypeNode(gtt); + // Do we update equivalence classes, this is conflict unification only. + solver_->MergeFromTo(lhs, repr); + solver_->MergeFromTo(rhs, repr); + std::cout << "Solution: " << gtt << std::endl; + return gtt; } else { - TypeNode* top = solver_->GetTypeNode(resolved); - solver_->MergeFromTo(lhs, top); - solver_->MergeFromTo(rhs, top); - return resolved; + Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type); + if (!resolved.defined()) { + solver_->ReportError( + RELAY_ERROR( + "unable to unify: " + << "`" << PrettyPrint(lhs->resolved_type) + << "` and `" + << PrettyPrint(rhs->resolved_type) << "`"), this->loc); + return lhs->resolved_type; + } else { + TypeNode* top = solver_->GetTypeNode(resolved); + solver_->MergeFromTo(lhs, top); + solver_->MergeFromTo(rhs, top); + return resolved; + } } } } @@ -446,6 +514,11 @@ class TypeSolver::Merger : public TypeFunctor { VisitType(src->resolved_type); // set parent at the end so later calls to GetTypeNode go back to src src->parent = dst; + if (src->hint.defined() && !dst->hint.defined()) { + dst->hint = src->hint; + } else if (src->hint.defined() && dst->hint.defined()) { + std::cout << "SRC: " << src->hint << "DST: " << dst->hint << std::endl; + } // now propagate relations to child nodes, since change to // a child node should update parent too @@ -574,51 +647,93 @@ Type TypeSolver::Resolve(const Type& type) { return resolver.Resolve(t); } + +void TypeSolver::AddHint(const Type& t, const Type& hint) { + auto root = GetTypeNode(t)->FindRoot(); + Type old_hint = root->hint; + if (!old_hint.defined()) { + root->hint = hint; + } + + std::vector concrete_types; + if (old_hint.as() != nullptr) { + auto rhint = Resolve(old_hint); + if (rhint.as() == nullptr) { + LOG(FATAL) << "UNRE"; + // root->hint = Unify(hint, rhint, hint); + } else { + concrete_types.push_back(Downcast(rhint)); + } + auto ctype = GeneralizeShape(concrete_types); + root->hint = Unify(hint, ctype, ctype); + std::cout << "AddHint: " << root->hint << std::endl; + } +} + bool TypeSolver::Solve() { // Update until queue is empty. - while (!update_queue_.empty()) { - RelationNode* rnode = update_queue_.front(); - const auto& rel = rnode->rel; - update_queue_.pop(); - CHECK(!rnode->resolved); - // update the relation with given evidence. - Array args; - for (auto* tlink = rnode->type_list.head; tlink != nullptr; tlink = tlink->next) { - args.push_back(Resolve(tlink->value->FindRoot()->resolved_type)); - CHECK_LE(args.size(), rel->args.size()); - } - - CHECK(rnode->location.defined()) - << "undefined location, should be set when constructing relation node"; - - // We need to set this in order to understand where unification - // errors generated by the error reporting are coming from. - reporter_->SetLocation(rnode->location); - - try { - // Call the Type Relation's function. - bool resolved = rel->func(args, rel->num_inputs, rel->attrs, reporter_); - - if (resolved) { - ++num_resolved_rels_; + while (true) { + while (!update_queue_.empty()) { + RelationNode* rnode = update_queue_.front(); + const auto& rel = rnode->rel; + update_queue_.pop(); + CHECK(!rnode->resolved); + // update the relation with given evidence. + Array args; + for (auto* tlink = rnode->type_list.head; tlink != nullptr; tlink = tlink->next) { + args.push_back(Resolve(tlink->value->FindRoot()->resolved_type)); + CHECK_LE(args.size(), rel->args.size()); } - rnode->resolved = resolved; - } catch (const Error& err) { - this->ReportError(err, rnode->location); - rnode->resolved = false; - } catch (const dmlc::Error& err) { - rnode->resolved = false; - this->ReportError( - RELAY_ERROR( - "an internal invariant was violated while " \ - "typechecking your program " << - err.what()), rnode->location); - } - - // Mark inqueue as false after the function call - // so that rnode itself won't get enqueued again. - rnode->inqueue = false; + CHECK(rnode->location.defined()) + << "undefined location, should be set when constructing relation node"; + + // We need to set this in order to understand where unification + // errors generated by the error reporting are coming from. + reporter_->SetLocation(rnode->location); + + try { + // Call the Type Relation's function. + bool resolved = rel->func(args, rel->num_inputs, rel->attrs, reporter_); + + if (resolved) { + ++num_resolved_rels_; + } + + rnode->resolved = resolved; + } catch (const Error& err) { + this->ReportError(err, rnode->location); + rnode->resolved = false; + } catch (const dmlc::Error& err) { + rnode->resolved = false; + this->ReportError( + RELAY_ERROR( + "an internal invariant was violated while " \ + "typechecking your program " << + err.what()), rnode->location); + } + + // Mark inqueue as false after the function call + // so that rnode itself won't get enqueued again. + rnode->inqueue = false; + } + + bool solved = false; + + for (auto type : type_nodes_) { + auto root = type->FindRoot(); + auto rty = root->resolved_type; + if (rty.as() != nullptr && root->hint.defined()) { + solved = true; + Unify(rty, root->hint, rty); + } + } + + // If hints provided no change then we should just exit, solving is done. + if (!solved) { + break; + } + } // This criterion is not necessarily right for all the possible cases diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index ff0464610aefa..a044fc9f6d235 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -96,6 +96,12 @@ class TypeSolver { */ void ReportError(const Error& err, const NodeRef& location); + /*! + * \brief Add type hint. + * \param type The type to hint. + * \param hints The hint for the type. + */ + void AddHint(const Type& t, const Type& hint); private: class OccursChecker; class Unifier; @@ -123,6 +129,8 @@ class TypeSolver { TypeNode* parent{nullptr}; /*! \brief set of relations that is related to this type node */ std::unordered_set rel_set; + Type hint{nullptr}; + /*! * \brief Find the root type node, perform path compression * \return The root type node. @@ -166,7 +174,7 @@ class TypeSolver { std::vector rel_nodes_; /*! \brief Number of resolved relations */ size_t num_resolved_rels_{0}; - /*! \brief map from type node to types. */ + /*! \brief map from types to type nodes. */ std::unordered_map tmap_; /*! \brief Internal queue to update the relation */ std::queue update_queue_; @@ -206,6 +214,7 @@ class TypeSolver { rel->inqueue = true; update_queue_.push(rel); } + /*! * \brief Merge rhs type node to lhs * \param src The source operand diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 36c35c6992943..4da2e635d0ce5 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -1,6 +1,6 @@ import tvm from tvm import relay -from tvm.relay import Kind, while_loop, while_loop2, foreach +from tvm.relay import Kind, while_loop, foreach import numpy as np def int32(val): @@ -114,19 +114,28 @@ def _body(i, st): def test_dynamic_concat_with_hint(): """ - fn @concat_loop(%i: int32, %st: (any, 1)) -> (any, 1) { - if (%i < 10) { - let %i = reshape(cast(i, "float32"), newshape=(1, )) - let %new_st = concatenate((st, i), axis=0) - concat_loop(%i + 1, ) + v0.0.1 + let %while_loop = fn (%i: int32, %arg1: Tensor[(?, 1), int32]) -> (int32, Tensor[(?, 1), int32]) { + let %unify: Tensor[(?, 1), int32] = meta[relay.Constant][0] + let %hint0: () = type_hint(for=%i, hint=int32) + let %hint1: () = type_hint(for=%arg1, hint=Tensor[(?, 1), int32]) + %0 = less(%i, 10 /* ty=int32 */) /* ty=bool */ + %1 = min(%0) /* ty=bool */ + if (%1) { + %2 = add(%i, 1 /* ty=int32 */) /* ty=int32 */ + %3 = reshape(%i, newshape=[1, 1]) /* ty=Tensor[(1, 1), int32] */ + %4 = (%arg1, %3) + %5 = concatenate(%4) /* ty=Tensor[(?, 1), int32] */ + %while_loop(%2, %5) /* ty=(int32, Tensor[(?, 1), int32]) */ } else { - st + (%i, %arg1) } } + %while_loop """ # Initial Values. i = relay.var('i', shape=(), dtype='int32') - st = relay.var('st', type_annotation=relay.IncompleteType()) + st = relay.const(np.ones((1, 1), dtype='int32')) def _cond(i, st): return relay.op.min(relay.op.less(i, int32(10))) @@ -136,48 +145,20 @@ def _body(i, st): ret = relay.op.concatenate([st, i_vec], axis=0) return i + int32(1), ret - loop = while_loop2(_cond, [i, st], _body) + loop = while_loop(_cond, [i, st], _body) + print(loop) + loop = relay.ir_pass.infer_type(loop) start = relay.var('start', shape=(), dtype='int32') - body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1))) - print(body) - body = relay.ir_pass.infer_type(body) + print(loop) + # body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1))) + # print(body) + # body = relay.ir_pass.infer_type(body) # func = relay.Function([start], relay.TupleGetItem(body, 1)) # func = relay.ir_pass.infer_type(func) - import pdb; pdb.set_trace() - -# def test_dynamic_concat_with_hint(): -# """ -# fn @concat_loop(%i: int32, %st: (any, 1)) -> (any, 1) { -# if (%i < 10) { -# let %i = reshape(cast(i, "float32"), newshape=(1, )) -# let %new_st = concatenate((st, i), axis=0) -# concat_loop(%i + 1, ) -# } else { -# st -# } -# } -# """ -# # Initial Values. -# i = relay.var('i', shape=(), dtype='int32') -# st = relay.var('st', type_annotation=relay.IncompleteType()) - -# def _cond(i, st): -# return relay.op.min(relay.op.less(i, int32(10))) - -# def _body(i, st): -# i_vec = relay.op.reshape(i, (1,1)) -# ret = relay.op.concatenate([st, i_vec], axis=0) -# return i + int32(1), ret - -# loop = while_loop(_cond, [i, st], _body) -# start = relay.var('start', shape=(), dtype='int32') -# body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1))) -# func = relay.Function([start], relay.TupleGetItem(body, 1)) -# # func = relay.ir_pass.infer_type(func) -# import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() if __name__ == "__main__": - # test_arange_with_dynamic_shape() - # test_dynamic_concat() - # test_dynamic_concat_with_wrong_annotation() + test_arange_with_dynamic_shape() + test_dynamic_concat() + test_dynamic_concat_with_wrong_annotation() test_dynamic_concat_with_hint()