From 0ce1f68896c47d83862133bca994732ff735a7f5 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 23 Jan 2019 00:47:56 -0800 Subject: [PATCH] better type infer ability --- src/relay/ir/expr.cc | 2 +- src/relay/pass/type_infer.cc | 22 ++++++---------------- 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 8a5115db6cb75..b055b3eb90add 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -315,7 +315,7 @@ TVM_REGISTER_API("relay._make.RefWrite").set_body([](TVMArgs args, TVMRetValue* TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const RefWriteNode* node, tvm::IRPrinter* p) { - p->stream << "RefWriteNode(" << node->value << ")"; + p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")"; }); TVM_REGISTER_API("relay._expr.TempExprRealize") diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 0edfa9a11add2..c135d6fbc0299 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -396,25 +396,15 @@ class TypeInferencer : private ExprFunctor { } Type VisitExpr_(const RefReadNode* op) final { - // TODO(M.K.) - // handle case where ref type is not known - Type ref_type = GetType(op->ref); - auto ref_ty_node = ref_type.as(); - if (!ref_ty_node) { - LOG(FATAL) << "only expressions with ref types is accepted" << GetRef(op); - } - return ref_ty_node->value; + Type it = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + this->Unify(GetType(op->ref), RefTypeNode::make(it), op->span); + return it; } Type VisitExpr_(const RefWriteNode* op) final { - // TODO(M.K.) - // handle case where ref type is not known - Type ref_type = GetType(op->ref); - auto ref_ty_node = ref_type.as(); - if (!ref_ty_node) { - LOG(FATAL) << "only expressions with ref types is accepted" << GetRef(op); - } - this->Unify(ref_ty_node->value, GetType(op->value), op->span); + Type it = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + this->Unify(GetType(op->ref), RefTypeNode::make(it), op->span); + this->Unify(GetType(op->value), it, op->span); return TupleTypeNode::make({}); } };