Skip to content

Commit

Permalink
TE Integration (apache#36)
Browse files Browse the repository at this point in the history
* Init.

* Proof of concept.

* Rebase on the newest branch

* Move to emit_te

* Update emit_te

* Make RXPlaceholderOpNode as a subclass of PlaceholderOpNode

* Update

* run vm test_te

* Update argument conversion

* Reset create_primfunc

* Update doc

* Update test

* Add error message

* Update

* Update

* Address comment

* unit test check structural and validate_te_args

* raise ValueError when multiple outputs

* address comments

* example usage emit_te

* Rename to context_mod

* Handle multiple call

* Address comments

* Address comments

* Use unique name

* remove

* rename args to te_args

* address comments

* fix TVMscript manually

* spelling

Co-authored-by: Andrew Liu <[email protected]>
  • Loading branch information
2 people authored and yongwww committed Aug 14, 2022
1 parent 8a9b978 commit 18385c5
Show file tree
Hide file tree
Showing 9 changed files with 377 additions and 10 deletions.
2 changes: 1 addition & 1 deletion include/tvm/te/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class PlaceholderOpNode : public OperationNode {
}

static constexpr const char* _type_key = "PlaceholderOp";
TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode);
TVM_DECLARE_BASE_OBJECT_INFO(PlaceholderOpNode, OperationNode);
};

