Skip to content

Commit

Permalink
[TIR] Enhance Python Type Annotations for TIR Expr (apache#16083)
Browse files Browse the repository at this point in the history
This PR enhances the Python annotations for the TIR expr,
adding class member variables annotations.
  • Loading branch information
Hzfengsy authored Nov 8, 2023
1 parent db4290b commit 3f3473e
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 151 deletions.
4 changes: 2 additions & 2 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ class RangeNode : public Object {
TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object);
};

/*! \brief Range constainer */
/*! \brief Range container */
class Range : public ObjectRef {
public:
/*!
Expand All @@ -736,7 +736,7 @@ class Range : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode);
};

// implementataions
// implementations
inline const Type& RelayExprNode::checked_type() const {
ICHECK(checked_type_.defined()) << "internal error: the type checker has "
<< "not populated the checked_type "
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ class IterVarNode : public Object {
IterVarType iter_type;
/*!
* \brief additional tag on the iteration variable,
* set this if this is binded already to a known thread tag.
* set this if this is bound already to a known thread tag.
*/
String thread_tag;
/*!
Expand Down
40 changes: 29 additions & 11 deletions python/tvm/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,21 @@
# under the License.
"""Common expressions data structures in the IR."""
from numbers import Number
from typing import Callable, Optional

import tvm._ffi

from ..runtime import Scriptable, const, convert
from ..runtime import Object, Scriptable, const, convert
from . import _ffi_api
from .base import Node
from .base import Node, Span
from .type import Type


class BaseExpr(Node):
"""Base class of all the expressions."""

span: Optional[Span]


class PrimExpr(BaseExpr):
"""Base class of all primitive expressions.
Expand All @@ -35,6 +39,8 @@ class PrimExpr(BaseExpr):
optimizations and integer analysis.
"""

dtype: str


class RelayExpr(BaseExpr):
"""Base class of all non-primitive expressions."""
Expand Down Expand Up @@ -67,10 +73,12 @@ class GlobalVar(RelayExpr):
The name of the variable.
"""

def __init__(self, name_hint, type_annot=None):
name_hint: str

def __init__(self, name_hint: str, type_annot: Optional[Type] = None):
self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint, type_annot)

def __call__(self, *args):
def __call__(self, *args: RelayExpr) -> BaseExpr:
"""Call the global variable.
Parameters
Expand All @@ -94,7 +102,9 @@ def __call__(self, *args):
arg_types = [type(x) for x in args]
raise RuntimeError(f"Do not know how to handle GlobalVar.__call__ for types {arg_types}")

def astext(self, show_meta_data=True, annotate=None):
def astext(
self, show_meta_data: bool = True, annotate: Optional[Callable[[Object], str]] = None
) -> str:
"""Get the text format of the expression.
Parameters
Expand Down Expand Up @@ -140,22 +150,30 @@ class Range(Node, Scriptable):
The end value of the range.
span : Optional[Span]
The location of this itervar in the source code.
The location of this node in the source code.
Note
----
The constructor creates the range `[begin, end)`
if the end argument is not None. Otherwise, it creates `[0, begin)`.
"""

def __init__(self, begin, end=None, span=None):
min: PrimExpr
extent: PrimExpr
span: Optional[Span]

def __init__(
self, begin: PrimExpr, end: Optional[PrimExpr] = None, span: Optional[Span] = None
) -> None:
if end is None:
end = convert(begin)
begin = const(0, dtype=end.dtype, span=span)
self.__init_handle_by_constructor__(_ffi_api.Range, begin, end, span)

@staticmethod
def from_min_extent(min_value, extent, span=None):
def from_min_extent(
min_value: PrimExpr, extent: PrimExpr, span: Optional[Span] = None
) -> "Range":
"""Construct a Range by min and extent.
This constructs a range in [min_value, min_value + extent)
Expand All @@ -169,7 +187,7 @@ def from_min_extent(min_value, extent, span=None):
The extent of the range.
span : Optional[Span]
The location of this itervar in the source code.
The location of this node in the source code.
Returns
-------
Expand All @@ -178,8 +196,8 @@ def from_min_extent(min_value, extent, span=None):
"""
return _ffi_api.Range_from_min_extent(min_value, extent, span)

def __eq__(self, other):
def __eq__(self, other: Object) -> bool:
return tvm.ir.structural_equal(self, other)

def __ne__(self, other):
def __ne__(self, other: Object) -> bool:
return not self.__eq__(other)
Loading

0 comments on commit 3f3473e

Please sign in to comment.