Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Support ScatterElements with reduction #13894

Merged
merged 37 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
74e413f
add ScatterElements converter to ONNX front-end
Feb 1, 2023
eafbe11
native front-end for ScatterElements was implemented
Feb 2, 2023
12cf532
update ScatterElements in ONNX high-level front-end
Feb 2, 2023
64656a7
update comments
Feb 3, 2023
90a4b30
register ScatterElementsAttrs
Feb 3, 2023
709b74e
register scatter elements strategy
Feb 3, 2023
1327f7d
implement generic scatter elements in topi
Feb 3, 2023
f16648c
fix min-max redefinition
Feb 4, 2023
39f4fc0
fix IntImm conversion and update scatter element implementation
Feb 4, 2023
c293094
fix parallel approach
Feb 4, 2023
e92ffb9
CI tests for scatter elements were added
Feb 5, 2023
f0fb416
small update of description
Feb 5, 2023
7f9128d
sphinx issue was fixed
Feb 5, 2023
1e3a663
fix scatter deprecation in the CI test
Feb 5, 2023
c2653b3
fix
Feb 5, 2023
a7e0aae
fix scatter version support
Feb 6, 2023
d527411
fix negative indices
Feb 6, 2023
3525fcb
add scatter elements strategy for cuda, gpu
Feb 6, 2023
38af70c
update assert comment, update check of negative indices, hide tests f…
Feb 6, 2023
ab7cf51
fixes
Feb 6, 2023
5984eb3
extend error log for convenient analysis
Feb 7, 2023
46199da
lint fix
Feb 7, 2023
fd2ad9a
fix
Feb 7, 2023
e14e4dd
sync dtypes
Feb 7, 2023
77e2308
update cpu tir for scatter elements by scan example
Feb 8, 2023
3276dd7
scatter elements was basically implemented for topi/cuda
Feb 8, 2023
da839a4
fix cpu scatter elements
Feb 8, 2023
b0e1f12
fix gpu scatter elements
Feb 8, 2023
c617f22
fix
Feb 8, 2023
21ce735
small update
Feb 8, 2023
1ebe71d
transfer indices check out of general loop
Feb 8, 2023
5721d94
trancsfer ranges and strides calculation to gpu device
Feb 9, 2023
2bbff6a
fixes
Feb 9, 2023
fe27ea8
fix axis
Feb 9, 2023
62c56b7
clean code
Feb 10, 2023
1fa653f
fix after review
Feb 16, 2023
ac7c230
fix lint
Feb 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,18 @@ struct ScatterAddAttrs : public tvm::AttrsNode<ScatterAddAttrs> {
}
};

struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
Integer axis;
String reduction;

TVM_DECLARE_ATTRS(ScatterElementsAttrs, "relay.attrs.ScatterElementsAttrs") {
TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values.");
TVM_ATTR_FIELD(reduction).set_default("update").describe(
"Reduction mode of the scatter elements, "
"either \"update\", \"add\", \"mul\", \"min\" or \"max\".");
}
};

struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> {
String mode;

Expand Down
52 changes: 50 additions & 2 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2848,11 +2848,59 @@ class Scatter(OnnxOpConverter):
"""Operator converter for Scatter."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
def _impl_v9(cls, inputs, attr, params):
axis = attr.get("axis", 0)
return _op.scatter(inputs[0], inputs[1], inputs[2], axis)


class ScatterElements(OnnxOpConverter):
"""Operator converter for ScatterElements."""

@classmethod
def _args_check(cls, inputs, attr, red_valids=None):
ret = []
assert (
len(inputs) == 3
), "ScatterElements takes 3 inputs (data, indices, updates), {} given".format(len(inputs))
assert infer_type(inputs[1]).checked_type.dtype in ["int32", "int64"]

axis = attr.get("axis", 0)
rank = len(infer_shape(inputs[0]))
assert rank > 0, "Data rank higher than 0 is expected"
assert -rank <= axis < rank, "Axis is out of bounds"
ret.append(axis)

if red_valids:
reduction = attr.get("reduction", None)
if reduction is None:
reduction = b"update"
reduction = reduction.decode("utf-8")
assert reduction in red_valids, "Only {} modes are supported, but {} is gotten".format(
red_valids, reduction
)
ret.append(reduction)

return ret

@classmethod
def _impl_v11(cls, inputs, attr, params):
axis = cls._args_check(inputs, attr)

return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis, "update")

@classmethod
def _impl_v16(cls, inputs, attr, params):
axis, reduction = cls._args_check(inputs, attr, ["update", "add", "mul"])

return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis, reduction)

@classmethod
def _impl_v18(cls, inputs, attr, params):
axis, reduction = cls._args_check(inputs, attr, ["update", "add", "mul", "min", "max"])

return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis, reduction)


class ScatterND(OnnxOpConverter):
"""Operator converter for ScatterND."""

Expand Down Expand Up @@ -6588,7 +6636,7 @@ def _get_convert_map(opset):
"Compress": Compress.get_converter(opset),
"Size": AttrCvt("ndarray_size", extras={"dtype": "int64"}),
"Scatter": Scatter.get_converter(opset),
"ScatterElements": Scatter.get_converter(opset),
"ScatterElements": ScatterElements.get_converter(opset),
"ScatterND": ScatterND.get_converter(opset),
"EyeLike": EyeLike.get_converter(opset),
"Squeeze": Squeeze.get_converter(opset),
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,15 @@ def compute_scatter_add(attrs, inputs, output_type):

_reg.register_strategy("scatter_add", strategy.scatter_add_strategy)

# scatter_elements
@_reg.register_compute("scatter_elements")
def compute_scatter_elements(attrs, inputs, output_type):
"""Compute definition of scatter_elements"""
return [topi.scatter_elements(inputs[0], inputs[1], inputs[2], attrs.axis, attrs.reduction)]


_reg.register_strategy("scatter_elements", strategy.scatter_elements_strategy)

# scatter_nd
@_reg.register_compute("scatter_nd")
def compute_scatter_nd(attrs, inputs, output_type):
Expand Down Expand Up @@ -679,6 +688,7 @@ def argwhere_shape_func(attrs, inputs, out_ndims):

_reg.register_shape_func("scatter", False, elemwise_shape_func)
_reg.register_shape_func("scatter_add", False, elemwise_shape_func)
_reg.register_shape_func("scatter_elements", False, elemwise_shape_func)
_reg.register_shape_func("scatter_nd", False, elemwise_shape_func)


Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,11 @@ class ScatterAddAttrs(Attrs):
"""Attributes used in scatter_add operators"""


@tvm._ffi.register_object("relay.attrs.ScatterElementsAttrs")
class ScatterElementsAttrs(Attrs):
"""Attributes used in scatter_elements operators"""


@tvm._ffi.register_object("relay.attrs.ScatterNDAttrs")
class ScatterNDAttrs(Attrs):
"""Attributes used in scatter_nd operators"""
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,20 @@ def scatter_add_cuda(attrs, inputs, out_type, target):
return strategy


@scatter_elements_strategy.register(["cuda", "gpu"])
def scatter_elements_cuda(attrs, inputs, out_type, target):
"""scatter elements cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter_elements(topi.cuda.scatter_elements),
wrap_topi_schedule(topi.cuda.schedule_extern),
name="scatter_elements.cuda",
plevel=10,
)
# TODO(vvchernov): There is possible specification for rank=1 as for scatter
return strategy


@scatter_nd_strategy.register(["cuda", "gpu"])
def scatter_nd_cuda(attrs, inputs, out_type, target):
"""scatter_nd cuda strategy"""
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,6 +1580,28 @@ def scatter_add_strategy(attrs, outs, out_type, target):
return strategy


# scatter_elements
@override_native_generic_func("scatter_elements_strategy")
def scatter_elements_strategy(attrs, inputs, out_type, target):
"""scatter_elements generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter_elements(topi.scatter_elements),
wrap_topi_schedule(topi.generic.schedule_extern),
name="scatter_elements.generic",
)
return strategy


def wrap_compute_scatter_elements(topi_compute):
"""Wrap scatter_elements topi compute"""

def _compute_scatter_elements(attrs, inputs, _):
return [topi_compute(inputs[0], inputs[1], inputs[2], attrs.axis, attrs.reduction)]

return _compute_scatter_elements


# scatter_nd
@override_native_generic_func("scatter_nd_strategy")
def scatter_nd_strategy(attrs, inputs, out_type, target):
Expand Down
35 changes: 35 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,41 @@ def scatter_add(data, indices, updates, axis):
return _make.scatter_add(data, indices, updates, axis)


def scatter_elements(data, indices, updates, axis=0, reduction="update"):
"""Scatter elements with updating data by reduction of values in updates
at positions defined by indices.

Parameters
----------
data : relay.Expr
The input data to the operator.

indices : relay.Expr
The index locations to update.

updates : relay.Expr
The values to update.

axis : int
The axis to scatter elements on. It is zero by default.

reduction : string, optional
The reduction mode for scatter. Choise is from ["update", "add", "mul", "min", max"]
If update, the update values will replace the input data
If add, the update values will be added to the input data
If mul, the update values will be multiply to the input data
If min, there is choice of minimal between the update values and the input data
If max, there is choice of maximal between the update values and the input data
It is "update" by default

Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.scatter_elements(data, indices, updates, axis, reduction)


def scatter_nd(data, indices, updates, mode="update"):
"""Scatter values from an array and update.

Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .broadcast import *
from .sort import *
from .scatter import *
from .scatter_elements import *
from .sparse_fill_empty_rows import *
from .sparse_reshape import *
from .scatter_add import *
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .nms import get_valid_counts, non_max_suppression, all_class_non_max_suppression
from .rcnn import *
from .scatter import *
from .scatter_elements import *
from .sort import *
from .conv2d_nhwc_tensorcore import *
from .conv3d_ndhwc_tensorcore import *
Expand Down
Loading