diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index 33f509d7dbcd..8d3825b58669 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -379,6 +379,8 @@ class CallPatternNode : public DFPatternNode { */ bool varg_default_wildcard; /*!< #args can be < #real args with the rest padded by Wildcard() */ + // Todo(relax-team): Dataflow pattern for StructInfo, and match sinfo_args + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("op", &op); v->Visit("args", &args); @@ -786,10 +788,10 @@ ExprPattern IsExpr(const Expr& expr); /*! \brief Syntatic Sugar for creating a ExprPattern base on an Op */ ExprPattern IsOp(const String& op_name); /*! \brief Syntatic Sugar for call_tir (return a tensor) */ -CallPattern IsCallTIR(const String& name, Optional args = NullOpt, - Optional> oshape = NullOpt); +// Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo +CallPattern IsCallTIR(const String& name, Optional args = NullOpt); /*! \brief Syntatic Sugar for call_tir (return a tuple of tensor) */ -CallPattern IsCallTIR(const String& name, TuplePattern var_args, Array> oshapes); +CallPattern IsCallTIR(const String& name, TuplePattern var_args); /*! \brief Syntatic Sugar for creating TuplePattern or UnorderedTuplePattern (unordered=true) */ DFPattern IsTuple(const Array& fields, bool unordered = false); /*! \brief Syntatic Sugar for creating a TupleGetItemPattern */ diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 3ac2adf2ebb9..e92d09e266f6 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -48,11 +48,6 @@ def get_static_type(sinfo: StructInfo) -> Type: return _ffi_api.GetStaticType(sinfo) # type: ignore -# Todo(ruihang): introduced to make call_packed work. To be removed in the followup PR. -def struct_info_from_type(ty: Type): # pylint: disable=invalid-name - return _ffi_api.StructInfoFromType(ty) # type: ignore - - def erase_to_well_defined( sinfo: StructInfo, shape_var_map: Dict[tir.Var, tir.PrimExpr] = None, diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 3b11c1a2b7f0..013e0318720e 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -418,9 +418,6 @@ def call_te(self, func: Callable, *args: Any, **kwargs: Any) -> Expr: and all(isinstance(t, tvm.te.tensor.Tensor) for t in te_out) ), "only support te.tensor or tuple/list/Array of te.tensor as function output" - if isinstance(te_out, (tuple, list, tvm.ir.Array)) and len(te_out) == 1: - te_out = te_out[0] - outs = [te_out] if isinstance(te_out, tvm.te.tensor.Tensor) else list(te_out) unbound_tir_vars = self._get_unbound_tir_vars(te_args + outs) @@ -446,27 +443,22 @@ def _shape_with_old_tir_var( # Invert the TIR variable mapping, to convert the output shape back # with old set of variables. tir_var_inverse_map = {v: k for k, v in tir_var_map.items()} - output_shape = ( - _shape_with_old_tir_var(outs[0].shape, tir_var_inverse_map) - if isinstance(te_out, tvm.te.tensor.Tensor) - else Tuple([_shape_with_old_tir_var(x.shape, tir_var_inverse_map) for x in outs]) - ) - output_dtype = ( - te_out.dtype if isinstance(te_out, tvm.te.tensor.Tensor) else [x.dtype for x in outs] - ) + output_sinfo = [ + TensorStructInfo(_shape_with_old_tir_var(out.shape, tir_var_inverse_map), out.dtype) + for out in outs + ] # add arguments for extra parameters from unbound var if len(unbound_tir_vars) > 0: call = call_tir( gvar, call_args, - output_shape, - output_dtype, + output_sinfo, tir_vars=_shape_with_old_tir_var(unbound_tir_vars, tir_var_inverse_map), ) else: - call = call_tir(gvar, call_args, output_shape, output_dtype) + call = call_tir(gvar, call_args, output_sinfo) return call def emit_te(self, func: Callable, *args: Any, **kwargs: Any) -> Var: @@ -539,7 +531,7 @@ def te_func(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, @R.function def rx_func(x: Tensor((n, m), "float32"), y: Tensor((n, m), "float32")) -> Tensor: # block 0 - gv = relax.call_tir("te_func", (x, y), (128, 128), dtype="float32") + gv = relax.call_tir("te_func", (x, y), R.Tensor((128, 128), "float32")) return gv Example @@ -584,7 +576,7 @@ def te_func(var_rxplaceholder: T.handle, var_compute: T.handle, n: T.int64) -> N def rx_func(x: Tensor((n,), "float32"), y: Tensor(((n + 1),), "float32")) -> Tensor(None, "float32", ndim=-1): # block 0 - gv = relax.call_tir(te_func, (y,), ((n + 1),), (n,), dtype="float32") + gv = relax.call_tir(te_func, (y,), R.Tensor((n + 1,), "float32"), (n,)) return gv """ return self.emit(self.call_te(func, *args, **kwargs)) diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index c761bf226aca..9ea138cabed0 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -800,30 +800,23 @@ def is_shape(shape: List[tvm.ir.PrimExpr]) -> "PrimArrPattern": return PrimArrPattern(shape) +# Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo def _is_call_tir( func_pattern: DFPattern, args: Union[List, Tuple, TuplePattern] = None, - shape: Union[Tuple, List[tvm.ir.PrimExpr], DFPattern] = None, ) -> CallPattern: if args is None: args = wildcard() elif isinstance(args, (list, tuple)): args = TuplePattern(args) - if shape is None: - shape = wildcard() - elif isinstance(shape, (list, Array)): - shape = PrimArrPattern(shape) - elif isinstance(shape, (tuple)): - shape = is_tuple(shape) # multiple shape patterns - - return is_op("relax.call_tir")(func_pattern, args, shape) + return is_op("relax.call_tir")(func_pattern, args) +# Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo def is_call_tir( func_name: str, args: Union[List, Tuple, TuplePattern] = None, - shape: Union[Tuple, List[tvm.ir.PrimExpr], DFPattern] = None, ) -> CallPattern: """ Syntax sugar for creating a CallPattern for call_tir that calls an function through global var. @@ -834,8 +827,6 @@ def is_call_tir( Name of the CPS function to call. args : Union[List[DFPattern], Tuple[DFPattern]], optional Arguments in expected call_packed, by default None meaning arbitrary (number of) arguments - shape : Union[Tuple, List[tvm.ir.PrimExpr], DFPattern], optional - Shape (or shapes in a tuple) of the output, by default None meaning arbitrary shape(s) Returns ------- @@ -843,13 +834,12 @@ def is_call_tir( The resulting CallPattern """ func_pattern = GlobalVarPattern(func_name) - return _is_call_tir(func_pattern, args, shape) + return _is_call_tir(func_pattern, args) def is_call_tir_extern( func_name: str, args: Union[List, Tuple, TuplePattern] = None, - shape: Union[Tuple, List[tvm.ir.PrimExpr], DFPattern] = None, ) -> CallPattern: """Syntax sugar for creating a CallPattern for call_tir that calls an extern function @@ -859,8 +849,6 @@ def is_call_tir_extern( Name of the CPS function to call. args : Union[List[DFPattern], Tuple[DFPattern]], optional Arguments in expected call_packed, by default None meaning arbitrary (number of) arguments - shape : Union[Tuple, List[tvm.ir.PrimExpr], DFPattern], optional - Shape (or shapes in a tuple) of the output, by default None meaning arbitrary shape(s) Returns ------- @@ -868,7 +856,7 @@ def is_call_tir_extern( The resulting CallPattern """ func_pattern = ExternFuncPattern(func_name) - return _is_call_tir(func_pattern, args, shape) + return _is_call_tir(func_pattern, args) def is_call_packed( diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index a0f0822aa3b6..df3ce6eed030 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -24,8 +24,8 @@ from . import _ffi_api from ..expr import Expr, ShapeExpr, Call, ExternFunc from ..expr import Tuple as RxTuple -from ..ty import DynTensorType, TupleType -from ...ir import Array, Type, PrimExpr +from ..struct_info import StructInfo, TensorStructInfo +from ...ir import PrimExpr from ..utils import args_converter @@ -47,8 +47,7 @@ def null_value() -> Call: def call_tir( func: Union[str, Expr], args: Expr, - shape: Union[RxTuple, ShapeExpr, List[int]], - dtype: Union[str, List[str]], + out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]], tir_vars: Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] = None, ) -> Call: """ @@ -62,11 +61,10 @@ def call_tir( args : Expr The input arguments. - shape: Union[RxTuple, ShapeExpr, List[int]] - The output shape. Tuple(ShapeExpr) if multiple outputs, ShapeExpr if single output. - - dtype: Union[str, List[str]] - The output dtype. List[str] if multiple outputs, str if single output. + out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]] + The structure info of the call_tir output. + It should be a single or a list of TensorStructInfo. Each one denotes the + structure info of a returned tensor. tir_vars : Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] ShapeExpr representing a tuple of integers to unpack when calling func. Is null if not used @@ -79,52 +77,16 @@ def call_tir( if isinstance(func, str): func = ExternFunc(func) - def _create_shape(shape: List[Union[int, PrimExpr]]) -> ShapeExpr: - shape_array = [] - for x in shape: - if isinstance(x, int): - shape_array.append(tvm.tir.IntImm("int64", x)) - elif isinstance(x, tvm.tir.IntImm): - shape_array.append(x if x.dtype == "int64" else tvm.tir.IntImm("int64", x.value)) - elif isinstance(x, PrimExpr): - if x.dtype != "int64": - raise TypeError("Expect int64 dtype for shape") - shape_array.append(x) - else: - raise TypeError("Expect int or PrimExpr for shape") - return ShapeExpr(shape_array) - - if isinstance(shape, (list, tuple, Array)): - if all([not isinstance(x, (list, tuple, Array, ShapeExpr)) for x in shape]): - shape = _create_shape(shape) # type: ignore - elif all([isinstance(x, (list, tuple, Array, ShapeExpr)) for x in shape]): - shape = RxTuple( - [ - _create_shape(x) if not isinstance(x, ShapeExpr) else x # type: ignore - for x in shape - ] - ) - else: - raise TypeError( - f"The shape is expected to be ShapeExpr or Tuple[ShapeExpr], bot got: f{shape}" - ) - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore args = RxTuple((args,)) - if isinstance(dtype, str): - output_type = DynTensorType(len(shape), dtype) - elif isinstance(dtype, (list, tuple)): - if len(shape) != len(dtype): - raise ValueError("The number of output_shape and output_dtype of call_tir mismatch") - output_type = TupleType([DynTensorType(len(x), y) for x, y in zip(shape, dtype)]) - else: - raise TypeError("Not supported dtype for call_tir: " + str(type(dtype))) + if not isinstance(out_sinfo, list): + out_sinfo = [out_sinfo] if isinstance(tir_vars, (list, tuple)): tir_vars = ShapeExpr(tir_vars) - return _ffi_api.call_tir(func, args, shape, output_type, tir_vars) # type: ignore + return _ffi_api.call_tir(func, args, out_sinfo, tir_vars) # type: ignore @args_converter.auto @@ -132,7 +94,7 @@ def call_builtin( func: Union[str, Expr], args: Expr, *, - type_args: Optional[Union[Type, List[Type]]] = None, + sinfo_args: Optional[Union[StructInfo, List[StructInfo]]] = None, int_args: Optional[List[int]] = None, dtype_arg: Optional[str] = None, str_args: Optional[List[str]] = None, @@ -148,8 +110,8 @@ def call_builtin( args : Expr The input arguments. - type_args: Optional[Union[Type, List[Type]]] - The type arguments to the call node. + sinfo_args: Optional[Union[StructInfo, List[StructInfo]]] + The struct info arguments to the call node. int_args: Optional[List[int]] List of additional int arguments passed to the builtin. @@ -171,11 +133,11 @@ def call_builtin( if isinstance(func, str): func = ExternFunc(func) - if type_args is not None and not isinstance(type_args, (list, tuple)): - type_args = [type_args] + if sinfo_args is not None and not isinstance(sinfo_args, (list, tuple)): + sinfo_args = [sinfo_args] return _ffi_api.call_builtin( # type: ignore - func, args, type_args, int_args, dtype_arg, str_args, require_ctx # type: ignore + func, args, sinfo_args, int_args, dtype_arg, str_args, require_ctx # type: ignore ) @@ -209,7 +171,7 @@ def make_closure( def invoke_closure( closure: Expr, args: Expr, - type_args: Union[List[Type], Type], + sinfo_args: Union[List[StructInfo], StructInfo], ) -> Object: """ Invoke a closure. @@ -222,8 +184,8 @@ def invoke_closure( args : Expr The input arguments. - type_args: Union[Tuple[Type], Type] - The type_args of the CallNode + type_args: Union[List[StructInfo], StructInfo] + The structure info arguments of the CallNode Returns ------- @@ -231,10 +193,10 @@ def invoke_closure( The result. """ - if not isinstance(type_args, (list, tuple)): - type_args = (type_args,) + if not isinstance(sinfo_args, (list, tuple)): + sinfo_args = [sinfo_args] - return _ffi_api.invoke_closure(closure, args, type_args) # type: ignore + return _ffi_api.invoke_closure(closure, args, sinfo_args) # type: ignore def render_object(val: tvm.Object) -> str: diff --git a/python/tvm/relax/testing/relay_translator.py b/python/tvm/relax/testing/relay_translator.py index e2b62c3d91d6..fd5aab89fa76 100644 --- a/python/tvm/relax/testing/relay_translator.py +++ b/python/tvm/relax/testing/relay_translator.py @@ -169,7 +169,9 @@ def visit_func(node): if translate_op_with_tir and op_name in translate_op_with_tir: tir_gvar = bb.add_func(translate_op_with_tir[op_name], op_name) - call = relax.call_tir(tir_gvar, new_args, out_type.shape, out_type.dtype) + call = relax.call_tir( + tir_gvar, new_args, relax.TensorStructInfo(out_type.shape, out_type.dtype) + ) var = bb.emit(call) else: with target: diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index a797d698dff6..78cabdddfb4b 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -24,12 +24,8 @@ import tvm from tvm import DataType, relax -from tvm.ir import PrimExpr, Type +from tvm.ir import PrimExpr from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, Var, const -from tvm.relax.analysis import get_static_type - -# Todo(ruihang): introduced to make call_packed work. To be removed in the followup PR. -from tvm.relax.analysis import struct_info_from_type ############################### Operators ############################### from tvm.relax.op import ( @@ -211,7 +207,7 @@ def output(*vars: Tuple[Var]) -> None: def call_packed( func: str, *args: Expr, - type_args: Optional[Union[StructInfo, List[StructInfo]]] = None, + sinfo_args: Union[StructInfo, List[StructInfo]], **kwargs: Any, ) -> Call: """Create a relax Call, which calls a packed function. @@ -221,8 +217,8 @@ def call_packed( The name of extern function. args : List[Expr] The arguments. - type_args: Optional[Union[StructInfo, List[StructInfo]]] - List of Types + sinfo_args: Union[StructInfo, List[StructInfo]] + The list of structure info arguments. kwargs: Expr The keyword arguments. @@ -231,8 +227,6 @@ def call_packed( call: Call The created Relax Call """ - # Todo(ruihang): reorganize API in the followup PR of A1. - sinfo_args = type_args op = ExternFunc(func) if sinfo_args is None: raise ValueError("R.call_packed is required to have type_args") @@ -240,21 +234,13 @@ def call_packed( sinfo_args = list(sinfo_args) elif not isinstance(sinfo_args, list): sinfo_args = [sinfo_args] - for i, argument in enumerate(sinfo_args): - if callable(argument): - argument = argument() + for i, sinfo_arg in enumerate(sinfo_args): + if callable(sinfo_arg): + sinfo_arg = sinfo_arg() # Convert possible StructInfoProxy to StructInfo - if isinstance(argument, ObjectGeneric): - argument = argument.asobject() - if isinstance(argument, StructInfo): - sinfo_args[i] = struct_info_from_type(get_static_type(argument)) - elif isinstance(argument, Type): - sinfo_args[i] = struct_info_from_type(argument) - else: - raise TypeError( - "call_packed `type_args` is expected to be list of StructInfo/Type, " - f"but got {type(arg)}" - ) + if isinstance(sinfo_arg, ObjectGeneric): + sinfo_arg = sinfo_arg.asobject() + sinfo_args[i] = sinfo_arg is_default = False if "attrs_type_key" in kwargs: @@ -270,8 +256,8 @@ def call_packed( return Call(op, args, attrs=attrs, sinfo_args=sinfo_args) -def _tensor_type_wrapper(func): - """A wrapper to convert StructInfo to relax.DynTensorType""" +def _sinfo_arg_wrapper(func): + """A wrapper to convert StructInfoProxies to StructInfo for builtin operators with sinfo_args""" def _convert_tensor_type(args): if isinstance(args, (list, py_tuple)): @@ -283,7 +269,7 @@ def _convert_tensor_type(args): args = args() if isinstance(args, ObjectGeneric): args = args.asobject() - return get_static_type(args) if isinstance(args, StructInfo) else args + return args @functools.wraps(func) def wrapped(*args, **kwargs): @@ -292,9 +278,9 @@ def wrapped(*args, **kwargs): return wrapped # type: ignore -invoke_closure = _tensor_type_wrapper(invoke_closure) # pylint: disable=invalid-name +invoke_closure = _sinfo_arg_wrapper(invoke_closure) # pylint: disable=invalid-name -call_builtin = _tensor_type_wrapper(call_builtin) # pylint: disable=invalid-name +call_builtin = _sinfo_arg_wrapper(call_builtin) # pylint: disable=invalid-name ############################### Bindings ############################### diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index 9384979405d8..2de06fe5d6f2 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -101,11 +101,6 @@ StructInfo StructInfoFromType(const Type& type) { } } -// Todo(ruihang): introduced to make call_packed work. To be removed in the followup PR. -TVM_REGISTER_GLOBAL("relax.analysis.StructInfoFromType").set_body_typed([](const Type& type) { - return StructInfoFromType(type); -}); - //-------------------------- // EraseToWellDefined //-------------------------- diff --git a/src/relax/backend/vm/vm_builtin_lower.cc b/src/relax/backend/vm/vm_builtin_lower.cc index aea622ca82db..4034e8f57096 100644 --- a/src/relax/backend/vm/vm_builtin_lower.cc +++ b/src/relax/backend/vm/vm_builtin_lower.cc @@ -75,9 +75,8 @@ class VMBuiltinLowerMutator : public ExprMutator { } else { auto attrs = DefaultBuiltinAttrs(); attrs->dtype_arg = dtype; - // Todo(ruihang): reorganize in the followup PR return Call(call_builtin_op_, {builtin_compute_alloc_shape_, Tuple({shape})}, Attrs(attrs), - {StructInfoFromType(GetStaticType(GetStructInfo(shape)))}); + {GetStructInfo(shape)}); } } diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 2bec1288e01d..dd6a1bb3e1c6 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -81,7 +81,7 @@ class PrimExprSlotCollector : public ExprVisitor, public StructInfoVisitor { PrimExprSlotCollector collector; collector.slot_vec_ = slot_vec; collector.slot_map_ = slot_map; - // collect shape delcarations in func params + // collect shape declaration in func params for (auto param : func->params) { collector.VisitStructInfo(GetStructInfo(param)); collector.VisitExpr(param); @@ -283,7 +283,7 @@ class VMShapeLowerMutator } //------------------------------------------------------- - // PrimExpr slot hanlding + // PrimExpr slot handling //------------------------------------------------------- static DataType ShapeDType() { return DataType::Int(64); } @@ -310,9 +310,9 @@ class VMShapeLowerMutator //------------------------------------------------------- // Helper functions to construct BuiltinFuncAttrs //------------------------------------------------------- - // Buildin attrs that contains extra int arguments. + // Builtin attrs that contains extra int arguments. Attrs ExtraIntArgs(std::vector int_args, Optional err_ctx = NullOpt) { - // intiialize with default value + // initialize with default value auto n = make_object(); Array arr; for (int64_t val : int_args) { @@ -332,9 +332,9 @@ class VMShapeLowerMutator return Attrs(n); } - // Buildin attrs that contains extra int arguments. + // Builtin attrs that contains extra int arguments. Attrs ExtraTensorInfoArgs(int ndim, DataType dtype, String err_ctx) { - // intiialize with default value + // initialize with default value auto n = make_object(); n->int_args = {IntImm(DataType::Int(64), ndim)}; n->dtype_arg = dtype; @@ -351,15 +351,14 @@ class VMShapeLowerMutator if (heap_size->value > 0) { TensorStructInfo heap_sinfo(ShapeDType(), 1); Var var("shape_heap", heap_sinfo); - // set up the buildin func. + // set up the builtin func. auto n = make_object(); n->int_args = {heap_size}; n->dtype_arg = DataType::Void(); n->str_args = NullValue>(); n->require_ctx = true; - // Todo(ruihang): reorganize in the followup PR Call call(call_builtin_op_, {builtin_alloc_shape_heap_, Tuple(Array())}, Attrs(n), - {StructInfoFromType(GetStaticType(heap_sinfo))}); + {heap_sinfo}); UpdateStructInfo(call, heap_sinfo); return VarBinding(var, call); } else { @@ -439,7 +438,7 @@ class VMShapeLowerMutator /*! * \brief Execute the match todo items. * - * This functoin can populate vars in the match items when seeing it for the first time. + * This function can populate vars in the match items when seeing it for the first time. * These new vars will be added to this->ready_vars_. * * If an item contains PrimExpr that are yet to be computed (but may be computable through @@ -586,7 +585,7 @@ class VMShapeLowerMutator * \param struct_info The struct info to be matched. * \param value The input value. * \param always_check Whether we insert runtime check even if we can prove - * that value's struct info already satisifies the condition. + * that value's struct info already satisfies the condition. * This option is necessary for argument checking per our calling convention. * * \param err_ctx Extra error context to bring more informative error reporting. diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc index 8758abd8f87d..b3f8be1e9cff 100644 --- a/src/relax/ir/dataflow_pattern.cc +++ b/src/relax/ir/dataflow_pattern.cc @@ -548,8 +548,7 @@ ConstantPattern IsConst() { return ConstantPattern(make_object()); } ExprPattern IsExpr(const Expr& expr) { return ExprPattern(expr); } ExprPattern IsOp(const String& op_name) { return IsExpr(Op::Get(op_name)); } -CallPattern IsCallTIR(const String& name, Optional var_args, - Optional> oshape) { +CallPattern IsCallTIR(const String& name, Optional var_args) { DFPattern arg_pattern; if (!var_args.defined()) { arg_pattern = Wildcard(); @@ -557,23 +556,11 @@ CallPattern IsCallTIR(const String& name, Optional var_args, arg_pattern = var_args.value(); } - DFPattern shape_pattern; - if (!oshape.defined()) { - shape_pattern = Wildcard(); - } else { - shape_pattern = PrimArrPattern(oshape.value()); - } - - return IsOp("relax.call_tir")(GlobalVarPattern(name), arg_pattern, shape_pattern); + return IsOp("relax.call_tir")(GlobalVarPattern(name), arg_pattern); } -CallPattern IsCallTIR(const String& name, TuplePattern var_args, Array> oshapes) { - Array shape_patterns; - shape_patterns.reserve(oshapes.size()); - for (auto shape : oshapes) shape_patterns.push_back(PrimArrPattern(std::move(shape))); - - return IsOp("relax.call_tir")(GlobalVarPattern(name), var_args, - IsTuple(std::move(shape_patterns))); +CallPattern IsCallTIR(const String& name, TuplePattern var_args) { + return IsOp("relax.call_tir")(GlobalVarPattern(name), var_args); } DFPattern IsTuple(const Array& fields, bool unordered) { diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 225db03d7961..e251f87c6b8d 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -524,26 +524,21 @@ TVM_REGISTER_GLOBAL("relax.FunctionCreateEmpty") // Special opaque derivation function for ExternFunc // Take look at sinfo_args to figure out the return StructInfo. -// TODO(relax-team): revisit sinfo_args related deduction. -TVM_REGISTER_GLOBAL("tvm.relax.struct_info.infer_by_ty_args") +TVM_REGISTER_GLOBAL("tvm.relax.struct_info.infer_by_sinfo_args") .set_body_typed([](const Call& call, const BlockBuilder& ctx) -> StructInfo { - // Todo(ruihang): reorganize in the followup PR - if (call->sinfo_args.defined()) { - if (call->sinfo_args.size() == 0) { - return ObjectStructInfo(); - } else if (call->sinfo_args.size() == 1) { - return call->sinfo_args[0]; - } else { - return StructInfoFromType(GetStaticType(TupleStructInfo(call->sinfo_args))); - } - } else { + ICHECK(call->sinfo_args.defined()) << "sinfo_args field of CallNode should always be defined"; + if (call->sinfo_args.empty()) { return ObjectStructInfo(); + } else if (call->sinfo_args.size() == 1) { + return call->sinfo_args[0]; + } else { + return TupleStructInfo(call->sinfo_args); } }); // Get the derive function. FuncStructInfo GetExternFuncStructInfo() { - EnvFunc fn = EnvFunc::Get("tvm.relax.struct_info.infer_by_ty_args"); + EnvFunc fn = EnvFunc::Get("tvm.relax.struct_info.infer_by_sinfo_args"); StructInfoDeriveFunc derive; derive = fn; return FuncStructInfo::OpaqueFunc(derive); diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 190a60d36321..1fe7d7c62c48 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -71,55 +71,46 @@ StructInfo ReturnShapeStructInfo(const Call& call, const BlockBuilder& ctx) { // call_tir -StructInfo CallTIRStructInfoFromShapeType(Expr shape, Type type) { - if (auto* tuple = shape.as()) { - auto* ptr_type = type.as(); - ICHECK(ptr_type != nullptr) << "Expect tuple type and shape to be consistent."; - ICHECK_EQ(ptr_type->fields.size(), tuple->fields.size()); - Array arr; - for (size_t i = 0; i < ptr_type->fields.size(); ++i) { - arr.push_back(CallTIRStructInfoFromShapeType(tuple->fields[i], ptr_type->fields[i])); - } - return TupleStructInfo(arr); - } else { - auto* ptr_type = type.as(); - ICHECK(ptr_type != nullptr) << "Expect singleton shape to correspond to DynTensorType."; - return TensorStructInfo(shape, ptr_type->dtype); - } -} - StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { - // Todo(ruihang): reorganize in the followup PR - Expr output_shape = call->args[2]; if (call->sinfo_args.size() != 1) { ctx->ReportFatal(Diagnostic::Error(call) << "sinfo_args should have exact 1 output struct info."); } - Type output_type = GetStaticType(call->sinfo_args[0]); - return CallTIRStructInfoFromShapeType(output_shape, output_type); + return call->sinfo_args[0]; } RELAY_REGISTER_OP("relax.call_tir") - .set_num_inputs(4) + .set_num_inputs(3) .add_argument("func", "Expr", "The destination-passing-style function.") .add_argument("args", "Tuple", "The input arguments.") - .add_argument("output_shape", "Expr", "The output shape.") .add_argument("packed_ints", "Expr", "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " "args if unused") .set_attr("FInferStructInfo", InferStructInfoCallTIR); -Expr MakeCallTIR(Expr func, Tuple args, Expr output_shape, Type output_type, +Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, Optional packed_ints) { - // Todo(ruihang): reorganize in the followup PR + for (const TensorStructInfo& sinfo : out_sinfo_list) { + const auto* shape = sinfo->shape.as(); + CHECK(shape != nullptr) << "out_sinfo of call_tir should have defined ShapeExpr as shape. " + "However, one given structure info is " + << sinfo; + } + + StructInfo out_sinfo{nullptr}; + if (out_sinfo_list.size() == 1) { + out_sinfo = out_sinfo_list[0]; + } else { + out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()}); + } + static const Op& op = Op::Get("relax.call_tir"); Call call; if (!packed_ints) { // don't use additional optional argument - call = Call(op, {func, args, output_shape}, {}, {StructInfoFromType(output_type)}); + call = Call(op, {func, args}, {}, {out_sinfo}); } else { - call = Call(op, {func, args, output_shape, packed_ints.value()}, {}, - {StructInfoFromType(output_type)}); + call = Call(op, {func, args, packed_ints.value()}, {}, {out_sinfo}); } return call; } @@ -128,13 +119,12 @@ TVM_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR); // call builtin StructInfo InferStructInfoCallBuiltin(const Call& call, const BlockBuilder& ctx) { - // Todo(ruihang): reorganize in the followup PR if (call->sinfo_args.size() == 0) { // by default return void. return TupleStructInfo(Array()); } else { - ICHECK(call->sinfo_args[0].defined()) << call; - return StructInfoFromType(GetStaticType(call->sinfo_args[0])); + ICHECK_EQ(call->sinfo_args.size(), 1); + return call->sinfo_args[0]; } } @@ -144,11 +134,8 @@ TVM_REGISTER_OP("relax.call_builtin") .add_argument("args", "Tuple", "The input arguments.") .set_attr("FInferStructInfo", InferStructInfoCallBuiltin); -Expr MakeCallBuiltin(Expr func, Tuple args, Array type_args, Array int_args, +Expr MakeCallBuiltin(Expr func, Tuple args, Array sinfo_args, Array int_args, DataType dtype_arg, Array str_args, bool require_ctx) { - // Todo(ruihang): reorganize in the followup PR - Array sinfo_args = type_args.Map([](Type type) { return StructInfoFromType(type); }); - auto attrs = make_object(); attrs->int_args = int_args.Map([](IntImm value) { if (value->dtype != DataType::Int(64)) { @@ -259,13 +246,12 @@ TVM_REGISTER_GLOBAL("relax.op.make_closure").set_body_typed(MakeClosure); // invoke_closure StructInfo InferStructInfoInvokeClosure(const Call& call, const BlockBuilder& ctx) { - // Todo(ruihang): reorganize in the followup PR if (call->sinfo_args.empty()) { return ObjectStructInfo(); } else if (call->sinfo_args.size() == 1) { - return StructInfoFromType(GetStaticType(call->sinfo_args[0])); + return call->sinfo_args[0]; } else { - return StructInfoFromType(GetStaticType(TupleStructInfo(call->sinfo_args))); + return TupleStructInfo(call->sinfo_args); } } @@ -275,9 +261,7 @@ RELAY_REGISTER_OP("relax.invoke_closure") .add_argument("args", "Tuple", "The captured variables.") .set_attr("FInferStructInfo", InferStructInfoInvokeClosure); -Expr InvokeClosure(Expr closure, Tuple args, Array type_args) { - // Todo(ruihang): reorganize in the followup PR - Array sinfo_args = type_args.Map([](Type type) { return StructInfoFromType(type); }); +Expr InvokeClosure(Expr closure, Tuple args, Array sinfo_args) { static const Op& op = Op::Get("relax.invoke_closure"); return Call(op, {closure, args}, {}, sinfo_args); } diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc index 21e144f48df7..b4593f94586d 100644 --- a/src/relax/transform/call_tir_rewrite.cc +++ b/src/relax/transform/call_tir_rewrite.cc @@ -99,11 +99,11 @@ class CallTIRMutator : public ExprMutator { args = Downcast(call->args[1])->fields; args.insert(args.end(), outs.begin(), outs.end()); - if (call->args.size() == 3) { + if (call->args.size() == 2) { builder_->Emit(Call(call->args[0], args), "_"); } else { // unpack semantics - args.push_back(call->args[3]); + args.push_back(call->args[2]); builder_->Emit(Call(call_tir_dyn_op, {call->args[0], Tuple(args)}), "_"); } } else { diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index 1b75c9fd410e..15e2c15a78c0 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -35,12 +35,22 @@ class ConstantFolder : public ExprMutator { private: /*! - * \brief Pattern match expr to a constant shape and get runtime shape tuple from it. + * \brief Pattern match the shape inside the given struct info to a + * constant shape and get runtime shape tuple from it. + * \param struct_info The given struct info whose shape inside is to be casted. * \return The runtime shape tuple, or nullopt if it is not a constant shape. + * \note Only TensorStructInfo is supported at this moment. Return NullOpt + * if the input struct info is not TensorStructInfo. */ - static Optional MatchConstShape(const Expr& expr) { - auto* shape = expr.as(); - if (!shape) return NullOpt; + static Optional MatchConstShape(const StructInfo& struct_info) { + // Only support single output for call_tir at this moment. + const auto* tensor_sinfo = struct_info.as(); + if (tensor_sinfo == nullptr) { + return NullOpt; + } + + const auto* shape = tensor_sinfo->shape.as(); + ICHECK(shape != nullptr) << "struct info given by call_tir should have ShapeExpr shape"; std::vector shape_values; for (const auto v : shape->values) { @@ -145,12 +155,13 @@ class ConstantFolder : public ExprMutator { Expr VisitCallTIR(Call call) { // call_tir needs to have at least three arguments - ICHECK_GE(call->args.size(), 3); + ICHECK_GE(call->args.size(), 2); Optional func = MatchPrimFunc(call->args[0]); ICHECK(call->args[1].as()) << "call_tir.args[1] must be Tuple"; Optional> arr_args = MatchConstArrayArgs(call->args[1].as()->fields); - Optional shape = MatchConstShape(call->args[2]); + ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir should have exactly one sinfo arg"; + Optional shape = MatchConstShape(call->sinfo_args[0]); bool output_not_tuple = call->sinfo_args.size() == 1; // Pattern 0: call constant function, const argument with const shape. if (func && arr_args && shape && output_not_tuple) { diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 123b42ef9bc9..fa5c296d278e 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -387,9 +387,9 @@ class FusedTIRConstructor : public ExprVisitor { static size_t GetCallTIROutputSize(const CallNode* call) { static const Op& call_tir_op_ = Op::Get("relax.call_tir"); ICHECK(call->op.same_as(call_tir_op_)); - const Expr& output_shapes = call->args[2]; - if (const auto* tuple_output_shapes = output_shapes.as()) { - return tuple_output_shapes->fields.size(); + ICHECK_EQ(call->sinfo_args.size(), 1); + if (const auto* tuple_sinfo = call->sinfo_args[0].as()) { + return tuple_sinfo->fields.size(); } else { return 1; } @@ -662,10 +662,8 @@ class TIRFuseMutator : public ExprMutator { arg_list.insert(arg_list.end(), flattened.begin(), flattened.end()); } // Step b. Create call_tir - Array call_args = {fused_tir_gv, Tuple(arg_list), - GetCallTIRShape(GetStructInfo(call))}; - return Call(call_tir_op_, call_args, call->attrs, - {StructInfoFromType(call->checked_type())}); + Array call_args = {fused_tir_gv, Tuple(arg_list)}; + return Call(call_tir_op_, call_args, call->attrs, {GetStructInfo(call)}); } else { // Case 1.2. The callee function is not primitive, nothing to do. return call; @@ -675,8 +673,7 @@ class TIRFuseMutator : public ExprMutator { GlobalVar gv = Downcast(call->args[0]); tir::PrimFunc func = Downcast(mod_->Lookup(gv)); GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint); - return Call(call->op, {new_gv, call->args[1], call->args[2]}, call->attrs, call->sinfo_args, - call->span); + return Call(call->op, {new_gv, call->args[1]}, call->attrs, call->sinfo_args, call->span); } else { // Case 3. CallNode in other types. Leave it as it is. return call; diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc index ad5217aedbdd..f08499036b1c 100644 --- a/src/relax/transform/lambda_lift.cc +++ b/src/relax/transform/lambda_lift.cc @@ -58,7 +58,7 @@ class LambdaLifter : public ExprMutator { clo_arg = this->var_remap_.at(var->vid); } return Call(invoke_closure_op_, {clo_arg, Tuple(call_node->args)}, {}, - {StructInfoFromType(call_node->checked_type_)}); + {GetStructInfo(GetRef(call_node))}); } } if (auto global_var_node = call_node->op.as()) { diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index 7140330d4498..3ad9df2c07fa 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -69,12 +69,7 @@ class CodeGenRunner : ExprMutator { Expr new_op = VisitExpr(func); if (new_op->IsInstance()) { Array new_args({new_op}); - Array tmp_args; - for (const auto& arg : call_node->args) { - tmp_args.push_back(VisitExpr(arg)); - } - new_args.push_back(Tuple(tmp_args)); - new_args.push_back(GetShapeOf(func->body)); + new_args.push_back(Tuple(call_node->args.Map([this](Expr arg) { return VisitExpr(arg); }))); static const Op& call_op = Op::Get("relax.call_tir"); @@ -86,8 +81,7 @@ class CodeGenRunner : ExprMutator { func = (*RemoveFuncAttrFunc)(func, attr::kCodegen); builder_->UpdateFunction(gvar, func); - return Call(call_op, new_args, tvm::Attrs(), - {StructInfoFromType(GetStaticType(func->ret_struct_info))}); + return Call(call_op, new_args, tvm::Attrs(), {func->ret_struct_info}); } } Array new_args; diff --git a/src/relay/printer/relax_script_printer.cc b/src/relay/printer/relax_script_printer.cc index 7edd2d66fd30..539ec3df62f2 100644 --- a/src/relay/printer/relax_script_printer.cc +++ b/src/relay/printer/relax_script_printer.cc @@ -92,32 +92,28 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::CallNode* op) { if (op->op == call_tir_op) { doc << "R.call_tir"; - for (int i = 0; i < 3; ++i) { + ICHECK(op->args.size() == 2 || op->args.size() == 3); + for (int i = 0; i < 2; ++i) { args.push_back(Print(op->args[i])); } doc << "(" << Doc::Concat(args, Doc::Text(", ")); - Type output_type = GetStaticType(op->sinfo_args[0]); - if (const auto* out_type = output_type.as()) { - doc << ", dtype=" << PrintDType(out_type->dtype); - } else if (const auto* out_type = output_type.as()) { - std::vector dtypes; - for (auto field : out_type->fields) { - if (const auto* field_type = field.as()) { - Doc dtype; - dtype << PrintDType(field_type->dtype); - dtypes.push_back(dtype); - } else { - LOG(FATAL) << "TypeError: Invalid type: " << field_type->GetTypeKey(); - } + ICHECK(op->sinfo_args.size() == 1); + Doc out_sinfo_doc; + if (const auto* tuple_out_sinfo = op->sinfo_args[0].as()) { + std::vector field_sinfo_doc; + field_sinfo_doc.reserve(tuple_out_sinfo->fields.size()); + for (const StructInfo& field_sinfo : tuple_out_sinfo->fields) { + field_sinfo_doc.push_back(Print(field_sinfo)); } - doc << ", dtype=(" << Doc::Concat(dtypes, Doc::Text(", ")) << ")"; + out_sinfo_doc << "[" << Doc::Concat(field_sinfo_doc, Doc::Text(", ")) << "]"; } else { - LOG(FATAL) << "TypeError: Invalid type: " << output_type->GetTypeKey(); + out_sinfo_doc << Print(op->sinfo_args[0]); } + doc << ", out_sinfo=" << out_sinfo_doc; - if (op->args.size() == 4) { - doc << ", tir_vars=" << Print(op->args[3]); + if (op->args.size() == 3) { + doc << ", tir_vars=" << Print(op->args[2]); } doc << ")"; @@ -145,16 +141,14 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::CallNode* op) { } if (!op->sinfo_args.empty()) { - doc << ", type_args="; - Array type_args = - op->sinfo_args.Map([](StructInfo sinfo) { return GetStaticType(sinfo); }); - std::vector type_args_doc = PrintTypeArgs(type_args); + doc << ", sinfo_args=["; - if (type_args_doc.size() == 1) { - doc << "(" << type_args_doc[0] << " ,)"; - } else { - doc << "(" << Doc::Concat(type_args_doc, Doc::Text(", ")) << ")"; + std::vector sinfo_args_docs; + sinfo_args_docs.reserve(op->sinfo_args.size()); + for (const StructInfo& sinfo_arg : op->sinfo_args) { + sinfo_args_docs.push_back(Print(sinfo_arg)); } + doc << Doc::Concat(sinfo_args_docs, Doc::Text(", ")) << "]"; } doc << ")"; @@ -504,22 +498,6 @@ std::vector RelaxScriptPrinter::PrintAttrs(const Attrs& attrs) { return kwargs; } -std::vector RelaxScriptPrinter::PrintTypeArgs(const Array& type_args) { - std::vector type_args_doc; - if (!type_args.empty()) { - for (const auto& type : type_args) { - if (const auto* tensor = type.as()) { - Doc doc; - doc << "R.Tensor(ndim=" << tensor->ndim << ", dtype=" << PrintDType(tensor->dtype) << ")"; - type_args_doc.push_back(doc); - } else { - type_args_doc.push_back(this->VisitType(type)); - } - } - } - return type_args_doc; -} - Doc RelaxScriptPrinter::VisitAttrDefault_(const Object* op) { return PrintAttr(GetRef(op)); } diff --git a/src/relay/printer/text_printer.h b/src/relay/printer/text_printer.h index a04a40c29bd0..94802e81195f 100644 --- a/src/relay/printer/text_printer.h +++ b/src/relay/printer/text_printer.h @@ -309,7 +309,6 @@ class RelaxScriptPrinter : public relax::IRFunctor, Doc PrintAttr(const ObjectRef& attr); std::vector PrintAttrs(const Attrs& attrs); - std::vector PrintTypeArgs(const Array& type_args); Doc VisitAttrDefault_(const Object* op) override; Doc PrintExpr(const Expr& expr, bool meta, bool try_inline, bool optional_info = true); Doc VisitAttr_(const ArrayNode* op) override; diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index ee235cc6e9a4..23ad616ca2ed 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -109,8 +109,8 @@ class IdentityUnused: def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): lv0 = x - unused0 = R.call_tir("my_sigmoid", (x,), (32, 32), dtype="float32") - unused1 = R.call_tir("my_sigmoid", (unused0,), (32, 32), dtype="float32") + unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 32), dtype="float32")) R.output(lv0) return lv0 @@ -135,10 +135,10 @@ class IdentityUnused: def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): lv0 = x - unused0 = R.call_tir("my_sigmoid", (x,), (32, 32), dtype="float32") - unused1 = R.call_tir("my_sigmoid", (unused0,), (32, 32), dtype="float32") + unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 32), dtype="float32")) R.output(lv0) - z = R.call_packed("vm.builtin.copy", lv0, type_args=(R.Tensor((32, 32), "float32"))) + z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) return z optimized = remove_all_unused(IdentityUnused["main"]) @@ -150,7 +150,7 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): lv0 = x R.output(lv0) - z = R.call_packed("vm.builtin.copy", lv0, type_args=(R.Tensor((32, 32), "float32"))) + z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) return z tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) @@ -164,7 +164,7 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): lv0 = x R.output(lv0) - z = R.call_packed("vm.builtin.copy", lv0, type_args=(R.Tensor((32, 32), "float32"))) + z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) return lv0 optimized = remove_all_unused(IdentityUnused["main"]) @@ -177,7 +177,7 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: lv0 = x R.output(lv0) # This might bring side effect so cannot be removed. - z = R.call_packed("vm.builtin.copy", lv0, type_args=(R.Tensor((32, 32), "float32"))) + z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) return lv0 tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) @@ -188,7 +188,7 @@ def test_edge_binding_block_fake_unused_remove_all_unused(): class IdentityUnused: @R.function def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): - z = R.call_packed("vm.builtin.copy", x, type_args=(R.Tensor((32, 32), "float32"))) + z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((32, 32), "float32"))) return x optimized = remove_all_unused(IdentityUnused["main"]) diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index c74a2ed6125f..547791544454 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -362,7 +362,7 @@ def f( t = R.add(w, z) sh: R.Shape = R.shape_of(t) o: R.Object = R.call_packed( - "contrib.tensor_array_stack", x, y, type_args=R.Object(), test_attr=True + "contrib.tensor_array_stack", x, y, sinfo_args=R.Object(), test_attr=True ) return o @@ -425,7 +425,7 @@ def test_call_tir(): @R.function def foo(x: R.Tensor(("m", "n"), "float32")): m, n = T.var("int64"), T.var("int64") - gv0 = R.call_tir("test.op.identity", (x,), (m, n), dtype="float32") + gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) return gv0 foo_str = strip_whitespace( @@ -453,16 +453,17 @@ def foo(x: R.Tensor(("m", "n"), "float32")): "op": 'Op(name="relax.call_tir")', "args": """[ ExternFunc(global_symbol="test.op.identity"), - Tuple(fields=[ - Var(name_hint="x")]), - ShapeExpr(values=[PrimExpr(value=`m: int64`), - PrimExpr(value=`n: int64`) - ]) + Tuple(fields=[Var(name_hint="x")]) ]""", "sinfo_args": """[ TensorStructInfo( dtype=float32, - ndim=2 + shape=ShapeExpr( + values=[ + PrimExpr(value=`m: int64`), + PrimExpr(value=`n: int64`) + ] + ) ) ]""", }, diff --git a/tests/python/relax/test_autotir_integration.py b/tests/python/relax/test_autotir_integration.py index 8e3a75615db4..14c90bf0051d 100644 --- a/tests/python/relax/test_autotir_integration.py +++ b/tests/python/relax/test_autotir_integration.py @@ -65,12 +65,12 @@ def tir_relu(x:T.handle, y:T.handle): @R.function def main(x:R.Tensor((m,n), "float32"), w:R.Tensor((n,k), "float32")) -> R.Tensor: with R.dataflow(): - sh = R.call_packed("vm.builtin.shape_of", x) + sh = R.call_packed("vm.builtin.shape_of", x, sinfo_args=R.Tensor) x0 = R.match_cast(sh, R.Tensor((m, n), "float32")) - sh1 = R.call_packed("vm.builtin.shape_of", w) + sh1 = R.call_packed("vm.builtin.shape_of", w, sinfo_args=R.Tensor) x1 = R.match_cast(sh1, R.Tensor((n, k), "float32")) - lv0 = R.call_tir(tir_matmul, (x, w), (m, k), dtype="float32") - lv1 = R.call_tir(tir_relu, (lv0), (m, k), dtype="float32) + lv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((m, k), dtype="float32")) + lv1 = R.call_tir(tir_relu, (lv0), R.Tensor((m, k), dtype="float32)) R.output(lv1) return lv1 """ @@ -110,8 +110,8 @@ def tir_relu(x: T.handle, y: T.handle): @R.function def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): - lv0 = R.call_tir(tir_matmul, (x, w), (32, 32), dtype="float32") - lv1 = R.call_tir(tir_relu, (lv0), (32, 32), dtype="float32") + lv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir(tir_relu, (lv0), R.Tensor((32, 32), dtype="float32")) R.output(lv1) return lv1 @@ -211,12 +211,12 @@ def multiply1(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float @R.function def main(x: R.Tensor((128, 128), "float32")) -> R.Tensor(dtype="float32"): with R.dataflow(): - lv0 = R.call_tir(add1, (x,), (128, 128), dtype="float32") - lv1 = R.call_tir(multiply1, (lv0,), (128, 128), dtype="float32") - lv2 = R.call_tir(add2, (lv1,), (128, 128), dtype="float32") - lv3 = R.call_tir(multiply1, (lv2,), (128, 128), dtype="float32") - lv4 = R.call_tir(add3, (lv3,), (128, 128), dtype="float32") - gv = R.call_tir(add1, (lv4,), (128, 128), dtype="float32") + lv0 = R.call_tir(add1, (x,), R.Tensor((128, 128), dtype="float32")) + lv1 = R.call_tir(multiply1, (lv0,), R.Tensor((128, 128), dtype="float32")) + lv2 = R.call_tir(add2, (lv1,), R.Tensor((128, 128), dtype="float32")) + lv3 = R.call_tir(multiply1, (lv2,), R.Tensor((128, 128), dtype="float32")) + lv4 = R.call_tir(add3, (lv3,), R.Tensor((128, 128), dtype="float32")) + gv = R.call_tir(add1, (lv4,), R.Tensor((128, 128), dtype="float32")) R.output(gv) return gv diff --git a/tests/python/relax/test_backend_transform_shape_lower.py b/tests/python/relax/test_backend_transform_shape_lower.py index 58b0cd0e5422..1c93f9e95261 100644 --- a/tests/python/relax/test_backend_transform_shape_lower.py +++ b/tests/python/relax/test_backend_transform_shape_lower.py @@ -117,7 +117,7 @@ def main(x: R.Tensor(["n", 2, "m"], "float32")): [], int_args=[2], require_ctx=True, - type_args=[R.Tensor(ndim=1, dtype="int64")], + sinfo_args=[R.Tensor(ndim=1, dtype="int64")], ) _ = R.call_builtin( "vm.builtin.check_tensor_info", @@ -185,7 +185,7 @@ def main( [], int_args=[4], require_ctx=True, - type_args=[R.Tensor(ndim=1, dtype="int64")], + sinfo_args=[R.Tensor(ndim=1, dtype="int64")], ) _ = R.call_builtin( "vm.builtin.check_tensor_info", @@ -247,7 +247,7 @@ def main( MK.USE_IMM, 2, ], - type_args=[R.Shape(ndim=3)], + sinfo_args=[R.Shape(ndim=3)], ) return s @@ -286,7 +286,7 @@ def main( [], int_args=[3], require_ctx=True, - type_args=[R.Tensor(ndim=1, dtype="int64")], + sinfo_args=[R.Tensor(ndim=1, dtype="int64")], ) # recursively unpack tuple for static info check _ = R.call_builtin("vm.builtin.check_tuple_info", [x], int_args=[2], str_args=[""]) @@ -360,7 +360,7 @@ def main( [], int_args=[2], require_ctx=True, - type_args=[R.Tensor(ndim=1, dtype="int64")], + sinfo_args=[R.Tensor(ndim=1, dtype="int64")], ) _ = R.call_builtin( "vm.builtin.check_tensor_info", @@ -377,7 +377,9 @@ def main( ) _ = R.call_builtin("vm.builtin.check_tuple_info", [y], int_args=[1], str_args=[""]) # emit runtime function call since y do not have the right type. - y1 = R.call_builtin("vm.builtin.tuple_getitem", [y], int_args=[0], type_args=[R.Object]) + y1 = R.call_builtin( + "vm.builtin.tuple_getitem", [y], int_args=[0], sinfo_args=[R.Object] + ) # run check _ = R.call_builtin( "vm.builtin.check_tensor_info", diff --git a/tests/python/relax/test_binding_rewrite.py b/tests/python/relax/test_binding_rewrite.py index 0cb6dc30acae..1b424b97923a 100644 --- a/tests/python/relax/test_binding_rewrite.py +++ b/tests/python/relax/test_binding_rewrite.py @@ -228,8 +228,8 @@ class IdentityChainedUnused: def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): lv0 = x - unused0 = R.call_tir("my_sigmoid", (x,), (32, 32), dtype="float32") - unused1 = R.call_tir("my_sigmoid", (unused0,), (32, 32), dtype="float32") + unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 32), dtype="float32")) R.output(lv0) return lv0 @@ -263,19 +263,19 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): # lv4 with R.dataflow(): lv0: R.Tensor((32, 32), "float32") = R.call_tir( - "my_relu", (x,), (32, 32), dtype="float32" + "my_relu", (x,), R.Tensor((32, 32), dtype="float32") ) lv1: R.Tensor((32, 32), "float32") = R.call_tir( - "my_sigmoid", (x,), (32, 32), dtype="float32" + "my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32") ) lv2: R.Tensor((32, 32), "float32") = R.call_tir( - "my_add", (x, lv0), (32, 32), dtype="float32" + "my_add", (x, lv0), R.Tensor((32, 32), dtype="float32") ) lv3: R.Tensor((32, 32), "float32") = R.call_tir( - "my_mul", (x, lv0), (32, 32), dtype="float32" + "my_mul", (x, lv0), R.Tensor((32, 32), dtype="float32") ) lv4: R.Tensor((32, 32), "float32") = R.call_tir( - "my_whatever", (lv2, lv3), (32, 32), dtype="float32" + "my_whatever", (lv2, lv3), R.Tensor((32, 32), dtype="float32") ) R.output(lv4) return lv4 diff --git a/tests/python/relax/test_blockbuilder.py b/tests/python/relax/test_blockbuilder.py index 70511f9fee75..b5ee395753de 100644 --- a/tests/python/relax/test_blockbuilder.py +++ b/tests/python/relax/test_blockbuilder.py @@ -175,12 +175,12 @@ def before_main( x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32") ) -> R.Tensor: m, n, k = T.var("int64"), T.var("int64"), T.var("int64") - gv0 = R.call_tir("tir_matmul", (x, w), (m, k), dtype="float32") + gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k), dtype="float32")) return gv0 @R.function def after_main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: - gv0 = R.call_tir("tir_matmul", (x, w), (32, 32), dtype="float32") + gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((32, 32), dtype="float32")) return gv0 input_mod = InputModule @@ -380,7 +380,7 @@ def reshape( @R.function def rx_func(x: R.Tensor((200,), dtype="float32")) -> R.Tensor((10, 20), dtype="float32"): - gv = R.call_tir(reshape, (x,), (10, 20), dtype="float32") + gv = R.call_tir(reshape, (x,), R.Tensor((10, 20), dtype="float32")) return gv bb = rx.BlockBuilder() @@ -421,7 +421,7 @@ def full(T_full: T.Buffer[(T.int64(16), T.int64(32)), "float32"]): def rx_func( dummy_param: R.Tensor((200,), dtype="float32") ) -> R.Tensor((16, 32), dtype="float32"): - gv = R.call_tir(full, (), (16, 32), dtype="float32") + gv = R.call_tir(full, (), R.Tensor((16, 32), dtype="float32")) return gv bb = rx.BlockBuilder() @@ -466,7 +466,7 @@ def rx_func( ) -> R.Tensor(("n", "m"), dtype="float32"): n = T.var("int64") m = T.var("int64") - gv = R.call_tir(reshape, (x,), (n, m), dtype="float32") + gv = R.call_tir(reshape, (x,), R.Tensor((n, m), dtype="float32")) return gv bb = rx.BlockBuilder() @@ -523,7 +523,7 @@ def get_tir_func(): call_node = rx_func.body.blocks[0].bindings[0].value assert isinstance(call_node, rx.Call) assert call_node.op == relay.op.get("relax.call_tir") - assert len(call_node.args) == 3 + assert len(call_node.args) == 2 assert call_node.args[0].name_hint == "te_func" assert call_node.args[1][0] == x assert call_node.args[1][1] == y @@ -583,10 +583,10 @@ def te_func(A): call_node = rx_func.body.blocks[0].bindings[0].value assert call_node.op == relay.op.get("relax.call_tir") assert call_node.args[0].name_hint == "te_func" - assert isinstance(call_node.args[2], rx.Tuple) - assert len(call_node.args[2]) == 2 - assert isinstance(call_node.args[2][0], rx.ShapeExpr) - assert isinstance(call_node.args[2][1], rx.ShapeExpr) + assert isinstance(call_node.sinfo_args[0], rx.TupleStructInfo) + assert len(call_node.sinfo_args[0].fields) == 2 + assert isinstance(call_node.sinfo_args[0].fields[0].shape, rx.ShapeExpr) + assert isinstance(call_node.sinfo_args[0].fields[1].shape, rx.ShapeExpr) def test_emit_te_extern(): @@ -609,12 +609,12 @@ def test_emit_te_extern(): call_node = rx_func.body.blocks[0].bindings[0].value assert isinstance(call_node, rx.Call) assert call_node.op == relay.op.get("relax.call_tir") - assert len(call_node.args) == 3 + assert len(call_node.args) == 2 assert call_node.args[0].name_hint == "matmul" assert call_node.args[1][0] == x assert call_node.args[1][1] == y - assert call_node.args[2][0] == n - assert call_node.args[2][1] == n + assert call_node.sinfo_args[0].shape[0] == n + assert call_node.sinfo_args[0].shape[1] == n def test_emit_tuple_get_item(): diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index cc2ac7e7ea3e..e53620a3ba96 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -55,8 +55,8 @@ def tir_relu(x: T.handle, y: T.handle): @R.function def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): - lv0 = R.call_tir(tir_matmul, (x, w), (32, 32), dtype="float32") - lv1 = R.call_tir(tir_relu, (lv0), (32, 32), dtype="float32") + lv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir(tir_relu, (lv0), R.Tensor((32, 32), dtype="float32")) R.output(lv1) return lv1 @@ -294,7 +294,7 @@ def test_is_call_tir(): def simple_call_packed( x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32") ) -> R.Tensor: - gv0 = R.call_packed("test.vm.mul", x, w, type_args=(R.Tensor(ndim=2, dtype="float32"))) + gv0 = R.call_packed("test.vm.mul", x, w, sinfo_args=(R.Tensor(ndim=2, dtype="float32"))) return gv0 @@ -378,10 +378,10 @@ def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> # relu sigmoid # \ / # add - lv0 = R.call_tir("tir_matmul", (x, w), (32, 32), dtype="float32") - lv1 = R.call_tir("tir_relu", (lv0,), (32, 32), dtype="float32") - lv2 = R.call_tir("tir_sigmoid", (lv0), (32, 32), dtype="float32") - lv3 = R.call_tir("tir_add", (lv1, lv2), (32, 32), dtype="float32") + lv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir("tir_relu", (lv0,), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_tir("tir_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")) + lv3 = R.call_tir("tir_add", (lv1, lv2), R.Tensor((32, 32), dtype="float32")) R.output(lv3) return lv3 @@ -440,8 +440,8 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: # / \ # \ / # add - lv0 = R.call_tir("my_relu", (x,), (32, 32), dtype="float32") - lv1 = R.call_tir("my_add", (lv0, lv0), (32, 32), dtype="float32") + lv0 = R.call_tir("my_relu", (x,), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir("my_add", (lv0, lv0), R.Tensor((32, 32), dtype="float32")) R.output(lv1) return lv1 @@ -454,9 +454,9 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: # relu relu # \ / # add - lv0 = R.call_tir("my_relu", (x,), (32, 32), dtype="float32") - lv1 = R.call_tir("my_relu", (x,), (32, 32), dtype="float32") - lv2 = R.call_tir("my_add", (lv0, lv1), (32, 32), dtype="float32") + lv0 = R.call_tir("my_relu", (x,), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir("my_relu", (x,), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_tir("my_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32")) R.output(lv2) return lv2 @@ -507,13 +507,13 @@ def main( # \ / # concat with R.dataflow(): - lv0 = R.call_tir("conv1x1", (x, w0), (32, 32), dtype="float32") - lv1 = R.call_tir("bias_add", (lv0, bias0), (32, 32), dtype="float32") - lv2 = R.call_tir("my_relu", (lv1), (32, 32), dtype="float32") - lv3 = R.call_tir("conv1x1", (x, w1), (32, 32), dtype="float32") - lv4 = R.call_tir("bias_add", (lv3, bias1), (32, 32), dtype="float32") - lv5 = R.call_tir("my_relu", (lv4), (32, 32), dtype="float32") - lv6 = R.call_tir("concat", (lv2, lv5), (32, 64), dtype="float32") + lv0 = R.call_tir("conv1x1", (x, w0), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir("bias_add", (lv0, bias0), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_tir("my_relu", (lv1), R.Tensor((32, 32), dtype="float32")) + lv3 = R.call_tir("conv1x1", (x, w1), R.Tensor((32, 32), dtype="float32")) + lv4 = R.call_tir("bias_add", (lv3, bias1), R.Tensor((32, 32), dtype="float32")) + lv5 = R.call_tir("my_relu", (lv4), R.Tensor((32, 32), dtype="float32")) + lv6 = R.call_tir("concat", (lv2, lv5), R.Tensor((32, 64), dtype="float32")) R.output(lv6) return lv6 @@ -626,8 +626,8 @@ def main( c: R.Tensor((48, 32), "float32"), ) -> R.Tensor: with R.dataflow(): - lv0 = R.call_tir("matmul", (a, b), (32, 48), dtype="float32") - lv1 = R.call_tir("matmul", (lv0, c), (32, 32), dtype="float32") + lv0 = R.call_tir("matmul", (a, b), R.Tensor((32, 48), dtype="float32")) + lv1 = R.call_tir("matmul", (lv0, c), R.Tensor((32, 32), dtype="float32")) R.output(lv1) return lv1 @@ -661,10 +661,12 @@ def main( c: R.Tensor((16, 32), "float32"), ) -> R.Tensor: with R.dataflow(): - lv0 = R.call_tir("my_concat", (b, c), (32, 32), dtype="float32") - lv1 = R.call_tir("my_matmul", (a, lv0), (32, 32), dtype="float32") + lv0 = R.call_tir("my_concat", (b, c), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir("my_matmul", (a, lv0), R.Tensor((32, 32), dtype="float32")) lv2 = R.call_tir( - "my_split", (lv1,), ((16, 32), (16, 32)), dtype=("float32", "float32") + "my_split", + (lv1,), + [R.Tensor((16, 32), dtype="float32"), R.Tensor((16, 32), dtype="float32")], ) lv3 = R.TupleGetItem(lv2, 0) lv4 = R.TupleGetItem(lv2, 1) @@ -709,18 +711,18 @@ def main( ) -> R.Tensor: b, s, n, h = T.var("int64"), T.var("int64"), T.var("int64"), T.var("int64") with R.dataflow(): - fcq = R.call_tir("my_fc", (x, wq), (b, s, n, h), dtype="float32") - tpq = R.call_tir("my_transpose", (fcq,), (b, s, h, n), dtype="float32") + fcq = R.call_tir("my_fc", (x, wq), R.Tensor((b, s, n, h), dtype="float32")) + tpq = R.call_tir("my_transpose", (fcq,), R.Tensor((b, s, h, n), dtype="float32")) - fck = R.call_tir("my_fc", (x, wk), (b, s, n, h), dtype="float32") - tpk = R.call_tir("my_transpose", (fck,), (b, s, h, n), dtype="float32") + fck = R.call_tir("my_fc", (x, wk), R.Tensor((b, s, n, h), dtype="float32")) + tpk = R.call_tir("my_transpose", (fck,), R.Tensor((b, s, h, n), dtype="float32")) mul = R.multiply(tpq, tpk) scale = R.multiply(mul, R.const(1.1, "float32")) - softmax = R.call_tir("softmax", (scale,), (b, s, n, h), dtype="float32") + softmax = R.call_tir("softmax", (scale,), R.Tensor((b, s, n, h), dtype="float32")) - fcv = R.call_tir("my_fc", (x, wv), (b, s, n, h), dtype="float32") - tpv = R.call_tir("my_transpose", (fcv,), (b, s, h, n), dtype="float32") + fcv = R.call_tir("my_fc", (x, wv), R.Tensor((b, s, n, h), dtype="float32")) + tpv = R.call_tir("my_transpose", (fcv,), R.Tensor((b, s, h, n), dtype="float32")) out = R.multiply(softmax, tpv) R.output(out) @@ -750,14 +752,14 @@ def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> # add5 add6 # \ / # add7 - lv0 = R.call_tir("tir_matmul", (x, w), (32, 32), dtype="float32") - lv1 = R.call_tir("tir_matmul", (x, w), (32, 32), dtype="float32") - lv2 = R.call_tir("tir_sigmoid", (lv0), (32, 32), dtype="float32") - lv3 = R.call_tir("tir_sigmoid", (lv1), (32, 32), dtype="float32") - lv4 = R.call_tir("tir_add", (lv0, lv1), (32, 32), dtype="float32") - lv5 = R.call_tir("tir_add", (lv2, lv4), (32, 32), dtype="float32") - lv6 = R.call_tir("tir_add", (lv3, lv4), (32, 32), dtype="float32") - lv7 = R.call_tir("tir_add", (lv5, lv6), (32, 32), dtype="float32") + lv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir("tir_matmul", (x, w), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_tir("tir_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")) + lv3 = R.call_tir("tir_sigmoid", (lv1), R.Tensor((32, 32), dtype="float32")) + lv4 = R.call_tir("tir_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32")) + lv5 = R.call_tir("tir_add", (lv2, lv4), R.Tensor((32, 32), dtype="float32")) + lv6 = R.call_tir("tir_add", (lv3, lv4), R.Tensor((32, 32), dtype="float32")) + lv7 = R.call_tir("tir_add", (lv5, lv6), R.Tensor((32, 32), dtype="float32")) R.output(lv7) return lv7 @@ -809,9 +811,9 @@ def test_incremental_solving(): def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): # relu -> sigmoid -> neg - lv0 = R.call_tir("tir_relu", (x), (32, 32), dtype="float32") - lv1 = R.call_tir("tir_sigmoid", (lv0), (32, 32), dtype="float32") - lv2 = R.call_tir("tir_neg", (lv1), (32, 32), dtype="float32") + lv0 = R.call_tir("tir_relu", (x), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir("tir_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_tir("tir_neg", (lv1), R.Tensor((32, 32), dtype="float32")) R.output(lv2) return lv2 @@ -838,8 +840,8 @@ def test_incremental_solving_counter(): def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): # sigmoid -> neg - lv0 = R.call_tir("tir_sigmoid", (x), (32, 32), dtype="float32") - lv1 = R.call_tir("tir_neg", (lv0), (32, 32), dtype="float32") + lv0 = R.call_tir("tir_sigmoid", (x), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir("tir_neg", (lv0), R.Tensor((32, 32), dtype="float32")) R.output(lv1) return lv1 diff --git a/tests/python/relax/test_op.py b/tests/python/relax/test_op.py index b4d4c0ce773c..3a5697d3e371 100644 --- a/tests/python/relax/test_op.py +++ b/tests/python/relax/test_op.py @@ -39,8 +39,8 @@ def identity_tir(a: T.handle, b: T.handle) -> None: def test_call_tir() -> None: v0 = rx.Var("v0", R.Tensor([54, 96], "float32")) - v1 = rx.call_tir(rx.extern("test.op.identity"), [v0], [54, 96], "float32") - v1 = rx.call_tir(identity_tir, [v0], [54, 96], "float32") + v1 = rx.call_tir(rx.extern("test.op.identity"), [v0], R.Tensor((54, 96), "float32")) + v1 = rx.call_tir(identity_tir, [v0], R.Tensor((54, 96), "float32")) def test_implicit_op(): diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index 46bd73bfaf69..01e55b568489 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -27,12 +27,13 @@ # c.f. tests/python/unittest/test_tvmscript_error_report.py -def check_call(call, op, args): +def check_call(call, op, args, sinfo_args=[]): assert isinstance(call, relax.Call) if isinstance(op, str): op = relay.op.get(op) assert call.op == op assert_structural_equal(call.args, args) + assert_structural_equal(call.sinfo_args, sinfo_args) def test_annotations(): @@ -48,7 +49,7 @@ def f( q: R.Tensor(ndim=2) = R.add(w, w) t = R.add(w, z) sh: R.Shape = R.shape_of(t) - o: R.Object = R.call_packed("contrib.tensor_array_stack", x, y, type_args=R.Object) + o: R.Object = R.call_packed("contrib.tensor_array_stack", x, y, sinfo_args=R.Object) return o x, y, r = f.params @@ -147,7 +148,7 @@ def test_unexpected_tir_cast_args(): @R.function def f(x: R.Tensor(("m",), "float32")): m = T.var("int64") - return R.call_tir("foo", (x,), (T.cast("int32", m, 1),), dtype="float32") + return R.call_tir("foo", (x,), R.Tensor((T.cast("int32", m, 1),), dtype="float32")) def test_unexpected_tir_max_args(): @@ -157,7 +158,7 @@ def test_unexpected_tir_max_args(): @R.function def f(x: R.Tensor(("m", "n"), "float32")): m = T.var("int64") - return relax.call_tir("foo", (x,), (T.max(m),), dtype="float32") + return relax.call_tir("foo", (x,), R.Tensor((T.max(m),), dtype="float32")) def test_match_cast(): @@ -469,13 +470,18 @@ def test_call_tir(): @R.function def foo(x: R.Tensor(("m", "n"), "float32")): m, n = T.var("int64"), T.var("int64") - gv0 = relax.call_tir("test.op.identity", (x,), (m, n), dtype="float32") + gv0 = relax.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) return gv0 call_tir_node = foo.body.blocks[0].bindings[0].value assert call_tir_node.attrs is None - assert_structural_equal( - call_tir_node.sinfo_args[0], relax.TensorStructInfo(ndim=2, dtype="float32") + (x,) = foo.params + m, n = x.struct_info.shape + check_call( + call_tir_node, + "relax.call_tir", + [relax.ExternFunc("test.op.identity"), relax.Tuple([x])], + sinfo_args=[R.Tensor((m, n), dtype="float32")], ) @@ -496,7 +502,7 @@ def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] += A[vi, vk] * B[vj, vk] B = T.var("int64") - z = R.call_tir(my_matmul, (x, y), (B, 128), dtype="float32") + z = R.call_tir(my_matmul, (x, y), R.Tensor((B, 128), dtype="float32")) return z x, y = f.params @@ -509,7 +515,8 @@ def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: check_call( z_bind.value, "relax.call_tir", - [mm_bind.var, relax.Tuple([x, y]), relax.ShapeExpr([B, tir.IntImm("int64", 128)])], + [mm_bind.var, relax.Tuple([x, y])], + sinfo_args=[R.Tensor((B, 128), dtype="float32")], ) @@ -522,7 +529,7 @@ def f(x: R.Tensor((3, 3), "float32")): x, x, mp=False, - type_args=(R.Tensor(ndim=2, dtype="float32")), + sinfo_args=(R.Tensor(ndim=2, dtype="float32")), ) w = R.call_packed( @@ -530,16 +537,16 @@ def f(x: R.Tensor((3, 3), "float32")): x, dtype="int32", attrs_type_key="relay.attrs.ShapeOfAttrs", - type_args=(R.Shape), + sinfo_args=(R.Shape), ) - o = R.call_packed("contrib.tensor_array_stack", x, z, type_args=(R.Object)) + o = R.call_packed("contrib.tensor_array_stack", x, z, sinfo_args=(R.Object)) k = R.call_packed( "contrib.construct_tuple", x, x, - type_args=(R.Tuple(R.Tuple(R.Tensor(ndim=2, dtype="float32"), R.Tensor), R.Tensor)), + sinfo_args=(R.Tuple(R.Tuple(R.Tensor(ndim=2, dtype="float32"), R.Tensor), R.Tensor)), ) return k @@ -579,7 +586,7 @@ def f(x: R.Tensor((3, 3), "float32")): ) -def test_call_packed_no_type_args_fail(): +def test_call_packed_no_sinfo_args_fail(): with pytest.raises(tvm.error.DiagnosticError): @R.function @@ -589,13 +596,13 @@ def f(x: R.Tensor((3, 3), "float32")): return z -def test_call_packed_wrong_type_args_fail(): +def test_call_packed_wrong_sinfo_args_fail(): with pytest.raises(tvm.error.DiagnosticError): @R.function def f(x: R.Tensor((3, 3), "float32")): z: R.Tensor((n, m), "float32") = relax.call_packed( - "contrib.my_matmul", x, x, type_args=(Tuple) + "contrib.my_matmul", x, x, sinfo_args=(Tuple) ) return z @@ -627,7 +634,7 @@ def test_primexpr_arithmetic(): def f(x: R.Tensor(("n", "m"), "float32")): n, m = T.var("int64"), T.var("int64") z: R.Tensor((n * m,), "float32") = R.call_packed( - "my_flatten", (x,), type_args=(R.Tensor(ndim=1, dtype="float32")) + "my_flatten", (x,), sinfo_args=(R.Tensor(ndim=1, dtype="float32")) ) sh: R.Shape = (n + m, n // m) return z @@ -643,7 +650,7 @@ def f(x: R.Tensor(("n", "m"), "float32")): def test_call_tir_extern(): @R.function def f(x: R.Tensor) -> R.Tensor: - z = R.call_tir("my_extern", (x,), (10,), dtype="float32") + z = R.call_tir("my_extern", (x,), R.Tensor((10,), dtype="float32")) return z x = f.params[0] @@ -652,11 +659,8 @@ def f(x: R.Tensor) -> R.Tensor: check_call( z_bind.value, "relax.call_tir", - [ - relax.ExternFunc("my_extern"), - relax.Tuple([x]), - relax.ShapeExpr([tir.IntImm("int64", 10)]), - ], + [relax.ExternFunc("my_extern"), relax.Tuple([x])], + sinfo_args=[R.Tensor((10,), dtype="float32")], ) @@ -672,7 +676,7 @@ def scalar_add(a: T.handle, b: T.handle, c: T.handle) -> None: with T.block("add"): C[()] = A[()] + B[()] - z = relax.call_tir(scalar_add, (x, y), (), dtype="float32") + z = relax.call_tir(scalar_add, (x, y), R.Tensor((), dtype="float32")) return z x, y = f.params @@ -684,7 +688,8 @@ def scalar_add(a: T.handle, b: T.handle, c: T.handle) -> None: check_call( z_bind.value, "relax.call_tir", - [add_bind.var, relax.Tuple([x, y]), relax.ShapeExpr([])], + [add_bind.var, relax.Tuple([x, y])], + sinfo_args=[R.Tensor((), dtype="float32")], ) @@ -711,13 +716,13 @@ def f(x: R.Tensor(("n", "n"))) -> R.Tensor: @R.function def g(y: R.Tensor(("n", "n"))) -> R.Tensor: n = T.var("int64") - return R.call_tir(my_matmul, (y, y), (n, n), dtype="float32") + return R.call_tir(my_matmul, (y, y), R.Tensor((n, n), dtype="float32")) @R.function def j(y: R.Tensor(("n", "n"))) -> R.Tensor: n = T.var("int64") with R.dataflow(): - gv = R.call_tir(my_matmul, (y, y), (n, n), dtype="float32") + gv = R.call_tir(my_matmul, (y, y), R.Tensor((n, n), dtype="float32")) gv1 = (gv, gv) gv2 = gv1[1] R.output(gv2) @@ -725,7 +730,7 @@ def j(y: R.Tensor(("n", "n"))) -> R.Tensor: @R.function def k(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: - gv0 = R.call_packed("test.vm.mul", x, w, type_args=(R.Tensor(ndim=2, dtype="float32"))) + gv0 = R.call_packed("test.vm.mul", x, w, sinfo_args=(R.Tensor(ndim=2, dtype="float32"))) return gv0 my_module = MyModule @@ -840,7 +845,7 @@ def main( ) -> R.Tensor: R.func_attr({"global_symbol": "main"}) m, n, k = T.var("int64"), T.var("int64"), T.var("int64") - gv0 = R.call_tir("tir_matmul", (x, w), (m, k), dtype="float32") + gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k), dtype="float32")) return gv0 assert InputModule["main"].attrs["global_symbol"] == "main" diff --git a/tests/python/relax/test_printer.py b/tests/python/relax/test_printer.py index 03a06b8b863e..d25641e5e530 100644 --- a/tests/python/relax/test_printer.py +++ b/tests/python/relax/test_printer.py @@ -178,7 +178,7 @@ def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = 0.0 C[vi, vj] += A[vi, vk] * B[vj, vk] - z = R.call_tir(my_matmul, (x, y), (B, 128), dtype="float32") + z = R.call_tir(my_matmul, (x, y), R.Tensor((B, 128), dtype="float32")) return z check_roundtrip(foo) @@ -190,16 +190,16 @@ def foo(x: R.Tensor((3, 3), "float32")): # test that we can intro dim vars n, m = T.var("int64"), T.var("int64") z: R.Tensor((n, m), "float32") = R.call_packed( - "contrib.my_matmul", x, x, mp=False, type_args=R.Tensor(ndim=2, dtype="float32") + "contrib.my_matmul", x, x, mp=False, sinfo_args=R.Tensor((n, m), dtype="float32") ) w = R.call_packed( "contrib.my_shape_of", x, dtype="int32", attrs_type_key="relay.attrs.ShapeOfAttrs", - type_args=R.Shape, + sinfo_args=R.Shape(), ) - o = R.call_packed("contrib.tensor_array_stack", x, z, type_args=R.Object) + o = R.call_packed("contrib.tensor_array_stack", x, z, sinfo_args=R.Object()) return z check_roundtrip(foo) @@ -208,7 +208,7 @@ def foo(x: R.Tensor((3, 3), "float32")): def test_relax_base_op(): @R.function def foo(x: R.Tensor((2, 4), dtype="float32")): - gv = R.call_builtin("test_intrin", [x], type_args=R.Object) + gv = R.call_builtin("test_intrin", [x], sinfo_args=R.Object) return gv check_roundtrip(foo) @@ -219,7 +219,7 @@ def test_primexpr_arithmetic(): def foo(x: R.Tensor(("n", "m"), "float32")): n, m = T.var("int64"), T.var("int64") z: R.Tensor((n * m,), "float32") = R.call_packed( - "my_flatten", (x,), type_args=R.Tensor(ndim=1, dtype="float32") + "my_flatten", (x,), sinfo_args=R.Tensor((n * m,), dtype="float32") ) sh: R.Shape = (n + m, n // m) return z @@ -230,7 +230,7 @@ def foo(x: R.Tensor(("n", "m"), "float32")): def test_call_tir_extern(): @R.function def foo(x: R.Tensor): - z = R.call_tir("my_extern", (x,), (10,), dtype="float32") + z = R.call_tir("my_extern", (x,), R.Tensor((10,), dtype="float32")) return z check_roundtrip(foo) @@ -350,7 +350,7 @@ def f(x: R.Tensor(("n", "n"))) -> R.Tensor: @R.function def g(y: R.Tensor(("n", "n"))) -> R.Tensor(("n", "n"), "float32"): n = T.var("int64") - r = relax.call_tir(my_matmul, (y, y), (n, n), dtype="float32") + r = relax.call_tir(my_matmul, (y, y), R.Tensor((n, n), dtype="float32")) return r @R.function @@ -358,7 +358,7 @@ def h( x: R.Tensor(("n", "n")), y: R.Tensor(("n", "n")), z: R.Tensor(("n", "n")) ) -> R.Tensor: n = T.var("int64") - _ = R.call_tir(my_matmul, (x, y), (n, n), dtype="float32") + _ = R.call_tir(my_matmul, (x, y), R.Tensor((n, n), dtype="float32")) return z my_module = MyModule @@ -369,7 +369,7 @@ def test_tir_max(): @R.function def tir_max(x: R.Tensor(("m", "n"), "float32")): m, n = T.var("int64"), T.var("int64") - gv = relax.call_tir("my_extern", (x,), (T.max(n, m),), dtype="float32") + gv = relax.call_tir("my_extern", (x,), R.Tensor((T.max(n, m),), dtype="float32")) return gv check_roundtrip(tir_max) @@ -379,7 +379,9 @@ def test_tir_cast(): @R.function def tir_cast(x: R.Tensor(("m",), "float32")): m = T.var("int64") - gv = R.call_tir("my_extern", (x,), (T.cast(T.cast(m, "int32"), "int64"),), dtype="float32") + gv = R.call_tir( + "my_extern", (x,), R.Tensor((T.cast(T.cast(m, "int32"), "int64"),), dtype="float32") + ) return gv check_roundtrip(tir_cast) @@ -472,7 +474,7 @@ def test_call_pretty_print(): call = relax.Call(extern_func, [x], sinfo_args=[R.Tensor(ndim=1, dtype="float32")]) assert ( call.__str__() - == 'R.call_packed("my_func", x, type_args=(R.Tensor(ndim=1, dtype="float32") ,))' + == 'R.call_packed("my_func", x, sinfo_args=[R.Tensor(dtype="float32", ndim=1)])' ) diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 8a1f2e04f454..aff1c121973c 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -215,8 +215,8 @@ class TestToNonDataflow: def foo(x: R.Tensor(("m", "n"), "float32")): m, n = T.var("int64"), T.var("int64") with R.dataflow(): - lv0 = R.call_tir("test.op.identity", (x,), (m, n), dtype="float32") - gv0 = R.call_tir("test.op.identity", (lv0,), (m, n), dtype="float32") + lv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + gv0 = R.call_tir("test.op.identity", (lv0,), R.Tensor((m, n), dtype="float32")) R.output(gv0) return gv0 @@ -259,7 +259,7 @@ class TestCallTIRRewrite: @R.function def foo(x: R.Tensor(("m", "n"), "float32")): m, n = T.var("int64"), T.var("int64") - gv0 = R.call_tir("test.op.identity", (x,), (m, n), dtype="float32") + gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) return gv0 mod = TestCallTIRRewrite @@ -281,7 +281,7 @@ def foo(x: R.Tensor(("m", "n"), "float32")): assert isinstance(s1, relax.Call) assert s1.op.name == "relax.builtin.alloc_tensor" assert isinstance(s1.args[0], relax.ShapeExpr) - assert structural_equal(s1.args[0], s0.args[2]) + assert structural_equal(s1.args[0], s0.sinfo_args[0].shape) s2 = block.bindings[1].value assert s2.op.global_symbol == "test.op.identity" @@ -294,7 +294,7 @@ def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor: m, n = T.var("int64"), T.var("int64") alloc = R.builtin.alloc_tensor((m, n), runtime_device_index=0, dtype="float32") _ = R.call_packed( - "test.op.identity", x, alloc, type_args=(R.Tensor(ndim=2, dtype="float32")) + "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) ) gv0 = alloc return gv0 diff --git a/tests/python/relax/test_transform_attach_global_symbol.py b/tests/python/relax/test_transform_attach_global_symbol.py index a574b7e3e5fc..edfc646e2108 100644 --- a/tests/python/relax/test_transform_attach_global_symbol.py +++ b/tests/python/relax/test_transform_attach_global_symbol.py @@ -45,7 +45,7 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: @R.function def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")) -> R.Tensor: m, n, k = T.var("int64"), T.var("int64"), T.var("int64") - gv0 = R.call_tir("tir_matmul", (x, w), (m, k), dtype="float32") + gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k), dtype="float32")) return gv0 @@ -75,7 +75,7 @@ def main( ) -> R.Tensor: R.func_attr({"global_symbol": "main"}) m, n, k = T.var("int64"), T.var("int64"), T.var("int64") - gv0 = R.call_tir("tir_matmul", (x, w), (m, k), dtype="float32") + gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k), dtype="float32")) return gv0 before = Before diff --git a/tests/python/relax/test_transform_bind_params.py b/tests/python/relax/test_transform_bind_params.py index e8d6206e48d2..b96fb89e6c0a 100644 --- a/tests/python/relax/test_transform_bind_params.py +++ b/tests/python/relax/test_transform_bind_params.py @@ -48,7 +48,7 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: def main( x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") ) -> R.Tensor((16, 16), "float32"): - gv0 = R.call_tir(tir_matmul, (x, w), (16, 16), dtype="float32") + gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((16, 16), dtype="float32")) return gv0 x_np = np.random.rand(16, 16).astype(np.float32) diff --git a/tests/python/relax/test_transform_fold_constant.py b/tests/python/relax/test_transform_fold_constant.py index d642a408deac..32ee3e700080 100644 --- a/tests/python/relax/test_transform_fold_constant.py +++ b/tests/python/relax/test_transform_fold_constant.py @@ -67,7 +67,7 @@ def addone(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"]) - @R.function def before(c0: R.Tensor((16, 16), "float32")): - lv0 = relax.call_tir(addone, (c0,), (16, 16), dtype="float32") + lv0 = relax.call_tir(addone, (c0,), R.Tensor((16, 16), dtype="float32")) return lv0 @R.function @@ -97,7 +97,7 @@ def func(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(3, 2), "float32"]) -> None @R.function def before(c0: R.Tensor((2, 3), "float32")): - lv0 = relax.call_tir(func, (c0,), (3, 2), dtype="float32") + lv0 = relax.call_tir(func, (c0,), R.Tensor((3, 2), dtype="float32")) return lv0 @R.function @@ -126,8 +126,8 @@ def addone(A: T.Buffer[(2, 2), "float32"], B: T.Buffer[(2, 2), "float32"]) -> No @R.function def before(c0: R.Tensor((2, 2), "float32")): - lv0 = relax.call_tir(addone, (c0,), (2, 2), dtype="float32") - lv1 = relax.call_tir(addone, (lv0,), (2, 2), dtype="float32") + lv0 = relax.call_tir(addone, (c0,), R.Tensor((2, 2), dtype="float32")) + lv1 = relax.call_tir(addone, (lv0,), R.Tensor((2, 2), dtype="float32")) return lv1 @R.function @@ -159,7 +159,7 @@ def identity(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"]) @R.function def before(c0: R.Tensor((16, 16), "float32")): with R.dataflow(): - gv0 = relax.call_tir(identity, (c0,), (16, 16), dtype="float32") + gv0 = relax.call_tir(identity, (c0,), R.Tensor((16, 16), dtype="float32")) R.output(gv0) return gv0 @@ -209,13 +209,13 @@ def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor("float32", ndim=2)): n, m = T.var("int64"), T.var("int64") x0 = R.match_cast(x, R.Tensor((n, m), "float32")) # this line cannot be folded because n is unknown - lv0 = relax.call_tir(addone, (c0,), (n, 16), dtype="float32") + lv0 = relax.call_tir(addone, (c0,), R.Tensor((n, 16), dtype="float32")) # this line can be folded - lv1 = relax.call_tir(addone, (c0,), (16, 16), dtype="float32") + lv1 = relax.call_tir(addone, (c0,), R.Tensor((16, 16), dtype="float32")) # this line can be folded because all inputs are const - lv2 = relax.call_tir(sub, (c0, lv1), (16, 16), dtype="float32") + lv2 = relax.call_tir(sub, (c0, lv1), R.Tensor((16, 16), dtype="float32")) # this line can not be folded because x's shape is unknown - lv3 = relax.call_tir(sub, (lv2, x), (16, 16), dtype="float32") + lv3 = relax.call_tir(sub, (lv2, x), R.Tensor((16, 16), dtype="float32")) return lv3 @R.function @@ -228,13 +228,13 @@ def expected( n, m = T.var("int64"), T.var("int64") x0 = R.match_cast(x, R.Tensor((n, m), "float32")) # this line cannot be folded because n is unknown - lv0 = relax.call_tir(addone, (c0,), (n, 16), dtype="float32") + lv0 = relax.call_tir(addone, (c0,), R.Tensor((n, 16), dtype="float32")) # this line can be folded lv1 = c1 # this line can be folded because all inputs are const lv2 = c2 # this line can not be folded because x's shape is unknown - lv3 = relax.call_tir(sub, (c2, x), (16, 16), dtype="float32") + lv3 = relax.call_tir(sub, (c2, x), R.Tensor((16, 16), dtype="float32")) return lv3 c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) @@ -259,7 +259,7 @@ def addone(A: T.Buffer[(16, 16), "int32"], B: T.Buffer[(16, 16), "int32"]) -> No @R.function def before(c0: R.Tensor((16, 16), "int32")): - lv0 = relax.call_tir(addone, (c0,), (16, 16), dtype="int32") + lv0 = relax.call_tir(addone, (c0,), R.Tensor((16, 16), dtype="int32")) return lv0 @R.function diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 4f4a4ce3b80b..868ba62bd252 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -23,6 +23,9 @@ def _check(mod_before, mod_expected): mod = relax.transform.FuseTIR()(mod_before) + print(mod.script()) + print(mod_expected.script()) + print(tvm.ir.base.get_first_structural_mismatch(mod, mod_expected)) tvm.ir.assert_structural_equal(mod, mod_expected) diff --git a/tests/python/relax/test_transform_lambda_lift.py b/tests/python/relax/test_transform_lambda_lift.py index 785b95860297..71f088b01b60 100644 --- a/tests/python/relax/test_transform_lambda_lift.py +++ b/tests/python/relax/test_transform_lambda_lift.py @@ -94,7 +94,7 @@ def main( ) -> R.Tensor((2, 3), "float32"): outer_func = lifted_func_0 in_call = outer_func(x) - res = R.invoke_closure(in_call, (y,), type_args=(R.Tensor(ndim=2, dtype="float32"))) + res = R.invoke_closure(in_call, (y,), sinfo_args=(R.Tensor((2, 3), dtype="float32"))) return res @R.function @@ -117,7 +117,7 @@ def main( @R.function def outer_func(c1: R.Tensor((2, 3), "float32")): @R.function - def inner_func(x1: R.Tensor((2, 3), "float32")): + def inner_func(x1: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): s: R.Tensor((2, 3), "float32") = R.add(x1, c1) return s @@ -144,7 +144,7 @@ def lifted_func_0( i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32"), x: R.Tensor((2, 3), "float32") ) -> R.Tensor((2, 3), "float32"): cond: R.Tensor((), "bool") = R.call_packed( - "test.vm.less", i, R.const(10), type_args=(R.Tensor(ndim=0, dtype="bool")) + "test.vm.less", i, R.const(10), sinfo_args=(R.Tensor((), dtype="bool")) ) c: R.Tensor((), "int32") = R.const(1, dtype="int32") if cond: @@ -161,7 +161,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor: gv = R.invoke_closure( while_loop, (relax.const(0), x), - type_args=(R.Tensor(ndim=2, dtype="float32")), + sinfo_args=(R.Tensor(ndim=2, dtype="float32")), ) return gv @@ -175,7 +175,7 @@ def while_loop( i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32") ) -> R.Tensor((2, 3), "float32"): cond: R.Tensor((), "bool") = R.call_packed( - "test.vm.less", i, R.const(10), type_args=(R.Tensor(ndim=0, dtype="bool")) + "test.vm.less", i, R.const(10), sinfo_args=(R.Tensor((), dtype="bool")) ) c: R.Tensor((), "int32") = R.const(1, dtype="int32") if cond: @@ -290,7 +290,7 @@ def sub( @R.function def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor(dtype="float32", ndim=2)): - s = R.call_tir(sub, (c0, x), (16, 16), dtype="float32") + s = R.call_tir(sub, (c0, x), R.Tensor((16, 16), dtype="float32")) return s before = Before diff --git a/tests/python/relax/test_transform_meta_schedule_tuning.py b/tests/python/relax/test_transform_meta_schedule_tuning.py index 8d92a686907a..ff695b9436a3 100644 --- a/tests/python/relax/test_transform_meta_schedule_tuning.py +++ b/tests/python/relax/test_transform_meta_schedule_tuning.py @@ -61,8 +61,8 @@ def tir_relu(x: T.handle, y: T.handle): @R.function def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): - lv0 = R.call_tir(tir_matmul, (x, w), (32, 32), dtype="float32") - lv1 = R.call_tir(tir_relu, (lv0), (32, 32), dtype="float32") + lv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir(tir_relu, (lv0), R.Tensor((32, 32), dtype="float32")) R.output(lv1) return lv1 diff --git a/tests/python/relax/test_transform_normalize.py b/tests/python/relax/test_transform_normalize.py index 9037b907b445..9e9533a5ed23 100644 --- a/tests/python/relax/test_transform_normalize.py +++ b/tests/python/relax/test_transform_normalize.py @@ -124,8 +124,8 @@ class ANFMod2: def foo(x: R.Tensor(("m", "n"), "float32")): m, n = T.var("int64"), T.var("int64") with R.dataflow(): - lv0 = R.call_tir("test.op.identity", (x,), (m, n), dtype="float32") - gv0 = R.call_tir("test.op.identity", (lv0,), (m, n), dtype="float32") + lv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + gv0 = R.call_tir("test.op.identity", (lv0,), R.Tensor((m, n), dtype="float32")) R.output(gv0) return gv0 diff --git a/tests/python/relax/test_transform_remove_unused_funcs.py b/tests/python/relax/test_transform_remove_unused_funcs.py index 8b29596f6705..c4c30a0ccf4f 100644 --- a/tests/python/relax/test_transform_remove_unused_funcs.py +++ b/tests/python/relax/test_transform_remove_unused_funcs.py @@ -51,7 +51,7 @@ def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32 def main( x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") ) -> R.Tensor((16, 16), "float32"): - gv0 = R.call_tir(tir_add, (x, w), (16, 16), dtype="float32") + gv0 = R.call_tir(tir_add, (x, w), R.Tensor((16, 16), dtype="float32")) return gv0 mod = InputModule @@ -85,7 +85,7 @@ def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32 def foo( x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") ) -> R.Tensor((16, 16), "float32"): - gv0 = R.call_tir(tir_add, (x, w), (16, 16), dtype="float32") + gv0 = R.call_tir(tir_add, (x, w), R.Tensor((16, 16), dtype="float32")) return gv0 mod = InputModule @@ -121,7 +121,7 @@ def unused_func(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "flo @R.function def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")): m, k = T.var("int64"), T.var("int64") - gv0 = R.call_tir(tir_add, (x, w), (m + 1, k), dtype="float32") + gv0 = R.call_tir(tir_add, (x, w), R.Tensor((m + 1, k), dtype="float32")) return gv0 mod = InputModule diff --git a/tests/python/relax/test_tuning_api.py b/tests/python/relax/test_tuning_api.py index 3c20a10a7085..b12ff016705d 100644 --- a/tests/python/relax/test_tuning_api.py +++ b/tests/python/relax/test_tuning_api.py @@ -57,7 +57,7 @@ def addone(A: T.Buffer[(16, 16), "int32"], B: T.Buffer[(16, 16), "int32"]) -> No # Input IRModule. @R.function def before(c0: R.Tensor((16, 16), "int32")): - lv0 = R.call_tir(addone, (c0,), (16, 16), dtype="int32") + lv0 = R.call_tir(addone, (c0,), R.Tensor((16, 16), dtype="int32")) return lv0 # Expected IRModule after transformation. diff --git a/tests/python/relax/test_tvmscript_ir_builder.py b/tests/python/relax/test_tvmscript_ir_builder.py index 15d84af26c82..12d8b114b862 100644 --- a/tests/python/relax/test_tvmscript_ir_builder.py +++ b/tests/python/relax/test_tvmscript_ir_builder.py @@ -25,7 +25,7 @@ def test_function_simple(): """ @R.function def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): - out = R.call_tir("extern_func", x, (128, 128), dtype="float32") + out = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) return out """ # create with Script IRBuilder @@ -35,7 +35,9 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) R.func_attr({"Primitive": 1}) x = R.arg("x", relax.TensorStructInfo((128, 128), "float32")) R.func_ret_struct_info(relax.TensorStructInfo(dtype="float32", ndim=2)) - out = R.emit(R.call_tir("extern_func", x, (128, 128), dtype="float32")) + out = R.emit( + R.call_tir("extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32")) + ) IRBuilder.name("out", out) R.func_ret_value(out) func = ir_builder.get() @@ -43,7 +45,9 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) x = relax.Var("x", relax.TensorStructInfo((128, 128), "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x,), attrs={"Primitive": 1}): - out = bb.emit(relax.call_tir("extern_func", x, (128, 128), dtype="float32")) + out = bb.emit( + relax.call_tir("extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32")) + ) bb.emit_func_output(out) mod = bb.get() @@ -98,7 +102,7 @@ def test_dataflow_block(): def foo(x: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2): # block 0 with R.dataflow(): - lv0 = R.call_tir("extern_func", (x,), (128, 128), dtype="float32") + lv0 = R.call_tir("extern_func", (x,), R.Tensor((128, 128), dtype="float32")) gv: Tensor((128, 128), "float32") = lv0 R.output(gv) return gv @@ -109,7 +113,11 @@ def foo(x: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2): R.func_name("foo") x = R.arg("x", relax.TensorStructInfo((128, 128), "float32")) with R.dataflow() as df: - lv0 = R.emit(R.call_tir("extern_func", x, (128, 128), dtype="float32")) + lv0 = R.emit( + R.call_tir( + "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") + ) + ) IRBuilder.name("lv0", lv0) gv = R.emit(lv0) IRBuilder.name("gv", gv) @@ -123,7 +131,11 @@ def foo(x: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2): bb = relax.BlockBuilder() with bb.function("foo", (x,)): with bb.dataflow(): - lv0 = bb.emit(relax.call_tir("extern_func", x, (128, 128), dtype="float32")) + lv0 = bb.emit( + relax.call_tir( + "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") + ) + ) gv = bb.emit_output(lv0) bb.emit_func_output(gv) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 296802c041d1..f03a7fab6e38 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -43,13 +43,13 @@ def test_simple_func(): @R.function def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): R.func_attr({"Primitive": 1}) - gv0 = R.call_tir("extern_func", x, (128, 128), dtype="float32") + gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) return gv0 x = relax.Var("x", R.Tensor((128, 128), "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x,), attrs={"Primitive": 1}): - out = bb.emit(relax.call_tir("extern_func", x, (128, 128), dtype="float32")) + out = bb.emit(relax.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32"))) bb.emit_func_output(out) _check(foo, bb.get()["foo"]) @@ -60,7 +60,7 @@ def test_error_report(): @R.function def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): - gv0 = gv1 = R.call_tir("extern_func", x, (128, 128), dtype="float32") + gv0 = gv1 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) return gv0 @@ -81,7 +81,7 @@ def tir_func( @R.function def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): # TODO(Siyuan): Need to change to `TestModule.tir_func` - gv0 = R.call_tir(tir_func, x, (128, 128), dtype="float32") + gv0 = R.call_tir(tir_func, x, R.Tensor((128, 128), dtype="float32")) return gv0 x = relax.Var("x", R.Tensor((128, 128), "float32")) @@ -132,14 +132,14 @@ def test_symbolic_shape(): def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): m = T.var("int64", "m") n = T.var("int64", "n") - gv0 = R.call_tir("extern_func", x, (m, n), dtype="float32") + gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32")) return gv0 @R.function def bar(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): m = T.var("int64") n = T.var("int64") - gv0 = R.call_tir("extern_func", x, (m, n), dtype="float32") + gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32")) return gv0 with pytest.raises(tvm.error.DiagnosticError): @@ -148,7 +148,7 @@ def bar(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): def mismatch_dtype(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(None, "float32", ndim=2): m = T.var("int64") n = T.var("int32") # The shape dtype should be int64 - gv0 = R.call_tir("extern_func", x, (m, n), dtype="float32") + gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32")) return gv0 def _expected(name: str): @@ -156,7 +156,7 @@ def _expected(name: str): x = relax.Var("x", R.Tensor([m, n], "float32")) bb = relax.BlockBuilder() with bb.function(name, (x,)): - out = bb.emit(relax.call_tir("extern_func", x, (m, n), dtype="float32")) + out = bb.emit(relax.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32"))) bb.emit_func_output(out) return bb.get()[name] @@ -220,15 +220,15 @@ def foo(x: R.Tensor("float32"), y: R.Tensor("float32")): def test_tuple_return(): @R.function def foo(x: R.Tensor((4, 4), "float32")): - gv0 = R.call_tir("extern_func_0", x, (4, 4), dtype="float32") - gv1 = R.call_tir("extern_func_1", x, (4, 4), dtype="float32") + gv0 = R.call_tir("extern_func_0", x, R.Tensor((4, 4), dtype="float32")) + gv1 = R.call_tir("extern_func_1", x, R.Tensor((4, 4), dtype="float32")) return (gv0, gv1) x = relax.Var("x", R.Tensor((4, 4), "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x,)): - gv0 = bb.emit(relax.call_tir("extern_func_0", x, (4, 4), dtype="float32")) - gv1 = bb.emit(relax.call_tir("extern_func_1", x, (4, 4), dtype="float32")) + gv0 = bb.emit(relax.call_tir("extern_func_0", x, R.Tensor((4, 4), dtype="float32"))) + gv1 = bb.emit(relax.call_tir("extern_func_1", x, R.Tensor((4, 4), dtype="float32"))) bb.emit_func_output(relax.Tuple((gv0, gv1))) _check(foo, bb.get()["foo"]) @@ -276,8 +276,8 @@ def test_dataflow_block(): @R.function def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): with R.dataflow(): - lv0 = R.call_tir("extern_func", x, (128, 128), dtype="float32") - lv1 = R.call_tir("extern_func", lv0, (128, 128), dtype="float32") + lv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + lv1 = R.call_tir("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) gv = lv1 R.output(gv) return gv @@ -286,8 +286,8 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) bb = relax.BlockBuilder() with bb.function("foo", (x,)): with bb.dataflow(): - lv0 = bb.emit(relax.call_tir("extern_func", x, (128, 128), dtype="float32")) - lv1 = bb.emit(relax.call_tir("extern_func", lv0, (128, 128), dtype="float32")) + lv0 = bb.emit(relax.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32"))) + lv1 = bb.emit(relax.call_tir("extern_func", lv0, R.Tensor((128, 128), dtype="float32"))) gv = bb.emit_output(lv1) bb.emit_func_output(gv) @@ -297,22 +297,22 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) def test_dataflow_block_advanced(): @R.function def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): - gv0 = R.call_tir("extern_func", x, (128, 128), dtype="float32") - gv1 = R.call_tir("extern_func", gv0, (128, 128), dtype="float32") + gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + gv1 = R.call_tir("extern_func", gv0, R.Tensor((128, 128), dtype="float32")) with R.dataflow(): m = T.var("int64") n = T.var("int64") - lv0 = R.call_tir("extern_func", gv1, (128, 128), dtype="float32") + lv0 = R.call_tir("extern_func", gv1, R.Tensor((128, 128), dtype="float32")) lv1 = R.match_cast(lv0, R.Tensor((m, n), "float32")) - gv2 = R.call_tir("extern_func", lv0, (128, 128), dtype="float32") - gv2 = R.call_tir("extern_func", gv2, (128, 128), dtype="float32") + gv2 = R.call_tir("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) + gv2 = R.call_tir("extern_func", gv2, R.Tensor((128, 128), dtype="float32")) gv3 = R.match_cast(gv2, R.Tensor((m, n), "float32")) gv3 = R.match_cast(lv0, R.Tensor((m, n), "float32")) gv4 = gv3 gv5 = gv2 R.output(gv5, gv4) - gv6 = R.call_tir("extern_func", gv5, (128, 128), dtype="float32") - gv7 = R.call_tir("extern_func", gv6, (128, 128), dtype="float32") + gv6 = R.call_tir("extern_func", gv5, R.Tensor((128, 128), dtype="float32")) + gv7 = R.call_tir("extern_func", gv6, R.Tensor((128, 128), dtype="float32")) return gv7 x = relax.Var("x", R.Tensor((128, 128), "float32")) @@ -320,19 +320,21 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) m = tir.Var("m", dtype="int64") n = tir.Var("n", dtype="int64") with bb.function("foo", (x,)): - gv0 = bb.emit(relax.call_tir("extern_func", x, (128, 128), dtype="float32")) - gv1 = bb.emit(relax.call_tir("extern_func", gv0, (128, 128), dtype="float32")) + gv0 = bb.emit(relax.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32"))) + gv1 = bb.emit(relax.call_tir("extern_func", gv0, R.Tensor((128, 128), dtype="float32"))) with bb.dataflow(): - lv0 = bb.emit(relax.call_tir("extern_func", gv1, (128, 128), dtype="float32")) + lv0 = bb.emit(relax.call_tir("extern_func", gv1, R.Tensor((128, 128), dtype="float32"))) lv1 = bb.match_cast(lv0, R.Tensor((m, n), "float32")) - gv2 = bb.emit(relax.call_tir("extern_func", lv0, (128, 128), dtype="float32")) - gv21 = bb.emit(relax.call_tir("extern_func", gv2, (128, 128), dtype="float32")) + gv2 = bb.emit(relax.call_tir("extern_func", lv0, R.Tensor((128, 128), dtype="float32"))) + gv21 = bb.emit( + relax.call_tir("extern_func", gv2, R.Tensor((128, 128), dtype="float32")) + ) gv3 = bb.match_cast(gv21, R.Tensor((m, n), "float32")) gv31 = bb.match_cast(lv0, R.Tensor((m, n), "float32")) gv32 = bb.emit_output(gv31) gv22 = bb.emit_output(gv21) - gv4 = bb.emit(relax.call_tir("extern_func", gv22, (128, 128), dtype="float32")) - gv5 = bb.emit(relax.call_tir("extern_func", gv4, (128, 128), dtype="float32")) + gv4 = bb.emit(relax.call_tir("extern_func", gv22, R.Tensor((128, 128), dtype="float32"))) + gv5 = bb.emit(relax.call_tir("extern_func", gv4, R.Tensor((128, 128), dtype="float32"))) bb.emit_func_output(gv5) _check(foo, bb.get()["foo"]) @@ -344,9 +346,9 @@ def test_dataflow_binding_after_output(): @R.function def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): with R.dataflow(): - gv = R.call_tir("extern_func", x, (128, 128), dtype="float32") + gv = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) R.output(gv) - lv = R.call_tir("extern_func", gv, (128, 128), dtype="float32") + lv = R.call_tir("extern_func", gv, R.Tensor((128, 128), dtype="float32")) return gv @@ -355,9 +357,9 @@ def test_dataflow_output_global_var(): @R.function def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): - gv0 = R.call_tir("extern_func", x, (128, 128), dtype="float32") + gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) with R.dataflow(): - gv1 = R.call_tir("extern_func", gv0, (128, 128), dtype="float32") + gv1 = R.call_tir("extern_func", gv0, R.Tensor((128, 128), dtype="float32")) R.output(gv0, gv1) return gv1 @@ -368,7 +370,7 @@ def test_dataflow_multiple_output(): @R.function def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): with R.dataflow(): - gv = R.call_tir("extern_func", x, (128, 128), dtype="float32") + gv = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) R.output(gv) R.output(gv) return gv @@ -379,7 +381,7 @@ def test_dataflow_output_outside_dataflow_block(): @R.function def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): - gv = R.call_tir("extern_func", x, (128, 128), dtype="float32") + gv = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) R.output(gv) return gv @@ -411,19 +413,19 @@ def test_function_without_return(): @R.function def foo(x: R.Tensor((128, 128), "float32")): - gv0 = R.call_tir("extern_func", x, (128, 128), dtype="float32") + gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) def test_tensor_type_without_args(): @R.function def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: - v = R.call_tir("tir_relu", x, (32, 32), dtype="float32") + v = R.call_tir("tir_relu", x, R.Tensor((32, 32), dtype="float32")) return v x = relax.Var("x", R.Tensor((32, 32), "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x)): - v = bb.emit(relax.call_tir("tir_relu", x, (32, 32), dtype="float32")) + v = bb.emit(relax.call_tir("tir_relu", x, R.Tensor((32, 32), dtype="float32"))) bb.emit_func_output(v) _check(foo, bb.get()["foo"]) @@ -445,7 +447,7 @@ def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): def test_call_packed(): @R.function def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: - z = R.call_packed("vm.builtin.copy", x, type_args=R.Tensor((32, 32), "float32")) + z = R.call_packed("vm.builtin.copy", x, sinfo_args=R.Tensor((32, 32), "float32")) return z x = relax.Var("x", R.Tensor((32, 32), "float32")) @@ -456,7 +458,7 @@ def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: relax.ExternFunc("vm.builtin.copy"), (x,), None, - sinfo_args=[R.Tensor("float32", ndim=2)], + sinfo_args=[R.Tensor((32, 32), "float32")], ) ) bb.emit_func_output(z) @@ -480,7 +482,7 @@ def foo( _: R.Tensor((1, 1), "int8") = R.builtin.alloc_tensor( (1, 1), dtype="int8", runtime_device_index=0 ) - o: R.Object = R.call_packed("contrib.tensor_array_stack", x, y, type_args=R.Object) + o: R.Object = R.call_packed("contrib.tensor_array_stack", x, y, sinfo_args=R.Object) return o def _check_struct_info(binding, expected_sinfo): @@ -537,11 +539,11 @@ def bar(x: R.Tensor): def test_call_tir_empty_shape(): @R.function def foo(x: R.Tensor((), "float32")): - z = R.call_tir("scalar_add", x, (), dtype="float32") + z = R.call_tir("scalar_add", x, R.Tensor((), dtype="float32")) return z (z_bind,) = foo.body.blocks[0].bindings - shape_expr = z_bind.value.args[2] + shape_expr = z_bind.value.sinfo_args[0].shape assert isinstance(shape_expr, relax.ShapeExpr) assert len(shape_expr.values) == 0 @@ -565,7 +567,7 @@ def main( dumb_param: R.Tensor(("n",), "float32"), x: R.Tensor(("n * 2", "float32")) ) -> R.Tensor(("n * 2",), "float32"): n = T.var("int64") - y = R.call_tir(copy, (x,), ((n * 2,)), dtype="float32", tir_vars=(n,)) + y = R.call_tir(copy, (x,), R.Tensor(((n * 2,)), dtype="float32"), tir_vars=(n,)) return y @T.prim_func @@ -627,7 +629,7 @@ def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = 0.0 C[vi, vj] += A[vi, vk] * B[vj, vk] - z = relax.call_tir(my_matmul, (x, y), (128, 128), dtype="float32") + z = relax.call_tir(my_matmul, (x, y), R.Tensor((128, 128), dtype="float32")) return z bindings = TestModule["f"].body.blocks[0].bindings @@ -842,7 +844,7 @@ def bar( x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",), "float32") ) -> R.Tensor(("T.max(m, 20) + 1",), "float32"): m = T.var("int64") - z = R.call_tir("test_intrin", (x, y), (T.max(m, 20) + 1,), dtype="float32") + z = R.call_tir("test_intrin", (x, y), R.Tensor((T.max(m, 20) + 1,), dtype="float32")) return z m = tir.Var("m", "int64") @@ -850,7 +852,9 @@ def bar( y = relax.Var("y", relax.TensorStructInfo([tir.max(m, 20)], "float32")) bb = relax.BlockBuilder() with bb.function("bar", (x, y)): - z = bb.emit(relax.call_tir("test_intrin", (x, y), (tir.max(m, 20) + 1,), dtype="float32")) + z = bb.emit( + relax.call_tir("test_intrin", (x, y), R.Tensor((tir.max(m, 20) + 1,), dtype="float32")) + ) bb.emit_func_output(z) _check(bar, bb.get()["bar"]) @@ -859,7 +863,7 @@ def bar( @R.function def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")): m = T.var("int64") - z = R.call_tir("test_intrin", y, (m * 2,), dtype="float32") + z = R.call_tir("test_intrin", y, R.Tensor((m * 2,), dtype="float32")) return z m = tir.Var("m", "int64") @@ -867,7 +871,7 @@ def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")): y = relax.Var("y", relax.TensorStructInfo([m * 2], "float32")) bb = relax.BlockBuilder() with bb.function("baz", (x, y)): - z = bb.emit(relax.call_tir("test_intrin", (y), (m * 2,), dtype="float32")) + z = bb.emit(relax.call_tir("test_intrin", (y), R.Tensor((m * 2,), dtype="float32"))) bb.emit_func_output(z) _check(baz, bb.get()["baz"]) @@ -944,7 +948,7 @@ def foo(x: R.Tensor(("m", "n"), dtype="float32")): def test_prim_value(): @R.function def foo(): - gv = R.call_packed("test", 1, type_args=R.Tensor((32, 32), "float32")) + gv = R.call_packed("test", 1, sinfo_args=R.Tensor((32, 32), "float32")) return gv _check(foo, None) @@ -953,7 +957,7 @@ def foo(): def test_string_imm(): @R.function def foo(): - gv = R.call_packed("test", "hello", type_args=R.Tensor((32, 32), "float32")) + gv = R.call_packed("test", "hello", sinfo_args=R.Tensor((32, 32), "float32")) return gv _check(foo, None) @@ -962,7 +966,7 @@ def foo(): def test_datatype_imm(): @R.function def foo(): - gv = R.call_packed("test", R.dtype("float32"), type_args=R.Tensor((32, 32), "float32")) + gv = R.call_packed("test", R.dtype("float32"), sinfo_args=R.Tensor((32, 32), "float32")) return gv _check(foo, None) diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index 628fc9f65ccb..f21d759c283c 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -38,7 +38,7 @@ class TestVMCompileStage0: @R.function def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): z = R.call_packed( - "test.vm.identity", x, y, type_args=(R.Tensor(ndim=2, dtype="float32")) + "test.vm.identity", x, y, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) ) return y @@ -119,7 +119,7 @@ class TestVMCompileStage3: @R.function def foo(x: R.Tensor((32, 16), "float32")) -> R.Tensor: with R.dataflow(): - y = R.call_tir("test.vm.identity", (x), (32, 16), dtype="float32") + y = R.call_tir("test.vm.identity", (x), R.Tensor((32, 16), dtype="float32")) R.output(y) return y @@ -143,7 +143,7 @@ def foo(x: R.Tensor(dtype="float32")) -> R.Tensor: with R.dataflow(): n, m = T.var("int64"), T.var("int64") _ = R.match_cast(x, R.Tensor((n, m), "float32")) - y = R.call_tir("test.vm.tile", (x), (n, m * 2), dtype="float32") + y = R.call_tir("test.vm.tile", (x), R.Tensor((n, m * 2), dtype="float32")) R.output(y) return y @@ -185,7 +185,7 @@ def func( x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32") ) -> R.Tensor: m, k = T.var("int64"), T.var("int64") - gv0 = R.call_tir(tir_matmul, (x, w), (m, k), dtype="float32") + gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((m, k), dtype="float32")) return gv0 mod = TestVMCompileE2E2 @@ -490,7 +490,7 @@ def tuple_get_item( t = (x, y) a = t[0] b = t[1] - c = R.call_packed("test.vm.add", a, b, type_args=(R.Tensor(ndim=2, dtype="float32"))) + c = R.call_packed("test.vm.add", a, b, sinfo_args=(R.Tensor(ndim=2, dtype="float32"))) return c mod = TestVMTupleGetItem @@ -559,7 +559,7 @@ def relax_matmul_tir( x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32") ) -> R.Tensor((32, 32), dtype="float32"): with R.dataflow(): - gv0 = R.call_tir(tir_matmul, (x, w), (32, 32), dtype="float32") + gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) R.output(gv0) return gv0 @@ -567,7 +567,7 @@ def relax_matmul_tir( def relax_matmul_packed( x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32") ) -> R.Object: - gv0 = R.call_packed("test.vm.mul", x, w, type_args=(R.Tensor(ndim=2, dtype="float32"))) + gv0 = R.call_packed("test.vm.mul", x, w, sinfo_args=(R.Tensor(ndim=2, dtype="float32"))) return gv0 @R.function @@ -594,17 +594,17 @@ class TestVMRecursion: @R.function def recursion(n: R.Tensor((1,), "float32")) -> R.Tensor: cond = R.call_packed( - "test.vm.equal_zero", n, type_args=(R.Tensor(ndim=1, dtype="float32")) + "test.vm.equal_zero", n, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) ) if cond: res = R.const(1.0) else: gv0 = R.call_packed( - "test.vm.subtract_one", n, type_args=(R.Tensor(ndim=1, dtype="float32")) + "test.vm.subtract_one", n, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) ) tmp = recursion(gv0) res = R.call_packed( - "test.vm.add", tmp, tmp, type_args=(R.Tensor(ndim=1, dtype="float32")) + "test.vm.add", tmp, tmp, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) ) return res @@ -626,7 +626,7 @@ def test_vm_closure(exec_mode): class TestClosure: @R.function def lifted_func_1(x: R.Tensor((2, 3), "float32"), env: R.Tensor((2, 3), "float32")): - return R.call_packed("test.vm.add", x, env, type_args=(R.Tensor)) + return R.call_packed("test.vm.add", x, env, sinfo_args=(R.Tensor)) @R.function def main( @@ -634,7 +634,7 @@ def main( y: R.Tensor((2, 3), "float32"), ): clo = R.make_closure(lifted_func_1, (x,)) - res = R.invoke_closure(clo, (y,), type_args=(R.Tensor)) + res = R.invoke_closure(clo, (y,), sinfo_args=(R.Tensor)) return res mod = TestClosure @@ -653,7 +653,9 @@ def test_time_evaluator(exec_mode): class TestTimeEvaluator: @R.function def main(x: R.Tensor((1,), "float32"), y: R.Tensor((1,), "float32")): - return R.call_packed("test.vm.add", x, y, type_args=(R.Tensor(ndim=1, dtype="float32"))) + return R.call_packed( + "test.vm.add", x, y, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) + ) target = tvm.target.Target("llvm", host="llvm") ex = relax.vm.build(TestTimeEvaluator, target, exec_mode=exec_mode) @@ -716,7 +718,7 @@ def test_vm_nested_tuple( @R.function def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: - gv0 = R.call_tir("test_vm_mul", (x, w), (32, 32), dtype="float32") + gv0 = R.call_tir("test_vm_mul", (x, w), R.Tensor((32, 32), dtype="float32")) return gv0 diff --git a/tests/python/relax/test_vm_codegen_only.py b/tests/python/relax/test_vm_codegen_only.py index e3acba91cdc1..667450ce03f9 100644 --- a/tests/python/relax/test_vm_codegen_only.py +++ b/tests/python/relax/test_vm_codegen_only.py @@ -40,7 +40,7 @@ class TestVMMove: @R.function def foo(x: R.Tensor((3, 4), "float32")): R.func_attr({"global_symbol": "foo"}) - z = R.call_packed("vm.builtin.copy", x, type_args=(R.Tensor(ndim=2, dtype="float32"))) + z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((3, 4), dtype="float32"))) return z mod = TestVMMove @@ -81,7 +81,7 @@ class TestVMMove: @R.function def foo(x: R.Tensor((3, 4), "float32")): R.func_attr({"global_symbol": "foo"}) - z = R.call_packed("vm.builtin.copy", x, type_args=(R.Tensor(ndim=2, dtype="float32"))) + z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((3, 4), dtype="float32"))) return z mod = TestVMMove @@ -105,9 +105,9 @@ class TestVMCompileIf: def ife(cond: R.Tensor((), "bool"), x: R.Tensor((3, 4), "float32")) -> R.Tensor: R.func_attr({"global_symbol": "ife"}) if cond: - w = R.call_packed("test.vm.add", x, x, type_args=(R.Tensor)) + w = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor)) else: - w = R.call_packed("test.vm.mul", x, x, type_args=(R.Tensor)) + w = R.call_packed("test.vm.mul", x, x, sinfo_args=(R.Tensor)) return w mod = TestVMCompileIf @@ -158,13 +158,13 @@ def main(x: R.Tensor(ndim=2, dtype="float32")): "test.vm.add", relax.const([1, 2]), relax.const([3, 4]), - type_args=(R.Tensor(ndim=2, dtype="float32")), + sinfo_args=(R.Tensor(ndim=2, dtype="float32")), ) b = R.call_packed( "test.vm.add", a, x, - type_args=(R.Tensor(ndim=2, dtype="float32")), + sinfo_args=(R.Tensor(ndim=2, dtype="float32")), ) return b @@ -197,7 +197,7 @@ def main(x: R.Tensor(["n", "m"], "float32")) -> R.Shape(ndim=3): [], int_args=[3], require_ctx=True, - type_args=[R.Tensor(ndim=1, dtype="int64")], + sinfo_args=[R.Tensor(ndim=1, dtype="int64")], ) _ = R.call_builtin( "vm.builtin.check_tensor_info", @@ -225,7 +225,7 @@ def main(x: R.Tensor(["n", "m"], "float32")) -> R.Shape(ndim=3): MK.USE_IMM, 2, ], - type_args=[R.Shape(ndim=3)], + sinfo_args=[R.Shape(ndim=3)], ) return s diff --git a/tests/python/relax/test_vm_codegen_tir.py b/tests/python/relax/test_vm_codegen_tir.py index 93889114aaa5..17377bb9a683 100644 --- a/tests/python/relax/test_vm_codegen_tir.py +++ b/tests/python/relax/test_vm_codegen_tir.py @@ -37,7 +37,7 @@ class Before: @R.function def foo(x: R.Tensor): R.func_attr({"global_symbol": "foo"}) - z = R.call_packed("test.vm.add", x, x, type_args=(R.Tensor)) + z = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor)) return z @tvm.script.ir_module @@ -108,9 +108,9 @@ class Before: def ife(cond: R.Tensor((), "bool"), x: R.Tensor) -> R.Tensor: R.func_attr({"global_symbol": "ife"}) if cond: - w = R.call_packed("test.vm.add", x, x, type_args=(R.Tensor)) + w = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor)) else: - w = R.call_packed("test.vm.mul", x, x, type_args=(R.Tensor)) + w = R.call_packed("test.vm.mul", x, x, sinfo_args=(R.Tensor)) return w @tvm.script.ir_module @@ -195,7 +195,7 @@ class Before: def main(x: R.Tensor): R.func_attr({"global_symbol": "main"}) y = R.const([1, 2]) - z = R.call_packed("test.vm.add", x, y, type_args=(R.Tensor)) + z = R.call_packed("test.vm.add", x, y, sinfo_args=(R.Tensor)) return z @tvm.script.ir_module