From 1acbdcd197433f7164e5acf694a607c481b6aaef Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 May 2019 17:49:29 -0700 Subject: [PATCH] Remove unification hack --- src/relay/pass/type_infer.cc | 51 +++--------------------------------- 1 file changed, 3 insertions(+), 48 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 887d3bdf2801c..e8185f22e9baf 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -101,20 +101,6 @@ TVM_REGISTER_API("tvm.relay.type_relation.MakeTuple") MakeTupleRel); -// Deferred relation for call arg -bool CallArgRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 2); - reporter->AssignArg(types[1], types[0]); - return true; -} -TVM_REGISTER_API("tvm.relay.type_relation.CallArg") -.set_body_typed&, int, const Attrs&, const TypeReporter&)>( - CallArgRel); - - struct ResolvedTypeInfo { explicit ResolvedTypeInfo(Type checked_type, Array type_args) : checked_type(checked_type), type_args(type_args) {} @@ -167,7 +153,6 @@ class TypeInferencer : private ExprFunctor, // relation function TypeRelationFn tuple_getitem_rel_; TypeRelationFn make_tuple_rel_; - TypeRelationFn call_arg_rel_; // Perform unification on two types and report the error at the expression // or the span of the expression. @@ -394,30 +379,6 @@ class TypeInferencer : private ExprFunctor, return rtype; } - // struct SubShapeMutator : TypeMutator { - // std::unordered_map sh_map_; - - // SubShapeMutator(std::unordered_map sh_map) : - // sh_map_(sh_map) {} - - // Type VisitType_(const TensorTypeNode* ty) { - // tvm::Array shape; - // for (auto sh : ty->shape) { - // auto it = sh_map_.find(sh); - // if (it != sh_map_.end()) { - // shape.push_back(it->second); - // } else { - // shape.push_back(sh); - // } - // return TensorTypeNode::make(shape, ty->dtype); - // } - // } - // }; - - // Type SubShape(const Type& ty, std::unordered_map sh_map) { - // return SubShapeMutator(sh_map).VisitType(ty); - // } - // substitute the type args in the function type FuncType InstantiateFuncType(const FuncTypeNode* fn_ty, const Array& ty_args) { tvm::Map subst_map; @@ -529,14 +490,8 @@ class TypeInferencer : private ExprFunctor, } } - if (!call_arg_rel_.defined()) { - call_arg_rel_ = TypeRelationFn( - EnvFunc::Get("tvm.relay.type_relation.CallArg").node_); - } for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { - solver_.AddConstraint( - TypeRelationNode::make(call_arg_rel_, {arg_types[i], fn_ty->arg_types[i]}, 2, {}), - GetRef(call)); + this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]); } for (auto cs : fn_ty->type_constraints) { @@ -629,7 +584,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { } Expr VisitExpr_(const VarNode* op) final { - return AttachCheckedType(op); + return VisitVar(GetRef(op)); } Expr VisitExpr_(const ConstantNode* op) final { @@ -734,7 +689,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { bool need_update_var = ( std::is_base_of::value && update_missing_type_annotation_ && - !new_var->type_annotation.same_as(checked_type)); + !new_var->type_annotation.defined()); bool need_update_fn = ( std::is_base_of::value &&