diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py index f0a83555f9d73..8bacea5e29f7c 100644 --- a/python/tvm/topi/cuda/scatter_elements.py +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -126,10 +126,12 @@ def gen_ir(data, indices, updates, out): ib.scope_attr(bx1, "thread_extent", num_blocks_1) ib.scope_attr(tx1, "thread_extent", max_threads) - ind_fused = bx2 * max_threads + tx2 + ind_fused = bx1 * max_threads + tx1 with ib.if_scope(ind_fused < ind_full_range): index_check = tir.LT(indices_ptr[ind_fused], tir.const(0, indices.dtype)) - indices_ptr[ind_fused] += tir.Select(index_check, axis_range, tir.const(0, indices.dtype)) + indices_ptr[ind_fused] += tir.Select( + index_check, axis_range, tir.const(0, indices.dtype) + ) # TODO(vvchernov): assert for index out of bounds # TODO (vvchernov): use atomic function for special conditions (see cuda.scatter_nd)