From 0c965f442a3456cc1052338b0146aacb4fea367e Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 17 Feb 2023 01:43:01 +0400 Subject: [PATCH] [ONNX] Support ScatterElements with reduction (#13894) * add ScatterElements converter to ONNX front-end * native front-end for ScatterElements was implemented * update ScatterElements in ONNX high-level front-end * update comments * register ScatterElementsAttrs * register scatter elements strategy * implement generic scatter elements in topi * fix min-max redefinition * fix IntImm conversion and update scatter element implementation * fix parallel approach * CI tests for scatter elements were added * small update of description * sphinx issue was fixed * fix scatter deprecation in the CI test * fix * fix scatter version support * fix negative indices * add scatter elements strategy for cuda, gpu * update assert comment, update check of negative indices, hide tests for 18 version * fixes * extend error log for convenient analysis * lint fix * fix * sync dtypes * update cpu tir for scatter elements by scan example * scatter elements was basically implemented for topi/cuda * fix cpu scatter elements * fix gpu scatter elements * fix * small update * transfer indices check out of general loop * trancsfer ranges and strides calculation to gpu device * fixes * fix axis * clean code * fix after review * fix lint --------- Co-authored-by: Valery Chernov --- include/tvm/relay/attrs/transform.h | 12 ++ python/tvm/relay/frontend/onnx.py | 52 ++++++- python/tvm/relay/op/_transform.py | 10 ++ python/tvm/relay/op/op_attrs.py | 5 + python/tvm/relay/op/strategy/cuda.py | 14 ++ python/tvm/relay/op/strategy/generic.py | 22 +++ python/tvm/relay/op/transform.py | 35 +++++ python/tvm/topi/__init__.py | 1 + python/tvm/topi/cuda/__init__.py | 1 + python/tvm/topi/cuda/scatter_elements.py | 167 +++++++++++++++++++++ python/tvm/topi/scatter_elements.py | 150 ++++++++++++++++++ src/relay/op/tensor/transform.cc | 56 +++++++ src/tir/ir/expr.cc | 4 +- tests/python/frontend/onnx/test_forward.py | 80 +++++++++- 14 files changed, 601 insertions(+), 8 deletions(-) create mode 100644 python/tvm/topi/cuda/scatter_elements.py create mode 100644 python/tvm/topi/scatter_elements.py diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 82403d7c40ee..b5333961ebf9 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -164,6 +164,18 @@ struct ScatterAddAttrs : public tvm::AttrsNode { } }; +struct ScatterElementsAttrs : public tvm::AttrsNode { + Integer axis; + String reduction; + + TVM_DECLARE_ATTRS(ScatterElementsAttrs, "relay.attrs.ScatterElementsAttrs") { + TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values."); + TVM_ATTR_FIELD(reduction).set_default("update").describe( + "Reduction mode of the scatter elements, " + "either \"update\", \"add\", \"mul\", \"min\" or \"max\"."); + } +}; + struct ScatterNDAttrs : public tvm::AttrsNode { String mode; diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index aebc6daa5ebe..2a1890627225 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2848,11 +2848,59 @@ class Scatter(OnnxOpConverter): """Operator converter for Scatter.""" @classmethod - def _impl_v1(cls, inputs, attr, params): + def _impl_v9(cls, inputs, attr, params): axis = attr.get("axis", 0) return _op.scatter(inputs[0], inputs[1], inputs[2], axis) +class ScatterElements(OnnxOpConverter): + """Operator converter for ScatterElements.""" + + @classmethod + def _args_check(cls, inputs, attr, red_valids=None): + ret = [] + assert ( + len(inputs) == 3 + ), "ScatterElements takes 3 inputs (data, indices, updates), {} given".format(len(inputs)) + assert infer_type(inputs[1]).checked_type.dtype in ["int32", "int64"] + + axis = attr.get("axis", 0) + rank = len(infer_shape(inputs[0])) + assert rank > 0, "Data rank higher than 0 is expected" + assert -rank <= axis < rank, "Axis is out of bounds" + ret.append(axis) + + if red_valids: + reduction = attr.get("reduction", None) + if reduction is None: + reduction = b"update" + reduction = reduction.decode("utf-8") + assert reduction in red_valids, "Only {} modes are supported, but {} is gotten".format( + red_valids, reduction + ) + ret.append(reduction) + + return ret + + @classmethod + def _impl_v11(cls, inputs, attr, params): + axis = cls._args_check(inputs, attr) + + return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis, "update") + + @classmethod + def _impl_v16(cls, inputs, attr, params): + axis, reduction = cls._args_check(inputs, attr, ["update", "add", "mul"]) + + return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis, reduction) + + @classmethod + def _impl_v18(cls, inputs, attr, params): + axis, reduction = cls._args_check(inputs, attr, ["update", "add", "mul", "min", "max"]) + + return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis, reduction) + + class ScatterND(OnnxOpConverter): """Operator converter for ScatterND.""" @@ -6588,7 +6636,7 @@ def _get_convert_map(opset): "Compress": Compress.get_converter(opset), "Size": AttrCvt("ndarray_size", extras={"dtype": "int64"}), "Scatter": Scatter.get_converter(opset), - "ScatterElements": Scatter.get_converter(opset), + "ScatterElements": ScatterElements.get_converter(opset), "ScatterND": ScatterND.get_converter(opset), "EyeLike": EyeLike.get_converter(opset), "Squeeze": Squeeze.get_converter(opset), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index d4e4a527835a..e40179ed2d03 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -204,6 +204,15 @@ def compute_scatter_add(attrs, inputs, output_type): _reg.register_strategy("scatter_add", strategy.scatter_add_strategy) +# scatter_elements +@_reg.register_compute("scatter_elements") +def compute_scatter_elements(attrs, inputs, output_type): + """Compute definition of scatter_elements""" + return [topi.scatter_elements(inputs[0], inputs[1], inputs[2], attrs.axis, attrs.reduction)] + + +_reg.register_strategy("scatter_elements", strategy.scatter_elements_strategy) + # scatter_nd @_reg.register_compute("scatter_nd") def compute_scatter_nd(attrs, inputs, output_type): @@ -679,6 +688,7 @@ def argwhere_shape_func(attrs, inputs, out_ndims): _reg.register_shape_func("scatter", False, elemwise_shape_func) _reg.register_shape_func("scatter_add", False, elemwise_shape_func) +_reg.register_shape_func("scatter_elements", False, elemwise_shape_func) _reg.register_shape_func("scatter_nd", False, elemwise_shape_func) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index b76097722c07..ea7c415b511f 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -639,6 +639,11 @@ class ScatterAddAttrs(Attrs): """Attributes used in scatter_add operators""" +@tvm._ffi.register_object("relay.attrs.ScatterElementsAttrs") +class ScatterElementsAttrs(Attrs): + """Attributes used in scatter_elements operators""" + + @tvm._ffi.register_object("relay.attrs.ScatterNDAttrs") class ScatterNDAttrs(Attrs): """Attributes used in scatter_nd operators""" diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index fa295c93a19f..fc1691fe9ef0 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1099,6 +1099,20 @@ def scatter_add_cuda(attrs, inputs, out_type, target): return strategy +@scatter_elements_strategy.register(["cuda", "gpu"]) +def scatter_elements_cuda(attrs, inputs, out_type, target): + """scatter elements cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scatter_elements(topi.cuda.scatter_elements), + wrap_topi_schedule(topi.cuda.schedule_extern), + name="scatter_elements.cuda", + plevel=10, + ) + # TODO(vvchernov): There is possible specification for rank=1 as for scatter + return strategy + + @scatter_nd_strategy.register(["cuda", "gpu"]) def scatter_nd_cuda(attrs, inputs, out_type, target): """scatter_nd cuda strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 1cf55f7145cd..4e0448f1799b 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1580,6 +1580,28 @@ def scatter_add_strategy(attrs, outs, out_type, target): return strategy +# scatter_elements +@override_native_generic_func("scatter_elements_strategy") +def scatter_elements_strategy(attrs, inputs, out_type, target): + """scatter_elements generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scatter_elements(topi.scatter_elements), + wrap_topi_schedule(topi.generic.schedule_extern), + name="scatter_elements.generic", + ) + return strategy + + +def wrap_compute_scatter_elements(topi_compute): + """Wrap scatter_elements topi compute""" + + def _compute_scatter_elements(attrs, inputs, _): + return [topi_compute(inputs[0], inputs[1], inputs[2], attrs.axis, attrs.reduction)] + + return _compute_scatter_elements + + # scatter_nd @override_native_generic_func("scatter_nd_strategy") def scatter_nd_strategy(attrs, inputs, out_type, target): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 782797dadb83..3df13da04426 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -403,6 +403,41 @@ def scatter_add(data, indices, updates, axis): return _make.scatter_add(data, indices, updates, axis) +def scatter_elements(data, indices, updates, axis=0, reduction="update"): + """Scatter elements with updating data by reduction of values in updates + at positions defined by indices. + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + indices : relay.Expr + The index locations to update. + + updates : relay.Expr + The values to update. + + axis : int + The axis to scatter elements on. It is zero by default. + + reduction : string, optional + The reduction mode for scatter. Choise is from ["update", "add", "mul", "min", max"] + If update, the update values will replace the input data + If add, the update values will be added to the input data + If mul, the update values will be multiply to the input data + If min, there is choice of minimal between the update values and the input data + If max, there is choice of maximal between the update values and the input data + It is "update" by default + + Returns + ------- + ret : relay.Expr + The computed result. + """ + return _make.scatter_elements(data, indices, updates, axis, reduction) + + def scatter_nd(data, indices, updates, mode="update"): """Scatter values from an array and update. diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index 785ba395d25b..75867136e09e 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -38,6 +38,7 @@ from .broadcast import * from .sort import * from .scatter import * +from .scatter_elements import * from .sparse_fill_empty_rows import * from .sparse_reshape import * from .scatter_add import * diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index 31433e3a1a54..b746c95c0fc1 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -46,6 +46,7 @@ from .nms import get_valid_counts, non_max_suppression, all_class_non_max_suppression from .rcnn import * from .scatter import * +from .scatter_elements import * from .sort import * from .conv2d_nhwc_tensorcore import * from .conv3d_ndhwc_tensorcore import * diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py new file mode 100644 index 000000000000..8ed3e2972081 --- /dev/null +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -0,0 +1,167 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Scatter operator """ +import tvm +from tvm import te, tir +from ..utils import ceil_div, get_const_int +from ..math import cast + + +def scatter_elements(data, indices, updates, axis=0, reduction="update"): + """Scatter elements from updates to corresponding indices of copied data. + + Data, indices, updates and output have the same shape. + Indices can not have duplicates (if idx1 != idx2, then indices[idx1] != indices[idx2]) + if reduction == "update". + + .. code-block:: + + output[indices[i][j]][j] = f(output[indices[i][j]][j], updates[i][j]) if axis = 0 + output[i][indices[i][j]] = f(output[i][indices[i][j]], updates[i][j]) if axis = 1 + + where the update function f is determinted by the reduction. + Five types of the function are supported: "update", "add", "mul", "min" and "max" (see below) + + Parameters + ---------- + data : tvm.te.Tensor + The source array. + + indices : tvm.te.Tensor + The indices of the values to extract. + + updates : tvm.te.Tensor + The updates to apply at the Indices + + axis : optional, int + The axis to scatter on. It is zero by default. + + reduction : optional, string + The update mode for the algorithm, either "update", "add", "mul", "min" or "max" + If update, the update values will replace the input data + If add, the update values will be added to the input data + If mul, the update values will be multiply to the input data + If min, there is choice of minimal between the update values and the input data + If max, there is choice of maximal between the update values and the input data + It is "update" by default + + Returns + ------- + ret : tvm.te.Tensor + """ + if not isinstance(axis, int): + axis = get_const_int(axis) + + def gen_ir(data, indices, updates, out, axis): + ib = tir.ir_builder.create() + + data_ptr = ib.buffer_ptr(data) + indices_ptr = ib.buffer_ptr(indices) + updates_ptr = ib.buffer_ptr(updates) + out_ptr = ib.buffer_ptr(out) + + # Prepare ranges and strides + shape = data.shape + if axis < 0: + axis = len(shape) + axis + axis_range = cast(shape[axis], indices.dtype) + + before_axis_range = 1 + after_axis_range = 1 + for i, value in enumerate(shape, 0): + if i < axis: + before_axis_range *= value + elif i > axis: + after_axis_range *= value + before_axis_stride = axis_range * after_axis_range + full_range = before_axis_range * before_axis_stride + + ind_shape = indices.shape + ind_axis_range = ind_shape[axis] + + ind_before_axis_range = 1 + ind_after_axis_range = 1 + for i, value in enumerate(ind_shape, 0): + if i < axis: + ind_before_axis_range *= value + elif i > axis: + ind_after_axis_range *= value + ind_before_axis_stride = ind_axis_range * ind_after_axis_range + ind_full_range_excl_axis = ind_before_axis_range * ind_after_axis_range + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + # Copy initial input data to output + with ib.new_scope(): + num_blocks = ceil_div(full_range, max_threads) + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(bx, "thread_extent", num_blocks) + ib.scope_attr(tx, "thread_extent", max_threads) + + index = bx * max_threads + tx + with ib.if_scope(index < full_range): + out_ptr[index] = data_ptr[index] + + # TODO (vvchernov): use atomic function for special conditions (see cuda.scatter_nd) + with ib.new_scope(): + num_blocks_2 = ceil_div(ind_full_range_excl_axis, max_threads) + bx2 = te.thread_axis("blockIdx.x") + tx2 = te.thread_axis("threadIdx.x") + ib.scope_attr(bx2, "thread_extent", num_blocks_2) + ib.scope_attr(tx2, "thread_extent", max_threads) + + ind_fused = bx2 * max_threads + tx2 + with ib.if_scope(ind_fused < ind_full_range_excl_axis): + i = ind_fused // ind_after_axis_range + j = ind_fused % ind_after_axis_range + with ib.for_range(0, ind_axis_range, "k") as k: + # Offset along indices or updates + index1 = i * ind_before_axis_stride + k * ind_after_axis_range + j + # Get index and shift to positive side if need + new_index = indices_ptr[index1] + shifted_index = new_index + (new_index < 0) * axis_range + # Offset along data + index2 = i * before_axis_stride + shifted_index * after_axis_range + j + if reduction == "update": + out_ptr[index2] = updates_ptr[index1] + elif reduction == "add": + out_ptr[index2] += updates_ptr[index1] + elif reduction == "mul": + out_ptr[index2] *= updates_ptr[index1] + elif reduction == "min": + out_ptr[index2] = tir.min(out_ptr[index2], updates_ptr[index1]) + elif reduction == "max": + out_ptr[index2] = tir.max(out_ptr[index2], updates_ptr[index1]) + else: + raise NotImplementedError( + "scatter_elements reduction not in [update, add, mul, min, max]:", + reduction, + ) + + return ib.get() + + out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf") + return te.extern( + [data.shape], + [data, indices, updates], + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], axis), + dtype=data.dtype, + out_buffers=[out_buf], + name="scatter_elements_cuda", + tag="scatter_elements_cuda", + ) diff --git a/python/tvm/topi/scatter_elements.py b/python/tvm/topi/scatter_elements.py new file mode 100644 index 000000000000..e7948f6cb6c7 --- /dev/null +++ b/python/tvm/topi/scatter_elements.py @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""ScatterElements operator""" +from tvm import te +from tvm import tir +from . import utils +from .math import cast + + +def scatter_elements(data, indices, updates, axis=0, reduction="update"): + """Scatter elements from updates to corresponding indices of copied data. + + Data, indices, updates and output have the same shape. + Indices can not have duplicates (if idx1 != idx2, then indices[idx1] != indices[idx2]) + if reduction == "update". + + .. code-block:: + + output[indices[i][j]][j] = f(output[indices[i][j]][j], updates[i][j]) if axis = 0 + output[i][indices[i][j]] = f(output[i][indices[i][j]], updates[i][j]) if axis = 1 + + where the update function f is determinted by the reduction. + Five types of the function are supported: "update", "add", "mul", "min" and "max" (see below) + + Parameters + ---------- + data : tvm.te.Tensor + The source array. + + indices : tvm.te.Tensor + The indices of the values to extract. + + updates : tvm.te.Tensor + The updates to apply at the Indices + + axis : optional, int + The axis to scatter on. It is zero by default. + + reduction : optional, string + The update mode for the algorithm, either "update", "add", "mul", "min" or "max" + If update, the update values will replace the input data + If add, the update values will be added to the input data + If mul, the update values will be multiply to the input data + If min, there is choice of minimal between the update values and the input data + If max, there is choice of maximal between the update values and the input data + It is "update" by default + + Returns + ------- + ret : tvm.te.Tensor + """ + if not isinstance(axis, int): + axis = utils.get_const_int(axis) + + # Prepare ranges and strides + shape = data.shape + if axis < 0: + axis = len(shape) + axis + axis_range = cast(shape[axis], indices.dtype) + + full_range = 1 + after_axis_range = 1 + for i, value in enumerate(shape, 0): + full_range *= value + if i > axis: + after_axis_range *= value + before_axis_stride = axis_range * after_axis_range + + ind_shape = indices.shape + ind_axis_range = ind_shape[axis] + + ind_before_axis_range = 1 + ind_after_axis_range = 1 + for i, value in enumerate(ind_shape, 0): + if i < axis: + ind_before_axis_range *= value + elif i > axis: + ind_after_axis_range *= value + ind_before_axis_stride = ind_axis_range * ind_after_axis_range + + def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): + # pylint: disable=invalid-name + ib = tir.ir_builder.create() + + data = ib.buffer_ptr(data_ptr) + indices = ib.buffer_ptr(indices_ptr) + updates = ib.buffer_ptr(updates_ptr) + out = ib.buffer_ptr(out_ptr) + + # Copy initial input data to output + with ib.for_range(0, full_range, "i", kind="parallel") as i: + out[i] = data[i] + + # TODO(vvchernov): find optimal parallel approach + with ib.for_range( + 0, ind_before_axis_range * ind_after_axis_range, "fused", kind="parallel" + ) as fused: + i = fused // ind_after_axis_range + j = fused % ind_after_axis_range + with ib.for_range(0, ind_axis_range, "k") as k: + # Offset along indices or updates + index1 = i * ind_before_axis_stride + k * ind_after_axis_range + j + # TODO(vvchernov): assert for out of bounds, separated check for indices + k_new = indices[index1] + index_check = tir.LT(k_new, tir.const(0, indices.dtype)) + k_new += tir.Select(index_check, axis_range, tir.const(0, indices.dtype)) + # Offset along data + index2 = i * before_axis_stride + k_new * after_axis_range + j + if reduction == "update": + out[index2] = updates[index1] + elif reduction == "add": + out[index2] += updates[index1] + elif reduction == "mul": + out[index2] *= updates[index1] + elif reduction == "min": + tir.min(out[index2], updates[index1]) + elif reduction == "max": + tir.max(out[index2], updates[index1]) + else: + raise NotImplementedError( + "scatter_elements reduction not in [update, add, mul, min, max]:", + reduction, + ) + + return ib.get() + + out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf") + return te.extern( + [data.shape], + [data, indices, updates], + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), + dtype=data.dtype, + out_buffers=[out_buf], + name="scatter_elements.generic", + tag="scatter_elements.generic", + ) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 5c5cd6f4b721..01e5a7f5f359 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1191,6 +1191,62 @@ RELAY_REGISTER_OP("scatter_add") .set_attr("TOpPattern", kOpaque) .set_support_level(10); +// scatter_elements operator +TVM_REGISTER_NODE_TYPE(ScatterElementsAttrs); + +bool ScatterElementsRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, indices, updates, output] + ICHECK_EQ(types.size(), 4); + const auto* data = types[0].as(); + const auto* indices = types[1].as(); + const auto* updates = types[2].as(); + if (data == nullptr) { + ICHECK(types[0].as()) + << "ScatterElements: expect input data type to be TensorType but got " << types[0]; + return false; + } + if (indices == nullptr) { + ICHECK(types[1].as()) + << "ScatterElements: expect indices type to be TensorType but got " << types[1]; + return false; + } + if (updates == nullptr) { + ICHECK(types[2].as()) + << "ScatterElements: expect updates type to be TensorType but got " << types[2]; + return false; + } + // TODO(vvchernov): ONNX requires int32 and int64 + ICHECK(indices->dtype.is_int() || indices->dtype.is_uint()) + << "ScatterElements: indices must be a tensor of integers."; + + // Assign output + reporter->Assign(types[3], TensorType(data->shape, data->dtype)); + return true; +} + +Expr MakeScatterElements(Expr data, Expr indices, Expr updates, int axis, String reduction) { + auto attrs = make_object(); + attrs->axis = std::move(axis); + attrs->reduction = std::move(reduction); + static const Op& op = Op::Get("scatter_elements"); + return Call(op, {data, indices, updates}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.scatter_elements").set_body_typed(MakeScatterElements); + +// scatter_elements op has extern schedules: convert to Opaque to prevent compilation failures +RELAY_REGISTER_OP("scatter_elements") + .describe(R"code(Scatter elements with updating data by reduction of values in updates +at positions defined by indices.)code" TVM_ADD_FILELINE) + .set_num_inputs(3) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("indices", "Tensor", "The indices tensor.") + .add_argument("updates", "Tensor", "The input tensor of updates.") + .set_support_level(3) + .add_type_rel("ScatterElements", ScatterElementsRel) + .set_attr("TOpPattern", kOpaque); + // scatter_nd operator TVM_REGISTER_NODE_TYPE(ScatterNDAttrs); diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 40606761f804..d5caeab53922 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -393,7 +393,9 @@ Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Sp ICHECK(false_value.defined()) << "ValueError: true_value is undefined"; ICHECK(condition.dtype().is_bool()); ICHECK(condition.dtype().lanes() == true_value.dtype().lanes() || condition.dtype().lanes() == 1); - ICHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types"; + ICHECK(false_value.dtype() == true_value.dtype()) + << "TypeError: mismatched types. " + << "False type: " << false_value.dtype() << "; True type: " << true_value.dtype(); ObjectPtr node = make_object(); node->dtype = true_value.dtype(); diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 470a67e86c93..293f4d38e649 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -843,9 +843,7 @@ def verify_scatter(in_shape, indices, axis): indices = np.array(indices, dtype="int32") updates = np.random.uniform(size=indices.shape).astype("float32") - y = helper.make_node( - "ScatterElements", ["data", "indices", "updates"], ["output"], axis=axis - ) + y = helper.make_node("Scatter", ["data", "indices", "updates"], ["output"], axis=axis) graph = helper.make_graph( [y], @@ -858,7 +856,9 @@ def verify_scatter(in_shape, indices, axis): outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(in_shape))], ) model = helper.make_model(graph, producer_name="scatter_test") - verify_with_ort_with_inputs(model, [x, indices, updates], target=target, dev=dev) + # Scatter operator has been supported from version 9 and + # deprecated since version 11 of the default ONNX operator set + verify_with_ort_with_inputs(model, [x, indices, updates], target=target, dev=dev, opset=9) verify_scatter((4,), [1], 0) verify_scatter((1, 4), [[0]], 0) @@ -868,6 +868,77 @@ def verify_scatter(in_shape, indices, axis): verify_scatter((4, 3, 5, 6), [[[[2, 1, 0, 0]]]], 0) +@tvm.testing.parametrize_targets +def test_scatter_elements(target, dev): + """test_scatter_elements""" + + def verify_scatter_elements(in_shape, indices, axis=0, reduction="update"): + x = np.random.uniform(size=in_shape).astype("float32") + indices = np.array(indices, dtype="int32") + updates = np.random.uniform(size=indices.shape).astype("float32") + + scatter_elements_node = helper.make_node( + "ScatterElements", + ["data", "indices", "updates"], + ["output"], + axis=axis, + reduction=reduction, + ) + + graph = helper.make_graph( + [scatter_elements_node], + "scatter_elements_test", + inputs=[ + helper.make_tensor_value_info("data", TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("indices", TensorProto.INT32, list(indices.shape)), + helper.make_tensor_value_info("updates", TensorProto.FLOAT, list(indices.shape)), + ], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(in_shape))], + ) + model = helper.make_model(graph, producer_name="scatter_elements_test") + verify_with_ort_with_inputs(model, [x, indices, updates], target=target, dev=dev) + + # Usual scatter for 1d input + verify_scatter_elements((4,), [2, 3]) + # Usual scatter with specified positive axis + verify_scatter_elements((2, 2), [[1, 0], [0, 1]], 1) + # Usual scatter for 3d input with spicified negative indices and axis + verify_scatter_elements((3, 3, 3), [[[-1, -3]]], -1) + # Usual scatter for 4d input + verify_scatter_elements((4, 3, 5, 6), [[[[2, 1, 0, 0]]]]) + # Scatter elements with addition reduction of duplicates + verify_scatter_elements( + (3, 3, 3), + [[[0, 2, 1], [1, 1, 1], [2, 1, 0]], [[0, 2, 1], [1, 1, 1], [2, 1, 0]]], + 0, + "add", + ) + # Scatter elements with reduction and specified axis + verify_scatter_elements((3, 3, 3), [[[2, 2, 2], [1, 1, 1], [0, 0, 0]]], 2, "add") + # Scatter elements with multiplication reduction of duplicates + verify_scatter_elements( + (3, 3, 3), + [[[0, 2, 1], [1, 1, 1], [2, 1, 0]], [[0, 2, 1], [1, 1, 1], [2, 1, 0]]], + 0, + "mul", + ) + # TODO(vvchernov): min and max options are supported from 18 version, but CI supports 17 only + # # Scatter elements with min reduction of duplicates + # verify_scatter_elements( + # (3, 3, 3), + # [[[0, 2, 1], [1, 1, 1], [2, 1, 0]], [[0, 2, 1], [1, 1, 1], [2, 1, 0]]], + # 0, + # "min", + # ) + # # Scatter elements with max reduction of duplicates + # verify_scatter_elements( + # (3, 3, 3), + # [[[0, 2, 1], [1, 1, 1], [2, 1, 0]], [[0, 2, 1], [1, 1, 1], [2, 1, 0]]], + # 0, + # "max", + # ) + + @tvm.testing.parametrize_targets def test_slice(target, dev): """test_slice""" @@ -5397,7 +5468,6 @@ def verify_eyelike(indata, dynamic=False): "test_reduce_sum_negative_axes_keepdims_example", "test_reduce_sum_negative_axes_keepdims_random", "test_roialign_aligned_true", - "test_scatter_elements_with_duplicate_indices", "test_sequence_insert_at_back", "test_sequence_insert_at_front", "test_sequence_map_add_1_sequence_1_tensor",