Skip to content

Commit

Permalink
[ONNX] Support ScatterElements with reduction (#13894)
Browse files Browse the repository at this point in the history
* add ScatterElements converter to ONNX front-end

* native front-end for ScatterElements was implemented

* update ScatterElements in ONNX high-level front-end

* update comments

* register ScatterElementsAttrs

* register scatter elements strategy

* implement generic scatter elements in topi

* fix min-max redefinition

* fix IntImm conversion and update scatter element implementation

* fix parallel approach

* CI tests for scatter elements were added

* small update of description

* sphinx issue was fixed

* fix scatter deprecation in the CI test

* fix

* fix scatter version support

* fix negative indices

* add scatter elements strategy for cuda, gpu

* update assert comment, update check of negative indices, hide tests for 18 version

* fixes

* extend error log for convenient analysis

* lint fix

* fix

* sync dtypes

* update cpu tir for scatter elements by scan example

* scatter elements was basically implemented for topi/cuda

* fix cpu scatter elements

* fix gpu scatter elements

* fix

* small update

* transfer indices check out of general loop

* trancsfer ranges and strides calculation to gpu device

* fixes

* fix axis

* clean code

* fix after review

* fix lint

---------

Co-authored-by: Valery Chernov <[email protected]>
  • Loading branch information
vvchernov and Valery Chernov authored Feb 16, 2023
1 parent 0dd3d4a commit 0c965f4
Show file tree
Hide file tree
Showing 14 changed files with 601 additions and 8 deletions.
12 changes: 12 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,18 @@ struct ScatterAddAttrs : public tvm::AttrsNode<ScatterAddAttrs> {
}
};

struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
Integer axis;
String reduction;

TVM_DECLARE_ATTRS(ScatterElementsAttrs, "relay.attrs.ScatterElementsAttrs") {
TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values.");
TVM_ATTR_FIELD(reduction).set_default("update").describe(
"Reduction mode of the scatter elements, "
"either \"update\", \"add\", \"mul\", \"min\" or \"max\".");
}
};

struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> {
String mode;

Expand Down
52 changes: 50 additions & 2 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2848,11 +2848,59 @@ class Scatter(OnnxOpConverter):
"""Operator converter for Scatter."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
def _impl_v9(cls, inputs, attr, params):
axis = attr.get("axis", 0)
return _op.scatter(inputs[0], inputs[1], inputs[2], axis)


class ScatterElements(OnnxOpConverter):
"""Operator converter for ScatterElements."""

@classmethod
def _args_check(cls, inputs, attr, red_valids=None):
ret = []
assert (
len(inputs) == 3
), "ScatterElements takes 3 inputs (data, indices, updates), {} given".format(len(inputs))
assert infer_type(inputs[1]).checked_type.dtype in ["int32", "int64"]

axis = attr.get("axis", 0)
rank = len(infer_shape(inputs[0]))
assert rank > 0, "Data rank higher than 0 is expected"
assert -rank <= axis < rank, "Axis is out of bounds"
ret.append(axis)

if red_valids:
reduction = attr.get("reduction", None)
if reduction is None:
reduction = b"update"
reduction = reduction.decode("utf-8")
assert reduction in red_valids, "Only {} modes are supported, but {} is gotten".format(
red_valids, reduction
)
ret.append(reduction)

return ret

@classmethod
def _impl_v11(cls, inputs, attr, params):
axis = cls._args_check(inputs, attr)

return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis, "update")

@classmethod
def _impl_v16(cls, inputs, attr, params):
axis, reduction = cls._args_check(inputs, attr, ["update", "add", "mul"])

return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis, reduction)

@classmethod
def _impl_v18(cls, inputs, attr, params):
axis, reduction = cls._args_check(inputs, attr, ["update", "add", "mul", "min", "max"])

return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis, reduction)


class ScatterND(OnnxOpConverter):
"""Operator converter for ScatterND."""

Expand Down Expand Up @@ -6588,7 +6636,7 @@ def _get_convert_map(opset):
"Compress": Compress.get_converter(opset),
"Size": AttrCvt("ndarray_size", extras={"dtype": "int64"}),
"Scatter": Scatter.get_converter(opset),
"ScatterElements": Scatter.get_converter(opset),
"ScatterElements": ScatterElements.get_converter(opset),
"ScatterND": ScatterND.get_converter(opset),
"EyeLike": EyeLike.get_converter(opset),
"Squeeze": Squeeze.get_converter(opset),
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,15 @@ def compute_scatter_add(attrs, inputs, output_type):

_reg.register_strategy("scatter_add", strategy.scatter_add_strategy)

# scatter_elements
@_reg.register_compute("scatter_elements")
def compute_scatter_elements(attrs, inputs, output_type):
"""Compute definition of scatter_elements"""
return [topi.scatter_elements(inputs[0], inputs[1], inputs[2], attrs.axis, attrs.reduction)]


_reg.register_strategy("scatter_elements", strategy.scatter_elements_strategy)

# scatter_nd
@_reg.register_compute("scatter_nd")
def compute_scatter_nd(attrs, inputs, output_type):
Expand Down Expand Up @@ -679,6 +688,7 @@ def argwhere_shape_func(attrs, inputs, out_ndims):

_reg.register_shape_func("scatter", False, elemwise_shape_func)
_reg.register_shape_func("scatter_add", False, elemwise_shape_func)
_reg.register_shape_func("scatter_elements", False, elemwise_shape_func)
_reg.register_shape_func("scatter_nd", False, elemwise_shape_func)


Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,11 @@ class ScatterAddAttrs(Attrs):
"""Attributes used in scatter_add operators"""


@tvm._ffi.register_object("relay.attrs.ScatterElementsAttrs")
class ScatterElementsAttrs(Attrs):
"""Attributes used in scatter_elements operators"""


@tvm._ffi.register_object("relay.attrs.ScatterNDAttrs")
class ScatterNDAttrs(Attrs):
"""Attributes used in scatter_nd operators"""
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,20 @@ def scatter_add_cuda(attrs, inputs, out_type, target):
return strategy


@scatter_elements_strategy.register(["cuda", "gpu"])
def scatter_elements_cuda(attrs, inputs, out_type, target):
"""scatter elements cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter_elements(topi.cuda.scatter_elements),
wrap_topi_schedule(topi.cuda.schedule_extern),
name="scatter_elements.cuda",
plevel=10,
)
# TODO(vvchernov): There is possible specification for rank=1 as for scatter
return strategy


