From e6abad226821d08fbb3453384f756edf032ae65e Mon Sep 17 00:00:00 2001
From: Jared Roesch <roeschinc@gmail.com>
Date: Tue, 23 Oct 2018 00:35:10 -0700
Subject: [PATCH 1/7] Add support for populating type args

---
 include/tvm/relay/expr.h              |   2 +-
 src/relay/pass/type_infer.cc          | 108 +++++++++++++++++++-------
 tests/python/relay/test_type_infer.py |  11 +++
 3 files changed, 93 insertions(+), 28 deletions(-)

diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h
index 142982d48907..f459663ad705 100644
--- a/include/tvm/relay/expr.h
+++ b/include/tvm/relay/expr.h
@@ -267,7 +267,7 @@ class CallNode : public ExprNode {
    *
    * \endcode
    */
-  tvm::Array<Type> type_args;
+  mutable tvm::Array<Type> type_args;
 
   void VisitAttrs(tvm::AttrVisitor* v) final {
     v->Visit("op", &op);
diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc
index 0cbce833aed9..63ff7860a0f4 100644
--- a/src/relay/pass/type_infer.cc
+++ b/src/relay/pass/type_infer.cc
@@ -61,6 +61,18 @@ TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem")
 .set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>(
     TupleGetItemRel);
 
+struct ResolvedTypeInfo {
+  explicit ResolvedTypeInfo(Type checked_type) : checked_type(checked_type), type_args() {}
+  explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args) : checked_type(checked_type), type_args() {}
+  explicit ResolvedTypeInfo(Array<Type> type_args) : checked_type(), type_args(type_args) {}
+  ResolvedTypeInfo(const ResolvedTypeInfo& rti) : checked_type(rti.checked_type), type_args(rti.type_args) {}
+  ResolvedTypeInfo() : checked_type(), type_args() {}
+
+  Type checked_type;
+  // Only allocated when the expression is a call.
+  Array<Type> type_args;
+};
+
 //
 // The inference algorithm can roughly be devided into three stages:
 // - Populate the constraints by visiting the expression (TypeInferencer.GetType)
@@ -87,7 +99,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
   Environment env_;
   // map from expression to checked type
   // type inferencer will populate it up
-  std::unordered_map<Expr, Type, NodeHash, NodeEqual> type_map_;
+  std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual> type_map_;
+
   // The solver used by the inferencer.
   TypeSolver solver_;
   // relation function
@@ -111,11 +124,12 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
   // will call visit to deduce it if it is not in the type_map_
   Type GetType(const Expr &expr) {
     auto it = type_map_.find(expr);
-    if (it != type_map_.end()) {
-      return it->second;
+    if (it != type_map_.end() && it->second.checked_type.defined()) {
+      return it->second.checked_type;
     }
     Type ret = this->VisitExpr(expr);
-    type_map_[expr] = ret;
+    ResolvedTypeInfo& rti = type_map_[expr];
+    rti.checked_type = ret;
     return ret;
   }
 
@@ -176,7 +190,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
     }
     CHECK(!type_map_.count(op->var));
     // NOTE: no scoping is necessary because var are unique in program
-    type_map_[op->var] = vtype;
+    type_map_[op->var].checked_type = { vtype };
     return GetType(op->body);
   }
 
@@ -224,6 +238,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
       subst_map.Set(ty_param, fresh);
       ty_args->push_back(fresh);
     }
+
     Type ret_type = fn_ty->ret_type;
 
     // If the function type is incomplete, place a new IncompleteType
@@ -234,6 +249,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
     if (!ret_type.defined()) {
       ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
     }
+
     Type inst_ty = FuncTypeNode::make(fn_ty->arg_types,
                                       ret_type, {},
                                       fn_ty->type_constraints);
@@ -241,49 +257,74 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
     return Downcast<FuncType>(inst_ty);
   }
 
+  void AddTypeArgs(const Expr& expr, Array<Type> type_args) {
+    auto type_info = type_map_.find(expr);
+    if (type_info == type_map_.end()) {
+      type_map_.insert({expr, ResolvedTypeInfo(type_args) });
+    } else {
+      CHECK(!type_info->second.type_args.defined());
+      type_info->second.type_args = type_args;
+    }
+  }
+
   // Handle general call node.
