diff --git a/python/tvm/topi/scatter_elements.py b/python/tvm/topi/scatter_elements.py index adaff6a75bd49..b81a341ae0ea4 100644 --- a/python/tvm/topi/scatter_elements.py +++ b/python/tvm/topi/scatter_elements.py @@ -66,22 +66,33 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): if not isinstance(axis, int): axis = utils.get_const_int(axis) - shape = data.shape - axis_range = cast(shape[axis], indices.dtype) - if axis < 0: axis = len(shape) + axis + shape = data.shape + axis_range = cast(shape[axis], indices.dtype) + # Prepare ranges and strides - before_axis_range = 1 + full_range = 1 after_axis_range = 1 for i, value in enumerate(shape, 0): - if i < axis: - before_axis_range *= value - elif i > axis: + full_range *= value + if i > axis: after_axis_range *= value before_axis_stride = axis_range * after_axis_range - full_range = before_axis_range * before_axis_stride + + ind_shape = indices.shape + ind_axis_range = ind_shape[axis] + + # Prepare ranges and strides + ind_before_axis_range = 1 + ind_after_axis_range = 1 + for i, value in enumerate(ind_shape, 0): + if i < axis: + ind_before_axis_range *= value + elif i > axis: + ind_after_axis_range *= value + ind_before_axis_stride = ind_axis_range * ind_after_axis_range def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): # pylint: disable=invalid-name @@ -98,18 +109,19 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): # TODO(vvchernov): find optimal parallel approach with ib.for_range( - 0, before_axis_range * after_axis_range, "fused", kind="parallel" + 0, ind_before_axis_range * ind_after_axis_range, "fused", kind="parallel" ) as fused: - i = fused // after_axis_range - j = fused % after_axis_range - pre_index = i * before_axis_stride + j - with ib.for_range(0, axis_range, "k") as k: - index1 = pre_index + k * after_axis_range + i = fused // ind_after_axis_range + j = fused % ind_after_axis_range + with ib.for_range(0, ind_axis_range, "k") as k: + # Offset allong indices or updates + index1 = i * ind_before_axis_stride + k * ind_after_axis_range + j # TODO(vvchernov): assert for out of bounds, separated check for indices k_new = indices[index1] index_check = tir.LT(k_new, tir.const(0, indices.dtype)) k_new += tir.Select(index_check, axis_range, tir.const(0, indices.dtype)) - index2 = pre_index + k_new * after_axis_range + # Offset allong data + index2 = i * before_axis_stride + k_new * after_axis_range + j if reduction == "update": out[index2] = updates[index1] elif reduction == "add":