From e6abad226821d08fbb3453384f756edf032ae65e Mon Sep 17 00:00:00 2001 From: Jared Roesch <roeschinc@gmail.com> Date: Tue, 23 Oct 2018 00:35:10 -0700 Subject: [PATCH 1/7] Add support for populating type args --- include/tvm/relay/expr.h | 2 +- src/relay/pass/type_infer.cc | 108 +++++++++++++++++++------- tests/python/relay/test_type_infer.py | 11 +++ 3 files changed, 93 insertions(+), 28 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 142982d48907..f459663ad705 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -267,7 +267,7 @@ class CallNode : public ExprNode { * * \endcode */ - tvm::Array<Type> type_args; + mutable tvm::Array<Type> type_args; void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("op", &op); diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 0cbce833aed9..63ff7860a0f4 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -61,6 +61,18 @@ TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem") .set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>( TupleGetItemRel); +struct ResolvedTypeInfo { + explicit ResolvedTypeInfo(Type checked_type) : checked_type(checked_type), type_args() {} + explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args) : checked_type(checked_type), type_args() {} + explicit ResolvedTypeInfo(Array<Type> type_args) : checked_type(), type_args(type_args) {} + ResolvedTypeInfo(const ResolvedTypeInfo& rti) : checked_type(rti.checked_type), type_args(rti.type_args) {} + ResolvedTypeInfo() : checked_type(), type_args() {} + + Type checked_type; + // Only allocated when the expression is a call. + Array<Type> type_args; +}; + // // The inference algorithm can roughly be devided into three stages: // - Populate the constraints by visiting the expression (TypeInferencer.GetType) @@ -87,7 +99,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { Environment env_; // map from expression to checked type // type inferencer will populate it up - std::unordered_map<Expr, Type, NodeHash, NodeEqual> type_map_; + std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual> type_map_; + // The solver used by the inferencer. TypeSolver solver_; // relation function @@ -111,11 +124,12 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { // will call visit to deduce it if it is not in the type_map_ Type GetType(const Expr &expr) { auto it = type_map_.find(expr); - if (it != type_map_.end()) { - return it->second; + if (it != type_map_.end() && it->second.checked_type.defined()) { + return it->second.checked_type; } Type ret = this->VisitExpr(expr); - type_map_[expr] = ret; + ResolvedTypeInfo& rti = type_map_[expr]; + rti.checked_type = ret; return ret; } @@ -176,7 +190,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { } CHECK(!type_map_.count(op->var)); // NOTE: no scoping is necessary because var are unique in program - type_map_[op->var] = vtype; + type_map_[op->var].checked_type = { vtype }; return GetType(op->body); } @@ -224,6 +238,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { subst_map.Set(ty_param, fresh); ty_args->push_back(fresh); } + Type ret_type = fn_ty->ret_type; // If the function type is incomplete, place a new IncompleteType @@ -234,6 +249,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { if (!ret_type.defined()) { ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); } + Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, ret_type, {}, fn_ty->type_constraints); @@ -241,49 +257,74 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { return Downcast<FuncType>(inst_ty); } + void AddTypeArgs(const Expr& expr, Array<Type> type_args) { + auto type_info = type_map_.find(expr); + if (type_info == type_map_.end()) { + type_map_.insert({expr, ResolvedTypeInfo(type_args) }); + } else { + CHECK(!type_info->second.type_args.defined()); + type_info->second.type_args = type_args; + } + } + // Handle general call node. - Type GeneralCall(const CallNode* op, Array<Type> arg_types) { - Type ftype = GetType(op->op); + Type GeneralCall(const CallNode* call, Array<Type> arg_types) { + Type ftype = GetType(call->op); auto* fn_ty_node = ftype.as<FuncTypeNode>(); + CHECK(fn_ty_node != nullptr) << "only expressions with function types can be called, at " - << op->span; + << call->span; Array<Type> type_args; FuncType fn_ty = Instantiate(fn_ty_node, &type_args); + + AddTypeArgs(GetRef<Call>(call), type_args); + size_t type_arity = fn_ty->arg_types.size(); size_t number_of_args = arg_types.size(); if (type_arity != number_of_args) { if (type_arity < number_of_args) { - LOG(FATAL) << "the function is provided too many arguments " << op->span; + LOG(FATAL) << "the function is provided too many arguments " << call->span; } else { - LOG(FATAL) << "the function is provided too few arguments" << op->span; + LOG(FATAL) << "the function is provided too few arguments" << call->span; } } + for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { - this->Unify(fn_ty->arg_types[i], arg_types[i], op->args[i]->span); + this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]->span); } for (auto cs : fn_ty->type_constraints) { - solver_.AddConstraint(cs); + if (auto tr = cs.as<TypeRelationNode>()) { + solver_.AddConstraint( + TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs)); + } else { + solver_.AddConstraint(cs); + } } + return fn_ty->ret_type; } - Type VisitExpr_(const CallNode* op) final { - // Fast path: well-formed primitive op + Type VisitExpr_(const CallNode* call) final { Array<Type> arg_types; - for (Expr arg : op->args) { + for (Expr arg : call->args) { arg_types.push_back(GetType(arg)); } - if (const OpNode* opnode = op->op.as<OpNode>()) { + + if (const OpNode* opnode = call->op.as<OpNode>()) { Type rtype = PrimitiveCall(opnode->op_type.as<FuncTypeNode>(), arg_types, - op->attrs); - if (rtype.defined()) return rtype; + call->attrs); + if (rtype.defined()) { + AddTypeArgs(GetRef<Call>(call), arg_types); + return rtype; + } } - return GeneralCall(op, arg_types); + + return GeneralCall(call, arg_types); } Type VisitExpr_(const FunctionNode* f) final { @@ -312,7 +353,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { class TypeInferencer::Resolver : public ExprMutator { public: - Resolver(const std::unordered_map<Expr, Type, NodeHash, NodeEqual>& tmap, + Resolver(const std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual>& tmap, TypeSolver* solver) : tmap_(tmap), solver_(solver) { } @@ -346,7 +387,20 @@ class TypeInferencer::Resolver : public ExprMutator { } Expr VisitExpr_(const CallNode* op) final { - return AttachCheckedType(op); + auto call = GetRef<Call>(op); + auto it = tmap_.find(call); + if (it != tmap_.end()) { + Call new_op = Downcast<Call>(AttachCheckedType(op)); + new_op->type_args = it->second.type_args; + + for (int i = 0; i < new_op->type_args.size(); i++) { + new_op->type_args.Set(i, solver_->Resolve(new_op->type_args[i])); + } + + return new_op; + } else { + return AttachCheckedType(op); + } } Expr VisitExpr_(const LetNode* op) final { @@ -362,7 +416,7 @@ class TypeInferencer::Resolver : public ExprMutator { Expr AttachCheckedType(const T* op) { auto it = tmap_.find(GetRef<Expr>(op)); CHECK(it != tmap_.end()); - Type checked_type = solver_->Resolve(it->second); + Type checked_type = solver_->Resolve(it->second.checked_type); CHECK(checked_type.as<IncompleteTypeNode>() == nullptr) << "Cannot resolve type of " << GetRef<Expr>(op) << " at " << op->span; @@ -379,22 +433,22 @@ class TypeInferencer::Resolver : public ExprMutator { return new_e; } - Type VisitType(const Type& t) final { + Type VisitType(const Type &t) final { return solver_->Resolve(t); } private: - const std::unordered_map<Expr, Type, NodeHash, NodeEqual>& tmap_; + const std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual>& tmap_; TypeSolver* solver_; }; Expr TypeInferencer::Infer(Expr expr) { - // step 0: populate the constraints + // Step 0: Populate the constraints. GetType(expr); - // step 1: solve the constraints + // Step 1: Solve the constraints. solver_.Solve(); - // step 2: attach resolved types to checked_type field + // Step 2: Attach resolved types to checked_type field. return Resolver(type_map_, &solver_).VisitExpr(expr); } diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 2d8f98974639..d59f90b6484e 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -91,6 +91,17 @@ def test_free_expr(): yy = relay.ir_pass.infer_type(y) assert yy.checked_type == relay.scalar_type("float32") +def test_type_args(): + x = relay.var("x", shape=(10, 10)) + y = relay.var("y", shape=(10, 5)) + z = relay.add(x, y) + ty_z = relay.ir_pass.infer_type(z) + ty_args = ty_z.type_args + assert len(ty_args) == 2 + assert ty_args[0].dtype == "float32" + assert ty_args[1].dtype == "float32" + assert ty_args[0].shape == (10, 10) + assert ty_args[1].shape == (10, 5) if __name__ == "__main__": test_free_expr() From 1660b311f40a87a46dd45accc2db4acb9e07106b Mon Sep 17 00:00:00 2001 From: Jared Roesch <roeschinc@gmail.com> Date: Tue, 23 Oct 2018 00:46:03 -0700 Subject: [PATCH 2/7] Fix text printer and get test green --- src/relay/ir/text_printer.cc | 20 +++++++++++++++++--- tests/python/relay/test_type_infer.py | 12 +++++++++--- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 66ef86641fae..09f2ba23141f 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -277,8 +277,6 @@ class TextPrinter : TextValue VisitExpr_(const CallNode* op) final { // TODO(tqchen, M.K.): support generic call // possibly through meta-data - CHECK_EQ(op->type_args.size(), 0U) - << "generic call not yet supported"; TextValue call_op = GetValue(op->op); std::vector<TextValue> args; for (Expr arg : op->args) { @@ -286,7 +284,23 @@ class TextPrinter : } TextValue id = this->AllocTempVar(); this->PrintIndent(); - stream_ << id << " = " << call_op << "("; + + stream_ << id << " = " << call_op; + + auto type_args = op->type_args; + + if (type_args.size() > 0U) { + stream_ << "<"; + for (size_t i = 0; i < op->type_args.size(); ++i) { + this->PrintType(type_args[i], stream_); + if (i + 1 != type_args.size()) { + stream_ << ", "; + } + } + stream_ << ">"; + } + + stream_ << "("; for (size_t i = 0; i < args.size(); ++i) { stream_ << args[i]; if (i + 1 != args.size()) { diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index d59f90b6484e..e1d749e75863 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -93,15 +93,19 @@ def test_free_expr(): def test_type_args(): x = relay.var("x", shape=(10, 10)) - y = relay.var("y", shape=(10, 5)) + y = relay.var("y", shape=(1, 10)) z = relay.add(x, y) ty_z = relay.ir_pass.infer_type(z) ty_args = ty_z.type_args assert len(ty_args) == 2 assert ty_args[0].dtype == "float32" assert ty_args[1].dtype == "float32" - assert ty_args[0].shape == (10, 10) - assert ty_args[1].shape == (10, 5) + sh1 = ty_args[0].shape + sh2 = ty_args[1].shape + assert sh1[0].value == 10 + assert sh1[1].value == 10 + assert sh2[0].value == 1 + assert sh2[1].value == 10 if __name__ == "__main__": test_free_expr() @@ -111,3 +115,5 @@ def test_type_args(): test_decl() test_recursion() test_tuple() + test_free_expr() + test_type_args() From d55cc315d8f2f34a757a429405e10934be2391cb Mon Sep 17 00:00:00 2001 From: Jared Roesch <roeschinc@gmail.com> Date: Tue, 23 Oct 2018 00:52:16 -0700 Subject: [PATCH 3/7] Fix lint --- src/relay/pass/type_infer.cc | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 63ff7860a0f4..39ea112f33f0 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -62,10 +62,14 @@ TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem") TupleGetItemRel); struct ResolvedTypeInfo { - explicit ResolvedTypeInfo(Type checked_type) : checked_type(checked_type), type_args() {} - explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args) : checked_type(checked_type), type_args() {} - explicit ResolvedTypeInfo(Array<Type> type_args) : checked_type(), type_args(type_args) {} - ResolvedTypeInfo(const ResolvedTypeInfo& rti) : checked_type(rti.checked_type), type_args(rti.type_args) {} + explicit ResolvedTypeInfo(Type checked_type) + : checked_type(checked_type), type_args() {} + explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args) + : checked_type(checked_type), type_args() {} + explicit ResolvedTypeInfo(Array<Type> type_args) + : checked_type(), type_args(type_args) {} + ResolvedTypeInfo(const ResolvedTypeInfo& rti) + : checked_type(rti.checked_type), type_args(rti.type_args) {} ResolvedTypeInfo() : checked_type(), type_args() {} Type checked_type; From ca33fc2b44883e19022ad8e6006917ee795a5695 Mon Sep 17 00:00:00 2001 From: Jared Roesch <roeschinc@gmail.com> Date: Tue, 23 Oct 2018 14:44:01 -0700 Subject: [PATCH 4/7] Addresss some comments --- include/tvm/relay/expr.h | 2 +- src/relay/pass/type_infer.cc | 42 ++++++++++++++++-------------------- 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index f459663ad705..142982d48907 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -267,7 +267,7 @@ class CallNode : public ExprNode { * * \endcode */ - mutable tvm::Array<Type> type_args; + tvm::Array<Type> type_args; void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("op", &op); diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 39ea112f33f0..3ba249777ef8 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -62,19 +62,16 @@ TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem") TupleGetItemRel); struct ResolvedTypeInfo { - explicit ResolvedTypeInfo(Type checked_type) - : checked_type(checked_type), type_args() {} explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args) - : checked_type(checked_type), type_args() {} - explicit ResolvedTypeInfo(Array<Type> type_args) - : checked_type(), type_args(type_args) {} + : checked_type(checked_type), type_args(type_args) {} ResolvedTypeInfo(const ResolvedTypeInfo& rti) : checked_type(rti.checked_type), type_args(rti.type_args) {} - ResolvedTypeInfo() : checked_type(), type_args() {} + ResolvedTypeInfo() : checked_type() {} Type checked_type; // Only allocated when the expression is a call. - Array<Type> type_args; + + Array<Type> type_args = Array<Type>(NodePtr<Node>(nullptr)); }; // @@ -194,7 +191,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { } CHECK(!type_map_.count(op->var)); // NOTE: no scoping is necessary because var are unique in program - type_map_[op->var].checked_type = { vtype }; + type_map_[op->var].checked_type = vtype; // ResolvedTypeInfo(vtype, Array<Type>(NodePtr<Node>(nullptr))); return GetType(op->body); } @@ -264,7 +261,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { void AddTypeArgs(const Expr& expr, Array<Type> type_args) { auto type_info = type_map_.find(expr); if (type_info == type_map_.end()) { - type_map_.insert({expr, ResolvedTypeInfo(type_args) }); + type_map_.insert({expr, ResolvedTypeInfo(Type(), type_args)}); } else { CHECK(!type_info->second.type_args.defined()); type_info->second.type_args = type_args; @@ -391,20 +388,7 @@ class TypeInferencer::Resolver : public ExprMutator { } Expr VisitExpr_(const CallNode* op) final { - auto call = GetRef<Call>(op); - auto it = tmap_.find(call); - if (it != tmap_.end()) { - Call new_op = Downcast<Call>(AttachCheckedType(op)); - new_op->type_args = it->second.type_args; - - for (int i = 0; i < new_op->type_args.size(); i++) { - new_op->type_args.Set(i, solver_->Resolve(new_op->type_args[i])); - } - - return new_op; - } else { - return AttachCheckedType(op); - } + return AttachCheckedType(op); } Expr VisitExpr_(const LetNode* op) final { @@ -434,6 +418,18 @@ class TypeInferencer::Resolver : public ExprMutator { } new_e->checked_type_ = checked_type; } + + if (it->second.type_args.defined()) { + Call call = Downcast<Call>(new_e); + const CallNode* const_call_ref = call.operator->(); + CallNode* call_ref = const_cast<CallNode*>(const_call_ref); + call_ref->type_args = it->second.type_args; + + for (size_t i = 0; i < call->type_args.size(); i++) { + call_ref->type_args.Set(i, solver_->Resolve(call->type_args[i])); + } + } + return new_e; } From 21e24381aaaa0fac444f30050b8f30e82d7b497d Mon Sep 17 00:00:00 2001 From: Jared Roesch <roeschinc@gmail.com> Date: Tue, 23 Oct 2018 14:46:22 -0700 Subject: [PATCH 5/7] Fix lint --- src/relay/pass/type_infer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 3ba249777ef8..351a509b2874 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -191,7 +191,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { } CHECK(!type_map_.count(op->var)); // NOTE: no scoping is necessary because var are unique in program - type_map_[op->var].checked_type = vtype; // ResolvedTypeInfo(vtype, Array<Type>(NodePtr<Node>(nullptr))); + type_map_[op->var].checked_type = vtype; return GetType(op->body); } From 9dc874ec7b055235183a55afb88c3e7167321016 Mon Sep 17 00:00:00 2001 From: Jared Roesch <roeschinc@gmail.com> Date: Tue, 23 Oct 2018 14:58:31 -0700 Subject: [PATCH 6/7] Fix test case --- tests/python/relay/test_ir_text_printer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 29814ecc5eb7..1d272236c680 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -30,7 +30,7 @@ def test_env(): env["myf"] = f text = env.astext() assert "def @myf" in text - assert "%1 = add(%0, %0) # ty=float32" in text + assert "%1 = add<float32, float32>(%0, %0) # ty=float32" in text show(text) From ba8c4cbb22ff67f8ea2b4c099fc1cafcaee8fa27 Mon Sep 17 00:00:00 2001 From: Jared Roesch <roeschinc@gmail.com> Date: Tue, 23 Oct 2018 16:19:51 -0700 Subject: [PATCH 7/7] Fix printing for primitive ops Address feedback --- include/tvm/relay/op.h | 30 ++++++++++++++++++++++ src/relay/ir/text_printer.cc | 3 +-- src/relay/pass/type_infer.cc | 4 +-- tests/python/relay/test_ir_text_printer.py | 2 +- 4 files changed, 33 insertions(+), 6 deletions(-) diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index fe6d957e79ed..9f28fbebccfc 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -485,6 +485,36 @@ inline ValueType OpMap<ValueType>::get(const Op& op, return map_.get<ValueType>(op, def_value); } +/*! + * \brief Check that an expression is a "primtive operator". + * + * Will return true if the expression is an operator which + * matches the form of primtive operators registered directly + * by the Relay codebase. + * + * That is the arguments are all type variables, and there is a single + * type relation applied to the input and output types. + */ +inline bool IsPrimitiveOp(const Expr& expr) { + const auto* op = expr.as<OpNode>(); + + if (!op) { + return false; + } + + const auto& fn_ty = op->op_type; + if (fn_ty->type_constraints.size() != 1) return false; + + const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>(); + if (rel == nullptr) return false; + // validate if the type parameter matches up + for (size_t i = 0; i < fn_ty->type_params.size(); ++i) { + if (!fn_ty->type_params[i].same_as(rel->args[i])) return false; + } + + return true; +} + } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_H_ diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 09f2ba23141f..86ca4d74a974 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -275,7 +275,6 @@ class TextPrinter : } TextValue VisitExpr_(const CallNode* op) final { - // TODO(tqchen, M.K.): support generic call // possibly through meta-data TextValue call_op = GetValue(op->op); std::vector<TextValue> args; @@ -289,7 +288,7 @@ class TextPrinter : auto type_args = op->type_args; - if (type_args.size() > 0U) { + if (!IsPrimitiveOp(op->op) && type_args.size() > 0U) { stream_ << "<"; for (size_t i = 0; i < op->type_args.size(); ++i) { this->PrintType(type_args[i], stream_); diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 351a509b2874..87fdb1c0ffba 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -64,9 +64,7 @@ TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem") struct ResolvedTypeInfo { explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args) : checked_type(checked_type), type_args(type_args) {} - ResolvedTypeInfo(const ResolvedTypeInfo& rti) - : checked_type(rti.checked_type), type_args(rti.type_args) {} - ResolvedTypeInfo() : checked_type() {} + ResolvedTypeInfo() {} Type checked_type; // Only allocated when the expression is a call. diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 1d272236c680..29814ecc5eb7 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -30,7 +30,7 @@ def test_env(): env["myf"] = f text = env.astext() assert "def @myf" in text - assert "%1 = add<float32, float32>(%0, %0) # ty=float32" in text + assert "%1 = add(%0, %0) # ty=float32" in text show(text)