From a3fb7a861a0eb7eeb43372d687aec2583b1bbbf9 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 13 Jan 2023 21:56:15 -0800 Subject: [PATCH] [TVMScript] Migrate More to TVMScripr Printer This PR gradually migrates more pieces of the default printing to TVMScript printer for TIR. This PR gradually migrates more pieces of the default printing to TVMScript printer for TIR. Details: - Introduced a method `AsLegacyRepr` which preserves existing `AsRepr` provided by `ReprPrinter`, so that the legacy behavior could be 100% preserved. - Introduced `Script` method to `IRModule`, `PrimFunc`, `tir.Stmt`, `tir.PrimExpr`. The `script` method exists in python side before, and this PR introduced them to C++ to be consistent. - Replace TIR's `PrettyPrint` to `operator <<` that is provided by the new `ReprPrinter`, which outputs in TVMScript format by default. `PrettyPrint` on Relay is all preserved for backward compatibility. --- include/tvm/ir/expr.h | 11 + include/tvm/ir/module.h | 55 +- include/tvm/ir/type.h | 19 + include/tvm/ir/type_functor.h | 4 +- include/tvm/node/repr_printer.h | 32 + .../tvm/script/printer/ir_docsifier_functor.h | 3 + include/tvm/script/printer/printer.h | 15 - include/tvm/tir/expr.h | 3 - include/tvm/tir/function.h | 11 + include/tvm/tir/stmt.h | 11 + python/tvm/ir/__init__.py | 56 +- python/tvm/ir/affine_type.py | 7 +- python/tvm/ir/base.py | 12 +- python/tvm/ir/module.py | 126 ++- python/tvm/ir/tensor_type.py | 7 +- python/tvm/ir/type.py | 3 +- python/tvm/relay/dataflow_pattern/__init__.py | 5 + python/tvm/relay/expr.py | 24 +- python/tvm/relay/frontend/tensorflow_ops.py | 2 +- python/tvm/relay/function.py | 9 +- python/tvm/relay/op/contrib/cutlass.py | 8 +- python/tvm/relay/op/contrib/dnnl.py | 23 +- python/tvm/relay/op/contrib/ethosu.py | 64 +- python/tvm/relay/op/contrib/tensorrt.py | 13 +- python/tvm/runtime/_ffi_node_api.py | 5 + python/tvm/runtime/object.py | 19 +- python/tvm/tir/expr.py | 86 +- python/tvm/tir/function.py | 81 +- python/tvm/tir/schedule/schedule.py | 25 +- python/tvm/tir/stmt.py | 75 ++ src/arith/iter_affine_map.cc | 12 +- src/auto_scheduler/compute_dag.cc | 22 +- src/ir/adt.cc | 5 +- src/ir/attrs.cc | 6 - src/ir/error.cc | 7 - src/ir/expr.cc | 11 - src/ir/function.cc | 16 +- src/ir/module.cc | 157 +-- src/ir/transform.cc | 3 - src/ir/type.cc | 25 - src/meta_schedule/arg_info.cc | 5 +- src/meta_schedule/database/json_database.cc | 2 +- .../task_scheduler/task_scheduler.cc | 6 +- src/meta_schedule/utils.h | 1 - src/node/repr_printer.cc | 25 + src/node/structural_equal.cc | 4 +- src/printer/model_library_format_printer.cc | 6 +- src/printer/text_printer.h | 2 - src/printer/tvmscript_printer.cc | 10 - src/relay/backend/te_compiler_cache.cc | 3 +- src/relay/ir/function.cc | 130 +++ src/relay/transforms/defunctionalization.cc | 2 +- src/script/printer/ir/ir.cc | 70 +- src/script/printer/ir/script_method.cc | 34 + src/script/printer/ir/utils.h | 16 +- src/script/printer/legacy_repr.cc | 1008 +++++++++++++++++ src/script/printer/tir/block.cc | 4 +- src/script/printer/tir/buffer.cc | 16 +- src/script/printer/tir/expr.cc | 75 +- src/script/printer/tir/for_loop.cc | 4 +- src/script/printer/tir/function.cc | 20 +- src/script/printer/tir/ir.cc | 18 +- src/script/printer/tir/script_method.cc | 59 + src/script/printer/tir/stmt.cc | 26 +- src/script/printer/tir/utils.h | 31 +- src/script/printer/utils.h | 73 ++ src/target/source/interface_c.cc | 3 +- src/target/source/source_module.cc | 3 +- src/tir/analysis/control_flow_graph.cc | 5 +- src/tir/analysis/oob_checker.cc | 1 - src/tir/analysis/verify_memory.cc | 2 +- src/tir/ir/legacy_printer.cc | 270 ----- src/tir/schedule/analysis/verify.cc | 2 +- src/tir/schedule/error.cc | 1 + src/tir/schedule/primitive/compute_inline.cc | 10 +- .../primitive/layout_transformation.cc | 4 +- src/tir/schedule/utils.h | 2 +- src/tir/transforms/common_subexpr_elim.cc | 4 +- .../transforms/common_subexpr_elim_tools.cc | 4 +- src/tir/transforms/install_debug_spans.cc | 2 +- src/tir/transforms/narrow_datatype.cc | 1 - src/tir/usmp/transform/assign_pool_info.cc | 4 +- .../test_ethosu/test_encode_constants.py | 24 +- .../test_outline_compiler_functions.py | 4 +- .../test_ethosu/test_remove_concatenates.py | 7 +- .../test_ethosu/test_replace_conv2d.py | 11 +- .../contrib/test_ethosu/test_replace_copy.py | 13 +- tests/python/contrib/test_tensorrt.py | 12 +- .../python/contrib/test_uma/test_partition.py | 7 +- tests/python/frontend/pytorch/qnn_test.py | 39 +- .../unittest/test_arith_deduce_bound.py | 20 +- .../test_meta_schedule_schedule_rule_mlt.py | 112 +- tests/python/unittest/test_te_schedule.py | 7 +- tests/python/unittest/test_tir_nodes.py | 54 +- ...est_tir_transform_inject_ptx_async_copy.py | 6 +- ...est_tir_transform_inject_rolling_buffer.py | 12 +- ..._tir_transform_inject_software_pipeline.py | 4 +- .../test_tir_transform_make_packed_api.py | 16 +- .../test_tir_transform_thread_sync.py | 4 +- .../unittest/test_tvmscript_complete.py | 2 +- tests/python/unittest/test_tvmscript_ops.py | 8 +- .../unittest/test_tvmscript_printer_ir.py | 49 + .../unittest/test_tvmscript_printer_tir.py | 5 +- .../unittest/test_tvmscript_regression.py | 4 +- .../unittest/test_tvmscript_roundtrip.py | 23 +- 105 files changed, 2484 insertions(+), 1005 deletions(-) create mode 100644 src/script/printer/ir/script_method.cc create mode 100644 src/script/printer/legacy_repr.cc create mode 100644 src/script/printer/tir/script_method.cc create mode 100644 src/script/printer/utils.h delete mode 100644 src/tir/ir/legacy_printer.cc create mode 100644 tests/python/unittest/test_tvmscript_printer_ir.py diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index bb4c468f452f..bfbaa7cddd4f 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -100,6 +100,17 @@ class PrimExprNode : public BaseExprNode { */ DataType dtype; + /*! + * \brief Returns the TVMScript format + * \param indent_spaces Number of spaces used for indentation + * \param print_line_numbers Whether to print line numbers + * \param num_context_lines Number of context lines to print around the underlined text + * \param path_to_underline Object path to be underlined + */ + TVM_DLL std::string Script(int indent_spaces = 4, bool print_line_numbers = false, + int num_context_lines = -1, + Optional path_to_underline = NullOpt) const; + static constexpr const char* _type_key = "PrimExpr"; static constexpr const uint32_t _type_child_slots = 38; TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode); diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 7313b4f78349..f26e640f6c22 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -63,6 +63,26 @@ class IRModuleNode : public Object { parser::SourceMap source_map; /* \brief Additional attributes storing meta-data about the module. */ DictAttrs attrs; + /*! + * \brief A map from string names to global variables that + * ensures global uniqueness. + */ + Map global_var_map_; + + /*! \brief A map from string names to global type variables (ADT names) + * that ensures global uniqueness. + */ + Map global_type_var_map_; + + /*! \brief A map from constructor tags to constructor objects + * for convenient access + */ + std::unordered_map constructor_tag_map_; + + /*! \brief The files previously imported, required to ensure + importing is idempotent for each module. + */ + std::unordered_set import_set_; /*! * \brief Get a module attribute. @@ -304,15 +324,20 @@ class IRModuleNode : public Object { TVM_DLL void ImportFromStd(const String& path); /*! - * \brief Should Link Parameters into the module - * \return Whether the Executor is configured to execute with linked parameters (Default: false) + * \brief The set of imported files. */ - TVM_DLL Bool ShouldLinkParameters() const; + TVM_DLL std::unordered_set Imports() const; /*! - * \brief The set of imported files. + * \brief Returns the TVMScript format + * \param indent_spaces Number of spaces used for indentation + * \param print_line_numbers Whether to print line numbers + * \param num_context_lines Number of context lines to print around the underlined text + * \param path_to_underline Object path to be underlined */ - TVM_DLL std::unordered_set Imports() const; + TVM_DLL std::string Script(int indent_spaces = 4, bool print_line_numbers = false, + int num_context_lines = -1, + Optional path_to_underline = NullOpt) const; static constexpr const char* _type_key = "IRModule"; static constexpr const bool _type_has_method_sequal_reduce = true; @@ -322,26 +347,6 @@ class IRModuleNode : public Object { private: /*! \brief Helper function for registering a typedef's constructors */ void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type); - - /*! \brief A map from string names to global variables that - * ensures global uniqueness. - */ - Map global_var_map_; - - /*! \brief A map from string names to global type variables (ADT names) - * that ensures global uniqueness. - */ - Map global_type_var_map_; - - /*! \brief A map from constructor tags to constructor objects - * for convenient access - */ - std::unordered_map constructor_tag_map_; - - /*! \brief The files previously imported, required to ensure - importing is idempotent for each module. - */ - std::unordered_set import_set_; friend class IRModule; }; diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 579061e02eb6..62328f6a074a 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -207,6 +207,25 @@ enum TypeKind : int { kTypeData = 6 }; +/*! \brief Converts a TypeKind to a string. */ +inline String TypeKind2String(TypeKind kind) { + switch (kind) { + case TypeKind::kType: + return "Type"; + case TypeKind::kShapeVar: + return "ShapeVar"; + case TypeKind::kBaseType: + return "BaseType"; + case TypeKind::kConstraint: + return "Constraint"; + case TypeKind::kAdtHandle: + return "AdtHandle"; + case TypeKind::kTypeData: + return "TypeData"; + } + LOG(FATAL) << "ValueError: Unknown TypeKind: " << static_cast(kind); +} + /*! * \brief Type parameter in functions. * diff --git a/include/tvm/ir/type_functor.h b/include/tvm/ir/type_functor.h index 11bf7d4740d0..334a35d052e1 100644 --- a/include/tvm/ir/type_functor.h +++ b/include/tvm/ir/type_functor.h @@ -24,9 +24,9 @@ #ifndef TVM_IR_TYPE_FUNCTOR_H_ #define TVM_IR_TYPE_FUNCTOR_H_ +#include +#include #include -#include -#include #include #include diff --git a/include/tvm/node/repr_printer.h b/include/tvm/node/repr_printer.h index 532425a51b3e..e3f59fcc14a1 100644 --- a/include/tvm/node/repr_printer.h +++ b/include/tvm/node/repr_printer.h @@ -26,6 +26,7 @@ #include #include +#include namespace tvm { /*! \brief A printer class to print the AST/IR nodes. */ @@ -48,6 +49,30 @@ class ReprPrinter { TVM_DLL static FType& vtable(); }; +/*! \brief Legacy behavior of ReprPrinter. */ +class ReprLegacyPrinter { + public: + /*! \brief The indentation level. */ + int indent{0}; + + explicit ReprLegacyPrinter(std::ostream& stream) // NOLINT(*) + : stream(stream) {} + + /*! \brief The node to be printed. */ + TVM_DLL void Print(const ObjectRef& node); + /*! \brief Print indent to the stream */ + TVM_DLL void PrintIndent(); + /*! \brief Return the ostream it maintains */ + TVM_DLL std::ostream& Stream() const; + // Allow registration to be printer. + using FType = NodeFunctor; + TVM_DLL static FType& vtable(); + + private: + /*! \brief The output stream */ + std::ostream& stream; +}; + /*! * \brief Dump the node to stderr, used for debug purposes. * \param node The input node @@ -70,6 +95,13 @@ inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLI ReprPrinter(os).Print(n); return os; } + +inline std::string AsLegacyRepr(const ObjectRef& n) { + std::ostringstream os; + ReprLegacyPrinter(os).Print(n); + return os.str(); +} } // namespace runtime +using runtime::AsLegacyRepr; } // namespace tvm #endif // TVM_NODE_REPR_PRINTER_H_ diff --git a/include/tvm/script/printer/ir_docsifier_functor.h b/include/tvm/script/printer/ir_docsifier_functor.h index d04d8c4d028a..54810fd55a43 100644 --- a/include/tvm/script/printer/ir_docsifier_functor.h +++ b/include/tvm/script/printer/ir_docsifier_functor.h @@ -69,6 +69,9 @@ class IRDocsifierFunctor { if ((pf = LookupDispatchTable("", type_index)) != nullptr) { return (*pf)(obj, args...); } + LOG(WARNING) << "ObjectFunctor calls un-registered function on type: " + << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")" + << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj; ICHECK(false) << "ObjectFunctor calls un-registered function on type: " << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")" << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj; diff --git a/include/tvm/script/printer/printer.h b/include/tvm/script/printer/printer.h index 289e838b52a8..b373a2be73fb 100644 --- a/include/tvm/script/printer/printer.h +++ b/include/tvm/script/printer/printer.h @@ -55,21 +55,6 @@ struct Default { static bool& VerboseExpr() { return Instance()->verbose_expr; } }; -/*! - * \brief The entry method for TVMScript printing - * \param obj The object to be printed - * \param indent_spaces Number of spaces used for indentation - * \param print_line_numbers Whether to print line numbers - * \param num_context_lines Number of context lines to print around the underlined text - * \param path_to_underline Object path to be underlined - * \return The TVMScript text format - */ -String Script(ObjectRef obj, // - int indent_spaces = 4, // - bool print_line_numbers = false, // - int num_context_lines = -1, // - Optional path_to_underline = NullOpt); - /*! * \brief Convert Doc into Python script. * \param doc Doc to be converted diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 1d5e8f317a2e..689b1c0a17ad 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -1191,9 +1191,6 @@ class Any : public PrimExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(AnyNode); }; -/*! \brief Legacy ReprPrint format for TIR */ -std::string LegacyTIRPrint(const ObjectRef& obj); - /* * \brief Template function to convert Map to unordered_map * Sometimes useful for API gluing when internal uses unordered_map diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 9f7c0fa16b06..17e7de930260 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -132,6 +132,17 @@ class PrimFuncNode : public BaseFuncNode { */ TVM_DLL FuncType func_type_annotation() const; + /*! + * \brief Returns the TVMScript format + * \param indent_spaces Number of spaces used for indentation + * \param print_line_numbers Whether to print line numbers + * \param num_context_lines Number of context lines to print around the underlined text + * \param path_to_underline Object path to be underlined + */ + std::string Script(int indent_spaces = 4, bool print_line_numbers = false, + int num_context_lines = -1, + Optional path_to_underline = NullOpt) const; + static constexpr const char* _type_key = "tir.PrimFunc"; TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode); }; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 96e03477a141..e0b7bcc868b3 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -46,6 +46,17 @@ class StmtNode : public Object { StmtNode() = default; explicit StmtNode(Span span) : span(span) {} + /*! + * \brief Returns the TVMScript format + * \param indent_spaces Number of spaces used for indentation + * \param print_line_numbers Whether to print line numbers + * \param num_context_lines Number of context lines to print around the underlined text + * \param path_to_underline Object path to be underlined + */ + std::string Script(int indent_spaces = 4, bool print_line_numbers = false, + int num_context_lines = -1, + Optional path_to_underline = NullOpt) const; + static constexpr const char* _type_key = "tir.Stmt"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 4e847c0310a4..9e81dd5519e1 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -16,29 +16,47 @@ # under the License. # pylint: disable=unused-import """Common data structures across all IR variants.""" -from .base import SourceName, Span, Node, EnvFunc, load_json, save_json -from .base import structural_equal, assert_structural_equal, structural_hash -from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType -from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType -from .tensor_type import TensorType -from .affine_type import TensorAffineType, TupleAffineType -from .type_relation import TypeCall, TypeRelation -from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range -from .op import Op, register_op_attr, register_intrin_lowering -from .function import CallingConv, BaseFunc +from . import diagnostics, instrument, transform from .adt import Constructor, TypeData -from .module import IRModule +from .affine_type import TensorAffineType, TupleAffineType from .attrs import Attrs, DictAttrs, make_node +from .base import ( + EnvFunc, + Node, + SourceName, + Span, + assert_structural_equal, + load_json, + pretty_print, + save_json, + structural_equal, + structural_hash, +) from .container import Array, Map +from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelayExpr +from .function import BaseFunc, CallingConv from .memory_pools import ( - PoolInfo, - WorkspacePoolInfo, - ConstantPoolInfo, - WorkspaceMemoryPools, ConstantMemoryPools, + ConstantPoolInfo, + PoolInfo, PoolInfoProperties, + WorkspaceMemoryPools, + WorkspacePoolInfo, ) - -from . import transform -from . import instrument -from . import diagnostics +from .module import IRModule +from .op import Op, register_intrin_lowering, register_op_attr +from .tensor_type import TensorType +from .type import ( + FuncType, + GlobalTypeVar, + IncompleteType, + PointerType, + PrimType, + RelayRefType, + TupleType, + Type, + TypeConstraint, + TypeKind, + TypeVar, +) +from .type_relation import TypeCall, TypeRelation diff --git a/python/tvm/ir/affine_type.py b/python/tvm/ir/affine_type.py index bd77c187af40..8d185ae59a34 100644 --- a/python/tvm/ir/affine_type.py +++ b/python/tvm/ir/affine_type.py @@ -17,8 +17,8 @@ """Types for quantized Tensors.""" import tvm._ffi -from .base import Node from . import _ffi_api +from .base import Node class AffineType(Node): @@ -31,6 +31,11 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + def __str__(self): + from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel + + return pretty_print(self) + @tvm._ffi.register_object("TensorAffineType") class TensorAffineType(AffineType): diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index d754ae567c5e..a1e1d20d8823 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -16,13 +16,16 @@ # under the License. """Common base structures.""" import tvm._ffi - import tvm.error import tvm.runtime._ffi_node_api from tvm.runtime import Object -from . import _ffi_api -from . import json_compact +from . import _ffi_api, json_compact + + +def pretty_print(obj: Object) -> None: + """Pretty print the object.""" + return _ffi_api.PrettyPrint(obj) # type: ignore # pylint: disable=no-member class Node(Object): @@ -54,9 +57,6 @@ def astext(self, show_meta_data=True, annotate=None): """ return _ffi_api.AsText(self, show_meta_data, annotate) - def __str__(self): - return _ffi_api.PrettyPrint(self) - @tvm._ffi.register_object("SourceName") class SourceName(Object): diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 3ed7e57cb758..b184c3b0c3cf 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -17,13 +17,13 @@ """IRModule that holds the functions and type definitions.""" from typing import Optional -from tvm._ffi.base import string_types import tvm._ffi +from tvm._ffi.base import string_types -from .base import Node +from . import _ffi_api from . import expr as _expr from . import type as _ty -from . import _ffi_api +from .base import Node @tvm._ffi.register_object("IRModule") @@ -252,51 +252,6 @@ def import_from_std(self, file_to_import): _ffi_api.Module_ImportFromStd(self, file_to_import) return tvm.relay.transform.InferType()(self) - def __str__(self): - return _ffi_api.PrettyPrint(self) - - def __repr__(self): - return self.astext() - - def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: - """Print IRModule into TVMScript - - Parameters - ---------- - tir_prefix : str - The tir namespace prefix - - show_meta : bool - Whether to show meta information - - Returns - ------- - script : str - The TVM Script of the IRModule - """ - return tvm._ffi.get_global_func("script.AsTVMScript")( - self, tir_prefix, show_meta - ) # type: ignore - - def show(self, style: Optional[str] = None, black_format: bool = True) -> None: - """A sugar for print highlighted TVM script. - - Parameters - ---------- - style : str, optional - - Pygmentize printing style, auto-detected if None. See - `tvm.script.highlight.cprint` for more details. - - black_format: bool - - If true (default), use the formatter Black to format the TVMScript - """ - from tvm.script.highlight import cprint # pylint: disable=import-outside-toplevel - - # Use deferred import to avoid circular import while keeping cprint under tvm/script - cprint(self, style=style, black_format=black_format) - def get_attr(self, attr_key): """Get the IRModule attribute. @@ -331,3 +286,78 @@ def with_attr(self, attr_key, attr_value): """ return _ffi_api.Module_WithAttr(self, attr_key, attr_value) + + def script( + self, + *, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: Optional[int] = None, + path_to_underline=None, + ) -> str: + """Print IRModule into TVMScript + + Parameters + ---------- + indent_spaces : int + The number of indent spaces to use in the output + print_line_numbers: bool + Whether to print line numbers + num_context_lines : Optional[int] + Number of context lines to print around the underlined text + path_to_underline : Optional[ObjectPath] + Object path to be underlined + + Returns + ------- + script : str + The TVM Script of the IRModule + """ + if num_context_lines is None: + num_context_lines = -1 + return _ffi_api.Module_Script( # type: ignore # pylint: disable=no-member + self, indent_spaces, print_line_numbers, num_context_lines, path_to_underline + ) + + def show( + self, + *, + style: Optional[str] = None, + black_format: bool = True, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: Optional[int] = None, + path_to_underline=None, + ) -> None: + """A sugar for print highlighted TVM script. + + Parameters + ---------- + style : str, optional + Pygmentize printing style, auto-detected if None. See + `tvm.script.highlight.cprint` for more details. + black_format: bool + If true (default), use the formatter Black to format the TVMScript + indent_spaces : int + The number of indent spaces to use in the output + print_line_numbers: bool + Whether to print line numbers + num_context_lines : Optional[int] + Number of context lines to print around the underlined text + path_to_underline : Optional[ObjectPath] + Object path to be underlined + """ + from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel + cprint, + ) + + cprint( + self.script( + indent_spaces=indent_spaces, + print_line_numbers=print_line_numbers, + num_context_lines=num_context_lines, + path_to_underline=path_to_underline, + ), + style=style, + black_format=black_format, + ) diff --git a/python/tvm/ir/tensor_type.py b/python/tvm/ir/tensor_type.py index 22b15a397e30..7313f3c2b42c 100644 --- a/python/tvm/ir/tensor_type.py +++ b/python/tvm/ir/tensor_type.py @@ -17,8 +17,8 @@ """Type relation and function for type checking.""" import tvm._ffi -from .type import Type from . import _ffi_api +from .type import Type @tvm._ffi.register_object("relay.TensorType") @@ -54,3 +54,8 @@ def concrete_shape(self): TypeError : If the shape is symbolic """ return tuple(int(x) for x in self.shape) + + def __str__(self): + from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel + + return pretty_print(self) diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py index 4fe28f1d72e2..ea06aeda2030 100644 --- a/python/tvm/ir/type.py +++ b/python/tvm/ir/type.py @@ -16,11 +16,12 @@ # under the License. """Unified type system in the project.""" from enum import IntEnum + import tvm import tvm._ffi -from .base import Node from . import _ffi_api +from .base import Node class Type(Node): diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 1f6d8bb9ab0b..6c29825bc04d 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -46,6 +46,11 @@ def register_df_node(type_key=None): class DFPattern(Node): """Base class of all Patterns.""" + def __str__(self): + from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel + + return pretty_print(self) + def __call__(self, *args): args = list(args) if len(args) == 1 and args[0] is None: diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 88b84bbe7ebc..7d60e89b59b7 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -17,17 +17,20 @@ # pylint: disable=no-else-return, invalid-name, unused-import """The expression nodes of Relay.""" from __future__ import absolute_import + from numbers import Number as _Number import numpy as _np + import tvm._ffi from tvm._ffi import base as _base -from tvm.runtime import NDArray, ndarray as _nd -from tvm.ir import RelayExpr, GlobalVar, Node +from tvm.ir import GlobalVar, Node, RelayExpr +from tvm.runtime import NDArray +from tvm.runtime import ndarray as _nd -from .base import RelayNode from . import _ffi_api from . import ty as _ty +from .base import RelayNode # alias relay expr as Expr. Expr = RelayExpr @@ -58,6 +61,11 @@ def astype(self, dtype): """ return _ffi_api.cast(self, dtype) + def __str__(self): + from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel + + return pretty_print(self) + def __neg__(self): return _op_make.negative(self) @@ -710,6 +718,11 @@ class StorageInfo(Node): def __init__(self, sids, dev_types, sizes): self.__init_handle_by_constructor__(_ffi_api.StorageInfo, sids, dev_types, sizes) + def __str__(self): + from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel + + return pretty_print(self) + @property def storage_ids(self): return _ffi_api.StorageInfoStorageIds(self) @@ -735,3 +748,8 @@ class StaticMemoryPlan(Node): def __init__(self, expr_to_storage_info): self.__init_handle_by_constructor__(_ffi_api.StaticMemoryPlan, expr_to_storage_info) + + def __str__(self): + from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel + + return pretty_print(self) diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index 66bb858edbf0..e9bb15e1d1c6 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -1847,7 +1847,7 @@ def _impl(inputs, attr, params, mod): shape_arg = tuple(params_new.numpy().astype("int32").flatten()) except Exception: # Deal with symbolic shape case. - if isinstance(pop_node, _expr.Call) and "shape_of" in str(pop_node.op): + if isinstance(pop_node, _expr.Call) and "shape_of" in str(pop_node.op.name): # shape_of is the direct ancestor. return _op.reshape_like(inputs[0], pop_node.args[0]) shape_arg = pop_node diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py index 68d8953900cf..ef3356450085 100644 --- a/python/tvm/relay/function.py +++ b/python/tvm/relay/function.py @@ -19,11 +19,11 @@ from __future__ import absolute_import import tvm._ffi -from tvm.runtime import convert from tvm.ir import BaseFunc +from tvm.runtime import convert -from .expr import Call from . import _ffi_api +from .expr import Call @tvm._ffi.register_object("relay.Function") @@ -67,6 +67,11 @@ def __call__(self, *args): """ return Call(self, args, None, None) + def __str__(self): + from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel + + return pretty_print(self) + @tvm._ffi.register_func("relay.FunctionWithFields") def FunctionWithFields( diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 1a441a6f03c2..6fce020a6694 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -17,12 +17,14 @@ # pylint: disable=invalid-name """Patterns supported CUTLASS.""" from functools import partial + from tvm import relay -from tvm.ir.transform import Sequential, PassContext +from tvm.ir.transform import PassContext, Sequential from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name from tvm.relay.op.contrib.register import register_pattern_table # type: ignore -from ...dataflow_pattern import wildcard, is_op, is_constant + +from ...dataflow_pattern import is_constant, is_op, wildcard def make_gelu_pattern(bias_out, out_dtype="float16"): @@ -124,7 +126,7 @@ def check_dtype(lhs, rhs): def get_root_call(call, root_op_name): if not isinstance(call, relay.Call): return None - if str(call.op) == root_op_name: + if str(call.op.name) == root_op_name: return call return get_root_call(call.args[0], root_op_name) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index bdf910d704ce..7db8608d6d7c 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -36,22 +36,25 @@ from functools import reduce import tvm.ir -from tvm.ir import Op from tvm import relay +from tvm.ir import Op +from tvm.relay import expr as _expr from tvm.relay import transform -from tvm.relay.expr import GlobalVar -from tvm.relay.expr_functor import ExprMutator, ExprVisitor -from tvm.relay.expr import const - from tvm.relay.analysis import analysis as _analysis -from tvm.relay import expr as _expr +from tvm.relay.expr import Call, GlobalVar, TupleGetItem, const +from tvm.relay.expr_functor import ExprMutator, ExprVisitor -from tvm.relay.expr import Call, TupleGetItem from ... import _ffi_api -from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr, rewrite, DFPatternCallback +from ...dataflow_pattern import ( + DFPatternCallback, + is_constant, + is_expr, + is_op, + rewrite, + wildcard, +) from .register import register_pattern_table - logger = logging.getLogger("DNNL") supported_post_elts = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", "mish", None] @@ -762,7 +765,7 @@ def visit_call(self, call): ] ) if isinstance(call.op, tvm.tir.op.Op): - if str(call.op) in compute_intensive_ops: + if str(call.op.name) in compute_intensive_ops: self.is_compute_intensive = True return super().visit_call(call) diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index a86357db39fc..bd9a7d5ba0d1 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -17,16 +17,22 @@ # pylint: disable=ungrouped-imports, import-outside-toplevel """Arm(R) Ethos(TM)-U NPU supported operators.""" import functools -from typing import Dict, List, Tuple, Callable, Optional +from typing import Callable, Dict, List, Optional, Tuple import numpy as np # type: ignore import tvm # type: ignore from tvm import relay -from tvm.relay.expr import Constant, Call # type: ignore -from tvm.relay.op.contrib.register import register_pattern_table # type: ignore -from tvm.relay.dataflow_pattern import wildcard, is_op, is_constant, is_tuple # type: ignore +from tvm.ir import Op from tvm.relay.build_module import bind_params_by_name # type: ignore +from tvm.relay.dataflow_pattern import ( # type: ignore + is_constant, + is_op, + is_tuple, + wildcard, +) +from tvm.relay.expr import Call, Constant # type: ignore +from tvm.relay.op.contrib.register import register_pattern_table # type: ignore try: # As ethos-u-vela package is an optional TVM dependency, we want to lazy load it @@ -197,20 +203,23 @@ class QnnConv2DParams: @requires_vela def __init__(self, func_body: tvm.relay.Function): from tvm.relay.backend.contrib.ethosu.util import QConv2DArgs # type: ignore - from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs - from tvm.relay.backend.contrib.ethosu.util import RequantArgs + from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs, RequantArgs activation = None separate_padding = None - if str(func_body.op) in self.activation_map.keys(): + if str(func_body.op.name) in self.activation_map.keys(): activation = func_body requantize_op = activation.args[0] else: requantize_op = func_body bias_add = requantize_op.args[0] qnn_conv2d = bias_add.args[0] - if isinstance(qnn_conv2d.args[0], relay.Call) and str(qnn_conv2d.args[0].op) == "nn.pad": + if ( + isinstance(qnn_conv2d.args[0], relay.Call) + and isinstance(qnn_conv2d.args[0].op, Op) + and str(qnn_conv2d.args[0].op.name) == "nn.pad" + ): separate_padding = qnn_conv2d.args[0] data_layout = qnn_conv2d.attrs.data_layout self.kernel_layout = qnn_conv2d.attrs.kernel_layout @@ -330,13 +339,14 @@ class QnnConv2DTransposeParams: @requires_vela def __init__(self, func_body: tvm.relay.Function): - from tvm.relay.backend.contrib.ethosu.util import QConv2DTransposeArgs # type: ignore - from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs - from tvm.relay.backend.contrib.ethosu.util import RequantArgs + from tvm.relay.backend.contrib.ethosu.util import ( + QConv2DTransposeArgs, # type: ignore + ) + from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs, RequantArgs requantize = func_body call = func_body.args[0] - if str(call.op) == "nn.bias_add": + if str(call.op.name) == "nn.bias_add": bias_add = call call = call.args[0] else: @@ -561,7 +571,7 @@ class MaxPool2DParams: def __init__(self, func_body: Call): clip = None - if str(func_body.op) == "clip": + if str(func_body.op.name) == "clip": clip = func_body pool_op = clip.args[0] else: @@ -617,7 +627,7 @@ class AvgPool2DParams: def __init__(self, func_body: Call): clip = None - if str(func_body.op) == "clip": + if str(func_body.op.name) == "clip": clip = func_body cast2 = clip.args[0] else: @@ -681,19 +691,21 @@ class BinaryElementwiseParams: """ def __init__(self, func_body: Call, operator_type: str, is_quantized_operation: bool): - from tvm.relay.backend.contrib.ethosu.util import BinaryElementwiseArgs - from tvm.relay.backend.contrib.ethosu.util import RequantArgs + from tvm.relay.backend.contrib.ethosu.util import ( + BinaryElementwiseArgs, + RequantArgs, + ) current_call = func_body clip = None requantize = None if is_quantized_operation: - if str(current_call.op) == "clip": + if str(current_call.op.name) == "clip": clip = current_call current_call = clip.args[0] else: - if str(current_call.op) == "qnn.requantize": + if str(current_call.op.name) == "qnn.requantize": requantize = current_call clip = current_call.args[0] current_call = clip.args[0] @@ -1101,8 +1113,7 @@ class AbsParams: composite_name = "ethos-u.abs" def __init__(self, func_body: Call): - from tvm.relay.backend.contrib.ethosu.util import QuantizeArgs - from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs + from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs, QuantizeArgs quantize = func_body abs_op = quantize.args[0] @@ -1157,8 +1168,7 @@ class LutActivationParams: """ def __init__(self, func_body: Call): - from tvm.relay.backend.contrib.ethosu.util import QuantizeArgs - from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs + from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs, QuantizeArgs layout = "NHWC" @@ -1631,18 +1641,17 @@ class FullyConnectedParams: @requires_vela def __init__(self, func_body): from tvm.relay.backend.contrib.ethosu.util import QDenseArgs # type: ignore - from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs - from tvm.relay.backend.contrib.ethosu.util import RequantArgs + from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs, RequantArgs self.activation = None - if str(func_body.op) == "clip": + if str(func_body.op.name) == "clip": self.activation = func_body requantize_op = self.activation.args[0] else: requantize_op = func_body call = requantize_op.args[0] - if str(requantize_op.args[0].op) == "nn.bias_add": + if str(requantize_op.args[0].op.name) == "nn.bias_add": bias_add = call qnn_dense = call.args[0] else: @@ -1733,8 +1742,7 @@ class HardSwishParams: composite_name = "ethos-u.hard_swish" def __init__(self, func_body): - from tvm.relay.backend.contrib.ethosu.util import QuantizeArgs - from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs + from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs, QuantizeArgs quantize = func_body divide = quantize.args[0] diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index 4008b0eb3f78..0971770e5726 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -17,15 +17,22 @@ # pylint: disable=invalid-name, unused-argument, logging-format-interpolation """TensorRT supported operators.""" import logging -from typing import Tuple, List, Dict, Union, Optional, Any, Callable +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np # type: ignore + import tvm from tvm import relay from tvm.ir import Op from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name -from tvm.relay.dataflow_pattern import is_op, wildcard, is_constant, is_tuple, is_tuple_get_item +from tvm.relay.dataflow_pattern import ( + is_constant, + is_op, + is_tuple, + is_tuple_get_item, + wildcard, +) from tvm.relay.expr import Call, Constant, TupleGetItem from tvm.relay.expr_functor import ExprMutator, ExprVisitor from tvm.relay.op.contrib.register import register_pattern_table @@ -1050,7 +1057,7 @@ def visit_call(self, call: relay.expr.Call) -> None: "mean", } if isinstance(call.op, tvm.tir.op.Op): - if str(call.op) in compute_intensive_ops: + if str(call.op.name) in compute_intensive_ops: self.is_compute_intensive = True return super().visit_call(call) diff --git a/python/tvm/runtime/_ffi_node_api.py b/python/tvm/runtime/_ffi_node_api.py index 11d317b657e6..703a12f45f4b 100644 --- a/python/tvm/runtime/_ffi_node_api.py +++ b/python/tvm/runtime/_ffi_node_api.py @@ -19,6 +19,7 @@ """FFI for tvm.node""" import tvm._ffi + # The implementations below are default ones when the corresponding # functions are not available in the runtime only mode. # They will be overriden via _init_api to the ones registered @@ -27,6 +28,10 @@ def AsRepr(obj): return obj.type_key() + "(" + obj.handle.value + ")" +def AsLegacyRepr(obj): + return obj.type_key() + "(" + obj.handle.value + ")" + + def NodeListAttrNames(obj): return lambda x: 0 diff --git a/python/tvm/runtime/object.py b/python/tvm/runtime/object.py index e522fd539b4e..6a8dd6587643 100644 --- a/python/tvm/runtime/object.py +++ b/python/tvm/runtime/object.py @@ -18,22 +18,30 @@ """Runtime Object API""" import ctypes -from tvm._ffi.base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str +from tvm._ffi.base import _FFI_MODE, _LIB, _RUNTIME_ONLY, c_str, check_call from tvm._ffi.runtime_ctypes import ObjectRValueRef + from . import _ffi_api, _ffi_node_api try: # pylint: disable=wrong-import-position,unused-import if _FFI_MODE == "ctypes": raise ImportError() - from tvm._ffi._cy3.core import _set_class_object, _set_class_object_generic - from tvm._ffi._cy3.core import ObjectBase, PyNativeObject + from tvm._ffi._cy3.core import ( + ObjectBase, + PyNativeObject, + _set_class_object, + _set_class_object_generic, + ) except (RuntimeError, ImportError) as error: # pylint: disable=wrong-import-position,unused-import if _FFI_MODE == "cython": raise error - from tvm._ffi._ctypes.packed_func import _set_class_object, _set_class_object_generic from tvm._ffi._ctypes.object import ObjectBase, PyNativeObject + from tvm._ffi._ctypes.packed_func import ( + _set_class_object, + _set_class_object_generic, + ) def _new_object(cls): @@ -49,6 +57,9 @@ class Object(ObjectBase): def __repr__(self): return _ffi_node_api.AsRepr(self) + def legacy_repr(self): + return _ffi_node_api.AsLegacyRepr(self) + def __dir__(self): class_names = dir(self.__class__) fnames = _ffi_node_api.NodeListAttrNames(self) diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index d52fbb83c368..dab7a175185d 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -28,15 +28,16 @@ assert(y.a == x) """ from typing import Optional, Union -from tvm import ir + import tvm._ffi +import tvm.ir._ffi_api +from tvm import ir +from tvm.ir import Op, PrimExpr from tvm.ir.base import Span +from tvm.runtime import DataType, DataTypeCode, Object, ObjectGeneric, const -from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const -from tvm.ir import PrimExpr, Op -import tvm.ir._ffi_api -from . import generic as _generic from . import _ffi_api +from . import generic as _generic def div_ambiguity_error(): @@ -324,6 +325,81 @@ class PrimExprWithOp(ExprOp, PrimExpr): # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__ __hash__ = PrimExpr.__hash__ + def script( + self, + *, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: Optional[int] = None, + path_to_underline=None, + ) -> str: + """Print IRModule into TVMScript + + Parameters + ---------- + indent_spaces : int + The number of indent spaces to use in the output + print_line_numbers: bool + Whether to print line numbers + num_context_lines : Optional[int] + Number of context lines to print around the underlined text + path_to_underline : Optional[ObjectPath] + Object path to be underlined + + Returns + ------- + script : str + The TVM Script of the IRModule + """ + if num_context_lines is None: + num_context_lines = -1 + return _ffi_api.PrimExprScript( # type: ignore # pylint: disable=no-member + self, indent_spaces, print_line_numbers, num_context_lines, path_to_underline + ) + + def show( + self, + *, + style: Optional[str] = None, + black_format: bool = True, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: Optional[int] = None, + path_to_underline=None, + ) -> None: + """A sugar for print highlighted TVM script. + + Parameters + ---------- + style : str, optional + Pygmentize printing style, auto-detected if None. See + `tvm.script.highlight.cprint` for more details. + black_format: bool + If true (default), use the formatter Black to format the TVMScript + indent_spaces : int + The number of indent spaces to use in the output + print_line_numbers: bool + Whether to print line numbers + num_context_lines : Optional[int] + Number of context lines to print around the underlined text + path_to_underline : Optional[ObjectPath] + Object path to be underlined + """ + from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel + cprint, + ) + + cprint( + self.script( + indent_spaces=indent_spaces, + print_line_numbers=print_line_numbers, + num_context_lines=num_context_lines, + path_to_underline=path_to_underline, + ), + style=style, + black_format=black_format, + ) + class ConstExpr(PrimExprWithOp): pass diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 082faeb456d3..fb5a37c5dc17 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -18,17 +18,18 @@ import collections import inspect -from typing import Callable, List, Mapping, Optional, Union, Tuple +from typing import Callable, List, Mapping, Optional, Tuple, Union import tvm import tvm._ffi import tvm.runtime -from tvm.runtime import Object from tvm.ir import BaseFunc, Range -from .buffer import Buffer -from .expr import Var, PrimExpr -from . import _ffi_api +from tvm.runtime import Object + from ..runtime.ndarray import NDArray +from . import _ffi_api +from .buffer import Buffer +from .expr import PrimExpr, Var @tvm._ffi.register_object("tir.PrimFunc") @@ -169,44 +170,80 @@ def mem_copy_16_16(a: T.handle, b: T.handle) -> None: """ return _ffi_api.Specialize(self, param_map) # type: ignore - def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: + def script( + self, + *, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: Optional[int] = None, + path_to_underline=None, + ) -> str: """Print IRModule into TVMScript Parameters ---------- - tir_prefix : str - The tir namespace prefix - - show_meta : bool - Whether to show meta information + indent_spaces : int + The number of indent spaces to use in the output + print_line_numbers: bool + Whether to print line numbers + num_context_lines : Optional[int] + Number of context lines to print around the underlined text + path_to_underline : Optional[ObjectPath] + Object path to be underlined Returns ------- script : str - The TVM Script of the PrimFunc + The TVM Script of the IRModule """ - return tvm._ffi.get_global_func("script.AsTVMScript")( - self, tir_prefix, show_meta - ) # type: ignore + if num_context_lines is None: + num_context_lines = -1 + return _ffi_api.PrimFuncScript( # type: ignore # pylint: disable=no-member + self, indent_spaces, print_line_numbers, num_context_lines, path_to_underline + ) - def show(self, style: Optional[str] = None, black_format: bool = True) -> None: + def show( + self, + *, + style: Optional[str] = None, + black_format: bool = True, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: Optional[int] = None, + path_to_underline=None, + ) -> None: """A sugar for print highlighted TVM script. Parameters ---------- style : str, optional - Pygmentize printing style, auto-detected if None. See `tvm.script.highlight.cprint` for more details. - black_format: bool - If true (default), use the formatter Black to format the TVMScript + indent_spaces : int + The number of indent spaces to use in the output + print_line_numbers: bool + Whether to print line numbers + num_context_lines : Optional[int] + Number of context lines to print around the underlined text + path_to_underline : Optional[ObjectPath] + Object path to be underlined """ - from tvm.script.highlight import cprint # pylint: disable=import-outside-toplevel + from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel + cprint, + ) - # Use deferred import to avoid circular import while keeping cprint under tvm/script - cprint(self, style=style, black_format=black_format) + cprint( + self.script( + indent_spaces=indent_spaces, + print_line_numbers=print_line_numbers, + num_context_lines=num_context_lines, + path_to_underline=path_to_underline, + ), + style=style, + black_format=black_format, + ) @tvm._ffi.register_object("tir.TensorIntrin") diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 48850012cbb7..64aba0e029fe 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -239,21 +239,26 @@ def fork_seed(self) -> int: """ return _ffi_api.ScheduleForkSeed(self) # type: ignore # pylint: disable=no-member - @type_checked - def show(self, rand_var: RAND_VAR_TYPE) -> str: - """Returns a string representation of the value that the random variable evaluates to + def show(self, style: Optional[str] = None, black_format: bool = True) -> None: + """A sugar for print highlighted TVM script. Parameters ---------- - rand_var : Union[ExprRV, BlockRV, LoopRV] - The random variable to be evaluated + style : str, optional - Returns - ------- - str_repr : str - The string representation + Pygmentize printing style, auto-detected if None. See + `tvm.script.highlight.cprint` for more details. + + black_format: bool + + If true (default), use the formatter Black to format the TVMScript """ - return str(self.get(rand_var)) + mod = self.mod + if mod is not None: + mod.show(style=style, black_format=black_format) + trace = self.trace + if trace is not None: + trace.show(style=style, black_format=black_format) ########## Lookup ########## diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 4847e377dec1..096c13653a94 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -41,6 +41,81 @@ class Stmt(Object): """Base class of all the statements.""" + def script( + self, + *, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: Optional[int] = None, + path_to_underline=None, + ) -> str: + """Print IRModule into TVMScript + + Parameters + ---------- + indent_spaces : int + The number of indent spaces to use in the output + print_line_numbers: bool + Whether to print line numbers + num_context_lines : Optional[int] + Number of context lines to print around the underlined text + path_to_underline : Optional[ObjectPath] + Object path to be underlined + + Returns + ------- + script : str + The TVM Script of the IRModule + """ + if num_context_lines is None: + num_context_lines = -1 + return _ffi_api.StmtScript( # type: ignore # pylint: disable=no-member + self, indent_spaces, print_line_numbers, num_context_lines, path_to_underline + ) + + def show( + self, + *, + style: Optional[str] = None, + black_format: bool = True, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: Optional[int] = None, + path_to_underline=None, + ) -> None: + """A sugar for print highlighted TVM script. + + Parameters + ---------- + style : str, optional + Pygmentize printing style, auto-detected if None. See + `tvm.script.highlight.cprint` for more details. + black_format: bool + If true (default), use the formatter Black to format the TVMScript + indent_spaces : int + The number of indent spaces to use in the output + print_line_numbers: bool + Whether to print line numbers + num_context_lines : Optional[int] + Number of context lines to print around the underlined text + path_to_underline : Optional[ObjectPath] + Object path to be underlined + """ + from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel + cprint, + ) + + cprint( + self.script( + indent_spaces=indent_spaces, + print_line_numbers=print_line_numbers, + num_context_lines=num_context_lines, + path_to_underline=path_to_underline, + ), + style=style, + black_format=black_format, + ) + @tvm._ffi.register_object("tir.LetStmt") class LetStmt(Stmt): diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 03a36e803be8..af6e47b7a066 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1288,7 +1288,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { if (a->IsInstance() && b->IsInstance()) { // cannot multiply two iterators, mark as unresolved. ErrorLogger(this) << "Product of two iterators cannot be represented as an IterMap, " - << "occurs in " << tvm::PrettyPrint(GetRef(op)); + << "occurs in " << GetRef(op); return GetRef(op); } @@ -1321,7 +1321,7 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o } auto opt_fused = TryFuseIters(sum, check_level_); if (!opt_fused) { - ErrorLogger(this) << "Dividend " << tvm::PrettyPrint(original_dividend) + ErrorLogger(this) << "Dividend " << original_dividend << ", can't be written as a single fused IterSum"; return IterSumExpr(); } @@ -1446,8 +1446,7 @@ std::pair IterMapRewriter::PadDividendToDivisor(IterSpl // since the extent covers the full padding range. left_pad = floordiv(mark_left_pad, split->lower_factor); } else { - ErrorLogger(this) << "Detect incompatible left padding on " - << tvm::PrettyPrint(NormalizeIterMapToExpr(split)) + ErrorLogger(this) << "Detect incompatible left padding on " << NormalizeIterMapToExpr(split) << ", the iter mark is left padded with " << mark_left_pad; return {IterSplitExpr(), PrimExpr()}; } @@ -1522,8 +1521,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P } else { // mark as unresolved. ErrorLogger(this) << "Cannot represent as IterMap: the numerator's scaling factor, " - << tvm::PrettyPrint(lhs->scale) << " and the divisor " - << tvm::PrettyPrint(rhs) + << lhs->scale << " and the divisor " << rhs << " cannot be simplified to remove the scaling factor."; return PrimExpr(); } @@ -1621,7 +1619,7 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, P // mark as unresolved. ErrorLogger(this) << "Cannot represent as IterMap: the left-hand side of FloorMod has a scaling factor, " - << tvm::PrettyPrint(lhs->scale) << " and the right-hand " << tvm::PrettyPrint(rhs) + << lhs->scale << " and the right-hand " << rhs << " cannot be used to simplify out the scaling factor."; return PrimExpr(); } diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 3a9224227680..e03d4302c89f 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -1274,28 +1274,28 @@ String ComputeDAG::PrintDAG(bool simple_mode) const { ICHECK_LT(k, p_reduce->combiner->result.size()); PrimExpr combiner = p_reduce->combiner->result[k]; if (combiner->IsInstance()) { - ss << " += " << LegacyTIRPrint(p_reduce->source[0]) << "\n"; + ss << " += " << AsLegacyRepr(p_reduce->source[0]) << "\n"; } else if (combiner->IsInstance()) { - ss << " max= " << LegacyTIRPrint(p_reduce->source[0]) << "\n"; + ss << " max= " << AsLegacyRepr(p_reduce->source[0]) << "\n"; } else if (combiner->IsInstance()) { - ss << " min= " << LegacyTIRPrint(p_reduce->source[0]) << "\n"; + ss << " min= " << AsLegacyRepr(p_reduce->source[0]) << "\n"; } else if (combiner->IsInstance()) { const auto& select = combiner.as(); - ss << " select(" << LegacyTIRPrint(select->condition) // - << ", " << LegacyTIRPrint(select->true_value) // - << ", " << LegacyTIRPrint(select->false_value) // - << ")= (" << LegacyTIRPrint(p_reduce->source[0]) // - << ',' << LegacyTIRPrint(p_reduce->source[1]) // + ss << " select(" << AsLegacyRepr(select->condition) // + << ", " << AsLegacyRepr(select->true_value) // + << ", " << AsLegacyRepr(select->false_value) // + << ")= (" << AsLegacyRepr(p_reduce->source[0]) // + << ',' << AsLegacyRepr(p_reduce->source[1]) // << ")\n"; } else { - ss << "reduce" << LegacyTIRPrint(combiner) << "\n"; + ss << "reduce" << AsLegacyRepr(combiner) << "\n"; } } else { auto call = pop->body[k].as(); if (simple_mode && call) { - ss << " = " << LegacyTIRPrint(call->op) << "\n"; + ss << " = " << AsLegacyRepr(call->op) << "\n"; } else { - ss << " = " << LegacyTIRPrint(pop->body[k]) << "\n"; + ss << " = " << AsLegacyRepr(pop->body[k]) << "\n"; } } } diff --git a/src/ir/adt.cc b/src/ir/adt.cc index f0ce859f3f87..3533c8c514cd 100644 --- a/src/ir/adt.cc +++ b/src/ir/adt.cc @@ -21,8 +21,9 @@ * \file src/ir/adt.cc * \brief ADT type definitions. */ -#include -#include +#include +#include +#include namespace tvm { diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index af46439cff7c..f197ac4416fa 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -53,12 +53,6 @@ DictAttrs::DictAttrs(Map dict) { data_ = std::move(n); } -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->dict; - }); - TVM_REGISTER_NODE_TYPE(DictAttrsNode); TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode); diff --git a/src/ir/error.cc b/src/ir/error.cc index f0e78b954a41..26448d04005c 100644 --- a/src/ir/error.cc +++ b/src/ir/error.cc @@ -21,15 +21,8 @@ * \file ir/error.cc * \brief Utilities for error tracking and reporting. */ - #include #include -// NOTE: reverse dependency on relay. -// These dependencies do not happen at the interface-level, -// and are only used in minimum cases where they are clearly marked. -// -// Rationale: use relay's printer for astext. -#include // clang-format off #include diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 7ba99e34d519..050d9b87a856 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -25,11 +25,6 @@ #include #include #include -// NOTE: reverse dependency on top/tir. -// These dependencies do not happen at the interface-level, -// and are only used in minimum cases where they are clearly marked. -// -// Rationale: convert from IterVar and top::Tensor #include #include @@ -168,12 +163,6 @@ TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name, Type type) { return GlobalVar(name, type); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "GlobalVar(" << node->name_hint << ")"; - }); - TVM_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) { std::stringstream ss; ss << ref; diff --git a/src/ir/function.cc b/src/ir/function.cc index dcfddd5f69d5..ce294708b2a9 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -23,12 +23,6 @@ */ #include #include -// NOTE: reverse dependency on relay, tir/ -// These dependencies do not happen at the interface-level, -// and are only used in minimum cases where they are clearly marked. -// -// Rationale: We calls into the type specific WithAttr function -#include #include namespace tvm { @@ -41,11 +35,13 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr") .set_body_typed([](BaseFunc func, String key, ObjectRef value) -> BaseFunc { if (func->IsInstance()) { return WithAttr(Downcast(std::move(func)), key, value); - } else if (func->IsInstance()) { - return WithAttr(Downcast(std::move(func)), key, value); - } else { - LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); } + if (const auto* f = runtime::Registry::Get("relay.ir.FuncWithAttr")) { + if (Optional ret = (*f)(func, key, value)) { + return ret.value(); + } + } + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); }); } // namespace tvm diff --git a/src/ir/module.cc b/src/ir/module.cc index def94a046855..b6923cd1e60d 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -23,19 +23,10 @@ */ #include #include -#include -#include -// NOTE: reverse dependency on relay. -// These dependencies do not happen at the interface-level, -// and are only used in minimum cases where they are clearly marked. -// -// Rationale: We calls into relay's analysis module to verify correctness. #include +#include #include -#include -#include -#include -#include +#include #include #include @@ -182,26 +173,11 @@ tvm::Array IRModuleNode::GetGlobalTypeVars() const { return tvm::Array(global_type_vars); } -void WarnIfMalformed(const IRModule& mod, relay::Function func) { - func = Downcast(relay::DeDup(func)); - // Type check the item before we add it to the module. - auto fv = relay::FreeVars(func); - auto ftv = relay::FreeTypeVars(func, mod); - // TODO(@jroesch): refactor to use diagnostic context - ICHECK_EQ(fv.size(), 0) << "Function:" << std::endl - << PrettyPrint(func) << std::endl - << "contains free variables: " << fv; - ICHECK_EQ(ftv.size(), 0) << "Function:" << std::endl - << PrettyPrint(func) << std::endl - << "contains free type variables: " << fv; -} - void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) { BaseFunc checked_func = f; - if (auto* ptr = f.as()) { - WarnIfMalformed(GetRef(this), GetRef(ptr)); + if (const auto* f = runtime::Registry::Get("relay.ir.WarnIfMalformed")) { + (*f)(GetRef(this), checked_func); } - AddUnchecked(var, checked_func); } @@ -212,8 +188,7 @@ void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) { if (it != global_var_map_.end()) { ICHECK_EQ((*it).second, var); } else { - ICHECK(global_var_map_.count(var->name_hint) == 0) - << "Duplicate global function name " << PrettyPrint(var); + ICHECK(global_var_map_.count(var->name_hint) == 0) << "Duplicate global function name " << var; } global_var_map_.Set(var->name_hint, var); @@ -243,7 +218,7 @@ void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& if (!update) { // set global type var map ICHECK(global_type_var_map_.count(var->name_hint) == 0) - << "Duplicate global type definition name " << PrettyPrint(var); + << "Duplicate global type definition name " << var; } global_type_var_map_.Set(var->name_hint, var); RegisterConstructors(var, type); @@ -266,7 +241,7 @@ void IRModuleNode::Remove(const GlobalVar& var) { BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const { auto it = functions.find(var); - ICHECK(it != functions.end()) << "There is no definition of " << PrettyPrint(var); + ICHECK(it != functions.end()) << "There is no definition of " << var; return (*it).second; } @@ -277,7 +252,7 @@ BaseFunc IRModuleNode::Lookup(const String& name) const { TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const { auto it = type_definitions.find(var); - ICHECK(it != type_definitions.end()) << "There is no definition of " << PrettyPrint(var); + ICHECK(it != type_definitions.end()) << "There is no definition of " << var; return (*it).second; } @@ -292,70 +267,14 @@ Constructor IRModuleNode::LookupTag(const int32_t tag) { return (*it).second; } -/*! - * \brief Renames global type/term variables to prefer the GlobalTypeVar/GlobalVar in the lhs - * ('one') side above the rhs ('two'). - */ -struct Renamer : relay::ExprMutator, TypeMutator { - Map defs; - Map types; - std::unordered_map ctors; - - Renamer(Map defs_one, Map defs_two, - Map types_one, Map types_two, - std::unordered_map ctors_one, - std::unordered_map ctor_two) { - for (auto pair : defs_one) { - defs.Set(pair.first, pair.second); - } - - for (auto pair : defs_two) { - auto it = defs.find(pair.first); - if (it == defs.end()) { - defs.Set(pair.first, pair.second); - } - } - - for (auto pair : types_one) { - types.Set(pair.first, pair.second); - } - - for (auto pair : types_two) { - auto it = types.find(pair.first); - if (it == types.end()) { - types.Set(pair.first, pair.second); - } - } - } - - relay::Expr VisitExpr_(const GlobalVarNode* node) override { return defs.at(node->name_hint); } - - Type VisitType_(const GlobalTypeVarNode* node) override { return types.at(node->name_hint); } -}; - void IRModuleNode::Update(const IRModule& mod) { - Renamer renamer(this->global_var_map_, mod->global_var_map_, this->global_type_var_map_, - mod->global_type_var_map_, this->constructor_tag_map_, mod->constructor_tag_map_); - - this->global_var_map_ = renamer.defs; - this->global_type_var_map_ = renamer.types; - this->constructor_tag_map_ = renamer.ctors; - - for (auto pair : mod->type_definitions) { - auto tvar = renamer.types.at(pair.first->name_hint); - auto ty = renamer.ExprMutator::VisitType(pair.second); - this->AddTypeDefUnchecked(tvar, Downcast(ty), true); + if (const auto* f = runtime::Registry::Get("relay.ir.IRModuleUpdateWithRenamer")) { + (*f)(GetRef(this), mod); + return; } - for (auto pair : mod->functions) { - if (auto rfn = pair.second.as()) { - auto gvar = renamer.defs.at(pair.first->name_hint); - auto fn = renamer.VisitExpr(GetRef(rfn)); - this->AddUnchecked(gvar, Downcast(fn)); - } else { - // TODO(@jroesch): rename into IRModule. - this->AddUnchecked(pair.first, pair.second); - } + // TODO(@jroesch): rename into IRModule. + this->AddUnchecked(pair.first, pair.second); } } @@ -379,8 +298,10 @@ std::pair IRModule::FromExprInContext( // Function literal has been annotated with it's required global symbol. gv_name = opt.value(); } + } else if (const auto* f = runtime::Registry::Get("relay.ir.FunctionFromExprInContext")) { + func = (*f)(expr, mod); } else { - func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {}); + LOG(FATAL) << "`relay.ir.FunctionFromExprInContext` is not registered"; } GlobalVar main_gv; @@ -418,14 +339,6 @@ void IRModuleNode::ImportFromStd(const String& path) { this->Import(std_path + "/" + path); } -Bool IRModuleNode::ShouldLinkParameters() const { - Optional executor = GetAttr(tvm::attr::kExecutor); - if (!executor.defined()) { - return Bool(false); - } - return executor.value()->ShouldLinkParameters(); -} - std::unordered_set IRModuleNode::Imports() const { return this->import_set_; } IRModule IRModule::FromText(const String& text, const String& source_path) { @@ -440,29 +353,15 @@ TVM_REGISTER_GLOBAL("ir.IRModule") return IRModule(funcs, types, {}); }); -TVM_REGISTER_GLOBAL("ir.Module_Add").set_body([](TVMArgs args, TVMRetValue* ret) { - IRModule mod = args[0]; - GlobalVar var = args[1]; - ObjectRef val = args[2]; - bool update = args[3]; - ICHECK(val->IsInstance()); - - if (val->IsInstance()) { - mod->Add(var, Downcast(val), update); - } else if (val->IsInstance()) { - GlobalVar gv = Downcast(val); - auto mod_copy = IRModule(make_object(*mod.operator->())); - mod_copy = relay::transform::EtaExpand( - /* expand_constructor */ false, - /* expand_global_var */ true)(mod_copy); - auto func = mod_copy->Lookup(gv->name_hint); - mod->Add(var, Downcast(func), update); - } else { - auto func = relay::Function({}, Downcast(val), Type(nullptr), {}); - mod->Add(var, func, update); - } - *ret = mod; -}); +TVM_REGISTER_GLOBAL("ir.Module_Add") + .set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule { + ICHECK(val->IsInstance()); + if (const auto* f = runtime::Registry::Get("relay.ir.IRModuleAdd")) { + return (*f)(mod, var, val, update); + } + mod->Add(var, Downcast(val), update); + return mod; + }); TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_method(&IRModuleNode::AddTypeDef); @@ -530,10 +429,4 @@ TVM_REGISTER_GLOBAL("ir.Module_GetAttr").set_body_typed([](IRModule mod, String return mod->GetAttr(key); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "IRModule(" << node->functions << ")"; - }); - } // namespace tvm diff --git a/src/ir/transform.cc b/src/ir/transform.cc index e0f08d28fb18..bfd0a5917556 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -377,7 +377,6 @@ IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) c VLOG_CONTEXT << pass_info->name; VLOG(0) << "Executing module pass with opt level: " << pass_info->opt_level; - VLOG(1) << "Input module:" << std::endl << PrettyPrint(mod); mod = pass_func(std::move(mod), pass_ctx); @@ -389,8 +388,6 @@ IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) c pass_ctx->diag_ctx.value().Render(); pass_ctx->diag_ctx = previous; - VLOG(1) << "Result module:" << std::endl << PrettyPrint(mod); - return mod; } diff --git a/src/ir/type.cc b/src/ir/type.cc index ee05fd03596a..d965406e8bb0 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -65,12 +65,6 @@ TVM_REGISTER_GLOBAL("ir.TypeVar").set_body_typed([](String name, int kind) { return TypeVar(name, static_cast(kind)); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypeVar(" << node->name_hint << ", " << node->kind << ")"; - }); - GlobalTypeVar::GlobalTypeVar(String name, TypeKind kind, Span span) { ObjectPtr n = make_object(); n->name_hint = std::move(name); @@ -85,12 +79,6 @@ TVM_REGISTER_GLOBAL("ir.GlobalTypeVar").set_body_typed([](String name, int kind) return GlobalTypeVar(name, static_cast(kind)); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "GlobalTypeVar(" << node->name_hint << ", " << node->kind << ")"; - }); - FuncType::FuncType(tvm::Array arg_types, Type ret_type, tvm::Array type_params, tvm::Array type_constraints, Span span) { ObjectPtr n = make_object(); @@ -110,13 +98,6 @@ TVM_REGISTER_GLOBAL("ir.FuncType") return FuncType(arg_types, ret_type, type_params, type_constraints); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "FuncType(" << node->type_params << ", " << node->arg_types << ", " - << node->ret_type << ", " << node->type_constraints << ")"; - }); - TupleType::TupleType(Array fields, Span span) { ObjectPtr n = make_object(); n->fields = std::move(fields); @@ -158,10 +139,4 @@ TVM_REGISTER_GLOBAL("ir.RelayRefType").set_body_typed([](Type value) { TVM_REGISTER_NODE_TYPE(RelayRefTypeNode); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RelayRefTypeNode(" << node->value << ")"; - }); - } // namespace tvm diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc index 4663fd90762a..c90d92f83b39 100644 --- a/src/meta_schedule/arg_info.cc +++ b/src/meta_schedule/arg_info.cc @@ -52,13 +52,12 @@ inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { } // Priority 3: The only PrimFunc in the IRModule if (num_prim_func == 0) { - LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule: " - << tir::AsTVMScript(mod); + LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule: " << mod; } if (num_prim_func > 1) { LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but none of them are " "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`" - << tir::AsTVMScript(mod); + << mod; } return GetRef(last_func); } diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index 22d6ec849c5f..b0fba5adb5c2 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -196,7 +196,7 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, } catch (std::runtime_error& e) { LOG(FATAL) << "ValueError: Unable to parse TuningRecord, on line " << (task_id + 1) << " of file " << path_tuning_record << ". The workload is:\n" - << (workload.defined() ? tir::AsTVMScript(workload->mod) : "(null)") + << (workload.defined() ? workload->mod->Script() : "(null)") << "\nThe JSONObject of TuningRecord is:\n" << json_obj << "\nThe error message is:\n" << e.what(); diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index 9d859947e4fe..404ee01983c5 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -32,7 +32,7 @@ TaskRecord::TaskRecord(TuneContext ctx, double task_weight) { << "ValueError: Require `context.space_generator`, but it is not defined"; CHECK(ctx->search_strategy.defined()) << "ValueError: Require `context.search_strategy`, but it is not defined"; - TVM_PY_LOG(INFO, ctx->logger) << "\n" << tir::AsTVMScript(ctx->mod); + TVM_PY_LOG(INFO, ctx->logger) << "\n" << ctx->mod; ctx->Initialize(); n->flop = std::max(1.0, tir::EstimateTIRFlops(ctx->mod.value())); this->data_ = std::move(n); @@ -124,7 +124,7 @@ void TaskCleanUp(TaskRecordNode* self, int task_id, const Array& r << (builder_result->error_msg.defined() ? "building" : "running") << ":\n" << err << "\n" - << tir::AsTVMScript(sch->mod()) << "\n" + << sch->mod() << "\n" << Concat(sch->trace().value()->AsPython(false), "\n"); } else { double best_ms = *std::min_element(self->latency_ms.begin(), self->latency_ms.end()); @@ -168,7 +168,7 @@ void TaskSchedulerNode::Tune(Array ctxs, Array task_weigh tir::Trace trace = sch->trace().value(); trace = trace->Simplified(true); TVM_PY_LOG(INFO, ctx->logger) << "Design space #" << i << ":\n" - << tir::AsTVMScript(sch->mod()) << "\n" + << sch->mod() << "\n" << Concat(trace->AsPython(false), "\n"); } ctx->search_strategy.value()->PreTuning(max_trials_per_task, num_trials_per_iter, design_spaces, diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 6039423844e8..9a372dde8f6d 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -48,7 +48,6 @@ #include #include -#include "../printer/text_printer.h" #include "../support/array.h" #include "../support/base64.h" #include "../support/nd_int_set.h" diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index ea263439023f..63bba67dd5f2 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -51,6 +51,28 @@ ReprPrinter::FType& ReprPrinter::vtable() { return inst; } +void ReprLegacyPrinter::Print(const ObjectRef& node) { + static const FType& f = vtable(); + if (!node.defined()) { + stream << "(nullptr)"; + } else if (f.can_dispatch(node)) { + f(node, this); + } else { + stream << node; // Use ReprPrinter + } +} + +void ReprLegacyPrinter::PrintIndent() { + for (int i = 0; i < indent; ++i) { + stream << ' '; + } +} + +ReprLegacyPrinter::FType& ReprLegacyPrinter::vtable() { + static FType inst; + return inst; +} + void Dump(const runtime::ObjectRef& n) { std::cerr << n << "\n"; } void Dump(const runtime::Object* n) { Dump(runtime::GetRef(n)); } @@ -60,4 +82,7 @@ TVM_REGISTER_GLOBAL("node.AsRepr").set_body_typed([](runtime::ObjectRef obj) { os << obj; return os.str(); }); + +TVM_REGISTER_GLOBAL("node.AsLegacyRepr").set_body_typed(runtime::AsLegacyRepr); + } // namespace tvm diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 0290b7afe3fd..80e390d9b0ad 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -314,9 +314,9 @@ class SEqualHandlerDefault::Impl { } if (assert_mode_ && !result) { LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by lhs:" << std::endl - << PrettyPrint(lhs) << std::endl + << lhs << std::endl << "and rhs:" << std::endl - << PrettyPrint(rhs); + << rhs; } return result; } diff --git a/src/printer/model_library_format_printer.cc b/src/printer/model_library_format_printer.cc index f6ac39ce79ff..4220aa00f5a4 100644 --- a/src/printer/model_library_format_printer.cc +++ b/src/printer/model_library_format_printer.cc @@ -38,9 +38,9 @@ class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode { const char* type_key() const final { return "model_library_format_printer"; } std::string Print(const ObjectRef& node) { - Doc doc; - doc << text_printer_.PrintFinal(node); - return doc.str(); + std::ostringstream oss; + oss << node; + return oss.str(); } TVMRetValue GetVarName(tir::Var var) { diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index afc76112879e..925c2ebf494e 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -409,8 +409,6 @@ class TIRTextPrinter : public StmtFunctor, Doc PrintBody(const Stmt& body, bool indent = true); }; -String AsTVMScript(const ObjectRef& mod, const String& tir_prefix = "T", bool show_meta = false); - String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, runtime::TypedPackedFunc annotate); diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 274b9542cc92..c578bc53d3d3 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -2002,16 +2002,6 @@ Doc TVMScriptPrinterWithDiagnostic::PrintLoop(const For& loop) { return res; } -String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_meta) { - ICHECK(mod->IsInstance() || mod->IsInstance()); - Doc doc; - doc << TVMScriptPrinter::PrintHeader(tir_prefix) - << TVMScriptPrinter(tir_prefix, show_meta).Print(mod); - return doc.str() + "\n"; -} - -TVM_REGISTER_GLOBAL("script.AsTVMScript").set_body_typed(AsTVMScript); - String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, runtime::TypedPackedFunc annotate) { ICHECK(mod->IsInstance() || mod->IsInstance()); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index d71cbcfc667d..154101fc94fe 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -51,7 +51,6 @@ #include #include -#include "../../printer/text_printer.h" #include "../../te/operation/create_primfunc.h" #include "../op/memory/memory.h" #include "../src/meta_schedule/module_equality.h" @@ -646,7 +645,7 @@ class ScheduleBuilder : public ExprVisitor { // (dispatch & 4): controls whether to raise fatal errors for missing TIR if (dispatch & 2) { LOG(WARNING) << "Cannot find workload: " << prim_fn_var->name_hint << "\n" - << tir::AsTVMScript(f.value()); + << f.value(); } else { LOG(WARNING) << "Cannot find workload: " << prim_fn_var->name_hint; } diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 07cfb27b1d35..3ff5eaa059c1 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -21,7 +21,11 @@ * \file src/relay/ir/function.cc * \brief Function in relay. */ +#include +#include +#include #include +#include namespace tvm { namespace relay { @@ -119,6 +123,132 @@ const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func) { } return nullptr; } +TVM_REGISTER_GLOBAL("relay.ir.PrintRelayModule") + .set_body_typed([](IRModule mod) -> Optional { + for (const auto& it : mod->functions) { + if (it.second->IsInstance()) { + return PrettyPrint(mod); + } + } + return NullOpt; + }); + +TVM_REGISTER_GLOBAL("relay.ir.WarnIfMalformed") + .set_body_typed([](const IRModule& mod, const BaseFunc& base_func) -> void { + if (const auto* relay_func = base_func.as()) { + Function func = Downcast(relay::DeDup(GetRef(relay_func))); + // Type check the item before we add it to the module. + auto fv = relay::FreeVars(func); + auto ftv = relay::FreeTypeVars(func, mod); + // TODO(@jroesch): refactor to use diagnostic context + ICHECK_EQ(fv.size(), 0) << "Function:" << std::endl + << PrettyPrint(func) << std::endl + << "contains free variables: " << fv; + ICHECK_EQ(ftv.size(), 0) << "Function:" << std::endl + << PrettyPrint(func) << std::endl + << "contains free type variables: " << fv; + } + }); +TVM_REGISTER_GLOBAL("relay.ir.IRModuleAdd") + .set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule { + if (val->IsInstance()) { + mod->Add(var, Downcast(val), update); + } else if (val->IsInstance()) { + GlobalVar gv = Downcast(val); + IRModule mod_copy(make_object(*mod.operator->())); + mod_copy = relay::transform::EtaExpand( + /* expand_constructor */ false, + /* expand_global_var */ true)(mod_copy); + auto func = mod_copy->Lookup(gv->name_hint); + mod->Add(var, Downcast(func), update); + } else { + auto func = relay::Function({}, Downcast(val), Type(nullptr), {}); + mod->Add(var, func, update); + } + return mod; + }); + +TVM_REGISTER_GLOBAL("relay.ir.IRModuleUpdateWithRenamer") + .set_body_typed([](IRModule self, IRModule mod) -> void { + struct Renamer : relay::ExprMutator, TypeMutator { + Map defs; + Map types; + std::unordered_map ctors; + + Renamer(Map defs_one, Map defs_two, + Map types_one, Map types_two, + std::unordered_map ctors_one, + std::unordered_map ctor_two) { + for (auto pair : defs_one) { + defs.Set(pair.first, pair.second); + } + + for (auto pair : defs_two) { + auto it = defs.find(pair.first); + if (it == defs.end()) { + defs.Set(pair.first, pair.second); + } + } + + for (auto pair : types_one) { + types.Set(pair.first, pair.second); + } + + for (auto pair : types_two) { + auto it = types.find(pair.first); + if (it == types.end()) { + types.Set(pair.first, pair.second); + } + } + } + + relay::Expr VisitExpr_(const GlobalVarNode* node) override { + return defs.at(node->name_hint); + } + + Type VisitType_(const GlobalTypeVarNode* node) override { + return types.at(node->name_hint); + } + }; + + Renamer renamer(self->global_var_map_, mod->global_var_map_, self->global_type_var_map_, + mod->global_type_var_map_, self->constructor_tag_map_, + mod->constructor_tag_map_); + + self->global_var_map_ = renamer.defs; + self->global_type_var_map_ = renamer.types; + self->constructor_tag_map_ = renamer.ctors; + + for (auto pair : mod->type_definitions) { + auto tvar = renamer.types.at(pair.first->name_hint); + auto ty = renamer.ExprMutator::VisitType(pair.second); + self->AddTypeDefUnchecked(tvar, Downcast(ty), true); + } + + for (auto pair : mod->functions) { + if (auto rfn = pair.second.as()) { + auto gvar = renamer.defs.at(pair.first->name_hint); + auto fn = renamer.VisitExpr(GetRef(rfn)); + self->AddUnchecked(gvar, Downcast(fn)); + } else { + // TODO(@jroesch): rename into IRModule. + self->AddUnchecked(pair.first, pair.second); + } + } + }); + +TVM_REGISTER_GLOBAL("relay.ir.FunctionFromExprInContext") + .set_body_typed([](RelayExpr expr, IRModule mod) -> Function { + return Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {}); + }); + +TVM_REGISTER_GLOBAL("relay.ir.FuncWithAttr") + .set_body_typed([](BaseFunc func, String key, ObjectRef value) -> Optional { + if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } + return NullOpt; + }); TVM_REGISTER_NODE_TYPE(FunctionNode); diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc index 5ee3bbcef48f..59f94e0cdd86 100644 --- a/src/relay/transforms/defunctionalization.cc +++ b/src/relay/transforms/defunctionalization.cc @@ -312,7 +312,7 @@ class DefuncMutator : public ExprMutator { */ std::string TypeToString(const Type& t) { std::ostringstream s; - s << t; + s << t->GetTypeKey(); return s.str(); } diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index c4ecf92e9116..5cd459be6696 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "./utils.h" namespace tvm { @@ -50,7 +52,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) BaseFunc func = kv.second; (*f)->stmts.push_back(d->AsDoc(func, p->Attr("functions")->MapValue(gv))); } - return ClassDoc(IdDoc("Module"), {IR(d)}, (*f)->stmts); + return ClassDoc(IdDoc("Module"), {IR("ir_module")}, (*f)->stmts); } }); @@ -61,14 +63,76 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](GlobalVar gv, ObjectPath p, IRDocsifier d) -> Doc { - return IdDoc("GlobalVar")->Call({LiteralDoc::Str(gv->name_hint)}); + return IR("GlobalVar")->Call({LiteralDoc::Str(gv->name_hint)}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](Op op, ObjectPath p, IRDocsifier d) -> Doc { - return IdDoc("Op")->Call({LiteralDoc::Str(op->name)}); + return IR("Op")->Call({LiteralDoc::Str(op->name)}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](TypeVar type_var, ObjectPath p, IRDocsifier d) -> Doc { + return IR("TypeVar")->Call({LiteralDoc::Str(type_var->name_hint), // + LiteralDoc::Str(TypeKind2String(type_var->kind))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](GlobalTypeVar type_var, ObjectPath p, IRDocsifier d) -> Doc { + return IR("GlobalTypeVar") + ->Call({LiteralDoc::Str(type_var->name_hint), // + LiteralDoc::Str(TypeKind2String(type_var->kind))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](RelayRefType ref, ObjectPath p, IRDocsifier d) -> Doc { + return IR("RelayRef")->Call({d->AsDoc(ref->value, p->Attr("value"))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](TensorType type, ObjectPath p, IRDocsifier d) -> Doc { + return IR("TensorType") + ->Call({d->AsDoc(type->shape, p->Attr("shape")), + LiteralDoc::DataType(type->dtype)}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](FuncType func_type, ObjectPath p, IRDocsifier d) -> Doc { + return IR("FuncType") + ->Call({ + d->AsDoc(func_type->type_params, p->Attr("type_params")), + d->AsDoc(func_type->arg_types, p->Attr("arg_types")), + d->AsDoc(func_type->ret_type, p->Attr("ret_type")), + }); }); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](IncompleteType ty, ObjectPath p, IRDocsifier d) -> Doc { + return IR("IncompleteType")->Call({}); + }); + +void ReprPrintIRModule(const ObjectRef& mod, ReprPrinter* p) { + if (const auto* f = runtime::Registry::Get("relay.ir.PrintRelayModule")) { + if (Optional s = (*f)(mod)) { + p->stream << s.value(); + return; + } + } + std::string res = + DocToPythonScript(IRDocsifier()->AsDoc(Downcast(mod), ObjectPath::Root())); + p->stream << res; +} + +TVM_SCRIPT_REPR(TypeVarNode, ReprPrintIR); +TVM_SCRIPT_REPR(GlobalTypeVarNode, ReprPrintIR); +TVM_SCRIPT_REPR(GlobalVarNode, ReprPrintIR); +TVM_SCRIPT_REPR(DictAttrsNode, ReprPrintIR); +TVM_SCRIPT_REPR(RelayRefTypeNode, ReprPrintIR); +TVM_SCRIPT_REPR(FuncTypeNode, ReprPrintIR); +TVM_SCRIPT_REPR(IncompleteTypeNode, ReprPrintIR); +TVM_SCRIPT_REPR(IRModuleNode, ReprPrintIRModule); + } // namespace printer } // namespace script } // namespace tvm diff --git a/src/script/printer/ir/script_method.cc b/src/script/printer/ir/script_method.cc new file mode 100644 index 000000000000..01d3ede7ea6c --- /dev/null +++ b/src/script/printer/ir/script_method.cc @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "./utils.h" + +namespace tvm { + +std::string IRModuleNode::Script(int indent_spaces, bool print_line_numbers, int num_context_lines, + Optional path_to_underline) const { + using namespace tvm::script::printer; + return DocToPythonScript(IRDocsifier()->AsDoc(GetRef(this), ObjectPath::Root()), + indent_spaces, print_line_numbers, num_context_lines, path_to_underline); +} + +TVM_REGISTER_GLOBAL("ir.Module_Script").set_body_method(&IRModuleNode::Script); + +} // namespace tvm diff --git a/src/script/printer/ir/utils.h b/src/script/printer/ir/utils.h index 4065b895c1bb..820fe13df3c6 100644 --- a/src/script/printer/ir/utils.h +++ b/src/script/printer/ir/utils.h @@ -28,11 +28,14 @@ #include +#include "../utils.h" + namespace tvm { namespace script { namespace printer { -inline ExprDoc IR(const IRDocsifier& d) { return IdDoc("tvm")->Attr("script"); } +/*! \brief Creates the IR common prefix, which is by default `I` */ +inline ExprDoc IR(const String& attr) { return IdDoc(Default::Prefix("ir"))->Attr(attr); } class IRFrameNode : public FrameNode { public: @@ -54,6 +57,17 @@ class IRFrame : public Frame { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRFrame, Frame, IRFrameNode); }; +inline void ReprPrintIR(const ObjectRef& obj, ReprPrinter* p) { + IRDocsifier d; + With f(d); + (*f)->AddDispatchToken(d, "ir"); + try { + p->stream << DocToPythonScript(Docsify(obj, d, *f)); + } catch (const Error& e) { + HandleUnsupportedFallback(e, obj, p); + } +} + } // namespace printer } // namespace script } // namespace tvm diff --git a/src/script/printer/legacy_repr.cc b/src/script/printer/legacy_repr.cc new file mode 100644 index 000000000000..f264dfee8d50 --- /dev/null +++ b/src/script/printer/legacy_repr.cc @@ -0,0 +1,1008 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include + +#include + +#include "../../support/str_escape.h" + +namespace tvm { + +#define TVM_LEGACY_REPR_PRINTER_DEF_OP(Type) \ + ReprLegacyPrinter& operator<<(ReprLegacyPrinter& p, Type value) { \ + p.Stream() << value; \ + return p; \ + } + +TVM_LEGACY_REPR_PRINTER_DEF_OP(int); +TVM_LEGACY_REPR_PRINTER_DEF_OP(int64_t); +TVM_LEGACY_REPR_PRINTER_DEF_OP(float); +TVM_LEGACY_REPR_PRINTER_DEF_OP(double); +TVM_LEGACY_REPR_PRINTER_DEF_OP(char); +TVM_LEGACY_REPR_PRINTER_DEF_OP(const char*); +TVM_LEGACY_REPR_PRINTER_DEF_OP(const std::string&); +TVM_LEGACY_REPR_PRINTER_DEF_OP(runtime::DataType); +TVM_LEGACY_REPR_PRINTER_DEF_OP(const void*); +TVM_LEGACY_REPR_PRINTER_DEF_OP(const String&); + +std::ostream& ReprLegacyPrinter::Stream() const { return stream; } + +ReprLegacyPrinter& operator<<(ReprLegacyPrinter& p, const ObjectRef& value) { + p.Stream() << AsLegacyRepr(value); + return p; +} + +ReprLegacyPrinter& operator<<(ReprLegacyPrinter& out, tir::ForKind type) { // NOLINT(*) + using tvm::tir::ForKind; + switch (type) { + case ForKind::kSerial: + out << "for"; + break; + case ForKind::kParallel: + out << "parallel"; + break; + case ForKind::kUnrolled: + out << "unrolled"; + break; + case ForKind::kVectorized: + out << "vectorized"; + break; + case ForKind::kThreadBinding: + out << "launch_thread"; + break; + } + return out; +} + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '['; + for (size_t i = 0; i < op->size(); ++i) { + if (i != 0) { + (*p) << ", "; + } + p->Print(op->at(i)); + } + (*p) << ']'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '{'; + for (auto it = op->begin(); it != op->end(); ++it) { + if (it != op->begin()) { + (*p) << ", "; + } + if (it->first->IsInstance()) { + (*p) << '\"' << Downcast(it->first) << "\": "; + } else { + p->Print(it->first); + (*p) << ": "; + } + p->Print(it->second); + } + (*p) << '}'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '['; + for (size_t i = 0; i < op->size; ++i) { + if (i != 0) { + (*p) << ", "; + } + (*p) << op->data[i]; + } + (*p) << ']'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + if (op->dtype == DataType::Int(32)) { + (*p) << op->value; + } else { + (*p) << "(" << op->dtype << ")" << op->value; + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + switch (op->dtype.bits()) { + case 64: + (*p) << op->value; + break; + case 32: + (*p) << op->value << 'f'; + break; + case 16: + (*p) << op->value << 'h'; + break; + default: + LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits(); + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "range(min=" << op->min << ", ext=" << op->extent << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + auto* node = static_cast(ref.get()); + (*p) << node->dtype; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + auto* node = static_cast(ref.get()); + if (!node->storage_scope.empty()) { + (*p) << node->storage_scope << " "; + } + p->Print(node->element_type); + (*p) << '*'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + auto* node = static_cast(ref.get()); + (*p) << "TupleTypeNode(" << node->fields << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + auto* node = static_cast(ref.get()); + (*p) << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << op->dict; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + auto* node = static_cast(ref.get()); + (*p) << "GlobalVar(" << node->name_hint << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + auto* node = static_cast(ref.get()); + (*p) << "IRModule(" << node->functions << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + auto* node = static_cast(ref.get()); + (*p) << "TypeVar(" << node->name_hint << ", " << node->kind << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + auto* node = static_cast(ref.get()); + (*p) << "GlobalTypeVar(" << node->name_hint << ", " << node->kind << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + auto* node = static_cast(ref.get()); + (*p) << "FuncType(" << node->type_params << ", " << node->arg_types << ", " << node->ret_type + << ", " << node->type_constraints << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + auto* node = static_cast(ref.get()); + (*p) << "RelayRefTypeNode(" << node->value << ")"; + }); + +} // namespace tvm + +namespace tvm { +namespace tir { + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "buffer(" << op->name << ", " << op << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + // omit the type + // stream << op->name << "." << op->type; + (*p) << op->name_hint; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "{" << op->name_hint << "|" << op->name_hint << ">=0}"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "iter_var("; + if (op->var->name_hint.length() != 0) { + (*p) << op->var->name_hint << ", "; + } + if (op->dom.defined()) { + (*p) << op->dom; + } + if (op->thread_tag.length() != 0) { + (*p) << ", " << op->thread_tag; + } + (*p) << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '\"' << support::StrEscape(op->value) << '\"'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << op->dtype << '('; + p->Print(op->value); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " + "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " - "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << "*"; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << "/"; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " % "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "floordiv(" << op->a << ", " << op->b << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "floormod(" << op->a << ", " << op->b << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "min("; + p->Print(op->a); + (*p) << ", "; + p->Print(op->b); + (*p) << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "max("; + p->Print(op->a); + (*p) << ", "; + p->Print(op->b); + (*p) << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " == "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " != "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " < "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " <= "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " > "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " >= "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " && "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " || "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '!'; + p->Print(op->a); + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "select("; + p->Print(op->condition); + (*p) << ", "; + p->Print(op->true_value); + (*p) << ", "; + p->Print(op->false_value); + (*p) << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << op->buffer_var << "["; + p->Print(op->index); + (*p) << "]"; + if (!is_one(op->predicate)) { + (*p) << " if "; + p->Print(op->predicate); + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "ramp("; + p->Print(op->base); + (*p) << ", "; + p->Print(op->stride); + (*p) << ", " << op->lanes << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "x" << op->lanes << "("; + p->Print(op->value); + (*p) << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "(let " << op->var << " = "; + p->Print(op->value); + (*p) << " in "; + p->Print(op->body); + (*p) << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + if (auto* ptr_op = op->op.as()) { + (*p) << ptr_op->name << "("; + } else { + auto* ptr_gvar = op->op.as(); + ICHECK(ptr_gvar != nullptr); + (*p) << "@" << ptr_gvar->name_hint << "("; + } + for (size_t i = 0; i < op->args.size(); ++i) { + p->Print(op->args[i]); + if (i < op->args.size() - 1) { + (*p) << ", "; + } + } + (*p) << ")"; + }); + +template +void PrintList(const Array& exprs, ReprLegacyPrinter* p) { + for (size_t i = 0; i < exprs.size(); ++i) { + p->Print(exprs[i]); + if (i < exprs.size() - 1) { + (*p) << ", "; + } + } +} + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "shuffle("; + PrintList(op->vectors, p); + (*p) << ", "; + PrintList(op->indices, p); + (*p) << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "comm_reducer(result=" << op->result << ", lhs=" << op->lhs << ", rhs=" << op->rhs + << ", identity_element=" << op->identity_element << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "reduce(combiner=" << op->combiner; + (*p) << ", source=" << op->source; + (*p) << ", init=" << op->init; + (*p) << ", axis=" << op->axis; + (*p) << ", where=" << op->condition; + (*p) << ", value_index=" << op->value_index; + (*p) << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { (*p) << "?"; }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << op->buffer->name << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) { + (*p) << ", "; + } + } + (*p) << "]"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << op->producer->GetNameHint() << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) { + (*p) << ", "; + } + } + (*p) << "]"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + // TODO(tvm-team) redirect to Text printer once we have a good text format. + auto* node = static_cast(ref.get()); + (*p) << "PrimFunc(" << node->params << ") "; + if (node->attrs.defined()) { + (*p) << "attrs=" << node->attrs; + } + (*p) << " {\n"; + p->indent += 2; + p->Print(node->body); + p->indent -= 2; + (*p) << "}\n"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << "let " << op->var << " = "; + p->Print(op->value); + (*p) << '\n'; + p->Print(op->body); + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << "// attr ["; + p->Print(op->node); + (*p) << "] " << op->attr_key << " = "; + p->Print(op->value); + (*p) << '\n'; + p->Print(op->body); + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << "assert("; + p->Print(op->condition); + (*p) << ", "; + p->Print(op->message); + (*p) << ")\n"; + p->Print(op->body); + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << op->kind << " (" << op->loop_var << ", "; + p->Print(op->min); + (*p) << ", "; + p->Print(op->extent); + (*p) << ") {\n"; + + p->indent += 2; + p->Print(op->body); + p->indent -= 2; + + p->PrintIndent(); + (*p) << "}\n"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << "while(" << op->condition << ") {\n"; + p->indent += 2; + p->Print(op->body); + p->indent -= 2; + p->PrintIndent(); + (*p) << "}\n"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << op->buffer_var << "["; + p->Print(op->index); + (*p) << "] = "; + p->Print(op->value); + if (!is_one(op->predicate)) { + (*p) << " if "; + p->Print(op->predicate); + } + (*p) << '\n'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << op->producer->GetNameHint() << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) (*p) << ", "; + } + (*p) << "]"; + (*p) << " ="; + p->Print(op->value); + (*p) << '\n'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + const auto* ptr_type = op->buffer_var->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + p->PrintIndent(); + (*p) << "allocate " << op->buffer_var << "[" << op->dtype; + for (size_t i = 0; i < op->extents.size(); ++i) { + (*p) << " * "; + p->Print(op->extents[i]); + } + (*p) << "], storage_scope = " << ptr_type->storage_scope; + if (!is_one(op->condition)) { + (*p) << " if "; + p->Print(op->condition); + } + (*p) << "\n"; + p->Print(op->body); + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << "constant " << op->buffer_var << "[" << op->dtype; + for (size_t i = 0; i < op->extents.size(); ++i) { + (*p) << " * "; + p->Print(op->extents[i]); + } + (*p) << "]"; + (*p) << "\n"; + p->Print(op->body); + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << "decl_buffer " << op->buffer << "\n"; + (*p) << op->body; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << "producer_realize " << op->producer->GetNameHint() << "("; + for (size_t i = 0; i < op->bounds.size(); ++i) { + (*p) << "["; + p->Print(op->bounds[i]->min); + (*p) << ", "; + p->Print(op->bounds[i]->extent); + (*p) << "]"; + if (i < op->bounds.size() - 1) (*p) << ", "; + } + (*p) << ")"; + if (!is_one(op->condition)) { + (*p) << " if "; + p->Print(op->condition); + } + (*p) << " {\n"; + + p->indent += 2; + p->Print(op->body); + p->indent -= 2; + + p->PrintIndent(); + (*p) << "}\n"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << "prefetch " << op->buffer << "("; + for (size_t i = 0; i < op->bounds.size(); ++i) { + (*p) << "["; + p->Print(op->bounds[i]->min); + (*p) << ", "; + p->Print(op->bounds[i]->extent); + (*p) << "]"; + if (i < op->bounds.size() - 1) (*p) << ", "; + } + (*p) << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + for (Stmt stmt : op->seq) { + p->Print(stmt); + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + while (true) { + (*p) << "if (" << op->condition << ") {\n"; + p->indent += 2; + p->Print(op->then_case); + p->indent -= 2; + + if (!op->else_case) { + break; + } + + if (const IfThenElseNode* nested_if = op->else_case.as()) { + p->PrintIndent(); + (*p) << "} else "; + op = nested_if; + } else { + p->PrintIndent(); + (*p) << "} else {\n"; + p->indent += 2; + p->Print(op->else_case); + p->indent -= 2; + break; + } + } + p->PrintIndent(); + (*p) << "}\n"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->Print(op->value); + (*p) << "\n"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << op->buffer->name << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) (*p) << ", "; + } + (*p) << "]"; + (*p) << " = "; + p->Print(op->value); + (*p) << '\n'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << "buffer_realize " << op->buffer->name << "("; + for (size_t i = 0; i < op->bounds.size(); ++i) { + (*p) << "["; + p->Print(op->bounds[i]->min); + (*p) << ", "; + p->Print(op->bounds[i]->extent); + (*p) << "]"; + if (i < op->bounds.size() - 1) (*p) << ", "; + } + (*p) << ")"; + if (!is_one(op->condition)) { + (*p) << " if "; + p->Print(op->condition); + } + (*p) << " {\n"; + + p->indent += 2; + p->Print(op->body); + p->indent -= 2; + + p->PrintIndent(); + (*p) << "}\n"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << op->buffer->name; + (*p) << "["; + for (size_t i = 0; i < op->region.size(); ++i) { + const auto& range = op->region[i]; + p->Print(range->min); + if (!is_one(range->extent)) { + (*p) << ":"; + p->Print(range->min + range->extent); + } + if (i != op->region.size() - 1) (*p) << ", "; + } + (*p) << "]"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << op->buffer->name << " = match_buffer("; + p->Print(op->source); + (*p) << ")\n"; + }); + +void PrintBlockTitle(const BlockNode* op, ReprLegacyPrinter* p) { + (*p) << "block " << op->name_hint << "("; + for (size_t i = 0; i < op->iter_vars.size(); i++) { + p->Print(op->iter_vars[i]); + if (i < op->iter_vars.size() - 1) (*p) << ", "; + } + (*p) << ")"; +} + +void PrintBlockSignature(const BlockNode* op, ReprLegacyPrinter* p) { + // print read/write regions + p->PrintIndent(); + (*p) << "reads("; + p->Print(op->reads); + (*p) << ")\n"; + p->PrintIndent(); + (*p) << "writes("; + p->Print(op->writes); + (*p) << ")\n"; + // Print alloc_buffers + for (const auto& alloc_buf : op->alloc_buffers) { + p->PrintIndent(); + (*p) << alloc_buf->name << " = alloc_buffer(" << alloc_buf->dtype << "["; + for (size_t i = 0; i < alloc_buf->shape.size(); ++i) { + if (i > 0) (*p) << ", "; + p->Print(alloc_buf->shape[i]); + } + (*p) << "])\n"; + } + // Print match_buffer_regions + for (const auto& match_buf : op->match_buffers) { + p->Print(match_buf); + } + if (!op->annotations.empty()) { + p->PrintIndent(); + (*p) << "annotations(" << op->annotations << ")\n"; + } +} + +void PrintBlockBody(const BlockNode* op, ReprLegacyPrinter* p) { + // Print init + if (op->init.defined()) { + p->PrintIndent(); + (*p) << "with init() {\n"; + p->indent += 2; + p->Print(op->init.value()); + p->indent -= 2; + p->PrintIndent(); + (*p) << "}\n"; + } + // Print body + p->Print(op->body); +} + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + PrintBlockTitle(op, p); + (*p) << " {\n"; + p->indent += 2; + + // Print block elements (e.g. reads/writes, etc) + PrintBlockSignature(op, p); + // Print block init and body + PrintBlockBody(op, p); + + p->indent -= 2; + p->PrintIndent(); + (*p) << "}\n"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + auto* block_op = op->block.get(); + p->PrintIndent(); + PrintBlockTitle(block_op, p); + (*p) << " {\n"; + p->indent += 2; + + // Print binding iter_values + for (size_t i = 0; i < block_op->iter_vars.size(); ++i) { + p->PrintIndent(); + (*p) << "bind("; + p->Print(block_op->iter_vars[i]->var); + (*p) << ", "; + p->Print(op->iter_values[i]); + (*p) << ")\n"; + } + // Print predicate + if (!is_one(op->predicate)) { + p->PrintIndent(); + (*p) << "where("; + p->Print(op->predicate); + (*p) << ")\n"; + } + // Print block elements (e.g. reads/writes, etc) + PrintBlockSignature(block_op, p); + // Print block init and body + PrintBlockBody(block_op, p); + + p->indent -= 2; + p->PrintIndent(); + (*p) << "}\n"; + }); + +} // namespace tir +} // namespace tvm diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc index 8f008375ff87..e7f733864cc5 100644 --- a/src/script/printer/tir/block.cc +++ b/src/script/printer/tir/block.cc @@ -140,8 +140,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return PrintBlock(d, block, p, NullOpt, NullOpt); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); +TVM_SCRIPT_REPR(tir::BlockNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::BlockRealizeNode, ReprPrintTIR); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index b9eef12abc77..5400328fe219 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -247,14 +247,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return ScopeDoc(NullOpt, prefix, (*f)->stmts); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); +TVM_SCRIPT_REPR(tir::BufferRegionNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::BufferLoadNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::BufferStoreNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::BufferNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::MatchBufferRegionNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::ProducerLoadNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::ProducerStoreNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::ProducerRealizeNode, ReprPrintTIR); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index 317201fa3d74..1f2ba97700cb 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -134,10 +134,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) Array vars; vars.reserve(n_vars + n_vars); for (int i = 0; i < n_vars; ++i) { - vars.push_back(DefineVar(r->lhs[i], *f, d)); + vars.push_back(Downcast(DefineVar(r->lhs[i], *f, d))); } for (int i = 0; i < n_vars; ++i) { - vars.push_back(DefineVar(r->rhs[i], *f, d)); + vars.push_back(Downcast(DefineVar(r->rhs[i], *f, d))); } int n_results = r->result.size(); Array results; @@ -190,7 +190,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }; ExprDoc prefix{nullptr}; if (const auto* op = call->op.as()) { - String name = op_names[GetRef(op)]; + String name = op_names.get(GetRef(op), op->name); + if (op_names.count(GetRef(op)) == 0) { + LOG(WARNING) << "No TScriptPrinterName attribute for " << op->name; + } prefix = TIR(name); } else if (const auto* gv = call->op.as()) { prefix = LiteralDoc::Str(gv->name_hint); @@ -278,39 +281,39 @@ TVM_SCRIPT_PRINTER_DEF_BINARY(Max, "max"); #undef TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR #undef TVM_SCRIPT_PRINTER_DEF_BINARY -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); +TVM_SCRIPT_REPR(tir::VarNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::SizeVarNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::IterVarNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::StringImmNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::CastNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::AddNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::SubNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::MulNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::DivNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::ModNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::FloorDivNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::FloorModNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::MinNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::MaxNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::LTNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::LENode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::EQNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::NENode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::GTNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::GENode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::AndNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::OrNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::NotNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::SelectNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::RampNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::BroadcastNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::LetNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::CallNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::ShuffleNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::CommReducerNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::AnyNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::ReduceNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::LoadNode, ReprPrintTIR); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/for_loop.cc b/src/script/printer/tir/for_loop.cc index 239b8e565f35..c8e2580f9c6f 100644 --- a/src/script/printer/tir/for_loop.cc +++ b/src/script/printer/tir/for_loop.cc @@ -62,7 +62,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return ForDoc(TupleDoc(lhs), TIR("grid")->Call(rhs), (*f)->stmts); } // Step 3. If not `T.grid`, print loop kind accordingly - IdDoc lhs = DefineVar(loop->loop_var, *f, d); + ExprDoc lhs = DefineVar(loop->loop_var, *f, d); Optional min = NullOpt; Optional max = NullOpt; Optional annotations = NullOpt; @@ -117,7 +117,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return ForDoc(lhs, rhs, (*f)->stmts); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); +TVM_SCRIPT_REPR(tir::ForNode, ReprPrintTIR); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index 55e8c075deb7..f0f84e81d57c 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -68,19 +68,27 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } // Step 4. Handle `func->body` AsDocBody(func->body, p->Attr("body"), frame->get(), d); + Optional ret_type = NullOpt; + if (func->ret_type.defined()) { + const auto* as_tuple = func->ret_type.as(); + if (!as_tuple || as_tuple->fields.size()) { + ret_type = d->AsDoc(func->ret_type, p->Attr("ret_type")); + } + } return FunctionDoc( /*name=*/IdDoc(FindFunctionName(d, func)), /*args=*/args, /*decorators=*/{TIR("prim_func")}, - /*return_type=*/d->AsDoc(func->ret_type, p->Attr("ret_type")), + /*return_type=*/ret_type, /*body=*/(*frame)->stmts); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { - std::string res = DocToPythonScript(IRDocsifier()->AsDoc(obj, ObjectPath::Root())); - p->stream << res; - }); +void ReprPrintPrimFunc(const ObjectRef& obj, ReprPrinter* p) { + std::string res = DocToPythonScript(IRDocsifier()->AsDoc(obj, ObjectPath::Root())); + p->stream << res; +} + +TVM_SCRIPT_REPR(tir::PrimFuncNode, ReprPrintPrimFunc); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/ir.cc b/src/script/printer/tir/ir.cc index 5fea278a4444..ad00c42119f6 100644 --- a/src/script/printer/tir/ir.cc +++ b/src/script/printer/tir/ir.cc @@ -89,24 +89,18 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return TIR("Tuple")->Call(d->AsDoc(ty->fields, p->Attr("fields"))->elements); }); -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](IncompleteType ty, ObjectPath p, IRDocsifier d) -> Doc { - return TIR("IncompleteType")->Call({}); - }); - TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](Target target, ObjectPath p, IRDocsifier d) -> Doc { Map config = target->Export(); return TIR("target")->Call({d->AsDoc(config, p)}); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); +TVM_SCRIPT_REPR(IntImmNode, ReprPrintTIR); +TVM_SCRIPT_REPR(FloatImmNode, ReprPrintTIR); +TVM_SCRIPT_REPR(RangeNode, ReprPrintTIR); +TVM_SCRIPT_REPR(PrimTypeNode, ReprPrintTIR); +TVM_SCRIPT_REPR(PointerTypeNode, ReprPrintTIR); +TVM_SCRIPT_REPR(TupleTypeNode, ReprPrintTIR); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/script_method.cc b/src/script/printer/tir/script_method.cc new file mode 100644 index 000000000000..5cda9a9626db --- /dev/null +++ b/src/script/printer/tir/script_method.cc @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "./utils.h" + +namespace tvm { + +std::string PrimExprNode::Script(int indent_spaces, bool print_line_numbers, int num_context_lines, + Optional path_to_underline) const { + using namespace tvm::script::printer; + IRDocsifier d; + ObjectRef obj = GetRef(this); + With f(MakeDispatchFrame(d, obj, ObjectRef(nullptr))); + return DocToPythonScript(Docsify(obj, d, *f), indent_spaces, print_line_numbers, + num_context_lines, path_to_underline); +} + +namespace tir { + +std::string StmtNode::Script(int indent_spaces, bool print_line_numbers, int num_context_lines, + Optional path_to_underline) const { + using namespace tvm::script::printer; + IRDocsifier d; + ObjectRef obj = GetRef(this); + With f(MakeDispatchFrame(d, obj, ObjectRef(nullptr))); + return DocToPythonScript(Docsify(obj, d, *f), indent_spaces, print_line_numbers, + num_context_lines, path_to_underline); +} + +std::string PrimFuncNode::Script(int indent_spaces, bool print_line_numbers, int num_context_lines, + Optional path_to_underline) const { + using namespace tvm::script::printer; + return DocToPythonScript(IRDocsifier()->AsDoc(GetRef(this), ObjectPath::Root()), + indent_spaces, print_line_numbers, num_context_lines, path_to_underline); +} + +TVM_REGISTER_GLOBAL("tir.PrimFuncScript").set_body_method(&PrimFuncNode::Script); +TVM_REGISTER_GLOBAL("tir.StmtScript").set_body_method(&StmtNode::Script); +TVM_REGISTER_GLOBAL("tir.PrimExprScript").set_body_method(&PrimExprNode::Script); + +} // namespace tir +} // namespace tvm diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index 436f2b202d85..7344cb4d98d5 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -352,19 +352,19 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) LOG(FATAL) << "ValueError: Store has been deprecated for BufferStore: " << stmt; }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); +TVM_SCRIPT_REPR(tir::LetStmtNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::AttrStmtNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::AssertStmtNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::WhileNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::AllocateNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::AllocateConstNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::DeclBufferNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::PrefetchNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::SeqStmtNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::IfThenElseNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::EvaluateNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::BufferRealizeNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::StoreNode, ReprPrintTIR); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h index 7f67c3a11c73..047513dcb316 100644 --- a/src/script/printer/tir/utils.h +++ b/src/script/printer/tir/utils.h @@ -33,6 +33,8 @@ #include #include +#include "../utils.h" + namespace tvm { namespace script { namespace printer { @@ -81,7 +83,10 @@ inline ExprDoc TIR(const String& attr) { return IdDoc(Default::Prefix("tir"))->A * \param frame The frame to define the variable in * \return The IdDoc corresponding to the variable */ -inline IdDoc DefineVar(const tir::Var& var, const Frame& frame, const IRDocsifier& d) { +inline ExprDoc DefineVar(const tir::Var& var, const Frame& frame, const IRDocsifier& d) { + if (Optional doc = d->GetVarDoc(var)) { + return doc.value(); + } return d->Define(var, frame, var->name_hint.empty() ? "v" : var->name_hint); } @@ -181,26 +186,14 @@ inline TIRFrame MakeDispatchFrame(const IRDocsifier& d, const ObjectRef& root, } /*! \brief Redirected method for the ReprPrinter */ -inline void ReprPrint(const ObjectRef& stmt, ReprPrinter* p) { +inline void ReprPrintTIR(const ObjectRef& obj, ReprPrinter* p) { IRDocsifier d; - With f(MakeDispatchFrame(d, stmt, ObjectRef(nullptr))); - Doc doc = d->AsDoc(stmt, ObjectPath::Root()); - if (const auto* expr_doc = doc.as()) { - if (!Default::VerboseExpr()) { - (*f)->stmts.clear(); - } - (*f)->stmts.push_back(ExprStmtDoc(GetRef(expr_doc))); - } else if (const auto* stmt_doc = doc.as()) { - (*f)->stmts.push_back(GetRef(stmt_doc)); - } else if (const auto* stmt_block = doc.as()) { - for (const StmtDoc& d : stmt_block->stmts) { - (*f)->stmts.push_back(d); - } - } else { - LOG(FATAL) << "TypeError: Unexpected doc type: " << doc->GetTypeKey(); + With f(MakeDispatchFrame(d, obj, ObjectRef(nullptr))); + try { + p->stream << DocToPythonScript(Docsify(obj, d, *f)); + } catch (const tvm::Error& e) { + HandleUnsupportedFallback(e, obj, p); } - std::string res = DocToPythonScript(StmtBlockDoc((*f)->stmts)); - p->stream << res; } /*! diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h new file mode 100644 index 000000000000..9f9a7d8299c4 --- /dev/null +++ b/src/script/printer/utils.h @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_UTILS_H_ +#define TVM_SCRIPT_PRINTER_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +#define TVM_SCRIPT_REPR(ObjectType, Method) \ + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(Method); + +inline StmtBlockDoc Docsify(const ObjectRef& obj, const IRDocsifier& d, const Frame& f) { + Doc doc = d->AsDoc(obj, ObjectPath::Root()); + if (const auto* expr_doc = doc.as()) { + if (!Default::VerboseExpr()) { + f->stmts.clear(); + } + f->stmts.push_back(ExprStmtDoc(GetRef(expr_doc))); + } else if (const auto* stmt_doc = doc.as()) { + f->stmts.push_back(GetRef(stmt_doc)); + } else if (const auto* stmt_block = doc.as()) { + for (const StmtDoc& d : stmt_block->stmts) { + f->stmts.push_back(d); + } + } else { + LOG(FATAL) << "TypeError: Unexpected doc type: " << doc->GetTypeKey(); + } + return StmtBlockDoc(f->stmts); +} + +inline void HandleUnsupportedFallback(const tvm::Error& error, const ObjectRef& obj, + ReprPrinter* p) { + LOG(WARNING) << "TVMScript printer falls back to the legacy ReprPrinter with the error:\n" + << error.what(); + p->stream << AsLegacyRepr(obj); +} + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_UTILS_H_ diff --git a/src/target/source/interface_c.cc b/src/target/source/interface_c.cc index fe495b212ad8..d2d1d3f78d74 100644 --- a/src/target/source/interface_c.cc +++ b/src/target/source/interface_c.cc @@ -218,8 +218,7 @@ class InterfaceCNode : public runtime::ModuleNode { code_ << '\n'; } else { - LOG(FATAL) << "No constant data in constant pool found " - << PrettyPrint(GetRef(pool_info)); + LOG(FATAL) << "No constant data in constant pool found " << GetRef(pool_info); } } diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index ce5f5d5b5357..ccc15fc1ee49 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -329,8 +329,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << "};"; code_ << "// of total size " << allocated_size << " bytes\n"; } else { - LOG(FATAL) << "No constant data in constant pool found " - << PrettyPrint(GetRef(pool_info)); + LOG(FATAL) << "No constant data in constant pool found " << GetRef(pool_info); } } diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc index 2e537450d232..de9da80140e4 100644 --- a/src/tir/analysis/control_flow_graph.cc +++ b/src/tir/analysis/control_flow_graph.cc @@ -25,6 +25,7 @@ #include "control_flow_graph.h" #include +#include #include #include #include @@ -1623,8 +1624,8 @@ bool ControlFlowGraph::IsOverwrittenWithoutEffect(const tir::BufferStore& store, } auto it = control_flow_lookup_.find(context.get()); - ICHECK(it != control_flow_lookup_.end()) - << "Context " << PrettyPrint(context) << " did not occur within analyzed statement"; + ICHECK(it != control_flow_lookup_.end()) << "Context did not occur within analyzed statement:\n" + << context; const auto& context_block = control_flow_[it->second]; auto [store_touch, free_params] = context_block.MakeBufferTouch( diff --git a/src/tir/analysis/oob_checker.cc b/src/tir/analysis/oob_checker.cc index a3d3501a9aae..dbe114df4973 100644 --- a/src/tir/analysis/oob_checker.cc +++ b/src/tir/analysis/oob_checker.cc @@ -24,7 +24,6 @@ #include #include "../../arith/ir_visitor_with_analyzer.h" -#include "../../printer/text_printer.h" #include "../schedule/error.h" namespace tvm { diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 80d6897011d5..9d932d236355 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -182,7 +182,7 @@ std::vector VerifyMemory_(const PrimFunc& func) { VLOG(1) << "verifying memory for target '" << target.value()->str() << "' for primitive:" << std::endl - << PrettyPrint(func); + << func; if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDefault) { diff --git a/src/tir/ir/legacy_printer.cc b/src/tir/ir/legacy_printer.cc deleted file mode 100644 index 4c2fd5037b65..000000000000 --- a/src/tir/ir/legacy_printer.cc +++ /dev/null @@ -1,270 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include - -#include - -#include "../../support/str_escape.h" - -namespace tvm { -namespace tir { - -std::string LegacyTIRPrint(const ObjectRef& obj) { - using namespace tvm::tir; - class LegacyTIRPrinter : private tir::ExprVisitor { - public: - explicit LegacyTIRPrinter(std::ostream& os) : stream(os) {} - - void Print(const ObjectRef& obj) { - if (const auto* op = obj.as()) { - Print_(op); - } else if (const auto* op = obj.as()) { - Print_(op); - } else if (const auto* op = obj.as()) { - Print_(op); - } else if (const auto* op = obj.as()) { - Print_(op); - } else { - VisitExpr(Downcast(obj)); - } - } - - private: - void VisitExpr_(const VarNode* op) final { stream << op->name_hint; } - - void VisitExpr_(const SizeVarNode* op) final { - stream << "{" << op->name_hint << "|" << op->name_hint << ">=0}"; - } - - void VisitExpr_(const IntImmNode* op) final { - if (op->dtype == DataType::Int(32)) { - stream << op->value; - } else { - stream << "(" << op->dtype << ")" << op->value; - } - } - - void VisitExpr_(const FloatImmNode* op) final { - switch (op->dtype.bits()) { - case 64: - stream << op->value; - break; - case 32: - stream << op->value << 'f'; - break; - case 16: - stream << op->value << 'h'; - break; - default: - LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits(); - } - } - void VisitExpr_(const StringImmNode* op) final { - stream << '\"' << support::StrEscape(op->value) << '\"'; - } - void VisitExpr_(const CastNode* op) final { - stream << op->dtype << '('; - VisitExpr(op->value); - stream << ')'; - } - void VisitExpr_(const AddNode* op) final { PrintBinary(op->a, op->b, " + "); } - void VisitExpr_(const SubNode* op) final { PrintBinary(op->a, op->b, " - "); } - void VisitExpr_(const MulNode* op) final { PrintBinary(op->a, op->b, "*"); } - void VisitExpr_(const DivNode* op) final { PrintBinary(op->a, op->b, "/"); } - void VisitExpr_(const ModNode* op) final { PrintBinary(op->a, op->b, " % "); } - void VisitExpr_(const FloorDivNode* op) final { PrintCall("floordiv", op->a, op->b); } - void VisitExpr_(const FloorModNode* op) final { PrintCall("floormod", op->a, op->b); } - void VisitExpr_(const MinNode* op) final { PrintCall("min", op->a, op->b); } - void VisitExpr_(const MaxNode* op) final { PrintCall("max", op->a, op->b); } - void VisitExpr_(const EQNode* op) final { PrintBinary(op->a, op->b, " == "); } - void VisitExpr_(const NENode* op) final { PrintBinary(op->a, op->b, " != "); } - void VisitExpr_(const LTNode* op) final { PrintBinary(op->a, op->b, " < "); } - void VisitExpr_(const LENode* op) final { PrintBinary(op->a, op->b, " <= "); } - void VisitExpr_(const GTNode* op) final { PrintBinary(op->a, op->b, " > "); } - void VisitExpr_(const GENode* op) final { PrintBinary(op->a, op->b, " >= "); } - void VisitExpr_(const AndNode* op) final { PrintBinary(op->a, op->b, " && "); } - void VisitExpr_(const OrNode* op) final { PrintBinary(op->a, op->b, " || "); } - - void VisitExpr_(const NotNode* op) final { - stream << "!"; - VisitExpr(op->a); - } - - void VisitExpr_(const SelectNode* op) final { - stream << "select("; - VisitExpr(op->condition); - stream << ", "; - VisitExpr(op->true_value); - stream << ", "; - VisitExpr(op->false_value); - stream << ')'; - } - - void VisitExpr_(const RampNode* op) final { - stream << "ramp("; - VisitExpr(op->base); - stream << ", "; - VisitExpr(op->stride); - stream << ", " << op->lanes << ')'; - } - - void VisitExpr_(const BroadcastNode* op) final { - stream << "x" << op->lanes << "("; - VisitExpr(op->value); - stream << ")"; - } - - void VisitExpr_(const LetNode* op) final { - stream << "(let " << op->var << " = "; - VisitExpr(op->value); - stream << " in "; - VisitExpr(op->body); - stream << ")"; - } - - void VisitExpr_(const CallNode* op) final { - if (auto* ptr_op = op->op.as()) { - stream << ptr_op->name << "("; - } else { - auto* p = op->op.as(); - ICHECK(p != nullptr); - stream << "@" << p->name_hint << "("; - } - for (size_t i = 0; i < op->args.size(); ++i) { - VisitExpr(op->args[i]); - if (i < op->args.size() - 1) { - stream << ", "; - } - } - stream << ")"; - } - - void VisitExpr_(const ShuffleNode* op) final { - stream << "shuffle("; - PrintList(op->vectors.GetArrayNode()); - stream << ", "; - PrintList(op->indices.GetArrayNode()); - stream << ")"; - } - - void VisitExpr_(const ReduceNode* op) final { - stream << "reduce(combiner="; - Print_(op->combiner.get()); - stream << ", source="; - PrintList(op->source.GetArrayNode()); - stream << ", init="; - PrintList(op->init.GetArrayNode()); - stream << ", axis="; - PrintList(op->axis.GetArrayNode()); - stream << ", where="; - VisitExpr(op->condition); - stream << ", value_index=" << op->value_index; - stream << ")"; - } - - void VisitExpr_(const AnyNode* op) final { stream << "?"; } - - void VisitExpr_(const BufferLoadNode* op) final { - stream << op->buffer->name << "["; - for (size_t i = 0; i < op->indices.size(); ++i) { - VisitExpr(op->indices[i]); - if (i < op->indices.size() - 1) { - stream << ", "; - } - } - stream << "]"; - } - - void VisitExpr_(const ProducerLoadNode* op) final { - stream << op->producer->GetNameHint() << "["; - for (size_t i = 0; i < op->indices.size(); ++i) { - VisitExpr(op->indices[i]); - if (i < op->indices.size() - 1) { - stream << ", "; - } - } - stream << "]"; - } - - private: - void Print_(const CommReducerNode* op) { - stream << "comm_reducer(result="; - PrintList(op->result.GetArrayNode()); - stream << ", lhs="; - PrintList(op->lhs.GetArrayNode()); - stream << ", rhs="; - PrintList(op->rhs.GetArrayNode()); - stream << ", identity_element="; - PrintList(op->identity_element.GetArrayNode()); - stream << ")"; - } - - void Print_(const IterVarNode* op) { - stream << "{" << op->var->name_hint << "|" << op->var->name_hint << " in ["; - VisitExpr(op->dom->min); - stream << ", "; - VisitExpr(op->dom->extent); - stream << ")}"; - } - - void Print_(const RangeNode* op) { - stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; - } - - void Print_(const OpNode* op) { stream << "Op(" << op->name << ")"; } - - private: - void PrintBinary(const PrimExpr& a, const PrimExpr& b, const std::string& sign) { - stream << '('; - VisitExpr(a); - stream << sign; - VisitExpr(b); - stream << ')'; - } - - void PrintCall(const std::string& call, const PrimExpr& a, const PrimExpr& b) { - stream << call << '('; - VisitExpr(a); - stream << ", "; - VisitExpr(b); - stream << ')'; - } - - void PrintList(const ArrayNode* exprs) { - int n = static_cast(exprs->size()); - for (int i = 0; i < n; ++i) { - VisitExpr(Downcast(exprs->at(i))); - if (i < n - 1) { - stream << ", "; - } - } - } - - std::ostream& stream; - }; - std::ostringstream os; - LegacyTIRPrinter(os).Print(obj); - return os.str(); -} - -} // namespace tir -} // namespace tvm diff --git a/src/tir/schedule/analysis/verify.cc b/src/tir/schedule/analysis/verify.cc index e9ee7227f6fb..ef45f7f8c701 100644 --- a/src/tir/schedule/analysis/verify.cc +++ b/src/tir/schedule/analysis/verify.cc @@ -234,7 +234,7 @@ void VerifyCachedFlags(const ScheduleState& self) { os << std::endl; } LOG(FATAL) << "Schedule verification failed. The IR is:\n" - << AsTVMScript(self->mod) << "\nThe errors are:\n" + << self->mod << "\nThe errors are:\n" << os.str(); throw; } diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc index 32e5c2455a85..55d751c3311e 100644 --- a/src/tir/schedule/error.cc +++ b/src/tir/schedule/error.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include "../../printer/text_printer.h" #include "./utils.h" namespace tvm { diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index e4771c8b19f6..d21149437f08 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -225,11 +225,11 @@ class ProducerHasNonTrivialPredicateError : public ScheduleError { } String DetailRenderTemplate() const final { - return "ScheduleError: The producer block {0} has a non-trivial predicate " + - PrettyPrint(producer_->predicate) + - " that cannot be implied " - "by the synthesized predicate " + - PrettyPrint(new_predicate_) + " of the new inlined block."; + std::ostringstream os; + os << "ScheduleError: The producer block {0} has a non-trivial predicate " + << producer_->predicate << " that cannot be implied by the synthesized predicate " + << new_predicate_ << " of the new inlined block."; + return os.str(); } IRModule mod() const final { return mod_; } diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index a9b367c4b7d9..6aff85da720d 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -17,6 +17,8 @@ * under the License. */ +#include + #include #include @@ -1266,7 +1268,7 @@ class OpaqueNewIterTypeError : public ScheduleError { String DetailRenderTemplate() const final { std::ostringstream os; - os << "Cannot detect the block iter type for new iter value " << PrettyPrint(iter_value_) + os << "Cannot detect the block iter type for new iter value " << iter_value_ << " in {0} because it contains more than one type of original iter vars."; return os.str(); } diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index bcc8b7facbc9..d40906209fb9 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -38,7 +39,6 @@ #include "../../arith/pattern_match.h" #include "../../node/attr_registry.h" -#include "../../printer/text_printer.h" #include "../../runtime/thread_storage_scope.h" #include "../../support/array.h" #include "../../support/nd_int_set.h" diff --git a/src/tir/transforms/common_subexpr_elim.cc b/src/tir/transforms/common_subexpr_elim.cc index 5cf6f231dd80..acda9220b731 100644 --- a/src/tir/transforms/common_subexpr_elim.cc +++ b/src/tir/transforms/common_subexpr_elim.cc @@ -151,8 +151,8 @@ bool CommonSubexpressionEliminator::OrderOnExprAndFrequency(std::pair> SyntacticToSemanticComputations( [](std::pair a, std::pair b) { std::stringstream a_stream; std::stringstream b_stream; - a_stream << LegacyTIRPrint(a.first); - b_stream << LegacyTIRPrint(b.first); + a_stream << AsLegacyRepr(a.first); + b_stream << AsLegacyRepr(b.first); return a_stream.str().compare(b_stream.str()) < 0; }); diff --git a/src/tir/transforms/install_debug_spans.cc b/src/tir/transforms/install_debug_spans.cc index 4daa1aafe8cc..bc9002ee841f 100644 --- a/src/tir/transforms/install_debug_spans.cc +++ b/src/tir/transforms/install_debug_spans.cc @@ -23,7 +23,7 @@ the location to which the ops would be printed */ -#include "install_debug_spans.h" +#include "./install_debug_spans.h" #include diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index e1dc2f5bf113..e9c57eb78e26 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -30,7 +30,6 @@ #include "../../arith/ir_mutator_with_analyzer.h" #include "../../arith/ir_visitor_with_analyzer.h" -#include "../../printer/text_printer.h" namespace tvm { namespace tir { diff --git a/src/tir/usmp/transform/assign_pool_info.cc b/src/tir/usmp/transform/assign_pool_info.cc index 2bded7b4877b..3acceab6e31b 100644 --- a/src/tir/usmp/transform/assign_pool_info.cc +++ b/src/tir/usmp/transform/assign_pool_info.cc @@ -99,7 +99,7 @@ class PoolInfoAssigner : public StmtExprMutator { }; WorkspacePoolInfo PoolInfoAssigner::CreateDefaultWorkspaceMemoryPool(const tvm::IRModule& module) { - VLOG(1) << "Creating default memory pool for:" << std::endl << PrettyPrint(module); + VLOG(1) << "Creating default memory pool for:" << std::endl << module; Map target_access; tir::PrimFunc tir_main_func = Downcast(module->Lookup(::tvm::runtime::symbol::tvm_module_main)); @@ -134,7 +134,7 @@ Stmt PoolInfoAssigner::VisitStmt_(const AllocateNode* op) { Map annotations = Map(op->annotations); if (op->annotations.find(kPoolCandidatesAllocateAttr) == op->annotations.end()) { ICHECK(target_pool_infos_.count(tgt.value()->str()) > 0) - << "Target " << PrettyPrint(tgt) << " not found among " << PrettyPrint(target_pool_infos_); + << "Target " << tgt << " not found among " << target_pool_infos_; annotations.Set(kPoolCandidatesAllocateAttr, target_pool_infos_[tgt.value()->str()]); } Stmt body = VisitStmt(op->body); diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index a70e091b2cee..0728840ee96b 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -14,20 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import pytest import numpy as np +import pytest pytest.importorskip("ethosu.vela") import tvm from tvm import relay -from tvm.script import tir as T -from tvm.relay.testing import run_opt_pass -from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir -from tvm.relay.backend.contrib.ethosu.tir.scheduler import OperatorCompute -from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator +from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir +from tvm.relay.backend.contrib.ethosu.tir.scheduler import ( + OperatorCompute, + copy_constants, +) +from tvm.relay.testing import run_opt_pass +from tvm.script import tir as T -from .infra import make_ethosu_conv2d, make_ethosu_binary_elementwise +from .infra import make_ethosu_binary_elementwise, make_ethosu_conv2d # fmt: off @@ -140,7 +142,7 @@ def _get_func(): with tvm.transform.PassContext(config={"relay.ext.ethos-u.options": config}): func = _get_func() mod, consts = _lower_to_tir(func, cascader=_planner) - script = mod.script(show_meta=True) + script = mod.script() test_mod = tvm.script.from_source(script) tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) @@ -242,7 +244,7 @@ def _get_func(): with tvm.transform.PassContext(config={"relay.ext.ethos-u.options": config}): func = _get_func() mod, consts = _lower_to_tir(func, cascader=_cascader) - script = mod.script(show_meta=True) + script = mod.script() test_mod = tvm.script.from_source(script) tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) @@ -340,7 +342,7 @@ def _get_func(): func = _get_func() mod, consts = _lower_to_tir(func) - script = mod.script(show_meta=True) + script = mod.script() test_mod = tvm.script.from_source(script) tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) @@ -474,7 +476,7 @@ def _get_func(): func = _get_func() mod, consts = _lower_to_tir(func, cascader=_planner) - script = mod.script(show_meta=True) + script = mod.script() test_mod = tvm.script.from_source(script) tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) diff --git a/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py b/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py index 91458f60e172..062637b3bb94 100644 --- a/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py +++ b/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py @@ -81,6 +81,6 @@ def expected(): exp = expected() global_vars = [str(gv) for gv in after.get_global_vars()] - assert "@ext_func" in global_vars - assert "@ext_func_2" not in global_vars + assert 'I.GlobalVar("ext_func")' in global_vars + assert 'I.GlobalVar("ext_func_2")' not in global_vars assert tvm.ir.structural_equal(after["ext_func"], exp["ext_func"]) diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py index 4b4ba52b86f6..b8ce7f0d60c9 100644 --- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py +++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py @@ -19,10 +19,11 @@ pytest.importorskip("ethosu.vela") import tvm import tvm.script -from tvm.script import tir as T from tvm import relay -from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir +from tvm.relay.testing import run_opt_pass +from tvm.script import tir as T + from .infra import make_ethosu_conv2d @@ -73,7 +74,7 @@ def _get_func(): func = _get_func() mod, _ = _lower_to_tir(func) - script = mod.script(show_meta=True) + script = mod.script() test_mod = tvm.script.from_source(script) reference_mod = ReferenceModule diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 649f2a611d50..bdc0447bc718 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -18,11 +18,12 @@ pytest.importorskip("ethosu.vela") import tvm -from tvm.script import tir as T from tvm import relay -from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir from tvm.relay.backend.contrib.ethosu.tir.scheduler import total_cascader +from tvm.relay.testing import run_opt_pass +from tvm.script import tir as T + from .infra import make_ethosu_conv2d @@ -634,7 +635,7 @@ def _get_func( params = trial[1:] func = _get_func(*params[:-1]) mod, _ = _lower_to_tir(func, cascader=total_cascader(params[-1])) - script = mod.script(show_meta=True) + script = mod.script() mod = tvm.script.from_source(script) tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) @@ -693,7 +694,7 @@ def _get_func(ifm_shape, lower, upper, ofm_channels=16): params = trial[1:] func = _get_func(*params) mod, _ = _lower_to_tir(func) - script = mod.script(show_meta=True) + script = mod.script() mod = tvm.script.from_source(script) tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) @@ -795,7 +796,7 @@ def _get_func(ifm_shape, reshaped, ifm_layout): params = trial[1:] func = _get_func(*params) mod, _ = _lower_to_tir(func, cascader=total_cascader((1, 4, 6, 16))) - script = mod.script(show_meta=True) + script = mod.script() mod = tvm.script.from_source(script) tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 07124c62ae8b..e23954f4cb67 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -18,11 +18,14 @@ pytest.importorskip("ethosu.vela") import tvm -from tvm.script import tir as T from tvm import relay -from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir -from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants, OperatorCompute +from tvm.relay.backend.contrib.ethosu.tir.scheduler import ( + OperatorCompute, + copy_constants, +) +from tvm.relay.testing import run_opt_pass +from tvm.script import tir as T from .infra import make_ethosu_conv2d @@ -65,7 +68,7 @@ def _get_func(): func = _get_func() mod, _ = _lower_to_tir(func, cascader=copy_constants()) - script = mod.script(show_meta=True) + script = mod.script() test_mod = tvm.script.from_source(script) reference_mod = ReferenceModule tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) @@ -125,7 +128,7 @@ def _get_func(): func = _get_func() mod, _ = _lower_to_tir(func, cascader=_cascader) - script = mod.script(show_meta=True) + script = mod.script() test_mod = tvm.script.from_source(script) reference_mod = WeightStream tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index 9e39821fd317..6b3da0fd06d5 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. -import numpy as np -import pytest import itertools import logging from typing import Tuple +import numpy as np +import pytest + try: # See issue #9362. import torch @@ -28,13 +29,12 @@ pass import tvm -import tvm.testing import tvm.relay.testing - +import tvm.testing from tvm import relay +from tvm.contrib.download import download from tvm.relay import Any, GlobalVar from tvm.relay.expr_functor import ExprVisitor -from tvm.contrib.download import download from tvm.relay.op.contrib import tensorrt SUPPORTED_DTYPES = ["float16", "float32"] @@ -615,7 +615,7 @@ def __init__(self, op_list): def visit_call(self, call): if isinstance(call.op, tvm.tir.op.Op): - if str(call.op) in self.op_list: + if str(call.op.name) in self.op_list: self.on_graph = True return super().visit_call(call) diff --git a/tests/python/contrib/test_uma/test_partition.py b/tests/python/contrib/test_uma/test_partition.py index ec2107f881bc..d02903610933 100644 --- a/tests/python/contrib/test_uma/test_partition.py +++ b/tests/python/contrib/test_uma/test_partition.py @@ -16,15 +16,12 @@ # under the License. import pytest - import tvm import tvm.relay as relay - +from tvm.relay.backend.contrib.uma import uma_available from tvm.relay.backend.contrib.uma.api import UMAPartitioner from tvm.relay.op.contrib.register import get_pattern_table -from tvm.relay.testing import resnet, mlp -from tvm.relay.backend.contrib.uma import uma_available - +from tvm.relay.testing import mlp, resnet pytestmark = pytest.mark.skipif(not uma_available(), reason="UMA not available") diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 1fae75f23eae..e9fbe12e9754 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -17,27 +17,24 @@ """ Tests on quantized torch model conversion """ import os -from PIL import Image - import numpy as np - import torch +import tvm +import tvm.testing +from PIL import Image from torch import nn from torch.quantization import ( - QuantStub, DeQuantStub, - fuse_modules, + QuantStub, QuantWrapper, - prepare_qat, + fuse_modules, get_default_qat_qconfig, + prepare_qat, ) - -import tvm -import tvm.testing from tvm import relay -from tvm.relay.frontend.pytorch_utils import is_version_greater_than from tvm.contrib.download import download_testdata -from tvm.relay.op.contrib.register import register_pattern_table, get_pattern_table +from tvm.relay.frontend.pytorch_utils import is_version_greater_than +from tvm.relay.op.contrib.register import get_pattern_table, register_pattern_table def torch_version_check(): @@ -66,8 +63,10 @@ def get_tvm_runtime(script_module, input_name, ishape, keep_quantized_weight=Fal def get_qconfig(per_channel): - from torch.quantization.observer import MovingAverageMinMaxObserver - from torch.quantization.observer import default_weight_observer + from torch.quantization.observer import ( + MovingAverageMinMaxObserver, + default_weight_observer, + ) if per_channel: return torch.quantization.get_default_qconfig("fbgemm") @@ -396,11 +395,13 @@ def get_imagenet_input(): pt_tensor = preprocess(im) return np.expand_dims(pt_tensor.numpy(), 0) - from torchvision.models.quantization import resnet as qresnet - from torchvision.models.quantization import mobilenet as qmobilenet - from torchvision.models.quantization import inception as qinception from torchvision.models.quantization import googlenet as qgooglenet - from torchvision.models.quantization import mobilenet_v3_large as qmobilenet_v3_large + from torchvision.models.quantization import inception as qinception + from torchvision.models.quantization import mobilenet as qmobilenet + from torchvision.models.quantization import ( + mobilenet_v3_large as qmobilenet_v3_large, + ) + from torchvision.models.quantization import resnet as qresnet per_channel = True qmodels = [ @@ -596,7 +597,7 @@ def forward(self, inp): def make_qnn_add_pattern(): - from tvm.relay.dataflow_pattern import wildcard, is_op + from tvm.relay.dataflow_pattern import is_op, wildcard lhs = wildcard() rhs = wildcard() @@ -782,7 +783,7 @@ def forward(self, input): assert isinstance(output, relay.Tuple) and len(output) == 2 dq1, dq2 = output - assert str(dq1.op) == "qnn.dequantize" and str(dq2.op) == "qnn.dequantize" + assert dq1.op.name == "qnn.dequantize" and dq2.op.name == "qnn.dequantize" scale1 = dq1.args[1].data.numpy().item() scale2 = dq2.args[1].data.numpy().item() assert scale1 != scale2 diff --git a/tests/python/unittest/test_arith_deduce_bound.py b/tests/python/unittest/test_arith_deduce_bound.py index 0915df3051db..d5e0303b05b2 100644 --- a/tests/python/unittest/test_arith_deduce_bound.py +++ b/tests/python/unittest/test_arith_deduce_bound.py @@ -64,14 +64,14 @@ def test_deduce(): e2 = tvm.te.max(5, a * 4) < 0 res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) - assert str(res2.max_value) == "neg_inf: handle" - assert str(res2.min_value) == "pos_inf: handle" + assert str(res2.max_value) == "neg_inf" + assert str(res2.min_value) == "pos_inf" # expression containing variable a is on rhs e2 = zero < tvm.te.max(5, a * 4) res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) - assert str(res2.max_value) == "neg_inf: handle" - assert str(res2.min_value) == "pos_inf: handle" + assert str(res2.max_value) == "neg_inf" + assert str(res2.min_value) == "pos_inf" e3 = (-b) + a * c - d res3 = tvm.arith.deduce_bound(a, e3 >= 0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) @@ -88,8 +88,8 @@ def test_deduce(): # Unsatisfiable `EQ`, variable as one of the Operand res5 = tvm.arith.deduce_bound(a, (a == b), {b: b_s}, {b: b_s}) - assert str(res5.max_value) == "neg_inf: handle" - assert str(res5.min_value) == "pos_inf: handle" + assert str(res5.max_value) == "neg_inf" + assert str(res5.min_value) == "pos_inf" # variable `a` on the RHS side res6 = tvm.arith.deduce_bound(a, 10 == a, {}, {}) @@ -111,15 +111,15 @@ def test_deduce(): # Unsatisfiable Mul in `EQ` e5 = 4 * a == b res9 = tvm.arith.deduce_bound(a, e5, {b: b_s}, {}) - assert str(res9.max_value) == "neg_inf: handle" - assert str(res9.min_value) == "pos_inf: handle" + assert str(res9.max_value) == "neg_inf" + assert str(res9.min_value) == "pos_inf" # Unsatisfiable Mul in `EQ` res10 = tvm.arith.deduce_bound( a, (b * a == b), {b: b_s}, {} ) # simplifier is not able to prove that (b % b == 0) - assert str(res10.max_value) == "neg_inf: handle" - assert str(res10.min_value) == "pos_inf: handle" + assert str(res10.max_value) == "neg_inf" + assert str(res10.min_value) == "pos_inf" def test_check(): diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py index 6d4dcd996475..bb9602279404 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py @@ -22,6 +22,7 @@ from tvm.meta_schedule.testing.space_generation import ( check_sketches, generate_design_space, + print_sketches, ) from tvm.script import tir as T from tvm.target import Target @@ -625,6 +626,97 @@ def cpu_conv2d_nhwc( def test_cache_read_specify_consumer(): + @T.prim_func + def cache_read_specify_consumer_0( + A: T.Buffer((512, 512), "float32"), + B: T.Buffer((512, 512), "float32"), + T_add: T.Buffer((512, 512), "float32"), + ): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + C = T.alloc_buffer((512, 512)) + C_local = T.alloc_buffer((512, 512), scope="local") + A_shared = T.alloc_buffer((512, 512), scope="shared") + B_shared = T.alloc_buffer((512, 512), scope="shared") + for i_0_j_0_fused in T.thread_binding(2, thread="blockIdx.x"): + for i_1_j_1_fused in T.thread_binding(512, thread="vthread.x"): + for i_2_j_2_fused in T.thread_binding(16, thread="threadIdx.x"): + for k_0 in range(2): + for ax0_ax1_fused in range(131072): + with T.block("A_shared"): + v0 = T.axis.spatial(512, ax0_ax1_fused // 256) + v1 = T.axis.spatial(512, k_0 * 256 + ax0_ax1_fused % 256) + T.reads(A[v0, v1]) + T.writes(A_shared[v0, v1]) + T.block_attr({"meta_schedule.cooperative_fetch": 2}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused in range(65536): + with T.block("B_shared"): + v0 = T.axis.spatial(512, k_0 * 256 + ax0_ax1_fused // 256) + v1 = T.axis.spatial(512, i_0_j_0_fused * 256 + ax0_ax1_fused % 256) + T.reads(B[v0, v1]) + T.writes(B_shared[v0, v1]) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) + B_shared[v0, v1] = B[v0, v1] + for k_1, i_3, j_3, k_2, i_4, j_4 in T.grid(64, 1, 1, 4, 1, 16): + with T.block("C"): + v_i = T.axis.spatial( + 512, + i_1_j_1_fused // 8 * 8 + i_2_j_2_fused // 2 + i_3 + i_4, + ) + v_j = T.axis.spatial( + 512, + i_0_j_0_fused * 256 + + i_1_j_1_fused % 8 * 32 + + i_2_j_2_fused % 2 * 16 + + j_3 * 16 + + j_4, + ) + v_k = T.axis.reduce(512, k_0 * 256 + k_1 * 4 + k_2) + T.reads(A_shared[v_i, v_k], B_shared[v_k, v_j]) + T.writes(C_local[v_i, v_j]) + T.block_attr( + { + "meta_schedule.thread_extent_high_inclusive": 1024, + "meta_schedule.thread_extent_low_inclusive": 32, + "meta_schedule.tiling_structure": "SSSRRSRS", + } + ) + with T.init(): + C_local[v_i, v_j] = T.float32(0) + C_local[v_i, v_j] = ( + C_local[v_i, v_j] + A_shared[v_i, v_k] * B_shared[v_k, v_j] + ) + for ax0, ax1 in T.grid(1, 16): + with T.block("C_local"): + v0 = T.axis.spatial( + 512, + i_1_j_1_fused // 8 * 8 + i_2_j_2_fused // 2 + ax0, + ) + v1 = T.axis.spatial( + 512, + i_0_j_0_fused * 256 + + i_1_j_1_fused % 8 * 32 + + i_2_j_2_fused % 2 * 16 + + ax1, + ) + T.reads(C_local[v0, v1]) + T.writes(C[v0, v1]) + C[v0, v1] = C_local[v0, v1] + for ax0, ax1 in T.grid(512, 512): + with T.block("T_add"): + v_ax0 = T.axis.spatial(512, ax0) + v_ax1 = T.axis.spatial(512, ax1) + T.reads(C[v_ax0, v_ax1], A[v_ax0, v_ax1]) + T.writes(T_add[v_ax0, v_ax1]) + T_add[v_ax0, v_ax1] = C[v_ax0, v_ax1] + A[v_ax0, v_ax1] + + decision_0 = [ + ("SamplePerfectTile", [1, 64, 8, 1, 1]), + ("SamplePerfectTile", [2, 8, 2, 1, 16]), + ("SamplePerfectTile", [2, 64, 4]), + ("SampleCategorical", 1), + ("SampleCategorical", 2), + ] A, B, C = te_workload.matmul(512, 512, 512) mod = te.create_prim_func([A, B, C + A]) @@ -634,17 +726,12 @@ def test_cache_read_specify_consumer(): target=Target("nvidia/geforce-rtx-3080"), types=ms.schedule_rule.MultiLevelTiling, ) - - residual_block = """ - for ax0, ax1 in T.grid(512, 512): - with T.block("T_add"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(C[v_ax0, v_ax1], A[v_ax0, v_ax1]) - T.writes(T_add[v_ax0, v_ax1]) - T_add[v_ax0, v_ax1] = C[v_ax0, v_ax1] + A[v_ax0, v_ax1] - """ - - assert residual_block in space[0].mod.script() + check_sketches( + mod, + sketches=space, + expected_mods=[cache_read_specify_consumer_0], + expected_decisions=[decision_0], + ) def test_max_pool_blocked(): @@ -798,4 +885,5 @@ def max_pool_blocked_compute(height, width, channel): if __name__ == "__main__": - tvm.testing.main() + # tvm.testing.main() + test_cache_read_specify_consumer() diff --git a/tests/python/unittest/test_te_schedule.py b/tests/python/unittest/test_te_schedule.py index 8b504df120e0..69478b451893 100644 --- a/tests/python/unittest/test_te_schedule.py +++ b/tests/python/unittest/test_te_schedule.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pickle as pkl + import pytest import tvm from tvm import te -import pickle as pkl def test_schedule_create(): @@ -297,8 +298,8 @@ def intrin_func(ins, outs, sp): stmt = tvm.lower(s, [A, C])["main"].body assert isinstance(stmt.body.body, tvm.tir.Evaluate) assert len(stmt.body.body.value.args) == 5 - assert str(stmt.body.body.value.args[3]) == "(i: int32*i)" - assert str(stmt.body.body.value.args[4]) == "(i: int32 + j: int32)" + assert str(stmt.body.body.value.args[3]) == "i * i" + assert str(stmt.body.body.value.args[4]) == "i + j" def test_legalize_invalid_attach(): diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 83cd64fa229b..d4ae84a556d7 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -14,10 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np import pytest import tvm -from tvm import te, ir -import numpy as np +from tvm import ir, te def test_const(): @@ -142,7 +142,7 @@ def test_basic(): a = te.var("a") b = te.var("b") c = a + b - assert str(c) == "(%s: int32 + %s: int32)" % (a.name, b.name) + assert str(c) == "%s + %s" % (a.name, b.name) def test_stmt(): @@ -176,8 +176,8 @@ def test_any(): assert False except ValueError: pass - assert str(tvm.tir.any(x < y)) == "(%s: int32 < %s: int32)" % (x.name, y.name) - assert str(tvm.tir.any(x < y, x > z)) == "((%s: int32 < %s: int32) || (%s > %s: int32))" % ( + assert str(tvm.tir.any(x < y)) == "%s < %s" % (x.name, y.name) + assert str(tvm.tir.any(x < y, x > z)) == "%s < %s or %s > %s" % ( x.name, y.name, x.name, @@ -185,7 +185,7 @@ def test_any(): ) assert str( tvm.tir.any(x < y, y > z + 1, x < z * 2) - ) == "(((%s: int32 < %s: int32) || (%s > (%s: int32 + 1))) || (%s < (%s*2)))" % ( + ) == "%s < %s or %s > %s + 1 or %s < %s * 2" % ( x.name, y.name, y.name, @@ -209,8 +209,8 @@ def test_all(): assert False except ValueError: pass - assert str(tvm.tir.all(x < y)) == "(%s: int32 < %s: int32)" % (x.name, y.name) - assert str(tvm.tir.all(x < y, x > z)) == "((%s: int32 < %s: int32) && (%s > %s: int32))" % ( + assert str(tvm.tir.all(x < y)) == "%s < %s" % (x.name, y.name) + assert str(tvm.tir.all(x < y, x > z)) == "%s < %s and %s > %s" % ( x.name, y.name, x.name, @@ -218,7 +218,7 @@ def test_all(): ) assert str( tvm.tir.all(x < y, y > z + 1, x < z * 2) - ) == "(((%s: int32 < %s: int32) && (%s > (%s: int32 + 1))) && (%s < (%s*2)))" % ( + ) == "%s < %s and %s > %s + 1 and %s < %s * 2" % ( x.name, y.name, y.name, @@ -231,19 +231,19 @@ def test_all(): def test_bitwise(): x = te.var("x") y = te.var("y") - assert str(x << y) == "@tir.shift_left(x: int32, y: int32, dtype=int32)" - assert str(x >> y) == "@tir.shift_right(x: int32, y: int32, dtype=int32)" - assert str(x & y) == "@tir.bitwise_and(x: int32, y: int32, dtype=int32)" - assert str(x | y) == "@tir.bitwise_or(x: int32, y: int32, dtype=int32)" - assert str(x ^ y) == "@tir.bitwise_xor(x: int32, y: int32, dtype=int32)" - assert str(10 & x) == "@tir.bitwise_and(10, x: int32, dtype=int32)" - assert str(10 | x) == "@tir.bitwise_or(10, x: int32, dtype=int32)" - assert str(10 ^ x) == "@tir.bitwise_xor(10, x: int32, dtype=int32)" - assert str(10 >> x) == "@tir.shift_right(10, x: int32, dtype=int32)" - assert str(10 << x) == "@tir.shift_left(10, x: int32, dtype=int32)" - assert str(10 % x) == "floormod(10, x: int32)" - - assert str(~x) == "@tir.bitwise_not(x: int32, dtype=int32)" + assert str(x << y) == "T.shift_left(x, y)" + assert str(x >> y) == "T.shift_right(x, y)" + assert str(x & y) == "T.bitwise_and(x, y)" + assert str(x | y) == "T.bitwise_or(x, y)" + assert str(x ^ y) == "T.bitwise_xor(x, y)" + assert str(10 & x) == "T.bitwise_and(10, x)" + assert str(10 | x) == "T.bitwise_or(10, x)" + assert str(10 ^ x) == "T.bitwise_xor(10, x)" + assert str(10 >> x) == "T.shift_right(10, x)" + assert str(10 << x) == "T.shift_left(10, x)" + assert str(10 % x) == "10 % x" + + assert str(~x) == "T.bitwise_not(x)" assert (tvm.tir.const(1, "int8x2") >> 1).dtype == "int8x2" assert (x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2" assert (te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2" @@ -302,17 +302,17 @@ def test_divide_by_zero(): def test_infinity(): - assert str(tvm.tir.infinity("float16")) == "inff16" - assert str(tvm.tir.infinity("float32")) == "inff32" - assert str(tvm.tir.infinity("float64")) == "inff64" + assert str(tvm.tir.infinity("float16")) == 'T.float16("inf")' + assert str(tvm.tir.infinity("float32")) == 'T.float32("inf")' + assert str(tvm.tir.infinity("float64")) == 'T.float64("inf")' def test_isnan(): x = te.var("x", "float32") - assert str(tvm.tir.isnan(x)) == "@tir.isnan(x: float32, dtype=bool)" + assert str(tvm.tir.isnan(x)) == "T.isnan(x)" assert str(tvm.tir.isnan(x).dtype) == "bool" y = te.var("y", "float16") - assert str(tvm.tir.isnan(y)) == "@tir.isnan(cast(float32, y: float16), dtype=bool)" + assert str(tvm.tir.isnan(y)) == 'T.isnan(T.Cast("float32", y))' z = te.var("z", "int32") assert str(tvm.tir.isnan(z)) == "False" k = te.var("k", "int8x2") diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index 7062d5129713..adf3d9da05ce 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -14,17 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -from tvm.script import tir as T import numpy as np +import tvm import tvm.testing +from tvm.script import tir as T def count_cp_async(stmt): num_alloc = [0] def verify(n): - if isinstance(n, tvm.tir.Call) and str(n.op) == "tir.ptx_cp_async": + if isinstance(n, tvm.tir.Call) and n.op.name == "tir.ptx_cp_async": num_alloc[0] += 1 tvm.tir.stmt_functor.post_order_visit(stmt, verify) diff --git a/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py b/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py index 70c14b02f0eb..d75fb2b03e39 100644 --- a/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py +++ b/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py @@ -14,15 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np +import pytest import tvm import tvm.script -from tvm.script import tir as T -from tvm import te -from tvm import topi +from tvm import te, topi from tvm.driver.build_module import get_binds -import numpy as np - -import pytest +from tvm.script import tir as T def _tile_nd(s, tensor, tile): @@ -271,7 +269,7 @@ def main(A: T.handle, tensor: T.handle) -> None: def test_rolling_buffer_ir_transform(): mod = PreRollingBuffer mod = tvm.tir.transform.InjectRollingBuffer()(mod) - script = mod.script(show_meta=True) + script = mod.script() mod = tvm.script.from_source(script) tvm.ir.assert_structural_equal(mod["main"], PostRollingBuffer["main"], True) diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index 006b67d62697..cf01d7700725 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -1507,7 +1507,7 @@ def test_async_pipelined_mma_gemm_simple(): assert body.block.body.body[1].block.body.body.value == 3 assert epilogue.block.body.body.block.body.body.attr_key == "async_wait_inflight_count" - assert str(epilogue.block.body.body.block.body.body.value) == "(2 - k_0_0: int32)" + assert str(epilogue.block.body.body.block.body.body.value) == "2 - k_0_0" build_and_run(sch) @@ -1554,7 +1554,7 @@ def test_async_nested_pipeline_mma_gemm_ideal_annotation(): assert body.block.body.body[1].block.body.body.attr_key == "async_wait_inflight_count" assert body.block.body.body[1].block.body.body.value == 2 - assert str(epilogue.block.body.body[0].block.body.body.value) == "(1 - k_0_0: int32)" + assert str(epilogue.block.body.body[0].block.body.body.value) == "1 - k_0_0" build_and_run(sch) diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index e78ed98d8569..47bb7bf228d4 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -83,19 +83,17 @@ def test_variable_passed_from_args(): # Arguments unpacking assignment = _find_assignment(func.body, "arg.input_buffer") - assert str(assignment.value) == "@tir.tvm_struct_get(args: handle, 0, 12, dtype=handle)" + assert str(assignment.value) == 'T.tvm_struct_get(args, 0, 12, "handle")' assignment = _find_assignment(func.body, "arg.not_device_context") - assert str(assignment.value) == "@tir.tvm_struct_get(args: handle, 1, 12, dtype=handle)" + assert str(assignment.value) == 'T.tvm_struct_get(args, 1, 12, "handle")' assignment = _find_assignment(func.body, "input_buffer") - assert ( - str(assignment.value) == "@tir.tvm_struct_get(arg.input_buffer: handle, 0, 1, dtype=handle)" - ) + assert str(assignment.value) == 'T.tvm_struct_get(arg_input_buffer, 0, 1, "handle")' unpacked_input_buffer = assignment.var assignment = _find_assignment(func.body, "not_device_context") - assert str(assignment.value) == "arg.not_device_context: handle" + assert str(assignment.value) == "arg_not_device_context" unpacked_not_device_context = assignment.var seq_stmt = _find_next(assignment, tvm.tir.SeqStmt) @@ -131,12 +129,10 @@ def test_device_api_context_implicit_resource_handle(): # Arguments unpacking assignment = _find_assignment(func.body, "arg.input_buffer") - assert str(assignment.value) == "@tir.tvm_struct_get(args: handle, 0, 12, dtype=handle)" + assert str(assignment.value) == 'T.tvm_struct_get(args, 0, 12, "handle")' assignment = _find_assignment(func.body, "input_buffer") - assert ( - str(assignment.value) == "@tir.tvm_struct_get(arg.input_buffer: handle, 0, 1, dtype=handle)" - ) + assert str(assignment.value) == 'T.tvm_struct_get(arg_input_buffer, 0, 1, "handle")' unpacked_input_buffer = assignment.var seq_stmt = _find_next(assignment, tvm.tir.SeqStmt) diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 0c5d77d02b91..b2a0581d6980 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -92,7 +92,7 @@ def ir(A, B): stmt = ir(A, B) func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) mod = run_passes(func) - assert "@tir.tvm_storage_sync" in str(mod) + assert "T.tvm_storage_sync" in str(mod) @tvm.testing.requires_cuda @@ -115,7 +115,7 @@ def func(p0_arg: T.Buffer[(1, 2, 1, 1), "float32"], p1: T.Buffer[2, "float32"]) result_local[0] = result_local[0] + temp_shared[0] * p1[1] mod = run_passes(func) - assert "@tir.tvm_storage_sync" in str(mod) + assert "T.tvm_storage_sync" in str(mod) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tvmscript_complete.py b/tests/python/unittest/test_tvmscript_complete.py index 29ac5dc5da0d..2f81b0302626 100644 --- a/tests/python/unittest/test_tvmscript_complete.py +++ b/tests/python/unittest/test_tvmscript_complete.py @@ -315,7 +315,7 @@ def expect_alloc_buffer_func(a: T.handle, b: T.handle) -> None: def test_complete_alloc_buffer(): - rt_func = tvm.script.from_source(alloc_buffer_func.script(show_meta=True)) + rt_func = tvm.script.from_source(alloc_buffer_func.script()) tvm.ir.assert_structural_equal(alloc_buffer_func, expect_alloc_buffer_func) diff --git a/tests/python/unittest/test_tvmscript_ops.py b/tests/python/unittest/test_tvmscript_ops.py index 3f30c6ddb0bc..e10681338727 100644 --- a/tests/python/unittest/test_tvmscript_ops.py +++ b/tests/python/unittest/test_tvmscript_ops.py @@ -15,10 +15,10 @@ # specific language governing permissions and limitations # under the License. -import tvm -from tvm.script import tir as T import numpy as np +import tvm import tvm.testing +from tvm.script import tir as T @T.prim_func @@ -152,8 +152,8 @@ def _check_alloc_zero_dim_buffer(f): def test_alloc_zero_dim_buffer_round_trip(): func = alloc_zero_dim_buffer func_with_block = alloc_zero_dim_buffer_block - rt_func = tvm.script.from_source(func.script(show_meta=True)) - rt_func_with_block = tvm.script.from_source(func_with_block.script(show_meta=True)) + rt_func = tvm.script.from_source(func.script()) + rt_func_with_block = tvm.script.from_source(func_with_block.script()) rt_mod = tvm.build(rt_func, "llvm") rt_mod_with_block = tvm.build(rt_func_with_block, "llvm") tvm.ir.assert_structural_equal(func, func_with_block) diff --git a/tests/python/unittest/test_tvmscript_printer_ir.py b/tests/python/unittest/test_tvmscript_printer_ir.py new file mode 100644 index 000000000000..c3da3d8c702b --- /dev/null +++ b/tests/python/unittest/test_tvmscript_printer_ir.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from tvm import IRModule +from tvm.script.ir_builder import IRBuilder +from tvm.script.ir_builder import ir as I +from tvm.script.ir_builder import tir as T + + +def _assert_print(obj, expected): + assert str(obj).strip() == expected.strip() + assert repr(obj).strip() == expected.strip() + if isinstance(obj, IRModule): + assert obj.script().strip() == expected.strip() + + +def test_ir_module(): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module(): + with T.prim_func(): + T.func_name("foo") + mod = ib.get() + _assert_print( + mod, + """ +@I.ir_module +class Module: + @T.prim_func + def foo(): + T.evaluate(0)""", + ) + + +if __name__ == "__main__": + test_ir_module() diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index fd3bb3788cfb..9c15fbc88949 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -35,6 +35,9 @@ def verbose_expr(): def _assert_print(obj, expected): with verbose_expr(): + if isinstance(obj, (tir.PrimFunc, tir.PrimExpr, tir.Stmt)): + assert obj.script().strip() == expected.strip() + assert str(obj).strip() == expected.strip() assert repr(obj).strip() == expected.strip() @@ -54,7 +57,7 @@ def test_prim_func(): func, expected=""" @T.prim_func -def main(a: T.handle, b: T.handle) -> None: +def main(a: T.handle, b: T.handle): A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (256, 256)) T.evaluate(0)""", diff --git a/tests/python/unittest/test_tvmscript_regression.py b/tests/python/unittest/test_tvmscript_regression.py index 44d3036596ba..6678c10acd7a 100644 --- a/tests/python/unittest/test_tvmscript_regression.py +++ b/tests/python/unittest/test_tvmscript_regression.py @@ -15,12 +15,10 @@ # specific language governing permissions and limitations # under the License. import numpy - import tvm import tvm.testing from tvm.script import tir as T - # This numpy array is used to test the comparison between the global objects and the # `tvm.script.tir` submodule. np_array = numpy.array([0, 1, 2, 3]) @@ -42,7 +40,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: def test_multi_element_array_in_outmost_namespace(): func = matmul - rt_func = tvm.script.from_source(func.script(show_meta=True)) + rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func) diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 0e9be0463943..0a6a2a26380c 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -2591,7 +2591,7 @@ def test_module_define(): def test_matmul_original(): func = matmul_original() - rt_func = tvm.script.from_source(func.script(show_meta=True)) + rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func) assert isinstance(rt_func.body.block, tir.stmt.Block) @@ -2605,7 +2605,7 @@ def test_matmul_original(): def test_element_wise(): func = element_wise() - rt_func = tvm.script.from_source(func.script(show_meta=True)) + rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func) assert isinstance(rt_func.body.block, tir.stmt.Block) @@ -2621,7 +2621,7 @@ def test_element_wise(): def test_predicate(): func = predicate() - rt_func = tvm.script.from_source(func.script(show_meta=True)) + rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func) assert isinstance(rt_func.body.block, tir.stmt.Block) @@ -2648,7 +2648,7 @@ def for_thread_binding(a: T.handle, b: T.handle) -> None: def test_for_thread_binding(): func = for_thread_binding() - rt_func = tvm.script.from_source(func.script(show_meta=True)) + rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func) assert isinstance(rt_func.body, tir.stmt.For) @@ -2682,7 +2682,7 @@ def match_buffer_region(a: T.handle, b: T.handle) -> None: def test_match_buffer_region(): func = match_buffer_region() - rt_func = tvm.script.from_source(func.script(show_meta=True)) + rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func) assert isinstance(rt_func.body, tir.stmt.BlockRealize) @@ -2727,7 +2727,7 @@ def block_elements(a: T.handle, b: T.handle) -> None: def test_block_elements(): func = block_elements() - rt_func = tvm.script.from_source(func.script(show_meta=True)) + rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func) assert isinstance(rt_func.body.block, tir.stmt.Block) @@ -2763,7 +2763,7 @@ def opaque_block(a: T.handle, b: T.handle) -> None: def test_opaque_block(): func = opaque_block() - rt_func = tvm.script.from_source(func.script(show_meta=True)) + rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func) root_block = rt_func.body.block @@ -2945,14 +2945,9 @@ def var_with_same_name(a: T.handle) -> None: def test_same_name_var(): func = var_with_same_name() - out_str = func.script(tir_prefix="T", show_meta=True) + out_str = func.script() rt_func = tvm.script.from_source(out_str) tvm.ir.assert_structural_equal(func, rt_func) - - assert out_str.count('vi, vj = T.axis.remap("SS", [i, j])') == 2 - assert out_str.find("vi_") == -1 - assert out_str.find("vj_") == -1 - assert out_str.count("for i, j in T.grid(16, 16)") == 2 assert out_str.find("i_") == -1 assert out_str.find("i_") == -1 @@ -3621,7 +3616,7 @@ def func(): def test_roundtrip(ir_generator): original = ir_generator() - after_roundtrip = tvm.script.from_source(original.script(show_meta=True)) + after_roundtrip = tvm.script.from_source(original.script()) tvm.ir.assert_structural_equal(original, after_roundtrip, True)