/*!
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
# helper functions
const = expr.const
extern = expr.extern
te_tensor = expr.te_tensor

# Type
Type = ty.Type
Expand Down
161 changes: 156 additions & 5 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
# specific language governing permissions and limitations
# under the License.
"""Developer API of constructing Relax AST."""
from typing import List, Optional, Union, Dict
import typing
from typing import List, Optional, Union, Dict, Any, Callable
from tvm.relay.expr import Tuple
from tvm.runtime import Object
from tvm import relax as rx
from tvm import tir
from .expr import *
from .op.base import call_dps
from tvm._ffi.base import _LIB, check_call
from . import _ffi_api

Expand Down Expand Up @@ -72,7 +75,7 @@ class BlockBuilder(Object):
dtype1 = rx.DynTensorType(rank=1, dtype="float16")
x = rx.Var("x", [m, n], dtype0)
y = rx.Var("y", [n], dtype1)
ib = rx.IRBuilder()
ib = rx.BlockBuilder()
with ib.function([x, y], "func"):
with ib.dataflow() as df:
lv0 = ib.emit(rx.add(x, y))
Expand All @@ -84,17 +87,69 @@ class BlockBuilder(Object):

def __init__(self):
self._blocks = []
self._context_mod = tvm.IRModule()
self.__init_handle_by_constructor__(_ffi_api.BlockBuilderCreate)

def _begin_dataflow_block(self) -> None:
_ffi_api.BlockBuilderBeginDataflowBlock(self)

def _begin_binding_block(self) -> None:
_ffi_api.BlockBuilderBeginBindingBlock(self)

def _end_block(self) -> BindingBlock:
return _ffi_api.BlockBuilderEndBlock(self)

def _convert_te_arg(self,
te_args: Any
) -> typing.Tuple[Any, List[tvm.te.Tensor]]:
"""Helper function to convert Relax expressions to te tensor.
In the common case, the type of te_args is a Relax expression and is converted into a te tensor.
If te_args is a nested or recursive datatype (i.e list, dict, tvm.ir.Map, tvm.ir.Array),
we recursive and convert any value of type Relax expression into a te tensor.
Common values of type int, float, and str are preserved.
Parameters
----------
te_args : Any
Argument to convert to te
Returns
-------
ret : (Any, [tvm.te.Tensor])
A tuple of the converted te_args, and a list of te tensors for each converted Relax expression
"""
te_args_list = []

def _convert_te_arg_helper(arg):
if isinstance(arg, Expr):
arg = te_tensor(arg)
te_args_list.append(arg)
return arg
elif isinstance(arg, (list, tvm.ir.Array)):
return [_convert_te_arg_helper(x) for x in arg]
elif isinstance(arg, tuple):
return tuple([_convert_te_arg_helper(x) for x in arg])
elif isinstance(arg, (dict, tvm.ir.Map)):
for key in arg:
assert isinstance(key, str), "emit_te only supports dict with string as the key currently"
return {k: _convert_te_arg_helper(arg[k]) for k in arg}
elif isinstance(arg, (int, float, str)):
return arg
else:
raise TypeError("not supported type in emit_te: {}".format(type(arg)))

new_arg = _convert_te_arg_helper(te_args)
return new_arg, te_args_list

def _check_te_args(self, args: List[tvm.te.Tensor]):
"""check te arguments."""
#TODO(hypercubestart, ziheng) support full dynamic shape in the future
for x in args:
for s in x.shape:
if not isinstance(s, (tir.Var, tir.IntImm)):
raise ValueError("emit_te not support symbolic shape"
"contains expression now: {}".format(x.shape))

def function(self,
params: Optional[Union[Var, Tuple, List[Var]]] = None,
name: Optional[str] = "") -> FunctionScope:
Expand Down Expand Up @@ -139,7 +194,7 @@ def emit(self, call: relay.Call) -> Var:
Parameters
----------
call : tvm.relay.Call
call : tvm.relax.Call
The call node to be emitted.
Returns
Expand All @@ -149,12 +204,97 @@ def emit(self, call: relay.Call) -> Var:
"""
return _ffi_api.BlockBuilderEmit(self, call)

def emit_te(self, func: Callable, *args: Any, **kwargs: Any) -> Var:
"""Emit a call node according to the te function.
This function converts arguments from relax expression to te tensor,
The callback func should return a te tensor.
Parameters
----------
func : Callable
A function that return a te tensor.
Returns
-------
ret : tvm.relax.Var
A newly created variable that gets binded to the call code.
Example
-------
.. code-block:: python
bb = rx.BlockBuilder()
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
type_anno = rx.DynTensorType(2, "float32")
x = rx.Var("x", [n, m], type_anno)
y = rx.Var("y", [n, m], type_anno)
def te_func(args, args_dict, msg):
A = args[0]
B = args_dict["B"]
return te.compute((128, 128), lambda i, j: A[i, j] + B[i, j])
with bb.function([x, y], "rx_func"):
out = bb.emit_te(te_func, [x], {"B": y}, msg="hello")
bb.emit_func_output(out)
will result in TVMScript
.. code-block:: python
@tvm.script.ir_module
class Module:
@T.prim_func
def te_func(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_compute: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "te_func"})
m = T.var("int64")
n = T.var("int64")
rxplaceholder = T.match_buffer(var_rxplaceholder, [n, m], dtype="float32")
rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [n, m], dtype="float32")
compute = T.match_buffer(var_compute, [128, 128], dtype="float32")
# body
# with T.block("root")
for i0, i1 in T.grid(128, 128):
with T.block("compute"):
i, j = T.axis.remap("SS", [i0, i1])
T.reads([rxplaceholder[i, j], rxplaceholder_1[i, j]])
T.writes([compute[i, j]])
compute[i, j] = rxplaceholder[i, j] + rxplaceholder_1[i, j]
@R.function
def rx_func(x: Tensor[(n, m), "float32"], y: Tensor[(n, m), "float32"]) -> Tensor:
# block 0
gv = relax.call_dps((128, 128), "te_func", (x, y))
return gv
"""
new_args, te_arg_list = self._convert_te_arg(args)
new_kwargs, te_kwarg_list = self._convert_te_arg(kwargs)

te_args = te_arg_list + te_kwarg_list
self._check_te_args(te_args)

# TODO(hypercubestart, ziheng) handle multiple output case
te_out = func(*new_args, **new_kwargs)
assert isinstance(te_out, tvm.te.tensor.Tensor), "only support te tensor as function output"

inputs = [*te_args, te_out]
tir_func = tvm.te.create_prim_func(inputs)
func_name = _ffi_api.BlockBuilderGetUniqueName(self, func.__name__)
tir_func = tir_func.with_attr("global_symbol", func_name)
gvar = GlobalVar(func_name)
self._context_mod[gvar] = tir_func
call = call_dps(inputs[-1].shape, gvar, [x.op.value for x in inputs[:-1]])
return _ffi_api.BlockBuilderEmit(self, call)


def match_shape(self, value: Expr, pattern: List[PrimExpr]) -> Var:
"""Emit a MatchShape.
Parameters
----------
value : tvm.relay.Expr
value : tvm.relax.Expr
The value of the MatchShape to be emitted.
pattern : List[PrimExpr]
Expand Down Expand Up @@ -224,8 +364,19 @@ def get(self) -> Function:
ret : tvm.relax.Function
A Relax function node being built.
"""
# TODO(hyoercubestart, ziheng) get should return IRModule with relax + TIR functions
seqe = rx.SeqExpr(self._blocks, self._func_ret)
func = rx.Function(
self._func_params, seqe, rx.DynTensorType(-1, "float32"), rx.GlobalVar(self._func_name)
)
return func

def context_mod(self):
"""Return the context module that might contain tir functions.
Returns
-------
mod : tvm.IRModule
The context module that contains tir functions during emit.
"""
return self._context_mod
8 changes: 7 additions & 1 deletion python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,5 +169,11 @@ def __init__(self, global_symbol: String, span: Span = None) -> None:
self.__init_handle_by_constructor__(_ffi_api.ExternFunc, global_symbol, span)


def extern(name, span: Span = None):
def extern(name: str, span: Span = None):
"""Create extern function."""
return ExternFunc(name, span)


def te_tensor(value: Expr, name: str = "rxplaceholder"):
"""Create te tensor from relax expression."""
return _ffi_api.TETensor(value, name)
4 changes: 2 additions & 2 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
from ...ir import BaseFunc
from ...ir import BaseFunc, Array
from ..expr import Expr, ShapeExpr, Tuple, Call
from . import _ffi_api
from typing import Union, List
Expand Down Expand Up @@ -41,7 +41,7 @@ def call_dps(
ret: Call
A call node for the call_dps operator.
"""
if isinstance(shape, (list, tuple)):
if isinstance(shape, (list, tuple, Array)):
shape = ShapeExpr(shape)
if isinstance(args, (list, tuple)):
args = Tuple(args)
Expand Down
5 changes: 5 additions & 0 deletions src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -550,5 +550,10 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderNormalize")
return builder->Normalize(expr);
});

TVM_REGISTER_GLOBAL("relax.BlockBuilderGetUniqueName")
.set_body_typed([](BlockBuilder builder, String name_hint) {
return builder->name_table()->GetUniqueName(name_hint);
});

} // namespace relax
} // namespace tvm
64 changes: 64 additions & 0 deletions src/relax/ir/emit_te.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file relax/src/ir/emit_te.cc
*/
#include <tvm/relax/type.h>
#include "./emit_te.h"

