From 61fb9d5c528e6865ba274228a11407e98d13ef9e Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 4 Oct 2018 17:47:51 -0700 Subject: [PATCH] Visit type params first in type alpha_eq so eq_map is updated --- src/relay/pass/alpha_eq.cc | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc index 87a93dbc2dd2f..39f55af6fe70e 100644 --- a/src/relay/pass/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -95,6 +95,16 @@ struct TypeAlphaEq : TypeVisitor { return; } + // must visit params first so they are appropriate entered + // into equality map + for (size_t i = 0; i < op->type_params.size(); i++) { + eq_map.Set(op->type_params[i], ta2->type_params[i]); + this->VisitType(op->type_params[i], ta2->type_params[i]); + if (!equal) { + return; + } + } + for (size_t i = 0; i < op->arg_types.size(); i++) { this->VisitType(op->arg_types[i], ta2->arg_types[i]); if (!equal) { @@ -107,14 +117,6 @@ struct TypeAlphaEq : TypeVisitor { return; } - for (size_t i = 0; i < op->type_params.size(); i++) { - eq_map.Set(op->type_params[i], ta2->type_params[i]); - this->VisitType(op->type_params[i], ta2->type_params[i]); - if (!equal) { - return; - } - } - for (size_t i = 0; i < op->type_constraints.size(); i++) { this->VisitType(op->type_constraints[i], ta2->type_constraints[i]); if (!equal) { @@ -128,7 +130,24 @@ struct TypeAlphaEq : TypeVisitor { void VisitType_(const TypeRelationNode *tr1, const Type& t2) final { if (const TypeRelationNode *tr2 = t2.as()) { - equal = equal && (tr1 == tr2); + if (tr1->func != tr2->func + || tr1->num_inputs != tr2->num_inputs + || tr1->attrs != tr2->attrs) { + equal = false; + return; + } + + if (tr1->args.size() != tr2->args.size()) { + equal = false; + return; + } + + for (size_t i = 0; i < tr1->args.size(); i++) { + this->VisitType(tr1->args[i], tr2->args[i]); + if (!equal) { + return; + } + } } else { equal = false; }