From 4166672009bbfc7fcc3ca98bd331d6b57e11a68c Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 20 Jan 2020 20:06:17 -0800 Subject: [PATCH] [REFACTOR] Establish printer in the source folder (#4752) * [REFACTOR] Establish printer in the source folder. As we move towards the unified IR, we will eventually want to build a unified printers for both relay and TIR. This PR isolate the printer component into a separate folder in src as a first step. - Refactored the Doc DSL using Object, clean up APIs. - Isolate out the meta data into a header. - move printer into relay_text_printer, add comments about further TODos. * Rename NodePrinter -> ReprPrinter to distinguish it from other printers --- CMakeLists.txt | 1 + apps/lldb/tvm.py | 4 +- include/tvm/ir/module.h | 28 ++ include/tvm/node/functor.h | 14 +- include/tvm/node/node.h | 2 +- .../tvm/node/{printer.h => repr_printer.h} | 16 +- include/tvm/relay/expr.h | 16 +- src/arith/const_int_bound.cc | 4 +- src/arith/int_set.cc | 4 +- src/arith/modular_set.cc | 4 +- src/ir/adt.cc | 8 +- src/ir/attrs.cc | 4 +- src/ir/env_func.cc | 4 +- src/ir/error.cc | 2 +- src/ir/expr.cc | 28 +- src/ir/module.cc | 4 +- src/ir/op.cc | 4 +- src/ir/span.cc | 8 +- src/ir/tensor_type.cc | 6 +- src/ir/transform.cc | 20 +- src/ir/type.cc | 28 +- src/ir/type_relation.cc | 8 +- src/node/{printer.cc => repr_printer.cc} | 10 +- src/printer/doc.cc | 173 +++++++++ src/printer/doc.h | 165 +++++++++ src/printer/meta_data.h | 140 ++++++++ .../relay_text_printer.cc} | 335 +++++++----------- src/relay/backend/interpreter.cc | 16 +- src/relay/ir/adt.cc | 24 +- src/relay/ir/doc.cc | 126 ------- src/relay/ir/doc.h | 130 ------- src/relay/ir/expr.cc | 47 +-- src/relay/ir/transform.cc | 6 +- src/relay/pass/fuse_ops.cc | 2 +- src/relay/pass/quantize/quantize.cc | 4 +- src/target/generic_func.cc | 2 +- src/target/target.cc | 10 +- src/target/target_info.cc | 6 +- src/tir/ir/buffer.cc | 4 +- src/tir/ir/data_layout.cc | 8 +- src/tir/ir/expr.cc | 98 ++--- src/tir/ir/lowered_func.cc | 4 +- src/tir/ir/stmt.cc | 62 ++-- src/top/operation/compute_op.cc | 4 +- src/top/operation/extern_op.cc | 4 +- src/top/operation/hybrid_op.cc | 4 +- src/top/operation/placeholder_op.cc | 4 +- src/top/operation/scan_op.cc | 4 +- src/top/operation/tensor_compute_op.cc | 4 +- src/top/schedule/schedule_lang.cc | 16 +- src/top/tensor.cc | 12 +- 51 files changed, 901 insertions(+), 740 deletions(-) rename include/tvm/node/{printer.h => repr_printer.h} (85%) rename src/node/{printer.cc => repr_printer.cc} (87%) create mode 100644 src/printer/doc.cc create mode 100644 src/printer/doc.h create mode 100644 src/printer/meta_data.h rename src/{relay/ir/pretty_printer.cc => printer/relay_text_printer.cc} (74%) delete mode 100644 src/relay/ir/doc.cc delete mode 100644 src/relay/ir/doc.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 554dabe96f99..2fa023dbeb35 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -133,6 +133,7 @@ file(GLOB_RECURSE COMPILER_SRCS src/autotvm/*.cc src/tir/*.cc src/driver/*.cc + src/printer/*.cc src/api/*.cc ) diff --git a/apps/lldb/tvm.py b/apps/lldb/tvm.py index a2607b7baa15..811d32db6c75 100644 --- a/apps/lldb/tvm.py +++ b/apps/lldb/tvm.py @@ -144,7 +144,7 @@ def _GetContext(debugger): def PrettyPrint(debugger, command, result, internal_dict): ctx = _GetContext(debugger) rc = ctx.EvaluateExpression( - "tvm::relay::PrettyPrint({command})".format(command=command) + "tvm::PrettyPrint({command})".format(command=command) ) result.AppendMessage(str(rc)) @@ -175,7 +175,7 @@ def _EvalExpressionAsString(logger, ctx, expr): def _EvalAsNodeRef(logger, ctx, value): return _EvalExpressionAsString( - logger, ctx, "tvm::relay::PrettyPrint({name})".format(name=value.name) + logger, ctx, "tvm::PrettyPrint({name})".format(name=value.name) ) diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 8f922c0d42f7..ad5f4d9b8ccb 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -308,5 +308,33 @@ class IRModule : public ObjectRef { TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path); }; +/*! + * \brief Pretty print a node for debug purposes. + * + * \param node The node to be printed. + * \return The text reperesentation. + * \note This function does not show version or meta-data. + * Use AsText if you want to store the text. + * \sa AsText. + */ +TVM_DLL std::string PrettyPrint(const ObjectRef& node); + +/*! + * \brief Render the node as a string in the text format. + * + * \param node The node to be rendered. + * \param show_meta_data Whether to print meta data section. + * \param annotate An optional callback function for attaching + * additional comment block to an expr. + * + * \note We support a limited set of IR nodes that are part of + * relay IR and + * + * \sa PrettyPrint. + * \return The text representation. + */ +TVM_DLL std::string AsText(const ObjectRef& node, + bool show_meta_data = true, + runtime::TypedPackedFunc annotate = nullptr); } // namespace tvm #endif // TVM_IR_MODULE_H_ diff --git a/include/tvm/node/functor.h b/include/tvm/node/functor.h index d925fbde4671..e11fda892c30 100644 --- a/include/tvm/node/functor.h +++ b/include/tvm/node/functor.h @@ -139,11 +139,11 @@ class NodeFunctor { * \brief Useful macro to set NodeFunctor dispatch in a global static field. * * \code - * // Use NodeFunctor to implement NodePrinter similar to Visitor Pattern. + * // Use NodeFunctor to implement ReprPrinter similar to Visitor Pattern. * // vtable allows easy patch of new Node types, without changing - * // interface of NodePrinter. + * // interface of ReprPrinter. * - * class NodePrinter { + * class ReprPrinter { * public: * std::ostream& stream; * // the dispatch function. @@ -152,18 +152,18 @@ class NodeFunctor { * f(e, this); * } * - * using FType = NodeFunctor; + * using FType = NodeFunctor; * // function to return global function table * static FType& vtable(); * }; * * // in cpp/cc file - * NodePrinter::FType& NodePrinter::vtable() { // NOLINT(*) + * ReprPrinter::FType& ReprPrinter::vtable() { // NOLINT(*) * static FType inst; return inst; * } * - * TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) - * .set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + * TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + * .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { * auto* n = static_cast(ref.get()); * p->print(n->a); * p->stream << '+' diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index 10c577a890be..3ea3d763df74 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -38,7 +38,7 @@ #include #include #include -#include +#include #include #include diff --git a/include/tvm/node/printer.h b/include/tvm/node/repr_printer.h similarity index 85% rename from include/tvm/node/printer.h rename to include/tvm/node/repr_printer.h index a4c6a696633c..41789a34d342 100644 --- a/include/tvm/node/printer.h +++ b/include/tvm/node/repr_printer.h @@ -17,25 +17,25 @@ * under the License. */ /*! - * \file tvm/node/printer.h + * \file tvm/node/repr_printer.h * \brief Printer class to print repr string of each AST/IR nodes. */ -#ifndef TVM_NODE_PRINTER_H_ -#define TVM_NODE_PRINTER_H_ +#ifndef TVM_NODE_REPR_PRINTER_H_ +#define TVM_NODE_REPR_PRINTER_H_ #include #include namespace tvm { /*! \brief A printer class to print the AST/IR nodes. */ -class NodePrinter { +class ReprPrinter { public: /*! \brief The output stream */ std::ostream& stream; /*! \brief The indentation level. */ int indent{0}; - explicit NodePrinter(std::ostream& stream) // NOLINT(*) + explicit ReprPrinter(std::ostream& stream) // NOLINT(*) : stream(stream) {} /*! \brief The node to be printed. */ @@ -43,7 +43,7 @@ class NodePrinter { /*! \brief Print indent to the stream */ TVM_DLL void PrintIndent(); // Allow registration to be printer. - using FType = NodeFunctor; + using FType = NodeFunctor; TVM_DLL static FType& vtable(); }; @@ -60,9 +60,9 @@ namespace runtime { // default print function for all objects // provide in the runtime namespace as this is where objectref originally comes from. inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*) - NodePrinter(os).Print(n); + ReprPrinter(os).Print(n); return os; } } // namespace runtime } // namespace tvm -#endif // TVM_NODE_PRINTER_H_ +#endif // TVM_NODE_REPR_PRINTER_H_ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 1062c20bb4f9..72523dd08dd3 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include "./base.h" @@ -40,6 +41,7 @@ using BaseFunc = tvm::BaseFunc; using BaseFuncNode = tvm::BaseFuncNode; using GlobalVar = tvm::GlobalVar; using GlobalVarNode = tvm::GlobalVarNode; +using tvm::PrettyPrint; /*! * \brief Constant tensor, backed by an NDArray on the cpu(0) device. @@ -539,20 +541,6 @@ class TempExpr : public Expr { TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, RelayExpr, TempExprNode); }; -/*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */ -std::string PrettyPrint(const ObjectRef& node); - -/*! - * \brief Render the node as a string in the Relay text format. - * \param node The node to be rendered. - * \param show_meta_data Whether to print meta data section. - * \param annotate An optional callback function for attaching - * additional comment block to an expr. - * \return The text representation. - */ -std::string AsText(const ObjectRef& node, - bool show_meta_data = true, - runtime::TypedPackedFunc annotate = nullptr); /*! \brief namespace of the attributes that are attached to a function. */ namespace attr { diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index a75e86a32660..7fb90a5e87c1 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -51,8 +51,8 @@ inline void PrintBoundValue(std::ostream& os, int64_t val) { } } -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "ConstIntBound["; PrintBoundValue(p->stream, op->min_value); diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 27cdffee02b1..728cca1b5705 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -813,8 +813,8 @@ IntSet EvalSet(Range r, TVM_REGISTER_NODE_TYPE(IntervalSetNode); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "IntervalSet" << "[" << op->min_value << ", " diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index 8b5309272efe..c3031ca0edfc 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -44,8 +44,8 @@ ModularSet::ModularSet(int64_t coeff, int64_t base) { data_ = std::move(node); } -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "ModularSet(" << "coeff=" << op->coeff << ", base=" diff --git a/src/ir/adt.cc b/src/ir/adt.cc index 2914779ae7ad..f94284090e26 100644 --- a/src/ir/adt.cc +++ b/src/ir/adt.cc @@ -45,8 +45,8 @@ TVM_REGISTER_GLOBAL("relay._make.Constructor") return Constructor(name_hint, inputs, belong_to); }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "ConstructorNode(" << node->name_hint << ", " << node->inputs << ", " << node->belong_to << ")"; @@ -71,8 +71,8 @@ TVM_REGISTER_GLOBAL("relay._make.TypeData") return TypeData(header, type_vars, constructors); }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", " << node->constructors << ")"; diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index 8c6e191ce287..c5d7446d2955 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -59,8 +59,8 @@ Attrs DictAttrsNode::make(Map dict) { return Attrs(n); } -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << op->dict; }); diff --git a/src/ir/env_func.cc b/src/ir/env_func.cc index b041d73949fd..b125c0318853 100644 --- a/src/ir/env_func.cc +++ b/src/ir/env_func.cc @@ -31,8 +31,8 @@ using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "EnvFunc(" << op->name << ")"; }); diff --git a/src/ir/error.cc b/src/ir/error.cc index 62faf502e1ca..9d498288d2ba 100644 --- a/src/ir/error.cc +++ b/src/ir/error.cc @@ -111,7 +111,7 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { // // The annotation callback will annotate the error messages // contained in the map. - annotated_prog << relay::AsText(func, false, [&err_map](tvm::relay::Expr expr) { + annotated_prog << AsText(func, false, [&err_map](const ObjectRef& expr) { auto it = err_map.find(expr); if (it != err_map.end()) { CHECK_NE(it->second.size(), 0); diff --git a/src/ir/expr.cc b/src/ir/expr.cc index f81eb33ba0df..f194a386e359 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -78,8 +78,8 @@ TVM_REGISTER_GLOBAL("make.IntImm") TVM_REGISTER_NODE_TYPE(IntImmNode); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); if (op->dtype == DataType::Int(32)) { p->stream << op->value; @@ -104,8 +104,8 @@ TVM_REGISTER_GLOBAL("make.FloatImm") TVM_REGISTER_NODE_TYPE(FloatImmNode); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); auto& stream = p->stream; switch (op->dtype.bits()) { @@ -134,8 +134,8 @@ Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) { return Range(make_object(min, extent)); } -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; }); @@ -159,15 +159,15 @@ TVM_REGISTER_GLOBAL("relay._make.GlobalVar") return GlobalVar(name); }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "GlobalVar(" << node->name_hint << ")"; }); // Container printer -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '['; for (size_t i = 0 ; i < op->data.size(); ++i) { @@ -179,8 +179,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << ']'; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '{'; for (auto it = op->data.begin(); it != op->data.end(); ++it) { @@ -194,8 +194,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << '}'; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '{'; for (auto it = op->data.begin(); it != op->data.end(); ++it) { diff --git a/src/ir/module.cc b/src/ir/module.cc index 01a8baaedb82..7f3796ed07f5 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -434,8 +434,8 @@ TVM_REGISTER_GLOBAL("relay._module.Module_ImportFromStd") mod->ImportFromStd(path); });; -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "IRModuleNode( " << node->functions << ")"; }); diff --git a/src/ir/op.cc b/src/ir/op.cc index 2bdb04d729ac..558b69891ae6 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -227,8 +227,8 @@ TVM_REGISTER_NODE_TYPE(OpNode) return static_cast(n)->name; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "Op(" << node->name << ")"; }); diff --git a/src/ir/span.cc b/src/ir/span.cc index 2519321eba38..2ea7095c89ac 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -48,8 +48,8 @@ SourceName SourceName::Get(const std::string& name) { TVM_REGISTER_GLOBAL("relay._make.SourceName") .set_body_typed(SourceName::Get); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "SourceName(" << node->name << ", " << node << ")"; }); @@ -73,8 +73,8 @@ TVM_REGISTER_NODE_TYPE(SpanNode); TVM_REGISTER_GLOBAL("relay._make.Span") .set_body_typed(SpanNode::make); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "Span(" << node->source << ", " << node->lineno << ", " << node->col_offset << ")"; diff --git a/src/ir/tensor_type.cc b/src/ir/tensor_type.cc index 0a9ed4eed327..5e7c51c72d9b 100644 --- a/src/ir/tensor_type.cc +++ b/src/ir/tensor_type.cc @@ -27,7 +27,7 @@ namespace tvm { -using tvm::NodePrinter; +using tvm::ReprPrinter; using namespace tvm::runtime; TensorType::TensorType(Array shape, DataType dtype) { @@ -60,8 +60,8 @@ TVM_REGISTER_GLOBAL("relay._make.TensorType") return TensorType(shape, dtype); }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")"; }); diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 5d0f5d845833..1da010c5979d 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include // TODO(tqchen): Update to use String container after it is merged. @@ -38,7 +38,7 @@ namespace transform { using tvm::runtime::TVMArgs; using tvm::runtime::TVMRetValue; -using tvm::NodePrinter; +using tvm::ReprPrinter; struct PassContextThreadLocalEntry { /*! \brief The default pass context. */ @@ -341,8 +341,8 @@ TVM_REGISTER_GLOBAL("relay._transform.Info") *ret = pass->Info(); }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, tvm::NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, tvm::ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "The meta data of the pass: "; p->stream << "pass name: " << node->name; @@ -371,8 +371,8 @@ TVM_REGISTER_GLOBAL("relay._transform.RunPass") *ret = pass(mod); }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); const PassInfo info = node->Info(); p->stream << "Run Module pass: " << info->name @@ -391,8 +391,8 @@ TVM_REGISTER_GLOBAL("relay._transform.Sequential") *ret = Sequential(passes, pass_info); }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); const PassInfo info = node->Info(); p->stream << "Run Sequential pass: " << info->name @@ -421,8 +421,8 @@ TVM_REGISTER_GLOBAL("relay._transform.PassContext") *ret = pctx; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "Pass context information: " << "\n"; p->stream << "\topt_level: " << node->opt_level << "\n"; diff --git a/src/ir/type.cc b/src/ir/type.cc index 233274a79e02..02ddfc9371fd 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -38,8 +38,8 @@ TVM_REGISTER_GLOBAL("relay._make.PrimType") return PrimType(dtype); }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << node->dtype; }); @@ -59,8 +59,8 @@ TVM_REGISTER_GLOBAL("relay._make.TypeVar") return TypeVar(name, static_cast(kind)); }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "TypeVar(" << node->name_hint << ", " << node->kind << ")"; @@ -81,8 +81,8 @@ TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar") return GlobalTypeVar(name, static_cast(kind)); }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "GlobalTypeVar(" << node->name_hint << ", " << node->kind << ")"; @@ -110,8 +110,8 @@ TVM_REGISTER_GLOBAL("relay._make.FuncType") return FuncType(arg_types, ret_type, type_params, type_constraints); }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "FuncType(" << node->type_params << ", " << node->arg_types << ", " << node->ret_type << ", " @@ -136,8 +136,8 @@ TVM_REGISTER_GLOBAL("relay._make.TupleType") return TupleType(fields); }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "TupleTypeNode(" << node->fields << ")"; }); @@ -156,8 +156,8 @@ TVM_REGISTER_GLOBAL("relay._make.IncompleteType") return IncompleteType(static_cast(kind)); }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; }); @@ -176,8 +176,8 @@ TVM_REGISTER_GLOBAL("relay._make.RefType") TVM_REGISTER_NODE_TYPE(RelayRefTypeNode); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "RelayRefTypeNode(" << node->value << ")"; }); diff --git a/src/ir/type_relation.cc b/src/ir/type_relation.cc index 361525c55044..1d80f95b10c9 100644 --- a/src/ir/type_relation.cc +++ b/src/ir/type_relation.cc @@ -40,8 +40,8 @@ TVM_REGISTER_GLOBAL("relay._make.TypeCall") return TypeCall(func, type); }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "TypeCallNode(" << node->func << ", " << node->args << ")"; @@ -69,8 +69,8 @@ TVM_REGISTER_GLOBAL("relay._make.TypeRelation") return TypeRelation(func, args, num_inputs, attrs); }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "TypeRelationNode(" << node->func->name diff --git a/src/node/printer.cc b/src/node/repr_printer.cc similarity index 87% rename from src/node/printer.cc rename to src/node/repr_printer.cc index e0176d245d80..ef91d2f5965e 100644 --- a/src/node/printer.cc +++ b/src/node/repr_printer.cc @@ -19,13 +19,13 @@ /*! * Printer utilities - * \file node/printer.cc + * \file node/repr_printer.cc */ -#include +#include namespace tvm { -void NodePrinter::Print(const ObjectRef& node) { +void ReprPrinter::Print(const ObjectRef& node) { static const FType& f = vtable(); if (!node.defined()) { stream << "(nullptr)"; @@ -39,13 +39,13 @@ void NodePrinter::Print(const ObjectRef& node) { } } -void NodePrinter::PrintIndent() { +void ReprPrinter::PrintIndent() { for (int i = 0; i < indent; ++i) { stream << ' '; } } -NodePrinter::FType& NodePrinter::vtable() { +ReprPrinter::FType& ReprPrinter::vtable() { static FType inst; return inst; } diff --git a/src/printer/doc.cc b/src/printer/doc.cc new file mode 100644 index 000000000000..9072fd6bda33 --- /dev/null +++ b/src/printer/doc.cc @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relay/doc.cc + * \brief Doc ADT used for pretty printing. + * + * Reference: Philip Wadler. A Prettier Printer. Journal of Functional Programming'98 + */ +#include +#include +#include +#include "doc.h" + +namespace tvm { + +/*! + * \brief Represent a piece of text in the doc. + */ +class DocTextNode : public DocAtomNode { + public: + /*! \brief The str content in the text. */ + std::string str; + + explicit DocTextNode(std::string str_val) + : str(str_val) { + if (str.find_first_of("\t\n") != str.npos) { + LOG(WARNING) << "text node: '" << str << "' should not has tab or newline."; + } + } + + static constexpr const char* _type_key = "printer.DocText"; + TVM_DECLARE_FINAL_OBJECT_INFO(DocTextNode, DocAtomNode); +}; + +TVM_REGISTER_OBJECT_TYPE(DocTextNode); + +class DocText : public DocAtom { + public: + explicit DocText(std::string str) { + data_ = runtime::make_object(str); + } + + TVM_DEFINE_OBJECT_REF_METHODS(DocText, DocAtom, DocTextNode); +}; + +/*! + * \brief Represent a line breaker in the doc. + */ +class DocLineNode : public DocAtomNode { + public: + /*! \brief The amount of indent in newline. */ + int indent; + + explicit DocLineNode(int indent) + : indent(indent) {} + + static constexpr const char* _type_key = "printer.DocLine"; + TVM_DECLARE_FINAL_OBJECT_INFO(DocLineNode, DocAtomNode); +}; + +TVM_REGISTER_OBJECT_TYPE(DocLineNode); + +class DocLine : public DocAtom { + public: + explicit DocLine(int indent) { + data_ = runtime::make_object(indent); + } + + TVM_DEFINE_OBJECT_REF_METHODS(DocLine, DocAtom, DocLineNode); +}; + +// DSL function implementations +Doc& Doc::operator<<(const Doc& right) { + CHECK(this != &right); + this->stream_.insert( + this->stream_.end(), right.stream_.begin(), right.stream_.end()); + return *this; +} + +Doc& Doc::operator<<(std::string right) { + return *this << DocText(right); +} + +Doc& Doc::operator<<(const DocAtom& right) { + this->stream_.push_back(right); + return *this; +} + +std::string Doc::str() { + std::ostringstream os; + for (auto atom : this->stream_) { + if (auto* text = atom.as()) { + os << text->str; + } else if (auto* line = atom.as()) { + os << "\n" << std::string(line->indent, ' '); + } else { + LOG(FATAL) << "do not expect type " << atom->GetTypeKey(); + } + } + return os.str(); +} + +Doc Doc::NewLine(int indent) { + return Doc() << DocLine(indent); +} + +Doc Doc::Text(std::string text) { + return Doc() << DocText(text); +} + +Doc Doc::Indent(int indent, Doc doc) { + for (size_t i = 0; i < doc.stream_.size(); ++i) { + if (auto* line = doc.stream_[i].as()) { + doc.stream_[i] = DocLine(indent + line->indent); + } + } + return doc; +} + +Doc Doc::StrLiteral(const std::string& value, std::string quote) { + // TODO(M.K.): add escape. + Doc doc; + return doc << quote << value << quote; +} + +Doc Doc::PyBoolLiteral(bool value) { + if (value) { + return Doc::Text("True"); + } else { + return Doc::Text("False"); + } +} + +Doc Doc::Brace(std::string open, + const Doc& body, + std::string close, + int indent) { + Doc doc; + doc << open; + doc << Indent(indent, NewLine() << body) << NewLine(); + doc << close; + return doc; +} + +Doc Doc::Concat(const std::vector& vec, const Doc& sep) { + Doc seq; + if (vec.size() != 0) { + if (vec.size() == 1) return vec[0]; + seq << vec[0]; + for (size_t i = 1; i < vec.size(); ++i) { + seq << sep << vec[i]; + } + } + return seq; +} +} // namespace tvm diff --git a/src/printer/doc.h b/src/printer/doc.h new file mode 100644 index 000000000000..34a284b0f116 --- /dev/null +++ b/src/printer/doc.h @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/printer/doc.h + * \brief Doc ADT used for pretty printing. + * + * Reference: Philip Wadler. A Prettier Printer. Journal of Functional Programming'98 + */ +#ifndef TVM_PRINTER_DOC_H_ +#define TVM_PRINTER_DOC_H_ + +#include +#include +#include +#include +#include +#include + +namespace tvm { + +/*! + * \brief Doc atom node for the ADT. + * \sa DocAtom + */ +class DocAtomNode : public Object { + public: + static constexpr const char* _type_key = "printer.DocAtom"; + TVM_DECLARE_BASE_OBJECT_INFO(DocAtomNode, Object); +}; + +/*! + * \brief Managed reference to DocAtomNode. + * \sa DocAtomNode. +*/ +class DocAtom : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(DocAtom, ObjectRef, DocAtomNode); +}; + +/*! + * \brief Stream-like interface for Doc DSL. + * + * The Doc DSL de-couples the layout decision from the printing decision. + * + * The layout(code formating) decisions include: + * - Change indentation. + * - Break single line into multiple ones(subjected to future improvements). + */ +class Doc { + public: + /*! \brief default constructor */ + Doc() {} + /*! + * \brief Append right to the end of the current doc stream. + * \param right The doc to be appended. + * \return reference to self. + */ + Doc& operator<<(const Doc& right); + /*! + * \brief Append right to the end of the current doc stream. + * \param right The doc to be appended. + * \return reference to self. + * \note pass by value to allow copy elison optimization. + */ + Doc& operator<<(std::string right); + /*! + * \brief Append right to the end of the current doc stream. + * \param right The doc to be appended. + * \return reference to self. + */ + Doc& operator<<(const DocAtom& right); + /*! + * \brief Convert value to string via std::ostreamstream + * the append to the current doc stream. + * \param right The doc to be appended. + * \tparam T the type of the value. + * \return reference to self. + */ + template::value>::type> + Doc& operator<<(const T& value) { + std::ostringstream os; + os << value; + return *this << os.str(); + } + /*! + * \brief Convert the doc stream into string. + * \return The string representation. + */ + std::string str(); + /*! + * \brief Create a doc that represents text content. + * \return The created doc. + */ + static Doc Text(std::string value); + /*! + * \brief Create a doc that represents a new line. + * \return The created doc. + */ + static Doc NewLine(int indent = 0); + /*! + * \brief Create a new doc that adds indentation to everyline of the doc. + * \param indent The indent to be added. + * \param doc The doc to be indented. + * \return The created doc. + * \note pass by value to allow copy elison optimization. + */ + static Doc Indent(int indent, Doc doc); + /*! + * \brief Create a Doc that represents a string literal. + * \param value The content of the string literal. + * \param quote The quote in the literal. + * \return The created doc. + */ + static Doc StrLiteral(const std::string& value, std::string quote = "\""); + /*! + * \brief Create a Doc that represents a boolean literal in python syntax. + * \param value The bool value. + * \return The created doc. + */ + static Doc PyBoolLiteral(bool value); + /*! + * \brief Enclose body by brace and add indent. + * \param body The body + * \param open The open brace. + * \param close The close brace. + * \param indent amount of indentation. + * \return The created doc. + */ + static Doc Brace(std::string open, + const Doc& body, + std::string close, + int indent = 2); + /*! + * \brief Create a doc by concatenating together with separator. + * \param vec The docs to be concatenated. + * \param sep The seperator. + * \return The created doc. + */ + static Doc Concat(const std::vector& vec, const Doc& sep = Text(", ")); + + private: + /*! \brief Internal doc stream. */ + std::vector stream_; +}; + +} // namespace tvm +#endif // TVM_PRINTER_DOC_H_ diff --git a/src/printer/meta_data.h b/src/printer/meta_data.h new file mode 100644 index 000000000000..6c300fd85176 --- /dev/null +++ b/src/printer/meta_data.h @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/printer/meta_data.h + * \brief Meta data context for printers. + */ +#ifndef TVM_PRINTER_META_DATA_H_ +#define TVM_PRINTER_META_DATA_H_ + +#include +#include +#include +#include +#include "doc.h" + +namespace tvm { +/*! + * \brief Meta data context for Printers + * + * This is an important part to enable bi-directional serializability. + * We use tvm's Node system to build the current IR. + * It can be hard to design a text format for all the possible nodes + * as the set of nodes can grow when we do more extensions. + * + * Instead of trying to design readable text format for every node, + * we support a meta data section in the text format. + * We allow the text format to refer to a node in the meta data section. + * + * The meta data section is a json serialized string of an Map>. + * Each element in the meta data section can be referenced by the text format. + * Each meta data node is printed in the following format. + * + * meta[type-key-of-node>][] + * + * Specifically, consider the following IR(constructed by python). + * + * \code + * + * n = tvm.var("n") + * x = tvm.relay.var("x", shape=(n, 1)) + * f = tvm.relay.Function([x], x) + * print(f.astext()) + * + * \endcode + * + * The corresponding text format is shown in the following code block. + * + * \code + * + * fn (%x: Tensor[(meta[Variable][0],), float32]) { + * %x + * } + * # Meta data section is a json-serialized string + * # of the following array. + * # [tvm.var("n")] + * + * \endcode + * + * Note that we store tvm.var("n") in the meta data section. + * Since it is stored in the index-0 in the meta data section, + * we print it as meta[Variable][0]. + * + * The text parser can recover this object by loading from the corresponding + * location in the meta data section. + * + * This is is a design trade-off. + * It allows us to embedded any meta data in the text format, + * while still being able to tweak the text part of the printed IR easily. + */ +class TextMetaDataContext { + public: + /*! + * \brief Get text representation of meta node. + * \param node The node to be converted to meta node. + * \return A string representation of the meta node. + */ + Doc GetMetaNode(const ObjectRef& node) { + auto it = meta_repr_.find(node); + if (it != meta_repr_.end()) { + return it->second; + } + std::string type_key = node->GetTypeKey(); + CHECK(!type_key.empty()); + Array& mvector = + meta_data_[type_key]; + int64_t index = static_cast(mvector.size()); + mvector.push_back(node); + Doc doc; + doc << "meta[" << type_key << "][" << index << "]"; + meta_repr_[node] = doc; + return meta_repr_[node]; + } + + /*! + * \brief Print a key value pair + */ + Doc PrintKeyValue(const std::string& str, const Doc& v) const { + return Doc() << "\"" << str << "\": " << v; + } + + /*! + * \brief Get the metadata section in json format. + * \return the meta data string. + */ + Doc GetMetaSection() const { + if (meta_data_.size() == 0) return Doc(); + return Doc::Text( + SaveJSON(Map(meta_data_.begin(), meta_data_.end()))); + } + + /*! \return whether the meta data context is empty. */ + bool empty() const { + return meta_data_.empty(); + } + + private: + /*! \brief additional metadata stored in TVM json format */ + std::unordered_map > meta_data_; + /*! \brief map from meta data into its string representation */ + std::unordered_map meta_repr_; +}; +} // namespace tvm +#endif // TVM_PRINTER_META_DATA_H_ diff --git a/src/relay/ir/pretty_printer.cc b/src/printer/relay_text_printer.cc similarity index 74% rename from src/relay/ir/pretty_printer.cc rename to src/printer/relay_text_printer.cc index c21f565f430c..0fa4da5b5077 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -18,9 +18,11 @@ */ /*! - * \file pretty_printer.cc - * \brief Pretty printer for Relay programs - * Supports ANF, GNF, and metadata. + * \file text_format_printer.cc + * \brief Printer to print out the IR text format + * that can be parsed by a parser. + * + * Supports ANF, GNF in relay and metadata. * * Inlining heuristics: * - Always inline: @@ -31,142 +33,27 @@ * - Otherwise, inline if the node is at the end of a scope and is used at most once. */ #include -#include -#include #include +#include #include #include "doc.h" -#include "../pass/dependency_graph.h" -#include "../../ir/attr_functor.h" +#include "meta_data.h" +#include "../relay/pass/dependency_graph.h" +#include "../ir/attr_functor.h" namespace tvm { namespace relay { -static const char* kSemVer = "v0.0.4"; - -Doc Brace(const Doc& d, - const std::string& open = "{", - const std::string& close = "}", - int indent = 2) { - Doc doc; - doc << open; - doc << Indent(indent, PrintNewLine() << d) << PrintNewLine(); - doc << close; - return doc; -} - -/*! - * \brief Meta data context for PrettyPrinter. - * - * This is an important part to enable bi-directional serializability. - * We use tvm's Node system to build the current IR. - * It can be hard to design a text format for all the possible nodes - * as the set of nodes can grow when we do more extensions. - * - * Instead of trying to design readable text format for every node, - * we support a meta data section in the text format. - * We allow the text format to refer to a node in the meta data section. - * - * The meta data section is a json serialized string of an Map>. - * Each element in the meta data section can be referenced by the text format. - * Each meta data node is printed in the following format. - * - * meta[type-key-of-node>][] - * - * Specifically, consider the following IR(constructed by python). - * - * \code - * - * n = tvm.var("n") - * x = tvm.relay.var("x", shape=(n, 1)) - * f = tvm.relay.Function([x], x) - * print(f.astext()) - * - * \endcode - * - * The corresponding text format is shown in the following code block. - * - * \code - * - * fn (%x: Tensor[(meta[Variable][0],), float32]) { - * %x - * } - * # Meta data section is a json-serialized string - * # of the following array. - * # [tvm.var("n")] - * - * \endcode - * - * Note that we store tvm.var("n") in the meta data section. - * Since it is stored in the index-0 in the meta data section, - * we print it as meta[Variable][0]. - * - * The text parser can recover this object by loading from the corresponding - * location in the meta data section. - * - * This is is a design trade-off. - * It allows us to embedded any meta data in the text format, - * while still being able to tweak the text part of the printed IR easily. - */ -class TextMetaDataContext { - public: - /*! - * \brief Get text representation of meta node. - * \param node The node to be converted to meta node. - * \return A string representation of the meta node. - */ - Doc GetMetaNode(const ObjectRef& node) { - auto it = meta_repr_.find(node); - if (it != meta_repr_.end()) { - return it->second; - } - std::string type_key = node->GetTypeKey(); - CHECK(!type_key.empty()); - Array& mvector = - meta_data_[type_key]; - int64_t index = static_cast(mvector.size()); - mvector.push_back(node); - Doc doc; - doc << "meta[" << type_key << "][" << index << "]"; - meta_repr_[node] = doc; - return meta_repr_[node]; - } - - Doc PrintKeyValue(const std::string& str, const Doc& v) const { - return Doc("\"") << str << "\": " << v; - } - - /*! - * \brief Get the metadata section in json format. - * \return the meta data string. - */ - Doc GetMetaSection() const { - if (meta_data_.size() == 0) return Doc(); - return Doc(SaveJSON(Map(meta_data_.begin(), meta_data_.end()))); - } - - /*! \return whether the meta data context is empty. */ - bool empty() const { - return meta_data_.empty(); - } - - private: - /*! \brief additional metadata stored in TVM json format */ - std::unordered_map > meta_data_; - /*! \brief map from meta data into its string representation */ - std::unordered_map meta_repr_; -}; - -class PrettyPrinter : +class RelayTextPrinter : public ExprFunctor, public PatternFunctor, public TypeFunctor, public AttrFunctor { public: - explicit PrettyPrinter(bool show_meta_data, - runtime::TypedPackedFunc annotate) : - show_meta_data_(show_meta_data), - annotate_(annotate) {} + explicit RelayTextPrinter(bool show_meta_data, + runtime::TypedPackedFunc annotate) + : show_meta_data_(show_meta_data), + annotate_(annotate) {} /*! * \brief Print additional info about expr in comment. @@ -194,7 +81,7 @@ class PrettyPrinter : Doc doc; Doc body; doc << "{"; - doc << Indent(indent, body << PrintNewLine() << PrintScope(node)) << PrintNewLine(); + doc << Doc::Indent(indent, body << Doc::NewLine() << PrintScope(node)) << Doc::NewLine(); doc << "}"; return doc; } @@ -220,10 +107,10 @@ class PrettyPrinter : Doc doc; doc << PrintScope(node); if (!meta_.empty()) { - doc << PrintNewLine(); + doc << Doc::NewLine(); if (show_meta_data_) { // append meta data in the end. - doc << "METADATA:" << PrintNewLine() << meta_.GetMetaSection(); + doc << "METADATA:" << Doc::NewLine() << meta_.GetMetaSection(); } else { doc << "// meta data omitted. you can use show_meta_data=True to include meta data"; } @@ -244,8 +131,9 @@ class PrettyPrinter : } else if (node.as()) { return PrintMod(Downcast(node)); } else { - Doc doc; - return doc << node; + std::ostringstream os; + os << node; + return Doc() << os.str(); } } @@ -278,23 +166,23 @@ class PrettyPrinter : } } name_alloc_map_[unique_prefix] = 0; - return Doc(unique_prefix); + return Doc::Text(unique_prefix); } Doc Print(Kind k) { switch (k) { case kType: - return Doc("Type"); + return Doc::Text("Type"); case kShapeVar: - return Doc("Shape"); + return Doc::Text("Shape"); case kBaseType: - return Doc("BaseType"); + return Doc::Text("BaseType"); case kConstraint: - return Doc("Constraint"); + return Doc::Text("Constraint"); case kAdtHandle: - return Doc("AdtHandle"); + return Doc::Text("AdtHandle"); case kTypeData: - return Doc("TypeData"); + return Doc::Text("TypeData"); default: LOG(ERROR) << "Unknown Kind"; throw; @@ -387,7 +275,7 @@ class PrettyPrinter : // wrap GNFed let in brackets Doc body; printed_expr << "("; - printed_expr << Indent(2, body << PrintNewLine() << VisitExpr(expr)) << PrintNewLine(); + printed_expr << Doc::Indent(2, body << Doc::NewLine() << VisitExpr(expr)) << Doc::NewLine(); printed_expr << ")"; } else { printed_expr = VisitExpr(expr); @@ -399,7 +287,7 @@ class PrettyPrinter : if (expr.as()) { // This is our first time visiting the var and we hit the VarNode case // in the visitor. Thus the variable is free. - doc_stack_.back() << "free_var " << printed_expr << PrintNewLine(); + doc_stack_.back() << "free_var " << printed_expr << Doc::NewLine(); // Memoization is done in AllocVar. return memo_[expr]; } else if (inline_expr) { @@ -408,7 +296,7 @@ class PrettyPrinter : } else { Doc temp_var = AllocTemp(); memo_[expr] = temp_var; - doc_stack_.back() << temp_var << " = " << printed_expr << ";" << PrintNewLine(); + doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine(); return temp_var; } } @@ -419,6 +307,28 @@ class PrettyPrinter : return AllocVar(GetRef(op)); } + /*! + * \brief special method to print out const scalar + * \param dtype The data type + * \param value The value to be printed. + */ + template + static Doc ScalarLiteral(DataType dtype, const T& value) { + std::ostringstream os; + if (dtype == DataType::Int(32)) { + os << value; + } else if (dtype == DataType::Float(32)) { + os << value << 'f'; + } else if (dtype == DataType::Float(64)) { + os << value; + } else if (dtype == DataType::Bool()) { + return Doc::PyBoolLiteral(value != 0); + } else { + os << value; + } + return Doc::Text(os.str()); + } + Doc VisitExpr_(const ConstantNode* op) final { // Print out simple scalars directly. if (op->is_scalar()) { @@ -426,15 +336,15 @@ class PrettyPrinter : DataType dtype = DataType(op->data->dtype); CHECK_EQ(op->data->ctx.device_type, kDLCPU); if (dtype == DataType::Int(32)) { - return PrintConstScalar(dtype, static_cast(op->data->data)); + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } else if (dtype == DataType::Int(64)) { - return PrintConstScalar(dtype, static_cast(op->data->data)); + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } else if (dtype == DataType::Float(32)) { - return PrintConstScalar(dtype, static_cast(op->data->data)); + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } else if (dtype == DataType::Float(64)) { - return PrintConstScalar(dtype, static_cast(op->data->data)); + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } else if (dtype == DataType::Bool()) { - return PrintConstScalar(dtype, static_cast(op->data->data)); + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } } // default fall-back, record it as meta node. @@ -448,7 +358,7 @@ class PrettyPrinter : fields.push_back(Print(field)); } Doc doc; - doc << "(" << PrintSep(fields); + doc << "(" << Doc::Concat(fields); // conform to python tuple format (1,) if (op->fields.size() == 1) { doc << ","; @@ -478,7 +388,7 @@ class PrettyPrinter : << " = " << Print(op->value, false, true) << ";" - << PrintNewLine(); + << Doc::NewLine(); // we use a scope here so GNF hoisting doesn't escape too far // and nested, unique lets are not hoisted doc << PrintScope(op->body); @@ -492,9 +402,9 @@ class PrettyPrinter : doc << "["; std::vector type_params; for (const TypeVar& tv : fn->type_params) { - type_params.push_back(Doc(tv->name_hint)); + type_params.push_back(Doc::Text(tv->name_hint)); } - doc << PrintSep(type_params); + doc << Doc::Concat(type_params); doc << "]"; } doc << "("; @@ -505,7 +415,7 @@ class PrettyPrinter : for (const Doc& d : PrintFuncAttrs(fn->attrs)) { params.push_back(d); } - doc << PrintSep(params) << ") "; + doc << Doc::Concat(params) << ") "; if (fn->ret_type.defined()) { doc << "-> " << Print(fn->ret_type) << " "; } @@ -530,36 +440,36 @@ class PrettyPrinter : // type definitions for (const auto& kv : mod->type_definitions) { if (counter++ != 0) { - doc << PrintNewLine(); + doc << Doc::NewLine(); } doc << Print(kv.second); - doc << PrintNewLine(); + doc << Doc::NewLine(); } // functions for (const auto& kv : mod->functions) { dg_ = DependencyGraph::Create(&arena_, kv.second); if (counter++ != 0) { - doc << PrintNewLine(); + doc << Doc::NewLine(); } std::ostringstream os; os << "def @" << kv.first->name_hint; - doc << PrintFunc(Doc(os.str()), kv.second); - doc << PrintNewLine(); + doc << PrintFunc(Doc::Text(os.str()), kv.second); + doc << Doc::NewLine(); } return doc; } Doc VisitExpr_(const FunctionNode* op) final { - return PrintFunc(Doc("fn "), GetRef(op)); + return PrintFunc(Doc::Text("fn "), GetRef(op)); } Doc VisitExpr_(const GlobalVarNode* op) final { - return Doc('@' + op->name_hint); + return Doc::Text('@' + op->name_hint); } Doc VisitExpr_(const OpNode* op) final { - return Doc(op->name); + return Doc::Text(op->name); } Doc VisitExpr_(const CallNode* op) final { @@ -584,7 +494,7 @@ class PrettyPrinter : // don't print as a call if it's a 0-arity cons return doc; } else { - return doc << "(" << PrintSep(args) << ")"; + return doc << "(" << Doc::Concat(args) << ")"; } } @@ -619,13 +529,13 @@ class PrettyPrinter : Doc rhs_doc = PrintScope(clause->rhs); if (clause->rhs.as()) { // only add braces if there are multiple lines on the rhs - rhs_doc = Brace(rhs_doc); + rhs_doc = Doc::Brace("{", rhs_doc, "}"); } clause_doc << rhs_doc << ","; clause_docs.push_back(clause_doc); } - doc << Indent(2, body << PrintNewLine() << PrintSep(clause_docs, PrintNewLine())) - << PrintNewLine() << "}"; + doc << Doc::Indent(2, body << Doc::NewLine() << Doc::Concat(clause_docs, Doc::NewLine())) + << Doc::NewLine() << "}"; return doc; } @@ -651,7 +561,7 @@ class PrettyPrinter : for (const auto& pat : p->patterns) { pats.push_back(Print(pat)); } - doc << PrintSep(pats) << ")"; + doc << Doc::Concat(pats) << ")"; } return doc; } @@ -663,12 +573,12 @@ class PrettyPrinter : for (const auto& pat : pt->patterns) { pats.push_back(Print(pat)); } - doc << PrintSep(pats) << ")"; + doc << Doc::Concat(pats) << ")"; return doc; } Doc VisitPattern_(const PatternWildcardNode* pw) final { - return Doc("_"); + return Doc::Text("_"); } Doc VisitPattern_(const PatternVarNode* pv) final { @@ -684,7 +594,7 @@ class PrettyPrinter : for (Type input : n->inputs) { inputs.push_back(Print(input)); } - doc << PrintSep(inputs) << ")"; + doc << Doc::Concat(inputs) << ")"; } return doc; } @@ -711,11 +621,11 @@ class PrettyPrinter : } Doc VisitType_(const TypeVarNode* node) final { - return Doc(node->name_hint); + return Doc::Text(node->name_hint); } Doc VisitType_(const GlobalTypeVarNode* node) final { - return Doc(node->name_hint); + return Doc::Text(node->name_hint); } Doc VisitType_(const TypeCallNode* node) final { @@ -725,11 +635,15 @@ class PrettyPrinter : args.push_back(PrintType(t, false)); } doc << "["; - doc << PrintSep(args); + doc << Doc::Concat(args); doc << "]"; return doc; } + Doc PrintDType(DataType dtype) { + return Doc::Text(runtime::DLDataType2String(dtype)); + } + Doc VisitType_(const TensorTypeNode* node) final { // scalar type if (node->shape.size() == 0) { @@ -741,7 +655,7 @@ class PrettyPrinter : for (ObjectRef shape : node->shape) { shapes.push_back(PrintAttr(shape)); } - doc << PrintSep(shapes); + doc << Doc::Concat(shapes); return doc << "), " << PrintDType(node->dtype) << "]"; } @@ -751,7 +665,7 @@ class PrettyPrinter : fields.push_back(Print(field)); } Doc doc; - doc << "(" << PrintSep(fields); + doc << "(" << Doc::Concat(fields); // conform to python tuple format (1,) if (node->fields.size() == 1) { doc << ","; @@ -768,14 +682,14 @@ class PrettyPrinter : for (Type type_param : node->type_params) { type_params.push_back(Print(type_param)); } - doc << PrintSep(type_params); + doc << Doc::Concat(type_params); doc << "]"; } std::vector arg_types; for (Type arg_type : node->arg_types) { arg_types.push_back(Print(arg_type)); } - return doc << "(" << PrintSep(arg_types) << ") -> " << Print(node->ret_type); + return doc << "(" << Doc::Concat(arg_types) << ") -> " << Print(node->ret_type); } Doc VisitType_(const RelayRefTypeNode* node) final { @@ -795,7 +709,7 @@ class PrettyPrinter : for (Type type_var : node->type_vars) { type_vars.push_back(Print(type_var)); } - doc << PrintSep(type_vars) << "]"; + doc << Doc::Concat(type_vars) << "]"; } doc << " "; @@ -804,14 +718,14 @@ class PrettyPrinter : constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true)); } Doc separator; - separator << "," << PrintNewLine(); + separator << "," << Doc::NewLine(); Doc adt_body; - adt_body << PrintSep(constructor_docs, separator); + adt_body << Doc::Concat(constructor_docs, separator); // add trailing comma if there are any constructors if (!constructor_docs.empty()) { adt_body << ","; } - doc << Brace(adt_body); + doc << Doc::Brace("{", adt_body, "}"); in_adt_def_ = false; return doc; } @@ -832,7 +746,7 @@ class PrettyPrinter : } return printed_attr; } else { - return Doc("None"); + return Doc::Text("None"); } } @@ -847,28 +761,28 @@ class PrettyPrinter : for (auto val : op->data) { arr_vals.push_back(PrintAttr(val)); } - doc << PrintSep(arr_vals); + doc << Doc::Concat(arr_vals); doc << "]"; return doc; } Doc VisitAttr_(const tir::IntImmNode* op) final { - return PrintConstScalar(op->dtype, &(op->value)); + return ScalarLiteral(op->dtype, op->value); } Doc VisitAttr_(const tir::FloatImmNode* op) final { - return PrintConstScalar(op->dtype, &(op->value)); + return ScalarLiteral(op->dtype, op->value); } Doc VisitAttr_(const tir::StringImmNode* op) final { - return PrintString(op->value); + return Doc::StrLiteral(op->value); } private: /*! \brief Whether to print meta data. */ bool show_meta_data_; /*! \brief additional comment function */ - runtime::TypedPackedFunc annotate_; + runtime::TypedPackedFunc annotate_; /*! \brief Stack of docs to implement scoped GNFing. */ std::vector doc_stack_{}; /*! \brief Map from Expr to Doc */ @@ -896,9 +810,11 @@ class PrettyPrinter : /*! * \brief Attribute printer which prints the attributes in the call. */ -class PrettyPrinter::AttrPrinter : public AttrVisitor { +class RelayTextPrinter::AttrPrinter : + public AttrVisitor { public: - AttrPrinter(std::vector* doc, PrettyPrinter* parent) : docs(doc), parent_(parent) {} + AttrPrinter(std::vector* doc, RelayTextPrinter* parent) + : docs(doc), parent_(parent) {} template void PrintKV(const char* key, const T& value) { @@ -922,16 +838,16 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor { PrintKV(key, *value); } void Visit(const char* key, bool* value) final { - PrintKV(key, PrintBool(*value)); + PrintKV(key, Doc::PyBoolLiteral(*value)); } void Visit(const char* key, std::string* value) final { - PrintKV(key, PrintString(*value)); + PrintKV(key, Doc::StrLiteral(*value)); } void Visit(const char* key, void** value) final { LOG(FATAL) << "do not allow void as argument"; } void Visit(const char* key, DataType* value) final { - PrintKV(key, PrintString(runtime::DLDataType2String(*value))); + PrintKV(key, Doc::StrLiteral(runtime::DLDataType2String(*value))); } void Visit(const char* key, runtime::NDArray* value) final { LOG(FATAL) << "do not allow NDarray as argument"; @@ -942,10 +858,11 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor { private: std::vector* docs; - PrettyPrinter* parent_; + RelayTextPrinter* parent_; }; -std::vector PrettyPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& op) { +std::vector RelayTextPrinter::PrintCallAttrs( + const Attrs& attrs, const Expr& op) { std::vector docs; if (!attrs.defined()) return docs; const auto* op_node = op.as(); @@ -962,7 +879,7 @@ std::vector PrettyPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& o } } -std::vector PrettyPrinter::PrintFuncAttrs(const Attrs& attrs) { +std::vector RelayTextPrinter::PrintFuncAttrs(const Attrs& attrs) { std::vector docs; if (!attrs.defined()) return docs; const auto* dict_attrs = attrs.as(); @@ -974,30 +891,34 @@ std::vector PrettyPrinter::PrintFuncAttrs(const Attrs& attrs) { } return docs; } +} // namespace relay -std::string PrettyPrint_(const ObjectRef& node, - bool show_meta_data, - runtime::TypedPackedFunc annotate) { - Doc doc; - doc << kSemVer << PrintNewLine() - << PrettyPrinter(show_meta_data, annotate).PrintFinal(node); - return doc.str(); -} +static const char* kSemVer = "v0.0.4"; +// TODO(tvm-team): split into files, related: arith/analyzer.h +// +// - text_printer.h (common header) +// - text_printer.cc (prints modules dispatch into relay and tir files) +// - type_text_printer.cc(specific printing logics for types, +// can also consider put under type_text_printer) +// - Implements AsText +// - relay_text_printer.cc (specific printing logics for relay) +// - tir_text_printer.cc (specific printing logics for TIR) std::string PrettyPrint(const ObjectRef& node) { Doc doc; - doc << PrettyPrinter(false, runtime::TypedPackedFunc()).PrintFinal(node); + doc << relay::RelayTextPrinter(false, nullptr).PrintFinal(node); return doc.str(); } std::string AsText(const ObjectRef& node, - bool show_meta_data, - runtime::TypedPackedFunc annotate) { - return PrettyPrint_(node, show_meta_data, annotate); + bool show_meta_data, + runtime::TypedPackedFunc annotate) { + Doc doc; + doc << kSemVer << Doc::NewLine() + << relay::RelayTextPrinter(show_meta_data, annotate).PrintFinal(node); + return doc.str(); } TVM_REGISTER_GLOBAL("relay._expr.AsText") .set_body_typed(AsText); - -} // namespace relay } // namespace tvm diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 224ff778ff34..0e95a69d466b 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -47,8 +47,8 @@ InterpreterClosure::InterpreterClosure(tvm::Map env, data_ = std::move(n); } -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "InterpreterClosureNode(" << node->func << ", " << node->env << ")"; }); @@ -68,8 +68,8 @@ RecClosure::RecClosure(InterpreterClosure clos, Var bind) { data_ = std::move(n); } -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "RecClosureObj(" << node->clos << ")"; }); @@ -87,8 +87,8 @@ TVM_REGISTER_GLOBAL("relay._make.RefValue") TVM_REGISTER_NODE_TYPE(RefValueObj); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "RefValueObj(" << node->value << ")"; }); @@ -111,8 +111,8 @@ TVM_REGISTER_GLOBAL("relay._make.ConstructorValue") TVM_REGISTER_NODE_TYPE(ConstructorValueObj); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "ConstructorValueObj(" << node->tag << "," << node->fields << ")"; diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index bf9c9189e926..485a0a283a31 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -37,8 +37,8 @@ TVM_REGISTER_NODE_TYPE(PatternWildcardNode); TVM_REGISTER_GLOBAL("relay._make.PatternWildcard") .set_body_typed(PatternWildcardNode::make); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { p->stream << "PatternWildcardNode()"; }); @@ -53,8 +53,8 @@ TVM_REGISTER_NODE_TYPE(PatternVarNode); TVM_REGISTER_GLOBAL("relay._make.PatternVar") .set_body_typed(PatternVarNode::make); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "PatternVarNode(" << node->var << ")"; }); @@ -72,8 +72,8 @@ TVM_REGISTER_NODE_TYPE(PatternConstructorNode); TVM_REGISTER_GLOBAL("relay._make.PatternConstructor") .set_body_typed(PatternConstructorNode::make); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "PatternConstructorNode(" << node->constructor << ", " << node->patterns << ")"; @@ -90,8 +90,8 @@ TVM_REGISTER_NODE_TYPE(PatternTupleNode); TVM_REGISTER_GLOBAL("relay._make.PatternTuple") .set_body_typed(PatternTupleNode::make); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "PatternTupleNode(" << node->patterns << ")"; }); @@ -108,8 +108,8 @@ TVM_REGISTER_NODE_TYPE(ClauseNode); TVM_REGISTER_GLOBAL("relay._make.Clause") .set_body_typed(ClauseNode::make); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "ClauseNode(" << node->lhs << ", " << node->rhs << ")"; @@ -128,8 +128,8 @@ TVM_REGISTER_NODE_TYPE(MatchNode); TVM_REGISTER_GLOBAL("relay._make.Match") .set_body_typed(MatchNode::make); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "MatchNode(" << node->data << ", " << node->clauses << ", " << node->complete << ")"; diff --git a/src/relay/ir/doc.cc b/src/relay/ir/doc.cc deleted file mode 100644 index 26aec39e5282..000000000000 --- a/src/relay/ir/doc.cc +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/tvm/relay/doc.cc - * \brief Doc ADT used for pretty printing. - * Based on Section 1 of https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf. - */ -#include -#include -#include "doc.h" - -namespace tvm { -namespace relay { - -// Text constructor -DocAtom Text(const std::string& str) { - return std::make_shared(str); -} - -// Line constructor -DocAtom Line(int indent = 0) { - return std::make_shared(indent); -} - -Doc::Doc(const std::string& str) { - if (str == "\n") { - this->stream_ = {Line()}; - } else { - this->stream_ = {Text(str)}; - } -} - -// DSL function implementations - -Doc& Doc::operator<<(const Doc& right) { - CHECK(this != &right); - this->stream_.insert(this->stream_.end(), right.stream_.begin(), right.stream_.end()); - return *this; -} - -Doc& Doc::operator<<(const std::string& right) { - return *this << Doc(right); -} - -Doc& Doc::operator<<(const DocAtom& right) { - this->stream_.push_back(right); - return *this; -} - -Doc Indent(int indent, const Doc& doc) { - Doc ret; - for (auto atom : doc.stream_) { - if (auto text = std::dynamic_pointer_cast(atom)) { - ret.stream_.push_back(text); - } else if (auto line = std::dynamic_pointer_cast(atom)) { - ret.stream_.push_back(Line(indent + line->indent)); - } else {CHECK(false);} - } - return ret; -} - -std::string Doc::str() { - std::ostringstream os; - for (auto atom : this->stream_) { - if (auto text = std::dynamic_pointer_cast(atom)) { - os << text->str; - } else if (auto line = std::dynamic_pointer_cast(atom)) { - os << "\n" << std::string(line->indent, ' '); - } else {CHECK(false);} - } - return os.str(); -} - -Doc PrintSep(const std::vector& vec, const Doc& sep) { - Doc seq; - if (vec.size() != 0) { - seq = vec[0]; - for (size_t i = 1; i < vec.size(); i++) { - seq << sep << vec[i]; - } - } - return seq; -} - -Doc PrintBool(bool value) { - if (value) { - return Doc("True"); - } else { - return Doc("False"); - } -} - -Doc PrintDType(DataType dtype) { - return Doc(runtime::DLDataType2String(dtype)); -} - -Doc PrintString(const std::string& value) { - // TODO(M.K.): add escape. - Doc doc; - return doc << "\"" << value << "\""; -} - -Doc PrintNewLine(int ident) { - Doc doc; - return doc << Line(ident); -} - -} // namespace relay -} // namespace tvm diff --git a/src/relay/ir/doc.h b/src/relay/ir/doc.h deleted file mode 100644 index a41fd6145d26..000000000000 --- a/src/relay/ir/doc.h +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/relay/doc.h - * \brief Doc ADT used for pretty printing. - * Based on Section 1 of - * https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf, but with - * a vector instead of an implicitly linked list. - */ -#ifndef TVM_RELAY_IR_DOC_H_ -#define TVM_RELAY_IR_DOC_H_ - -#include -#include -#include -#include - -namespace tvm { -namespace relay { - -// Doc Atom ADT -struct DocAtomNode { - virtual ~DocAtomNode() = default; -}; - -using DocAtom = std::shared_ptr; - -struct TextNode : DocAtomNode { - std::string str; - - explicit TextNode(const std::string& str) : str(str) { - if (str.find_first_of("\t\n") != str.npos) { - LOG(WARNING) << "text node: '" << str << "' should not has tab or newline."; - } - } -}; - -struct LineNode : DocAtomNode { - int indent; - - explicit LineNode(int indent) : indent(indent) {} -}; - -// Doc is a stream-like interface -class Doc { - public: - Doc() {} - explicit Doc(const std::string& str); - template - explicit Doc(const T& str) { - (*this) << str; - } - - // Append right to this. - Doc& operator<<(const Doc& right); - // Like above. - Doc& operator<<(const std::string& right); - // Like above. - Doc& operator<<(const DocAtom& right); - // Like above, but converts right to a string first. - template - Doc& operator<<(const T& right) { - std::ostringstream os; - os << right; - return *this << os.str(); - } - - // Indent a doc stream. - friend Doc Indent(int indent, const Doc& doc); - - // Wadler's `layout` - std::string str(); - - private: - std::vector stream_; -}; - -// DSL functions - -// Render vectors of docs with a separator. e.g. PrintSep([1, 2, 3], f) -> 1f2f3 -Doc PrintSep(const std::vector& vec, const Doc& sep = Doc(", ")); -// Print a constant bool value. -Doc PrintBool(bool value); -// Print a data type. -Doc PrintDType(DataType dtype); -// Print a string. -Doc PrintString(const std::string& value); -// Print a newline. -Doc PrintNewLine(int indent = 0); -/*! - * \brief special method to print out const scalar - * \param dtype The data type - * \param data The pointer to hold the data. - */ -template -Doc PrintConstScalar(DataType dtype, const T* data) { - std::ostringstream os; - if (dtype == DataType::Int(32)) { - os << data[0]; - } else if (dtype == DataType::Float(32)) { - os << data[0] << 'f'; - } else if (dtype == DataType::Bool()) { - return PrintBool(data[0] != 0); - } else { - // todo(@M.K.) this is unsafe. fix. - os << data[0]; - } - return Doc(os.str()); -} - -} // namespace relay -} // namespace tvm -#endif // TVM_RELAY_IR_DOC_H_ diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 3d8cc3a85b2b..89395bb742c1 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -21,12 +21,13 @@ * \file src/tvm/relay/ir/expr.cc * \brief The expression AST nodes of Relay. */ +#include #include namespace tvm { namespace relay { -using tvm::NodePrinter; +using tvm::ReprPrinter; using namespace tvm::runtime; Constant ConstantNode::make(runtime::NDArray data) { @@ -40,8 +41,8 @@ TVM_REGISTER_NODE_TYPE(ConstantNode); TVM_REGISTER_GLOBAL("relay._make.Constant") .set_body_typed(ConstantNode::make); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); const PackedFunc* fprint = Registry::Get("relay._constant_repr"); CHECK(fprint) << "unable to find printing function for constants"; @@ -73,8 +74,8 @@ TVM_REGISTER_NODE_TYPE(TupleNode); TVM_REGISTER_GLOBAL("relay._make.Tuple") .set_body_typed(TupleNode::make); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "Tuple(" << node->fields << ")"; }); @@ -98,8 +99,8 @@ TVM_REGISTER_NODE_TYPE(VarNode); TVM_REGISTER_GLOBAL("relay._make.Var") .set_body_typed(static_cast(VarNode::make)); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "Var(" << node->name_hint(); if (node->type_annotation.defined()) { @@ -208,8 +209,8 @@ TVM_REGISTER_NODE_TYPE(FunctionNode); TVM_REGISTER_GLOBAL("relay._make.Function") .set_body_typed(FunctionNode::make); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "FunctionNode(" << node->params << ", " << node->ret_type << ", " << node->body << ", " << node->type_params << ", " @@ -231,8 +232,8 @@ TVM_REGISTER_NODE_TYPE(CallNode); TVM_REGISTER_GLOBAL("relay._make.Call") .set_body_typed(CallNode::make); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "CallNode(" << node->op << ", " << node->args << ", " << node->attrs << ", " << node->type_args << ")"; @@ -251,8 +252,8 @@ TVM_REGISTER_NODE_TYPE(LetNode); TVM_REGISTER_GLOBAL("relay._make.Let") .set_body_typed(LetNode::make); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "LetNode(" << node->var << ", " << node->value << ", " << node->body << ")"; @@ -271,8 +272,8 @@ TVM_REGISTER_NODE_TYPE(IfNode); TVM_REGISTER_GLOBAL("relay._make.If") .set_body_typed(IfNode::make); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "IfNode(" << node->cond << ", " << node->true_branch << ", " << node->false_branch << ")"; @@ -290,8 +291,8 @@ TVM_REGISTER_NODE_TYPE(TupleGetItemNode); TVM_REGISTER_GLOBAL("relay._make.TupleGetItem") .set_body_typed(TupleGetItemNode::make); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; }); @@ -307,8 +308,8 @@ TVM_REGISTER_NODE_TYPE(RefCreateNode); TVM_REGISTER_GLOBAL("relay._make.RefCreate") .set_body_typed(RefCreateNode::make); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "RefCreateNode(" << node->value << ")"; }); @@ -324,8 +325,8 @@ TVM_REGISTER_NODE_TYPE(RefReadNode); TVM_REGISTER_GLOBAL("relay._make.RefRead") .set_body_typed(RefReadNode::make); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "RefReadNode(" << node->ref << ")"; }); @@ -342,8 +343,8 @@ TVM_REGISTER_NODE_TYPE(RefWriteNode); TVM_REGISTER_GLOBAL("relay._make.RefWrite") .set_body_typed(RefWriteNode::make); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")"; }); diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index 1f2e8ed52f8d..ac0f36cf2205 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include @@ -157,8 +157,8 @@ TVM_REGISTER_NODE_TYPE(FunctionPassNode); TVM_REGISTER_GLOBAL("relay._transform.MakeFunctionPass") .set_body_typed(FunctionPassNode::make); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); const PassInfo info = node->Info(); p->stream << "Run Function pass: " << info->name diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 3d37d61448f0..fd8e6fd0f59f 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -957,7 +957,7 @@ class FuseMutator : private ExprMutator { // Debug function, dump the group assignment in text. void DebugDumpGroup(const Expr& body) { - std::string text = AsText(body, false, [this](const Expr& expr) -> std::string { + std::string text = AsText(body, false, [this](const ObjectRef& expr) -> std::string { auto it = gmap_.find(expr.get()); if (it == gmap_.end()) return ""; std::ostringstream os; diff --git a/src/relay/pass/quantize/quantize.cc b/src/relay/pass/quantize/quantize.cc index 2441f6e65d88..b3a8733c45e1 100644 --- a/src/relay/pass/quantize/quantize.cc +++ b/src/relay/pass/quantize/quantize.cc @@ -116,8 +116,8 @@ QConfig& QConfig::Current() { TVM_REGISTER_NODE_TYPE(QConfigNode); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* op = static_cast(ref.get()); p->stream << "qconfig("; p->stream << "nbit_input=" << op->nbit_input << ", "; diff --git a/src/target/generic_func.cc b/src/target/generic_func.cc index 884bf762c940..817d48f0cdbf 100644 --- a/src/target/generic_func.cc +++ b/src/target/generic_func.cc @@ -23,7 +23,7 @@ #include #include -#include +#include #include #include #include diff --git a/src/target/target.cc b/src/target/target.cc index a75e146586f5..245425a63921 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -23,7 +23,7 @@ #include #include -#include +#include #include #include @@ -39,8 +39,8 @@ using runtime::PackedFunc; TVM_REGISTER_NODE_TYPE(TargetNode); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << op->str(); }); @@ -381,8 +381,8 @@ tvm::BuildConfig BuildConfig::Current() { TVM_REGISTER_NODE_TYPE(BuildConfigNode); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "build_config("; p->stream << "data_alignment=" << op->data_alignment << ", "; diff --git a/src/target/target_info.cc b/src/target/target_info.cc index 6c332e77b9ba..73fe011cc936 100644 --- a/src/target/target_info.cc +++ b/src/target/target_info.cc @@ -21,13 +21,13 @@ * \file target/target_info.cc */ #include -#include +#include #include namespace tvm { -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "mem-info(" << "unit_bits=" << op->unit_bits << ", " diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index c2fc581a3904..ff67e8d9cbc2 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -453,8 +453,8 @@ Buffer BufferNode::make(Var data, return Buffer(n); } -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "buffer(" << op->name << ", " << op << ")"; }); diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 59fa2af41631..8a5125bca193 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -198,8 +198,8 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const { return -1; } -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* l = static_cast(node.get()); p->stream << "Layout(" << l->name << ")"; }); @@ -365,8 +365,8 @@ BijectiveLayout BijectiveLayoutNode::make(const Layout& src_layout, return BijectiveLayout(n); } -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* b = static_cast(node.get()); p->stream << "BijectiveLayout(" << b->src_layout.name() << "->" << b->dst_layout.name() << ")"; diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 0cdbfdc71c97..d06c33f79dcc 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -57,8 +57,8 @@ IterVar IterVarNode::make(Range dom, return IterVar(n); } -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "iter_var("; if (op->var->name_hint.length() != 0) { @@ -339,8 +339,8 @@ PrimExpr AnyNode::make() { return PrimExpr(n); } -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); auto& stream = p->stream; stream << '"'; @@ -375,24 +375,24 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) stream << '"'; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << op->dtype << '('; p->Print(op->value); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); // omit the type // stream << op->name << "." << op->type; p->stream << op->name_hint; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "{" << op->name_hint << "|" << op->name_hint << ">=0}"; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -400,7 +400,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -408,7 +408,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -416,7 +416,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -424,7 +424,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -432,7 +432,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "min("; p->Print(op->a); @@ -440,7 +440,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->Print(op->b); p->stream << ")"; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "max("; p->Print(op->a); @@ -448,7 +448,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->Print(op->b); p->stream << ")"; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -456,7 +456,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -464,7 +464,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -472,7 +472,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -480,7 +480,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -488,7 +488,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -497,20 +497,20 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << ')'; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "floordiv(" << op->a << ", " << op->b << ")"; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "floormod(" << op->a << ", " << op->b << ")"; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -519,8 +519,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << ')'; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); @@ -529,15 +529,15 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << ')'; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '!'; p->Print(op->a); }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "select("; p->Print(op->condition); @@ -548,8 +548,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << ")"; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << op->buffer_var << "["; p->Print(op->index); @@ -560,8 +560,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) } }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "ramp("; p->Print(op->base); @@ -570,16 +570,16 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << ", " << op->lanes << ")"; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "x" << op->lanes << "("; p->Print(op->value); p->stream << ")"; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << op->name << "("; for (size_t i = 0; i < op->args.size(); ++i) { @@ -591,8 +591,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << ")"; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "(let " << op->var << " = "; p->Print(op->value); @@ -601,13 +601,13 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << ")"; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { p->stream << "?"; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "reduce(combiner=" << op->combiner; @@ -618,8 +618,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << ")"; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "comm_reducer(result=" << op->result << ", lhs=" << op->lhs diff --git a/src/tir/ir/lowered_func.cc b/src/tir/ir/lowered_func.cc index a2755343fc43..c1331fbd4c1f 100644 --- a/src/tir/ir/lowered_func.cc +++ b/src/tir/ir/lowered_func.cc @@ -24,8 +24,8 @@ namespace tvm { namespace tir { -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "LoweredFunc(" << op->name << ", " << op << ")"; }); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 6d39d99f6939..0cd2aba319ee 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -248,8 +248,8 @@ Stmt EvaluateNode::make(PrimExpr value) { // Printers -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << "let " << op->var << " = "; @@ -258,8 +258,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->Print(op->body); }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << "// attr ["; @@ -271,8 +271,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->Print(op->body); }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << "assert("; @@ -283,8 +283,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->Print(op->body); }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); if (op->is_producer) { p->PrintIndent(); @@ -317,8 +317,8 @@ std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*) return out; } -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << op->for_type << " (" << op->loop_var << ", "; @@ -335,8 +335,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "}\n"; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << op->buffer_var << "["; @@ -350,8 +350,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << '\n'; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << op->func->func_name() << "("; @@ -368,8 +368,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << '\n'; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << "allocate " << op->buffer_var << "[" << op->dtype; @@ -386,16 +386,16 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->Print(op->body); }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << "free " << op->buffer_var; p->stream << '\n'; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << "realize " << op->func->func_name() << "("; @@ -425,8 +425,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "}\n"; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << "prefetch " << op->func->func_name() << "("; @@ -444,16 +444,16 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) } }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); for (Stmt stmt : op->seq) { p->Print(stmt); } }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); while (true) { @@ -483,8 +483,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "}\n"; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->Print(op->value); @@ -492,7 +492,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) }); template -void PrintList(const Array &exprs, NodePrinter* p) { +void PrintList(const Array &exprs, ReprPrinter* p) { for (size_t i = 0; i < exprs.size(); ++i) { p->Print(exprs[i]); if (i < exprs.size() - 1) { @@ -501,8 +501,8 @@ void PrintList(const Array &exprs, NodePrinter* p) { } } -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "shuffle("; PrintList(op->vectors, p); diff --git a/src/top/operation/compute_op.cc b/src/top/operation/compute_op.cc index f325ae85002c..598a0f74405d 100644 --- a/src/top/operation/compute_op.cc +++ b/src/top/operation/compute_op.cc @@ -39,8 +39,8 @@ namespace tvm { namespace top { using namespace tir; -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "compute(" << op->name << ", " << op << ")"; }); diff --git a/src/top/operation/extern_op.cc b/src/top/operation/extern_op.cc index 276b5ebf6bc7..8bc812e52a23 100644 --- a/src/top/operation/extern_op.cc +++ b/src/top/operation/extern_op.cc @@ -31,8 +31,8 @@ namespace tvm { namespace top { using namespace tir; // ExternOpNode -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "extern(" << op->name << ", " << op << ")"; }); diff --git a/src/top/operation/hybrid_op.cc b/src/top/operation/hybrid_op.cc index f4e3850650a3..4c1ab7073c89 100644 --- a/src/top/operation/hybrid_op.cc +++ b/src/top/operation/hybrid_op.cc @@ -37,8 +37,8 @@ namespace tvm { namespace top { using namespace tir; // HybridOpNode -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "hybrid(" << op->name << ", " << op << ")"; }); diff --git a/src/top/operation/placeholder_op.cc b/src/top/operation/placeholder_op.cc index 284752b3661c..13311a894123 100644 --- a/src/top/operation/placeholder_op.cc +++ b/src/top/operation/placeholder_op.cc @@ -27,8 +27,8 @@ namespace tvm { namespace top { // PlaceholderOpNode -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "placeholder(" << op->name << ", " << op << ")"; }); diff --git a/src/top/operation/scan_op.cc b/src/top/operation/scan_op.cc index 2ddb6bd11cc8..62ddecb2851f 100644 --- a/src/top/operation/scan_op.cc +++ b/src/top/operation/scan_op.cc @@ -31,8 +31,8 @@ namespace tvm { namespace top { using namespace tir; -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "scan(" << op->name << ", " << op << ")"; }); diff --git a/src/top/operation/tensor_compute_op.cc b/src/top/operation/tensor_compute_op.cc index 2cc821928809..5011c1662296 100644 --- a/src/top/operation/tensor_compute_op.cc +++ b/src/top/operation/tensor_compute_op.cc @@ -34,8 +34,8 @@ namespace tvm { namespace top { using namespace tir; // TensorComputeOpNode -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "tensor_compute_op(" << op->name << ", " << op << ")"; }); diff --git a/src/top/schedule/schedule_lang.cc b/src/top/schedule/schedule_lang.cc index 10d5ddc48b7b..130d06bda848 100644 --- a/src/top/schedule/schedule_lang.cc +++ b/src/top/schedule/schedule_lang.cc @@ -795,8 +795,8 @@ TVM_REGISTER_NODE_TYPE(SingletonNode); TVM_REGISTER_NODE_TYPE(ScheduleNode); // Printer -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); if (op->op.defined()) { p->stream << "stage(" << op->origin_op->name << ", " << op << ")"; @@ -804,11 +804,11 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "group-stage(" << op << ")"; } }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << IterVarType2String(op->iter_type); }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "split(parent="; p->Print(op->parent); @@ -818,7 +818,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->Print(op->inner); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "split("; p->stream << "outer="; @@ -829,7 +829,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->Print(op->fused); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "rebase("; p->stream << "parent="; @@ -838,13 +838,13 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->Print(op->rebased); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "singleton("; p->Print(op->iter); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "schedule(" << op << ")"; }); diff --git a/src/top/tensor.cc b/src/top/tensor.cc index c848cc4b4367..85232b9f719c 100644 --- a/src/top/tensor.cc +++ b/src/top/tensor.cc @@ -82,8 +82,8 @@ Tensor TensorNode::make(Array shape, return Tensor(n); } -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* t = static_cast(node.get()); p->stream << "Tensor(shape=" << t->shape << ", op.name=" << t->op->name << ')'; @@ -114,8 +114,8 @@ TensorIntrin TensorIntrinNode::make(std::string name, return TensorIntrin(n); } -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")"; }); @@ -139,8 +139,8 @@ TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, return TensorIntrinCall(n); } -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* n = static_cast(node.get()); p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")"; });