diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 13f39317dbe4c..6f669b7c634f6 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -182,7 +182,7 @@ class PlaceholderOpNode : public OperationNode { } static constexpr const char* _type_key = "PlaceholderOp"; - TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode); + TVM_DECLARE_BASE_OBJECT_INFO(PlaceholderOpNode, OperationNode); }; /*! diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index a2f53d5756308..77717a40ddc86 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -49,6 +49,7 @@ # helper functions const = expr.const extern = expr.extern +te_tensor = expr.te_tensor # Type Type = ty.Type diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index cd1fca8a6e883..4fac922580abb 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -15,11 +15,14 @@ # specific language governing permissions and limitations # under the License. """Developer API of constructing Relax AST.""" -from typing import List, Optional, Union, Dict +import typing +from typing import List, Optional, Union, Dict, Any, Callable from tvm.relay.expr import Tuple from tvm.runtime import Object from tvm import relax as rx +from tvm import tir from .expr import * +from .op.base import call_dps from tvm._ffi.base import _LIB, check_call from . import _ffi_api @@ -72,7 +75,7 @@ class BlockBuilder(Object): dtype1 = rx.DynTensorType(rank=1, dtype="float16") x = rx.Var("x", [m, n], dtype0) y = rx.Var("y", [n], dtype1) - ib = rx.IRBuilder() + ib = rx.BlockBuilder() with ib.function([x, y], "func"): with ib.dataflow() as df: lv0 = ib.emit(rx.add(x, y)) @@ -84,6 +87,7 @@ class BlockBuilder(Object): def __init__(self): self._blocks = [] + self._context_mod = tvm.IRModule() self.__init_handle_by_constructor__(_ffi_api.BlockBuilderCreate) def _begin_dataflow_block(self) -> None: @@ -91,10 +95,61 @@ def _begin_dataflow_block(self) -> None: def _begin_binding_block(self) -> None: _ffi_api.BlockBuilderBeginBindingBlock(self) - + def _end_block(self) -> BindingBlock: return _ffi_api.BlockBuilderEndBlock(self) + def _convert_te_arg(self, + te_args: Any + ) -> typing.Tuple[Any, List[tvm.te.Tensor]]: + """Helper function to convert Relax expressions to te tensor. + In the common case, the type of te_args is a Relax expression and is converted into a te tensor. + If te_args is a nested or recursive datatype (i.e list, dict, tvm.ir.Map, tvm.ir.Array), + we recursive and convert any value of type Relax expression into a te tensor. + Common values of type int, float, and str are preserved. + + Parameters + ---------- + te_args : Any + Argument to convert to te + + Returns + ------- + ret : (Any, [tvm.te.Tensor]) + A tuple of the converted te_args, and a list of te tensors for each converted Relax expression + """ + te_args_list = [] + + def _convert_te_arg_helper(arg): + if isinstance(arg, Expr): + arg = te_tensor(arg) + te_args_list.append(arg) + return arg + elif isinstance(arg, (list, tvm.ir.Array)): + return [_convert_te_arg_helper(x) for x in arg] + elif isinstance(arg, tuple): + return tuple([_convert_te_arg_helper(x) for x in arg]) + elif isinstance(arg, (dict, tvm.ir.Map)): + for key in arg: + assert isinstance(key, str), "emit_te only supports dict with string as the key currently" + return {k: _convert_te_arg_helper(arg[k]) for k in arg} + elif isinstance(arg, (int, float, str)): + return arg + else: + raise TypeError("not supported type in emit_te: {}".format(type(arg))) + + new_arg = _convert_te_arg_helper(te_args) + return new_arg, te_args_list + + def _check_te_args(self, args: List[tvm.te.Tensor]): + """check te arguments.""" + #TODO(hypercubestart, ziheng) support full dynamic shape in the future + for x in args: + for s in x.shape: + if not isinstance(s, (tir.Var, tir.IntImm)): + raise ValueError("emit_te not support symbolic shape" + "contains expression now: {}".format(x.shape)) + def function(self, params: Optional[Union[Var, Tuple, List[Var]]] = None, name: Optional[str] = "") -> FunctionScope: @@ -139,7 +194,7 @@ def emit(self, call: relay.Call) -> Var: Parameters ---------- - call : tvm.relay.Call + call : tvm.relax.Call The call node to be emitted. Returns @@ -149,12 +204,97 @@ def emit(self, call: relay.Call) -> Var: """ return _ffi_api.BlockBuilderEmit(self, call) + def emit_te(self, func: Callable, *args: Any, **kwargs: Any) -> Var: + """Emit a call node according to the te function. + This function converts arguments from relax expression to te tensor, + The callback func should return a te tensor. + + Parameters + ---------- + func : Callable + A function that return a te tensor. + + Returns + ------- + ret : tvm.relax.Var + A newly created variable that gets binded to the call code. + + Example + ------- + + .. code-block:: python + + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + type_anno = rx.DynTensorType(2, "float32") + x = rx.Var("x", [n, m], type_anno) + y = rx.Var("y", [n, m], type_anno) + + def te_func(args, args_dict, msg): + A = args[0] + B = args_dict["B"] + return te.compute((128, 128), lambda i, j: A[i, j] + B[i, j]) + + with bb.function([x, y], "rx_func"): + out = bb.emit_te(te_func, [x], {"B": y}, msg="hello") + bb.emit_func_output(out) + + will result in TVMScript + + .. code-block:: python + + @tvm.script.ir_module + class Module: + @T.prim_func + def te_func(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_compute: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "te_func"}) + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [n, m], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [n, m], dtype="float32") + compute = T.match_buffer(var_compute, [128, 128], dtype="float32") + # body + # with T.block("root") + for i0, i1 in T.grid(128, 128): + with T.block("compute"): + i, j = T.axis.remap("SS", [i0, i1]) + T.reads([rxplaceholder[i, j], rxplaceholder_1[i, j]]) + T.writes([compute[i, j]]) + compute[i, j] = rxplaceholder[i, j] + rxplaceholder_1[i, j] + + @R.function + def rx_func(x: Tensor[(n, m), "float32"], y: Tensor[(n, m), "float32"]) -> Tensor: + # block 0 + gv = relax.call_dps((128, 128), "te_func", (x, y)) + return gv + """ + new_args, te_arg_list = self._convert_te_arg(args) + new_kwargs, te_kwarg_list = self._convert_te_arg(kwargs) + + te_args = te_arg_list + te_kwarg_list + self._check_te_args(te_args) + + # TODO(hypercubestart, ziheng) handle multiple output case + te_out = func(*new_args, **new_kwargs) + assert isinstance(te_out, tvm.te.tensor.Tensor), "only support te tensor as function output" + + inputs = [*te_args, te_out] + tir_func = tvm.te.create_prim_func(inputs) + func_name = _ffi_api.BlockBuilderGetUniqueName(self, func.__name__) + tir_func = tir_func.with_attr("global_symbol", func_name) + gvar = GlobalVar(func_name) + self._context_mod[gvar] = tir_func + call = call_dps(inputs[-1].shape, gvar, [x.op.value for x in inputs[:-1]]) + return _ffi_api.BlockBuilderEmit(self, call) + + def match_shape(self, value: Expr, pattern: List[PrimExpr]) -> Var: """Emit a MatchShape. Parameters ---------- - value : tvm.relay.Expr + value : tvm.relax.Expr The value of the MatchShape to be emitted. pattern : List[PrimExpr] @@ -224,8 +364,19 @@ def get(self) -> Function: ret : tvm.relax.Function A Relax function node being built. """ + # TODO(hyoercubestart, ziheng) get should return IRModule with relax + TIR functions seqe = rx.SeqExpr(self._blocks, self._func_ret) func = rx.Function( self._func_params, seqe, rx.DynTensorType(-1, "float32"), rx.GlobalVar(self._func_name) ) return func + + def context_mod(self): + """Return the context module that might contain tir functions. + + Returns + ------- + mod : tvm.IRModule + The context module that contains tir functions during emit. + """ + return self._context_mod diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 2e570eb8ab916..9ae771a72340e 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -169,5 +169,11 @@ def __init__(self, global_symbol: String, span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.ExternFunc, global_symbol, span) -def extern(name, span: Span = None): +def extern(name: str, span: Span = None): + """Create extern function.""" return ExternFunc(name, span) + + +def te_tensor(value: Expr, name: str = "rxplaceholder"): + """Create te tensor from relax expression.""" + return _ffi_api.TETensor(value, name) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 2fa88cb954fb2..72d8d6b981e7a 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -13,7 +13,7 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations -from ...ir import BaseFunc +from ...ir import BaseFunc, Array from ..expr import Expr, ShapeExpr, Tuple, Call from . import _ffi_api from typing import Union, List @@ -41,7 +41,7 @@ def call_dps( ret: Call A call node for the call_dps operator. """ - if isinstance(shape, (list, tuple)): + if isinstance(shape, (list, tuple, Array)): shape = ShapeExpr(shape) if isinstance(args, (list, tuple)): args = Tuple(args) diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 5b95e191247cd..aa4ca0b7c509f 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -550,5 +550,10 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderNormalize") return builder->Normalize(expr); }); +TVM_REGISTER_GLOBAL("relax.BlockBuilderGetUniqueName") + .set_body_typed([](BlockBuilder builder, String name_hint) { + return builder->name_table()->GetUniqueName(name_hint); + }); + } // namespace relax } // namespace tvm diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc new file mode 100644 index 0000000000000..181e893a587a7 --- /dev/null +++ b/src/relax/ir/emit_te.cc @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/src/ir/emit_te.cc + */ +#include +#include "./emit_te.h" + +namespace tvm { +namespace relax { + +// RXPlaceholderOpNode +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "rxplaceholder(" << op->name << ", " << op << ")"; +}); + +TVM_REGISTER_NODE_TYPE(RXPlaceholderOpNode); + +te::Tensor TETensor(Expr value, std::string name) { + auto n = make_object(); + n->name = name; + n->value = value; + + Expr shape_expr = value->shape(); + CHECK(shape_expr->IsInstance()) + << "ValueError: Expression does not have an known symbolic shape, please consider use match_shape " + << "to constrain the shape before passing into te_tensor"; + Array shape = Downcast(shape_expr)->values; + n->shape = shape; + Type type = value->checked_type(); + ICHECK(type->IsInstance()) + << "ValueError: Expression should have a inferred DynTensorType: " + << type->GetTypeKey(); + DataType dtype = Downcast(type)->dtype; + n->dtype = dtype; + return te::PlaceholderOp(n).output(0); +} + +TVM_REGISTER_GLOBAL("relax.TETensor") +.set_body_typed([](Expr value, std::string name) { + return TETensor(value, name); +}); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/emit_te.h b/src/relax/ir/emit_te.h new file mode 100644 index 0000000000000..acdecd325e48a --- /dev/null +++ b/src/relax/ir/emit_te.h @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/src/ir/emit_te.h + * \brief Tensor expression extension in Relax. + */ +#ifndef TVM_RELAX_IR_EMIT_TE_H_ +#define TVM_RELAX_IR_EMIT_TE_H_ + +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief A placeholder op that represents a relax expression. + */ +class RXPlaceholderOpNode : public te::PlaceholderOpNode { + public: + /*! \brief The relax expression. */ + Expr value; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("tag", &tag); + v->Visit("attrs", &attrs); + v->Visit("value", &value); + v->Visit("shape", &shape); + v->Visit("dtype", &dtype); + } + + static constexpr const char* _type_key = "RXPlaceholderOp"; + TVM_DECLARE_FINAL_OBJECT_INFO(RXPlaceholderOpNode, te::PlaceholderOpNode); +}; + +/*! + * \brief create a te tensor from relax expression. + * \param value The relax experession. + * \param name The name of the tensor. + */ +te::Tensor TETensor(Expr value, std::string name = "rxplaceholder"); + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_IR_EMIT_TE_H_ diff --git a/tests/python/relax/test_blockbuilder.py b/tests/python/relax/test_blockbuilder.py index 676bb8eeab964..497ae31ae75e5 100644 --- a/tests/python/relax/test_blockbuilder.py +++ b/tests/python/relax/test_blockbuilder.py @@ -17,10 +17,13 @@ from __future__ import annotations # must import to defer parsing of annotations import tvm -from tvm import tir +from tvm import tir, te from tvm import relay from tvm import relax as rx +import numpy as np +from tvm.ir.base import assert_structural_equal +from tvm.relax import op def test_block_builder(): m = tir.Var("m", "int32") @@ -226,6 +229,78 @@ def test_normalize(): assert add_call.shape[1] == n +def test_emit_te(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + type_anno = rx.DynTensorType(2, "float32") + x = rx.Var("x", [n, m], type_anno) + y = rx.Var("y", [n, m], type_anno) + z = rx.Var("z", [n, m], type_anno) + + def te_func(args, args_dict, msg): + A, B = args + C = args_dict["C"] + D = te.compute((128, 128), lambda i, j: A[i, j] + B[i, j]) + E = te.compute((128, 128), lambda i, j: D[i, j] - C[i, j]) + return E + + with bb.function([x, y, z], "rx_func"): + out = bb.emit_te(te_func, [x, y], {"C": z}, msg="hello") + bb.emit_func_output(out) + + func = bb.get() + mod = bb.context_mod() + + gvar = tvm.relay.GlobalVar("rx_func") + mod[gvar] = func + + def get_tir_func(): + A = te.placeholder((n, m), dtype="float32", name="A") + B = te.placeholder((n, m), dtype="float32", name="B") + C = te.placeholder((n, m), dtype="float32", name="C") + out = te_func((A, B), {"C": C}, "") + return tvm.te.create_prim_func([A, B, C, out]) + + # check TIR structure matches expected + assert_structural_equal(mod["te_func"].body, get_tir_func().body) + + # check Relax function calls TIR function with call_dps call + assert func.params[0] == x + assert func.params[1] == y + assert func.params[2] == z + assert func.name.name_hint == "rx_func" + assert func.body.body == out + assert len(func.body.blocks) == 1 + assert len(func.body.blocks[0].bindings) == 1 + assert isinstance(func.body.blocks[0].bindings[0].value, rx.Call) + assert func.body.blocks[0].bindings[0].value.op == relay.op.get("relax.call_dps") + assert len(func.body.blocks[0].bindings[0].value.args) == 3 + assert func.body.blocks[0].bindings[0].value.args[1].name_hint == "te_func" + assert func.body.blocks[0].bindings[0].value.args[2][0] == x + assert func.body.blocks[0].bindings[0].value.args[2][1] == y + assert func.body.blocks[0].bindings[0].value.args[2][2] == z + + +def test_emit_te_multiple(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + type_anno = rx.DynTensorType(2, "float32") + x = rx.Var("x", [n, m], type_anno) + y = rx.Var("y", [n, m], type_anno) + + def te_func(A): + B = te.compute((128, 128), lambda i, j: A[i, j] + 1) + return B + + with bb.function([x, y], "rx_func"): + x1 = bb.emit_te(te_func, x) + y1 = bb.emit_te(te_func, y) + bb.emit_func_output(y1) + + func = bb.get() + assert func.body.blocks[0].bindings[0].value.args[1].name_hint == "te_func" + assert func.body.blocks[0].bindings[1].value.args[1].name_hint == "te_func1" + if __name__ == "__main__": test_block_builder() test_function_single_block() @@ -233,3 +308,5 @@ def test_normalize(): test_binary_shape_type_deduction() test_emit_match_shape() test_normalize() + test_emit_te() + test_emit_te_multiple()