Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Update the type_keys to reflect the code-org #5074

Merged
merged 1 commit into from
Mar 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ class GlobalVarNode : public RelayExprNode {
v->Visit("_checked_type_", &checked_type_);
}

static constexpr const char* _type_key = "relay.GlobalVar";
static constexpr const char* _type_key = "GlobalVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode);
};

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class IRModuleNode : public Object {
*/
TVM_DLL std::unordered_set<std::string> Imports() const;

static constexpr const char* _type_key = "relay.Module";
static constexpr const char* _type_key = "IRModule";
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);

private:
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/ir/span.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class SourceNameNode : public Object {
// override attr visitor
void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }

static constexpr const char* _type_key = "relay.SourceName";
static constexpr const char* _type_key = "SourceName";
TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object);
};

Expand Down Expand Up @@ -89,7 +89,7 @@ class SpanNode : public Object {

TVM_DLL static Span make(SourceName source, int lineno, int col_offset);

static constexpr const char* _type_key = "relay.Span";
static constexpr const char* _type_key = "Span";
TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object);
};

Expand Down
6 changes: 3 additions & 3 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class PassContextNode : public Object {
v->Visit("disabled_pass", &disabled_pass);
}

static constexpr const char* _type_key = "relay.PassContext";
static constexpr const char* _type_key = "transform.PassContext";
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
};

Expand Down Expand Up @@ -206,7 +206,7 @@ class PassInfoNode : public Object {
v->Visit("required", &required);
}

static constexpr const char* _type_key = "relay.PassInfo";
static constexpr const char* _type_key = "transform.PassInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object);
};

Expand Down Expand Up @@ -265,7 +265,7 @@ class PassNode : public Object {

void VisitAttrs(AttrVisitor* v) {}

static constexpr const char* _type_key = "relay.Pass";
static constexpr const char* _type_key = "transform.Pass";
TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object);
};

Expand Down
18 changes: 10 additions & 8 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class TypeNode : public Object {
*/
mutable Span span;

static constexpr const char* _type_key = "relay.Type";
static constexpr const char* _type_key = "Type";
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
};

Expand Down Expand Up @@ -110,7 +110,7 @@ class PrimTypeNode : public TypeNode {
v->Visit("dtype", &dtype);
}

static constexpr const char* _type_key = "relay.PrimType";
static constexpr const char* _type_key = "PrimType";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
};

Expand Down Expand Up @@ -175,7 +175,7 @@ class TypeVarNode : public TypeNode {
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.TypeVar";
static constexpr const char* _type_key = "TypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
};

Expand Down Expand Up @@ -215,7 +215,7 @@ class GlobalTypeVarNode : public TypeNode {
v->Visit("kind", &kind);
}

static constexpr const char* _type_key = "relay.GlobalTypeVar";
static constexpr const char* _type_key = "GlobalTypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
};

Expand Down Expand Up @@ -251,7 +251,7 @@ class TupleTypeNode : public TypeNode {
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.TupleType";
static constexpr const char* _type_key = "TupleType";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
};

Expand Down Expand Up @@ -289,7 +289,7 @@ inline Type VoidType() {
*/
class TypeConstraintNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.TypeConstraint";
static constexpr const char* _type_key = "TypeConstraint";
TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode);
};

Expand Down Expand Up @@ -334,7 +334,7 @@ class FuncTypeNode : public TypeNode {
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.FuncType";
static constexpr const char* _type_key = "FuncType";
TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
};

Expand Down Expand Up @@ -380,7 +380,7 @@ class IncompleteTypeNode : public TypeNode {
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.IncompleteType";
static constexpr const char* _type_key = "IncompleteType";
TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
};

Expand Down Expand Up @@ -417,6 +417,8 @@ class RelayRefTypeNode : public TypeNode {
v->Visit("span", &span);
}

// Keep the relay prefix in the type as this type is specific
// to the relay itself.
static constexpr const char* _type_key = "relay.RefType";
TVM_DECLARE_FINAL_OBJECT_INFO(RelayRefTypeNode, TypeNode);
};
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/ir/type_relation.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class TypeCallNode : public TypeNode {
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.TypeCall";
static constexpr const char* _type_key = "TypeCall";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode);
};

