diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index bb4c468f452f5..bfbaa7cddd4fd 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -100,6 +100,17 @@ class PrimExprNode : public BaseExprNode { */ DataType dtype; + /*! + * \brief Returns the TVMScript format + * \param indent_spaces Number of spaces used for indentation + * \param print_line_numbers Whether to print line numbers + * \param num_context_lines Number of context lines to print around the underlined text + * \param path_to_underline Object path to be underlined + */ + TVM_DLL std::string Script(int indent_spaces = 4, bool print_line_numbers = false, + int num_context_lines = -1, + Optional path_to_underline = NullOpt) const; + static constexpr const char* _type_key = "PrimExpr"; static constexpr const uint32_t _type_child_slots = 38; TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode); diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 7313b4f783492..f26e640f6c221 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -63,6 +63,26 @@ class IRModuleNode : public Object { parser::SourceMap source_map; /* \brief Additional attributes storing meta-data about the module. */ DictAttrs attrs; + /*! + * \brief A map from string names to global variables that + * ensures global uniqueness. + */ + Map global_var_map_; + + /*! \brief A map from string names to global type variables (ADT names) + * that ensures global uniqueness. + */ + Map global_type_var_map_; + + /*! \brief A map from constructor tags to constructor objects + * for convenient access + */ + std::unordered_map constructor_tag_map_; + + /*! \brief The files previously imported, required to ensure + importing is idempotent for each module. + */ + std::unordered_set import_set_; /*! * \brief Get a module attribute. @@ -304,15 +324,20 @@ class IRModuleNode : public Object { TVM_DLL void ImportFromStd(const String& path); /*! - * \brief Should Link Parameters into the module - * \return Whether the Executor is configured to execute with linked parameters (Default: false) + * \brief The set of imported files. */ - TVM_DLL Bool ShouldLinkParameters() const; + TVM_DLL std::unordered_set Imports() const; /*! - * \brief The set of imported files. + * \brief Returns the TVMScript format + * \param indent_spaces Number of spaces used for indentation + * \param print_line_numbers Whether to print line numbers + * \param num_context_lines Number of context lines to print around the underlined text + * \param path_to_underline Object path to be underlined */ - TVM_DLL std::unordered_set Imports() const; + TVM_DLL std::string Script(int indent_spaces = 4, bool print_line_numbers = false, + int num_context_lines = -1, + Optional path_to_underline = NullOpt) const; static constexpr const char* _type_key = "IRModule"; static constexpr const bool _type_has_method_sequal_reduce = true; @@ -322,26 +347,6 @@ class IRModuleNode : public Object { private: /*! \brief Helper function for registering a typedef's constructors */ void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type); - - /*! \brief A map from string names to global variables that - * ensures global uniqueness. - */ - Map global_var_map_; - - /*! \brief A map from string names to global type variables (ADT names) - * that ensures global uniqueness. - */ - Map global_type_var_map_; - - /*! \brief A map from constructor tags to constructor objects - * for convenient access - */ - std::unordered_map constructor_tag_map_; - - /*! \brief The files previously imported, required to ensure - importing is idempotent for each module. - */ - std::unordered_set import_set_; friend class IRModule; }; diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 579061e02eb6e..62328f6a074a2 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -207,6 +207,25 @@ enum TypeKind : int { kTypeData = 6 }; +/*! \brief Converts a TypeKind to a string. */ +inline String TypeKind2String(TypeKind kind) { + switch (kind) { + case TypeKind::kType: + return "Type"; + case TypeKind::kShapeVar: + return "ShapeVar"; + case TypeKind::kBaseType: + return "BaseType"; + case TypeKind::kConstraint: + return "Constraint"; + case TypeKind::kAdtHandle: + return "AdtHandle"; + case TypeKind::kTypeData: + return "TypeData"; + } + LOG(FATAL) << "ValueError: Unknown TypeKind: " << static_cast(kind); +} + /*! * \brief Type parameter in functions. * diff --git a/include/tvm/ir/type_functor.h b/include/tvm/ir/type_functor.h index 11bf7d4740d0f..334a35d052e18 100644 --- a/include/tvm/ir/type_functor.h +++ b/include/tvm/ir/type_functor.h @@ -24,9 +24,9 @@ #ifndef TVM_IR_TYPE_FUNCTOR_H_ #define TVM_IR_TYPE_FUNCTOR_H_ +#include +#include #include -#include -#include #include #include diff --git a/include/tvm/node/repr_printer.h b/include/tvm/node/repr_printer.h index 532425a51b3ec..e3f59fcc14a18 100644 --- a/include/tvm/node/repr_printer.h +++ b/include/tvm/node/repr_printer.h @@ -26,6 +26,7 @@ #include #include +#include namespace tvm { /*! \brief A printer class to print the AST/IR nodes. */ @@ -48,6 +49,30 @@ class ReprPrinter { TVM_DLL static FType& vtable(); }; +/*! \brief Legacy behavior of ReprPrinter. */ +class ReprLegacyPrinter { + public: + /*! \brief The indentation level. */ + int indent{0}; + + explicit ReprLegacyPrinter(std::ostream& stream) // NOLINT(*) + : stream(stream) {} + + /*! \brief The node to be printed. */ + TVM_DLL void Print(const ObjectRef& node); + /*! \brief Print indent to the stream */ + TVM_DLL void PrintIndent(); + /*! \brief Return the ostream it maintains */ + TVM_DLL std::ostream& Stream() const; + // Allow registration to be printer. + using FType = NodeFunctor; + TVM_DLL static FType& vtable(); + + private: + /*! \brief The output stream */ + std::ostream& stream; +}; + /*! * \brief Dump the node to stderr, used for debug purposes. * \param node The input node @@ -70,6 +95,13 @@ inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLI ReprPrinter(os).Print(n); return os; } + +inline std::string AsLegacyRepr(const ObjectRef& n) { + std::ostringstream os; + ReprLegacyPrinter(os).Print(n); + return os.str(); +} } // namespace runtime +using runtime::AsLegacyRepr; } // namespace tvm #endif // TVM_NODE_REPR_PRINTER_H_ diff --git a/include/tvm/script/printer/ir_docsifier_functor.h b/include/tvm/script/printer/ir_docsifier_functor.h index d04d8c4d028ab..54810fd55a431 100644 --- a/include/tvm/script/printer/ir_docsifier_functor.h +++ b/include/tvm/script/printer/ir_docsifier_functor.h @@ -69,6 +69,9 @@ class IRDocsifierFunctor { if ((pf = LookupDispatchTable("", type_index)) != nullptr) { return (*pf)(obj, args...); } + LOG(WARNING) << "ObjectFunctor calls un-registered function on type: " + << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")" + << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj; ICHECK(false) << "ObjectFunctor calls un-registered function on type: " << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")" << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj; diff --git a/include/tvm/script/printer/printer.h b/include/tvm/script/printer/printer.h index 289e838b52a80..b373a2be73fb1 100644 --- a/include/tvm/script/printer/printer.h +++ b/include/tvm/script/printer/printer.h @@ -55,21 +55,6 @@ struct Default { static bool& VerboseExpr() { return Instance()->verbose_expr; } }; -/*! - * \brief The entry method for TVMScript printing - * \param obj The object to be printed - * \param indent_spaces Number of spaces used for indentation - * \param print_line_numbers Whether to print line numbers - * \param num_context_lines Number of context lines to print around the underlined text - * \param path_to_underline Object path to be underlined - * \return The TVMScript text format - */ -String Script(ObjectRef obj, // - int indent_spaces = 4, // - bool print_line_numbers = false, // - int num_context_lines = -1, // - Optional path_to_underline = NullOpt); - /*! * \brief Convert Doc into Python script. * \param doc Doc to be converted diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 1d5e8f317a2eb..689b1c0a17add 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -1191,9 +1191,6 @@ class Any : public PrimExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(AnyNode); }; -/*! \brief Legacy ReprPrint format for TIR */ -std::string LegacyTIRPrint(const ObjectRef& obj); - /* * \brief Template function to convert Map to unordered_map * Sometimes useful for API gluing when internal uses unordered_map diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index cf92f97360b18..542d05e032767 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -132,6 +132,17 @@ class PrimFuncNode : public BaseFuncNode { */ TVM_DLL FuncType func_type_annotation() const; + /*! + * \brief Returns the TVMScript format + * \param indent_spaces Number of spaces used for indentation + * \param print_line_numbers Whether to print line numbers + * \param num_context_lines Number of context lines to print around the underlined text + * \param path_to_underline Object path to be underlined + */ + std::string Script(int indent_spaces = 4, bool print_line_numbers = false, + int num_context_lines = -1, + Optional path_to_underline = NullOpt) const; + static constexpr const char* _type_key = "tir.PrimFunc"; TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode); }; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 96e03477a1414..e0b7bcc868b30 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -46,6 +46,17 @@ class StmtNode : public Object { StmtNode() = default; explicit StmtNode(Span span) : span(span) {} + /*! + * \brief Returns the TVMScript format + * \param indent_spaces Number of spaces used for indentation + * \param print_line_numbers Whether to print line numbers + * \param num_context_lines Number of context lines to print around the underlined text + * \param path_to_underline Object path to be underlined + */ + std::string Script(int indent_spaces = 4, bool print_line_numbers = false, + int num_context_lines = -1, + Optional path_to_underline = NullOpt) const; + static constexpr const char* _type_key = "tir.Stmt"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 4e847c0310a46..9e81dd5519e14 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -16,29 +16,47 @@ # under the License. # pylint: disable=unused-import """Common data structures across all IR variants.""" -from .base import SourceName, Span, Node, EnvFunc, load_json, save_json -from .base import structural_equal, assert_structural_equal, structural_hash -from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType -from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType -from .tensor_type import TensorType -from .affine_type import TensorAffineType, TupleAffineType -from .type_relation import TypeCall, TypeRelation -from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range -from .op import Op, register_op_attr, register_intrin_lowering -from .function import CallingConv, BaseFunc +from . import diagnostics, instrument, transform from .adt import Constructor, TypeData -from .module import IRModule +from .affine_type import TensorAffineType, TupleAffineType from .attrs import Attrs, DictAttrs, make_node +from .base import ( + EnvFunc, + Node, + SourceName, + Span, + assert_structural_equal, + load_json, + pretty_print, + save_json, + structural_equal, + structural_hash, +) from .container import Array, Map +from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelayExpr +from .function import BaseFunc, CallingConv from .memory_pools import ( - PoolInfo, - WorkspacePoolInfo, - ConstantPoolInfo, - WorkspaceMemoryPools, ConstantMemoryPools, + ConstantPoolInfo, + PoolInfo, PoolInfoProperties, + WorkspaceMemoryPools, + WorkspacePoolInfo, ) - -from . import transform -from . import instrument -from . import diagnostics +from .module import IRModule +from .op import Op, register_intrin_lowering, register_op_attr +from .tensor_type import TensorType +from .type import ( + FuncType, + GlobalTypeVar, + IncompleteType, + PointerType, + PrimType, + RelayRefType, + TupleType, + Type, + TypeConstraint, + TypeKind, + TypeVar, +) +from .type_relation import TypeCall, TypeRelation diff --git a/python/tvm/ir/affine_type.py b/python/tvm/ir/affine_type.py index bd77c187af40e..8d185ae59a340 100644 --- a/python/tvm/ir/affine_type.py +++ b/python/tvm/ir/affine_type.py @@ -17,8 +17,8 @@ """Types for quantized Tensors.""" import tvm._ffi -from .base import Node from . import _ffi_api +from .base import Node class AffineType(Node): @@ -31,6 +31,11 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + def __str__(self): + from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel + + return pretty_print(self) + @tvm._ffi.register_object("TensorAffineType") class TensorAffineType(AffineType): diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index d754ae567c5e8..a1e1d20d88237 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -16,13 +16,16 @@ # under the License. """Common base structures.""" import tvm._ffi - import tvm.error import tvm.runtime._ffi_node_api from tvm.runtime import Object -from . import _ffi_api -from . import json_compact +from . import _ffi_api, json_compact + + +def pretty_print(obj: Object) -> None: + """Pretty print the object.""" + return _ffi_api.PrettyPrint(obj) # type: ignore # pylint: disable=no-member class Node(Object): @@ -54,9 +57,6 @@ def astext(self, show_meta_data=True, annotate=None): """ return _ffi_api.AsText(self, show_meta_data, annotate) - def __str__(self): - return _ffi_api.PrettyPrint(self) - @tvm._ffi.register_object("SourceName") class SourceName(Object): diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 3ed7e57cb758a..b184c3b0c3cf7 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -17,13 +17,13 @@ """IRModule that holds the functions and type definitions.""" from typing import Optional -from tvm._ffi.base import string_types import tvm._ffi +from tvm._ffi.base import string_types -from .base import Node +from . import _ffi_api from . import expr as _expr from . import type as _ty -from . import _ffi_api +from .base import Node @tvm._ffi.register_object("IRModule") @@ -252,51 +252,6 @@ def import_from_std(self, file_to_import): _ffi_api.Module_ImportFromStd(self, file_to_import) return tvm.relay.transform.InferType()(self) - def __str__(self): - return _ffi_api.PrettyPrint(self) - - def __repr__(self): - return self.astext() - - def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: - """Print IRModule into TVMScript - - Parameters - ---------- - tir_prefix : str - The tir namespace prefix - - show_meta : bool - Whether to show meta information - - Returns - ------- - script : str - The TVM Script of the IRModule - """ - return tvm._ffi.get_global_func("script.AsTVMScript")( - self, tir_prefix, show_meta - ) # type: ignore - - def show(self, style: Optional[str] = None, black_format: bool = True) -> None: - """A sugar for print highlighted TVM script. - - Parameters - ---------- - style : str, optional - - Pygmentize printing style, auto-detected if None. See - `tvm.script.highlight.cprint` for more details. - - black_format: bool - - If true (default), use the formatter Black to format the TVMScript - """ - from tvm.script.highlight import cprint # pylint: disable=import-outside-toplevel - - # Use deferred import to avoid circular import while keeping cprint under tvm/script - cprint(self, style=style, black_format=black_format) - def get_attr(self, attr_key): """Get the IRModule attribute. @@ -331,3 +286,78 @@ def with_attr(self, attr_key, attr_value): """ return _ffi_api.Module_WithAttr(self, attr_key, attr_value) + + def script( + self, + *, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: Optional[int] = None, + path_to_underline=None, + ) -> str: + """Print IRModule into TVMScript + + Parameters + ---------- + indent_spaces : int + The number of indent spaces to use in the output + print_line_numbers: bool + Whether to print line numbers + num_context_lines : Optional[int] + Number of context lines to print around the underlined text + path_to_underline : Optional[ObjectPath] + Object path to be underlined + + Returns + ------- + script : str + The TVM Script of the IRModule + """ + if num_context_lines is None: + num_context_lines = -1 + return _ffi_api.Module_Script( # type: ignore # pylint: disable=no-member + self, indent_spaces, print_line_numbers, num_context_lines, path_to_underline + ) + + def show( + self, + *, + style: Optional[str] = None, + black_format: bool = True, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: Optional[int] = None, + path_to_underline=None, + ) -> None: + """A sugar for print highlighted TVM script. + + Parameters + ---------- + style : str, optional + Pygmentize printing style, auto-detected if None. See + `tvm.script.highlight.cprint` for more details. + black_format: bool + If true (default), use the formatter Black to format the TVMScript + indent_spaces : int + The number of indent spaces to use in the output + print_line_numbers: bool + Whether to print line numbers + num_context_lines : Optional[int] + Number of context lines to print around the underlined text + path_to_underline : Optional[ObjectPath] + Object path to be underlined + """ + from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel + cprint, + ) + + cprint( + self.script( + indent_spaces=indent_spaces, + print_line_numbers=print_line_numbers, + num_context_lines=num_context_lines, + path_to_underline=path_to_underline, + ), + style=style, + black_format=black_format, + ) diff --git a/python/tvm/ir/tensor_type.py b/python/tvm/ir/tensor_type.py index 22b15a397e30a..7313f3c2b42c4 100644 --- a/python/tvm/ir/tensor_type.py +++ b/python/tvm/ir/tensor_type.py @@ -17,8 +17,8 @@ """Type relation and function for type checking.""" import tvm._ffi -from .type import Type from . import _ffi_api +from .type import Type @tvm._ffi.register_object("relay.TensorType") @@ -54,3 +54,8 @@ def concrete_shape(self): TypeError : If the shape is symbolic """ return tuple(int(x) for x in self.shape) + + def __str__(self): + from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel + + return pretty_print(self) diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py index 4fe28f1d72e24..ea06aeda2030d 100644 --- a/python/tvm/ir/type.py +++ b/python/tvm/ir/type.py @@ -16,11 +16,12 @@ # under the License. """Unified type system in the project.""" from enum import IntEnum + import tvm import tvm._ffi -from .base import Node from . import _ffi_api +from .base import Node class Type(Node): diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 1f6d8bb9ab0b1..6c29825bc04d3 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -46,6 +46,11 @@ def register_df_node(type_key=None): class DFPattern(Node): """Base class of all Patterns.""" + def __str__(self): + from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel + + return pretty_print(self) + def __call__(self, *args): args = list(args) if len(args) == 1 and args[0] is None: diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 88b84bbe7ebc3..7d60e89b59b75 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -17,17 +17,20 @@ # pylint: disable=no-else-return, invalid-name, unused-import """The expression nodes of Relay.""" from __future__ import absolute_import + from numbers import Number as _Number import numpy as _np + import tvm._ffi from tvm._ffi import base as _base -from tvm.runtime import NDArray, ndarray as _nd -from tvm.ir import RelayExpr, GlobalVar, Node +from tvm.ir import GlobalVar, Node, RelayExpr +from tvm.runtime import NDArray +from tvm.runtime import ndarray as _nd -from .base import RelayNode from . import _ffi_api from . import ty as _ty +from .base import RelayNode # alias relay expr as Expr. Expr = RelayExpr @@ -58,6 +61,11 @@ def astype(self, dtype): """ return _ffi_api.cast(self, dtype) + def __str__(self): + from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel + + return pretty_print(self) + def __neg__(self): return _op_make.negative(self) @@ -710,6 +718,11 @@ class StorageInfo(Node): def __init__(self, sids, dev_types, sizes): self.__init_handle_by_constructor__(_ffi_api.StorageInfo, sids, dev_types, sizes) + def __str__(self): + from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel + + return pretty_print(self) + @property def storage_ids(self): return _ffi_api.StorageInfoStorageIds(self) @@ -735,3 +748,8 @@ class StaticMemoryPlan(Node): def __init__(self, expr_to_storage_info): self.__init_handle_by_constructor__(_ffi_api.StaticMemoryPlan, expr_to_storage_info) + + def __str__(self): + from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel + + return pretty_print(self) diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index 66bb858edbf05..e9bb15e1d1c6c 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -1847,7 +1847,7 @@ def _impl(inputs, attr, params, mod): shape_arg = tuple(params_new.numpy().astype("int32").flatten()) except Exception: # Deal with symbolic shape case. - if isinstance(pop_node, _expr.Call) and "shape_of" in str(pop_node.op): + if isinstance(pop_node, _expr.Call) and "shape_of" in str(pop_node.op.name): # shape_of is the direct ancestor. return _op.reshape_like(inputs[0], pop_node.args[0]) shape_arg = pop_node diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py index 68d8953900cfb..ef3356450085f 100644 --- a/python/tvm/relay/function.py +++ b/python/tvm/relay/function.py @@ -19,11 +19,11 @@ from __future__ import absolute_import import tvm._ffi -from tvm.runtime import convert from tvm.ir import BaseFunc +from tvm.runtime import convert -from .expr import Call from . import _ffi_api +from .expr import Call @tvm._ffi.register_object("relay.Function") @@ -67,6 +67,11 @@ def __call__(self, *args): """ return Call(self, args, None, None) + def __str__(self): + from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel + + return pretty_print(self) + @tvm._ffi.register_func("relay.FunctionWithFields") def FunctionWithFields( diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 1a441a6f03c2a..6fce020a66948 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -17,12 +17,14 @@ # pylint: disable=invalid-name """Patterns supported CUTLASS.""" from functools import partial + from tvm import relay -from tvm.ir.transform import Sequential, PassContext +from tvm.ir.transform import PassContext, Sequential from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name from tvm.relay.op.contrib.register import register_pattern_table # type: ignore -from ...dataflow_pattern import wildcard, is_op, is_constant + +from ...dataflow_pattern import is_constant, is_op, wildcard def make_gelu_pattern(bias_out, out_dtype="float16"): @@ -124,7 +126,7 @@ def check_dtype(lhs, rhs): def get_root_call(call, root_op_name): if not isinstance(call, relay.Call): return None - if str(call.op) == root_op_name: + if str(call.op.name) == root_op_name: return call return get_root_call(call.args[0], root_op_name) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index bdf910d704cea..7db8608d6d7c0 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -36,22 +36,25 @@ from functools import reduce import tvm.ir -from tvm.ir import Op from tvm import relay +from tvm.ir import Op +from tvm.relay import expr as _expr from tvm.relay import transform -from tvm.relay.expr import GlobalVar -from tvm.relay.expr_functor import ExprMutator, ExprVisitor -from tvm.relay.expr import const - from tvm.relay.analysis import analysis as _analysis -from tvm.relay import expr as _expr +from tvm.relay.expr import Call, GlobalVar, TupleGetItem, const +from tvm.relay.expr_functor import ExprMutator, ExprVisitor -from tvm.relay.expr import Call, TupleGetItem from ... import _ffi_api -from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr, rewrite, DFPatternCallback +from ...dataflow_pattern import ( + DFPatternCallback, + is_constant, + is_expr, + is_op, + rewrite, + wildcard, +) from .register import register_pattern_table - logger = logging.getLogger("DNNL") supported_post_elts = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", "mish", None] @@ -762,7 +765,7 @@ def visit_call(self, call): ] ) if isinstance(call.op, tvm.tir.op.Op): - if str(call.op) in compute_intensive_ops: + if str(call.op.name) in compute_intensive_ops: self.is_compute_intensive = True return super().visit_call(call) diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index a86357db39fc9..bd9a7d5ba0d19 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -17,16 +17,22 @@ # pylint: disable=ungrouped-imports, import-outside-toplevel """Arm(R) Ethos(TM)-U NPU supported operators.""" import functools -from typing import Dict, List, Tuple, Callable, Optional +from typing import Callable, Dict, List, Optional, Tuple import numpy as np # type: ignore import tvm # type: ignore from tvm import relay -from tvm.relay.expr import Constant, Call # type: ignore -from tvm.relay.op.contrib.register import register_pattern_table # type: ignore -from tvm.relay.dataflow_pattern import wildcard, is_op, is_constant, is_tuple # type: ignore +from tvm.ir import Op from tvm.relay.build_module import bind_params_by_name # type: ignore +from tvm.relay.dataflow_pattern import ( # type: ignore + is_constant, + is_op, + is_tuple, + wildcard, +) +from tvm.relay.expr import Call, Constant # type: ignore +from tvm.relay.op.contrib.register import register_pattern_table # type: ignore try: # As ethos-u-vela package is an optional TVM dependency, we want to lazy load it @@ -197,20 +203,23 @@ class QnnConv2DParams: @requires_vela def __init__(self, func_body: tvm.relay.Function): from tvm.relay.backend.contrib.ethosu.util import QConv2DArgs # type: ignore - from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs - from tvm.relay.backend.contrib.ethosu.util import RequantArgs + from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs, RequantArgs activation = None separate_padding = None - if str(func_body.op) in self.activation_map.keys(): + if str(func_body.op.name) in self.activation_map.keys(): activation = func_body requantize_op = activation.args[0] else: requantize_op = func_body bias_add = requantize_op.args[0] qnn_conv2d = bias_add.args[0] - if isinstance(qnn_conv2d.args[0], relay.Call) and str(qnn_conv2d.args[0].op) == "nn.pad": + if ( + isinstance(qnn_conv2d.args[0], relay.Call) + and isinstance(qnn_conv2d.args[0].op, Op) + and str(qnn_conv2d.args[0].op.name) == "nn.pad" + ): separate_padding = qnn_conv2d.args[0] data_layout = qnn_conv2d.attrs.data_layout self.kernel_layout = qnn_conv2d.attrs.kernel_layout @@ -330,13 +339,14 @@ class QnnConv2DTransposeParams: @requires_vela def __init__(self, func_body: tvm.relay.Function): - from tvm.relay.backend.contrib.ethosu.util import QConv2DTransposeArgs # type: ignore - from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs - from tvm.relay.backend.contrib.ethosu.util import RequantArgs + from tvm.relay.backend.contrib.ethosu.util import ( + QConv2DTransposeArgs, # type: ignore + ) + from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs, RequantArgs requantize = func_body call = func_body.args[0] - if str(call.op) == "nn.bias_add": + if str(call.op.name) == "nn.bias_add": bias_add = call call = call.args[0] else: @@ -561,7 +571,7 @@ class MaxPool2DParams: def __init__(self, func_body: Call): clip = None - if str(func_body.op) == "clip": + if str(func_body.op.name) == "clip": clip = func_body pool_op = clip.args[0] else: @@ -617,7 +627,7 @@ class AvgPool2DParams: def __init__(self, func_body: Call): clip = None - if str(func_body.op) == "clip": + if str(func_body.op.name) == "clip": clip = func_body cast2 = clip.args[0] else: @@ -681,19 +691,21 @@ class BinaryElementwiseParams: """ def __init__(self, func_body: Call, operator_type: str, is_quantized_operation: bool): - from tvm.relay.backend.contrib.ethosu.util import BinaryElementwiseArgs - from tvm.relay.backend.contrib.ethosu.util import RequantArgs + from tvm.relay.backend.contrib.ethosu.util import ( + BinaryElementwiseArgs, + RequantArgs, + ) current_call = func_body clip = None requantize = None if is_quantized_operation: - if str(current_call.op) == "clip": + if str(current_call.op.name) == "clip": clip = current_call current_call = clip.args[0] else: - if str(current_call.op) == "qnn.requantize": + if str(current_call.op.name) == "qnn.requantize": requantize = current_call clip = current_call.args[0] current_call = clip.args[0] @@ -1101,8 +1113,7 @@ class AbsParams: composite_name = "ethos-u.abs" def __init__(self, func_body: Call): - from tvm.relay.backend.contrib.ethosu.util import QuantizeArgs - from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs + from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs, QuantizeArgs quantize = func_body abs_op = quantize.args[0] @@ -1157,8 +1168,7 @@ class LutActivationParams: """ def __init__(self, func_body: Call): - from tvm.relay.backend.contrib.ethosu.util import QuantizeArgs - from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs + from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs, QuantizeArgs layout = "NHWC" @@ -1631,18 +1641,17 @@ class FullyConnectedParams: @requires_vela def __init__(self, func_body): from tvm.relay.backend.contrib.ethosu.util import QDenseArgs # type: ignore - from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs - from tvm.relay.backend.contrib.ethosu.util import RequantArgs + from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs, RequantArgs self.activation = None - if str(func_body.op) == "clip": + if str(func_body.op.name) == "clip": self.activation = func_body requantize_op = self.activation.args[0] else: requantize_op = func_body call = requantize_op.args[0] - if str(requantize_op.args[0].op) == "nn.bias_add": + if str(requantize_op.args[0].op.name) == "nn.bias_add": bias_add = call qnn_dense = call.args[0] else: @@ -1733,8 +1742,7 @@ class HardSwishParams: composite_name = "ethos-u.hard_swish" def __init__(self, func_body): - from tvm.relay.backend.contrib.ethosu.util import QuantizeArgs - from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs + from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs, QuantizeArgs quantize = func_body divide = quantize.args[0] diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index 4008b0eb3f78f..0971770e57267 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -17,15 +17,22 @@ # pylint: disable=invalid-name, unused-argument, logging-format-interpolation """TensorRT supported operators.""" import logging -from typing import Tuple, List, Dict, Union, Optional, Any, Callable +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np # type: ignore + import tvm from tvm import relay from tvm.ir import Op from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name -from tvm.relay.dataflow_pattern import is_op, wildcard, is_constant, is_tuple, is_tuple_get_item +from tvm.relay.dataflow_pattern import ( + is_constant, + is_op, + is_tuple, + is_tuple_get_item, + wildcard, +) from tvm.relay.expr import Call, Constant, TupleGetItem from tvm.relay.expr_functor import ExprMutator, ExprVisitor from tvm.relay.op.contrib.register import register_pattern_table @@ -1050,7 +1057,7 @@ def visit_call(self, call: relay.expr.Call) -> None: "mean", } if isinstance(call.op, tvm.tir.op.Op): - if str(call.op) in compute_intensive_ops: + if str(call.op.name) in compute_intensive_ops: self.is_compute_intensive = True return super().visit_call(call) diff --git a/python/tvm/runtime/_ffi_node_api.py b/python/tvm/runtime/_ffi_node_api.py index 11d317b657e6e..703a12f45f4b9 100644 --- a/python/tvm/runtime/_ffi_node_api.py +++ b/python/tvm/runtime/_ffi_node_api.py @@ -19,6 +19,7 @@ """FFI for tvm.node""" import tvm._ffi + # The implementations below are default ones when the corresponding # functions are not available in the runtime only mode. # They will be overriden via _init_api to the ones registered @@ -27,6 +28,10 @@ def AsRepr(obj): return obj.type_key() + "(" + obj.handle.value + ")" +def AsLegacyRepr(obj): + return obj.type_key() + "(" + obj.handle.value + ")" + + def NodeListAttrNames(obj): return lambda x: 0 diff --git a/python/tvm/runtime/object.py b/python/tvm/runtime/object.py index e522fd539b4e8..6a8dd65876439 100644 --- a/python/tvm/runtime/object.py +++ b/python/tvm/runtime/object.py @@ -18,22 +18,30 @@ """Runtime Object API""" import ctypes -from tvm._ffi.base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str +from tvm._ffi.base import _FFI_MODE, _LIB, _RUNTIME_ONLY, c_str, check_call from tvm._ffi.runtime_ctypes import ObjectRValueRef + from . import _ffi_api, _ffi_node_api try: # pylint: disable=wrong-import-position,unused-import if _FFI_MODE == "ctypes": raise ImportError() - from tvm._ffi._cy3.core import _set_class_object, _set_class_object_generic - from tvm._ffi._cy3.core import ObjectBase, PyNativeObject + from tvm._ffi._cy3.core import ( + ObjectBase, + PyNativeObject, + _set_class_object, + _set_class_object_generic, + ) except (RuntimeError, ImportError) as error: # pylint: disable=wrong-import-position,unused-import if _FFI_MODE == "cython": raise error - from tvm._ffi._ctypes.packed_func import _set_class_object, _set_class_object_generic from tvm._ffi._ctypes.object import ObjectBase, PyNativeObject + from tvm._ffi._ctypes.packed_func import ( + _set_class_object, + _set_class_object_generic, + ) def _new_object(cls): @@ -49,6 +57,9 @@ class Object(ObjectBase): def __repr__(self): return _ffi_node_api.AsRepr(self) + def legacy_repr(self): + return _ffi_node_api.AsLegacyRepr(self) + def __dir__(self): class_names = dir(self.__class__) fnames = _ffi_node_api.NodeListAttrNames(self) diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index d52fbb83c3689..dab7a175185de 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -28,15 +28,16 @@ assert(y.a == x) """ from typing import Optional, Union -from tvm import ir + import tvm._ffi +import tvm.ir._ffi_api +from tvm import ir +from tvm.ir import Op, PrimExpr from tvm.ir.base import Span +from tvm.runtime import DataType, DataTypeCode, Object, ObjectGeneric, const -from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const -from tvm.ir import PrimExpr, Op -import tvm.ir._ffi_api -from . import generic as _generic from . import _ffi_api +from . import generic as _generic def div_ambiguity_error(): @@ -324,6 +325,81 @@ class PrimExprWithOp(ExprOp, PrimExpr): # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__ __hash__ = PrimExpr.__hash__ + def script( + self, + *, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: Optional[int] = None, + path_to_underline=None, + ) -> str: + """Print IRModule into TVMScript + + Parameters + ---------- + indent_spaces : int + The number of indent spaces to use in the output + print_line_numbers: bool + Whether to print line numbers + num_context_lines : Optional[int] + Number of context lines to print around the underlined text + path_to_underline : Optional[ObjectPath] + Object path to be underlined + + Returns + ------- + script : str + The TVM Script of the IRModule + """ + if num_context_lines is None: + num_context_lines = -1 + return _ffi_api.PrimExprScript( # type: ignore # pylint: disable=no-member + self, indent_spaces, print_line_numbers, num_context_lines, path_to_underline + ) + + def show( + self, + *, + style: Optional[str] = None, + black_format: bool = True, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: Optional[int] = None, + path_to_underline=None, + ) -> None: + """A sugar for print highlighted TVM script. + + Parameters + ---------- + style : str, optional + Pygmentize printing style, auto-detected if None. See + `tvm.script.highlight.cprint` for more details. + black_format: bool + If true (default), use the formatter Black to format the TVMScript + indent_spaces : int + The number of indent spaces to use in the output + print_line_numbers: bool + Whether to print line numbers + num_context_lines : Optional[int] + Number of context lines to print around the underlined text + path_to_underline : Optional[ObjectPath] + Object path to be underlined + """ + from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel + cprint, + ) + + cprint( + self.script( + indent_spaces=indent_spaces, + print_line_numbers=print_line_numbers, + num_context_lines=num_context_lines, + path_to_underline=path_to_underline, + ), + style=style, + black_format=black_format, + ) + class ConstExpr(PrimExprWithOp): pass diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 082faeb456d31..fb5a37c5dc17f 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -18,17 +18,18 @@ import collections import inspect -from typing import Callable, List, Mapping, Optional, Union, Tuple +from typing import Callable, List, Mapping, Optional, Tuple, Union import tvm import tvm._ffi import tvm.runtime -from tvm.runtime import Object from tvm.ir import BaseFunc, Range -from .buffer import Buffer -from .expr import Var, PrimExpr -from . import _ffi_api +from tvm.runtime import Object + from ..runtime.ndarray import NDArray +from . import _ffi_api +from .buffer import Buffer +from .expr import PrimExpr, Var @tvm._ffi.register_object("tir.PrimFunc") @@ -169,44 +170,80 @@ def mem_copy_16_16(a: T.handle, b: T.handle) -> None: """ return _ffi_api.Specialize(self, param_map) # type: ignore - def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: + def script( + self, + *, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: Optional[int] = None, + path_to_underline=None, + ) -> str: """Print IRModule into TVMScript Parameters ---------- - tir_prefix : str - The tir namespace prefix - - show_meta : bool - Whether to show meta information + indent_spaces : int + The number of indent spaces to use in the output + print_line_numbers: bool + Whether to print line numbers + num_context_lines : Optional[int] + Number of context lines to print around the underlined text + path_to_underline : Optional[ObjectPath] + Object path to be underlined Returns ------- script : str - The TVM Script of the PrimFunc + The TVM Script of the IRModule """ - return tvm._ffi.get_global_func("script.AsTVMScript")( - self, tir_prefix, show_meta - ) # type: ignore + if num_context_lines is None: + num_context_lines = -1 + return _ffi_api.PrimFuncScript( # type: ignore # pylint: disable=no-member + self, indent_spaces, print_line_numbers, num_context_lines, path_to_underline + ) - def show(self, style: Optional[str] = None, black_format: bool = True) -> None: + def show( + self, + *, + style: Optional[str] = None, + black_format: bool = True, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: Optional[int] = None, + path_to_underline=None, + ) -> None: """A sugar for print highlighted TVM script. Parameters ---------- style : str, optional - Pygmentize printing style, auto-detected if None. See `tvm.script.highlight.cprint` for more details. - black_format: bool - If true (default), use the formatter Black to format the TVMScript + indent_spaces : int + The number of indent spaces to use in the output + print_line_numbers: bool + Whether to print line numbers + num_context_lines : Optional[int] + Number of context lines to print around the underlined text + path_to_underline : Optional[ObjectPath] + Object path to be underlined """ - from tvm.script.highlight import cprint # pylint: disable=import-outside-toplevel + from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel + cprint, + ) - # Use deferred import to avoid circular import while keeping cprint under tvm/script - cprint(self, style=style, black_format=black_format) + cprint( + self.script( + indent_spaces=indent_spaces, + print_line_numbers=print_line_numbers, + num_context_lines=num_context_lines, + path_to_underline=path_to_underline, + ), + style=style, + black_format=black_format, + ) @tvm._ffi.register_object("tir.TensorIntrin") diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 48850012cbb7f..64aba0e029fe3 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -239,21 +239,26 @@ def fork_seed(self) -> int: """ return _ffi_api.ScheduleForkSeed(self) # type: ignore # pylint: disable=no-member - @type_checked - def show(self, rand_var: RAND_VAR_TYPE) -> str: - """Returns a string representation of the value that the random variable evaluates to + def show(self, style: Optional[str] = None, black_format: bool = True) -> None: + """A sugar for print highlighted TVM script. Parameters ---------- - rand_var : Union[ExprRV, BlockRV, LoopRV] - The random variable to be evaluated + style : str, optional - Returns - ------- - str_repr : str - The string representation + Pygmentize printing style, auto-detected if None. See + `tvm.script.highlight.cprint` for more details. + + black_format: bool + + If true (default), use the formatter Black to format the TVMScript """ - return str(self.get(rand_var)) + mod = self.mod + if mod is not None: + mod.show(style=style, black_format=black_format) + trace = self.trace + if trace is not None: + trace.show(style=style, black_format=black_format) ########## Lookup ########## diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 4847e377dec1a..096c13653a94b 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -41,6 +41,81 @@ class Stmt(Object): """Base class of all the statements.""" + def script( + self, + *, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: Optional[int] = None, + path_to_underline=None, + ) -> str: + """Print IRModule into TVMScript + + Parameters + ---------- + indent_spaces : int + The number of indent spaces to use in the output + print_line_numbers: bool + Whether to print line numbers + num_context_lines : Optional[int] + Number of context lines to print around the underlined text + path_to_underline : Optional[ObjectPath] + Object path to be underlined + + Returns + ------- + script : str + The TVM Script of the IRModule + """ + if num_context_lines is None: + num_context_lines = -1 + return _ffi_api.StmtScript( # type: ignore # pylint: disable=no-member + self, indent_spaces, print_line_numbers, num_context_lines, path_to_underline + ) + + def show( + self, + *, + style: Optional[str] = None, + black_format: bool = True, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: Optional[int] = None, + path_to_underline=None, + ) -> None: + """A sugar for print highlighted TVM script. + + Parameters + ---------- + style : str, optional + Pygmentize printing style, auto-detected if None. See + `tvm.script.highlight.cprint` for more details. + black_format: bool + If true (default), use the formatter Black to format the TVMScript + indent_spaces : int + The number of indent spaces to use in the output + print_line_numbers: bool + Whether to print line numbers + num_context_lines : Optional[int] + Number of context lines to print around the underlined text + path_to_underline : Optional[ObjectPath] + Object path to be underlined + """ + from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel + cprint, + ) + + cprint( + self.script( + indent_spaces=indent_spaces, + print_line_numbers=print_line_numbers, + num_context_lines=num_context_lines, + path_to_underline=path_to_underline, + ), + style=style, + black_format=black_format, + ) + @tvm._ffi.register_object("tir.LetStmt") class LetStmt(Stmt): diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 03a36e803be81..af6e47b7a0666 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1288,7 +1288,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { if (a->IsInstance() && b->IsInstance()) { // cannot multiply two iterators, mark as unresolved. ErrorLogger(this) << "Product of two iterators cannot be represented as an IterMap, " - << "occurs in " << tvm::PrettyPrint(GetRef(op)); + << "occurs in " << GetRef(op); return GetRef(op); } @@ -1321,7 +1321,7 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o } auto opt_fused = TryFuseIters(sum, check_level_); if (!opt_fused) { - ErrorLogger(this) << "Dividend " << tvm::PrettyPrint(original_dividend) + ErrorLogger(this) << "Dividend " << original_dividend << ", can't be written as a single fused IterSum"; return IterSumExpr(); } @@ -1446,8 +1446,7 @@ std::pair IterMapRewriter::PadDividendToDivisor(IterSpl // since the extent covers the full padding range. left_pad = floordiv(mark_left_pad, split->lower_factor); } else { - ErrorLogger(this) << "Detect incompatible left padding on " - << tvm::PrettyPrint(NormalizeIterMapToExpr(split)) + ErrorLogger(this) << "Detect incompatible left padding on " << NormalizeIterMapToExpr(split) << ", the iter mark is left padded with " << mark_left_pad; return {IterSplitExpr(), PrimExpr()}; } @@ -1522,8 +1521,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P } else { // mark as unresolved. ErrorLogger(this) << "Cannot represent as IterMap: the numerator's scaling factor, " - << tvm::PrettyPrint(lhs->scale) << " and the divisor " - << tvm::PrettyPrint(rhs) + << lhs->scale << " and the divisor " << rhs << " cannot be simplified to remove the scaling factor."; return PrimExpr(); } @@ -1621,7 +1619,7 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, P // mark as unresolved. ErrorLogger(this) << "Cannot represent as IterMap: the left-hand side of FloorMod has a scaling factor, " - << tvm::PrettyPrint(lhs->scale) << " and the right-hand " << tvm::PrettyPrint(rhs) + << lhs->scale << " and the right-hand " << rhs << " cannot be used to simplify out the scaling factor."; return PrimExpr(); } diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 3a92242276800..e03d4302c89f7 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -1274,28 +1274,28 @@ String ComputeDAG::PrintDAG(bool simple_mode) const { ICHECK_LT(k, p_reduce->combiner->result.size()); PrimExpr combiner = p_reduce->combiner->result[k]; if (combiner->IsInstance()) { - ss << " += " << LegacyTIRPrint(p_reduce->source[0]) << "\n"; + ss << " += " << AsLegacyRepr(p_reduce->source[0]) << "\n"; } else if (combiner->IsInstance()) { - ss << " max= " << LegacyTIRPrint(p_reduce->source[0]) << "\n"; + ss << " max= " << AsLegacyRepr(p_reduce->source[0]) << "\n"; } else if (combiner->IsInstance()) { - ss << " min= " << LegacyTIRPrint(p_reduce->source[0]) << "\n"; + ss << " min= " << AsLegacyRepr(p_reduce->source[0]) << "\n"; } else if (combiner->IsInstance()) { const auto& select = combiner.as(); - ss << " select(" << LegacyTIRPrint(select->condition) // - << ", " << LegacyTIRPrint(select->true_value) // - << ", " << LegacyTIRPrint(select->false_value) // - << ")= (" << LegacyTIRPrint(p_reduce->source[0]) // - << ',' << LegacyTIRPrint(p_reduce->source[1]) // + ss << " select(" << AsLegacyRepr(select->condition) // + << ", " << AsLegacyRepr(select->true_value) // + << ", " << AsLegacyRepr(select->false_value) // + << ")= (" << AsLegacyRepr(p_reduce->source[0]) // + << ',' << AsLegacyRepr(p_reduce->source[1]) // << ")\n"; } else { - ss << "reduce" << LegacyTIRPrint(combiner) << "\n"; + ss << "reduce" << AsLegacyRepr(combiner) << "\n"; } } else { auto call = pop->body[k].as(); if (simple_mode && call) { - ss << " = " << LegacyTIRPrint(call->op) << "\n"; + ss << " = " << AsLegacyRepr(call->op) << "\n"; } else { - ss << " = " << LegacyTIRPrint(pop->body[k]) << "\n"; + ss << " = " << AsLegacyRepr(pop->body[k]) << "\n"; } } } diff --git a/src/ir/adt.cc b/src/ir/adt.cc index f0ce859f3f87e..3533c8c514cd4 100644 --- a/src/ir/adt.cc +++ b/src/ir/adt.cc @@ -21,8 +21,9 @@ * \file src/ir/adt.cc * \brief ADT type definitions. */ -#include -#include +#include +#include +#include namespace tvm { diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index af46439cff7ca..f197ac4416faa 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -53,12 +53,6 @@ DictAttrs::DictAttrs(Map dict) { data_ = std::move(n); } -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->dict; - }); - TVM_REGISTER_NODE_TYPE(DictAttrsNode); TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode); diff --git a/src/ir/error.cc b/src/ir/error.cc index f0e78b954a410..26448d04005ca 100644 --- a/src/ir/error.cc +++ b/src/ir/error.cc @@ -21,15 +21,8 @@ * \file ir/error.cc * \brief Utilities for error tracking and reporting. */ - #include #include -// NOTE: reverse dependency on relay. -// These dependencies do not happen at the interface-level, -// and are only used in minimum cases where they are clearly marked. -// -// Rationale: use relay's printer for astext. -#include // clang-format off #include diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 7ba99e34d519a..050d9b87a856d 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -25,11 +25,6 @@ #include #include #include -// NOTE: reverse dependency on top/tir. -// These dependencies do not happen at the interface-level, -// and are only used in minimum cases where they are clearly marked. -// -// Rationale: convert from IterVar and top::Tensor #include #include @@ -168,12 +163,6 @@ TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name, Type type) { return GlobalVar(name, type); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "GlobalVar(" << node->name_hint << ")"; - }); - TVM_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) { std::stringstream ss; ss << ref; diff --git a/src/ir/function.cc b/src/ir/function.cc index dcfddd5f69d57..ce294708b2a9f 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -23,12 +23,6 @@ */ #include #include -// NOTE: reverse dependency on relay, tir/ -// These dependencies do not happen at the interface-level, -// and are only used in minimum cases where they are clearly marked. -// -// Rationale: We calls into the type specific WithAttr function -#include #include namespace tvm { @@ -41,11 +35,13 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr") .set_body_typed([](BaseFunc func, String key, ObjectRef value) -> BaseFunc { if (func->IsInstance()) { return WithAttr(Downcast(std::move(func)), key, value); - } else if (func->IsInstance()) { - return WithAttr(Downcast(std::move(func)), key, value); - } else { - LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); } + if (const auto* f = runtime::Registry::Get("relay.ir.FuncWithAttr")) { + if (Optional ret = (*f)(func, key, value)) { + return ret.value(); + } + } + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); }); } // namespace tvm diff --git a/src/ir/module.cc b/src/ir/module.cc index def94a0468551..b6923cd1e60dc 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -23,19 +23,10 @@ */ #include #include -#include -#include -// NOTE: reverse dependency on relay. -// These dependencies do not happen at the interface-level, -// and are only used in minimum cases where they are clearly marked. -// -// Rationale: We calls into relay's analysis module to verify correctness. #include +#include #include -#include -#include -#include -#include +#include #include #include @@ -182,26 +173,11 @@ tvm::Array IRModuleNode::GetGlobalTypeVars() const { return tvm::Array(global_type_vars); } -void WarnIfMalformed(const IRModule& mod, relay::Function func) { - func = Downcast(relay::DeDup(func)); - // Type check the item before we add it to the module. - auto fv = relay::FreeVars(func); - auto ftv = relay::FreeTypeVars(func, mod); - // TODO(@jroesch): refactor to use diagnostic context - ICHECK_EQ(fv.size(), 0) << "Function:" << std::endl - << PrettyPrint(func) << std::endl - << "contains free variables: " << fv; - ICHECK_EQ(ftv.size(), 0) << "Function:" << std::endl - << PrettyPrint(func) << std::endl - << "contains free type variables: " << fv; -} - void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) { BaseFunc checked_func = f; - if (auto* ptr = f.as()) { - WarnIfMalformed(GetRef(this), GetRef(ptr)); + if (const auto* f = runtime::Registry::Get("relay.ir.WarnIfMalformed")) { + (*f)(GetRef(this), checked_func); } - AddUnchecked(var, checked_func); } @@ -212,8 +188,7 @@ void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) { if (it != global_var_map_.end()) { ICHECK_EQ((*it).second, var); } else { - ICHECK(global_var_map_.count(var->name_hint) == 0) - << "Duplicate global function name " << PrettyPrint(var); + ICHECK(global_var_map_.count(var->name_hint) == 0) << "Duplicate global function name " << var; } global_var_map_.Set(var->name_hint, var); @@ -243,7 +218,7 @@ void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& if (!update) { // set global type var map ICHECK(global_type_var_map_.count(var->name_hint) == 0) - << "Duplicate global type definition name " << PrettyPrint(var); + << "Duplicate global type definition name " << var; } global_type_var_map_.Set(var->name_hint, var); RegisterConstructors(var, type); @@ -266,7 +241,7 @@ void IRModuleNode::Remove(const GlobalVar& var) { BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const { auto it = functions.find(var); - ICHECK(it != functions.end()) << "There is no definition of " << PrettyPrint(var); + ICHECK(it != functions.end()) << "There is no definition of " << var; return (*it).second; } @@ -277,7 +252,7 @@ BaseFunc IRModuleNode::Lookup(const String& name) const { TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const { auto it = type_definitions.find(var); - ICHECK(it != type_definitions.end()) << "There is no definition of " << PrettyPrint(var); + ICHECK(it != type_definitions.end()) << "There is no definition of " << var; return (*it).second; } @@ -292,70 +267,14 @@ Constructor IRModuleNode::LookupTag(const int32_t tag) { return (*it).second; } -/*! - * \brief Renames global type/term variables to prefer the GlobalTypeVar/GlobalVar in the lhs - * ('one') side above the rhs ('two'). - */ -struct Renamer : relay::ExprMutator, TypeMutator { - Map defs; - Map types; - std::unordered_map ctors; - - Renamer(Map defs_one, Map defs_two, - Map types_one, Map types_two, - std::unordered_map ctors_one, - std::unordered_map ctor_two) { - for (auto pair : defs_one) { - defs.Set(pair.first, pair.second); - } - - for (auto pair : defs_two) { - auto it = defs.find(pair.first); - if (it == defs.end()) { - defs.Set(pair.first, pair.second); - } - } - - for (auto pair : types_one) { - types.Set(pair.first, pair.second); - } - - for (auto pair : types_two) { - auto it = types.find(pair.first); - if (it == types.end()) { - types.Set(pair.first, pair.second); - } - } - } - - relay::Expr VisitExpr_(const GlobalVarNode* node) override { return defs.at(node->name_hint); } - - Type VisitType_(const GlobalTypeVarNode* node) override { return types.at(node->name_hint); } -}; - void IRModuleNode::Update(const IRModule& mod) { - Renamer renamer(this->global_var_map_, mod->global_var_map_, this->global_type_var_map_, - mod->global_type_var_map_, this->constructor_tag_map_, mod->constructor_tag_map_); - - this->global_var_map_ = renamer.defs; - this->global_type_var_map_ = renamer.types; - this->constructor_tag_map_ = renamer.ctors; - - for (auto pair : mod->type_definitions) { - auto tvar = renamer.types.at(pair.first->name_hint); - auto ty = renamer.ExprMutator::VisitType(pair.second); - this->AddTypeDefUnchecked(tvar, Downcast(ty), true); + if (const auto* f = runtime::Registry::Get("relay.ir.IRModuleUpdateWithRenamer")) { + (*f)(GetRef(this), mod); + return; } - for (auto pair : mod->functions) { - if (auto rfn = pair.second.as()) { - auto gvar = renamer.defs.at(pair.first->name_hint); - auto fn = renamer.VisitExpr(GetRef(rfn)); - this->AddUnchecked(gvar, Downcast(fn)); - } else { - // TODO(@jroesch): rename into IRModule. - this->AddUnchecked(pair.first, pair.second); - } + // TODO(@jroesch): rename into IRModule. + this->AddUnchecked(pair.first, pair.second); } } @@ -379,8 +298,10 @@ std::pair IRModule::FromExprInContext( // Function literal has been annotated with it's required global symbol. gv_name = opt.value(); } + } else if (const auto* f = runtime::Registry::Get("relay.ir.FunctionFromExprInContext")) { + func = (*f)(expr, mod); } else { - func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {}); + LOG(FATAL) << "`relay.ir.FunctionFromExprInContext` is not registered"; } GlobalVar main_gv; @@ -418,14 +339,6 @@ void IRModuleNode::ImportFromStd(const String& path) { this->Import(std_path + "/" + path); } -Bool IRModuleNode::ShouldLinkParameters() const { - Optional executor = GetAttr(tvm::attr::kExecutor); - if (!executor.defined()) { - return Bool(false); - } - return executor.value()->ShouldLinkParameters(); -} - std::unordered_set IRModuleNode::Imports() const { return this->import_set_; } IRModule IRModule::FromText(const String& text, const String& source_path) { @@ -440,29 +353,15 @@ TVM_REGISTER_GLOBAL("ir.IRModule") return IRModule(funcs, types, {}); }); -TVM_REGISTER_GLOBAL("ir.Module_Add").set_body([](TVMArgs args, TVMRetValue* ret) { - IRModule mod = args[0]; - GlobalVar var = args[1]; - ObjectRef val = args[2]; - bool update = args[3]; - ICHECK(val->IsInstance()); - - if (val->IsInstance()) { - mod->Add(var, Downcast(val), update); - } else if (val->IsInstance()) { - GlobalVar gv = Downcast(val); - auto mod_copy = IRModule(make_object(*mod.operator->())); - mod_copy = relay::transform::EtaExpand( - /* expand_constructor */ false, - /* expand_global_var */ true)(mod_copy); - auto func = mod_copy->Lookup(gv->name_hint); - mod->Add(var, Downcast(func), update); - } else { - auto func = relay::Function({}, Downcast(val), Type(nullptr), {}); - mod->Add(var, func, update); - } - *ret = mod; -}); +TVM_REGISTER_GLOBAL("ir.Module_Add") + .set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule { + ICHECK(val->IsInstance()); + if (const auto* f = runtime::Registry::Get("relay.ir.IRModuleAdd")) { + return (*f)(mod, var, val, update); + } + mod->Add(var, Downcast(val), update); + return mod; + }); TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_method(&IRModuleNode::AddTypeDef); @@ -530,10 +429,4 @@ TVM_REGISTER_GLOBAL("ir.Module_GetAttr").set_body_typed([](IRModule mod, String return mod->GetAttr(key); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "IRModule(" << node->functions << ")"; - }); - } // namespace tvm diff --git a/src/ir/transform.cc b/src/ir/transform.cc index e0f08d28fb182..bfd0a59175561 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -377,7 +377,6 @@ IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) c VLOG_CONTEXT << pass_info->name; VLOG(0) << "Executing module pass with opt level: " << pass_info->opt_level; - VLOG(1) << "Input module:" << std::endl << PrettyPrint(mod); mod = pass_func(std::move(mod), pass_ctx); @@ -389,8 +388,6 @@ IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) c pass_ctx->diag_ctx.value().Render(); pass_ctx->diag_ctx = previous; - VLOG(1) << "Result module:" << std::endl << PrettyPrint(mod); - return mod; } diff --git a/src/ir/type.cc b/src/ir/type.cc index ee05fd03596a6..d965406e8bb09 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -65,12 +65,6 @@ TVM_REGISTER_GLOBAL("ir.TypeVar").set_body_typed([](String name, int kind) { return TypeVar(name, static_cast(kind)); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypeVar(" << node->name_hint << ", " << node->kind << ")"; - }); - GlobalTypeVar::GlobalTypeVar(String name, TypeKind kind, Span span) { ObjectPtr n = make_object(); n->name_hint = std::move(name); @@ -85,12 +79,6 @@ TVM_REGISTER_GLOBAL("ir.GlobalTypeVar").set_body_typed([](String name, int kind) return GlobalTypeVar(name, static_cast(kind)); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "GlobalTypeVar(" << node->name_hint << ", " << node->kind << ")"; - }); - FuncType::FuncType(tvm::Array arg_types, Type ret_type, tvm::Array type_params, tvm::Array type_constraints, Span span) { ObjectPtr n = make_object(); @@ -110,13 +98,6 @@ TVM_REGISTER_GLOBAL("ir.FuncType") return FuncType(arg_types, ret_type, type_params, type_constraints); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "FuncType(" << node->type_params << ", " << node->arg_types << ", " - << node->ret_type << ", " << node->type_constraints << ")"; - }); - TupleType::TupleType(Array fields, Span span) { ObjectPtr n = make_object(); n->fields = std::move(fields); @@ -158,10 +139,4 @@ TVM_REGISTER_GLOBAL("ir.RelayRefType").set_body_typed([](Type value) { TVM_REGISTER_NODE_TYPE(RelayRefTypeNode); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RelayRefTypeNode(" << node->value << ")"; - }); - } // namespace tvm diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc index 4663fd90762ab..c90d92f83b391 100644 --- a/src/meta_schedule/arg_info.cc +++ b/src/meta_schedule/arg_info.cc @@ -52,13 +52,12 @@ inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { } // Priority 3: The only PrimFunc in the IRModule if (num_prim_func == 0) { - LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule: " - << tir::AsTVMScript(mod); + LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule: " << mod; } if (num_prim_func > 1) { LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but none of them are " "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`" - << tir::AsTVMScript(mod); + << mod; } return GetRef(last_func); } diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index 22d6ec849c5f9..b0fba5adb5c24 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -196,7 +196,7 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, } catch (std::runtime_error& e) { LOG(FATAL) << "ValueError: Unable to parse TuningRecord, on line " << (task_id + 1) << " of file " << path_tuning_record << ". The workload is:\n" - << (workload.defined() ? tir::AsTVMScript(workload->mod) : "(null)") + << (workload.defined() ? workload->mod->Script() : "(null)") << "\nThe JSONObject of TuningRecord is:\n" << json_obj << "\nThe error message is:\n" << e.what(); diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index 9d859947e4fee..404ee01983c5a 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -32,7 +32,7 @@ TaskRecord::TaskRecord(TuneContext ctx, double task_weight) { << "ValueError: Require `context.space_generator`, but it is not defined"; CHECK(ctx->search_strategy.defined()) << "ValueError: Require `context.search_strategy`, but it is not defined"; - TVM_PY_LOG(INFO, ctx->logger) << "\n" << tir::AsTVMScript(ctx->mod); + TVM_PY_LOG(INFO, ctx->logger) << "\n" << ctx->mod; ctx->Initialize(); n->flop = std::max(1.0, tir::EstimateTIRFlops(ctx->mod.value())); this->data_ = std::move(n); @@ -124,7 +124,7 @@ void TaskCleanUp(TaskRecordNode* self, int task_id, const Array& r << (builder_result->error_msg.defined() ? "building" : "running") << ":\n" << err << "\n" - << tir::AsTVMScript(sch->mod()) << "\n" + << sch->mod() << "\n" << Concat(sch->trace().value()->AsPython(false), "\n"); } else { double best_ms = *std::min_element(self->latency_ms.begin(), self->latency_ms.end()); @@ -168,7 +168,7 @@ void TaskSchedulerNode::Tune(Array ctxs, Array task_weigh tir::Trace trace = sch->trace().value(); trace = trace->Simplified(true); TVM_PY_LOG(INFO, ctx->logger) << "Design space #" << i << ":\n" - << tir::AsTVMScript(sch->mod()) << "\n" + << sch->mod() << "\n" << Concat(trace->AsPython(false), "\n"); } ctx->search_strategy.value()->PreTuning(max_trials_per_task, num_trials_per_iter, design_spaces, diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 6039423844e85..9a372dde8f6de 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -48,7 +48,6 @@ #include #include -#include "../printer/text_printer.h" #include "../support/array.h" #include "../support/base64.h" #include "../support/nd_int_set.h" diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index ea263439023fd..63bba67dd5f2a 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -51,6 +51,28 @@ ReprPrinter::FType& ReprPrinter::vtable() { return inst; } +void ReprLegacyPrinter::Print(const ObjectRef& node) { + static const FType& f = vtable(); + if (!node.defined()) { + stream << "(nullptr)"; + } else if (f.can_dispatch(node)) { + f(node, this); + } else { + stream << node; // Use ReprPrinter + } +} + +void ReprLegacyPrinter::PrintIndent() { + for (int i = 0; i < indent; ++i) { + stream << ' '; + } +} + +ReprLegacyPrinter::FType& ReprLegacyPrinter::vtable() { + static FType inst; + return inst; +} + void Dump(const runtime::ObjectRef& n) { std::cerr << n << "\n"; } void Dump(const runtime::Object* n) { Dump(runtime::GetRef(n)); } @@ -60,4 +82,7 @@ TVM_REGISTER_GLOBAL("node.AsRepr").set_body_typed([](runtime::ObjectRef obj) { os << obj; return os.str(); }); + +TVM_REGISTER_GLOBAL("node.AsLegacyRepr").set_body_typed(runtime::AsLegacyRepr); + } // namespace tvm diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 0290b7afe3fd8..80e390d9b0ada 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -314,9 +314,9 @@ class SEqualHandlerDefault::Impl { } if (assert_mode_ && !result) { LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by lhs:" << std::endl - << PrettyPrint(lhs) << std::endl + << lhs << std::endl << "and rhs:" << std::endl - << PrettyPrint(rhs); + << rhs; } return result; } diff --git a/src/printer/model_library_format_printer.cc b/src/printer/model_library_format_printer.cc index f6ac39ce79ffa..4220aa00f5a42 100644 --- a/src/printer/model_library_format_printer.cc +++ b/src/printer/model_library_format_printer.cc @@ -38,9 +38,9 @@ class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode { const char* type_key() const final { return "model_library_format_printer"; } std::string Print(const ObjectRef& node) { - Doc doc; - doc << text_printer_.PrintFinal(node); - return doc.str(); + std::ostringstream oss; + oss << node; + return oss.str(); } TVMRetValue GetVarName(tir::Var var) { diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index afc76112879e2..925c2ebf494e6 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -409,8 +409,6 @@ class TIRTextPrinter : public StmtFunctor, Doc PrintBody(const Stmt& body, bool indent = true); }; -String AsTVMScript(const ObjectRef& mod, const String& tir_prefix = "T", bool show_meta = false); - String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, runtime::TypedPackedFunc annotate); diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 274b9542cc925..c578bc53d3d36 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -2002,16 +2002,6 @@ Doc TVMScriptPrinterWithDiagnostic::PrintLoop(const For& loop) { return res; } -String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_meta) { - ICHECK(mod->IsInstance() || mod->IsInstance()); - Doc doc; - doc << TVMScriptPrinter::PrintHeader(tir_prefix) - << TVMScriptPrinter(tir_prefix, show_meta).Print(mod); - return doc.str() + "\n"; -} - -TVM_REGISTER_GLOBAL("script.AsTVMScript").set_body_typed(AsTVMScript); - String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, runtime::TypedPackedFunc annotate) { ICHECK(mod->IsInstance() || mod->IsInstance()); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index d71cbcfc667df..154101fc94fec 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -51,7 +51,6 @@ #include #include -#include "../../printer/text_printer.h" #include "../../te/operation/create_primfunc.h" #include "../op/memory/memory.h" #include "../src/meta_schedule/module_equality.h" @@ -646,7 +645,7 @@ class ScheduleBuilder : public ExprVisitor { // (dispatch & 4): controls whether to raise fatal errors for missing TIR if (dispatch & 2) { LOG(WARNING) << "Cannot find workload: " << prim_fn_var->name_hint << "\n" - << tir::AsTVMScript(f.value()); + << f.value(); } else { LOG(WARNING) << "Cannot find workload: " << prim_fn_var->name_hint; } diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 07cfb27b1d35e..3ff5eaa059c1c 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -21,7 +21,11 @@ * \file src/relay/ir/function.cc * \brief Function in relay. */ +#include +#include +#include #include +#include namespace tvm { namespace relay { @@ -119,6 +123,132 @@ const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func) { } return nullptr; } +TVM_REGISTER_GLOBAL("relay.ir.PrintRelayModule") + .set_body_typed([](IRModule mod) -> Optional { + for (const auto& it : mod->functions) { + if (it.second->IsInstance()) { + return PrettyPrint(mod); + } + } + return NullOpt; + }); + +TVM_REGISTER_GLOBAL("relay.ir.WarnIfMalformed") + .set_body_typed([](const IRModule& mod, const BaseFunc& base_func) -> void { + if (const auto* relay_func = base_func.as()) { + Function func = Downcast(relay::DeDup(GetRef(relay_func))); + // Type check the item before we add it to the module. + auto fv = relay::FreeVars(func); + auto ftv = relay::FreeTypeVars(func, mod); + // TODO(@jroesch): refactor to use diagnostic context + ICHECK_EQ(fv.size(), 0) << "Function:" << std::endl + << PrettyPrint(func) << std::endl + << "contains free variables: " << fv; + ICHECK_EQ(ftv.size(), 0) << "Function:" << std::endl + << PrettyPrint(func) << std::endl + << "contains free type variables: " << fv; + } + }); +TVM_REGISTER_GLOBAL("relay.ir.IRModuleAdd") + .set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule { + if (val->IsInstance()) { + mod->Add(var, Downcast(val), update); + } else if (val->IsInstance()) { + GlobalVar gv = Downcast(val); + IRModule mod_copy(make_object(*mod.operator->())); + mod_copy = relay::transform::EtaExpand( + /* expand_constructor */ false, + /* expand_global_var */ true)(mod_copy); + auto func = mod_copy->Lookup(gv->name_hint); + mod->Add(var, Downcast(func), update); + } else { + auto func = relay::Function({}, Downcast(val), Type(nullptr), {}); + mod->Add(var, func, update); + } + return mod; + }); + +TVM_REGISTER_GLOBAL("relay.ir.IRModuleUpdateWithRenamer") + .set_body_typed([](IRModule self, IRModule mod) -> void { + struct Renamer : relay::ExprMutator, TypeMutator { + Map defs; + Map types; + std::unordered_map ctors; + + Renamer(Map defs_one, Map defs_two, + Map types_one, Map types_two, + std::unordered_map ctors_one, + std::unordered_map ctor_two) { + for (auto pair : defs_one) { + defs.Set(pair.first, pair.second); + } + + for (auto pair : defs_two) { + auto it = defs.find(pair.first); + if (it == defs.end()) { + defs.Set(pair.first, pair.second); + } + } + + for (auto pair : types_one) { + types.Set(pair.first, pair.second); + } + + for (auto pair : types_two) { + auto it = types.find(pair.first); + if (it == types.end()) { + types.Set(pair.first, pair.second); + } + } + } + + relay::Expr VisitExpr_(const GlobalVarNode* node) override { + return defs.at(node->name_hint); + } + + Type VisitType_(const GlobalTypeVarNode* node) override { + return types.at(node->name_hint); + } + }; + + Renamer renamer(self->global_var_map_, mod->global_var_map_, self->global_type_var_map_, + mod->global_type_var_map_, self->constructor_tag_map_, + mod->constructor_tag_map_); + + self->global_var_map_ = renamer.defs; + self->global_type_var_map_ = renamer.types; + self->constructor_tag_map_ = renamer.ctors; + + for (auto pair : mod->type_definitions) { + auto tvar = renamer.types.at(pair.first->name_hint); + auto ty = renamer.ExprMutator::VisitType(pair.second); + self->AddTypeDefUnchecked(tvar, Downcast(ty), true); + } + + for (auto pair : mod->functions) { + if (auto rfn = pair.second.as()) { + auto gvar = renamer.defs.at(pair.first->name_hint); + auto fn = renamer.VisitExpr(GetRef(rfn)); + self->AddUnchecked(gvar, Downcast(fn)); + } else { + // TODO(@jroesch): rename into IRModule. + self->AddUnchecked(pair.first, pair.second); + } + } + }); + +TVM_REGISTER_GLOBAL("relay.ir.FunctionFromExprInContext") + .set_body_typed([](RelayExpr expr, IRModule mod) -> Function { + return Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {}); + }); + +TVM_REGISTER_GLOBAL("relay.ir.FuncWithAttr") + .set_body_typed([](BaseFunc func, String key, ObjectRef value) -> Optional { + if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } + return NullOpt; + }); TVM_REGISTER_NODE_TYPE(FunctionNode); diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc index 5ee3bbcef48fa..59f94e0cdd86b 100644 --- a/src/relay/transforms/defunctionalization.cc +++ b/src/relay/transforms/defunctionalization.cc @@ -312,7 +312,7 @@ class DefuncMutator : public ExprMutator { */ std::string TypeToString(const Type& t) { std::ostringstream s; - s << t; + s << t->GetTypeKey(); return s.str(); } diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index c4ecf92e9116d..5cd459be66964 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "./utils.h" namespace tvm { @@ -50,7 +52,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) BaseFunc func = kv.second; (*f)->stmts.push_back(d->AsDoc(func, p->Attr("functions")->MapValue(gv))); } - return ClassDoc(IdDoc("Module"), {IR(d)}, (*f)->stmts); + return ClassDoc(IdDoc("Module"), {IR("ir_module")}, (*f)->stmts); } }); @@ -61,14 +63,76 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](GlobalVar gv, ObjectPath p, IRDocsifier d) -> Doc { - return IdDoc("GlobalVar")->Call({LiteralDoc::Str(gv->name_hint)}); + return IR("GlobalVar")->Call({LiteralDoc::Str(gv->name_hint)}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](Op op, ObjectPath p, IRDocsifier d) -> Doc { - return IdDoc("Op")->Call({LiteralDoc::Str(op->name)}); + return IR("Op")->Call({LiteralDoc::Str(op->name)}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](TypeVar type_var, ObjectPath p, IRDocsifier d) -> Doc { + return IR("TypeVar")->Call({LiteralDoc::Str(type_var->name_hint), // + LiteralDoc::Str(TypeKind2String(type_var->kind))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](GlobalTypeVar type_var, ObjectPath p, IRDocsifier d) -> Doc { + return IR("GlobalTypeVar") + ->Call({LiteralDoc::Str(type_var->name_hint), // + LiteralDoc::Str(TypeKind2String(type_var->kind))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](RelayRefType ref, ObjectPath p, IRDocsifier d) -> Doc { + return IR("RelayRef")->Call({d->AsDoc(ref->value, p->Attr("value"))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](TensorType type, ObjectPath p, IRDocsifier d) -> Doc { + return IR("TensorType") + ->Call({d->AsDoc(type->shape, p->Attr("shape")), + LiteralDoc::DataType(type->dtype)}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](FuncType func_type, ObjectPath p, IRDocsifier d) -> Doc { + return IR("FuncType") + ->Call({ + d->AsDoc(func_type->type_params, p->Attr("type_params")), + d->AsDoc(func_type->arg_types, p->Attr("arg_types")), + d->AsDoc(func_type->ret_type, p->Attr("ret_type")), + }); }); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](IncompleteType ty, ObjectPath p, IRDocsifier d) -> Doc { + return IR("IncompleteType")->Call({}); + }); + +void ReprPrintIRModule(const ObjectRef& mod, ReprPrinter* p) { + if (const auto* f = runtime::Registry::Get("relay.ir.PrintRelayModule")) { + if (Optional s = (*f)(mod)) { + p->stream << s.value(); + return; + } + } + std::string res = + DocToPythonScript(IRDocsifier()->AsDoc(Downcast(mod), ObjectPath::Root())); + p->stream << res; +} + +TVM_SCRIPT_REPR(TypeVarNode, ReprPrintIR); +TVM_SCRIPT_REPR(GlobalTypeVarNode, ReprPrintIR); +TVM_SCRIPT_REPR(GlobalVarNode, ReprPrintIR); +TVM_SCRIPT_REPR(DictAttrsNode, ReprPrintIR); +TVM_SCRIPT_REPR(RelayRefTypeNode, ReprPrintIR); +TVM_SCRIPT_REPR(FuncTypeNode, ReprPrintIR); +TVM_SCRIPT_REPR(IncompleteTypeNode, ReprPrintIR); +TVM_SCRIPT_REPR(IRModuleNode, ReprPrintIRModule); + } // namespace printer } // namespace script } // namespace tvm diff --git a/src/script/printer/ir/script_method.cc b/src/script/printer/ir/script_method.cc new file mode 100644 index 0000000000000..01d3ede7ea6cf --- /dev/null +++ b/src/script/printer/ir/script_method.cc @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "./utils.h" + +namespace tvm { + +std::string IRModuleNode::Script(int indent_spaces, bool print_line_numbers, int num_context_lines, + Optional path_to_underline) const { + using namespace tvm::script::printer; + return DocToPythonScript(IRDocsifier()->AsDoc(GetRef(this), ObjectPath::Root()), + indent_spaces, print_line_numbers, num_context_lines, path_to_underline); +} + +TVM_REGISTER_GLOBAL("ir.Module_Script").set_body_method(&IRModuleNode::Script); + +} // namespace tvm diff --git a/src/script/printer/ir/utils.h b/src/script/printer/ir/utils.h index 4065b895c1bbb..820fe13df3c6c 100644 --- a/src/script/printer/ir/utils.h +++ b/src/script/printer/ir/utils.h @@ -28,11 +28,14 @@ #include +#include "../utils.h" + namespace tvm { namespace script { namespace printer { -inline ExprDoc IR(const IRDocsifier& d) { return IdDoc("tvm")->Attr("script"); } +/*! \brief Creates the IR common prefix, which is by default `I` */ +inline ExprDoc IR(const String& attr) { return IdDoc(Default::Prefix("ir"))->Attr(attr); } class IRFrameNode : public FrameNode { public: @@ -54,6 +57,17 @@ class IRFrame : public Frame { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRFrame, Frame, IRFrameNode); }; +inline void ReprPrintIR(const ObjectRef& obj, ReprPrinter* p) { + IRDocsifier d; + With f(d); + (*f)->AddDispatchToken(d, "ir"); + try { + p->stream << DocToPythonScript(Docsify(obj, d, *f)); + } catch (const Error& e) { + HandleUnsupportedFallback(e, obj, p); + } +} + } // namespace printer } // namespace script } // namespace tvm diff --git a/src/script/printer/legacy_repr.cc b/src/script/printer/legacy_repr.cc new file mode 100644 index 0000000000000..8159994284510 --- /dev/null +++ b/src/script/printer/legacy_repr.cc @@ -0,0 +1,1007 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include + +#include + +#include "../../support/str_escape.h" + +namespace tvm { + +#define TVM_LEGACY_REPR_PRINTER_DEF_OP(Type) \ + ReprLegacyPrinter& operator<<(ReprLegacyPrinter& p, Type value) { \ + p.Stream() << value; \ + return p; \ + } + +TVM_LEGACY_REPR_PRINTER_DEF_OP(int); +TVM_LEGACY_REPR_PRINTER_DEF_OP(int64_t); +TVM_LEGACY_REPR_PRINTER_DEF_OP(float); +TVM_LEGACY_REPR_PRINTER_DEF_OP(double); +TVM_LEGACY_REPR_PRINTER_DEF_OP(char); +TVM_LEGACY_REPR_PRINTER_DEF_OP(const char*); +TVM_LEGACY_REPR_PRINTER_DEF_OP(const std::string&); +TVM_LEGACY_REPR_PRINTER_DEF_OP(runtime::DataType); +TVM_LEGACY_REPR_PRINTER_DEF_OP(const void*); + +std::ostream& ReprLegacyPrinter::Stream() const { return stream; } + +ReprLegacyPrinter& operator<<(ReprLegacyPrinter& p, const ObjectRef& value) { + p.Stream() << AsLegacyRepr(value); + return p; +} + +ReprLegacyPrinter& operator<<(ReprLegacyPrinter& out, tir::ForKind type) { // NOLINT(*) + using tvm::tir::ForKind; + switch (type) { + case ForKind::kSerial: + out << "for"; + break; + case ForKind::kParallel: + out << "parallel"; + break; + case ForKind::kUnrolled: + out << "unrolled"; + break; + case ForKind::kVectorized: + out << "vectorized"; + break; + case ForKind::kThreadBinding: + out << "launch_thread"; + break; + } + return out; +} + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '['; + for (size_t i = 0; i < op->size(); ++i) { + if (i != 0) { + (*p) << ", "; + } + p->Print(op->at(i)); + } + (*p) << ']'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '{'; + for (auto it = op->begin(); it != op->end(); ++it) { + if (it != op->begin()) { + (*p) << ", "; + } + if (it->first->IsInstance()) { + (*p) << '\"' << Downcast(it->first) << "\": "; + } else { + p->Print(it->first); + (*p) << ": "; + } + p->Print(it->second); + } + (*p) << '}'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '['; + for (size_t i = 0; i < op->size; ++i) { + if (i != 0) { + (*p) << ", "; + } + (*p) << op->data[i]; + } + (*p) << ']'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + if (op->dtype == DataType::Int(32)) { + (*p) << op->value; + } else { + (*p) << "(" << op->dtype << ")" << op->value; + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + switch (op->dtype.bits()) { + case 64: + (*p) << op->value; + break; + case 32: + (*p) << op->value << 'f'; + break; + case 16: + (*p) << op->value << 'h'; + break; + default: + LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits(); + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "range(min=" << op->min << ", ext=" << op->extent << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + auto* node = static_cast(ref.get()); + (*p) << node->dtype; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + auto* node = static_cast(ref.get()); + if (!node->storage_scope.empty()) { + (*p) << node->storage_scope << " "; + } + p->Print(node->element_type); + (*p) << '*'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + auto* node = static_cast(ref.get()); + (*p) << "TupleTypeNode(" << node->fields << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + auto* node = static_cast(ref.get()); + (*p) << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << op->dict; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + auto* node = static_cast(ref.get()); + (*p) << "GlobalVar(" << node->name_hint << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + auto* node = static_cast(ref.get()); + (*p) << "IRModule(" << node->functions << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + auto* node = static_cast(ref.get()); + (*p) << "TypeVar(" << node->name_hint << ", " << node->kind << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + auto* node = static_cast(ref.get()); + (*p) << "GlobalTypeVar(" << node->name_hint << ", " << node->kind << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + auto* node = static_cast(ref.get()); + (*p) << "FuncType(" << node->type_params << ", " << node->arg_types << ", " << node->ret_type + << ", " << node->type_constraints << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + auto* node = static_cast(ref.get()); + (*p) << "RelayRefTypeNode(" << node->value << ")"; + }); + +} // namespace tvm + +namespace tvm { +namespace tir { + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "buffer(" << op->name << ", " << op << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + // omit the type + // stream << op->name << "." << op->type; + (*p) << op->name_hint; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "{" << op->name_hint << "|" << op->name_hint << ">=0}"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "iter_var("; + if (op->var->name_hint.length() != 0) { + (*p) << op->var->name_hint << ", "; + } + if (op->dom.defined()) { + (*p) << op->dom; + } + if (op->thread_tag.length() != 0) { + (*p) << ", " << op->thread_tag; + } + (*p) << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '\"' << support::StrEscape(op->value) << '\"'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << op->dtype << '('; + p->Print(op->value); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " + "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " - "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << "*"; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << "/"; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " % "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "floordiv(" << op->a << ", " << op->b << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "floormod(" << op->a << ", " << op->b << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "min("; + p->Print(op->a); + (*p) << ", "; + p->Print(op->b); + (*p) << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "max("; + p->Print(op->a); + (*p) << ", "; + p->Print(op->b); + (*p) << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " == "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " != "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " < "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " <= "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " > "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " >= "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " && "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '('; + p->Print(op->a); + (*p) << " || "; + p->Print(op->b); + (*p) << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '!'; + p->Print(op->a); + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "select("; + p->Print(op->condition); + (*p) << ", "; + p->Print(op->true_value); + (*p) << ", "; + p->Print(op->false_value); + (*p) << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << op->buffer_var << "["; + p->Print(op->index); + (*p) << "]"; + if (!is_one(op->predicate)) { + (*p) << " if "; + p->Print(op->predicate); + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "ramp("; + p->Print(op->base); + (*p) << ", "; + p->Print(op->stride); + (*p) << ", " << op->lanes << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "x" << op->lanes << "("; + p->Print(op->value); + (*p) << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "(let " << op->var << " = "; + p->Print(op->value); + (*p) << " in "; + p->Print(op->body); + (*p) << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + if (auto* ptr_op = op->op.as()) { + (*p) << ptr_op->name << "("; + } else { + auto* ptr_gvar = op->op.as(); + ICHECK(ptr_gvar != nullptr); + (*p) << "@" << ptr_gvar->name_hint << "("; + } + for (size_t i = 0; i < op->args.size(); ++i) { + p->Print(op->args[i]); + if (i < op->args.size() - 1) { + (*p) << ", "; + } + } + (*p) << ")"; + }); + +template +void PrintList(const Array& exprs, ReprLegacyPrinter* p) { + for (size_t i = 0; i < exprs.size(); ++i) { + p->Print(exprs[i]); + if (i < exprs.size() - 1) { + (*p) << ", "; + } + } +} + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "shuffle("; + PrintList(op->vectors, p); + (*p) << ", "; + PrintList(op->indices, p); + (*p) << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "comm_reducer(result=" << op->result << ", lhs=" << op->lhs << ", rhs=" << op->rhs + << ", identity_element=" << op->identity_element << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << "reduce(combiner=" << op->combiner; + (*p) << ", source=" << op->source; + (*p) << ", init=" << op->init; + (*p) << ", axis=" << op->axis; + (*p) << ", where=" << op->condition; + (*p) << ", value_index=" << op->value_index; + (*p) << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { (*p) << "?"; }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << op->buffer->name << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) { + (*p) << ", "; + } + } + (*p) << "]"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << op->producer->GetNameHint() << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) { + (*p) << ", "; + } + } + (*p) << "]"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { + // TODO(tvm-team) redirect to Text printer once we have a good text format. + auto* node = static_cast(ref.get()); + (*p) << "PrimFunc(" << node->params << ") "; + if (node->attrs.defined()) { + (*p) << "attrs=" << node->attrs; + } + (*p) << " {\n"; + p->indent += 2; + p->Print(node->body); + p->indent -= 2; + (*p) << "}\n"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << "let " << op->var << " = "; + p->Print(op->value); + (*p) << '\n'; + p->Print(op->body); + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << "// attr ["; + p->Print(op->node); + (*p) << "] " << op->attr_key << " = "; + p->Print(op->value); + (*p) << '\n'; + p->Print(op->body); + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << "assert("; + p->Print(op->condition); + (*p) << ", "; + p->Print(op->message); + (*p) << ")\n"; + p->Print(op->body); + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << op->kind << " (" << op->loop_var << ", "; + p->Print(op->min); + (*p) << ", "; + p->Print(op->extent); + (*p) << ") {\n"; + + p->indent += 2; + p->Print(op->body); + p->indent -= 2; + + p->PrintIndent(); + (*p) << "}\n"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << "while(" << op->condition << ") {\n"; + p->indent += 2; + p->Print(op->body); + p->indent -= 2; + p->PrintIndent(); + (*p) << "}\n"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << op->buffer_var << "["; + p->Print(op->index); + (*p) << "] = "; + p->Print(op->value); + if (!is_one(op->predicate)) { + (*p) << " if "; + p->Print(op->predicate); + } + (*p) << '\n'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << op->producer->GetNameHint() << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) (*p) << ", "; + } + (*p) << "]"; + (*p) << " ="; + p->Print(op->value); + (*p) << '\n'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + const auto* ptr_type = op->buffer_var->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + p->PrintIndent(); + (*p) << "allocate " << op->buffer_var << "[" << op->dtype; + for (size_t i = 0; i < op->extents.size(); ++i) { + (*p) << " * "; + p->Print(op->extents[i]); + } + (*p) << "], storage_scope = " << ptr_type->storage_scope; + if (!is_one(op->condition)) { + (*p) << " if "; + p->Print(op->condition); + } + (*p) << "\n"; + p->Print(op->body); + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << "constant " << op->buffer_var << "[" << op->dtype; + for (size_t i = 0; i < op->extents.size(); ++i) { + (*p) << " * "; + p->Print(op->extents[i]); + } + (*p) << "]"; + (*p) << "\n"; + p->Print(op->body); + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << "decl_buffer " << op->buffer << "\n"; + (*p) << op->body; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << "producer_realize " << op->producer->GetNameHint() << "("; + for (size_t i = 0; i < op->bounds.size(); ++i) { + (*p) << "["; + p->Print(op->bounds[i]->min); + (*p) << ", "; + p->Print(op->bounds[i]->extent); + (*p) << "]"; + if (i < op->bounds.size() - 1) (*p) << ", "; + } + (*p) << ")"; + if (!is_one(op->condition)) { + (*p) << " if "; + p->Print(op->condition); + } + (*p) << " {\n"; + + p->indent += 2; + p->Print(op->body); + p->indent -= 2; + + p->PrintIndent(); + (*p) << "}\n"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << "prefetch " << op->buffer << "("; + for (size_t i = 0; i < op->bounds.size(); ++i) { + (*p) << "["; + p->Print(op->bounds[i]->min); + (*p) << ", "; + p->Print(op->bounds[i]->extent); + (*p) << "]"; + if (i < op->bounds.size() - 1) (*p) << ", "; + } + (*p) << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + for (Stmt stmt : op->seq) { + p->Print(stmt); + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + while (true) { + (*p) << "if (" << op->condition << ") {\n"; + p->indent += 2; + p->Print(op->then_case); + p->indent -= 2; + + if (!op->else_case) { + break; + } + + if (const IfThenElseNode* nested_if = op->else_case.as()) { + p->PrintIndent(); + (*p) << "} else "; + op = nested_if; + } else { + p->PrintIndent(); + (*p) << "} else {\n"; + p->indent += 2; + p->Print(op->else_case); + p->indent -= 2; + break; + } + } + p->PrintIndent(); + (*p) << "}\n"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->Print(op->value); + (*p) << "\n"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << op->buffer->name << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) (*p) << ", "; + } + (*p) << "]"; + (*p) << " = "; + p->Print(op->value); + (*p) << '\n'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << "buffer_realize " << op->buffer->name << "("; + for (size_t i = 0; i < op->bounds.size(); ++i) { + (*p) << "["; + p->Print(op->bounds[i]->min); + (*p) << ", "; + p->Print(op->bounds[i]->extent); + (*p) << "]"; + if (i < op->bounds.size() - 1) (*p) << ", "; + } + (*p) << ")"; + if (!is_one(op->condition)) { + (*p) << " if "; + p->Print(op->condition); + } + (*p) << " {\n"; + + p->indent += 2; + p->Print(op->body); + p->indent -= 2; + + p->PrintIndent(); + (*p) << "}\n"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << op->buffer->name; + (*p) << "["; + for (size_t i = 0; i < op->region.size(); ++i) { + const auto& range = op->region[i]; + p->Print(range->min); + if (!is_one(range->extent)) { + (*p) << ":"; + p->Print(range->min + range->extent); + } + if (i != op->region.size() - 1) (*p) << ", "; + } + (*p) << "]"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + (*p) << op->buffer->name << " = match_buffer("; + p->Print(op->source); + (*p) << ")\n"; + }); + +void PrintBlockTitle(const BlockNode* op, ReprLegacyPrinter* p) { + (*p) << "block " << op->name_hint << "("; + for (size_t i = 0; i < op->iter_vars.size(); i++) { + p->Print(op->iter_vars[i]); + if (i < op->iter_vars.size() - 1) (*p) << ", "; + } + (*p) << ")"; +} + +void PrintBlockSignature(const BlockNode* op, ReprLegacyPrinter* p) { + // print read/write regions + p->PrintIndent(); + (*p) << "reads("; + p->Print(op->reads); + (*p) << ")\n"; + p->PrintIndent(); + (*p) << "writes("; + p->Print(op->writes); + (*p) << ")\n"; + // Print alloc_buffers + for (const auto& alloc_buf : op->alloc_buffers) { + p->PrintIndent(); + (*p) << alloc_buf->name << " = alloc_buffer(" << alloc_buf->dtype << "["; + for (size_t i = 0; i < alloc_buf->shape.size(); ++i) { + if (i > 0) (*p) << ", "; + p->Print(alloc_buf->shape[i]); + } + (*p) << "])\n"; + } + // Print match_buffer_regions + for (const auto& match_buf : op->match_buffers) { + p->Print(match_buf); + } + if (!op->annotations.empty()) { + p->PrintIndent(); + (*p) << "annotations(" << op->annotations << ")\n"; + } +} + +void PrintBlockBody(const BlockNode* op, ReprLegacyPrinter* p) { + // Print init + if (op->init.defined()) { + p->PrintIndent(); + (*p) << "with init() {\n"; + p->indent += 2; + p->Print(op->init.value()); + p->indent -= 2; + p->PrintIndent(); + (*p) << "}\n"; + } + // Print body + p->Print(op->body); +} + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + PrintBlockTitle(op, p); + (*p) << " {\n"; + p->indent += 2; + + // Print block elements (e.g. reads/writes, etc) + PrintBlockSignature(op, p); + // Print block init and body + PrintBlockBody(op, p); + + p->indent -= 2; + p->PrintIndent(); + (*p) << "}\n"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + auto* block_op = op->block.get(); + p->PrintIndent(); + PrintBlockTitle(block_op, p); + (*p) << " {\n"; + p->indent += 2; + + // Print binding iter_values + for (size_t i = 0; i < block_op->iter_vars.size(); ++i) { + p->PrintIndent(); + (*p) << "bind("; + p->Print(block_op->iter_vars[i]->var); + (*p) << ", "; + p->Print(op->iter_values[i]); + (*p) << ")\n"; + } + // Print predicate + if (!is_one(op->predicate)) { + p->PrintIndent(); + (*p) << "where("; + p->Print(op->predicate); + (*p) << ")\n"; + } + // Print block elements (e.g. reads/writes, etc) + PrintBlockSignature(block_op, p); + // Print block init and body + PrintBlockBody(block_op, p); + + p->indent -= 2; + p->PrintIndent(); + (*p) << "}\n"; + }); + +} // namespace tir +} // namespace tvm diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc index 8f008375ff874..e7f733864cc59 100644 --- a/src/script/printer/tir/block.cc +++ b/src/script/printer/tir/block.cc @@ -140,8 +140,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return PrintBlock(d, block, p, NullOpt, NullOpt); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); +TVM_SCRIPT_REPR(tir::BlockNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::BlockRealizeNode, ReprPrintTIR); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index b9eef12abc77b..5400328fe219f 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -247,14 +247,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return ScopeDoc(NullOpt, prefix, (*f)->stmts); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); +TVM_SCRIPT_REPR(tir::BufferRegionNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::BufferLoadNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::BufferStoreNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::BufferNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::MatchBufferRegionNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::ProducerLoadNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::ProducerStoreNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::ProducerRealizeNode, ReprPrintTIR); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index 317201fa3d747..1f2ba97700cbb 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -134,10 +134,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) Array vars; vars.reserve(n_vars + n_vars); for (int i = 0; i < n_vars; ++i) { - vars.push_back(DefineVar(r->lhs[i], *f, d)); + vars.push_back(Downcast(DefineVar(r->lhs[i], *f, d))); } for (int i = 0; i < n_vars; ++i) { - vars.push_back(DefineVar(r->rhs[i], *f, d)); + vars.push_back(Downcast(DefineVar(r->rhs[i], *f, d))); } int n_results = r->result.size(); Array results; @@ -190,7 +190,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }; ExprDoc prefix{nullptr}; if (const auto* op = call->op.as()) { - String name = op_names[GetRef(op)]; + String name = op_names.get(GetRef(op), op->name); + if (op_names.count(GetRef(op)) == 0) { + LOG(WARNING) << "No TScriptPrinterName attribute for " << op->name; + } prefix = TIR(name); } else if (const auto* gv = call->op.as()) { prefix = LiteralDoc::Str(gv->name_hint); @@ -278,39 +281,39 @@ TVM_SCRIPT_PRINTER_DEF_BINARY(Max, "max"); #undef TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR #undef TVM_SCRIPT_PRINTER_DEF_BINARY -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); +TVM_SCRIPT_REPR(tir::VarNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::SizeVarNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::IterVarNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::StringImmNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::CastNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::AddNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::SubNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::MulNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::DivNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::ModNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::FloorDivNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::FloorModNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::MinNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::MaxNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::LTNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::LENode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::EQNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::NENode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::GTNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::GENode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::AndNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::OrNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::NotNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::SelectNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::RampNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::BroadcastNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::LetNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::CallNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::ShuffleNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::CommReducerNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::AnyNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::ReduceNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::LoadNode, ReprPrintTIR); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/for_loop.cc b/src/script/printer/tir/for_loop.cc index 239b8e565f355..c8e2580f9c6fa 100644 --- a/src/script/printer/tir/for_loop.cc +++ b/src/script/printer/tir/for_loop.cc @@ -62,7 +62,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return ForDoc(TupleDoc(lhs), TIR("grid")->Call(rhs), (*f)->stmts); } // Step 3. If not `T.grid`, print loop kind accordingly - IdDoc lhs = DefineVar(loop->loop_var, *f, d); + ExprDoc lhs = DefineVar(loop->loop_var, *f, d); Optional min = NullOpt; Optional max = NullOpt; Optional annotations = NullOpt; @@ -117,7 +117,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return ForDoc(lhs, rhs, (*f)->stmts); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); +TVM_SCRIPT_REPR(tir::ForNode, ReprPrintTIR); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index 55e8c075deb7d..43dce28135117 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -76,11 +76,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) /*body=*/(*frame)->stmts); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { - std::string res = DocToPythonScript(IRDocsifier()->AsDoc(obj, ObjectPath::Root())); - p->stream << res; - }); +void ReprPrintPrimFunc(const ObjectRef& obj, ReprPrinter* p) { + std::string res = DocToPythonScript(IRDocsifier()->AsDoc(obj, ObjectPath::Root())); + p->stream << res; +} + +TVM_SCRIPT_REPR(tir::PrimFuncNode, ReprPrintPrimFunc); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/ir.cc b/src/script/printer/tir/ir.cc index 5fea278a4444a..ad00c42119f61 100644 --- a/src/script/printer/tir/ir.cc +++ b/src/script/printer/tir/ir.cc @@ -89,24 +89,18 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return TIR("Tuple")->Call(d->AsDoc(ty->fields, p->Attr("fields"))->elements); }); -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](IncompleteType ty, ObjectPath p, IRDocsifier d) -> Doc { - return TIR("IncompleteType")->Call({}); - }); - TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](Target target, ObjectPath p, IRDocsifier d) -> Doc { Map config = target->Export(); return TIR("target")->Call({d->AsDoc(config, p)}); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); +TVM_SCRIPT_REPR(IntImmNode, ReprPrintTIR); +TVM_SCRIPT_REPR(FloatImmNode, ReprPrintTIR); +TVM_SCRIPT_REPR(RangeNode, ReprPrintTIR); +TVM_SCRIPT_REPR(PrimTypeNode, ReprPrintTIR); +TVM_SCRIPT_REPR(PointerTypeNode, ReprPrintTIR); +TVM_SCRIPT_REPR(TupleTypeNode, ReprPrintTIR); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/script_method.cc b/src/script/printer/tir/script_method.cc new file mode 100644 index 0000000000000..5cda9a9626db7 --- /dev/null +++ b/src/script/printer/tir/script_method.cc @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "./utils.h" + +namespace tvm { + +std::string PrimExprNode::Script(int indent_spaces, bool print_line_numbers, int num_context_lines, + Optional path_to_underline) const { + using namespace tvm::script::printer; + IRDocsifier d; + ObjectRef obj = GetRef(this); + With f(MakeDispatchFrame(d, obj, ObjectRef(nullptr))); + return DocToPythonScript(Docsify(obj, d, *f), indent_spaces, print_line_numbers, + num_context_lines, path_to_underline); +} + +namespace tir { + +std::string StmtNode::Script(int indent_spaces, bool print_line_numbers, int num_context_lines, + Optional path_to_underline) const { + using namespace tvm::script::printer; + IRDocsifier d; + ObjectRef obj = GetRef(this); + With f(MakeDispatchFrame(d, obj, ObjectRef(nullptr))); + return DocToPythonScript(Docsify(obj, d, *f), indent_spaces, print_line_numbers, + num_context_lines, path_to_underline); +} + +std::string PrimFuncNode::Script(int indent_spaces, bool print_line_numbers, int num_context_lines, + Optional path_to_underline) const { + using namespace tvm::script::printer; + return DocToPythonScript(IRDocsifier()->AsDoc(GetRef(this), ObjectPath::Root()), + indent_spaces, print_line_numbers, num_context_lines, path_to_underline); +} + +TVM_REGISTER_GLOBAL("tir.PrimFuncScript").set_body_method(&PrimFuncNode::Script); +TVM_REGISTER_GLOBAL("tir.StmtScript").set_body_method(&StmtNode::Script); +TVM_REGISTER_GLOBAL("tir.PrimExprScript").set_body_method(&PrimExprNode::Script); + +} // namespace tir +} // namespace tvm diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index 436f2b202d854..7344cb4d98d5f 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -352,19 +352,19 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) LOG(FATAL) << "ValueError: Store has been deprecated for BufferStore: " << stmt; }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(ReprPrint); +TVM_SCRIPT_REPR(tir::LetStmtNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::AttrStmtNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::AssertStmtNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::WhileNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::AllocateNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::AllocateConstNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::DeclBufferNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::PrefetchNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::SeqStmtNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::IfThenElseNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::EvaluateNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::BufferRealizeNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::StoreNode, ReprPrintTIR); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h index 7f67c3a11c73a..047513dcb316b 100644 --- a/src/script/printer/tir/utils.h +++ b/src/script/printer/tir/utils.h @@ -33,6 +33,8 @@ #include #include +#include "../utils.h" + namespace tvm { namespace script { namespace printer { @@ -81,7 +83,10 @@ inline ExprDoc TIR(const String& attr) { return IdDoc(Default::Prefix("tir"))->A * \param frame The frame to define the variable in * \return The IdDoc corresponding to the variable */ -inline IdDoc DefineVar(const tir::Var& var, const Frame& frame, const IRDocsifier& d) { +inline ExprDoc DefineVar(const tir::Var& var, const Frame& frame, const IRDocsifier& d) { + if (Optional doc = d->GetVarDoc(var)) { + return doc.value(); + } return d->Define(var, frame, var->name_hint.empty() ? "v" : var->name_hint); } @@ -181,26 +186,14 @@ inline TIRFrame MakeDispatchFrame(const IRDocsifier& d, const ObjectRef& root, } /*! \brief Redirected method for the ReprPrinter */ -inline void ReprPrint(const ObjectRef& stmt, ReprPrinter* p) { +inline void ReprPrintTIR(const ObjectRef& obj, ReprPrinter* p) { IRDocsifier d; - With f(MakeDispatchFrame(d, stmt, ObjectRef(nullptr))); - Doc doc = d->AsDoc(stmt, ObjectPath::Root()); - if (const auto* expr_doc = doc.as()) { - if (!Default::VerboseExpr()) { - (*f)->stmts.clear(); - } - (*f)->stmts.push_back(ExprStmtDoc(GetRef(expr_doc))); - } else if (const auto* stmt_doc = doc.as()) { - (*f)->stmts.push_back(GetRef(stmt_doc)); - } else if (const auto* stmt_block = doc.as()) { - for (const StmtDoc& d : stmt_block->stmts) { - (*f)->stmts.push_back(d); - } - } else { - LOG(FATAL) << "TypeError: Unexpected doc type: " << doc->GetTypeKey(); + With f(MakeDispatchFrame(d, obj, ObjectRef(nullptr))); + try { + p->stream << DocToPythonScript(Docsify(obj, d, *f)); + } catch (const tvm::Error& e) { + HandleUnsupportedFallback(e, obj, p); } - std::string res = DocToPythonScript(StmtBlockDoc((*f)->stmts)); - p->stream << res; } /*! diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h new file mode 100644 index 0000000000000..9f9a7d8299c4a --- /dev/null +++ b/src/script/printer/utils.h @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_UTILS_H_ +#define TVM_SCRIPT_PRINTER_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +#define TVM_SCRIPT_REPR(ObjectType, Method) \ + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(Method); + +inline StmtBlockDoc Docsify(const ObjectRef& obj, const IRDocsifier& d, const Frame& f) { + Doc doc = d->AsDoc(obj, ObjectPath::Root()); + if (const auto* expr_doc = doc.as()) { + if (!Default::VerboseExpr()) { + f->stmts.clear(); + } + f->stmts.push_back(ExprStmtDoc(GetRef(expr_doc))); + } else if (const auto* stmt_doc = doc.as()) { + f->stmts.push_back(GetRef(stmt_doc)); + } else if (const auto* stmt_block = doc.as()) { + for (const StmtDoc& d : stmt_block->stmts) { + f->stmts.push_back(d); + } + } else { + LOG(FATAL) << "TypeError: Unexpected doc type: " << doc->GetTypeKey(); + } + return StmtBlockDoc(f->stmts); +} + +inline void HandleUnsupportedFallback(const tvm::Error& error, const ObjectRef& obj, + ReprPrinter* p) { + LOG(WARNING) << "TVMScript printer falls back to the legacy ReprPrinter with the error:\n" + << error.what(); + p->stream << AsLegacyRepr(obj); +} + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_UTILS_H_ diff --git a/src/target/source/interface_c.cc b/src/target/source/interface_c.cc index fe495b212ad86..d2d1d3f78d74f 100644 --- a/src/target/source/interface_c.cc +++ b/src/target/source/interface_c.cc @@ -218,8 +218,7 @@ class InterfaceCNode : public runtime::ModuleNode { code_ << '\n'; } else { - LOG(FATAL) << "No constant data in constant pool found " - << PrettyPrint(GetRef(pool_info)); + LOG(FATAL) << "No constant data in constant pool found " << GetRef(pool_info); } } diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index ce5f5d5b53575..ccc15fc1ee490 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -329,8 +329,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << "};"; code_ << "// of total size " << allocated_size << " bytes\n"; } else { - LOG(FATAL) << "No constant data in constant pool found " - << PrettyPrint(GetRef(pool_info)); + LOG(FATAL) << "No constant data in constant pool found " << GetRef(pool_info); } } diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc index 2e537450d2328..de9da80140e4a 100644 --- a/src/tir/analysis/control_flow_graph.cc +++ b/src/tir/analysis/control_flow_graph.cc @@ -25,6 +25,7 @@ #include "control_flow_graph.h" #include +#include #include #include #include @@ -1623,8 +1624,8 @@ bool ControlFlowGraph::IsOverwrittenWithoutEffect(const tir::BufferStore& store, } auto it = control_flow_lookup_.find(context.get()); - ICHECK(it != control_flow_lookup_.end()) - << "Context " << PrettyPrint(context) << " did not occur within analyzed statement"; + ICHECK(it != control_flow_lookup_.end()) << "Context did not occur within analyzed statement:\n" + << context; const auto& context_block = control_flow_[it->second]; auto [store_touch, free_params] = context_block.MakeBufferTouch( diff --git a/src/tir/analysis/oob_checker.cc b/src/tir/analysis/oob_checker.cc index a3d3501a9aae2..dbe114df49738 100644 --- a/src/tir/analysis/oob_checker.cc +++ b/src/tir/analysis/oob_checker.cc @@ -24,7 +24,6 @@ #include #include "../../arith/ir_visitor_with_analyzer.h" -#include "../../printer/text_printer.h" #include "../schedule/error.h" namespace tvm { diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 80d6897011d5d..9d932d2363557 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -182,7 +182,7 @@ std::vector VerifyMemory_(const PrimFunc& func) { VLOG(1) << "verifying memory for target '" << target.value()->str() << "' for primitive:" << std::endl - << PrettyPrint(func); + << func; if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDefault) { diff --git a/src/tir/ir/legacy_printer.cc b/src/tir/ir/legacy_printer.cc deleted file mode 100644 index 4c2fd5037b652..0000000000000 --- a/src/tir/ir/legacy_printer.cc +++ /dev/null @@ -1,270 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include - -#include - -#include "../../support/str_escape.h" - -namespace tvm { -namespace tir { - -std::string LegacyTIRPrint(const ObjectRef& obj) { - using namespace tvm::tir; - class LegacyTIRPrinter : private tir::ExprVisitor { - public: - explicit LegacyTIRPrinter(std::ostream& os) : stream(os) {} - - void Print(const ObjectRef& obj) { - if (const auto* op = obj.as()) { - Print_(op); - } else if (const auto* op = obj.as()) { - Print_(op); - } else if (const auto* op = obj.as()) { - Print_(op); - } else if (const auto* op = obj.as()) { - Print_(op); - } else { - VisitExpr(Downcast(obj)); - } - } - - private: - void VisitExpr_(const VarNode* op) final { stream << op->name_hint; } - - void VisitExpr_(const SizeVarNode* op) final { - stream << "{" << op->name_hint << "|" << op->name_hint << ">=0}"; - } - - void VisitExpr_(const IntImmNode* op) final { - if (op->dtype == DataType::Int(32)) { - stream << op->value; - } else { - stream << "(" << op->dtype << ")" << op->value; - } - } - - void VisitExpr_(const FloatImmNode* op) final { - switch (op->dtype.bits()) { - case 64: - stream << op->value; - break; - case 32: - stream << op->value << 'f'; - break; - case 16: - stream << op->value << 'h'; - break; - default: - LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits(); - } - } - void VisitExpr_(const StringImmNode* op) final { - stream << '\"' << support::StrEscape(op->value) << '\"'; - } - void VisitExpr_(const CastNode* op) final { - stream << op->dtype << '('; - VisitExpr(op->value); - stream << ')'; - } - void VisitExpr_(const AddNode* op) final { PrintBinary(op->a, op->b, " + "); } - void VisitExpr_(const SubNode* op) final { PrintBinary(op->a, op->b, " - "); } - void VisitExpr_(const MulNode* op) final { PrintBinary(op->a, op->b, "*"); } - void VisitExpr_(const DivNode* op) final { PrintBinary(op->a, op->b, "/"); } - void VisitExpr_(const ModNode* op) final { PrintBinary(op->a, op->b, " % "); } - void VisitExpr_(const FloorDivNode* op) final { PrintCall("floordiv", op->a, op->b); } - void VisitExpr_(const FloorModNode* op) final { PrintCall("floormod", op->a, op->b); } - void VisitExpr_(const MinNode* op) final { PrintCall("min", op->a, op->b); } - void VisitExpr_(const MaxNode* op) final { PrintCall("max", op->a, op->b); } - void VisitExpr_(const EQNode* op) final { PrintBinary(op->a, op->b, " == "); } - void VisitExpr_(const NENode* op) final { PrintBinary(op->a, op->b, " != "); } - void VisitExpr_(const LTNode* op) final { PrintBinary(op->a, op->b, " < "); } - void VisitExpr_(const LENode* op) final { PrintBinary(op->a, op->b, " <= "); } - void VisitExpr_(const GTNode* op) final { PrintBinary(op->a, op->b, " > "); } - void VisitExpr_(const GENode* op) final { PrintBinary(op->a, op->b, " >= "); } - void VisitExpr_(const AndNode* op) final { PrintBinary(op->a, op->b, " && "); } - void VisitExpr_(const OrNode* op) final { PrintBinary(op->a, op->b, " || "); } - - void VisitExpr_(const NotNode* op) final { - stream << "!"; - VisitExpr(op->a); - } - - void VisitExpr_(const SelectNode* op) final { - stream << "select("; - VisitExpr(op->condition); - stream << ", "; - VisitExpr(op->true_value); - stream << ", "; - VisitExpr(op->false_value); - stream << ')'; - } - - void VisitExpr_(const RampNode* op) final { - stream << "ramp("; - VisitExpr(op->base); - stream << ", "; - VisitExpr(op->stride); - stream << ", " << op->lanes << ')'; - } - - void VisitExpr_(const BroadcastNode* op) final { - stream << "x" << op->lanes << "("; - VisitExpr(op->value); - stream << ")"; - } - - void VisitExpr_(const LetNode* op) final { - stream << "(let " << op->var << " = "; - VisitExpr(op->value); - stream << " in "; - VisitExpr(op->body); - stream << ")"; - } - - void VisitExpr_(const CallNode* op) final { - if (auto* ptr_op = op->op.as()) { - stream << ptr_op->name << "("; - } else { - auto* p = op->op.as(); - ICHECK(p != nullptr); - stream << "@" << p->name_hint << "("; - } - for (size_t i = 0; i < op->args.size(); ++i) { - VisitExpr(op->args[i]); - if (i < op->args.size() - 1) { - stream << ", "; - } - } - stream << ")"; - } - - void VisitExpr_(const ShuffleNode* op) final { - stream << "shuffle("; - PrintList(op->vectors.GetArrayNode()); - stream << ", "; - PrintList(op->indices.GetArrayNode()); - stream << ")"; - } - - void VisitExpr_(const ReduceNode* op) final { - stream << "reduce(combiner="; - Print_(op->combiner.get()); - stream << ", source="; - PrintList(op->source.GetArrayNode()); - stream << ", init="; - PrintList(op->init.GetArrayNode()); - stream << ", axis="; - PrintList(op->axis.GetArrayNode()); - stream << ", where="; - VisitExpr(op->condition); - stream << ", value_index=" << op->value_index; - stream << ")"; - } - - void VisitExpr_(const AnyNode* op) final { stream << "?"; } - - void VisitExpr_(const BufferLoadNode* op) final { - stream << op->buffer->name << "["; - for (size_t i = 0; i < op->indices.size(); ++i) { - VisitExpr(op->indices[i]); - if (i < op->indices.size() - 1) { - stream << ", "; - } - } - stream << "]"; - } - - void VisitExpr_(const ProducerLoadNode* op) final { - stream << op->producer->GetNameHint() << "["; - for (size_t i = 0; i < op->indices.size(); ++i) { - VisitExpr(op->indices[i]); - if (i < op->indices.size() - 1) { - stream << ", "; - } - } - stream << "]"; - } - - private: - void Print_(const CommReducerNode* op) { - stream << "comm_reducer(result="; - PrintList(op->result.GetArrayNode()); - stream << ", lhs="; - PrintList(op->lhs.GetArrayNode()); - stream << ", rhs="; - PrintList(op->rhs.GetArrayNode()); - stream << ", identity_element="; - PrintList(op->identity_element.GetArrayNode()); - stream << ")"; - } - - void Print_(const IterVarNode* op) { - stream << "{" << op->var->name_hint << "|" << op->var->name_hint << " in ["; - VisitExpr(op->dom->min); - stream << ", "; - VisitExpr(op->dom->extent); - stream << ")}"; - } - - void Print_(const RangeNode* op) { - stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; - } - - void Print_(const OpNode* op) { stream << "Op(" << op->name << ")"; } - - private: - void PrintBinary(const PrimExpr& a, const PrimExpr& b, const std::string& sign) { - stream << '('; - VisitExpr(a); - stream << sign; - VisitExpr(b); - stream << ')'; - } - - void PrintCall(const std::string& call, const PrimExpr& a, const PrimExpr& b) { - stream << call << '('; - VisitExpr(a); - stream << ", "; - VisitExpr(b); - stream << ')'; - } - - void PrintList(const ArrayNode* exprs) { - int n = static_cast(exprs->size()); - for (int i = 0; i < n; ++i) { - VisitExpr(Downcast(exprs->at(i))); - if (i < n - 1) { - stream << ", "; - } - } - } - - std::ostream& stream; - }; - std::ostringstream os; - LegacyTIRPrinter(os).Print(obj); - return os.str(); -} - -} // namespace tir -} // namespace tvm diff --git a/src/tir/schedule/analysis/verify.cc b/src/tir/schedule/analysis/verify.cc index e9ee7227f6fb0..ef45f7f8c701b 100644 --- a/src/tir/schedule/analysis/verify.cc +++ b/src/tir/schedule/analysis/verify.cc @@ -234,7 +234,7 @@ void VerifyCachedFlags(const ScheduleState& self) { os << std::endl; } LOG(FATAL) << "Schedule verification failed. The IR is:\n" - << AsTVMScript(self->mod) << "\nThe errors are:\n" + << self->mod << "\nThe errors are:\n" << os.str(); throw; } diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc index 32e5c2455a857..55d751c3311ed 100644 --- a/src/tir/schedule/error.cc +++ b/src/tir/schedule/error.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include "../../printer/text_printer.h" #include "./utils.h" namespace tvm { diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index e4771c8b19f66..d21149437f08d 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -225,11 +225,11 @@ class ProducerHasNonTrivialPredicateError : public ScheduleError { } String DetailRenderTemplate() const final { - return "ScheduleError: The producer block {0} has a non-trivial predicate " + - PrettyPrint(producer_->predicate) + - " that cannot be implied " - "by the synthesized predicate " + - PrettyPrint(new_predicate_) + " of the new inlined block."; + std::ostringstream os; + os << "ScheduleError: The producer block {0} has a non-trivial predicate " + << producer_->predicate << " that cannot be implied by the synthesized predicate " + << new_predicate_ << " of the new inlined block."; + return os.str(); } IRModule mod() const final { return mod_; } diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index a9b367c4b7d9e..6aff85da720dd 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -17,6 +17,8 @@ * under the License. */ +#include + #include #include @@ -1266,7 +1268,7 @@ class OpaqueNewIterTypeError : public ScheduleError { String DetailRenderTemplate() const final { std::ostringstream os; - os << "Cannot detect the block iter type for new iter value " << PrettyPrint(iter_value_) + os << "Cannot detect the block iter type for new iter value " << iter_value_ << " in {0} because it contains more than one type of original iter vars."; return os.str(); } diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index bcc8b7facbc9c..d40906209fb95 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -38,7 +39,6 @@ #include "../../arith/pattern_match.h" #include "../../node/attr_registry.h" -#include "../../printer/text_printer.h" #include "../../runtime/thread_storage_scope.h" #include "../../support/array.h" #include "../../support/nd_int_set.h" diff --git a/src/tir/transforms/common_subexpr_elim.cc b/src/tir/transforms/common_subexpr_elim.cc index 5cf6f231dd805..acda9220b731d 100644 --- a/src/tir/transforms/common_subexpr_elim.cc +++ b/src/tir/transforms/common_subexpr_elim.cc @@ -151,8 +151,8 @@ bool CommonSubexpressionEliminator::OrderOnExprAndFrequency(std::pair> SyntacticToSemanticComputations( [](std::pair a, std::pair b) { std::stringstream a_stream; std::stringstream b_stream; - a_stream << LegacyTIRPrint(a.first); - b_stream << LegacyTIRPrint(b.first); + a_stream << AsLegacyRepr(a.first); + b_stream << AsLegacyRepr(b.first); return a_stream.str().compare(b_stream.str()) < 0; }); diff --git a/src/tir/transforms/install_debug_spans.cc b/src/tir/transforms/install_debug_spans.cc index 4daa1aafe8cc8..bc9002ee841fa 100644 --- a/src/tir/transforms/install_debug_spans.cc +++ b/src/tir/transforms/install_debug_spans.cc @@ -23,7 +23,7 @@ the location to which the ops would be printed */ -#include "install_debug_spans.h" +#include "./install_debug_spans.h" #include diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index e1dc2f5bf113c..e9c57eb78e262 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -30,7 +30,6 @@ #include "../../arith/ir_mutator_with_analyzer.h" #include "../../arith/ir_visitor_with_analyzer.h" -#include "../../printer/text_printer.h" namespace tvm { namespace tir { diff --git a/src/tir/usmp/transform/assign_pool_info.cc b/src/tir/usmp/transform/assign_pool_info.cc index 2bded7b4877b0..3acceab6e31bb 100644 --- a/src/tir/usmp/transform/assign_pool_info.cc +++ b/src/tir/usmp/transform/assign_pool_info.cc @@ -99,7 +99,7 @@ class PoolInfoAssigner : public StmtExprMutator { }; WorkspacePoolInfo PoolInfoAssigner::CreateDefaultWorkspaceMemoryPool(const tvm::IRModule& module) { - VLOG(1) << "Creating default memory pool for:" << std::endl << PrettyPrint(module); + VLOG(1) << "Creating default memory pool for:" << std::endl << module; Map target_access; tir::PrimFunc tir_main_func = Downcast(module->Lookup(::tvm::runtime::symbol::tvm_module_main)); @@ -134,7 +134,7 @@ Stmt PoolInfoAssigner::VisitStmt_(const AllocateNode* op) { Map annotations = Map(op->annotations); if (op->annotations.find(kPoolCandidatesAllocateAttr) == op->annotations.end()) { ICHECK(target_pool_infos_.count(tgt.value()->str()) > 0) - << "Target " << PrettyPrint(tgt) << " not found among " << PrettyPrint(target_pool_infos_); + << "Target " << tgt << " not found among " << target_pool_infos_; annotations.Set(kPoolCandidatesAllocateAttr, target_pool_infos_[tgt.value()->str()]); } Stmt body = VisitStmt(op->body); diff --git a/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py b/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py index 91458f60e1725..062637b3bb942 100644 --- a/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py +++ b/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py @@ -81,6 +81,6 @@ def expected(): exp = expected() global_vars = [str(gv) for gv in after.get_global_vars()] - assert "@ext_func" in global_vars - assert "@ext_func_2" not in global_vars + assert 'I.GlobalVar("ext_func")' in global_vars + assert 'I.GlobalVar("ext_func_2")' not in global_vars assert tvm.ir.structural_equal(after["ext_func"], exp["ext_func"]) diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index 9e39821fd3173..6b3da0fd06d55 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. -import numpy as np -import pytest import itertools import logging from typing import Tuple +import numpy as np +import pytest + try: # See issue #9362. import torch @@ -28,13 +29,12 @@ pass import tvm -import tvm.testing import tvm.relay.testing - +import tvm.testing from tvm import relay +from tvm.contrib.download import download from tvm.relay import Any, GlobalVar from tvm.relay.expr_functor import ExprVisitor -from tvm.contrib.download import download from tvm.relay.op.contrib import tensorrt SUPPORTED_DTYPES = ["float16", "float32"] @@ -615,7 +615,7 @@ def __init__(self, op_list): def visit_call(self, call): if isinstance(call.op, tvm.tir.op.Op): - if str(call.op) in self.op_list: + if str(call.op.name) in self.op_list: self.on_graph = True return super().visit_call(call) diff --git a/tests/python/contrib/test_uma/test_partition.py b/tests/python/contrib/test_uma/test_partition.py index ec2107f881bc6..d029036109335 100644 --- a/tests/python/contrib/test_uma/test_partition.py +++ b/tests/python/contrib/test_uma/test_partition.py @@ -16,15 +16,12 @@ # under the License. import pytest - import tvm import tvm.relay as relay - +from tvm.relay.backend.contrib.uma import uma_available from tvm.relay.backend.contrib.uma.api import UMAPartitioner from tvm.relay.op.contrib.register import get_pattern_table -from tvm.relay.testing import resnet, mlp -from tvm.relay.backend.contrib.uma import uma_available - +from tvm.relay.testing import mlp, resnet pytestmark = pytest.mark.skipif(not uma_available(), reason="UMA not available") diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 1fae75f23eae4..e9fbe12e97540 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -17,27 +17,24 @@ """ Tests on quantized torch model conversion """ import os -from PIL import Image - import numpy as np - import torch +import tvm +import tvm.testing +from PIL import Image from torch import nn from torch.quantization import ( - QuantStub, DeQuantStub, - fuse_modules, + QuantStub, QuantWrapper, - prepare_qat, + fuse_modules, get_default_qat_qconfig, + prepare_qat, ) - -import tvm -import tvm.testing from tvm import relay -from tvm.relay.frontend.pytorch_utils import is_version_greater_than from tvm.contrib.download import download_testdata -from tvm.relay.op.contrib.register import register_pattern_table, get_pattern_table +from tvm.relay.frontend.pytorch_utils import is_version_greater_than +from tvm.relay.op.contrib.register import get_pattern_table, register_pattern_table def torch_version_check(): @@ -66,8 +63,10 @@ def get_tvm_runtime(script_module, input_name, ishape, keep_quantized_weight=Fal def get_qconfig(per_channel): - from torch.quantization.observer import MovingAverageMinMaxObserver - from torch.quantization.observer import default_weight_observer + from torch.quantization.observer import ( + MovingAverageMinMaxObserver, + default_weight_observer, + ) if per_channel: return torch.quantization.get_default_qconfig("fbgemm") @@ -396,11 +395,13 @@ def get_imagenet_input(): pt_tensor = preprocess(im) return np.expand_dims(pt_tensor.numpy(), 0) - from torchvision.models.quantization import resnet as qresnet - from torchvision.models.quantization import mobilenet as qmobilenet - from torchvision.models.quantization import inception as qinception from torchvision.models.quantization import googlenet as qgooglenet - from torchvision.models.quantization import mobilenet_v3_large as qmobilenet_v3_large + from torchvision.models.quantization import inception as qinception + from torchvision.models.quantization import mobilenet as qmobilenet + from torchvision.models.quantization import ( + mobilenet_v3_large as qmobilenet_v3_large, + ) + from torchvision.models.quantization import resnet as qresnet per_channel = True qmodels = [ @@ -596,7 +597,7 @@ def forward(self, inp): def make_qnn_add_pattern(): - from tvm.relay.dataflow_pattern import wildcard, is_op + from tvm.relay.dataflow_pattern import is_op, wildcard lhs = wildcard() rhs = wildcard() @@ -782,7 +783,7 @@ def forward(self, input): assert isinstance(output, relay.Tuple) and len(output) == 2 dq1, dq2 = output - assert str(dq1.op) == "qnn.dequantize" and str(dq2.op) == "qnn.dequantize" + assert dq1.op.name == "qnn.dequantize" and dq2.op.name == "qnn.dequantize" scale1 = dq1.args[1].data.numpy().item() scale2 = dq2.args[1].data.numpy().item() assert scale1 != scale2 diff --git a/tests/python/unittest/test_arith_deduce_bound.py b/tests/python/unittest/test_arith_deduce_bound.py index 0915df3051dbb..d5e0303b05b2c 100644 --- a/tests/python/unittest/test_arith_deduce_bound.py +++ b/tests/python/unittest/test_arith_deduce_bound.py @@ -64,14 +64,14 @@ def test_deduce(): e2 = tvm.te.max(5, a * 4) < 0 res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) - assert str(res2.max_value) == "neg_inf: handle" - assert str(res2.min_value) == "pos_inf: handle" + assert str(res2.max_value) == "neg_inf" + assert str(res2.min_value) == "pos_inf" # expression containing variable a is on rhs e2 = zero < tvm.te.max(5, a * 4) res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) - assert str(res2.max_value) == "neg_inf: handle" - assert str(res2.min_value) == "pos_inf: handle" + assert str(res2.max_value) == "neg_inf" + assert str(res2.min_value) == "pos_inf" e3 = (-b) + a * c - d res3 = tvm.arith.deduce_bound(a, e3 >= 0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) @@ -88,8 +88,8 @@ def test_deduce(): # Unsatisfiable `EQ`, variable as one of the Operand res5 = tvm.arith.deduce_bound(a, (a == b), {b: b_s}, {b: b_s}) - assert str(res5.max_value) == "neg_inf: handle" - assert str(res5.min_value) == "pos_inf: handle" + assert str(res5.max_value) == "neg_inf" + assert str(res5.min_value) == "pos_inf" # variable `a` on the RHS side res6 = tvm.arith.deduce_bound(a, 10 == a, {}, {}) @@ -111,15 +111,15 @@ def test_deduce(): # Unsatisfiable Mul in `EQ` e5 = 4 * a == b res9 = tvm.arith.deduce_bound(a, e5, {b: b_s}, {}) - assert str(res9.max_value) == "neg_inf: handle" - assert str(res9.min_value) == "pos_inf: handle" + assert str(res9.max_value) == "neg_inf" + assert str(res9.min_value) == "pos_inf" # Unsatisfiable Mul in `EQ` res10 = tvm.arith.deduce_bound( a, (b * a == b), {b: b_s}, {} ) # simplifier is not able to prove that (b % b == 0) - assert str(res10.max_value) == "neg_inf: handle" - assert str(res10.min_value) == "pos_inf: handle" + assert str(res10.max_value) == "neg_inf" + assert str(res10.min_value) == "pos_inf" def test_check(): diff --git a/tests/python/unittest/test_te_schedule.py b/tests/python/unittest/test_te_schedule.py index 8b504df120e08..69478b4518931 100644 --- a/tests/python/unittest/test_te_schedule.py +++ b/tests/python/unittest/test_te_schedule.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pickle as pkl + import pytest import tvm from tvm import te -import pickle as pkl def test_schedule_create(): @@ -297,8 +298,8 @@ def intrin_func(ins, outs, sp): stmt = tvm.lower(s, [A, C])["main"].body assert isinstance(stmt.body.body, tvm.tir.Evaluate) assert len(stmt.body.body.value.args) == 5 - assert str(stmt.body.body.value.args[3]) == "(i: int32*i)" - assert str(stmt.body.body.value.args[4]) == "(i: int32 + j: int32)" + assert str(stmt.body.body.value.args[3]) == "i * i" + assert str(stmt.body.body.value.args[4]) == "i + j" def test_legalize_invalid_attach(): diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 83cd64fa229be..d4ae84a556d78 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -14,10 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np import pytest import tvm -from tvm import te, ir -import numpy as np +from tvm import ir, te def test_const(): @@ -142,7 +142,7 @@ def test_basic(): a = te.var("a") b = te.var("b") c = a + b - assert str(c) == "(%s: int32 + %s: int32)" % (a.name, b.name) + assert str(c) == "%s + %s" % (a.name, b.name) def test_stmt(): @@ -176,8 +176,8 @@ def test_any(): assert False except ValueError: pass - assert str(tvm.tir.any(x < y)) == "(%s: int32 < %s: int32)" % (x.name, y.name) - assert str(tvm.tir.any(x < y, x > z)) == "((%s: int32 < %s: int32) || (%s > %s: int32))" % ( + assert str(tvm.tir.any(x < y)) == "%s < %s" % (x.name, y.name) + assert str(tvm.tir.any(x < y, x > z)) == "%s < %s or %s > %s" % ( x.name, y.name, x.name, @@ -185,7 +185,7 @@ def test_any(): ) assert str( tvm.tir.any(x < y, y > z + 1, x < z * 2) - ) == "(((%s: int32 < %s: int32) || (%s > (%s: int32 + 1))) || (%s < (%s*2)))" % ( + ) == "%s < %s or %s > %s + 1 or %s < %s * 2" % ( x.name, y.name, y.name, @@ -209,8 +209,8 @@ def test_all(): assert False except ValueError: pass - assert str(tvm.tir.all(x < y)) == "(%s: int32 < %s: int32)" % (x.name, y.name) - assert str(tvm.tir.all(x < y, x > z)) == "((%s: int32 < %s: int32) && (%s > %s: int32))" % ( + assert str(tvm.tir.all(x < y)) == "%s < %s" % (x.name, y.name) + assert str(tvm.tir.all(x < y, x > z)) == "%s < %s and %s > %s" % ( x.name, y.name, x.name, @@ -218,7 +218,7 @@ def test_all(): ) assert str( tvm.tir.all(x < y, y > z + 1, x < z * 2) - ) == "(((%s: int32 < %s: int32) && (%s > (%s: int32 + 1))) && (%s < (%s*2)))" % ( + ) == "%s < %s and %s > %s + 1 and %s < %s * 2" % ( x.name, y.name, y.name, @@ -231,19 +231,19 @@ def test_all(): def test_bitwise(): x = te.var("x") y = te.var("y") - assert str(x << y) == "@tir.shift_left(x: int32, y: int32, dtype=int32)" - assert str(x >> y) == "@tir.shift_right(x: int32, y: int32, dtype=int32)" - assert str(x & y) == "@tir.bitwise_and(x: int32, y: int32, dtype=int32)" - assert str(x | y) == "@tir.bitwise_or(x: int32, y: int32, dtype=int32)" - assert str(x ^ y) == "@tir.bitwise_xor(x: int32, y: int32, dtype=int32)" - assert str(10 & x) == "@tir.bitwise_and(10, x: int32, dtype=int32)" - assert str(10 | x) == "@tir.bitwise_or(10, x: int32, dtype=int32)" - assert str(10 ^ x) == "@tir.bitwise_xor(10, x: int32, dtype=int32)" - assert str(10 >> x) == "@tir.shift_right(10, x: int32, dtype=int32)" - assert str(10 << x) == "@tir.shift_left(10, x: int32, dtype=int32)" - assert str(10 % x) == "floormod(10, x: int32)" - - assert str(~x) == "@tir.bitwise_not(x: int32, dtype=int32)" + assert str(x << y) == "T.shift_left(x, y)" + assert str(x >> y) == "T.shift_right(x, y)" + assert str(x & y) == "T.bitwise_and(x, y)" + assert str(x | y) == "T.bitwise_or(x, y)" + assert str(x ^ y) == "T.bitwise_xor(x, y)" + assert str(10 & x) == "T.bitwise_and(10, x)" + assert str(10 | x) == "T.bitwise_or(10, x)" + assert str(10 ^ x) == "T.bitwise_xor(10, x)" + assert str(10 >> x) == "T.shift_right(10, x)" + assert str(10 << x) == "T.shift_left(10, x)" + assert str(10 % x) == "10 % x" + + assert str(~x) == "T.bitwise_not(x)" assert (tvm.tir.const(1, "int8x2") >> 1).dtype == "int8x2" assert (x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2" assert (te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2" @@ -302,17 +302,17 @@ def test_divide_by_zero(): def test_infinity(): - assert str(tvm.tir.infinity("float16")) == "inff16" - assert str(tvm.tir.infinity("float32")) == "inff32" - assert str(tvm.tir.infinity("float64")) == "inff64" + assert str(tvm.tir.infinity("float16")) == 'T.float16("inf")' + assert str(tvm.tir.infinity("float32")) == 'T.float32("inf")' + assert str(tvm.tir.infinity("float64")) == 'T.float64("inf")' def test_isnan(): x = te.var("x", "float32") - assert str(tvm.tir.isnan(x)) == "@tir.isnan(x: float32, dtype=bool)" + assert str(tvm.tir.isnan(x)) == "T.isnan(x)" assert str(tvm.tir.isnan(x).dtype) == "bool" y = te.var("y", "float16") - assert str(tvm.tir.isnan(y)) == "@tir.isnan(cast(float32, y: float16), dtype=bool)" + assert str(tvm.tir.isnan(y)) == 'T.isnan(T.Cast("float32", y))' z = te.var("z", "int32") assert str(tvm.tir.isnan(z)) == "False" k = te.var("k", "int8x2") diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index 7062d51297135..adf3d9da05cea 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -14,17 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -from tvm.script import tir as T import numpy as np +import tvm import tvm.testing +from tvm.script import tir as T def count_cp_async(stmt): num_alloc = [0] def verify(n): - if isinstance(n, tvm.tir.Call) and str(n.op) == "tir.ptx_cp_async": + if isinstance(n, tvm.tir.Call) and n.op.name == "tir.ptx_cp_async": num_alloc[0] += 1 tvm.tir.stmt_functor.post_order_visit(stmt, verify) diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index 006b67d626977..cf01d7700725f 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -1507,7 +1507,7 @@ def test_async_pipelined_mma_gemm_simple(): assert body.block.body.body[1].block.body.body.value == 3 assert epilogue.block.body.body.block.body.body.attr_key == "async_wait_inflight_count" - assert str(epilogue.block.body.body.block.body.body.value) == "(2 - k_0_0: int32)" + assert str(epilogue.block.body.body.block.body.body.value) == "2 - k_0_0" build_and_run(sch) @@ -1554,7 +1554,7 @@ def test_async_nested_pipeline_mma_gemm_ideal_annotation(): assert body.block.body.body[1].block.body.body.attr_key == "async_wait_inflight_count" assert body.block.body.body[1].block.body.body.value == 2 - assert str(epilogue.block.body.body[0].block.body.body.value) == "(1 - k_0_0: int32)" + assert str(epilogue.block.body.body[0].block.body.body.value) == "1 - k_0_0" build_and_run(sch) diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index e78ed98d85694..47bb7bf228d42 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -83,19 +83,17 @@ def test_variable_passed_from_args(): # Arguments unpacking assignment = _find_assignment(func.body, "arg.input_buffer") - assert str(assignment.value) == "@tir.tvm_struct_get(args: handle, 0, 12, dtype=handle)" + assert str(assignment.value) == 'T.tvm_struct_get(args, 0, 12, "handle")' assignment = _find_assignment(func.body, "arg.not_device_context") - assert str(assignment.value) == "@tir.tvm_struct_get(args: handle, 1, 12, dtype=handle)" + assert str(assignment.value) == 'T.tvm_struct_get(args, 1, 12, "handle")' assignment = _find_assignment(func.body, "input_buffer") - assert ( - str(assignment.value) == "@tir.tvm_struct_get(arg.input_buffer: handle, 0, 1, dtype=handle)" - ) + assert str(assignment.value) == 'T.tvm_struct_get(arg_input_buffer, 0, 1, "handle")' unpacked_input_buffer = assignment.var assignment = _find_assignment(func.body, "not_device_context") - assert str(assignment.value) == "arg.not_device_context: handle" + assert str(assignment.value) == "arg_not_device_context" unpacked_not_device_context = assignment.var seq_stmt = _find_next(assignment, tvm.tir.SeqStmt) @@ -131,12 +129,10 @@ def test_device_api_context_implicit_resource_handle(): # Arguments unpacking assignment = _find_assignment(func.body, "arg.input_buffer") - assert str(assignment.value) == "@tir.tvm_struct_get(args: handle, 0, 12, dtype=handle)" + assert str(assignment.value) == 'T.tvm_struct_get(args, 0, 12, "handle")' assignment = _find_assignment(func.body, "input_buffer") - assert ( - str(assignment.value) == "@tir.tvm_struct_get(arg.input_buffer: handle, 0, 1, dtype=handle)" - ) + assert str(assignment.value) == 'T.tvm_struct_get(arg_input_buffer, 0, 1, "handle")' unpacked_input_buffer = assignment.var seq_stmt = _find_next(assignment, tvm.tir.SeqStmt) diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 0c5d77d02b91c..b2a0581d69801 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -92,7 +92,7 @@ def ir(A, B): stmt = ir(A, B) func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) mod = run_passes(func) - assert "@tir.tvm_storage_sync" in str(mod) + assert "T.tvm_storage_sync" in str(mod) @tvm.testing.requires_cuda @@ -115,7 +115,7 @@ def func(p0_arg: T.Buffer[(1, 2, 1, 1), "float32"], p1: T.Buffer[2, "float32"]) result_local[0] = result_local[0] + temp_shared[0] * p1[1] mod = run_passes(func) - assert "@tir.tvm_storage_sync" in str(mod) + assert "T.tvm_storage_sync" in str(mod) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tvmscript_printer_ir.py b/tests/python/unittest/test_tvmscript_printer_ir.py new file mode 100644 index 0000000000000..e8836bb06d4fe --- /dev/null +++ b/tests/python/unittest/test_tvmscript_printer_ir.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from tvm import IRModule +from tvm.script.ir_builder import IRBuilder +from tvm.script.ir_builder import ir as I +from tvm.script.ir_builder import tir as T + + +def _assert_print(obj, expected): + assert str(obj).strip() == expected.strip() + assert repr(obj).strip() == expected.strip() + if isinstance(obj, IRModule): + assert obj.script().strip() == expected.strip() + + +def test_ir_module(): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module(): + with T.prim_func(): + T.func_name("foo") + mod = ib.get() + _assert_print( + mod, + """ +@I.ir_module +class Module: + @T.prim_func + def foo() -> None: + T.evaluate(0)""", + ) + + +if __name__ == "__main__": + test_ir_module() diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index fd3bb3788cfb7..0fd061a4298fb 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -35,6 +35,9 @@ def verbose_expr(): def _assert_print(obj, expected): with verbose_expr(): + if isinstance(obj, (tir.PrimFunc, tir.PrimExpr, tir.Stmt)): + assert obj.script().strip() == expected.strip() + assert str(obj).strip() == expected.strip() assert repr(obj).strip() == expected.strip()