Skip to content

Commit

Permalink
update cpu tir for scatter elements by scan example
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Feb 8, 2023
1 parent 4e0f37f commit bd94ecd
Showing 1 changed file with 28 additions and 25 deletions.
53 changes: 28 additions & 25 deletions python/tvm/topi/scatter_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,31 +97,34 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
out[i] = data[i]

# TODO(vvchernov): find optimal parallel approach
with ib.for_range(0, before_axis_range, "i", kind="parallel") as i:
with ib.for_range(0, after_axis_range, "j") as j:
with ib.for_range(0, axis_range, "k") as k:
pre_index = i * before_axis_stride + j
index1 = pre_index + k * after_axis_range
# TODO(vvchernov): assert for out of bounds, separated check for indices
k_new = indices[index1]
index_check = tir.LT(k_new, tir.const(0, indices.dtype))
k_new += tir.Select(index_check, axis_range, tir.const(0, indices.dtype))
index2 = pre_index + k_new * after_axis_range
if reduction == "update":
out[index2] = updates[index1]
elif reduction == "add":
out[index2] += updates[index1]
elif reduction == "mul":
out[index2] *= updates[index1]
elif reduction == "min":
tir.min(out[index2], updates[index1])
elif reduction == "max":
tir.max(out[index2], updates[index1])
else:
raise NotImplementedError(
"scatter_elements reduction not in [update, add, mul, min, max]:",
reduction,
)
with ib.for_range(
0, before_axis_range * after_axis_range, "fused", kind="parallel"
) as fused:
i = fused // after_axis_range
j = fused % after_axis_range
pre_index = i * before_axis_stride + j
with ib.for_range(0, axis_range, "k") as k:
index1 = pre_index + k * after_axis_range
# TODO(vvchernov): assert for out of bounds, separated check for indices
k_new = indices[index1]
index_check = tir.LT(k_new, tir.const(0, indices.dtype))
k_new += tir.Select(index_check, axis_range, tir.const(0, indices.dtype))
index2 = pre_index + k_new * after_axis_range
if reduction == "update":
out[index2] = updates[index1]
elif reduction == "add":
out[index2] += updates[index1]
elif reduction == "mul":
out[index2] *= updates[index1]
elif reduction == "min":
tir.min(out[index2], updates[index1])
elif reduction == "max":
tir.max(out[index2], updates[index1])
else:
raise NotImplementedError(
"scatter_elements reduction not in [update, add, mul, min, max]:",
reduction,
)

return ib.get()

Expand Down

0 comments on commit bd94ecd

Please sign in to comment.