Expand Down Expand Up @@ -119,7 +119,7 @@ class TypeReporterNode : public Object {
// solver is not serializable.
void VisitAttrs(AttrVisitor* v) {}

static constexpr const char* _type_key = "relay.TypeReporter";
static constexpr const char* _type_key = "TypeReporter";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object);
};

Expand Down Expand Up @@ -195,7 +195,7 @@ class TypeRelationNode : public TypeConstraintNode {
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.TypeRelation";
static constexpr const char* _type_key = "TypeRelation";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
};

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint: disable=unused-import
"""Common data structures across all IR variants."""
from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
from .type import Type, TypeKind, TypeVar, GlobalTypeVar, TupleType
from .type import Type, TypeKind, PrimType, TypeVar, GlobalTypeVar, TupleType
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .tensor_type import TensorType
from .type_relation import TypeCall, TypeRelation
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/ir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __str__(self):
return _ffi_api.PrettyPrint(self)


@tvm._ffi.register_object("relay.SourceName")
@tvm._ffi.register_object("SourceName")
class SourceName(Object):
"""A identifier for a source location.

Expand All @@ -69,7 +69,7 @@ def __init__(self, name):
self.__init_handle_by_constructor__(_ffi_api.SourceName, name)


@tvm._ffi.register_object("relay.Span")
@tvm._ffi.register_object("Span")
class Span(Object):
"""Specifies a location in a source program.

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def checked_type(self):
return ret


@tvm._ffi.register_object("relay.GlobalVar")
@tvm._ffi.register_object("GlobalVar")
class GlobalVar(RelayExpr):
"""A global variable in the IR.

Expand Down
24 changes: 24 additions & 0 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,35 @@ def _ftype_var(item, nodes):
# set vindex to null
nodes[vindex]["type_key"] = ""
del item["attrs"]["var"]
assert item["type_key"].startswith("relay.")
item["type_key"] = item["type_key"][len("relay."):]
return item

def _rename(new_name):
def _convert(item, _):
item["type_key"] = new_name
return item
return _convert

node_map = {
"relay.TypeVar": _ftype_var,
"relay.GlobalTypeVar": _ftype_var,
"relay.Type": _rename("Type"),
"relay.TupleType": _rename("TupleType"),
"relay.TypeConstraint": _rename("TypeConstraint"),
"relay.FuncType": _rename("FuncType"),
"relay.IncompleteType": _rename("IncompleteType"),
"relay.TypeRelation": _rename("TypeRelation"),
"relay.TypeCall": _rename("TypeCall"),
"relay.Module": _rename("IRModule"),
"relay.SourceName": _rename("SourceName"),
"relay.Span": _rename("Span"),
"relay.GlobalVar": _rename("GlobalVar"),
"relay.Pass": _rename("transform.Pass"),
"relay.PassInfo": _rename("transform.PassInfo"),
"relay.PassContext": _rename("transform.PassContext"),
"relay.ModulePass": _rename("transform.ModulePass"),
"relay.Sequantial": _rename("transform.Sequantial"),
}
return create_updater(node_map, "0.6", "0.7")

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from . import _ffi_api


@tvm._ffi.register_object("relay.Module")
@tvm._ffi.register_object("IRModule")
class IRModule(Node):
"""IRModule that holds functions and type definitions.

Expand Down
10 changes: 5 additions & 5 deletions python/tvm/ir/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from . import _ffi_transform_api

@tvm._ffi.register_object("relay.PassInfo")
@tvm._ffi.register_object("transform.PassInfo")
class PassInfo(Object):
"""The class contains the meta data required by a pass. It is the
container of information needed by running an optimization or analysis.
Expand All @@ -51,7 +51,7 @@ def __init__(self, opt_level, name, required=None):
_ffi_transform_api.PassInfo, opt_level, name, required)


@tvm._ffi.register_object("relay.PassContext")
@tvm._ffi.register_object("transform.PassContext")
class PassContext(Object):
"""The basis where a Relay optimization/analysis runs on.
Each pass context contains a number of auxiliary information that is used
Expand Down Expand Up @@ -112,7 +112,7 @@ def current():
return _ffi_transform_api.GetCurrentPassContext()