-  Type GeneralCall(const CallNode* op, Array<Type> arg_types) {
-    Type ftype = GetType(op->op);
+  Type GeneralCall(const CallNode* call, Array<Type> arg_types) {
+    Type ftype = GetType(call->op);
     auto* fn_ty_node = ftype.as<FuncTypeNode>();
+
     CHECK(fn_ty_node != nullptr)
         << "only expressions with function types can be called, at "
-        << op->span;
+        << call->span;
 
     Array<Type> type_args;
     FuncType fn_ty = Instantiate(fn_ty_node, &type_args);
+
+    AddTypeArgs(GetRef<Call>(call), type_args);
+
     size_t type_arity = fn_ty->arg_types.size();
     size_t number_of_args = arg_types.size();
 
     if (type_arity != number_of_args) {
       if (type_arity < number_of_args) {
-        LOG(FATAL) << "the function is provided too many arguments " << op->span;
+        LOG(FATAL) << "the function is provided too many arguments " << call->span;
       } else {
-        LOG(FATAL) << "the function is provided too few arguments" << op->span;
+        LOG(FATAL) << "the function is provided too few arguments" << call->span;
       }
     }
+
     for (size_t i = 0; i < fn_ty->arg_types.size(); i++) {
-      this->Unify(fn_ty->arg_types[i], arg_types[i], op->args[i]->span);
+      this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]->span);
     }
 
     for (auto cs : fn_ty->type_constraints) {
-      solver_.AddConstraint(cs);
+      if (auto tr = cs.as<TypeRelationNode>()) {
+        solver_.AddConstraint(
+          TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs));
+      } else {
+        solver_.AddConstraint(cs);
+      }
     }
+
     return fn_ty->ret_type;
   }
 
