Skip to content

Commit

Permalink
Fix printing for primitive ops
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Oct 23, 2018
1 parent 9dc874e commit eff1ec6
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
19 changes: 19 additions & 0 deletions include/tvm/relay/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,25 @@ inline ValueType OpMap<ValueType>::get(const Op& op,
return map_.get<ValueType>(op, def_value);
}

inline bool IsPrimitiveOp(const Expr& expr) {
if (!expr.as<OpNode>()) {
return false;
}

auto op = Downcast<Op>(expr);
auto fn_ty = op->op_type;
if (fn_ty->type_constraints.size() != 1) return false;

const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
if (rel == nullptr) return false;
// validate if the type parameter matches up
for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
}

return true;
}

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_H_
3 changes: 1 addition & 2 deletions src/relay/ir/text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,6 @@ class TextPrinter :
}

TextValue VisitExpr_(const CallNode* op) final {
// TODO(tqchen, M.K.): support generic call
// possibly through meta-data
TextValue call_op = GetValue(op->op);
std::vector<TextValue> args;
Expand All @@ -289,7 +288,7 @@ class TextPrinter :

auto type_args = op->type_args;

if (type_args.size() > 0U) {
if (!IsPrimitiveOp(op->op) && type_args.size() > 0U) {
stream_ << "<";
for (size_t i = 0; i < op->type_args.size(); ++i) {
this->PrintType(type_args[i], stream_);
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_ir_text_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_env():
env["myf"] = f
text = env.astext()
assert "def @myf" in text
assert "%1 = add<float32, float32>(%0, %0) # ty=float32" in text
assert "%1 = add(%0, %0) # ty=float32" in text
show(text)


Expand Down

0 comments on commit eff1ec6

Please sign in to comment.