@scatter_nd_strategy.register(["cuda", "gpu"])
def scatter_nd_cuda(attrs, inputs, out_type, target):
"""scatter_nd cuda strategy"""
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,6 +1580,28 @@ def scatter_add_strategy(attrs, outs, out_type, target):
return strategy


# scatter_elements
@override_native_generic_func("scatter_elements_strategy")
def scatter_elements_strategy(attrs, inputs, out_type, target):
"""scatter_elements generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter_elements(topi.scatter_elements),
wrap_topi_schedule(topi.generic.schedule_extern),
name="scatter_elements.generic",
)
return strategy


def wrap_compute_scatter_elements(topi_compute):
"""Wrap scatter_elements topi compute"""

def _compute_scatter_elements(attrs, inputs, _):
return [topi_compute(inputs[0], inputs[1], inputs[2], attrs.axis, attrs.reduction)]

return _compute_scatter_elements


# scatter_nd
@override_native_generic_func("scatter_nd_strategy")
def scatter_nd_strategy(attrs, inputs, out_type, target):
Expand Down
35 changes: 35 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,41 @@ def scatter_add(data, indices, updates, axis):
return _make.scatter_add(data, indices, updates, axis)


def scatter_elements(data, indices, updates, axis=0, reduction="update"):
"""Scatter elements with updating data by reduction of values in updates
at positions defined by indices.
Parameters
----------
data : relay.Expr
The input data to the operator.
indices : relay.Expr
The index locations to update.
updates : relay.Expr
The values to update.
axis : int
The axis to scatter elements on. It is zero by default.
reduction : string, optional
The reduction mode for scatter. Choise is from ["update", "add", "mul", "min", max"]
If update, the update values will replace the input data
If add, the update values will be added to the input data
If mul, the update values will be multiply to the input data
If min, there is choice of minimal between the update values and the input data
If max, there is choice of maximal between the update values and the input data
It is "update" by default
Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.scatter_elements(data, indices, updates, axis, reduction)


def scatter_nd(data, indices, updates, mode="update"):
"""Scatter values from an array and update.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .broadcast import *
from .sort import *
from .scatter import *
from .scatter_elements import *
from .sparse_fill_empty_rows import *
from .sparse_reshape import *
from .scatter_add import *
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .nms import get_valid_counts, non_max_suppression, all_class_non_max_suppression
from .rcnn import *
from .scatter import *
from .scatter_elements import *
from .sort import *
from .conv2d_nhwc_tensorcore import *
from .conv3d_ndhwc_tensorcore import *
Expand Down
Loading

0 comments on commit 0c965f4

Please sign in to comment.