Skip to content

Commit

Permalink
trancsfer ranges and strides calculation to gpu device
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Feb 9, 2023
1 parent 303b543 commit efa5927
Showing 1 changed file with 30 additions and 30 deletions.
60 changes: 30 additions & 30 deletions python/tvm/topi/cuda/scatter_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,36 +67,6 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"):
if not isinstance(axis, int):
axis = get_const_int(axis)

# Prepare ranges and strides
shape = data.shape
if axis < 0:
axis = len(shape) + axis
axis_range = cast(shape[axis], indices.dtype)

before_axis_range = 1
after_axis_range = 1
for i, value in enumerate(shape, 0):
if i < axis:
before_axis_range *= value
elif i > axis:
after_axis_range *= value
before_axis_stride = axis_range * after_axis_range
full_range = before_axis_range * before_axis_stride

ind_shape = indices.shape
ind_axis_range = shape[axis]

ind_before_axis_range = 1
ind_after_axis_range = 1
for i, value in enumerate(ind_shape, 0):
if i < axis:
ind_before_axis_range *= value
elif i > axis:
ind_after_axis_range *= value
ind_before_axis_stride = ind_axis_range * ind_after_axis_range
ind_full_range = ind_before_axis_range * ind_before_axis_stride
ind_full_range_excl_axis = ind_before_axis_range * ind_after_axis_range

def gen_ir(data, indices, updates, out):
ib = tir.ir_builder.create()

Expand All @@ -105,6 +75,36 @@ def gen_ir(data, indices, updates, out):
updates_ptr = ib.buffer_ptr(updates)
out_ptr = ib.buffer_ptr(out)

# Prepare ranges and strides
shape = data.shape
if axis < 0:
axis = len(shape) + axis
axis_range = cast(shape[axis], indices.dtype)

before_axis_range = 1
after_axis_range = 1
for i, value in enumerate(shape, 0):
if i < axis:
before_axis_range *= value
elif i > axis:
after_axis_range *= value
before_axis_stride = axis_range * after_axis_range
full_range = before_axis_range * before_axis_stride

ind_shape = indices.shape
ind_axis_range = shape[axis]

ind_before_axis_range = 1
ind_after_axis_range = 1
for i, value in enumerate(ind_shape, 0):
if i < axis:
ind_before_axis_range *= value
elif i > axis:
ind_after_axis_range *= value
ind_before_axis_stride = ind_axis_range * ind_after_axis_range
ind_full_range = ind_before_axis_range * ind_before_axis_stride
ind_full_range_excl_axis = ind_before_axis_range * ind_after_axis_range

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
# Copy initial input data to output
with ib.new_scope():
Expand Down

0 comments on commit efa5927

Please sign in to comment.