Skip to content

Commit

Permalink
[TVMScript] Migrate More to TVMScripr Printer
Browse files Browse the repository at this point in the history
This PR gradually migrates more pieces of the default printing to
TVMScript printer for TIR.

This PR gradually migrates more pieces of the default printing to
TVMScript printer for TIR. Details:
- Introduced a method `AsLegacyRepr` which preserves existing
`AsRepr` provided by `ReprPrinter`, so that the legacy behavior
could be 100% preserved.
- Introduced `Script` method to `IRModule`, `PrimFunc`, `tir.Stmt`,
`tir.PrimExpr`. The `script` method exists in python side before,
and this PR introduced them to C++ to be consistent.
- Replace TIR's `PrettyPrint` to `operator <<` that is provided by
the new `ReprPrinter`, which outputs in TVMScript format by default.
`PrettyPrint` on Relay is all preserved for backward compatibility.
  • Loading branch information
junrushao committed Jan 17, 2023
1 parent 1258863 commit 4330d6f
Show file tree
Hide file tree
Showing 99 changed files with 2,347 additions and 962 deletions.
11 changes: 11 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ class PrimExprNode : public BaseExprNode {
*/
DataType dtype;

/*!
* \brief Returns the TVMScript format
* \param indent_spaces Number of spaces used for indentation
* \param print_line_numbers Whether to print line numbers
* \param num_context_lines Number of context lines to print around the underlined text
* \param path_to_underline Object path to be underlined
*/
TVM_DLL std::string Script(int indent_spaces = 4, bool print_line_numbers = false,
int num_context_lines = -1,
Optional<ObjectPath> path_to_underline = NullOpt) const;

static constexpr const char* _type_key = "PrimExpr";
static constexpr const uint32_t _type_child_slots = 38;
TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode);
Expand Down
55 changes: 30 additions & 25 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,26 @@ class IRModuleNode : public Object {
parser::SourceMap source_map;
/* \brief Additional attributes storing meta-data about the module. */
DictAttrs attrs;
/*!
* \brief A map from string names to global variables that
* ensures global uniqueness.
*/
Map<String, GlobalVar> global_var_map_;

/*! \brief A map from string names to global type variables (ADT names)
* that ensures global uniqueness.
*/
Map<String, GlobalTypeVar> global_type_var_map_;

/*! \brief A map from constructor tags to constructor objects
* for convenient access
*/
std::unordered_map<int32_t, Constructor> constructor_tag_map_;

/*! \brief The files previously imported, required to ensure
importing is idempotent for each module.
*/
std::unordered_set<String> import_set_;

/*!
* \brief Get a module attribute.
Expand Down Expand Up @@ -304,15 +324,20 @@ class IRModuleNode : public Object {
TVM_DLL void ImportFromStd(const String& path);

/*!
* \brief Should Link Parameters into the module
* \return Whether the Executor is configured to execute with linked parameters (Default: false)
* \brief The set of imported files.
*/
TVM_DLL Bool ShouldLinkParameters() const;
TVM_DLL std::unordered_set<String> Imports() const;

/*!
* \brief The set of imported files.
* \brief Returns the TVMScript format
* \param indent_spaces Number of spaces used for indentation
* \param print_line_numbers Whether to print line numbers
* \param num_context_lines Number of context lines to print around the underlined text
* \param path_to_underline Object path to be underlined
*/
TVM_DLL std::unordered_set<String> Imports() const;
TVM_DLL std::string Script(int indent_spaces = 4, bool print_line_numbers = false,
int num_context_lines = -1,
Optional<ObjectPath> path_to_underline = NullOpt) const;

static constexpr const char* _type_key = "IRModule";
static constexpr const bool _type_has_method_sequal_reduce = true;
Expand All @@ -322,26 +347,6 @@ class IRModuleNode : public Object {
private:
/*! \brief Helper function for registering a typedef's constructors */
void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type);

/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
Map<String, GlobalVar> global_var_map_;

/*! \brief A map from string names to global type variables (ADT names)
* that ensures global uniqueness.
*/
Map<String, GlobalTypeVar> global_type_var_map_;

/*! \brief A map from constructor tags to constructor objects
* for convenient access
*/
std::unordered_map<int32_t, Constructor> constructor_tag_map_;

/*! \brief The files previously imported, required to ensure
importing is idempotent for each module.
*/
std::unordered_set<String> import_set_;
friend class IRModule;
};

