From eff1ec6484b01ea274a9074f9eaef3fdbe99fa65 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 23 Oct 2018 16:19:51 -0700 Subject: [PATCH] Fix printing for primitive ops --- include/tvm/relay/op.h | 19 +++++++++++++++++++ src/relay/ir/text_printer.cc | 3 +-- tests/python/relay/test_ir_text_printer.py | 2 +- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index fe6d957e79edc..7f9218599fce1 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -485,6 +485,25 @@ inline ValueType OpMap::get(const Op& op, return map_.get(op, def_value); } +inline bool IsPrimitiveOp(const Expr& expr) { + if (!expr.as()) { + return false; + } + + auto op = Downcast(expr); + auto fn_ty = op->op_type; + if (fn_ty->type_constraints.size() != 1) return false; + + const TypeRelationNode* rel = fn_ty->type_constraints[0].as(); + if (rel == nullptr) return false; + // validate if the type parameter matches up + for (size_t i = 0; i < fn_ty->type_params.size(); ++i) { + if (!fn_ty->type_params[i].same_as(rel->args[i])) return false; + } + + return true; +} + } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_H_ diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 09f2ba23141f8..86ca4d74a974e 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -275,7 +275,6 @@ class TextPrinter : } TextValue VisitExpr_(const CallNode* op) final { - // TODO(tqchen, M.K.): support generic call // possibly through meta-data TextValue call_op = GetValue(op->op); std::vector args; @@ -289,7 +288,7 @@ class TextPrinter : auto type_args = op->type_args; - if (type_args.size() > 0U) { + if (!IsPrimitiveOp(op->op) && type_args.size() > 0U) { stream_ << "<"; for (size_t i = 0; i < op->type_args.size(); ++i) { this->PrintType(type_args[i], stream_); diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 1d272236c6803..29814ecc5eb77 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -30,7 +30,7 @@ def test_env(): env["myf"] = f text = env.astext() assert "def @myf" in text - assert "%1 = add(%0, %0) # ty=float32" in text + assert "%1 = add(%0, %0) # ty=float32" in text show(text)