Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Feb 9, 2023
1 parent efa5927 commit 00a5f6c
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions python/tvm/topi/cuda/scatter_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,12 @@ def gen_ir(data, indices, updates, out):
ib.scope_attr(bx1, "thread_extent", num_blocks_1)
ib.scope_attr(tx1, "thread_extent", max_threads)

ind_fused = bx2 * max_threads + tx2
ind_fused = bx1 * max_threads + tx1
with ib.if_scope(ind_fused < ind_full_range):
index_check = tir.LT(indices_ptr[ind_fused], tir.const(0, indices.dtype))
indices_ptr[ind_fused] += tir.Select(index_check, axis_range, tir.const(0, indices.dtype))
indices_ptr[ind_fused] += tir.Select(
index_check, axis_range, tir.const(0, indices.dtype)
)
# TODO(vvchernov): assert for index out of bounds

# TODO (vvchernov): use atomic function for special conditions (see cuda.scatter_nd)
Expand Down

0 comments on commit 00a5f6c

Please sign in to comment.