Skip to content

Commit

Permalink
Change new concat (#11800)
Browse files Browse the repository at this point in the history
* changed x86/concat to use lists of ints instead of te.tensor.Tensor for loop extents and array offsets

* typos fixed

* removed unused import

* fixed micro model test

* fixed micro model test
  • Loading branch information
Sebastian Boblest authored Jun 22, 2022
1 parent 32d16eb commit 5056eb7
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 27 deletions.
48 changes: 23 additions & 25 deletions python/tvm/topi/x86/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
import tvm
from tvm import te
import numpy as np
from ..utils import get_const_int, const_vector
from ..utils import get_const_int


def concatenate(data: tvm.te.Tensor, axis: Optional[int] = 0):
"""Join a sequence of arrays along an existing axis. Optimized for CPU exeution.
"""Join a sequence of arrays along an existing axis.
Optimized for CPU execution.
Parameters
----------
Expand All @@ -38,48 +39,45 @@ def concatenate(data: tvm.te.Tensor, axis: Optional[int] = 0):
ret : tvm.te.Tensor
"""

def gen_ir_1d(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf):
"""Custom conactenation execution."""
in_outers = [int(np.prod(i.shape[axis:])) for i in data]
in_outers_cumsum = [0, *np.cumsum(in_outers, dtype="int64")[0:-1]]

def gen_ir_1d(data_bufs, out_buf):
"""Custom concatenation execution."""
i_b = tvm.tir.ir_builder.create()
data_bufs1 = [i_b.buffer_ptr(data_buf) for data_buf in data_bufs]
out_buf = i_b.buffer_ptr(out_buf)
outers = i_b.buffer_ptr(in_outers_tensor)
cumsum = i_b.buffer_ptr(in_cumsum_tensor)

for i in range(len(data)):
with i_b.for_range(0, outers[i], name="j") as j:
out_buf[cumsum[i] + j] = data_bufs1[i][j]
with i_b.for_range(0, in_outers[i], name="j") as j:
out_buf[in_outers_cumsum[i] + j] = data_bufs1[i][j]
return i_b.get()

def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer):
"""Common case of conactenation execution."""
def gen_ir(data_bufs, out_buf, inner, outer):
"""Common case of concatenation execution."""
i_b = tvm.tir.ir_builder.create()
data_bufs1 = [i_b.buffer_ptr(data_buf) for data_buf in data_bufs]
out_buf = i_b.buffer_ptr(out_buf)
outers = i_b.buffer_ptr(in_outers_tensor)
cumsum = i_b.buffer_ptr(in_cumsum_tensor)
if inner > 1:
with i_b.for_range(0, inner, name="inn", kind="parallel") as inn:
pos = inn * outer
for i in range(len(data)):
offset = inn * outers[i]
with i_b.for_range(0, outers[i], name="j") as j:
out_buf[pos + cumsum[i] + j] = data_bufs1[i][offset + j]
offset = inn * in_outers[i]
with i_b.for_range(0, in_outers[i], name="j") as j:
out_buf[pos + in_outers_cumsum[i] + j] = data_bufs1[i][offset + j]
else:
for i in range(len(data)):
with i_b.for_range(0, outers[i], name="j", kind="parallel") as j:
out_buf[cumsum[i] + j] = data_bufs1[i][j]
with i_b.for_range(0, in_outers[i], name="j", kind="parallel") as j:
out_buf[in_outers_cumsum[i] + j] = data_bufs1[i][j]
return i_b.get()

if axis < 0:
axis += len(data[0].shape)
concat_axis_sizes = [int(t.shape[axis]) for t in data]
join_size = int(np.sum(concat_axis_sizes))
in_outers = [int(np.prod(i.shape[axis:])) for i in data]
in_outers_cumsum = [0, *np.cumsum(in_outers, dtype="int64")[0:-1]]

dtype = data[0].dtype
out_shape = data[0].shape[:axis] + [join_size] + data[0].shape[axis + 1 :]
in_outers_tensor = const_vector(in_outers)
in_cumsum_tensor = const_vector(in_outers_cumsum, name="cumsum")
right_val = np.prod(out_shape[axis:])
left_val = np.prod(out_shape[:axis])

Expand All @@ -92,8 +90,8 @@ def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer)
# badly parallelized case
return te.extern(
[out_shape],
list(data) + [in_outers_tensor, in_cumsum_tensor],
lambda ins, outs: gen_ir_1d(ins, ins[-2], ins[-1], outs[0]),
list(data),
lambda ins, outs: gen_ir_1d(ins, outs[0]),
dtype=dtype,
name="concatenate_ext",
)
Expand All @@ -102,8 +100,8 @@ def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer)
outer = get_const_int(int(right_val))
return te.extern(
[out_shape],
list(data) + [in_outers_tensor, in_cumsum_tensor],
lambda ins, outs: gen_ir(ins, ins[-2], ins[-1], outs[0], inner, outer),
list(data),
lambda ins, outs: gen_ir(ins, outs[0], inner, outer),
dtype=dtype,
name="concatenate_ext",
)
4 changes: 2 additions & 2 deletions tests/python/unittest/test_micro_model_library_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def test_export_byoc_c_module():
"constants_size_bytes": 0,
"device": 1,
"io_size_bytes": 4800,
"workspace_size_bytes": 1264,
"workspace_size_bytes": 1200,
}
]
else:
Expand All @@ -469,7 +469,7 @@ def test_export_byoc_c_module():
"constants_size_bytes": 0,
"device": 1,
"io_size_bytes": 4800,
"workspace_size_bytes": 1248,
"workspace_size_bytes": 1200,
}
]

Expand Down

0 comments on commit 5056eb7

Please sign in to comment.