Expand Down
19 changes: 19 additions & 0 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,25 @@ enum TypeKind : int {
kTypeData = 6
};

/*! \brief Converts a TypeKind to a string. */
inline String TypeKind2String(TypeKind kind) {
switch (kind) {
case TypeKind::kType:
return "Type";
case TypeKind::kShapeVar:
return "ShapeVar";
case TypeKind::kBaseType:
return "BaseType";
case TypeKind::kConstraint:
return "Constraint";
case TypeKind::kAdtHandle:
return "AdtHandle";
case TypeKind::kTypeData:
return "TypeData";
}
LOG(FATAL) << "ValueError: Unknown TypeKind: " << static_cast<int>(kind);
}

/*!
* \brief Type parameter in functions.
*
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/ir/type_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
#ifndef TVM_IR_TYPE_FUNCTOR_H_
#define TVM_IR_TYPE_FUNCTOR_H_

#include <tvm/ir/tensor_type.h>
#include <tvm/ir/type_relation.h>
#include <tvm/node/functor.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/expr.h>

#include <string>
#include <utility>
Expand Down
32 changes: 32 additions & 0 deletions include/tvm/node/repr_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/node/functor.h>

#include <iostream>
#include <string>

namespace tvm {
/*! \brief A printer class to print the AST/IR nodes. */
Expand All @@ -48,6 +49,30 @@ class ReprPrinter {
TVM_DLL static FType& vtable();
};

/*! \brief Legacy behavior of ReprPrinter. */
class ReprLegacyPrinter {
public:
/*! \brief The indentation level. */
int indent{0};

explicit ReprLegacyPrinter(std::ostream& stream) // NOLINT(*)
: stream(stream) {}

/*! \brief The node to be printed. */
TVM_DLL void Print(const ObjectRef& node);
/*! \brief Print indent to the stream */
TVM_DLL void PrintIndent();
/*! \brief Return the ostream it maintains */
TVM_DLL std::ostream& Stream() const;
// Allow registration to be printer.
using FType = NodeFunctor<void(const ObjectRef&, ReprLegacyPrinter*)>;
TVM_DLL static FType& vtable();

private:
/*! \brief The output stream */
std::ostream& stream;
};

/*!
* \brief Dump the node to stderr, used for debug purposes.
* \param node The input node
Expand All @@ -70,6 +95,13 @@ inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLI
ReprPrinter(os).Print(n);
return os;
}

inline std::string AsLegacyRepr(const ObjectRef& n) {
std::ostringstream os;
ReprLegacyPrinter(os).Print(n);
return os.str();
}
} // namespace runtime
using runtime::AsLegacyRepr;
} // namespace tvm
#endif // TVM_NODE_REPR_PRINTER_H_
3 changes: 3 additions & 0 deletions include/tvm/script/printer/ir_docsifier_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ class IRDocsifierFunctor {
if ((pf = LookupDispatchTable("", type_index)) != nullptr) {
return (*pf)(obj, args...);
}
LOG(WARNING) << "ObjectFunctor calls un-registered function on type: "
<< runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")"
<< ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj;
ICHECK(false) << "ObjectFunctor calls un-registered function on type: "
<< runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")"
<< ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj;
Expand Down
15 changes: 0 additions & 15 deletions include/tvm/script/printer/printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,6 @@ struct Default {
static bool& VerboseExpr() { return Instance()->verbose_expr; }
};

/*!
* \brief The entry method for TVMScript printing
* \param obj The object to be printed
* \param indent_spaces Number of spaces used for indentation
* \param print_line_numbers Whether to print line numbers
* \param num_context_lines Number of context lines to print around the underlined text
* \param path_to_underline Object path to be underlined
* \return The TVMScript text format
*/
String Script(ObjectRef obj, //
int indent_spaces = 4, //
bool print_line_numbers = false, //
int num_context_lines = -1, //
Optional<ObjectPath> path_to_underline = NullOpt);

/*!
* \brief Convert Doc into Python script.
* \param doc Doc to be converted
Expand Down
3 changes: 0 additions & 3 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1191,9 +1191,6 @@ class Any : public PrimExpr {
TVM_DEFINE_OBJECT_REF_COW_METHOD(AnyNode);
};

/*! \brief Legacy ReprPrint format for TIR */
std::string LegacyTIRPrint(const ObjectRef& obj);

/*
* \brief Template function to convert Map to unordered_map
* Sometimes useful for API gluing when internal uses unordered_map
Expand Down
11 changes: 11 additions & 0 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,17 @@ class PrimFuncNode : public BaseFuncNode {
*/
TVM_DLL FuncType func_type_annotation() const;

