From b6aeee82d8c380cad886f4631601f5fb1bba22b0 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 4 Jan 2020 20:09:18 -0800 Subject: [PATCH] [REFACTOR] IRPrinter->NodePrinter, move to node/printer.h (#4622) Rationale: printer is a common infra that is shared across all nodes. --- include/tvm/expr.h | 31 ----- include/tvm/node/functor.h | 20 +-- include/tvm/node/node.h | 1 + include/tvm/node/printer.h | 61 ++++++++++ src/arithmetic/const_int_bound.cc | 4 +- src/arithmetic/int_set.cc | 4 +- src/arithmetic/modular_set.cc | 4 +- src/codegen/build_module.cc | 8 +- src/codegen/llvm/codegen_arm.cc | 3 + src/codegen/llvm/codegen_x86_64.cc | 2 + src/codegen/llvm/llvm_module.cc | 2 + src/codegen/source_module.cc | 1 + src/contrib/hybrid/codegen_hybrid.cc | 1 + src/ir/span.cc | 9 +- src/ir/type.cc | 13 +- src/lang/attrs.cc | 4 +- src/lang/buffer.cc | 4 +- src/lang/data_layout.cc | 8 +- src/lang/expr.cc | 37 +----- src/lang/ir.cc | 174 +++++++++++++-------------- src/lang/lowered_func.cc | 4 +- src/lang/target_info.cc | 4 +- src/lang/tensor.cc | 12 +- src/node/env_func.cc | 6 +- src/node/printer.cc | 52 ++++++++ src/op/compute_op.cc | 4 +- src/op/extern_op.cc | 4 +- src/op/hybrid_op.cc | 4 +- src/op/placeholder_op.cc | 4 +- src/op/scan_op.cc | 4 +- src/op/tensor_compute_op.cc | 4 +- src/relay/backend/interpreter.cc | 24 ++-- src/relay/ir/adt.cc | 32 ++--- src/relay/ir/expr.cc | 50 ++++---- src/relay/ir/module.cc | 6 +- src/relay/ir/op.cc | 4 +- src/relay/ir/type.cc | 26 ++-- src/relay/pass/pass_manager.cc | 22 ++-- src/relay/pass/quantize/quantize.cc | 4 +- src/schedule/schedule_lang.cc | 16 +-- 40 files changed, 374 insertions(+), 303 deletions(-) create mode 100644 include/tvm/node/printer.h create mode 100644 src/node/printer.cc diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 0605cc512690..aee565dcbc9c 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -470,37 +470,6 @@ inline std::unordered_map as_unordered_map(const Map& dmap) { } return ret; } - -// Printer infra. -/*! \brief A Pretty printer class to print the IR. */ -class IRPrinter { - public: - /*! \brief The output stream */ - std::ostream& stream; - /*! \brief The indentation level. */ - int indent{0}; - explicit IRPrinter(std::ostream& stream) // NOLINT(*) - : stream(stream) {} - - /*! \brief The node to be printed. */ - TVM_DLL void Print(const ObjectRef& node); - /*! \brief Print indent to the stream */ - TVM_DLL void PrintIndent(); - // Allow registration to be printer. - using FType = NodeFunctor; - TVM_DLL static FType& vtable(); -}; -} // namespace tvm - -namespace tvm { -namespace runtime { -// default print function for all objects -// provide in the runtime namespace as this is where objectref originally comes from. -inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*) - IRPrinter(os).Print(n); - return os; -} -} // namespace runtime } // namespace tvm namespace std { diff --git a/include/tvm/node/functor.h b/include/tvm/node/functor.h index d56fb8dde799..d925fbde4671 100644 --- a/include/tvm/node/functor.h +++ b/include/tvm/node/functor.h @@ -24,14 +24,16 @@ #define TVM_NODE_FUNCTOR_H_ #include -#include -#include +#include #include #include #include namespace tvm { + +using runtime::ObjectRef; + /*! * \brief A dynamically dispatched functor on the type of the first argument. * @@ -137,11 +139,11 @@ class NodeFunctor { * \brief Useful macro to set NodeFunctor dispatch in a global static field. * * \code - * // Use NodeFunctor to implement IRPrinter similar to Visitor Pattern. + * // Use NodeFunctor to implement NodePrinter similar to Visitor Pattern. * // vtable allows easy patch of new Node types, without changing - * // interface of IRPrinter. + * // interface of NodePrinter. * - * class IRPrinter { + * class NodePrinter { * public: * std::ostream& stream; * // the dispatch function. @@ -150,18 +152,18 @@ class NodeFunctor { * f(e, this); * } * - * using FType = NodeFunctor; + * using FType = NodeFunctor; * // function to return global function table * static FType& vtable(); * }; * * // in cpp/cc file - * IRPrinter::FType& IRPrinter::vtable() { // NOLINT(*) + * NodePrinter::FType& NodePrinter::vtable() { // NOLINT(*) * static FType inst; return inst; * } * - * TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) - * .set_dispatch([](const ObjectRef& ref, IRPrinter* p) { + * TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) + * .set_dispatch([](const ObjectRef& ref, NodePrinter* p) { * auto* n = static_cast(ref.get()); * p->print(n->a); * p->stream << '+' diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index bb5da415c463..54eb436de9d3 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -38,6 +38,7 @@ #include #include #include +#include #include #include diff --git a/include/tvm/node/printer.h b/include/tvm/node/printer.h new file mode 100644 index 000000000000..1e6c3e5bb8fb --- /dev/null +++ b/include/tvm/node/printer.h @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/node/printer.h + * \brief Printer class to print repr string of each AST/IR nodes. + */ +#ifndef TVM_NODE_PRINTER_H_ +#define TVM_NODE_PRINTER_H_ + +#include +#include + +namespace tvm { +/*! \brief A printer class to print the AST/IR nodes. */ +class NodePrinter { + public: + /*! \brief The output stream */ + std::ostream& stream; + /*! \brief The indentation level. */ + int indent{0}; + + explicit NodePrinter(std::ostream& stream) // NOLINT(*) + : stream(stream) {} + + /*! \brief The node to be printed. */ + TVM_DLL void Print(const ObjectRef& node); + /*! \brief Print indent to the stream */ + TVM_DLL void PrintIndent(); + // Allow registration to be printer. + using FType = NodeFunctor; + TVM_DLL static FType& vtable(); +}; +} // namespace tvm + +namespace tvm { +namespace runtime { +// default print function for all objects +// provide in the runtime namespace as this is where objectref originally comes from. +inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*) + NodePrinter(os).Print(n); + return os; +} +} // namespace runtime +} // namespace tvm +#endif // TVM_NODE_PRINTER_H_ diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index 16e489a9c818..ef405d8026c9 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -51,8 +51,8 @@ inline void PrintBoundValue(std::ostream& os, int64_t val) { } } -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "ConstIntBound["; PrintBoundValue(p->stream, op->min_value); diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index bdfcc1ae0fff..bf1cdf0466b7 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -811,8 +811,8 @@ IntSet EvalSet(Range r, TVM_REGISTER_NODE_TYPE(IntervalSetNode); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter *p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "IntervalSet" << "[" << op->min_value << ", " diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 5ab1bd386748..a83e98760baa 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -44,8 +44,8 @@ ModularSet::ModularSet(int64_t coeff, int64_t base) { data_ = std::move(node); } -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter *p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "ModularSet(" << "coeff=" << op->coeff << ", base=" diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index eab220ea972e..38f0b9532133 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -41,8 +41,8 @@ using runtime::PackedFunc; TVM_REGISTER_NODE_TYPE(TargetNode); TVM_REGISTER_NODE_TYPE(GenericFuncNode); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter *p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << op->str(); }); @@ -665,8 +665,8 @@ tvm::BuildConfig BuildConfig::Current() { TVM_REGISTER_NODE_TYPE(BuildConfigNode); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter *p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "build_config("; p->stream << "data_alignment=" << op->data_alignment << ", "; diff --git a/src/codegen/llvm/codegen_arm.cc b/src/codegen/llvm/codegen_arm.cc index 4c092dfe377a..39d51147bbff 100644 --- a/src/codegen/llvm/codegen_arm.cc +++ b/src/codegen/llvm/codegen_arm.cc @@ -22,6 +22,9 @@ * \brief ARM specific code generator */ #ifdef TVM_LLVM_VERSION + +#include + #include "codegen_cpu.h" namespace tvm { diff --git a/src/codegen/llvm/codegen_x86_64.cc b/src/codegen/llvm/codegen_x86_64.cc index 5d72b56df376..d6138830bfb4 100644 --- a/src/codegen/llvm/codegen_x86_64.cc +++ b/src/codegen/llvm/codegen_x86_64.cc @@ -22,6 +22,8 @@ * \brief X86-64 specific code generator */ #ifdef TVM_LLVM_VERSION + +#include #include "codegen_cpu.h" #include "llvm/MC/MCSubtargetInfo.h" diff --git a/src/codegen/llvm/llvm_module.cc b/src/codegen/llvm/llvm_module.cc index e042081b1b9f..937a1103eda4 100644 --- a/src/codegen/llvm/llvm_module.cc +++ b/src/codegen/llvm/llvm_module.cc @@ -22,7 +22,9 @@ * \brief LLVM runtime module for TVM */ #ifdef TVM_LLVM_VERSION + #include +#include #include #include #include "llvm_common.h" diff --git a/src/codegen/source_module.cc b/src/codegen/source_module.cc index e23ce60223f2..b9807b37bb73 100644 --- a/src/codegen/source_module.cc +++ b/src/codegen/source_module.cc @@ -22,6 +22,7 @@ * \brief Source code module, only for viewing */ #include +#include #include "codegen_source_base.h" #include "../runtime/file_util.h" #include "../runtime/meta_data.h" diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index beda99df7b15..301602fb8238 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -20,6 +20,7 @@ /*! * \file codegen_hybrid.cc */ +#include #include #include #include "codegen_hybrid.h" diff --git a/src/ir/span.cc b/src/ir/span.cc index 1d9f07951183..1be4e32fb037 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -21,6 +21,7 @@ * \brief The span data structure. */ #include +#include #include namespace tvm { @@ -48,8 +49,8 @@ SourceName SourceName::Get(const std::string& name) { TVM_REGISTER_GLOBAL("relay._make.SourceName") .set_body_typed(SourceName::Get); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "SourceName(" << node->name << ", " << node << ")"; }); @@ -73,8 +74,8 @@ TVM_REGISTER_NODE_TYPE(SpanNode); TVM_REGISTER_GLOBAL("relay._make.Span") .set_body_typed(SpanNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "Span(" << node->source << ", " << node->lineno << ", " << node->col_offset << ")"; diff --git a/src/ir/type.cc b/src/ir/type.cc index ef5f75b86a2c..3c56f8d3ae53 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -22,6 +22,7 @@ * \brief Common type system AST nodes throughout the IR. */ #include +#include #include namespace tvm { @@ -40,8 +41,8 @@ TVM_REGISTER_GLOBAL("relay._make.TypeVar") return TypeVarNode::make(name, static_cast(kind)); }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "TypeVar(" << node->name_hint << ", " << node->kind << ")"; @@ -61,8 +62,8 @@ TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar") return GlobalTypeVarNode::make(name, static_cast(kind)); }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "GlobalTypeVar(" << node->name_hint << ", " << node->kind << ")"; @@ -85,8 +86,8 @@ TVM_REGISTER_NODE_TYPE(FuncTypeNode); TVM_REGISTER_GLOBAL("relay._make.FuncType") .set_body_typed(FuncTypeNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "FuncType(" << node->type_params << ", " << node->arg_types << ", " << node->ret_type << ", " diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index fd28268cb480..d69e3e2ad703 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -61,8 +61,8 @@ Attrs DictAttrsNode::make(Map dict) { return Attrs(n); } -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter *p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << op->dict; }); diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 9bbd8d62105f..22efa1dcc1bf 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -450,8 +450,8 @@ Buffer BufferNode::make(Var data, return Buffer(n); } -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter *p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "buffer(" << op->name << ", " << op << ")"; }); diff --git a/src/lang/data_layout.cc b/src/lang/data_layout.cc index 58f033b69e51..c4a6b35c0724 100644 --- a/src/lang/data_layout.cc +++ b/src/lang/data_layout.cc @@ -194,8 +194,8 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const { return -1; } -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* l = static_cast(node.get()); p->stream << "Layout(" << l->name << ")"; }); @@ -361,8 +361,8 @@ BijectiveLayout BijectiveLayoutNode::make(const Layout& src_layout, return BijectiveLayout(n); } -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* b = static_cast(node.get()); p->stream << "BijectiveLayout(" << b->src_layout.name() << "->" << b->dst_layout.name() << ")"; diff --git a/src/lang/expr.cc b/src/lang/expr.cc index 5a54f0407c8d..eed693808708 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -97,33 +97,8 @@ Var var(std::string name_hint, DataType t) { return Var(name_hint, t); } -void IRPrinter::Print(const ObjectRef& ir) { - static const FType& f = vtable(); - if (!ir.defined()) { - stream << "(nullptr)"; - } else { - if (f.can_dispatch(ir)) { - f(ir, this); - } else { - // default value, output type key and addr. - stream << ir->GetTypeKey() << "(" << ir.get() << ")"; - } - } -} - -void IRPrinter::PrintIndent() { - for (int i = 0; i < indent; ++i) { - stream << ' '; - } -} - -IRPrinter::FType& IRPrinter::vtable() { - static FType inst; - return inst; -} - -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); if (op->dtype == DataType::Int(32)) { p->stream << op->value; @@ -132,8 +107,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) } }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "iter_var("; if (op->var->name_hint.length() != 0) { @@ -148,8 +123,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << ")"; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; }); diff --git a/src/lang/ir.cc b/src/lang/ir.cc index d5cc285ac861..5b410d1e3741 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -552,14 +552,14 @@ Stmt Evaluate::make(Expr value) { } // Printers -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "(" << op->dtype << ")" << op->value; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); auto& stream = p->stream; switch (op->dtype.bits()) { @@ -577,8 +577,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) } }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); auto& stream = p->stream; stream << '"'; @@ -613,20 +613,20 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) stream << '"'; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << op->dtype << '('; p->Print(op->value); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); // omit the type // stream << op->name << "." << op->type; p->stream << op->name_hint; }) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -634,7 +634,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -642,7 +642,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -650,7 +650,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->Print(op->b); p->stream << ')'; }) -.set_dispatch
([](const ObjectRef& node, IRPrinter* p) { +.set_dispatch
([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -658,7 +658,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -666,7 +666,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "min("; p->Print(op->a); @@ -674,7 +674,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->Print(op->b); p->stream << ")"; }) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "max("; p->Print(op->a); @@ -682,7 +682,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->Print(op->b); p->stream << ")"; }) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -690,7 +690,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -698,7 +698,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -706,7 +706,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -714,7 +714,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -722,7 +722,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -731,20 +731,20 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << ')'; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "floordiv(" << op->a << ", " << op->b << ")"; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "floormod(" << op->a << ", " << op->b << ")"; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -753,8 +753,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << ')'; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -763,15 +763,15 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << ')'; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << '!'; p->Print(op->a); }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "select("; p->Print(op->condition); @@ -782,8 +782,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << ")"; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << op->buffer_var << "["; p->Print(op->index); @@ -794,8 +794,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) } }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "ramp("; p->Print(op->base); @@ -804,16 +804,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << ", " << op->lanes << ")"; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "x" << op->lanes << "("; p->Print(op->value); p->stream << ")"; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << op->name << "("; for (size_t i = 0; i < op->args.size(); ++i) { @@ -825,8 +825,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << ")"; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "(let " << op->var << " = "; p->Print(op->value); @@ -835,8 +835,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << ")"; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << "let " << op->var << " = "; @@ -845,8 +845,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->Print(op->body); }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << "// attr ["; @@ -858,8 +858,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->Print(op->body); }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << "assert("; @@ -870,8 +870,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->Print(op->body); }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); if (op->is_producer) { p->PrintIndent(); @@ -904,8 +904,8 @@ std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*) return out; } -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << op->for_type << " (" << op->loop_var << ", "; @@ -922,8 +922,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "}\n"; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << op->buffer_var << "["; @@ -937,8 +937,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << '\n'; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << op->func->func_name() << "("; @@ -955,8 +955,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << '\n'; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << "allocate " << op->buffer_var << "[" << op->dtype; @@ -973,16 +973,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->Print(op->body); }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << "free " << op->buffer_var; p->stream << '\n'; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << "realize " << op->func->func_name() << "("; @@ -1012,8 +1012,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "}\n"; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << "prefetch " << op->func->func_name() << "("; @@ -1031,15 +1031,15 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) } }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->Print(op->first); if (op->rest.defined()) p->Print(op->rest); }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); while (true) { @@ -1069,8 +1069,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "}\n"; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->Print(op->value); @@ -1078,7 +1078,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); template -void PrintList(const Array &exprs, IRPrinter* p) { +void PrintList(const Array &exprs, NodePrinter* p) { for (size_t i = 0; i < exprs.size(); ++i) { p->Print(exprs[i]); if (i < exprs.size() - 1) { @@ -1087,8 +1087,8 @@ void PrintList(const Array &exprs, IRPrinter* p) { } } -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "shuffle("; PrintList(op->vectors, p); @@ -1098,8 +1098,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); // Container printer -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << '['; for (size_t i = 0 ; i < op->data.size(); ++i) { @@ -1111,8 +1111,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << ']'; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << '{'; for (auto it = op->data.begin(); it != op->data.end(); ++it) { @@ -1126,8 +1126,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << '}'; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << '{'; for (auto it = op->data.begin(); it != op->data.end(); ++it) { @@ -1140,8 +1140,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << '}'; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "reduce(combiner=" << op->combiner; @@ -1152,8 +1152,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << ")"; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "comm_reducer(result=" << op->result << ", lhs=" << op->lhs @@ -1162,8 +1162,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) << ")"; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { p->stream << "?"; }); diff --git a/src/lang/lowered_func.cc b/src/lang/lowered_func.cc index 2ef648b975c8..a6b6908d95f9 100644 --- a/src/lang/lowered_func.cc +++ b/src/lang/lowered_func.cc @@ -24,8 +24,8 @@ namespace tvm { -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter *p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "LoweredFunc(" << op->name << ", " << op << ")"; }); diff --git a/src/lang/target_info.cc b/src/lang/target_info.cc index 8c45a19cf818..6bdcf8800967 100644 --- a/src/lang/target_info.cc +++ b/src/lang/target_info.cc @@ -26,8 +26,8 @@ namespace tvm { -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter *p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "mem-info(" << "unit_bits=" << op->unit_bits << ", " diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc index e9ca89a4b31e..d0e81b9ca4d7 100644 --- a/src/lang/tensor.cc +++ b/src/lang/tensor.cc @@ -67,8 +67,8 @@ Tensor TensorNode::make(Array shape, return Tensor(n); } -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter *p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* t = static_cast(node.get()); p->stream << "Tensor(shape=" << t->shape << ", op.name=" << t->op->name << ')'; @@ -99,8 +99,8 @@ TensorIntrin TensorIntrinNode::make(std::string name, return TensorIntrin(n); } -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")"; }); @@ -124,8 +124,8 @@ TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, return TensorIntrinCall(n); } -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter *p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* n = static_cast(node.get()); p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")"; }); diff --git a/src/node/env_func.cc b/src/node/env_func.cc index 52bb61d7517c..4b5bc4cbde5a 100644 --- a/src/node/env_func.cc +++ b/src/node/env_func.cc @@ -26,13 +26,13 @@ namespace tvm { + using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; - -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter *p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "EnvFunc(" << op->name << ")"; }); diff --git a/src/node/printer.cc b/src/node/printer.cc new file mode 100644 index 000000000000..15171dfefc4c --- /dev/null +++ b/src/node/printer.cc @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Printer utilities + * \file node/printer.cc + */ +#include + +namespace tvm { + +void NodePrinter::Print(const ObjectRef& node) { + static const FType& f = vtable(); + if (!node.defined()) { + stream << "(nullptr)"; + } else { + if (f.can_dispatch(node)) { + f(node, this); + } else { + // default value, output type key and addr. + stream << node->GetTypeKey() << "(" << node.get() << ")"; + } + } +} + +void NodePrinter::PrintIndent() { + for (int i = 0; i < indent; ++i) { + stream << ' '; + } +} + +NodePrinter::FType& NodePrinter::vtable() { + static FType inst; + return inst; +} +} // namespace tvm diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 85459b4d723d..939327890fec 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -39,8 +39,8 @@ namespace tvm { using namespace ir; -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "compute(" << op->name << ", " << op << ")"; }); diff --git a/src/op/extern_op.cc b/src/op/extern_op.cc index b921c86f3556..c6102ed556e0 100644 --- a/src/op/extern_op.cc +++ b/src/op/extern_op.cc @@ -30,8 +30,8 @@ namespace tvm { using namespace ir; // ExternOpNode -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "extern(" << op->name << ", " << op << ")"; }); diff --git a/src/op/hybrid_op.cc b/src/op/hybrid_op.cc index 4de5f1cff18d..b4f29f5a36c8 100644 --- a/src/op/hybrid_op.cc +++ b/src/op/hybrid_op.cc @@ -36,8 +36,8 @@ namespace tvm { using namespace ir; // HybridOpNode -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "hybrid(" << op->name << ", " << op << ")"; }); diff --git a/src/op/placeholder_op.cc b/src/op/placeholder_op.cc index 7863c8a52265..6414d5c39ac1 100644 --- a/src/op/placeholder_op.cc +++ b/src/op/placeholder_op.cc @@ -26,8 +26,8 @@ namespace tvm { // PlaceholderOpNode -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "placeholder(" << op->name << ", " << op << ")"; }); diff --git a/src/op/scan_op.cc b/src/op/scan_op.cc index 57f16f82c54b..ef2d1efcc089 100644 --- a/src/op/scan_op.cc +++ b/src/op/scan_op.cc @@ -31,8 +31,8 @@ namespace tvm { using namespace ir; -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "scan(" << op->name << ", " << op << ")"; }); diff --git a/src/op/tensor_compute_op.cc b/src/op/tensor_compute_op.cc index d82363e496ca..76ecf3417d36 100644 --- a/src/op/tensor_compute_op.cc +++ b/src/op/tensor_compute_op.cc @@ -33,8 +33,8 @@ namespace tvm { using namespace ir; // TensorComputeOpNode -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "tensor_compute_op(" << op->name << ", " << op << ")"; }); diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 203fbfab27c4..c1e4fd59d042 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -54,8 +54,8 @@ Closure ClosureNode::make(tvm::Map env, Function func) { TVM_REGISTER_GLOBAL("relay._make.Closure") .set_body_typed(ClosureNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "ClosureNode(" << node->func << ", " << node->env << ")"; }); @@ -73,8 +73,8 @@ RecClosure RecClosureNode::make(Closure clos, Var bind) { TVM_REGISTER_GLOBAL("relay._make.RecClosure") .set_body_typed(RecClosureNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "RecClosureNode(" << node->clos << ")"; }); @@ -88,8 +88,8 @@ TupleValue TupleValueNode::make(tvm::Array value) { TVM_REGISTER_GLOBAL("relay._make.TupleValue") .set_body_typed(TupleValueNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "TupleValueNode(" << node->fields << ")"; }); @@ -100,8 +100,8 @@ TensorValue TensorValueNode::make(runtime::NDArray data) { return TensorValue(n); } -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); auto to_str = GetPackedFunc("relay._tensor_value_repr"); std::string data_str = to_str(GetRef(node)); @@ -122,8 +122,8 @@ TVM_REGISTER_GLOBAL("relay._make.RefValue") TVM_REGISTER_NODE_TYPE(RefValueNode); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "RefValueNode(" << node->value << ")"; }); @@ -143,8 +143,8 @@ TVM_REGISTER_GLOBAL("relay._make.ConstructorValue") TVM_REGISTER_NODE_TYPE(ConstructorValueNode); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "ConstructorValueNode(" << node->tag << "," << node->fields << ")"; diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index ff477897f412..1769298a4433 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -37,8 +37,8 @@ TVM_REGISTER_NODE_TYPE(PatternWildcardNode); TVM_REGISTER_GLOBAL("relay._make.PatternWildcard") .set_body_typed(PatternWildcardNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { p->stream << "PatternWildcardNode()"; }); @@ -53,8 +53,8 @@ TVM_REGISTER_NODE_TYPE(PatternVarNode); TVM_REGISTER_GLOBAL("relay._make.PatternVar") .set_body_typed(PatternVarNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "PatternVarNode(" << node->var << ")"; }); @@ -72,8 +72,8 @@ TVM_REGISTER_NODE_TYPE(PatternConstructorNode); TVM_REGISTER_GLOBAL("relay._make.PatternConstructor") .set_body_typed(PatternConstructorNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "PatternConstructorNode(" << node->constructor << ", " << node->patterns << ")"; @@ -90,8 +90,8 @@ TVM_REGISTER_NODE_TYPE(PatternTupleNode); TVM_REGISTER_GLOBAL("relay._make.PatternTuple") .set_body_typed(PatternTupleNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "PatternTupleNode(" << node->patterns << ")"; }); @@ -111,8 +111,8 @@ TVM_REGISTER_NODE_TYPE(ConstructorNode); TVM_REGISTER_GLOBAL("relay._make.Constructor") .set_body_typed(ConstructorNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "ConstructorNode(" << node->name_hint << ", " << node->inputs << ", " << node->belong_to << ")"; @@ -133,8 +133,8 @@ TVM_REGISTER_NODE_TYPE(TypeDataNode); TVM_REGISTER_GLOBAL("relay._make.TypeData") .set_body_typed(TypeDataNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", " << node->constructors << ")"; @@ -152,8 +152,8 @@ TVM_REGISTER_NODE_TYPE(ClauseNode); TVM_REGISTER_GLOBAL("relay._make.Clause") .set_body_typed(ClauseNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "ClauseNode(" << node->lhs << ", " << node->rhs << ")"; @@ -172,8 +172,8 @@ TVM_REGISTER_NODE_TYPE(MatchNode); TVM_REGISTER_GLOBAL("relay._make.Match") .set_body_typed(MatchNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "MatchNode(" << node->data << ", " << node->clauses << ", " << node->complete << ")"; diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 11689b079c67..e8f18e6ff734 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -26,7 +26,7 @@ namespace tvm { namespace relay { -using tvm::IRPrinter; +using tvm::NodePrinter; using namespace tvm::runtime; Constant ConstantNode::make(runtime::NDArray data) { @@ -40,8 +40,8 @@ TVM_REGISTER_NODE_TYPE(ConstantNode); TVM_REGISTER_GLOBAL("relay._make.Constant") .set_body_typed(ConstantNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); const PackedFunc* fprint = Registry::Get("relay._constant_repr"); CHECK(fprint) << "unable to find printing function for constants"; @@ -73,8 +73,8 @@ TVM_REGISTER_NODE_TYPE(TupleNode); TVM_REGISTER_GLOBAL("relay._make.Tuple") .set_body_typed(TupleNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "Tuple(" << node->fields << ")"; }); @@ -98,8 +98,8 @@ TVM_REGISTER_NODE_TYPE(VarNode); TVM_REGISTER_GLOBAL("relay._make.Var") .set_body_typed(static_cast(VarNode::make)); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "Var(" << node->name_hint(); if (node->type_annotation.defined()) { @@ -120,8 +120,8 @@ TVM_REGISTER_NODE_TYPE(GlobalVarNode); TVM_REGISTER_GLOBAL("relay._make.GlobalVar") .set_body_typed(GlobalVarNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "GlobalVar(" << node->name_hint << ")"; }); @@ -226,8 +226,8 @@ TVM_REGISTER_NODE_TYPE(FunctionNode); TVM_REGISTER_GLOBAL("relay._make.Function") .set_body_typed(FunctionNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "FunctionNode(" << node->params << ", " << node->ret_type << ", " << node->body << ", " << node->type_params << ", " @@ -249,8 +249,8 @@ TVM_REGISTER_NODE_TYPE(CallNode); TVM_REGISTER_GLOBAL("relay._make.Call") .set_body_typed(CallNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "CallNode(" << node->op << ", " << node->args << ", " << node->attrs << ", " << node->type_args << ")"; @@ -269,8 +269,8 @@ TVM_REGISTER_NODE_TYPE(LetNode); TVM_REGISTER_GLOBAL("relay._make.Let") .set_body_typed(LetNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "LetNode(" << node->var << ", " << node->value << ", " << node->body << ")"; @@ -289,8 +289,8 @@ TVM_REGISTER_NODE_TYPE(IfNode); TVM_REGISTER_GLOBAL("relay._make.If") .set_body_typed(IfNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "IfNode(" << node->cond << ", " << node->true_branch << ", " << node->false_branch << ")"; @@ -308,8 +308,8 @@ TVM_REGISTER_NODE_TYPE(TupleGetItemNode); TVM_REGISTER_GLOBAL("relay._make.TupleGetItem") .set_body_typed(TupleGetItemNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; }); @@ -325,8 +325,8 @@ TVM_REGISTER_NODE_TYPE(RefCreateNode); TVM_REGISTER_GLOBAL("relay._make.RefCreate") .set_body_typed(RefCreateNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "RefCreateNode(" << node->value << ")"; }); @@ -342,8 +342,8 @@ TVM_REGISTER_NODE_TYPE(RefReadNode); TVM_REGISTER_GLOBAL("relay._make.RefRead") .set_body_typed(RefReadNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "RefReadNode(" << node->ref << ")"; }); @@ -360,8 +360,8 @@ TVM_REGISTER_NODE_TYPE(RefWriteNode); TVM_REGISTER_GLOBAL("relay._make.RefWrite") .set_body_typed(RefWriteNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")"; }); diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 4e57258981c6..195914612b3f 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -31,7 +31,7 @@ namespace tvm { namespace relay { -using tvm::IRPrinter; +using tvm::NodePrinter; using namespace runtime; Module ModuleNode::make(tvm::Map global_funcs, @@ -414,8 +414,8 @@ TVM_REGISTER_GLOBAL("relay._module.Module_ImportFromStd") mod->ImportFromStd(path); });; -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "ModuleNode( " << node->functions << ")"; }); diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 4bef724957de..eb42477e954a 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -224,8 +224,8 @@ TVM_REGISTER_NODE_TYPE(OpNode) return static_cast(n)->name; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "Op(" << node->name << ")"; }); diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index aa9d37649586..8a386aeb0b6e 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -26,7 +26,7 @@ namespace tvm { namespace relay { -using tvm::IRPrinter; +using tvm::NodePrinter; using namespace tvm::runtime; TensorType TensorTypeNode::make(Array shape, DataType dtype) { @@ -57,8 +57,8 @@ TVM_REGISTER_NODE_TYPE(TensorTypeNode); TVM_REGISTER_GLOBAL("relay._make.TensorType") .set_body_typed(TensorTypeNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")"; }); @@ -75,8 +75,8 @@ TVM_REGISTER_NODE_TYPE(TypeCallNode); TVM_REGISTER_GLOBAL("relay._make.TypeCall") .set_body_typed(TypeCallNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "TypeCallNode(" << node->func << ", " << node->args << ")"; @@ -95,8 +95,8 @@ TVM_REGISTER_GLOBAL("relay._make.IncompleteType") return IncompleteTypeNode::make(static_cast(kind)); }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; }); @@ -118,8 +118,8 @@ TVM_REGISTER_NODE_TYPE(TypeRelationNode); TVM_REGISTER_GLOBAL("relay._make.TypeRelation") .set_body_typed(TypeRelationNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "TypeRelationNode(" << node->func->name @@ -137,8 +137,8 @@ TVM_REGISTER_NODE_TYPE(TupleTypeNode); TVM_REGISTER_GLOBAL("relay._make.TupleType") .set_body_typed(TupleTypeNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "TupleTypeNode(" << node->fields << ")"; }); @@ -154,8 +154,8 @@ TVM_REGISTER_GLOBAL("relay._make.RefType") TVM_REGISTER_NODE_TYPE(RefTypeNode); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "RefTypeNode(" << node->value << ")"; }); diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index ae02d7008842..e02dcc0dcda5 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -34,7 +34,7 @@ namespace tvm { namespace relay { namespace transform { -using tvm::IRPrinter; +using tvm::NodePrinter; struct RelayPassContextThreadLocalEntry { /*! \brief The default pass context. */ @@ -453,8 +453,8 @@ TVM_REGISTER_GLOBAL("relay._transform.Info") *ret = pass->Info(); }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, tvm::IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, tvm::NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "The meta data of the pass: "; p->stream << "pass name: " << node->name; @@ -479,8 +479,8 @@ TVM_REGISTER_GLOBAL("relay._transform.RunPass") *ret = pass(mod); }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); const PassInfo info = node->Info(); p->stream << "Run Module pass: " << info->name @@ -492,8 +492,8 @@ TVM_REGISTER_NODE_TYPE(FunctionPassNode); TVM_REGISTER_GLOBAL("relay._transform.MakeFunctionPass") .set_body_typed(FunctionPassNode::make); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); const PassInfo info = node->Info(); p->stream << "Run Function pass: " << info->name @@ -512,8 +512,8 @@ TVM_REGISTER_GLOBAL("relay._transform.Sequential") *ret = Sequential(passes, pass_info); }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); const PassInfo info = node->Info(); p->stream << "Run Sequential pass: " << info->name @@ -542,8 +542,8 @@ TVM_REGISTER_GLOBAL("relay._transform.PassContext") *ret = pctx; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* node = static_cast(ref.get()); p->stream << "Pass context information: " << "\n"; p->stream << "\topt_level: " << node->opt_level << "\n"; diff --git a/src/relay/pass/quantize/quantize.cc b/src/relay/pass/quantize/quantize.cc index c995994757a7..33e2c2f24cde 100644 --- a/src/relay/pass/quantize/quantize.cc +++ b/src/relay/pass/quantize/quantize.cc @@ -116,8 +116,8 @@ QConfig& QConfig::Current() { TVM_REGISTER_NODE_TYPE(QConfigNode); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { auto* op = static_cast(ref.get()); p->stream << "qconfig("; p->stream << "nbit_input=" << op->nbit_input << ", "; diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index 91d3726f0bab..be4251354916 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -798,8 +798,8 @@ TVM_REGISTER_NODE_TYPE(SingletonNode); TVM_REGISTER_NODE_TYPE(ScheduleNode); // Printer -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); if (op->op.defined()) { p->stream << "stage(" << op->origin_op->name << ", " << op << ")"; @@ -807,11 +807,11 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "group-stage(" << op << ")"; } }) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << IterVarType2String(op->iter_type); }) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "split(parent="; p->Print(op->parent); @@ -821,7 +821,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->Print(op->inner); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "split("; p->stream << "outer="; @@ -832,7 +832,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->Print(op->fused); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "rebase("; p->stream << "parent="; @@ -841,13 +841,13 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->Print(op->rebased); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "singleton("; p->Print(op->iter); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); p->stream << "schedule(" << op << ")"; });