Skip to content

Commit

Permalink
[EmitTE] EmitTE Symbolic Shape (apache#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
hypercubestart authored and yongwww committed Aug 14, 2022
1 parent 7f7eff8 commit 93cb308
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 11 deletions.
32 changes: 24 additions & 8 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,29 @@ def _convert_te_arg_helper(arg):
new_arg = _convert_te_arg_helper(te_args)
return new_arg, te_args_list

def _check_te_args(self, args: List[tvm.te.Tensor]):
"""check te arguments."""
#TODO(hypercubestart, ziheng) support full dynamic shape in the future
for x in args:
def _check_te_args(self, args: List[tvm.te.Tensor], te_out: tvm.te.Tensor):
"""check te arguments"""
#TODO(hypercubestart, ziheng) support case where match_buffer doesn't bind to a variable
tensors = args + [te_out]
bound_vars = set()
used_vars = set()

def _populate_used_vars(expr):
if isinstance(expr, tvm.tir.Var):
used_vars.add(expr)

for x in tensors:
for s in x.shape:
if not isinstance(s, (tir.Var, tir.IntImm)):
raise ValueError("emit_te not support symbolic shape"
"contains expression now: {}".format(x.shape))
tvm.tir.stmt_functor.post_order_visit(s, _populate_used_vars)
if isinstance(s, tir.Var):
bound_vars.add(s)

diff = used_vars - bound_vars

if len(diff) != 0:
# there are TIR variable in shape expressions that are not bound by match buffer
raise ValueError("emit_te does not support TE functions with unbound tir.Vars: {}".format(diff))


def function(self,
params: Optional[Union[Var, Tuple, List[Var]]] = None,
Expand Down Expand Up @@ -280,12 +295,13 @@ def rx_func(x: Tensor[(n, m), "float32"], y: Tensor[(n, m), "float32"]) -> Tenso
new_kwargs, te_kwarg_list = self._convert_te_arg(kwargs)

te_args = te_arg_list + te_kwarg_list
self._check_te_args(te_args)

# TODO(hypercubestart, ziheng) handle multiple output case
te_out = func(*new_args, **new_kwargs)
assert isinstance(te_out, tvm.te.tensor.Tensor), "only support te tensor as function output"

self._check_te_args(te_args, te_out)

inputs = [*te_args, te_out]
tir_func = tvm.te.create_prim_func(inputs)
func_name = _ffi_api.BlockBuilderGetUniqueName(self, func.__name__)
Expand Down
2 changes: 1 addition & 1 deletion src/relax/transform/call_dps_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class CallDPSMutator : public ExprMutator {

if (call->op == call_dps_op) {
ShapeExpr output_shape = Downcast<ShapeExpr>(call->args[0]);
Var tensor = builder_->Emit(Call(alloc_tensor_op, {call->args[0]}), "alloc");
Var tensor = builder_->Emit(Call(alloc_tensor_op, {output_shape}), "alloc");
Array<Expr> args;
if (call->args[2].as<TupleNode>()) {
args = Downcast<Tuple>(call->args[2])->fields;
Expand Down
8 changes: 8 additions & 0 deletions src/relax/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ std::string ExecutableNode::Stats() const {
}
oss.seekp(-2, oss.cur);
oss << "], ";
} else if (it.IsObjectRef<ShapeTuple>()){
ShapeTuple shape = it.operator ShapeTuple();
oss << "shapetuple[";
for (size_t i = 0; i < shape.size(); ++i) {
oss << shape.at(i) << ", ";
}
oss.seekp(-2, oss.cur);
oss << "], ";
} else {
try {
DLDataType dtype = it.operator DLDataType();
Expand Down
98 changes: 97 additions & 1 deletion tests/python/relax/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
import numpy as np
import tvm
from tvm.relay import Call
from tvm import relax, tir
from tvm import relax, tir, te
from tvm.runtime import container
import numpy as np

from tvm.ir.base import assert_structural_equal
import tvm.script
from tvm.script import tir as T, relax as R

Expand Down Expand Up @@ -425,6 +426,98 @@ def test_vm_emit_te_extern():
expected = np.dot(data.asnumpy(), weight.asnumpy())
np.testing.assert_allclose(expected, res.asnumpy(), rtol=1e-4, atol=1e-4)

def test_vm_emit_te_concat():
# concatenate of two vectors of size (n,) and (m,)
bb = relax.BlockBuilder()
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
type_anno = relax.DynTensorType(1, "float32")
x = relax.Var("x", [n], type_anno)
y = relax.Var("y", [m], type_anno)

def te_func(A, B):
C = te.compute((n + m), lambda i: tvm.tir.if_then_else(i < n, A[i], B[i-n]))
return C

with bb.function([x, y], "rx_func"):
x1 = bb.emit_te(te_func, x, y)
bb.emit_func_output(x1)

mod = bb.get()

target = tvm.target.Target("llvm")
target_host = tvm.target.Target("llvm")
ex, lib = relax.vm.build(mod, target, target_host)

vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
inp = tvm.nd.array(np.random.rand(1, ).astype(np.float32))
inp2 = tvm.nd.array(np.random.rand(2, ).astype(np.float32))
res = vm["rx_func"](inp, inp2)

np.testing.assert_allclose(res.asnumpy(), np.append(inp.asnumpy(), inp2.asnumpy()))

def test_vm_emit_te_floor_symbolic_shape():
bb = relax.BlockBuilder()
n = tir.Var("n", "int64")
type_anno = relax.DynTensorType(1, "float32")
x = relax.Var("x", [n], type_anno)

def te_func(A):
C = te.compute((tir.floordiv(n, 2),), lambda i: A[i] + 1)
return C

with bb.function([x], "rx_func"):
x1 = bb.emit_te(te_func, x)
bb.emit_func_output(x1)

mod = bb.get()

target = tvm.target.Target("llvm")
target_host = tvm.target.Target("llvm")
ex, lib = relax.vm.build(mod, target, target_host)

vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
shape = (9, )
inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32))
res = vm["rx_func"](inp)

def expected_output():
output_shape = (shape[0] // 2, )
return inp.asnumpy()[:output_shape[0]] + 1

np.testing.assert_allclose(res.asnumpy(), expected_output())

def test_vm_relax_symbolic_shape():
bb = relax.BlockBuilder()
n = tir.Var("n", "int64")
type_anno = relax.DynTensorType(1, "float32")
x = relax.Var("x", [n], type_anno)
y = relax.Var("y", [(n // 2) + 1], type_anno)

def te_func(A, B):
C = te.compute((n, ), lambda i: A[i] + B[i // 2])
return C

with bb.function([x, y], "rx_func"):
x1 = bb.emit_te(te_func, x, y)
bb.emit_func_output(x1)

mod = bb.get()

target = tvm.target.Target("llvm")
target_host = tvm.target.Target("llvm")
ex, lib = relax.vm.build(mod, target, target_host)

vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
shape1 = (5, )
shape2 = (3, )
inp = tvm.nd.array(np.random.rand(*shape1).astype(np.float32))
inp2 = tvm.nd.array(np.random.rand(*shape2).astype(np.float32))
res = vm["rx_func"](inp, inp2)

def expected_output():
return inp.asnumpy() + np.repeat(inp2.asnumpy(), 2)[:5]

np.testing.assert_allclose(res.asnumpy(), expected_output())

if __name__ == "__main__":
test_vm_execute()
Expand All @@ -443,3 +536,6 @@ def test_vm_emit_te_extern():
test_vm_compile_e2e()
test_vm_compile_e2e_func_param_with_shape()
test_vm_emit_te_extern()
test_vm_emit_te_concat()
test_vm_emit_te_floor_symbolic_shape()
test_vm_relax_symbolic_shape()
26 changes: 25 additions & 1 deletion tests/python/unittest/test_te_create_primfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# under the License
# pylint: disable=missing-function-docstring,missing-module-docstring
import numpy as np
import tvm
Expand Down Expand Up @@ -345,6 +345,29 @@ def test_data_dependent_access():
tvm.testing.assert_allclose(a_np[b_np], c.numpy())


def test_loop_var_datatype():
def test_helper(dtype):
n = te.var("n", dtype)
A = te.placeholder((n,), name="A")
B = te.placeholder((n,), name="B", dtype="int32")
C = te.compute((n,), lambda i: A[i] + B[i])

func = te.create_prim_func([C, A, B])

assert func.body.block.body.loop_var.dtype == dtype

func = tvm.build(func)

a_np = np.random.uniform(size=(10,)).astype(A.dtype)
b_np = np.random.uniform(size=(10,)).astype(B.dtype)
c = tvm.nd.array(np.zeros(10, dtype=C.dtype))
func(c, tvm.nd.array(a_np), tvm.nd.array(b_np))
tvm.testing.assert_allclose(a_np + b_np, c.numpy())

test_helper("int32")
test_helper("int64")


def test_select_simplify():
placeholder = te.placeholder([1, 128, 10, 10, 4], dtype="float32")
tensor = topi.nn.adaptive_pool(placeholder, [1, 1], "avg", "NCHW4c")
Expand Down Expand Up @@ -568,3 +591,4 @@ def expected(
test_argmax_val_idx()
test_int64_indices()
test_zero_dim_add()
test_loop_var_datatype()

0 comments on commit 93cb308

Please sign in to comment.