Skip to content

Commit

Permalink
Remove unification hack
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed May 22, 2019
1 parent e7e2322 commit 1acbdcd
Showing 1 changed file with 3 additions and 48 deletions.
51 changes: 3 additions & 48 deletions src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,6 @@ TVM_REGISTER_API("tvm.relay.type_relation.MakeTuple")
MakeTupleRel);


// Deferred relation for call arg
bool CallArgRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
reporter->AssignArg(types[1], types[0]);
return true;
}
TVM_REGISTER_API("tvm.relay.type_relation.CallArg")
.set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>(
CallArgRel);


struct ResolvedTypeInfo {
explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args)
: checked_type(checked_type), type_args(type_args) {}
Expand Down Expand Up @@ -167,7 +153,6 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
// relation function
TypeRelationFn tuple_getitem_rel_;
TypeRelationFn make_tuple_rel_;
TypeRelationFn call_arg_rel_;

// Perform unification on two types and report the error at the expression
// or the span of the expression.
Expand Down Expand Up @@ -394,30 +379,6 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
return rtype;
}

// struct SubShapeMutator : TypeMutator {
// std::unordered_map<IndexExpr, IndexExpr, NodeEqual, NodeHash> sh_map_;

// SubShapeMutator(std::unordered_map<IndexExpr, IndexExpr, NodeEqual, NodeHash> sh_map) :
// sh_map_(sh_map) {}

// Type VisitType_(const TensorTypeNode* ty) {
// tvm::Array<IndexExpr> shape;
// for (auto sh : ty->shape) {
// auto it = sh_map_.find(sh);
// if (it != sh_map_.end()) {
// shape.push_back(it->second);
// } else {
// shape.push_back(sh);
// }
// return TensorTypeNode::make(shape, ty->dtype);
// }
// }
// };

// Type SubShape(const Type& ty, std::unordered_map<IndexExpr, IndexExpr, NodeEqual, NodeHash> sh_map) {
// return SubShapeMutator(sh_map).VisitType(ty);
// }

// substitute the type args in the function type
FuncType InstantiateFuncType(const FuncTypeNode* fn_ty, const Array<Type>& ty_args) {
tvm::Map<TypeVar, Type> subst_map;
Expand Down Expand Up @@ -529,14 +490,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
}
}

if (!call_arg_rel_.defined()) {
call_arg_rel_ = TypeRelationFn(
EnvFunc::Get("tvm.relay.type_relation.CallArg").node_);
}
for (size_t i = 0; i < fn_ty->arg_types.size(); i++) {
solver_.AddConstraint(
TypeRelationNode::make(call_arg_rel_, {arg_types[i], fn_ty->arg_types[i]}, 2, {}),
GetRef<Call>(call));
this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]);
}

for (auto cs : fn_ty->type_constraints) {
Expand Down Expand Up @@ -629,7 +584,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
}

Expr VisitExpr_(const VarNode* op) final {
return AttachCheckedType(op);
return VisitVar(GetRef<Var>(op));
}

Expr VisitExpr_(const ConstantNode* op) final {
Expand Down Expand Up @@ -734,7 +689,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
bool need_update_var = (
std::is_base_of<VarNode, T>::value &&
update_missing_type_annotation_ &&
!new_var->type_annotation.same_as(checked_type));
!new_var->type_annotation.defined());

bool need_update_fn = (
std::is_base_of<FunctionNode, T>::value &&
Expand Down

0 comments on commit 1acbdcd

Please sign in to comment.