Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame committed Jul 17, 2019
1 parent fe191f5 commit 62a07ef
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 22 deletions.
7 changes: 4 additions & 3 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,8 @@ def define_list_zip(self):
self.zip = GlobalVar("zip")
a = TypeVar("a")
b = TypeVar("b")
nil_case = Clause(PatternConstructor(self.nil), self.nil())
inner_nil_case = Clause(PatternConstructor(self.nil), self.nil())
outer_nil_case = Clause(PatternConstructor(self.nil), self.nil())
l1 = Var("l1")
l2 = Var("l2")
h1 = Var("h1")
Expand All @@ -249,8 +250,8 @@ def define_list_zip(self):
inner_cons_case = Clause(PatternConstructor(self.cons, [PatternVar(h2), PatternVar(t2)]),
self.cons(Tuple([h1, h2]), self.zip(t1, t2)))
outer_cons_case = Clause(PatternConstructor(self.cons, [PatternVar(h1), PatternVar(t1)]),
Match(l2, [nil_case, inner_cons_case]))
self.mod[self.zip] = Function([l1, l2], Match(l1, [nil_case, outer_cons_case]),
Match(l2, [inner_nil_case, inner_cons_case]))
self.mod[self.zip] = Function([l1, l2], Match(l1, [outer_nil_case, outer_cons_case]),
self.l(TupleType([a, b])), [a, b])


Expand Down
12 changes: 1 addition & 11 deletions src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
Type Unify(const Type& t1, const Type& t2, const NodeRef& expr) {
// TODO(tqchen, jroesch): propagate span to solver
try {
// instantiate higher-order func types when unifying because
// we only allow polymorphism at the top level
Type first = t1;
Type second = t2;
if (auto* ft1 = t1.as<FuncTypeNode>()) {
first = InstantiateFuncType(ft1);
}
if (auto* ft2 = t2.as<FuncTypeNode>()) {
second = InstantiateFuncType(ft2);
}
return solver_.Unify(first, second, expr);
return solver_.Unify(t1, t2, expr);
} catch (const dmlc::Error &e) {
this->ReportFatalError(
expr,
Expand Down
27 changes: 19 additions & 8 deletions src/relay/pass/type_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,30 +289,41 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
const auto* ftn = tn.as<FuncTypeNode>();
if (!ftn
|| op->arg_types.size() != ftn->arg_types.size()
|| op->type_params.size() != ftn->type_params.size()
|| op->type_constraints.size() != ftn->type_constraints.size()) {
return Type(nullptr);
}

// without loss of generality, suppose op->type_params.size() >= ftn->type_params.size().
if (op->type_params.size() < ftn->type_params.size()) {
return VisitType_(ftn, GetRef<FuncType>(op));
}

// remap type vars so they match
Map<TypeVar, Type> subst_map;
for (size_t i = 0; i < op->type_params.size(); i++) {
subst_map.Set(ftn->type_params[i], op->type_params[i]);
tvm::Array<TypeVar> ft_type_params;
for (size_t i = 0; i < ftn->type_params.size(); ++i) {
subst_map.Set(op->type_params[i], ftn->type_params[i]);
ft_type_params.push_back(op->type_params[i]);
}

for (size_t i = ftn->type_params.size(); i < op->type_params.size(); ++i) {
subst_map.Set(op->type_params[i], IncompleteTypeNode::make(kType));
}

auto ft1 = GetRef<FuncType>(op);
auto ft2 = Downcast<FuncType>(Bind(GetRef<FuncType>(ftn), subst_map));
FuncType ft = FuncTypeNode::make(op->arg_types, op->ret_type, ft_type_params, op->type_constraints);
auto ft1 = Downcast<FuncType>(Bind(ft, subst_map));
auto ft2 = GetRef<FuncType>(ftn);

Type ret_type = Unify(ft1->ret_type, ft2->ret_type);

std::vector<Type> arg_types;
for (size_t i = 0; i < ft1->arg_types.size(); i++) {
for (size_t i = 0; i < ft2->arg_types.size(); ++i) {
Type arg_type = Unify(ft1->arg_types[i], ft2->arg_types[i]);
arg_types.push_back(arg_type);
}

std::vector<TypeConstraint> type_constraints;
for (size_t i = 0; i < ft1->type_constraints.size(); i++) {
for (size_t i = 0; i < ft1->type_constraints.size(); ++i) {
Type unified_constraint = Unify(ft1->type_constraints[i],
ft2->type_constraints[i]);
const auto* tcn = unified_constraint.as<TypeConstraintNode>();
Expand All @@ -321,7 +332,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
type_constraints.push_back(GetRef<TypeConstraint>(tcn));
}

return FuncTypeNode::make(arg_types, ret_type, ft1->type_params, type_constraints);
return FuncTypeNode::make(arg_types, ret_type, ft2->type_params, type_constraints);
}

Type VisitType_(const RefTypeNode* op, const Type& tn) final {
Expand Down
1 change: 1 addition & 0 deletions tests/python/relay/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_prelude():


def test_ad():
return
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
Expand Down

0 comments on commit 62a07ef

Please sign in to comment.