Skip to content

Commit

Permalink
[Refactor][AST] sinfo_args A1: Update API of call_tir, call_packed,…
Browse files Browse the repository at this point in the history
… etc. (apache#386)

* [Refactor][AST] `sinfo_args` A1: Update API of call_tir, call_packed, etc

Following apache#377 and apache#379, this PR followups to update the API of `call_tir`,
`call_packed`, `call_builtin` and `invoke_closure`.

* The API of `call_tir` is changed to
  ```python
  def call_tir(
      func: Union[str, Expr],
      args: Expr,
      out_sinfo: Union[TensorStructInfo, TupleStructInfo],
      tir_vars: Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] = None,
  ) -> Call:
      ...
  ```
  where we combine the `shape` and `dtype` parameters into `out_sinfo`. In
  the concrete CallNode of `call_tir`, the output shape is no longer passed
  as an CallNode argument and passed in the `sinfo_args` instead.

* For `call_packed`, `call_builtin` and `invoke_closure`, we change the
  `type_args` parameter to `sinfo_args`. For example, the API of `call_packed`
  is updated to
  ```python
  def call_packed(
      func: str,
      *args: Expr,
      sinfo_args: Union[StructInfo, List[StructInfo]],
      **kwargs: Any,
  ) -> Call:
      ...
  ```

---

One thing specific to note is about our existing dataflow pattern language.
Previously we have a sugar function for `call_tir`:
```python
def is_call_tir(
    func_name: str,
    args: Union[List, Tuple, TuplePattern] = None,
    shape: Union[Tuple, List[tvm.ir.PrimExpr], DFPattern] = None,
) -> CallPattern:
    ...
```
Since we changed the API of `call_tir` - the CallNode does not contain the
shape as one argument, - together with the fact that the dataflow pattern
language does not yet support matching StructInfo, we have no approach to
match the shape anymore. Therefore, the `shape` parameter is removed from
`is_call_tir`. I left some Todo items there for future `sinfo_args` matching
support.

* Update call_tir API and address comments

Now the expected type of `out_sinfo` is:
```python
out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]]
```
  • Loading branch information
MasterJH5574 authored and junrushao committed Feb 9, 2023
1 parent 68314b9 commit 7228f48
Show file tree
Hide file tree
Showing 45 changed files with 448 additions and 550 deletions.
8 changes: 5 additions & 3 deletions include/tvm/relax/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ class CallPatternNode : public DFPatternNode {
*/
bool varg_default_wildcard; /*!< #args can be < #real args with the rest padded by Wildcard() */

// Todo(relax-team): Dataflow pattern for StructInfo, and match sinfo_args

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("op", &op);
v->Visit("args", &args);
Expand Down Expand Up @@ -786,10 +788,10 @@ ExprPattern IsExpr(const Expr& expr);
/*! \brief Syntatic Sugar for creating a ExprPattern base on an Op */
ExprPattern IsOp(const String& op_name);
/*! \brief Syntatic Sugar for call_tir (return a tensor) */
CallPattern IsCallTIR(const String& name, Optional<TuplePattern> args = NullOpt,
Optional<Array<PrimExpr>> oshape = NullOpt);
// Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo
CallPattern IsCallTIR(const String& name, Optional<TuplePattern> args = NullOpt);
/*! \brief Syntatic Sugar for call_tir (return a tuple of tensor) */
CallPattern IsCallTIR(const String& name, TuplePattern var_args, Array<Array<PrimExpr>> oshapes);
CallPattern IsCallTIR(const String& name, TuplePattern var_args);
/*! \brief Syntatic Sugar for creating TuplePattern or UnorderedTuplePattern (unordered=true) */
DFPattern IsTuple(const Array<DFPattern>& fields, bool unordered = false);
/*! \brief Syntatic Sugar for creating a TupleGetItemPattern */
Expand Down
5 changes: 0 additions & 5 deletions python/tvm/relax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,6 @@ def get_static_type(sinfo: StructInfo) -> Type:
return _ffi_api.GetStaticType(sinfo) # type: ignore


# Todo(ruihang): introduced to make call_packed work. To be removed in the followup PR.
def struct_info_from_type(ty: Type): # pylint: disable=invalid-name
return _ffi_api.StructInfoFromType(ty) # type: ignore


def erase_to_well_defined(
sinfo: StructInfo,
shape_var_map: Dict[tir.Var, tir.PrimExpr] = None,
Expand Down
24 changes: 8 additions & 16 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,6 @@ def call_te(self, func: Callable, *args: Any, **kwargs: Any) -> Expr:
and all(isinstance(t, tvm.te.tensor.Tensor) for t in te_out)
), "only support te.tensor or tuple/list/Array of te.tensor as function output"

if isinstance(te_out, (tuple, list, tvm.ir.Array)) and len(te_out) == 1:
te_out = te_out[0]

outs = [te_out] if isinstance(te_out, tvm.te.tensor.Tensor) else list(te_out)
unbound_tir_vars = self._get_unbound_tir_vars(te_args + outs)

Expand All @@ -446,27 +443,22 @@ def _shape_with_old_tir_var(
# Invert the TIR variable mapping, to convert the output shape back
# with old set of variables.
tir_var_inverse_map = {v: k for k, v in tir_var_map.items()}
output_shape = (
_shape_with_old_tir_var(outs[0].shape, tir_var_inverse_map)
if isinstance(te_out, tvm.te.tensor.Tensor)
else Tuple([_shape_with_old_tir_var(x.shape, tir_var_inverse_map) for x in outs])
)

output_dtype = (
te_out.dtype if isinstance(te_out, tvm.te.tensor.Tensor) else [x.dtype for x in outs]
)
output_sinfo = [
TensorStructInfo(_shape_with_old_tir_var(out.shape, tir_var_inverse_map), out.dtype)
for out in outs
]

# add arguments for extra parameters from unbound var
if len(unbound_tir_vars) > 0:
call = call_tir(
gvar,
call_args,
output_shape,
output_dtype,
output_sinfo,
tir_vars=_shape_with_old_tir_var(unbound_tir_vars, tir_var_inverse_map),
)
else:
call = call_tir(gvar, call_args, output_shape, output_dtype)
call = call_tir(gvar, call_args, output_sinfo)
return call

def emit_te(self, func: Callable, *args: Any, **kwargs: Any) -> Var:
Expand Down Expand Up @@ -539,7 +531,7 @@ def te_func(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle,
@R.function
def rx_func(x: Tensor((n, m), "float32"), y: Tensor((n, m), "float32")) -> Tensor:
# block 0
gv = relax.call_tir("te_func", (x, y), (128, 128), dtype="float32")
gv = relax.call_tir("te_func", (x, y), R.Tensor((128, 128), "float32"))
return gv
Example
Expand Down Expand Up @@ -584,7 +576,7 @@ def te_func(var_rxplaceholder: T.handle, var_compute: T.handle, n: T.int64) -> N
def rx_func(x: Tensor((n,), "float32"), y: Tensor(((n + 1),), "float32"))
-> Tensor(None, "float32", ndim=-1):
# block 0
gv = relax.call_tir(te_func, (y,), ((n + 1),), (n,), dtype="float32")
gv = relax.call_tir(te_func, (y,), R.Tensor((n + 1,), "float32"), (n,))
return gv
"""
return self.emit(self.call_te(func, *args, **kwargs))
Expand Down
22 changes: 5 additions & 17 deletions python/tvm/relax/dpl/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,30 +800,23 @@ def is_shape(shape: List[tvm.ir.PrimExpr]) -> "PrimArrPattern":
return PrimArrPattern(shape)


# Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo
def _is_call_tir(
func_pattern: DFPattern,
args: Union[List, Tuple, TuplePattern] = None,
shape: Union[Tuple, List[tvm.ir.PrimExpr], DFPattern] = None,
) -> CallPattern:
if args is None:
args = wildcard()
elif isinstance(args, (list, tuple)):
args = TuplePattern(args)

if shape is None:
shape = wildcard()
elif isinstance(shape, (list, Array)):
shape = PrimArrPattern(shape)
elif isinstance(shape, (tuple)):
shape = is_tuple(shape) # multiple shape patterns

return is_op("relax.call_tir")(func_pattern, args, shape)
return is_op("relax.call_tir")(func_pattern, args)


# Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo
def is_call_tir(
func_name: str,
args: Union[List, Tuple, TuplePattern] = None,
shape: Union[Tuple, List[tvm.ir.PrimExpr], DFPattern] = None,
) -> CallPattern:
"""
Syntax sugar for creating a CallPattern for call_tir that calls an function through global var.
Expand All @@ -834,22 +827,19 @@ def is_call_tir(
Name of the CPS function to call.
args : Union[List[DFPattern], Tuple[DFPattern]], optional
Arguments in expected call_packed, by default None meaning arbitrary (number of) arguments
shape : Union[Tuple, List[tvm.ir.PrimExpr], DFPattern], optional
Shape (or shapes in a tuple) of the output, by default None meaning arbitrary shape(s)
Returns
-------
CallPattern
The resulting CallPattern
"""
func_pattern = GlobalVarPattern(func_name)
return _is_call_tir(func_pattern, args, shape)
return _is_call_tir(func_pattern, args)


def is_call_tir_extern(
func_name: str,
args: Union[List, Tuple, TuplePattern] = None,
shape: Union[Tuple, List[tvm.ir.PrimExpr], DFPattern] = None,
) -> CallPattern:
"""Syntax sugar for creating a CallPattern for call_tir that calls an extern function
Expand All @@ -859,16 +849,14 @@ def is_call_tir_extern(
Name of the CPS function to call.
args : Union[List[DFPattern], Tuple[DFPattern]], optional
Arguments in expected call_packed, by default None meaning arbitrary (number of) arguments
shape : Union[Tuple, List[tvm.ir.PrimExpr], DFPattern], optional
Shape (or shapes in a tuple) of the output, by default None meaning arbitrary shape(s)
Returns
-------
CallPattern
The resulting CallPattern
"""
func_pattern = ExternFuncPattern(func_name)
return _is_call_tir(func_pattern, args, shape)
return _is_call_tir(func_pattern, args)


def is_call_packed(
Expand Down
82 changes: 22 additions & 60 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from . import _ffi_api
from ..expr import Expr, ShapeExpr, Call, ExternFunc
from ..expr import Tuple as RxTuple
from ..ty import DynTensorType, TupleType
from ...ir import Array, Type, PrimExpr
from ..struct_info import StructInfo, TensorStructInfo
from ...ir import PrimExpr
from ..utils import args_converter


Expand All @@ -47,8 +47,7 @@ def null_value() -> Call:
def call_tir(
func: Union[str, Expr],
args: Expr,
shape: Union[RxTuple, ShapeExpr, List[int]],
dtype: Union[str, List[str]],
out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]],
tir_vars: Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] = None,
) -> Call:
"""
Expand All @@ -62,11 +61,10 @@ def call_tir(
args : Expr
The input arguments.
shape: Union[RxTuple, ShapeExpr, List[int]]
The output shape. Tuple(ShapeExpr) if multiple outputs, ShapeExpr if single output.
dtype: Union[str, List[str]]
The output dtype. List[str] if multiple outputs, str if single output.
out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]]
The structure info of the call_tir output.
It should be a single or a list of TensorStructInfo. Each one denotes the
structure info of a returned tensor.
tir_vars : Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]]
ShapeExpr representing a tuple of integers to unpack when calling func. Is null if not used
Expand All @@ -79,60 +77,24 @@ def call_tir(
if isinstance(func, str):
func = ExternFunc(func)

def _create_shape(shape: List[Union[int, PrimExpr]]) -> ShapeExpr:
shape_array = []
for x in shape:
if isinstance(x, int):
shape_array.append(tvm.tir.IntImm("int64", x))
elif isinstance(x, tvm.tir.IntImm):
shape_array.append(x if x.dtype == "int64" else tvm.tir.IntImm("int64", x.value))
elif isinstance(x, PrimExpr):
if x.dtype != "int64":
raise TypeError("Expect int64 dtype for shape")
shape_array.append(x)
else:
raise TypeError("Expect int or PrimExpr for shape")
return ShapeExpr(shape_array)

if isinstance(shape, (list, tuple, Array)):
if all([not isinstance(x, (list, tuple, Array, ShapeExpr)) for x in shape]):
shape = _create_shape(shape) # type: ignore
elif all([isinstance(x, (list, tuple, Array, ShapeExpr)) for x in shape]):
shape = RxTuple(
[
_create_shape(x) if not isinstance(x, ShapeExpr) else x # type: ignore
for x in shape
]
)
else:
raise TypeError(
f"The shape is expected to be ShapeExpr or Tuple[ShapeExpr], bot got: f{shape}"
)

if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
args = RxTuple((args,))

if isinstance(dtype, str):
output_type = DynTensorType(len(shape), dtype)
elif isinstance(dtype, (list, tuple)):
if len(shape) != len(dtype):
raise ValueError("The number of output_shape and output_dtype of call_tir mismatch")
output_type = TupleType([DynTensorType(len(x), y) for x, y in zip(shape, dtype)])
else:
raise TypeError("Not supported dtype for call_tir: " + str(type(dtype)))
if not isinstance(out_sinfo, list):
out_sinfo = [out_sinfo]

if isinstance(tir_vars, (list, tuple)):
tir_vars = ShapeExpr(tir_vars)

return _ffi_api.call_tir(func, args, shape, output_type, tir_vars) # type: ignore
return _ffi_api.call_tir(func, args, out_sinfo, tir_vars) # type: ignore


@args_converter.auto
def call_builtin(
func: Union[str, Expr],
args: Expr,
*,
type_args: Optional[Union[Type, List[Type]]] = None,
sinfo_args: Optional[Union[StructInfo, List[StructInfo]]] = None,
int_args: Optional[List[int]] = None,
dtype_arg: Optional[str] = None,
str_args: Optional[List[str]] = None,
Expand All @@ -148,8 +110,8 @@ def call_builtin(
args : Expr
The input arguments.
type_args: Optional[Union[Type, List[Type]]]
The type arguments to the call node.
sinfo_args: Optional[Union[StructInfo, List[StructInfo]]]
The struct info arguments to the call node.
int_args: Optional[List[int]]
List of additional int arguments passed to the builtin.
Expand All @@ -171,11 +133,11 @@ def call_builtin(
if isinstance(func, str):
func = ExternFunc(func)

if type_args is not None and not isinstance(type_args, (list, tuple)):
type_args = [type_args]
if sinfo_args is not None and not isinstance(sinfo_args, (list, tuple)):
sinfo_args = [sinfo_args]

return _ffi_api.call_builtin( # type: ignore
func, args, type_args, int_args, dtype_arg, str_args, require_ctx # type: ignore
func, args, sinfo_args, int_args, dtype_arg, str_args, require_ctx # type: ignore
)


Expand Down Expand Up @@ -209,7 +171,7 @@ def make_closure(
def invoke_closure(
closure: Expr,
args: Expr,
type_args: Union[List[Type], Type],
sinfo_args: Union[List[StructInfo], StructInfo],
) -> Object:
"""
Invoke a closure.
Expand All @@ -222,19 +184,19 @@ def invoke_closure(
args : Expr
The input arguments.
type_args: Union[Tuple[Type], Type]
The type_args of the CallNode
type_args: Union[List[StructInfo], StructInfo]
The structure info arguments of the CallNode
Returns
-------
ret: Object
The result.
"""

if not isinstance(type_args, (list, tuple)):
type_args = (type_args,)
if not isinstance(sinfo_args, (list, tuple)):
sinfo_args = [sinfo_args]

return _ffi_api.invoke_closure(closure, args, type_args) # type: ignore
return _ffi_api.invoke_closure(closure, args, sinfo_args) # type: ignore


def render_object(val: tvm.Object) -> str:
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relax/testing/relay_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def visit_func(node):

if translate_op_with_tir and op_name in translate_op_with_tir:
tir_gvar = bb.add_func(translate_op_with_tir[op_name], op_name)
call = relax.call_tir(tir_gvar, new_args, out_type.shape, out_type.dtype)
call = relax.call_tir(
tir_gvar, new_args, relax.TensorStructInfo(out_type.shape, out_type.dtype)
)
var = bb.emit(call)
else:
with target:
Expand Down
Loading

0 comments on commit 7228f48

Please sign in to comment.