@tvm._ffi.register_object("relay.Pass")
@tvm._ffi.register_object("transform.Pass")
class Pass(Object):
"""The base class of all passes. All methods here are just simple wrappers
that are implemented in the backend. They are defined for users to
Expand Down Expand Up @@ -141,7 +141,7 @@ def __call__(self, mod):
return _ffi_transform_api.RunPass(self, mod)


@tvm._ffi.register_object("relay.ModulePass")
@tvm._ffi.register_object("transform.ModulePass")
class ModulePass(Pass):
"""A pass that works on tvm.IRModule. Users don't need to interact with
this class directly. Instead, a module pass should be created through
Expand All @@ -152,7 +152,7 @@ class ModulePass(Pass):
"""


@tvm._ffi.register_object("relay.Sequential")
@tvm._ffi.register_object("transform.Sequential")
class Sequential(Pass):
"""A pass that works on a sequence of pass objects. Multiple passes can be
executed sequentially using this class.
Expand Down
25 changes: 19 additions & 6 deletions python/tvm/ir/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,20 @@ class TypeKind(IntEnum):
TypeData = 6


@tvm._ffi.register_object("relay.TypeVar")
class PrimType(Type):
"""Primitive data type in the low level IR

Parameters
----------
dtype : str
The runtime data type relates to the primtype.
"""
def __init__(self, dtype):
self.__init_handle_by_constructor__(
_ffi_api.PrimType, dtype)


@tvm._ffi.register_object("TypeVar")
class TypeVar(Type):
"""Type parameter in functions.

Expand Down Expand Up @@ -85,7 +98,7 @@ def __call__(self, *args):
return TypeCall(self, args)


@tvm._ffi.register_object("relay.GlobalTypeVar")
@tvm._ffi.register_object("GlobalTypeVar")
class GlobalTypeVar(Type):
"""A global type variable that is used for defining new types or type aliases.

Expand Down Expand Up @@ -120,7 +133,7 @@ def __call__(self, *args):
return TypeCall(self, args)


@tvm._ffi.register_object("relay.TupleType")
@tvm._ffi.register_object("TupleType")
class TupleType(Type):
"""The type of tuple values.

Expand All @@ -135,12 +148,12 @@ def __init__(self, fields):
_ffi_api.TupleType, fields)


@tvm._ffi.register_object("relay.TypeConstraint")
@tvm._ffi.register_object("TypeConstraint")
class TypeConstraint(Type):
"""Abstract class representing a type constraint."""


@tvm._ffi.register_object("relay.FuncType")
@tvm._ffi.register_object("FuncType")
class FuncType(Type):
"""Function type.

Expand Down Expand Up @@ -179,7 +192,7 @@ def __init__(self,
_ffi_api.FuncType, arg_types, ret_type, type_params, type_constraints)


@tvm._ffi.register_object("relay.IncompleteType")
@tvm._ffi.register_object("IncompleteType")
class IncompleteType(Type):
"""Incomplete type during type inference.

Expand Down
3 changes: 2 additions & 1 deletion python/tvm/ir/type_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from . import _ffi_api


@tvm._ffi.register_object("TypeCall")
class TypeCall(Type):
"""Type function application.

Expand All @@ -41,7 +42,7 @@ def __init__(self, func, args):
self.__init_handle_by_constructor__(_ffi_api.TypeCall, func, args)


@tvm._ffi.register_object("relay.TypeRelation")
@tvm._ffi.register_object("TypeRelation")
class TypeRelation(TypeConstraint):
"""User defined type relation, it is an input-output relation on types.

Expand Down
4 changes: 2 additions & 2 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class ModulePassNode : public PassNode {
*/
PassInfo Info() const override { return pass_info; }

static constexpr const char* _type_key = "relay.ModulePass";
static constexpr const char* _type_key = "transform.ModulePass";
TVM_DECLARE_FINAL_OBJECT_INFO(ModulePassNode, PassNode);
};

Expand Down Expand Up @@ -206,7 +206,7 @@ class SequentialNode : public PassNode {
*/
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;

static constexpr const char* _type_key = "relay.Sequential";
static constexpr const char* _type_key = "transform.Sequential";
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
};

Expand Down
Loading