Skip to content

Commit

Permalink
scatter elements was basically implemented for topi/cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Feb 8, 2023
1 parent bd94ecd commit 6708902
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .nms import get_valid_counts, non_max_suppression, all_class_non_max_suppression
from .rcnn import *
from .scatter import *
from .scatter_elements import *
from .sort import *
from .conv2d_nhwc_tensorcore import *
from .conv3d_ndhwc_tensorcore import *
Expand Down
157 changes: 157 additions & 0 deletions python/tvm/topi/cuda/scatter_elements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Scatter operator """
import tvm
from tvm import te
from tvm import tir
from ..utils import ceil_div, get_const_int
from ..math import cast


def scatter_elements(data, indices, updates, axis=0, reduction="update"):
"""Scatter elements from updates to corresponding indices of copied data.
Data, indices, updates and output have the same shape.
Indices can not have duplicates (if idx1 != idx2, then indices[idx1] != indices[idx2])
if reduction == "update".
.. code-block::
output[indices[i][j]][j] = f(output[indices[i][j]][j], updates[i][j]) if axis = 0
output[i][indices[i][j]] = f(output[i][indices[i][j]], updates[i][j]) if axis = 1
where the update function f is determinted by the reduction.
Five types of the function are supported: "update", "add", "mul", "min" and "max" (see below)
Parameters
----------
data : tvm.te.Tensor
The source array.
indices : tvm.te.Tensor
The indices of the values to extract.
updates : tvm.te.Tensor
The updates to apply at the Indices
axis : optional, int
The axis to scatter on. It is zero by default.
reduction : optional, string
The update mode for the algorithm, either "update", "add", "mul", "min" or "max"
If update, the update values will replace the input data
If add, the update values will be added to the input data
If mul, the update values will be multiply to the input data
If min, there is choice of minimal between the update values and the input data
If max, there is choice of maximal between the update values and the input data
It is "update" by default
Returns
-------
ret : tvm.te.Tensor
"""
if not isinstance(axis, int):
axis = get_const_int(axis)

shape = data.shape
axis_range = cast(shape[axis], indices.dtype)

if axis < 0:
axis = len(shape) + axis

# Prepare ranges and strides
before_axis_range = 1
after_axis_range = 1
for i, value in enumerate(shape, 0):
if i < axis:
before_axis_range *= value
elif i > axis:
after_axis_range *= value
before_axis_stride = axis_range * after_axis_range
full_range = before_axis_range * before_axis_stride
full_range_excl_axis = before_axis_range * after_axis_range

def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
ib = tir.ir_builder.create()

data = ib.buffer_ptr(data_ptr)
indices = ib.buffer_ptr(indices_ptr)
updates = ib.buffer_ptr(updates_ptr)
out = ib.buffer_ptr(out_ptr)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
# Copy initial input data to output
with ib.new_scope():
num_blocks = ceil_div(full_range, max_threads)
bx = te.thread_axis("blockIdx.x")
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(bx, "thread_extent", num_blocks)
ib.scope_attr(tx, "thread_extent", max_threads)

index = bx * max_threads + tx
with ib.if_scope(index < full_range):
out[index] = data[index]

# TODO (vvchernov): use atomic function for special conditions (see cuda.scatter_nd)
with ib.new_scope():
num_blocks_2 = ceil_div(full_range_excl_axis, max_threads)
bx2 = te.thread_axis("blockIdx.x")
tx2 = te.thread_axis("threadIdx.x")
ib.scope_attr(bx2, "thread_extent", num_blocks_2)
ib.scope_attr(tx2, "thread_extent", max_threads)

fused = bx2 * max_threads + tx2
with ib.if_scope(fused < full_range_excl_axis):
i = fused // after_axis_range
j = fused % after_axis_range
pre_index = i * before_axis_stride + j
with ib.for_range(0, axis_range, "k") as k:
index1 = pre_index + k * after_axis_range
# TODO(vvchernov): assert for out of bounds, separated check for indices
k_new = indices[index1]
index_check = tir.LT(k_new, tir.const(0, indices.dtype))
k_new += tir.Select(index_check, axis_range, tir.const(0, indices.dtype))
index2 = pre_index + k_new * after_axis_range
if reduction == "update":
out[index2] = updates[index1]
elif reduction == "add":
out[index2] += updates[index1]
elif reduction == "mul":
out[index2] *= updates[index1]
elif reduction == "min":
tir.min(out[index2], updates[index1])
elif reduction == "max":
tir.max(out[index2], updates[index1])
else:
raise NotImplementedError(
"scatter_elements reduction not in [update, add, mul, min, max]:",
reduction,
)

return ib.get()

out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf")
return te.extern(
[data.shape],
[data, indices, updates],
lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_elements_cuda",
tag="scatter_elements_cuda",
)

0 comments on commit 6708902

Please sign in to comment.