From ba8c4cbb22ff67f8ea2b4c099fc1cafcaee8fa27 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 23 Oct 2018 16:19:51 -0700 Subject: [PATCH] Fix printing for primitive ops Address feedback --- include/tvm/relay/op.h | 30 ++++++++++++++++++++++ src/relay/ir/text_printer.cc | 3 +-- src/relay/pass/type_infer.cc | 4 +-- tests/python/relay/test_ir_text_printer.py | 2 +- 4 files changed, 33 insertions(+), 6 deletions(-) diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index fe6d957e79ed..9f28fbebccfc 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -485,6 +485,36 @@ inline ValueType OpMap::get(const Op& op, return map_.get(op, def_value); } +/*! + * \brief Check that an expression is a "primtive operator". + * + * Will return true if the expression is an operator which + * matches the form of primtive operators registered directly + * by the Relay codebase. + * + * That is the arguments are all type variables, and there is a single + * type relation applied to the input and output types. + */ +inline bool IsPrimitiveOp(const Expr& expr) { + const auto* op = expr.as(); + + if (!op) { + return false; + } + + const auto& fn_ty = op->op_type; + if (fn_ty->type_constraints.size() != 1) return false; + + const TypeRelationNode* rel = fn_ty->type_constraints[0].as(); + if (rel == nullptr) return false; + // validate if the type parameter matches up + for (size_t i = 0; i < fn_ty->type_params.size(); ++i) { + if (!fn_ty->type_params[i].same_as(rel->args[i])) return false; + } + + return true; +} + } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_H_ diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 09f2ba23141f..86ca4d74a974 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -275,7 +275,6 @@ class TextPrinter : } TextValue VisitExpr_(const CallNode* op) final { - // TODO(tqchen, M.K.): support generic call // possibly through meta-data TextValue call_op = GetValue(op->op); std::vector args; @@ -289,7 +288,7 @@ class TextPrinter : auto type_args = op->type_args; - if (type_args.size() > 0U) { + if (!IsPrimitiveOp(op->op) && type_args.size() > 0U) { stream_ << "<"; for (size_t i = 0; i < op->type_args.size(); ++i) { this->PrintType(type_args[i], stream_); diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 351a509b2874..87fdb1c0ffba 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -64,9 +64,7 @@ TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem") struct ResolvedTypeInfo { explicit ResolvedTypeInfo(Type checked_type, Array type_args) : checked_type(checked_type), type_args(type_args) {} - ResolvedTypeInfo(const ResolvedTypeInfo& rti) - : checked_type(rti.checked_type), type_args(rti.type_args) {} - ResolvedTypeInfo() : checked_type() {} + ResolvedTypeInfo() {} Type checked_type; // Only allocated when the expression is a call. diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 1d272236c680..29814ecc5eb7 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -30,7 +30,7 @@ def test_env(): env["myf"] = f text = env.astext() assert "def @myf" in text - assert "%1 = add(%0, %0) # ty=float32" in text + assert "%1 = add(%0, %0) # ty=float32" in text show(text)