-  Type VisitExpr_(const CallNode* op) final {
-    // Fast path: well-formed primitive op
+  Type VisitExpr_(const CallNode* call) final {
     Array<Type> arg_types;
-    for (Expr arg : op->args) {
+    for (Expr arg : call->args) {
       arg_types.push_back(GetType(arg));
     }
-    if (const OpNode* opnode = op->op.as<OpNode>()) {
+
+    if (const OpNode* opnode = call->op.as<OpNode>()) {
       Type rtype = PrimitiveCall(opnode->op_type.as<FuncTypeNode>(),
                                  arg_types,
-                                 op->attrs);
-      if (rtype.defined()) return rtype;
+                                 call->attrs);
+      if (rtype.defined()) {
+        AddTypeArgs(GetRef<Call>(call), arg_types);
+        return rtype;
+      }
     }
-    return GeneralCall(op, arg_types);
+
+    return GeneralCall(call, arg_types);
   }
 
   Type VisitExpr_(const FunctionNode* f) final {
@@ -312,7 +353,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
 
 class TypeInferencer::Resolver : public ExprMutator {
  public:
-  Resolver(const std::unordered_map<Expr, Type, NodeHash, NodeEqual>& tmap,
+  Resolver(const std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual>& tmap,
            TypeSolver* solver)
       : tmap_(tmap), solver_(solver) {
   }
@@ -346,7 +387,20 @@ class TypeInferencer::Resolver : public ExprMutator {
   }
 
   Expr VisitExpr_(const CallNode* op) final {
-    return AttachCheckedType(op);
+    auto call = GetRef<Call>(op);
+    auto it = tmap_.find(call);
+    if (it != tmap_.end()) {
+      Call new_op = Downcast<Call>(AttachCheckedType(op));
+      new_op->type_args = it->second.type_args;
+
+      for (int i = 0; i < new_op->type_args.size(); i++) {
+        new_op->type_args.Set(i, solver_->Resolve(new_op->type_args[i]));
+      }
+
+      return new_op;
+    } else {
+      return AttachCheckedType(op);
+    }
   }
 
   Expr VisitExpr_(const LetNode* op) final {
@@ -362,7 +416,7 @@ class TypeInferencer::Resolver : public ExprMutator {
   Expr AttachCheckedType(const T* op) {
     auto it = tmap_.find(GetRef<Expr>(op));
     CHECK(it != tmap_.end());
-    Type checked_type = solver_->Resolve(it->second);
+    Type checked_type = solver_->Resolve(it->second.checked_type);
     CHECK(checked_type.as<IncompleteTypeNode>() == nullptr)
         << "Cannot resolve type of " << GetRef<Expr>(op)
         << " at " << op->span;
@@ -379,22 +433,22 @@ class TypeInferencer::Resolver : public ExprMutator {
     return new_e;
   }
 
-  Type VisitType(const Type& t) final {
+  Type VisitType(const Type &t) final {
     return solver_->Resolve(t);
   }
 
  private:
-  const std::unordered_map<Expr, Type, NodeHash, NodeEqual>& tmap_;
+  const std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual>& tmap_;
   TypeSolver* solver_;
 };
 
 
 Expr TypeInferencer::Infer(Expr expr) {
-  // step 0: populate the constraints
+  // Step 0: Populate the constraints.
   GetType(expr);
-  // step 1: solve the constraints
+  // Step 1: Solve the constraints.
   solver_.Solve();
-  // step 2: attach resolved types to checked_type field
+  // Step 2: Attach resolved types to checked_type field.
   return Resolver(type_map_, &solver_).VisitExpr(expr);
 }
 
diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py
index 2d8f98974639..d59f90b6484e 100644
--- a/tests/python/relay/test_type_infer.py
+++ b/tests/python/relay/test_type_infer.py
@@ -91,6 +91,17 @@ def test_free_expr():
     yy = relay.ir_pass.infer_type(y)
     assert yy.checked_type == relay.scalar_type("float32")
 
+def test_type_args():
+    x = relay.var("x", shape=(10, 10))
+    y = relay.var("y", shape=(10, 5))
+    z = relay.add(x, y)
+    ty_z = relay.ir_pass.infer_type(z)
+    ty_args = ty_z.type_args
+    assert len(ty_args) == 2
+    assert ty_args[0].dtype == "float32"
+    assert ty_args[1].dtype == "float32"
+    assert ty_args[0].shape == (10, 10)
+    assert ty_args[1].shape == (10, 5)
 
 if __name__ == "__main__":
     test_free_expr()

From 1660b311f40a87a46dd45accc2db4acb9e07106b Mon Sep 17 00:00:00 2001
From: Jared Roesch <roeschinc@gmail.com>
Date: Tue, 23 Oct 2018 00:46:03 -0700
Subject: [PATCH 2/7] Fix text printer and get test green

---
 src/relay/ir/text_printer.cc          | 20 +++++++++++++++++---
 tests/python/relay/test_type_infer.py | 12 +++++++++---
 2 files changed, 26 insertions(+), 6 deletions(-)

diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc
index 66ef86641fae..09f2ba23141f 100644
--- a/src/relay/ir/text_printer.cc
+++ b/src/relay/ir/text_printer.cc
@@ -277,8 +277,6 @@ class TextPrinter :
   TextValue VisitExpr_(const CallNode* op) final {
     // TODO(tqchen, M.K.): support generic call
     // possibly through meta-data
-    CHECK_EQ(op->type_args.size(), 0U)
-        << "generic call not yet supported";
     TextValue call_op = GetValue(op->op);
     std::vector<TextValue> args;
     for (Expr arg : op->args) {
@@ -286,7 +284,23 @@ class TextPrinter :
     }
     TextValue id = this->AllocTempVar();
     this->PrintIndent();
-    stream_ << id << " = " << call_op << "(";
+
+    stream_ << id << " = " << call_op;
+
+    auto type_args = op->type_args;
+
+    if (type_args.size() > 0U) {
+      stream_ << "<";
+      for (size_t i = 0; i < op->type_args.size(); ++i) {
+        this->PrintType(type_args[i], stream_);
+        if (i + 1 != type_args.size()) {
+          stream_ << ", ";
+        }
+      }
+      stream_ << ">";
+    }
+
+    stream_ << "(";
     for (size_t i = 0; i < args.size(); ++i) {
       stream_ << args[i];
       if (i + 1 != args.size()) {
diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py
index d59f90b6484e..e1d749e75863 100644
--- a/tests/python/relay/test_type_infer.py
+++ b/tests/python/relay/test_type_infer.py
@@ -93,15 +93,19 @@ def test_free_expr():
 
 def test_type_args():
     x = relay.var("x", shape=(10, 10))
-    y = relay.var("y", shape=(10, 5))
+    y = relay.var("y", shape=(1, 10))
     z = relay.add(x, y)
     ty_z = relay.ir_pass.infer_type(z)
     ty_args = ty_z.type_args
     assert len(ty_args) == 2
     assert ty_args[0].dtype == "float32"
     assert ty_args[1].dtype == "float32"
-    assert ty_args[0].shape == (10, 10)
-    assert ty_args[1].shape == (10, 5)
+    sh1 = ty_args[0].shape
+    sh2 = ty_args[1].shape
+    assert sh1[0].value == 10
+    assert sh1[1].value == 10
+    assert sh2[0].value == 1
+    assert sh2[1].value == 10
 
 if __name__ == "__main__":
     test_free_expr()
@@ -111,3 +115,5 @@ def test_type_args():
     test_decl()
     test_recursion()
     test_tuple()
+    test_free_expr()
+    test_type_args()

From d55cc315d8f2f34a757a429405e10934be2391cb Mon Sep 17 00:00:00 2001
From: Jared Roesch <roeschinc@gmail.com>
Date: Tue, 23 Oct 2018 00:52:16 -0700
Subject: [PATCH 3/7] Fix lint

---
 src/relay/pass/type_infer.cc | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)

diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc
index 63ff7860a0f4..39ea112f33f0 100644
--- a/src/relay/pass/type_infer.cc
+++ b/src/relay/pass/type_infer.cc
@@ -62,10 +62,14 @@ TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem")
     TupleGetItemRel);
 
 struct ResolvedTypeInfo {
-  explicit ResolvedTypeInfo(Type checked_type) : checked_type(checked_type), type_args() {}
-  explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args) : checked_type(checked_type), type_args() {}
-  explicit ResolvedTypeInfo(Array<Type> type_args) : checked_type(), type_args(type_args) {}
-  ResolvedTypeInfo(const ResolvedTypeInfo& rti) : checked_type(rti.checked_type), type_args(rti.type_args) {}
+  explicit ResolvedTypeInfo(Type checked_type)
+      : checked_type(checked_type), type_args() {}
+  explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args)
+      : checked_type(checked_type), type_args() {}
+  explicit ResolvedTypeInfo(Array<Type> type_args)
+      : checked_type(), type_args(type_args) {}
+  ResolvedTypeInfo(const ResolvedTypeInfo& rti)
+      : checked_type(rti.checked_type), type_args(rti.type_args) {}
   ResolvedTypeInfo() : checked_type(), type_args() {}
 
   Type checked_type;

From ca33fc2b44883e19022ad8e6006917ee795a5695 Mon Sep 17 00:00:00 2001
From: Jared Roesch <roeschinc@gmail.com>
Date: Tue, 23 Oct 2018 14:44:01 -0700
Subject: [PATCH 4/7] Addresss some comments

---
 include/tvm/relay/expr.h     |  2 +-
 src/relay/pass/type_infer.cc | 42 ++++++++++++++++--------------------
 2 files changed, 20 insertions(+), 24 deletions(-)

diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h
index f459663ad705..142982d48907 100644
--- a/include/tvm/relay/expr.h
+++ b/include/tvm/relay/expr.h
@@ -267,7 +267,7 @@ class CallNode : public ExprNode {
    *
    * \endcode
    */
-  mutable tvm::Array<Type> type_args;
+  tvm::Array<Type> type_args;
 
   void VisitAttrs(tvm::AttrVisitor* v) final {
     v->Visit("op", &op);
diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc
index 39ea112f33f0..3ba249777ef8 100644
--- a/src/relay/pass/type_infer.cc
+++ b/src/relay/pass/type_infer.cc
@@ -62,19 +62,16 @@ TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem")
     TupleGetItemRel);
 
 struct ResolvedTypeInfo {
-  explicit ResolvedTypeInfo(Type checked_type)
-      : checked_type(checked_type), type_args() {}
   explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args)
-      : checked_type(checked_type), type_args() {}
-  explicit ResolvedTypeInfo(Array<Type> type_args)
-      : checked_type(), type_args(type_args) {}
+      : checked_type(checked_type), type_args(type_args) {}
   ResolvedTypeInfo(const ResolvedTypeInfo& rti)
       : checked_type(rti.checked_type), type_args(rti.type_args) {}
-  ResolvedTypeInfo() : checked_type(), type_args() {}
+  ResolvedTypeInfo() : checked_type() {}
 
   Type checked_type;
   // Only allocated when the expression is a call.
-  Array<Type> type_args;
+
+  Array<Type> type_args = Array<Type>(NodePtr<Node>(nullptr));
 };
 
 //
@@ -194,7 +191,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
     }
     CHECK(!type_map_.count(op->var));
     // NOTE: no scoping is necessary because var are unique in program
-    type_map_[op->var].checked_type = { vtype };
+    type_map_[op->var].checked_type = vtype; // ResolvedTypeInfo(vtype, Array<Type>(NodePtr<Node>(nullptr)));
     return GetType(op->body);
   }
 
@@ -264,7 +261,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
   void AddTypeArgs(const Expr& expr, Array<Type> type_args) {
     auto type_info = type_map_.find(expr);
     if (type_info == type_map_.end()) {
-      type_map_.insert({expr, ResolvedTypeInfo(type_args) });
+      type_map_.insert({expr, ResolvedTypeInfo(Type(), type_args)});
     } else {
       CHECK(!type_info->second.type_args.defined());
       type_info->second.type_args = type_args;
@@ -391,20 +388,7 @@ class TypeInferencer::Resolver : public ExprMutator {
   }
 
   Expr VisitExpr_(const CallNode* op) final {
-    auto call = GetRef<Call>(op);
-    auto it = tmap_.find(call);
-    if (it != tmap_.end()) {
-      Call new_op = Downcast<Call>(AttachCheckedType(op));
-      new_op->type_args = it->second.type_args;
-
-      for (int i = 0; i < new_op->type_args.size(); i++) {
-        new_op->type_args.Set(i, solver_->Resolve(new_op->type_args[i]));
-      }
-
-      return new_op;
-    } else {
-      return AttachCheckedType(op);
-    }
+    return AttachCheckedType(op);
   }
 
   Expr VisitExpr_(const LetNode* op) final {
@@ -434,6 +418,18 @@ class TypeInferencer::Resolver : public ExprMutator {
       }
       new_e->checked_type_ = checked_type;
     }
+
+    if (it->second.type_args.defined()) {
+      Call call = Downcast<Call>(new_e);
+      const CallNode* const_call_ref = call.operator->();
+      CallNode* call_ref = const_cast<CallNode*>(const_call_ref);
+      call_ref->type_args = it->second.type_args;
+
+      for (size_t i = 0; i < call->type_args.size(); i++) {
+        call_ref->type_args.Set(i, solver_->Resolve(call->type_args[i]));
+      }
+    }
+
     return new_e;
   }
 

From 21e24381aaaa0fac444f30050b8f30e82d7b497d Mon Sep 17 00:00:00 2001
From: Jared Roesch <roeschinc@gmail.com>
Date: Tue, 23 Oct 2018 14:46:22 -0700
Subject: [PATCH 5/7] Fix lint

---
 src/relay/pass/type_infer.cc | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc
index 3ba249777ef8..351a509b2874 100644
--- a/src/relay/pass/type_infer.cc
+++ b/src/relay/pass/type_infer.cc
@@ -191,7 +191,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
     }
     CHECK(!type_map_.count(op->var));
     // NOTE: no scoping is necessary because var are unique in program
-    type_map_[op->var].checked_type = vtype; // ResolvedTypeInfo(vtype, Array<Type>(NodePtr<Node>(nullptr)));
+    type_map_[op->var].checked_type = vtype;
     return GetType(op->body);
   }
 

From 9dc874ec7b055235183a55afb88c3e7167321016 Mon Sep 17 00:00:00 2001
From: Jared Roesch <roeschinc@gmail.com>
Date: Tue, 23 Oct 2018 14:58:31 -0700
Subject: [PATCH 6/7] Fix test case

---
 tests/python/relay/test_ir_text_printer.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py
index 29814ecc5eb7..1d272236c680 100644
--- a/tests/python/relay/test_ir_text_printer.py
+++ b/tests/python/relay/test_ir_text_printer.py
@@ -30,7 +30,7 @@ def test_env():
     env["myf"] = f
     text = env.astext()
     assert "def @myf" in text
-    assert "%1 = add(%0, %0) # ty=float32" in text
+    assert "%1 = add<float32, float32>(%0, %0) # ty=float32" in text
     show(text)
 
 

From ba8c4cbb22ff67f8ea2b4c099fc1cafcaee8fa27 Mon Sep 17 00:00:00 2001
From: Jared Roesch <roeschinc@gmail.com>
Date: Tue, 23 Oct 2018 16:19:51 -0700
Subject: [PATCH 7/7] Fix printing for primitive ops

Address feedback
---
 include/tvm/relay/op.h                     | 30 ++++++++++++++++++++++
 src/relay/ir/text_printer.cc               |  3 +--
 src/relay/pass/type_infer.cc               |  4 +--
 tests/python/relay/test_ir_text_printer.py |  2 +-
 4 files changed, 33 insertions(+), 6 deletions(-)

diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h
index fe6d957e79ed..9f28fbebccfc 100644
--- a/include/tvm/relay/op.h
+++ b/include/tvm/relay/op.h
@@ -485,6 +485,36 @@ inline ValueType OpMap<ValueType>::get(const Op& op,
   return map_.get<ValueType>(op, def_value);
 }
 
+/*!
+ * \brief Check that an expression is a "primtive operator".
+ *
+ * Will return true if the expression is an operator which
+ * matches the form of primtive operators registered directly
+ * by the Relay codebase.
+ *
+ * That is the arguments are all type variables, and there is a single
+ * type relation applied to the input and output types.
+ */
+inline bool IsPrimitiveOp(const Expr& expr) {
+  const auto* op = expr.as<OpNode>();
+
+  if (!op) {
+    return false;
+  }
+
+  const auto& fn_ty = op->op_type;
+  if (fn_ty->type_constraints.size() != 1) return false;
+
+  const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
+  if (rel == nullptr) return false;
+  // validate if the type parameter matches up
+  for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
+    if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
+  }
+
+  return true;
+}
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_OP_H_
diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc
index 09f2ba23141f..86ca4d74a974 100644
--- a/src/relay/ir/text_printer.cc
+++ b/src/relay/ir/text_printer.cc
@@ -275,7 +275,6 @@ class TextPrinter :
   }
 
   TextValue VisitExpr_(const CallNode* op) final {
-    // TODO(tqchen, M.K.): support generic call
     // possibly through meta-data
     TextValue call_op = GetValue(op->op);
     std::vector<TextValue> args;
@@ -289,7 +288,7 @@ class TextPrinter :
 
     auto type_args = op->type_args;
 
-    if (type_args.size() > 0U) {
+    if (!IsPrimitiveOp(op->op) && type_args.size() > 0U) {
       stream_ << "<";
       for (size_t i = 0; i < op->type_args.size(); ++i) {
         this->PrintType(type_args[i], stream_);
diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc
index 351a509b2874..87fdb1c0ffba 100644
--- a/src/relay/pass/type_infer.cc
+++ b/src/relay/pass/type_infer.cc
@@ -64,9 +64,7 @@ TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem")
 struct ResolvedTypeInfo {
   explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args)
       : checked_type(checked_type), type_args(type_args) {}
-  ResolvedTypeInfo(const ResolvedTypeInfo& rti)
-      : checked_type(rti.checked_type), type_args(rti.type_args) {}
-  ResolvedTypeInfo() : checked_type() {}
+  ResolvedTypeInfo() {}
 
   Type checked_type;
   // Only allocated when the expression is a call.
diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py
index 1d272236c680..29814ecc5eb7 100644
--- a/tests/python/relay/test_ir_text_printer.py
+++ b/tests/python/relay/test_ir_text_printer.py
@@ -30,7 +30,7 @@ def test_env():
     env["myf"] = f
     text = env.astext()
     assert "def @myf" in text
-    assert "%1 = add<float32, float32>(%0, %0) # ty=float32" in text
+    assert "%1 = add(%0, %0) # ty=float32" in text
     show(text)