Skip to content

Commit

Permalink
Visit type params first in type alpha_eq so eq_map is updated
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky committed Oct 5, 2018
1 parent 1c25c01 commit 61fb9d5
Showing 1 changed file with 28 additions and 9 deletions.
37 changes: 28 additions & 9 deletions src/relay/pass/alpha_eq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
return;
}

// must visit params first so they are appropriate entered
// into equality map
for (size_t i = 0; i < op->type_params.size(); i++) {
eq_map.Set(op->type_params[i], ta2->type_params[i]);
this->VisitType(op->type_params[i], ta2->type_params[i]);
if (!equal) {
return;
}
}

for (size_t i = 0; i < op->arg_types.size(); i++) {
this->VisitType(op->arg_types[i], ta2->arg_types[i]);
if (!equal) {
Expand All @@ -107,14 +117,6 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
return;
}

for (size_t i = 0; i < op->type_params.size(); i++) {
eq_map.Set(op->type_params[i], ta2->type_params[i]);
this->VisitType(op->type_params[i], ta2->type_params[i]);
if (!equal) {
return;
}
}

for (size_t i = 0; i < op->type_constraints.size(); i++) {
this->VisitType(op->type_constraints[i], ta2->type_constraints[i]);
if (!equal) {
Expand All @@ -128,7 +130,24 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {

void VisitType_(const TypeRelationNode *tr1, const Type& t2) final {
if (const TypeRelationNode *tr2 = t2.as<TypeRelationNode>()) {
equal = equal && (tr1 == tr2);
if (tr1->func != tr2->func
|| tr1->num_inputs != tr2->num_inputs
|| tr1->attrs != tr2->attrs) {
equal = false;
return;
}

if (tr1->args.size() != tr2->args.size()) {
equal = false;
return;
}

for (size_t i = 0; i < tr1->args.size(); i++) {
this->VisitType(tr1->args[i], tr2->args[i]);
if (!equal) {
return;
}
}
} else {
equal = false;
}
Expand Down

0 comments on commit 61fb9d5

Please sign in to comment.