Skip to content

Commit

Permalink
better type infer ability
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame committed Jan 23, 2019
1 parent f5258a8 commit 0ce1f68
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ TVM_REGISTER_API("relay._make.RefWrite").set_body([](TVMArgs args, TVMRetValue*

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefWriteNode>([](const RefWriteNode* node, tvm::IRPrinter* p) {
p->stream << "RefWriteNode(" << node->value << ")";
p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")";
});

TVM_REGISTER_API("relay._expr.TempExprRealize")
Expand Down
22 changes: 6 additions & 16 deletions src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -396,25 +396,15 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
}

Type VisitExpr_(const RefReadNode* op) final {
// TODO(M.K.)
// handle case where ref type is not known
Type ref_type = GetType(op->ref);
auto ref_ty_node = ref_type.as<RefTypeNode>();
if (!ref_ty_node) {
LOG(FATAL) << "only expressions with ref types is accepted" << GetRef<RefRead>(op);
}
return ref_ty_node->value;
Type it = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
this->Unify(GetType(op->ref), RefTypeNode::make(it), op->span);
return it;
}

Type VisitExpr_(const RefWriteNode* op) final {
// TODO(M.K.)
// handle case where ref type is not known
Type ref_type = GetType(op->ref);
auto ref_ty_node = ref_type.as<RefTypeNode>();
if (!ref_ty_node) {
LOG(FATAL) << "only expressions with ref types is accepted" << GetRef<RefWrite>(op);
}
this->Unify(ref_ty_node->value, GetType(op->value), op->span);
Type it = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
this->Unify(GetType(op->ref), RefTypeNode::make(it), op->span);
this->Unify(GetType(op->value), it, op->span);
return TupleTypeNode::make({});
}
};
Expand Down

0 comments on commit 0ce1f68

Please sign in to comment.