/*!
* \brief Returns the TVMScript format
* \param indent_spaces Number of spaces used for indentation
* \param print_line_numbers Whether to print line numbers
* \param num_context_lines Number of context lines to print around the underlined text
* \param path_to_underline Object path to be underlined
*/
std::string Script(int indent_spaces = 4, bool print_line_numbers = false,
int num_context_lines = -1,
Optional<ObjectPath> path_to_underline = NullOpt) const;

static constexpr const char* _type_key = "tir.PrimFunc";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode);
};
Expand Down
11 changes: 11 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ class StmtNode : public Object {
StmtNode() = default;
explicit StmtNode(Span span) : span(span) {}

/*!
* \brief Returns the TVMScript format
* \param indent_spaces Number of spaces used for indentation
* \param print_line_numbers Whether to print line numbers
* \param num_context_lines Number of context lines to print around the underlined text
* \param path_to_underline Object path to be underlined
*/
std::string Script(int indent_spaces = 4, bool print_line_numbers = false,
int num_context_lines = -1,
Optional<ObjectPath> path_to_underline = NullOpt) const;

static constexpr const char* _type_key = "tir.Stmt";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
Expand Down
56 changes: 37 additions & 19 deletions python/tvm/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,47 @@
# under the License.
# pylint: disable=unused-import
"""Common data structures across all IR variants."""
from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
from .base import structural_equal, assert_structural_equal, structural_hash
from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .tensor_type import TensorType
from .affine_type import TensorAffineType, TupleAffineType
from .type_relation import TypeCall, TypeRelation
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range
from .op import Op, register_op_attr, register_intrin_lowering
from .function import CallingConv, BaseFunc
from . import diagnostics, instrument, transform
from .adt import Constructor, TypeData
from .module import IRModule
from .affine_type import TensorAffineType, TupleAffineType
from .attrs import Attrs, DictAttrs, make_node
from .base import (
EnvFunc,
Node,
SourceName,
Span,
assert_structural_equal,
load_json,
pretty_print,
save_json,
structural_equal,
structural_hash,
)
from .container import Array, Map
from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelayExpr
from .function import BaseFunc, CallingConv
from .memory_pools import (
PoolInfo,
WorkspacePoolInfo,
ConstantPoolInfo,
WorkspaceMemoryPools,
ConstantMemoryPools,
ConstantPoolInfo,
PoolInfo,
PoolInfoProperties,
WorkspaceMemoryPools,
WorkspacePoolInfo,
)

from . import transform
from . import instrument
from . import diagnostics
from .module import IRModule
from .op import Op, register_intrin_lowering, register_op_attr
from .tensor_type import TensorType
from .type import (
FuncType,
GlobalTypeVar,
IncompleteType,
PointerType,
PrimType,
RelayRefType,
TupleType,
Type,
TypeConstraint,
TypeKind,
TypeVar,
)
from .type_relation import TypeCall, TypeRelation
7 changes: 6 additions & 1 deletion python/tvm/ir/affine_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
"""Types for quantized Tensors."""
import tvm._ffi

from .base import Node
from . import _ffi_api
from .base import Node


class AffineType(Node):
Expand All @@ -31,6 +31,11 @@ def __eq__(self, other):
def __ne__(self, other):
return not self.__eq__(other)

def __str__(self):
from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel

return pretty_print(self)


@tvm._ffi.register_object("TensorAffineType")
class TensorAffineType(AffineType):
Expand Down
12 changes: 6 additions & 6 deletions python/tvm/ir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
# under the License.
"""Common base structures."""
import tvm._ffi

import tvm.error
import tvm.runtime._ffi_node_api
from tvm.runtime import Object

from . import _ffi_api
from . import json_compact
from . import _ffi_api, json_compact


def pretty_print(obj: Object) -> None:
"""Pretty print the object."""
return _ffi_api.PrettyPrint(obj) # type: ignore # pylint: disable=no-member


class Node(Object):
Expand Down Expand Up @@ -54,9 +57,6 @@ def astext(self, show_meta_data=True, annotate=None):
"""
return _ffi_api.AsText(self, show_meta_data, annotate)

def __str__(self):
return _ffi_api.PrettyPrint(self)


@tvm._ffi.register_object("SourceName")
class SourceName(Object):
Expand Down
Loading

0 comments on commit 4330d6f

Please sign in to comment.