namespace tvm {
namespace relax {

// RXPlaceholderOpNode
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RXPlaceholderOpNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RXPlaceholderOpNode*>(node.get());
p->stream << "rxplaceholder(" << op->name << ", " << op << ")";
});

TVM_REGISTER_NODE_TYPE(RXPlaceholderOpNode);

te::Tensor TETensor(Expr value, std::string name) {
auto n = make_object<RXPlaceholderOpNode>();
n->name = name;
n->value = value;

Expr shape_expr = value->shape();
CHECK(shape_expr->IsInstance<ShapeExprNode>())
<< "ValueError: Expression does not have an known symbolic shape, please consider use match_shape "
<< "to constrain the shape before passing into te_tensor";
Array<PrimExpr> shape = Downcast<ShapeExpr>(shape_expr)->values;
n->shape = shape;
Type type = value->checked_type();
ICHECK(type->IsInstance<DynTensorTypeNode>())
<< "ValueError: Expression should have a inferred DynTensorType: "
<< type->GetTypeKey();
DataType dtype = Downcast<DynTensorType>(type)->dtype;
n->dtype = dtype;
return te::PlaceholderOp(n).output(0);
}

TVM_REGISTER_GLOBAL("relax.TETensor")
.set_body_typed([](Expr value, std::string name) {
return TETensor(value, name);
});

} // namespace relax
} // namespace tvm
Loading

0 comments on commit 18385c5

Please sign in to comment.