Skip to content

Commit

Permalink
[REFACTOR] M1: Change parser/printer to only depend on struct info (a…
Browse files Browse the repository at this point in the history
…pache#319)

* [REFACTOR] StructInfo M1: Parser/printer/Var/Function to only depend on struct info field

* Update src/relax/backend/vm/vm_shape_lower.cc

Co-authored-by: Ruihang Lai <[email protected]>

* Address comments

* Allow function to have default value

Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
  • Loading branch information
3 people authored and junrushao committed Feb 5, 2023
1 parent 053312f commit 40fd2ef
Show file tree
Hide file tree
Showing 49 changed files with 750 additions and 956 deletions.
5 changes: 2 additions & 3 deletions apps/relax_examples/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@


import tvm
from tvm.relay import Call
from tvm import relax, tir, topi
import numpy as np

Expand All @@ -40,8 +39,8 @@ def build_mlp(data, weight):
# symbolic dimensions
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
# create data and weight variables
data = relax.Var("data", [n, m], relax.DynTensorType(2, "float32"))
weight = relax.Var("weight", [m, n], relax.DynTensorType(2, "float32"))
data = relax.Var("data", relax.TensorStructInfo([n, m], "float32"))
weight = relax.Var("weight", relax.TensorStructInfo([m, n], "float32"))

# construct a mlp model
mod = build_mlp(data, weight)
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/relax/binding_rewrite.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class DataflowBlockRewriteNode : public Object {
void Add(Binding binding);
/*! \brief Insert an expression as VarBinding with variable name. */
void Add(String var_name, Expr expr, bool is_dfvar = false) {
auto var = is_dfvar ? DataflowVar(var_name, expr->shape(), expr->checked_type())
: Var(var_name, expr->shape(), expr->checked_type());
auto var = is_dfvar ? DataflowVar(var_name, GetStructInfo(expr)) //
: Var(var_name, GetStructInfo(expr));
Add(VarBinding(std::move(var), std::move(expr)));
}
/*! \brief Insert an expression as VarBinding with automatic variable name. */
Expand Down
41 changes: 18 additions & 23 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -486,12 +486,11 @@ class VarNode : public ExprNode {

class Var : public Expr {
public:
TVM_DLL explicit Var(String name_hint, runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation, Span span = Span())
: Var(Id(name_hint), shape_annotation, type_annotation, span) {}
TVM_DLL explicit Var(String name_hint, Optional<StructInfo> struct_info_annotation,
Span span = Span())
: Var(Id(name_hint), struct_info_annotation, span) {}

TVM_DLL explicit Var(Id vid, runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation, Span span = Span());
TVM_DLL explicit Var(Id vid, Optional<StructInfo> struct_info_annotation, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode);
};
Expand Down Expand Up @@ -529,12 +528,12 @@ class DataflowVarNode : public VarNode {

class DataflowVar : public Var {
public:
TVM_DLL explicit DataflowVar(String name_hint, runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation, Span span = Span())
: DataflowVar(Id(name_hint), shape_annotation, type_annotation, span) {}
TVM_DLL explicit DataflowVar(String name_hint, Optional<StructInfo> struct_info_annotation,
Span span = Span())
: DataflowVar(Id(name_hint), struct_info_annotation, span) {}

TVM_DLL explicit DataflowVar(Id vid, runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation, Span span = Span());
TVM_DLL explicit DataflowVar(Id vid, Optional<StructInfo> struct_info_annotation,
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(DataflowVar, Var, DataflowVarNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowVarNode);
Expand Down Expand Up @@ -791,15 +790,12 @@ class FunctionNode : public BaseFuncNode {
/*! \brief The body of the function. */
Expr body;
/*! \brief The return type of the function. */
Type ret_type;
/*! \brief The return shape of the function. */
Expr ret_shape;
StructInfo ret_struct_info;

void VisitAttrs(AttrVisitor* v) {
v->Visit("params", &params);
v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
v->Visit("ret_shape", &ret_shape);
v->Visit("ret_struct_info", &ret_struct_info);
v->Visit("_checked_type_", &checked_type_);
v->Visit("shape_", &shape_);
v->Visit("struct_info_", &struct_info_);
Expand All @@ -810,7 +806,7 @@ class FunctionNode : public BaseFuncNode {
bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal.DefEqual(params, other->params) && equal(body, other->body) &&
equal(ret_type, other->ret_type) && equal(ret_shape, other->ret_shape) &&
equal(ret_struct_info, other->ret_struct_info) &&
equal(checked_type_, other->checked_type_) && equal(shape_, other->shape_) &&
equal(attrs, other->attrs);
}
Expand All @@ -819,8 +815,7 @@ class FunctionNode : public BaseFuncNode {
hash_reduce->MarkGraphNode();
hash_reduce.DefHash(params);
hash_reduce(body);
hash_reduce(ret_type);
hash_reduce(ret_shape);
hash_reduce(ret_struct_info);
hash_reduce(checked_type_);
hash_reduce(shape_);
hash_reduce(attrs);
Expand All @@ -834,15 +829,15 @@ class FunctionNode : public BaseFuncNode {

class Function : public BaseFunc {
public:
TVM_DLL explicit Function(Array<Var> params, Expr body, Type ret_type, Expr ret_shape,
TVM_DLL explicit Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info,
DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());

/*!
* \brief Mimics the constructor but without type checking.
* \brief Mimics the constructor but without body Expr.
* \note ret_struct_info is required, since it can not deduced by the body
*/
TVM_DLL static Function CreateUnchecked(Array<Var> params, Expr body, Type ret_type,
Expr ret_shape, DictAttrs attrs = NullValue<DictAttrs>(),
Span span = Span());
TVM_DLL static Function CreateEmpty(Array<Var> params, StructInfo ret_struct_info,
DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode);
Expand Down
22 changes: 8 additions & 14 deletions include/tvm/script/ir_builder/relax/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,16 @@ class FunctionFrameNode : public SeqExprFrameNode {
/*! \brief The function params. */
Array<tvm::relax::Var> params;
/*!
* \brief The function return type.
* \brief The function return struct info.
* \note Usually the function return type can be deduced by the function body.
* But we can use this field to specify a more "accurate" return type.
* i.e. If the `ret_type` is None, try to use the deduced type from body
* If the `ret_type` is not None, check the deduced type is a base type of the given one.
*/
Optional<Type> ret_type;
/*!
* \brief The function return shape.
* \sa ret_type
* i.e. If the `ret_struct_info` is None, try to use the deduced type from body
* If the `ret_struct_info` is not None, we can still take body.struct_info
* if we ret_struct_info is base of body.struct_info. If not, we will
* take the specified `ret_struct_info`.
*/
Optional<tvm::relax::Expr> ret_shape;
/*!
* \brief The function return struct info.
*/
Optional<tvm::relax::StructInfo> ret_sinfo;
Optional<tvm::relax::StructInfo> ret_struct_info;

/*! \brief The function attributes. */
Map<String, ObjectRef> attrs;
/*! \brief The block builder to create Relax function. */
Expand All @@ -113,7 +107,7 @@ class FunctionFrameNode : public SeqExprFrameNode {
SeqExprFrameNode::VisitAttrs(v);
v->Visit("name", &name);
v->Visit("params", &params);
v->Visit("ret_type", &ret_type);
v->Visit("ret_struct_info", &ret_struct_info);
v->Visit("attrs", &attrs);
v->Visit("binding_blocks", &binding_blocks);
v->Visit("output", &output);
Expand Down
28 changes: 11 additions & 17 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ def __init__(self, block_builder, def_vars):
else:
raise ValueError("def_vars only can take tir.Var")
# setup a dummy var so shape is in scope.
sparam = tvm.relax.Var("sparam")
tvm.relax.expr._update_struct_info(sparam, tvm.relax.ShapeStructInfo(shape_vars))
sparam = rx.Var("sparam", rx.ShapeStructInfo(shape_vars))
self._scope_params = [sparam]

def __enter__(self):
Expand All @@ -113,10 +112,8 @@ class BlockBuilder(Object):
m = tir.Var("m", "int32")
n = tir.Var("n", "int32")
type_anno0 = rx.DynTensorType(ndim=2, dtype="float16")
type_anno1 = rx.DynTensorType(ndim=1, dtype="float16")
x = rx.Var("x", [m, n], type_anno0)
y = rx.Var("y", [n], type_anno1)
x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
y = rx.Var("y", rx.TensorStructInfo([n], "float16")
bb = rx.BlockBuilder()
with bb.function([x, y], "func"):
with bb.dataflow() as df:
Expand All @@ -126,7 +123,7 @@ class BlockBuilder(Object):
bb.emit_func_output(gv0)
mod = bb.get()
BlockBuilder can also be used to contruct neural networks with nn.Module API
BlockBuilder can also be used to construct neural networks with nn.Module API
.. code-block:: python
Expand Down Expand Up @@ -454,9 +451,8 @@ def emit_te(self, func: Callable, *args: Any, **kwargs: Any) -> Var:
bb = rx.BlockBuilder()
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
type_anno = rx.DynTensorType(2, "float32")
x = rx.Var("x", [n, m], type_anno)
y = rx.Var("y", [n, m], type_anno)
x = rx.Var("x", rx.TensorStructInfo([n, m], "float32"))
y = rx.Var("y", rx.TensorStructInfo([n, m], "float32"))
def te_func(args, args_dict, msg):
A = args[0]
Expand Down Expand Up @@ -505,9 +501,8 @@ def rx_func(x: Tensor((n, m), "float32"), y: Tensor((n, m), "float32")) -> Tenso
bb = relax.BlockBuilder()
n = tir.Var("n", "int64")
type_anno = relax.DynTensorType(1, "float32")
x = relax.Var("x", [n], type_anno)
y = relax.Var("y", [n + 1], type_anno)
x = relax.Var("x", relax.TensorStructInfo([n], "float32"))
y = relax.Var("y", relax.TensorStructInfo([n + 1], "float32"))
def te_func(A):
C = te.compute((n + 1), lambda i: A[i])
Expand Down Expand Up @@ -628,10 +623,9 @@ def emit_func_output(
self._blocks.append(block)
seqe = self.normalize(rx.SeqExpr(self._blocks, output))

# The function's checked_type_ relies on the function body(seqe) to have deduced type
# TODO(@yuchen): handle the case where the body's checked_type_ is null
# TODO: Deduce the ret shape too
func = rx.Function(self._func_params, seqe, None, rx.RuntimeDepShape())
# do not specify ret_struct_info and let constructor deduce
# from seqe.struct_info
func = rx.Function(self._func_params, seqe)
for key, value in self._func_attrs.items():
func = func.with_attr(key, value)
self.end_scope()
Expand Down
91 changes: 62 additions & 29 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,40 @@
Type = Union[relay.Type]
GlobalVar = Union[relay.GlobalVar]

# will be registered afterwards
_op_make = None
# NOTE: place base struct info in expr to avoid cyclic dep
# from expr to struct info.
class StructInfo(Node):
"""The base class of all StructInfo.
StructInfo contains both the static type
and runtime structural information.
"""

def __eq__(self, other):
"""Compare two struct info for structural equivalence."""
return tvm.ir.structural_equal(self, other)

def __ne__(self, other):
return not self.__eq__(other)

def same_as(self, other):
"""Overload with structural equality."""
return super().__eq__(other)

def is_base_of(self, derived: "StructInfo") -> bool:
"""Check if self is base of another derived struct info.
Parameters
----------
derived : StructInfo
The derived struct info to be checked.
Returns
-------
result : bool
The check result.
"""
return _ffi_api.StructInfoIsBaseOf(self, derived) # type: ignore


@tvm._ffi.register_object("relax.expr.Call")
Expand Down Expand Up @@ -199,22 +231,24 @@ class Var(Expr):
"""The variable class for all Relax bindings."""

vid: Id
type_annotation: Optional[Type]
struct_info: Optional[StructInfo]

def __init__(
self,
name_hint: str,
shape_annotation: Optional[Union[List[Any], typing.Tuple[Any, ...]]] = None,
type_annotation: Optional[Type] = None,
struct_info: Optional[StructInfo] = None,
span: Span = None,
) -> None:
if isinstance(shape_annotation, (list, tuple)):
shape_annotation = make_shape(shape_annotation)
if struct_info is not None and not isinstance(struct_info, StructInfo):
raise TypeError(
"struct_info needs to be an instance of StructInfo. "
"If you attempt to pass in shape, "
"use relax.TensorStructInfo(shape, dtype)."
)
self.__init_handle_by_constructor__(
_ffi_api.Var if isinstance(name_hint, str) else _ffi_api.VarFromId, # type: ignore
name_hint,
shape_annotation,
type_annotation,
struct_info,
span,
)

Expand Down Expand Up @@ -249,23 +283,28 @@ class DataflowVar(Var):
"""A sub-type of the variable node used to mark dataflow variables from
normal visible "function local" bindings."""

vid: Id
struct_info: Optional[StructInfo]

def __init__(
self,
name_hint: Union[str, Id],
shape_annotation: Optional[Union[List[Any], typing.Tuple[Any, ...]]] = None,
type_annotation: Optional[Type] = None,
struct_info: Optional[StructInfo] = None,
span: Span = None,
) -> None:
if isinstance(shape_annotation, (list, tuple)):
shape_annotation = make_shape(shape_annotation)
if struct_info is not None and not isinstance(struct_info, StructInfo):
raise TypeError(
"struct_info needs to be an instance of StructInfo. "
"If you attempt to pass in shape, "
"use relax.TensorStructInfo(shape, dtype)."
)

self.__init_handle_by_constructor__(
_ffi_api.DataflowVar # type: ignore
if isinstance(name_hint, str)
else _ffi_api.DataflowVarFromId, # type: ignore
name_hint,
shape_annotation,
type_annotation,
struct_info,
span,
)

Expand Down Expand Up @@ -338,36 +377,30 @@ class Function(BaseFunc):

params: List[Var]
body: Expr
ret_type: Type
ret_shape: Expr
ret_struct_info: StructInfo
attrs: Optional[tvm.ir.DictAttrs]

def __init__(
self,
params: List[Var],
body: Expr,
ret_type: Type,
ret_shape: Expr,
ret_struct_info: Optional[StructInfo] = None,
attrs: Optional[tvm.ir.DictAttrs] = None,
span: Optional[Span] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.Function, params, body, ret_type, ret_shape, attrs, span # type: ignore
_ffi_api.Function, params, body, ret_struct_info, attrs, span # type: ignore
)

@staticmethod
def create_unchecked(
def create_empty(
params: List[Var],
body: Expr,
ret_type: Type,
ret_shape: Expr,
ret_struct_info: StructInfo,
attrs: Optional[tvm.ir.DictAttrs] = None,
span: Optional[Span] = None,
):
"""Construct a relax.Function but without type checking."""
return _ffi_api.Function_CreateUnchecked( # type: ignore
params, body, ret_type, ret_shape, attrs, span
)
"""Construct a relax.Function but without body"""
return _ffi_api.FunctionCreateEmpty(params, ret_struct_info, attrs, span) # type: ignore

def __call__(self, *args):
"""Invoke the global function.
Expand Down Expand Up @@ -476,5 +509,5 @@ def te_tensor(value: Expr, name: str = "rxplaceholder"):
return _ffi_api.TETensor(value, name) # type: ignore


def _update_struct_info(expr: Expr, struct_info: Optional["tvm.relax.StructInfo"]) -> None:
def _update_struct_info(expr: Expr, struct_info: Optional[StructInfo]) -> None:
_ffi_api.UpdateStructInfo(expr, struct_info) # type: ignore
Loading

0 comments on commit 40fd2ef

Please sign in to comment.