Skip to content

Commit

Permalink
fix cpu scatter elements
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Feb 8, 2023
1 parent 6708902 commit f0fdaf1
Showing 1 changed file with 27 additions and 15 deletions.
42 changes: 27 additions & 15 deletions python/tvm/topi/scatter_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,33 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"):
if not isinstance(axis, int):
axis = utils.get_const_int(axis)

shape = data.shape
axis_range = cast(shape[axis], indices.dtype)

if axis < 0:
axis = len(shape) + axis

shape = data.shape
axis_range = cast(shape[axis], indices.dtype)

# Prepare ranges and strides
before_axis_range = 1
full_range = 1
after_axis_range = 1
for i, value in enumerate(shape, 0):
if i < axis:
before_axis_range *= value
elif i > axis:
full_range *= value
if i > axis:
after_axis_range *= value
before_axis_stride = axis_range * after_axis_range
full_range = before_axis_range * before_axis_stride

ind_shape = indices.shape
ind_axis_range = ind_shape[axis]

# Prepare ranges and strides
ind_before_axis_range = 1
ind_after_axis_range = 1
for i, value in enumerate(ind_shape, 0):
if i < axis:
ind_before_axis_range *= value
elif i > axis:
ind_after_axis_range *= value
ind_before_axis_stride = ind_axis_range * ind_after_axis_range

def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
# pylint: disable=invalid-name
Expand All @@ -98,18 +109,19 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):

# TODO(vvchernov): find optimal parallel approach
with ib.for_range(
0, before_axis_range * after_axis_range, "fused", kind="parallel"
0, ind_before_axis_range * ind_after_axis_range, "fused", kind="parallel"
) as fused:
i = fused // after_axis_range
j = fused % after_axis_range
pre_index = i * before_axis_stride + j
with ib.for_range(0, axis_range, "k") as k:
index1 = pre_index + k * after_axis_range
i = fused // ind_after_axis_range
j = fused % ind_after_axis_range
with ib.for_range(0, ind_axis_range, "k") as k:
# Offset allong indices or updates
index1 = i * ind_before_axis_stride + k * ind_after_axis_range + j
# TODO(vvchernov): assert for out of bounds, separated check for indices
k_new = indices[index1]
index_check = tir.LT(k_new, tir.const(0, indices.dtype))
k_new += tir.Select(index_check, axis_range, tir.const(0, indices.dtype))
index2 = pre_index + k_new * after_axis_range
# Offset allong data
index2 = i * before_axis_stride + k_new * after_axis_range + j
if reduction == "update":
out[index2] = updates[index1]
elif reduction == "add":
Expand Down

0 comments on commit f0fdaf1

Please sign in to comment.